In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

In [3]:
means_cifar10 = [0.4914, 0.4822, 0.4465]
deviations_cifar10 = [0.2023, 0.1994, 0.2010]
means_fashionmnist = [0.5]
deviations_fashionmnist = [0.5]

In [4]:
transform_train_cifar10 = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(means_cifar10, deviations_cifar10),
    ])

transform_test_cifar10 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(means_cifar10, deviations_cifar10),
    ])

transform_train_fashionmnist = transforms.Compose([
    transforms.RandomCrop(28, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(means_fashionmnist, deviations_fashionmnist),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # fashionMNIST have grayscale image, convert grayscale to "color"
                                                    # just repeat 1 layer to 3 layer
    ])

transform_test_fashionmnist = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(means_fashionmnist, deviations_fashionmnist),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
    ])

In [21]:
batch_size_train = 128
batch_size_test = 50

trainset_cifar10 = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train_cifar10)
trainloader_cifar10 = torch.utils.data.DataLoader(
    trainset_cifar10, batch_size=batch_size_train, shuffle=True, num_workers=0)
testset_cifar10 = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test_cifar10)
testloader_cifar10 = torch.utils.data.DataLoader(
    testset_cifar10, batch_size=batch_size_test, shuffle=False, num_workers=0)

trainset_fashionmnist = torchvision.datasets.FashionMNIST(
    root='./data', train=True, download=True, transform=transform_train_fashionmnist)
trainloader_fashionmnist = torch.utils.data.DataLoader(
    trainset_fashionmnist, batch_size=batch_size_train, shuffle=True, num_workers=0)
testset_fashionmnist = torchvision.datasets.FashionMNIST(
    root='./data', train=False, download=True, transform=transform_test_fashionmnist)
testloader_fashionmnist = torch.utils.data.DataLoader(
    testset_fashionmnist, batch_size=batch_size_test, shuffle=False, num_workers=0)

Files already downloaded and verified
Files already downloaded and verified


In [22]:
trainloader_cifar10

<torch.utils.data.dataloader.DataLoader at 0x1710c741f88>

In [23]:
it = iter(trainloader_cifar10)
first = next(it)
second = next(it)

In [24]:
first

[tensor([[[[-2.4291, -2.4291, -2.4291,  ..., -2.4291, -2.4291, -2.4291],
           [-2.4291, -2.4291, -2.4291,  ..., -2.4291, -2.4291, -2.4291],
           [-2.4291, -2.4291, -2.4291,  ..., -2.4291, -2.4291, -2.4291],
           ...,
           [-2.4291,  0.2848,  0.1491,  ...,  0.1879,  0.1879,  0.1879],
           [-2.4291,  0.1491,  0.1491,  ...,  0.1491,  0.1685,  0.1685],
           [-2.4291,  0.1297,  0.1491,  ...,  0.2073,  0.2267,  0.2267]],
 
          [[-2.4183, -2.4183, -2.4183,  ..., -2.4183, -2.4183, -2.4183],
           [-2.4183, -2.4183, -2.4183,  ..., -2.4183, -2.4183, -2.4183],
           [-2.4183, -2.4183, -2.4183,  ..., -2.4183, -2.4183, -2.4183],
           ...,
           [-2.4183,  0.1974,  0.0598,  ...,  0.1188,  0.0991,  0.0991],
           [-2.4183,  0.0598,  0.0598,  ...,  0.0991,  0.0794,  0.0794],
           [-2.4183,  0.0401,  0.0598,  ...,  0.1384,  0.1188,  0.1188]],
 
          [[-2.2214, -2.2214, -2.2214,  ..., -2.2214, -2.2214, -2.2214],
           [-

In [25]:
128*len(trainloader_fashionmnist)

60032

In [27]:
len(trainloader_fashionmnist)

469

In [28]:
len(trainloader_cifar10)

391

In [26]:
128*len(trainloader_cifar10)

50048

In [9]:
for batch_idx, (inputs, targets) in enumerate(trainloader_cifar10):
    print(inputs.shape)
    

torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 3

KeyboardInterrupt: 