# Imports

In [13]:
# imports
import pandas as pd
import random
import matplotlib.pyplot as plt
import imageio
import numpy as np
import glob
from skmultilearn.model_selection import IterativeStratification

# transformers
from torchvision import transforms
from skimage.transform import rescale, resize, downscale_local_mean
import albumentations

# dataset imports
import os
import torch
import torch.utils.data
import torchvision
import torch
import torch.nn as nn

# evaluation imports
import time
from sklearn import metrics

# model
import copy
from torch.utils.data import DataLoader
from torchvision.models import resnet50, resnet18
import tez
from tez.callbacks import EarlyStopping
import tqdm
from ignite.metrics import Accuracy, Precision, Recall

In [2]:
CHANNELS = ['red', 'green', 'blue', 'yellow']
TRAIN_CSV = '../input/image_subset/cell/train.csv'
IMG_DIR = '../input/image_subset/cell/'

# Dataset Class

Each Image has already been pre segmented, we will then split into n number of folds and train.

In [3]:
class CellDataset(object):
    '''Dataset class to fetch HPA cell-level images
    and corresponding weak labels
    '''
    def __init__(self, images, targets, img_root, augmentations=None):
        self.images = images
        self.targets = targets
        self.img_root = img_root
        self.augmentations = augmentations
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_id = self.images[idx] 
        img_channels = self._fetch_channels(img_id)
        img = self._channels_2_array(img_channels)
        img = resize(img, (512, 512))  # Always resize cell images for collate function
        # If augmentation pipeline provided, apply augmentations
        if self.augmentations:
            img = self.augmentations(image=img)['image']
        # Adjust to channel first indexing for pytorch (speed reasons)
        features = np.transpose(img, (2, 0, 1)).astype(np.float32)
        target = self.targets[idx]  # Grab target vector
        
        return {'image': torch.tensor(features),
                'target': torch.tensor(target)}
    
    def _fetch_channels(self, img_id: str, channel_names=CHANNELS):
        'Return absolute path of segmentation channels of a given image id'
        base = os.path.join(self.img_root, img_id)
        return [base + '_' + i  + '.png' for i in channel_names]
                                         
    def _channels_2_array(self, img_channels):
        'Return 3D array of pixel values of input image channels'
        r = imageio.imread(img_channels[0])
        g = imageio.imread(img_channels[1])
        b = imageio.imread(img_channels[2])
        pixel_arr = np.dstack((r, g, b))
        return pixel_arr

# Model Class

In [4]:
class ResNet50(tez.Model):
    '''Model class to facilitate transfer learning 
    from a resnet-18 model
    '''
    NUM_CLASSES = 19
    IMG_DIR = '../input/image_subset/cell/'
    DROPOUT_RATE = 0.1
    
    def __init__(self, train_dl, valid_dl, metric, batch_size=16, pretrained=True):
        # Initialise pretrained net and final layers for cell classification
        super().__init__()
        self.convolutions = nn.Sequential(*(list(resnet50(pretrained).children())[0:-1]))
        self.dropout = nn.Dropout(self.DROPOUT_RATE)
        self.dense = nn.Linear(2048, self.NUM_CLASSES)
        self.out = nn.Sigmoid()
        self.loss_fn = nn.BCELoss()
        self.metric = metric
        
        # Below should probably be in tez.Model super class but is a quick hack around
        # Training time image augmentation stack
        self.train_loader = train_dl
        self.valid_loader = valid_dl
        
    def forward(self, image, target=None):
        batch_size = image.shape[0]
        
        # Extracts 512x1 feature vector from pretrained resnet18 conv layers
        x = self.convolutions(image).reshape(batch_size, -1)
        # Fully connected dense layer to 19 class output
        output = self.dense(self.dropout(x))
        # Sigmoid activations on output to infer class probabilities
        output_probs = self.out(output)
        
        if target is not None:
            loss = self.loss_fn(output_probs, target.to(torch.float32))  # why to float32???
            metrics = self.monitor_metrics(output_probs, target)
            return output_probs, loss, metrics
        return output_probs, None, None
    
    def monitor_metrics(self, outputs, targets):
        if targets is None:
            return {}
        try:
            self.metric.reset()
            outputs = (outputs>0.5).float()
            self.metric.update((outputs, targets))
            p = self.metric.compute().item()
        except:
            # if equation numerator or denom is zero
            p = 0
        return {"precision": p}
    
    def fetch_optimizer(self):
        opt = torch.optim.Adam(self.parameters(), lr=3e-4)
        return opt
    
    def fetch_scheduler(self):
        sch = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1
        )
        return sch

# Augmentations

### For pixel normalisation

#### RGB

In [5]:
MEAN_CHANNEL_VALUES = (0.07843, 0.05381, 0.06853)  

CHANNEL_STD_DEV = ( 0.12131, 0.080155, 0.142555)

In [6]:
# Image augmentation stack 
train_aug = albumentations.Compose([
    albumentations.augmentations.transforms.Normalize(
        mean=MEAN_CHANNEL_VALUES, 
        std=CHANNEL_STD_DEV, 
        max_pixel_value= 1.0
    ),
    albumentations.Transpose(p=0.5),
    albumentations.HorizontalFlip(p=0.5),
    albumentations.VerticalFlip(p=0.5),
])

valid_aug = albumentations.Compose([
    albumentations.augmentations.transforms.Normalize(
        mean=MEAN_CHANNEL_VALUES, 
        std=CHANNEL_STD_DEV, 
        max_pixel_value= 1.0
    )
])

# Stratified-K-Folding and Splits

In [7]:
def create_split_df_cell(df, nfolds=2, order=2):
    # deep copy so changes can propogate
    df_copy = copy.deepcopy(df)
    # define label rows
    labels = [str(i) for i in range(19)]
        # add OHE columns
    for i in range(19):
        # Label column contains string not np.array
        df_copy['{}'.format(i)] = df.Label.apply(lambda x: (int(x.strip('[]').replace(', ', '')[i])))
        
    df_copy = df_copy.set_index("cell_id")
    
    split_df = df_copy.iloc[:][labels]
    
    split_df = split_df.groupby(split_df.index).sum() 

    X, y = split_df.index.values, split_df.values

    k_fold = IterativeStratification(n_splits=nfolds, order=order)

    splits = list(k_fold.split(X, y))

    fold_splits = np.zeros(df.shape[0]).astype(np.int32)

    for i in range(nfolds):
        fold_splits[splits[i][1]] = i

    split_df['Split'] = fold_splits    

    df_folds = []

    for fold in range(nfolds):

        df_fold = split_df.copy()
            
        train_df = df_fold[df_fold.Split != fold].drop('Split', axis=1).reset_index()
        
        val_df = df_fold[df_fold.Split == fold].drop('Split', axis=1).reset_index()
        
        df_folds.append((train_df, val_df))

    return df_folds

In [8]:
def get_split_dataloaders(split, batch_size, train_aug=None, valid_aug=None):
    labels = [str(i) for i in range(19)]
    train_df, val_df = split
    # dataset class with augmentations
    train_ds = CellDataset(train_df.cell_id.values, 
                           np.array(train_df.loc[:, labels]),
                           IMG_DIR,
                           augmentations=train_aug)
    
    val_ds = CellDataset(val_df.cell_id.values, 
                         np.array(val_df.loc[:, labels]),
                         IMG_DIR,
                         augmentations=valid_aug)
    
    # dataloaders for each split
    train_dl = DataLoader(train_ds, batch_size, shuffle=True)
    val_dl = DataLoader(val_ds, batch_size)
    # return splits dataloaders
    return train_dl, val_dl

In [9]:
print('Reading CSV and devising folds...')
# read training csv
df = pd.read_csv(TRAIN_CSV)
# get stratified k fold splits
splits = create_split_df_cell(df, 2, order=2)
print('Done!')

Reading CSV and devising folds...
Done!


In [10]:
batch_size = 4
metric = Precision()
for i, split in enumerate(splits):
    metric.reset()
    print('Fold {}'.format(i))
    # generate dataloaders for each fold
    train_dl, val_dl = get_split_dataloaders(split, batch_size)
                                             #train_aug, valid_aug)
    # Init model 
    model = ResNet50(train_dl,
                     val_dl,
                     metric,
                     batch_size=16, 
                     pretrained=False)

    # Early stopping
    es = EarlyStopping(
        monitor='valid_loss',
        model_path='../models/early_split_{}.bin'.format(i),
        patience=3,
        mode='min',
    )

    # Model training
    model.fit(
        train_dataset=None,  # dataset inits are overriden in the model class above
        valid_dataset=None,  # otherwise tez breaks for me when it tries to do this itself
        train_bs=16,
        device='cuda', 
        callbacks=[es],
        epochs=1
    )

    # Save model (with optimizer and scheduler for future usage)
    model.save('../models/final_final_model_split_{}.bin'.format(i))

Fold 0


  0%|                                                                      | 0/61228 [00:01<?, ?it/s]


RuntimeError: mat1 dim 1 must match mat2 dim 0