# SUMMARY

This notebook demonstrates how to compute mean and standard deviation of training and test images in `PyTorch`. Knowing mean and STD may be helpful for normalizing images within the augmentation pipeline. While computing mean is easy (we can simply average means over batches), standard deviation is a bit more tricky: averaging STDs across batches is not the same as the overall STD. Let's see how to do it properly!

Note: original pipeline comes from [this notebook](https://www.kaggle.com/kozodoi/computing-dataset-mean-and-std).


### TL;DR

- train images: `mean = -0.0001, std = 0.9055`
- test images:  `mean = -0.0002, std = 0.8453`

# PREPARATIONS

First, we import relevant libraries and specify some parameters. No need to use GPU because there is no modeling involved.

In [None]:
##### PACKAGES

import numpy as np
import pandas as pd

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2

import cv2

from tqdm import tqdm

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
##### PARAMS

device      = torch.device('cpu') 
num_workers = 4
batch_size  = 64
image_size  = 300
data_path   = '/kaggle/input/seti-breakthrough-listen/'

# DATA PREP

Now, let's set up a Dataset and a DataLoader.

In [None]:
##### DATA IMPORT

def get_train_file_path(image_id):
    '''
    Borrowed from https://www.kaggle.com/yasufuminakama/seti-nfnet-l0-starter-training
    '''
    return data_path + '/train/{}/{}.npy'.format(image_id[0], image_id)


df              = pd.read_csv(data_path + 'train_labels.csv')
df['file_path'] = df['id'].apply(get_train_file_path)
df.head()

In the `Dataset` class, we are stacking all images along the time axis. If you use a different way to merge the original arrays (e.g., only use a subset of cadences or do a channel-level stacking), the calculation results might be different.

In [None]:
##### DATASET

'''
Adapted from https://www.kaggle.com/yasufuminakama/seti-nfnet-l0-starter-training
'''

class ImageData(Dataset):
    
    def __init__(self, df, transform = None):
        self.df         = df
        self.file_names = df['file_path'].values
        self.transform  = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image = np.load(self.file_names[idx])
        image = image.astype(np.float32)
        image = np.vstack(image).transpose((1, 0))
        image = self.transform(image = image)['image']
            
        return image

Our augmentation pipeline only uses `A.Resize()` to resize the images.

In [None]:
##### AUGMENTATIONS

augs = A.Compose([A.Resize(height  = image_size, 
                           width   = image_size),
                  ToTensorV2()])

In [None]:
##### EXAMINE SAMPLE BATCH

# dataset
image_dataset = ImageData(df        = df, 
                          transform = augs)

# data loader
image_loader = DataLoader(image_dataset, 
                          batch_size  = batch_size, 
                          shuffle     = False, 
                          num_workers = num_workers)

# display images
for batch_idx, inputs in enumerate(image_loader):
    fig = plt.figure(figsize = (14, 7))
    for i in range(4):
        ax = fig.add_subplot(2, 4, i + 1, xticks = [], yticks = [])     
        plt.imshow(inputs[i].numpy()[0, :, :], cmap = 'gray')
    break

# CALCULATIONS

The computation is done in three steps:

1. Define placeholders to store two batch-level stats: sum and squared sum of pixel values. The first will be used to compute means, and the latter will be needed for standard deviation calculations.
2. Loop through the batches and add up channel-specific sum and squared sum values.
3. Perform final calculations to obtain data-level mean and standard deviation.

## Training images

In [None]:
##### COMPUTE PIXEL SUM AND SQUARED SUM

# placeholders
psum    = torch.tensor([0.0])
psum_sq = torch.tensor([0.0])

# loop through images
for inputs in tqdm(image_loader):
    psum    += inputs.sum(axis        = [0, 2, 3])
    psum_sq += (inputs ** 2).sum(axis = [0, 2, 3])

Finally, we make some further calculations:

- mean: simply divide the sum of pixel values by the total count - number of pixels in the dataset computed as `len(df) * image_size * image_size`
- standard deviation: use the following equation: `total_std = sqrt(psum_sq / count - total_mean ** 2)`

Why we use such a weird formula for STD? Well, because this is how the variance equation can be simplified to make use of the sum of squares when other data is not available. If you are not sure about this, expand the cell below to see a calculation example or [read this](https://www.thoughtco.com/sum-of-squares-formula-shortcut-3126266) for some details.

![variance equation](https://kozodoi.me/images/copied_from_nb/images/fig_variance.jpg)

In [None]:
##### FINAL CALCULATIONS

# pixel count
count = len(df) * image_size * image_size

# mean and STD
total_mean = psum / count
total_var  = (psum_sq / count) - (total_mean ** 2)
total_std  = torch.sqrt(total_var)

# output
print('Training data stats:')
print('- mean: {:.4f}'.format(total_mean.item()))
print('- std:  {:.4f}'.format(total_std.item()))

## Test images

In [None]:
###### DATA IMPORT

df = pd.read_csv(data_path + 'sample_submission.csv')

def get_test_file_path(image_id):
    return data_path + '/test/{}/{}.npy'.format(image_id[0], image_id)

df['file_path'] = df['id'].apply(get_test_file_path)
df.head()

In [None]:
###### DATASET & DATALOADER

# dataset
image_dataset = ImageData(df        = df, 
                          transform = augs)

# data loader
image_loader = DataLoader(image_dataset, 
                          batch_size  = batch_size, 
                          shuffle     = False, 
                          num_workers = num_workers)

In [None]:
##### CALCULATIONS

# placeholders
psum    = torch.tensor([0.0])
psum_sq = torch.tensor([0.0])

# loop through images
for inputs in tqdm(image_loader):
    psum    += inputs.sum(axis        = [0, 2, 3])
    psum_sq += (inputs ** 2).sum(axis = [0, 2, 3])
    
# pixel count
count = len(df) * image_size * image_size

# mean and STD
total_mean = psum / count
total_var  = (psum_sq / count) - (total_mean ** 2)
total_std  = torch.sqrt(total_var)

# output
print('Test data stats:')
print('- mean: {:.4f}'.format(total_mean.item()))
print('- std:  {:.4f}'.format(total_std.item()))