In [1]:
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets

train_dataset = datasets.ImageFolder(root='train',
                                     transform=transforms.Compose([transforms.Grayscale(num_output_channels=1),
                                                                   transforms.ToTensor()]))
test_dataset = datasets.ImageFolder(root='validation',
                                    transform=transforms.Compose([transforms.Grayscale(num_output_channels=1),
                                                                  transforms.ToTensor()]))

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [2]:
def get_mean_and_std(dataloader):
    channels_sum, channels_squared_sum, num_batches = 0, 0, 0
    for data, _ in dataloader:
        # Mean over batch, height and width, but not over the channels
        channels_sum += torch.mean(data, dim=[0,2,3])
        channels_squared_sum += torch.mean(data**2, dim=[0,2,3])
        num_batches += 1
    
    mean = channels_sum / num_batches

    # std = sqrt(E[X^2] - (E[X])^2)
    std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5

    return mean, std

In [3]:
print(get_mean_and_std(train_loader))

(tensor([0.5068]), tensor([0.2553]))


In [4]:
get_mean_and_std(test_loader)

(tensor([0.5086]), tensor([0.2549]))