## CIFAR10 and MNIST

In [None]:
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from tqdm import tqdm
from torch.utils.data import DataLoader

In [None]:
# CIFAR10

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_dataset = datasets.CIFAR10(root = 'data/',
                                 train = True,
                                 transform = transforms.ToTensor(),
                                 download = True)

train_loader = DataLoader(dataset = train_dataset, batch_size = 64,
                          shuffle = True)

def get_mean_std(loader):
  # We need per channel mean and STD for the given dataset
  channels_sum = 0
  channels_sum_sq = 0
  batch_num = 0

  for data, _ in tqdm(loader):
    channels_sum += torch.mean(data, dim = [0, 2, 3])
    channels_sum_sq += torch.mean(data ** 2, dim = [0, 2, 3])
    batch_num += 1

  mean = channels_sum / batch_num
  std = (channels_sum_sq / batch_num - mean ** 2) ** (1/2)
  return mean, std   

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))


Extracting data/cifar-10-python.tar.gz to data/


In [None]:
mean, std = get_mean_std(train_loader)
print("\n")
print("Mean: ", mean)
print("STD: ", std)

100%|██████████| 782/782 [00:06<00:00, 121.11it/s]



Mean:  tensor([0.4914, 0.4822, 0.4465])
STD:  tensor([0.2470, 0.2435, 0.2616])





In [None]:
# MNIST

train_dataset = datasets.MNIST(root = 'MNIST/',
                                 train = True,
                                 transform = transforms.ToTensor(),
                                 download = True)

train_loader = DataLoader(dataset = train_dataset, batch_size = 64,
                          shuffle = True)

In [None]:
mean, std = get_mean_std(train_loader)
print("\n")
print("Mean: ", mean)
print("STD: ", std)

100%|██████████| 938/938 [00:04<00:00, 197.28it/s]



Mean:  tensor([0.1307])
STD:  tensor([0.3081])



