# SUMMARY

This notebook demonstrates how to compute mean and standard deviation of training and test images using PyTorch. Knowing mean and STD may be helpful for normalizing images within the augmentation pipeline. While computing mean is easy (we can simply average it over batches), standard deviation is a bit more tricky: averaging STDs across batches is not the same as the overall STD. Let's see how to do it properly!


### TL;DR

- train images: `mean = 0.9871, std = 0.0888`
- test images:  `mean = 0.9863, std = 0.0921`

# PREPARATIONS

First, we import relevant libraries and specify some parameters. No need to use GPU because there is no modeling involved.

In [None]:
##### PACKAGES

import numpy as np
import pandas as pd

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2

import cv2

from tqdm import tqdm

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
##### PARAMS

device      = torch.device('cpu') 
num_workers = 4
batch_size  = 128
image_size  = 224
data_path   = '/kaggle/input/bms-molecular-translation/'

# DATA PREP

Now, let's set up a Dataset and a DataLoader.

In [None]:
##### DATA IMPORT

def get_train_file_path(image_id):
    '''
    Borrowed from https://www.kaggle.com/ihelon/molecular-translation-exploratory-data-analysis
    '''
    return data_path + 'train/{}/{}/{}/{}.png'.format(image_id[0], image_id[1], image_id[2], image_id)


df              = pd.read_csv(data_path + 'train_labels.csv')
df['file_path'] = df['image_id'].apply(get_train_file_path)
df.head()

In [None]:
##### DATASET

class ImageData(Dataset):
    
    def __init__(self, df, transform):
        super().__init__()
        self.df         = df
        self.file_paths = df['file_path'].values
        self.transform  = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        
        # import
        file_path = self.file_paths[idx]        
        image     = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE) 
        if image is None:
            raise FileNotFoundError(file_path)
            
        # augmentations
        if self.transform:
            image = self.transform(image = image)['image']
            
        return image

Our augmentation pipeline uses `A.Normalize()` with mean = 0 and std = 1 to scale pixel values from `[0, 255]` to `[0, 1]`.

In [None]:
##### AUGMENTATIONS

augs = A.Compose([A.Resize(height  = image_size, 
                           width   = image_size),
                  A.Normalize(mean = (0), 
                              std  = (1)),
                  ToTensorV2()])

In [None]:
##### EXAMINE SAMPLE BATCH

# dataset
image_dataset = ImageData(df        = df, 
                          transform = augs)

# data loader
image_loader = DataLoader(image_dataset, 
                          batch_size  = batch_size, 
                          shuffle     = False, 
                          num_workers = num_workers)

# display images
for batch_idx, inputs in enumerate(image_loader):
    fig = plt.figure(figsize = (14, 7))
    for i in range(4):
        ax = fig.add_subplot(2, 4, i + 1, xticks = [], yticks = [])     
        plt.imshow(inputs[i].numpy()[0, :, :], cmap = 'gray')
    break

# CALCULATIONS

The calculations are done in three steps:

1. Define placeholders to store two batch-level stats: sum and squared sum of pixel values. The first will be used to compute the mean, the second will be needed for standard deviation calculations.
2. Loop through the batches and add up sum and squared sum values.
3. Perform final calculations to obtain data-level mean and standard deviation.

## Training images

In [None]:
##### COMPUTE PIXEL SUM AND SQUARED SUM

# placeholders
psum    = torch.tensor([0.0])
psum_sq = torch.tensor([0.0])

# loop through images
for inputs in tqdm(image_loader):
    psum    += inputs.sum(axis        = [0, 2, 3])
    psum_sq += (inputs ** 2).sum(axis = [0, 2, 3])

- to get the mean, we simply divide the sum of pixel values by `count` - the total number of pixels in the dataset computed as `len(df) * image_size * image_size`.
- to get the standard deviation, we use the following equation: `total_std = sqrt(psum_sq / count - total_mean ** 2)`. Why? Well, because this is how the variance equation can be simplified to make use of the sum of squares. If you are confused about this, check out [this link](https://www.thoughtco.com/sum-of-squares-formula-shortcut-3126266) for some details.

![variance equation](https://kozodoi.me/images/copied_from_nb/images/fig_variance.jpg)

In [None]:
##### FINAL CALCULATIONS

# pixel count
count = len(df) * image_size * image_size

# mean and STD
total_mean = psum / count
total_var  = (psum_sq / count) - (total_mean ** 2)
total_std  = torch.sqrt(total_var)

# output
print('Training data stats:')
print('- mean: {:.4f}'.format(total_mean.item()))
print('- std:  {:.4f}'.format(total_std.item()))

## Test images

In [None]:
###### DATA IMPORT

df = pd.read_csv(data_path + 'sample_submission.csv')

def get_test_file_path(image_id):
    return data_path + 'test/{}/{}/{}/{}.png'.format(image_id[0], image_id[1], image_id[2], image_id)

df['file_path'] = df['image_id'].apply(get_test_file_path)
df.head()

In [None]:
###### DATASET & DATALOADER

# dataset
image_dataset = ImageData(df        = df, 
                          transform = augs)

# data loader
image_loader = DataLoader(image_dataset, 
                          batch_size  = batch_size, 
                          shuffle     = False, 
                          num_workers = num_workers)

In [None]:
##### CALCULATIONS

# placeholders
psum    = torch.tensor([0.0])
psum_sq = torch.tensor([0.0])

# loop through images
for inputs in tqdm(image_loader):
    psum    += inputs.sum(axis        = [0, 2, 3])
    psum_sq += (inputs ** 2).sum(axis = [0, 2, 3])
    
# pixel count
count = len(df) * image_size * image_size

# mean and STD
total_mean = psum / count
total_var  = (psum_sq / count) - (total_mean ** 2)
total_std  = torch.sqrt(total_var)

# output
print('Test data stats:')
print('- mean: {:.4f}'.format(total_mean.item()))
print('- std:  {:.4f}'.format(total_std.item()))

If you use a different image size, you can simply change this parameter to make sure calculations are done appropriately. Good luck!