## Dataset class:

In [1]:
import torch
import pandas as pd
import numpy as np
from PIL import Image
from skimage.transform import resize
import os

CHANNELS = ['red', 'green', 'blue', 'yellow']
TRAIN_CSV = '../input/train_cells/train.csv'


class CellDataset(object):
    '''Dataset class to fetch HPA cell-level images
    and corresponding weak labels
    '''
    def __init__(self, images, targets, img_root, augmentations=None):
        self.images = images
        self.targets = targets
        self.img_root = img_root
        self.augmentations = augmentations
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_id = self.images[idx] 
        img_channels = self._fetch_channels(img_id)
        features = self._channels_2_array(img_channels)
        features = resize(features, (512, 512))  # Resize cell images for collate function
        # Adjust to channel first indexing for pytorch (speed reasons)
        features = np.transpose(features, (2, 0, 1)).astype(np.float32)
        # Grab target vector
        target = self.targets[idx]
        
        return {'image': torch.tensor(features),
                'target': torch.tensor(target)
                }
    
    def _fetch_channels(self, img_id: str, channel_names=CHANNELS):
        'Return absolute path of segmentation channels of a given image id'
        base = os.path.join(self.img_root, img_id)
        return [base + '_' + i  + '.png' for i in channel_names]
                                         
    def _channels_2_array(self, img_channels):
        'Return 3D array of pixel values of input image channels'
        # Init and reshape single channel array so we can concat other channels
        channel_1 = np.array(Image.open(img_channels[0]))
        shape = channel_1.shape + (1,)  
        pixel_arr = channel_1.reshape(shape)
        # Lay out 4 channels in 3D array for model input
        for channel in img_channels[1:3]:
            channel_values = np.array(Image.open(channel)).reshape(shape)
            pixel_arr = np.concatenate([pixel_arr, channel_values], axis=2)
        return pixel_arr

## Model Class:

In [2]:
from torch import nn
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from sklearn.metrics import average_precision_score
import tez


class ResNet18(tez.Model):
    '''Model class to facilitate transfer learning 
    from a resnet-18 model
    '''
    NUM_CLASSES = 19
    IMG_DIR = '../input/train_cells'
    DROPOUT_RATE = 0.1
    
    def __init__(self, train_df, valid_df, batch_size=16, pretrained=True):
        # Initialise pretrained net and final layers for cell classification
        super().__init__()
        self.convolutions = nn.Sequential(*(list(resnet18(pretrained).children())[0:-1]))
        self.dropout = nn.Dropout(self.DROPOUT_RATE)
        self.dense = nn.Linear(512, self.NUM_CLASSES)
        self.out = nn.Sigmoid()
        self.loss_fn = nn.BCELoss()
        
        # Below should probably be in tez.Model super class but is a quick hack around
        self.train_loader = self.gen_dataloader(train_df, batch_size, shuffle=True)
        self.valid_loader = self.gen_dataloader(valid_df, batch_size, shuffle=False)
        
    def forward(self, image, target=None):
        batch_size = image.shape[0]
        
        # Extracts 512x1 feature vector from pretrained resnet18 conv layers
        x = self.convolutions(image).reshape(batch_size, -1)
        # Fully connected dense layer to 19 class output
        output = self.dense(self.dropout(x))
        # Sigmoid activations on output to infer class probabilities
        output_probs = self.out(output)
        
        if target is not None:
            loss = self.loss_fn(output_probs, target.to(torch.float32))  # why to float32???
            metrics = self.monitor_metrics(output_probs, target)
            return output_probs, loss, metrics
        return output_probs, None, None
    
    def monitor_metrics(self, outputs, targets):
        if targets is None:
            return {}
        targets = targets.cpu().detach().numpy()
        outputs = outputs.cpu().detach().numpy()
        #precision = average_precision_score(targets, outputs, average=None)
        #precision = accuracy_score(targets, outputs)
        precision = 1
        return {"precision": precision}
    
    def fetch_optimizer(self):
        opt = torch.optim.Adam(self.parameters(), lr=3e-4)
        return opt
    
    def fetch_scheduler(self):
        sch = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1
        )
        return sch
    
    def gen_dataloader(self, df, bs, shuffle):
        'Return pytorch dataloader generated from cell image dataframe'
        # Extract images and targets as numpy arrays from dataframe tranche
        def extract_as_array(str_):
            list_ = str_.strip('][').split(', ')
            return np.array([int(i) for i in list_])
        images = df['cell_id'].values
        targets = df['Label'].apply(extract_as_array).values
        # Init custom dataset class and pass to pytorch
        dataset = CellDataset(images, targets, self.IMG_DIR)
        return DataLoader(dataset, batch_size=bs, shuffle=shuffle)


## if __name__ == '__main__':

In [3]:
# Select training folds from csv
df = pd.read_csv(TRAIN_CSV, index_col=0)
df_train, df_valid = df.iloc[:30, :] , df.iloc[30:, :]

# Init model 
model = ResNet18(df_train, df_valid, batch_size=16, pretrained=True)

# Model training
model.fit(
    train_dataset=None,
    valid_dataset=None,
    train_bs=16,
    device="cuda",
    epochs=1
)

# Save model (with optimizer and scheduler for future usage)
model.save("model.bin")

100%|██████████████████████████████████████| 2/2 [00:04<00:00,  2.23s/it, loss=0.69, precision=1, stage=train]
100%|█████████████████████████████████████| 1/1 [00:00<00:00,  1.32it/s, loss=0.642, precision=1, stage=valid]
