## This is a tutorial to compute image mean and std for image normalization

In [2]:
import torch
import torchvision
from torchvision import transforms,datasets
from tqdm import tqdm

In [3]:
augs = transforms.Compose([
    transforms.Resize((64, 64)),   
])

image_dataset = datasets.ImageFolder(root = './images/',
                                    transform = augs)

image_loader = torch.utils.data.DataLoader(image_dataset,
                                  batch_size = 32,
                                  shuffle = True)

In [5]:
image_dataset

Dataset ImageFolder
    Number of datapoints: 253
    Root location: ./images/
    StandardTransform
Transform: Compose(
               Resize(size=(64, 64), interpolation=bilinear, max_size=None, antialias=None)
           )

### Compute mean/std

In [4]:
psum = torch.zeros(3, dtype =torch.float32)
psum_sq = torch.zeros(3, dtype=torch.float32)

for inputs, labels in tqdm(image_loader):
    psum += inputs.sum(axis = [0,2,3])
    psum_sq += (inputs**2).sum(axis = [0,2,3])

# pixel count
batch_size = 32
img_width = 64
img_height = 64

count = len(image_loader)*batch_size*img_width*img_height

# mean and std
mean = psum / count
var = psum_sq / count - mean**2
std = torch.sqrt(var)

print('mean:\t', mean.numpy())
print('std:\t', std.numpy())

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 14.63it/s]

mean:	 [0.2451813  0.24542268 0.24551283]
std:	 [0.2235546  0.22381799 0.22402456]





## Customize a Dataset object

In [49]:
from torch.utils.data import Dataset, DataLoader
from skimage.io import imread
from torchvision.io import read_image
import pandas as pd
import os

class CustomImageDataset(Dataset):
    '''A custom Dataset class must implement three functions:
    __init__, __len__, __getitem__'''
    def __init__(self,data,directory, transform = None):
        self.data = data
        self.directory = directory
        self.transform = transform
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        
        img_path = os.path.join(self.directory, self.data.iloc[idx, 0])
        img = read_image(img_path, mode = torchvision.io.ImageReadMode.GRAY)
        label = self.data.iloc[idx, 1]
        if self.transform:
            img = self.transform(img)
        return img, label

In [50]:
def imagefolder_df(directory):
    path = [f for f in os.listdir(directory) if not f.startswith('.')]
    if path and os.path.isdir(os.path.join(directory,path[0])):
        files_col = []
        label_col = []
        for label in path:
            files = [os.path.join(label, f) 
                     for f in os.listdir(os.path.join(directory, label))]
            label_ = [label] * len(files)
            files_col.extend(files)
            label_col.extend(label_)
            
        df = pd.DataFrame({'files': files_col, 'labels': label_col})
    else:
        df = pd.DataFrame({'files':path})
    return df