## Constants

In [1]:
CHANNELS = ['red', 'green', 'blue', 'yellow']
TRAIN_CSV = 'D:/HPA_comp/single_cells/train_folds.csv'
IMG_DIR = 'D:/HPA_comp/single_cells'
MEAN_CHANNEL_VALUES = (0.07730, 0.05958, 0.07135)  # RGB
CHANNEL_STD_DEV = (0.12032, 0.08593, 0.14364)

In [2]:
from sklearn.metrics import auc, average_precision_score
from statistics import mean


def calc_prec_rec(y_pred_, y_target, thresh):
    'Return precision and recall for given threshold in tuple'
    y_pred = y_pred_.copy()
    # Convert to binary predictions
    super_idxs = y_pred >= thresh
    y_pred[super_idxs] = 1
    sub_idxs = y_pred < thresh
    y_pred[sub_idxs] = 0
    
    tp = (y_pred.T @ y_target).item()
    fp = sum((y_pred - y_target) > 0).item()
    fn = sum((y_pred - y_target) < 0).item()
    
    if (tp + fp) == 0:
        precision = None
    else:
        precision = tp / (tp + fp)
    
    if (tp + fn) == 0:
        recall = None
    else:
        recall = tp / (tp + fn)
    
    return precision, recall

def mAUC(y_preds, y_targets, n_iters=10):
    'Return mean precision recall auc across valid labels'
    ap_rec = []
    # Calc avg precision score one label at a time
    for lab in range(y_targets.shape[1]):
        y_pred = y_preds[:, lab]
        y_target = y_targets[:, lab]
        
        prec_rec = []
        for thresh in np.linspace(0, 0.9, n_iters):
            prec, rec = calc_prec_rec(y_pred, y_target, thresh)
            if prec is None or rec is None:
                continue
            prec_rec.append([prec, rec])
        if len(prec_rec) <= 1:
            continue
        # Extract precision recall for given thresholds and estimate auc
        prec_rec = np.array(prec_rec)
        prec, rec = prec_rec[:, 0], prec_rec[:, 1]
        ap = auc(rec, prec)
        ap_rec.append(ap)
    if len(ap_rec) == 0:
        return np.nan
    mean_AP = mean(ap_rec)
    return mean_AP

def skl_mAP(y_preds, y_targets):
    'Return sklearn mean average precision across valid labels'
    ap_rec = []
    # Calc avg precision score one label at a time
    for lab in range(y_targets.shape[1]):
        y_pred = y_preds[:, lab]
        y_target = y_targets[:, lab]
        # If no targets present, skip label to avoid /0 runtime warning
        if y_target.sum() == 0:
            continue
        ap = average_precision_score(y_target, y_pred)
        ap_rec.append(ap)
    if len(ap_rec) == 0:
        return np.nan
    mean_AP = mean(ap_rec)
    return mean_AP

## Dataset class

In [3]:
import torch
import pandas as pd
import numpy as np
import imageio
from skimage.transform import resize
import os


class CellDataset(object):
    '''Dataset class to fetch HPA cell-level images
    and corresponding weak labels
    '''
    def __init__(self, images, targets, img_root, augmentations=None):
        self.images = images
        self.targets = targets
        self.img_root = img_root
        self.augmentations = augmentations
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_id = self.images[idx] 
        img_channels = self._fetch_channels(img_id)
        img = self._channels_2_array(img_channels)
        img = resize(img, (224, 224))  # Always resize cell images for collate function
        # If augmentation pipeline provided, apply augmentations
        if self.augmentations:
            img = self.augmentations(image=img)['image']
        # Adjust to channel first indexing for pytorch (speed reasons)
        features = np.transpose(img, (2, 0, 1)).astype(np.float32)
        target = self.targets[idx]  # Grab target vector
        
        return {'image': torch.tensor(features),
                'target': torch.tensor(target)
                }
    
    def _fetch_channels(self, img_id: str, channel_names=CHANNELS):
        'Return absolute path of segmentation channels of a given image id'
        base = os.path.join(self.img_root, img_id)
        return [base + '_' + i  + '.png' for i in channel_names]
                                         
    def _channels_2_array(self, img_channels):
        'Return 3D array of pixel values of input image channels'
        r = imageio.imread(img_channels[0])
        g = imageio.imread(img_channels[1])
        b = imageio.imread(img_channels[2])
        pixel_arr = np.dstack((r, g, b))
        return pixel_arr

## Model class

In [4]:
from torch import nn
from torch.utils.data import DataLoader
from torchvision.models import resnet18
import tez

#from custom_metrics import mAUC, skl_mAP
#from custom_loss import FocalLoss


class ResNet18(tez.Model):
    '''Model class to facilitate transfer learning 
    from a resnet-18 model
    '''
    NUM_CLASSES = 19
    DROPOUT_RATE = 0.1
    IMG_DIR = 'D:/HPA_comp/single_cells'
    
    def __init__(self, train_df, valid_df, batch_size, train_aug=None, valid_aug=None, pretrained=True):
        # Initialise pretrained net and final layers for cell classification
        super().__init__()
        self.convolutions = nn.Sequential(*(list(resnet18(pretrained).children())[0:-1]))
        self.dropout = nn.Dropout(self.DROPOUT_RATE)
        self.dense = nn.Linear(512, self.NUM_CLASSES)
        self.out = nn.Sigmoid()
        self.loss_fn = nn.BCELoss()
        #self.loss_fn = nn.KLDivLoss()
        #self.loss_fn = FocalLoss()
        
        # Below should probably be in tez.Model super class but is a quick hack around
        # Training time image augmentation stack
        self.train_loader = self.gen_dataloader(train_df, batch_size, shuffle=True, aug=train_aug)
        self.valid_loader = self.gen_dataloader(valid_df, batch_size, shuffle=False, aug=valid_aug)
        
    def forward(self, image, target=None):
        batch_size = image.shape[0]
        
        # Extracts 512x1 feature vector from pretrained resnet18 conv layers
        x = self.convolutions(image).reshape(batch_size, -1)
        # Fully connected dense layer to 19 class output
        output = self.dense(self.dropout(x))
        # Sigmoid activations on output to infer class probabilities
        output_probs = self.out(output)
        
        if target is not None:
            loss = self.loss_fn(output_probs, target.to(torch.float32))  # why to float32???
            metrics = self.monitor_metrics(output_probs, target)
            return output_probs, loss, metrics
        return output_probs, None, None
    
    def monitor_metrics(self, outputs, targets):
        if targets is None:
            return {}
        targets = targets.cpu().detach().numpy()
        outputs = outputs.cpu().detach().numpy()
        # Calculate batch metrics
        mean_AUC = mAUC(outputs, targets, n_iters=10)
        mean_AP = skl_mAP(outputs, targets)
        return {'mAUC': mean_AUC,
                'mAP': mean_AP}
    
    def fetch_optimizer(self):
        opt = torch.optim.Adam(self.parameters(), lr=3e-4)
        return opt
    
    def fetch_scheduler(self):
        sch = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1
        )
        return sch
    
    def gen_dataloader(self, df, bs, shuffle, aug=None):
        'Return pytorch dataloader generated from cell image dataframe'
        # Extract images and targets as numpy arrays from dataframe tranche
        def extract_as_array(str_):
            list_ = str_.strip('][').split(', ')
            return np.array([int(i) for i in list_])
        images = df['cell_id'].values
        targets = df['Label'].apply(extract_as_array).values
        # Init custom dataset class and pass to pytorch
        dataset = CellDataset(images, targets, self.IMG_DIR, aug)
        return DataLoader(dataset, batch_size=bs, shuffle=shuffle)

## if __name__ == '__main__'

In [5]:
import albumentations as A

# Image augmentation stack 
train_aug = A.Compose([
    A.Transpose(p=0.5),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Normalize(
        mean=MEAN_CHANNEL_VALUES,
        std=CHANNEL_STD_DEV,
        max_pixel_value=1.0,
        p=1.0
    )
])
 
valid_aug = A.Compose([
    A.Normalize(
        mean=MEAN_CHANNEL_VALUES,
        std=CHANNEL_STD_DEV,
        max_pixel_value=1.0,
        p=1.0
    )
])

In [6]:
# Select training folds from csv
dfx = pd.read_csv(TRAIN_CSV, index_col=0).iloc[:1000, :]  # Just to test training loops
FOLD = 0

df_train = dfx[dfx['fold'] != FOLD].reset_index(drop=True)
df_valid = dfx[dfx['fold'] == FOLD].reset_index(drop=True)

# Init model
model = ResNet18(
     df_train, 
     df_valid, 
     batch_size=64, 
     train_aug=train_aug, 
     valid_aug=valid_aug, 
     pretrained=True
)

# Early stopping
from tez.callbacks import EarlyStopping
es = EarlyStopping(
    monitor='valid_loss',
    model_path='../models/model_checkpoint.bin',
    patience=3,
    mode='min',
)

# Model training
model.fit(
    train_dataset=None,  # dataset inits are overriden in the model class above
    valid_dataset=None,  # otherwise tez breaks for me when it tries to do this itself
    train_bs=64,
    device='cuda',
    callbacks=[es],
    epochs=1
)

# Save model (with optimizer and scheduler for future usage)
model.save('../models/trained_model.bin')

  0%|                                                                                                                   | 0/13 [00:00<?, ?it/s]

(64, 19)
(64, 19)


  8%|████▌                                                       | 1/13 [00:06<00:49,  4.13s/it, loss=0.72, mAP=0.16, mAUC=0.0984, stage=train]

(64, 19)
(64, 19)


 23%|█████████████▊                                              | 3/13 [00:09<00:30,  3.01s/it, loss=0.69, mAP=0.166, mAUC=0.102, stage=train]

(64, 19)
(64, 19)


 31%|██████████████████▍                                         | 4/13 [00:12<00:27,  3.01s/it, loss=0.663, mAP=0.181, mAUC=0.11, stage=train]

(64, 19)
(64, 19)


 38%|██████████████████████▋                                    | 5/13 [00:15<00:23,  2.92s/it, loss=0.639, mAP=0.177, mAUC=0.108, stage=train]

(64, 19)
(64, 19)


 38%|██████████████████████▋                                    | 5/13 [00:17<00:23,  2.92s/it, loss=0.614, mAP=0.191, mAUC=0.112, stage=train]

(64, 19)
(64, 19)


 54%|███████████████████████████████▊                           | 7/13 [00:20<00:16,  2.83s/it, loss=0.592, mAP=0.197, mAUC=0.115, stage=train]

(64, 19)
(64, 19)


 54%|███████████████████████████████▊                           | 7/13 [00:21<00:18,  3.14s/it, loss=0.592, mAP=0.197, mAUC=0.115, stage=train]


KeyboardInterrupt: 