In [None]:
# Add the timm pytorch image model library from which we can extract CSPNet
import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')
import timm

In [None]:
# Software library written for data manipulation and analysis.
import pandas as pd

# Python library to interact with the file system.
import os

# Python library for image augmentation
import albumentations as A

# fastai library for computer vision tasks
from fastai.vision.all import *

# Developing and training neural network based deep learning models.
import torch
from torch import nn

In [None]:
# Define path to dataset, whose benefit is that this sample is more balanced than original train data.
path = Path('../input/hpa-cell-tiles-sample-balanced-dataset')

In [None]:
df = pd.read_csv(path/'cell_df.csv')
df.head()

In [None]:
# extract the the total number of target labels
labels = [str(i) for i in range(19)]
for x in labels: df[x] = df['image_labels'].apply(lambda r: int(x in r.split('|')))

In [None]:
# Here a sample of the dataset has been taken, change frac to 1 to train the entire dataset!
dfs = df.sample(frac=0.1, random_state=42)
dfs = dfs.reset_index(drop=True)
len(dfs)

In [None]:
# obtain the input images.
def get_x(r): 
    return path/'cells'/(r['image_id']+'_'+str(r['cell_id'])+'.jpg')

# obtain the targets.
def get_y(r): 
    return r['image_labels'].split('|')

In [None]:
'''AlbumentationsTransform will perform different transforms over both
   the training and validation datasets ''' 
class AlbumentationsTransform(RandTransform):
    
    '''split_idx is None, which allows for us to say when we're setting our split_idx.
       We set an order to 2 which means any resize operations are done first before our new transform. '''
    split_idx, order = None, 2
    
    def __init__(self, train_aug, valid_aug): store_attr()
    
    # Inherit from RandTransform, allows for us to set that split_idx in our before_call.
    def before_call(self, b, split_idx):
        self.idx = split_idx
    
    # If split_idx is 0, run the trainining augmentation, otherwise run the validation augmentation. 
    def encodes(self, img: PILImage):
        if self.idx == 0:
            aug_img = self.train_aug(image=np.array(img))['image']
        else:
            aug_img = self.valid_aug(image=np.array(img))['image']
        return PILImage.create(aug_img)

In [None]:
def get_train_aug(size): 
    
    return A.Compose([
            # allows to combine RandomCrop and RandomScale
            A.RandomResizedCrop(size,size),
            
            # Transpose the input by swapping rows and columns.
            A.Transpose(p=0.5),
        
            # Flip the input horizontally around the y-axis.
            A.HorizontalFlip(p=0.5),
        
            # Flip the input horizontally around the x-axis.
            A.VerticalFlip(p=0.5),
        
            # Randomly apply affine transforms: translate, scale and rotate the input.
            A.ShiftScaleRotate(p=0.5),
        
            # Randomly change hue, saturation and value of the input image.
            A.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
        
            # Randomly change brightness and contrast of the input image.
            A.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
        
            # CoarseDropout of the rectangular regions in the image.
            A.CoarseDropout(p=0.5),
        
            # CoarseDropout of the square regions in the image.
            A.Cutout(p=0.5) ])

def get_valid_aug(size): 
    
    return A.Compose([
    # Crop the central part of the input.   
    A.CenterCrop(size, size, p=1.),
    
    # Resize the input to the given height and width.    
    A.Resize(size,size)], p=1.)

In [None]:
'''The first step item_tfms resizes all the images to the same size (this happens on the CPU) 
   and then batch_tfms happens on the GPU for the entire batch of images. '''
# Transforms we need to do for each image in the dataset
item_tfms = [Resize(224), AlbumentationsTransform(get_train_aug(224), get_valid_aug(224))]

# Transforms that can take place on a batch of images
batch_tfms = [Normalize.from_stats(*imagenet_stats)]

bs=6

In [None]:
dblock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock(vocab=labels)), # multi-label target
                splitter=RandomSplitter(seed=42), # split data into training and validation subsets.
                get_x=get_x, # obtain the input images.
                get_y=get_y,  # obtain the targets.
                item_tfms=item_tfms,
                batch_tfms=batch_tfms
                )

dls = dblock.dataloaders(dfs, bs=bs)

In [None]:
# We can call show_batch() to see what a sample of a batch looks like.
dls.show_batch()

In [None]:
class CSPNetModel(nn.Module):
    
    def __init__(self, num_classes=19, model_name='cspresnext50', pretrained=True):
        super(CSPNetModel, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        self.model.head.fc = nn.Linear(self.model.head.fc.in_features, num_classes)
        
    def forward(self, x):
        x = self.model(x)
        return x
    
model = CSPNetModel()


In [None]:
# Group together some dls, a model, and metrics to handle training
learn = Learner(dls, model, metrics= accuracy_multi)

In [None]:
# Choosing a good learning rate
learn.lr_find()

In [None]:
# We can use the fine_tune function to train a model with this given learning rate
learn.fine_tune(9,0.0012022644514217973)

In [None]:
# Plot training and validation losses.
learn.recorder.plot_loss()

In [None]:
# Interpretation methods for classification models.
interp = ClassificationInterpretation.from_learner(learn)

# Show images in top_losses along with their prediction, actual, loss, and probability of actual class.
interp.plot_top_losses(5, nrows=5)