# Normalisation

In [None]:
import torch
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
import imageio
from skimage.transform import resize
import os

In [None]:
CHANNELS = ['red', 'green', 'blue', 'yellow']
TRAIN_CSV = '../input/image_subset/cell/train.csv'
IMG_DIR = '../input/image_subset/cell/'

In [None]:
class CellDataset(object):
    '''Dataset class to fetch HPA cell-level images
    and corresponding weak labels
    '''
    def __init__(self, images, targets, img_root, augmentations=None):
        self.images = images
        self.targets = targets
        self.img_root = img_root
        self.augmentations = augmentations
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_id = self.images[idx] 
        img_channels = self._fetch_channels(img_id)
        img = self._channels_2_array(img_channels)
        img = resize(img, (224, 224))  # Always resize cell images for collate function
        # If augmentation pipeline provided, apply augmentations
        if self.augmentations:
            img = self.augmentations(image=img)['image']
        # Adjust to channel first indexing for pytorch (speed reasons)
        features = np.transpose(img, (2, 0, 1)).astype(np.float32)
        target = self.targets[idx]  # Grab target vector
        
        return {'image': torch.tensor(features),
                'target': torch.tensor(target)
                }
    
    def _fetch_channels(self, img_id: str, channel_names=CHANNELS):
        'Return absolute path of segmentation channels of a given image id'
        base = os.path.join(self.img_root, img_id)
        return [base + '_' + i  + '.png' for i in channel_names]
                                         
    def _channels_2_array(self, img_channels):
        'Return 3D array of pixel values of input image channels'
        r = imageio.imread(img_channels[0])
        g = imageio.imread(img_channels[1])
        b = imageio.imread(img_channels[2])
        pixel_arr = np.dstack((r, g, b))
        return pixel_arr

In [None]:
def gen_dataloader(df, img_dir, bs, shuffle, aug=None):
        'Return pytorch dataloader generated from cell image dataframe'
        # Extract images and targets as numpy arrays from dataframe tranche
        def extract_as_array(str_):
            list_ = str_.strip('][').split(', ')
            return np.array([int(i) for i in list_])
        images = df['cell_id'].values
        targets = df['Label'].apply(extract_as_array).values
        # Init custom dataset class and pass to pytorch
        dataset = CellDataset(images, targets, img_dir, aug)
        return DataLoader(dataset, batch_size=bs, shuffle=shuffle)

In [None]:
def grab_pixel_aggs(dataloader, sample_size):
    'Return dataframe of image channel means and standard deviations'
    aggs_df = pd.DataFrame()
    
    for count, sample in enumerate(dataloader):
        image_tensor = sample['image']  # indexed by (C, H, W)
        aggs = {}
        # Grab cell image channel aggregates
        channels = ['red', 'green', 'blue']
        for idx, channel_name in enumerate(channels):
            channel = image_tensor[0, idx, :, :]
            min_value = channel.min().item()
            max_value = channel.max().item()
            mean = channel.mean().item()
            std = channel.std().item()
            aggs[channel_name + '_max'] = max_value
            aggs[channel_name + '_min'] = min_value
            aggs[channel_name + '_mean'] = mean
            aggs[channel_name + '_std'] = std
        aggs_df = aggs_df.append(aggs, ignore_index=True)
        if count >= sample_size:
            break
    return aggs_df

In [None]:
df = pd.read_csv(TRAIN_CSV, index_col=0)
loader = gen_dataloader(df, img_dir=IMG_DIR, bs=1, shuffle=True, aug=None)
aggs_df = grab_pixel_aggs(loader, sample_size=1000)
aggs_df.head()

## Magic Numbers

In [None]:
print(f'Red channel mean:   {aggs_df.red_mean.mean()}')
print(f'Red std dev:   {aggs_df.red_std.mean()}')
print(f'Green channel mean:   {aggs_df.green_mean.mean()}')
print(f'Green std dev:   {aggs_df.green_std.mean()}')
print(f'Blue channel mean:   {aggs_df.blue_mean.mean()}')
print(f'Blue std dev:   {aggs_df.blue_std.mean()}')
print('Global max:   ?')