## Constants

In [18]:
CHANNELS = ['red', 'green', 'blue', 'yellow']
TRAIN_CSV = 'D:/HPA_comp/single_cells/train_folds.csv'
IMG_DIR = 'D:/HPA_comp/single_cells'
MEAN_CHANNEL_VALUES = (0.07730, 0.05958, 0.07135)  # RGB
CHANNEL_STD_DEV = (0.12032, 0.08593, 0.14364)

## Dataset class

In [19]:
import torch
import pandas as pd
import numpy as np
import imageio
from skimage.transform import resize
import os


class CellDataset(object):
    '''Dataset class to fetch HPA cell-level images
    and corresponding weak labels
    '''
    def __init__(self, images, targets, img_root, augmentations=None):
        self.images = images
        self.targets = targets
        self.img_root = img_root
        self.augmentations = augmentations
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_id = self.images[idx] 
        img_channels = self._fetch_channels(img_id)
        img = self._channels_2_array(img_channels)
        img = resize(img, (224, 224))  # Always resize cell images for collate function
        # If augmentation pipeline provided, apply augmentations
        if self.augmentations:
            img = self.augmentations(image=img)['image']
        # Adjust to channel first indexing for pytorch (speed reasons)
        features = np.transpose(img, (2, 0, 1)).astype(np.float32)
        target = self.targets[idx]  # Grab target vector
        
        return {'image': torch.tensor(features),
                'target': torch.tensor(target)
                }
    
    def _fetch_channels(self, img_id: str, channel_names=CHANNELS):
        'Return absolute path of segmentation channels of a given image id'
        base = os.path.join(self.img_root, img_id)
        return [base + '_' + i  + '.png' for i in channel_names]
                                         
    def _channels_2_array(self, img_channels):
        'Return 3D array of pixel values of input image channels'
        r = imageio.imread(img_channels[0])
        g = imageio.imread(img_channels[1])
        b = imageio.imread(img_channels[2])
        pixel_arr = np.dstack((r, g, b))
        return pixel_arr

## Model class

In [23]:
from torch import nn
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from sklearn.metrics import average_precision_score
import tez

class ResNet18(tez.Model):
    '''Model class to facilitate transfer learning 
    from a resnet-18 model
    '''
    NUM_CLASSES = 19
    DROPOUT_RATE = 0.1
    IMG_DIR = 'D:/HPA_comp/single_cells'
    
    def __init__(self, train_df, valid_df, batch_size=16, train_aug=None, valid_aug=None, pretrained=True):
        # Initialise pretrained net and final layers for cell classification
        super().__init__()
        self.convolutions = nn.Sequential(*(list(resnet18(pretrained).children())[0:-1]))
        self.dropout = nn.Dropout(self.DROPOUT_RATE)
        self.dense = nn.Linear(512, self.NUM_CLASSES)
        self.out = nn.Sigmoid()
        self.loss_fn = nn.BCELoss()
        
        # Below should probably be in tez.Model super class but is a quick hack around
        # Training time image augmentation stack
        self.train_loader = self.gen_dataloader(train_df, batch_size, shuffle=True, aug=train_aug)
        self.valid_loader = self.gen_dataloader(valid_df, batch_size, shuffle=False, aug=valid_aug)
        
    def forward(self, image, target=None):
        batch_size = image.shape[0]
        
        # Extracts 512x1 feature vector from pretrained resnet18 conv layers
        x = self.convolutions(image).reshape(batch_size, -1)
        # Fully connected dense layer to 19 class output
        output = self.dense(self.dropout(x))
        # Sigmoid activations on output to infer class probabilities
        output_probs = self.out(output)
        
        if target is not None:
            loss = self.loss_fn(output_probs, target.to(torch.float32))  # why to float32???
            metrics = self.monitor_metrics(output_probs, target)
            return output_probs, loss, metrics
        return output_probs, None, None
    
    def monitor_metrics(self, outputs, targets):
        if targets is None:
            return {}
        targets = targets.cpu().detach().numpy()
        outputs = outputs.cpu().detach().numpy()
        precision = average_precision_score(targets, outputs, average=None)
        #accuracy = accuracy_score(targets, outputs)
        #precision = 1
        return {"precision": precision}
    
    def fetch_optimizer(self):
        opt = torch.optim.Adam(self.parameters(), lr=3e-4)
        return opt
    
    def fetch_scheduler(self):
        sch = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1
        )
        return sch
    
    def gen_dataloader(self, df, bs, shuffle, aug=None):
        'Return pytorch dataloader generated from cell image dataframe'
        # Extract images and targets as numpy arrays from dataframe tranche
        def extract_as_array(str_):
            list_ = str_.strip('][').split(', ')
            return np.array([int(i) for i in list_])
        images = df['cell_id'].values
        targets = df['Label'].apply(extract_as_array).values
        # Init custom dataset class and pass to pytorch
        dataset = CellDataset(images, targets, self.IMG_DIR, aug)
        return DataLoader(dataset, batch_size=bs, shuffle=shuffle)

## if __name__ == '__main__'

In [21]:
import albumentations as A

# Image augmentation stack 
train_aug = A.Compose([
    A.Transpose(p=0.5),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Normalize(
        mean=MEAN_CHANNEL_VALUES,
        std=CHANNEL_STD_DEV,
        max_pixel_value=1.0,
        p=1.0
    )
])
 
valid_aug = A.Compose([
    A.Normalize(
        mean=MEAN_CHANNEL_VALUES,
        std=CHANNEL_STD_DEV,
        max_pixel_value=1.0,
        p=1.0
    )
])

In [24]:
# Select training folds from csv
dfx = pd.read_csv(TRAIN_CSV, index_col=0)
FOLD = 0
#df_train, df_valid = df.iloc[:30, :] , df.iloc[30:, :]
df_train = dfx[dfx['fold'] != FOLD].reset_index(drop=True)
df_valid = dfx[dfx['fold'] == FOLD].reset_index(drop=True)

# Init model
model = ResNet18(
     df_train, 
     df_valid, 
     batch_size=16, 
     train_aug=train_aug, 
     valid_aug=valid_aug, 
     pretrained=True
)

# Early stopping
from tez.callbacks import EarlyStopping
es = EarlyStopping(
    monitor='valid_loss',
    model_path='../models/model_checkpoint.bin',
    patience=3,
    mode='min',
)

# Model training
model.fit(
    train_dataset=None,  # dataset inits are overriden in the model class above
    valid_dataset=None,  # otherwise tez breaks for me when it tries to do this itself
    train_bs=16,
    device='cuda',
    callbacks=[es],
    epochs=1
)

# Save model (with optimizer and scheduler for future usage)
model.save('../models/trained_model.bin')

100%|████████████████████████████████████████████████████████████████| 1391/1391 [16:20<00:00,  1.42it/s, loss=0.202, precision=1, stage=train]
100%|██████████████████████████████████████████████████████████████████| 346/346 [03:48<00:00,  1.51it/s, loss=0.188, precision=1, stage=valid]


Validation score improved (inf --> 0.1880082305025503). Saving model!


In [27]:
df = pd.read_csv('D:/HPA_comp/single_cells/train.csv', index_col=0)
df.tail(10)

Unnamed: 0,cell_id,cell_number,edge_of_img,parent_image_id,size_x,size_y,Label
27766,15029c6e-bb9c-11e8-b2b9-ac1f6b6435d0_cell_8,8.0,0.0,15029c6e-bb9c-11e8-b2b9-ac1f6b6435d0,468.0,850.0,"[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
27767,14c9a968-bb9c-11e8-b2b9-ac1f6b6435d0_cell_1,1.0,1.0,14c9a968-bb9c-11e8-b2b9-ac1f6b6435d0,358.0,494.0,"[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, ..."
27768,14c9a968-bb9c-11e8-b2b9-ac1f6b6435d0_cell_2,2.0,0.0,14c9a968-bb9c-11e8-b2b9-ac1f6b6435d0,332.0,508.0,"[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, ..."
27769,14c9a968-bb9c-11e8-b2b9-ac1f6b6435d0_cell_3,3.0,0.0,14c9a968-bb9c-11e8-b2b9-ac1f6b6435d0,336.0,428.0,"[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, ..."
27770,14c9a968-bb9c-11e8-b2b9-ac1f6b6435d0_cell_4,4.0,0.0,14c9a968-bb9c-11e8-b2b9-ac1f6b6435d0,276.0,376.0,"[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, ..."
27771,14c9a968-bb9c-11e8-b2b9-ac1f6b6435d0_cell_5,5.0,0.0,14c9a968-bb9c-11e8-b2b9-ac1f6b6435d0,268.0,255.0,"[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, ..."
27772,14c9a968-bb9c-11e8-b2b9-ac1f6b6435d0_cell_6,6.0,0.0,14c9a968-bb9c-11e8-b2b9-ac1f6b6435d0,628.0,604.0,"[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, ..."
27773,14c9a968-bb9c-11e8-b2b9-ac1f6b6435d0_cell_7,7.0,0.0,14c9a968-bb9c-11e8-b2b9-ac1f6b6435d0,276.0,267.0,"[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, ..."
27774,14c9a968-bb9c-11e8-b2b9-ac1f6b6435d0_cell_8,8.0,1.0,14c9a968-bb9c-11e8-b2b9-ac1f6b6435d0,262.0,388.0,"[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, ..."
27775,14c9a968-bb9c-11e8-b2b9-ac1f6b6435d0_cell_9,9.0,0.0,14c9a968-bb9c-11e8-b2b9-ac1f6b6435d0,336.0,408.0,"[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, ..."


In [29]:
df = pd.read_csv('D:/HPA_comp/train/train.csv')
df['ID'] == '14c9a968-bb9c-11e8-b2b9-ac1f6b6435d0'

Unnamed: 0,ID,Label
0,5c27f04c-bb99-11e8-b2b9-ac1f6b6435d0,8|5|0
1,5fb643ee-bb99-11e8-b2b9-ac1f6b6435d0,14|0
2,60b57878-bb99-11e8-b2b9-ac1f6b6435d0,6|1
3,5c1a898e-bb99-11e8-b2b9-ac1f6b6435d0,16|10
4,5b931256-bb99-11e8-b2b9-ac1f6b6435d0,14|0
...,...,...
21801,dd0989c4-bbca-11e8-b2bc-ac1f6b6435d0,14
21802,dd1f7fb8-bbca-11e8-b2bc-ac1f6b6435d0,3|0
21803,dd5cb36a-bbca-11e8-b2bc-ac1f6b6435d0,14|0
21804,df573730-bbca-11e8-b2bc-ac1f6b6435d0,14
