Install requirements

In [2]:
! pip install -Uqq fastcore
# ! pip install -Uqq kaggle --force
! pip freeze | grep fastcore

In [3]:
from IPython.display import display

In [4]:
import os
from pathlib import Path
import random
from collections import OrderedDict, Counter

In [5]:
import pandas as pd
import numpy as np
from PIL import Image
from tqdm import tqdm 

In [6]:
import plotly
import plotly.graph_objs as go
from plotly.offline import iplot, init_notebook_mode
import matplotlib.pyplot as plt
import seaborn as sns

In [7]:
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from fastcore.basics import *
from fastcore.foundation import *
from fastcore.xtras import *
from fastcore.basics import noop

In [8]:
np.random.seed(1)

# Explore data

In [9]:
wdir = Path('../input/aptos2019-blindness-detection')

In [10]:
train_df = pd.read_csv(wdir/'train.csv')

In [11]:
train_df.head()

In [12]:
sns.countplot(data=train_df, x="diagnosis")

In [13]:
train_pth = Path(wdir/'train_images')
test_pth = Path(wdir/'test_images')

In [14]:
train_fns = train_pth.ls()
test_fns = test_pth.ls()
train_fns, test_fns

# Show a set of 2 images at random for each class

In [15]:
train_df.head()

Choose 3 random images based on `id_code`

In [16]:
cls=0
random.choices(train_df[train_df.diagnosis==cls].id_code.to_list(), k=3)

In [17]:
def open_img(img_id, pth:Path, ext='png', cspace='RGB', sz=None, pil=False):
  """Read image from path given the image_id."""
  img_pth = pth/f'{img_id}.{ext}'
  img = Image.open(img_pth).convert(cspace)
  img = img if sz is None else img.resize((sz, sz))
  return np.asarray(img) if not pil else img

In [18]:
img = open_img('000c1434d8d7', train_pth, sz=224)
print(img.shape)
plt.imshow(img)

In [19]:
img = open_img('000c1434d8d7', train_pth, sz=224, pil=True)
img

In [20]:
def annotate_axes(ax, text, fontsize=10, x=0., y=1.1):
  """Add annotations to ax object."""
  ax.text(x, y, text, transform=ax.transAxes,
          va="center", fontsize=fontsize, color="gray")

In [21]:
vocab = OrderedDict(NoDR=0, Mild=1, Moderate=2, Severe=3, ProliferativeDR=4)
vocab

In [22]:
vocab.keys()

In [23]:
def show_imgs(df, vocab=None, n_cols=5, n_rows=5, figsize=(10,8), sz=224):
  """Show a set of random images for each class"""
  fig, axs = plt.subplots(n_rows, n_cols, figsize=figsize)
  for cls, ax in enumerate(axs):
    data_ls = df[df.diagnosis==cls].id_code.to_list()
    fns = random.choices(data_ls, k=n_cols)
    # print(fns, cls)
    for i in range(n_cols):
      img = open_img(fns[i], train_pth, sz=sz)
      im = ax[i].imshow(img)
      ax[i].axis('off')
      text = f'class: {cls}' if vocab is None else L(vocab.keys())[cls]
      annotate_axes(ax[i], text)
    plt.tight_layout()
  plt.show()

In [24]:
show_imgs(train_df, vocab)

# Check images per class

In [25]:
counts = train_df.diagnosis.value_counts()

In [26]:
counts

In [27]:
ax = plt.pie(counts, autopct="%.1f%%")

In [28]:
vocab

In [29]:
"""
# Manually undersample
block_list = DefaultDict(lambda: False)
counter = DefaultDict(lambda: 0)
new_df = {'id_code': [],
          'diagnosis': [],}

for idx, row in train_df.iterrows():
  if row.diagnosis not in block_list:
    counter[row.diagnosis]+=1
    new_df['id_code'].extend([row.id_code])
    new_df['diagnosis'].extend([row.diagnosis])
  if counter[row.diagnosis] >= 250:
    block_list[row.diagnosis] = True
"""

Assign weights to each sample.

In [30]:
weights = counts/counts.sum()
weights

In [31]:
-np.log(weights)

In [32]:
train_df["weights"] = train_df.diagnosis.apply(lambda x: -np.log(weights[x]))

# Make train, valid test

In [33]:
train_df.head()

In [34]:
tmp = train_df.sample(frac=1, replace=False) # shuffle

In [35]:
tmp

Split train and valid and then based on class.

In [36]:
data = tmp.to_records(index=False)

In [37]:
len(data)

In [38]:
train_data = data[:3000]
valid_data = data[3000:]

In [39]:
plt.subplot(1,2,1)
sns.histplot(train_data.diagnosis,bins=5)
plt.subplot(1,2,2)
sns.histplot(valid_data.diagnosis, bins=5)

In [40]:
len(train_data), len(valid_data)

# Make Dataset

In [41]:
import torchvision.transforms as T
import torch
from torchvision import models

In [42]:
# imagenet_stats
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

In [43]:
resize_sz = 384

In [44]:
import torchvision.transforms.functional as TF
import random

In [45]:
class ExtraTransforms:
    """Additional transforms based on kaggle discussion.
    https://www.kaggle.com/c/aptos2019-blindness-detection/discussion/108065#latest-624210
    """
    def __init__(self, contrast_range=0.1, brightness_range=0.1, hue_range=0.1, 
                 saturation_range=0.1, sharpness_factor=0.1, blur_sigma=0.1, kernel_size=1,
                 rot_angles=np.arange(0, 100, 20), do_mirror=True):
        contrast_range = random.uniform(1 - contrast_range, 1 + contrast_range)
        brightness_range = random.uniform(1 - brightness_range, 1 + brightness_range)
        hue_range = random.uniform(-hue_range, hue_range)
        saturation_range = random.uniform(1 - saturation_range, 1 + saturation_range)
        sharpness_factor = random.uniform(1 - sharpness_factor, 1 + sharpness_factor)
        rot_angle = float(random.choice(rot_angles))
        store_attr('contrast_range,brightness_range,hue_range,saturation_range, \
                   blur_sigma, kernel_size, sharpness_factor,rot_angle,do_mirror')

    def __call__(self, x):
        x = TF.adjust_contrast(x, self.contrast_range)
        x = TF.adjust_brightness(x, self.brightness_range)
        x = TF.adjust_hue(x, self.hue_range)
        x = TF.adjust_saturation(x, self.saturation_range)
        x = TF.adjust_sharpness(x, self.sharpness_factor)
        x = TF.gaussian_blur(x, self.kernel_size, self.blur_sigma)
        x = TF.rotate(x, self.rot_angle)
        if self.do_mirror:
            x = TF.hflip(x)
        return x

In [46]:
def train_augs():
  """Training augmentation"""
  return T.Compose([
                    ExtraTransforms(),
                    T.ToTensor(),
                    T.Resize(resize_sz),
                    T.Normalize(mean=mean, std=std),])

def valid_augs():
  """Validation augmentation"""
  return T.Compose([
                    T.ToTensor(),
                    T.Resize(resize_sz),
                    T.Normalize(mean=mean, std=std),])

In [47]:
class APTOSDataset(torch.utils.data.Dataset):
  def __init__(self, data, pth, augs=None, sz=300, test=False):
    super(APTOSDataset, self).__init__()
    self.data = data
    self.pth = pth
    self.augs = augs() if augs is not None else False
    self.sz = sz
    self.test = test
  
  def __getitem__(self, idx):
    label = None
    img = open_img(self.data[idx][0], self.pth, sz=self.sz, pil=True)
    if not self.test:
      label = self.data[idx][1]
      label = torch.tensor([label], dtype=torch.long)
    if self.augs:
      img = self.augs(img)
    return img, label
  
  def __len__(self):
    return len(self.data)

In [48]:
APTOSDataset(train_data, pth=train_pth)

In [49]:
train_ds = APTOSDataset(train_data, pth=train_pth, augs=train_augs, sz=resize_sz)

In [50]:
train_ds[1]

In [51]:
plt.imshow(train_ds[3][0].permute(1,2,0))

In [52]:
valid_ds = APTOSDataset(valid_data, pth=train_pth, augs=valid_augs, sz=resize_sz)

In [53]:
valid_ds[100]

In [54]:
plt.imshow(valid_ds[10][0].permute(1,2,0))

# Make DataLoader

In [55]:
bs=64

In [56]:
len(train_ds)/bs, len(valid_ds)/bs

In [57]:
drop_last=True

In [58]:
train_sampler = torch.utils.data.WeightedRandomSampler(weights=train_data.weights, num_samples=len(train_ds), replacement=False)
valid_sampler = torch.utils.data.WeightedRandomSampler(weights=valid_data.weights, num_samples=len(valid_ds), replacement=False)

In [59]:
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=bs, drop_last=drop_last,sampler=train_sampler)

In [60]:
valid_dl = torch.utils.data.DataLoader(valid_ds, batch_size=bs, drop_last=drop_last, sampler=valid_sampler)

In [61]:
print(f'train dblock len: {len(train_dl)}')
print(f'valid dblock len: {len(valid_dl)}')

Make a sample batch.

In [62]:
for i, (img, label) in enumerate(train_dl):
    break

In [63]:
img.shape, label.shape

Check if the samples were correctly chosen.

In [64]:
counts

In [65]:
Counter(label.squeeze().tolist())

# Build Residual block from scratch
![](https://d2l.ai/_images/resnet-block.svg)

L: Identity/Regular Block

R: Convolutional Block

In [66]:
from torch import nn
import torch.nn.functional as F

Random tensor to test the implementation.

In [67]:
x = torch.rand(4, 3, 6, 6)
x.shape

## Making the common block

`ConvBnAct`: Apply `Conv2D`, `BatchNorm2d` and `ReLu` in sequence.

In [68]:
class ConvBnAct(nn.Module):
  """Calls `Conv2D`, `BatchNorm2d` and `act_fn` in sequence.
  `noop` with `bn=False`, `act=False`.
  """
  def __init__(self, in_ch=3, out_ch=64, k=3, s=1, p=0, d=1, bn=True, act=True):
    super(ConvBnAct, self).__init__()
    store_attr('in_ch, out_ch, k, s, p, d, bn, act', self)
    self.conv = nn.Conv2d(self.in_ch, self.out_ch, self.k, self.s, self.p, self.d, bias=False)
    self.bn = nn.BatchNorm2d(self.out_ch) if self.bn else noop
    self.act_fn = nn.ReLU() if self.act else noop
  def forward(self, x):
    x = self.bn(self.conv(x))
    return self.act_fn(x)

In [69]:
params = dict(in_ch=3, out_ch=64, k=7, s=2, p=3)

In [70]:
ConvBnAct(**params)(x).shape

## Making the RegularBlock and ConvBlock 

In [71]:
class RegularBlock(nn.Module):
  """Regular block (no 1x1conv). 
  Adds the input to the output before applying the final act_fn.
  """
  def __init__(self, in_ch, out_ch):
    super(RegularBlock, self).__init__()
    assert in_ch==out_ch, "Regular block should have in_ch==out_ch"
    self.conv1 = ConvBnAct(in_ch=in_ch, out_ch=out_ch, k=3, s=1, p=1)
    self.conv2 = ConvBnAct(in_ch=out_ch, out_ch=out_ch, k=3, s=1, p=1, act=False)
  def forward(self, x):
    x_copy = x.clone()
    x = self.conv1(x)
    x = self.conv2(x) 
    # print(x.shape, x_copy.shape)
    out = F.relu(x + x_copy)
    # print(out.shape)
    return out

In [72]:
RegularBlock(3, 3)(x).shape

In [73]:
class ConvBlock(nn.Module):
  """ConvBlock (with 1x1conv). 
  Adds the 1x1 conv input to the output before applying the final act_fn.
  """
  def __init__(self, in_ch, out_ch, s=1):
    super(ConvBlock, self).__init__()
    self.conv1 = ConvBnAct(in_ch=in_ch, out_ch=out_ch, k=3, s=s, p=1)
    self.conv2 = ConvBnAct(in_ch=out_ch, out_ch=out_ch, k=3, p=1, act=False)
    self.dsample = ConvBnAct(in_ch=in_ch, out_ch=out_ch, k=1, s=s, act=False)
  def forward(self, x):
    x_copy = x.clone()
    x = self.conv1(x)
    x = self.conv2(x) 
    x_copy = self.dsample(x_copy)
    # print(x.shape, x_copy.shape)
    out = F.relu(x + x_copy)
    # print(out.shape)
    return out

In [74]:
ConvBlock(3, 6, s=2)(x).shape

## First Layer of ResNet18

In [75]:
x = torch.randn(5, 3, 32, 32)

In [76]:
params = dict(in_ch=3, out_ch=64, k=7, s=2, p=3)

In [77]:
ConvBnAct(**params)(x).shape

In [78]:
l1 = nn.Sequential(ConvBnAct(**params), nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

In [79]:
acts = l1(x)
acts.shape

## Residual Blocks 
Consists of Several combinations of ConvBlock and RegularBlocks. The first ResBlock does not have ConvBlock. The following ResBlocks have a combination of ConvBlock followed by RegularBlock.

![](https://d2l.ai/_images/resnet18.svg)

In [80]:
def resnet_block(in_ch, out_ch, n_blocks, first_block=False):
  """Helper function to customize the number of blocks in the resnet."""
  layers = []
  names = []
  for i in range(n_blocks):
    if i == 0 and not first_block:
      layers.append(ConvBlock(in_ch, out_ch, s=2))
      names.append(f'conv_blk{i}')
    else:
      layers.append(RegularBlock(out_ch, out_ch))
      names.append(f'reg_blk{i}')
  return list(zip(names, layers))

In [81]:
resnet_block(64, 64, 2, first_block=False)

In [82]:
l2 = nn.Sequential(OrderedDict(resnet_block(64, 64, 2, first_block=True)))
l3 = nn.Sequential(OrderedDict(resnet_block(64, 128, 2)))
l4 = nn.Sequential(OrderedDict(resnet_block(128, 256, 2)))
l5 = nn.Sequential(OrderedDict(resnet_block(256, 512, 2)))

In [83]:
l5

In [84]:
l2(acts).shape

In [85]:
acts = l5(l4(l3(l2(acts))))

## Final Linear layer

In [86]:
vocab

In [87]:
len(vocab)

In [88]:
fc = nn.Sequential(nn.AdaptiveAvgPool2d((1,1)), 
                   nn.Flatten(), 
                   nn.Linear(512, len(vocab)))

In [89]:
fc(acts).shape

# Putting the blocks together

In [90]:
from torch.nn.parameter import Parameter

# https://github.com/filipradenovic/cnnimageretrieval-pytorch.
def gem(x, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)

class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM,self).__init__()
        self.p = Parameter(torch.ones(1)*p)
        self.eps = eps
    def forward(self, x):
        return gem(x, p=self.p, eps=self.eps)   

In [91]:
GeM()

In [92]:
class ResNet18(nn.Module):
  def __init__(self, n_cls=2):
    super(ResNet18, self).__init__()
    params = dict(in_ch=3, out_ch=64, k=7, s=2, p=3)                                # GoogLeNet
                                                                                    # x = [B, 3, 32, 32]
    self.l1 = nn.Sequential(ConvBnAct(**params), 
                            nn.MaxPool2d(kernel_size=3, stride=2, padding=1))       # -> [B, 512, 16, 16]
    self.l2 = nn.Sequential(OrderedDict(resnet_block(64, 64, 2, first_block=True))) # -> [B, 64, 8, 8]
    self.l3 = nn.Sequential(OrderedDict(resnet_block(64, 128, 2)))                  # -> [B, 128, 4, 4]
    self.l4 = nn.Sequential(OrderedDict(resnet_block(128, 256, 2)))                 # -> [B, 256, 2, 2]
    self.l5 = nn.Sequential(OrderedDict(resnet_block(256, 512, 2)))                 # -> [B, 512, 1, 1]
    self.pool = GeM() # nn.AdaptiveAvgPool2d((1,1))
    self.flat = nn.Flatten()
    self.fc = nn.Linear(512, n_cls)                                                 # -> [512, n_cls]
  def forward(self, x):
    x = self.l5(self.l4(self.l3(self.l2(self.l1(x)))))
    return self.fc(self.flat(self.pool(x)))

In [93]:
model = ResNet18(len(vocab))

In [94]:
logits = model(x)
# logits

In [95]:
torch.softmax(logits, -1).argmax(-1, keepdim=True)

## Comparing with the size of torch implementation of `resnet18` with `n_cls=1000`

In [96]:
! cp ../input/pytorch-pretrained-models/resnet18-5c106cde.pth /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth

In [97]:
torch_model = models.resnet18(pretrained=True)

In [98]:
sum([p.nelement()*p.element_size() for p in torch_model.parameters()])/1024/1024 # 44MB

In [99]:
sum([p.nelement()*p.element_size() for p in ResNet18(1000).parameters()])/1024/1024 # 44MB 

# Load the weights 
Update the weights of the model with imagenet weights from torch.models

In [100]:
model.l1[0]

In [101]:
model.l1[0]

In [102]:
with torch.no_grad():
    model.l1[0].conv.weight = nn.Parameter(torch_model.conv1.weight) # conv1 weight
    model.l1[0].bn.weight = nn.Parameter(torch_model.bn1.weight) # bn1 weight

In [103]:
model.l2.reg_blk0.conv1

In [104]:
with torch.no_grad():
    model.l2.reg_blk0.conv1.conv.weight = nn.Parameter(torch_model.layer1[0].conv1.weight)
    model.l2.reg_blk0.conv1.bn.weight = nn.Parameter(torch_model.layer1[0].bn1.weight)
    model.l2.reg_blk0.conv2.conv.weight = nn.Parameter(torch_model.layer1[0].conv2.weight)
    model.l2.reg_blk0.conv2.bn.weight = nn.Parameter(torch_model.layer1[0].bn2.weight)
    #
    model.l2.reg_blk1.conv1.conv.weight = nn.Parameter(torch_model.layer1[1].conv1.weight)
    model.l2.reg_blk1.conv1.bn.weight = nn.Parameter(torch_model.layer1[1].bn1.weight)
    model.l2.reg_blk1.conv2.conv.weight = nn.Parameter(torch_model.layer1[1].conv2.weight)
    model.l2.reg_blk1.conv2.bn.weight = nn.Parameter(torch_model.layer1[1].bn2.weight)

In [105]:
model.l3.conv_blk0.dsample

In [106]:
with torch.no_grad():
    model.l3.conv_blk0.conv1.conv.weight = nn.Parameter(torch_model.layer2[0].conv1.weight)
    model.l3.conv_blk0.conv1.bn.weight = nn.Parameter(torch_model.layer2[0].bn1.weight)
    model.l3.conv_blk0.conv2.conv.weight = nn.Parameter(torch_model.layer2[0].conv2.weight)
    model.l3.conv_blk0.conv2.bn.weight = nn.Parameter(torch_model.layer2[0].bn2.weight)
    model.l3.conv_blk0.dsample.conv.weight = nn.Parameter(torch_model.layer2[0].downsample[0].weight)
    model.l3.conv_blk0.dsample.bn.weight = nn.Parameter(torch_model.layer2[0].downsample[1].weight)
    #
    model.l3.reg_blk1.conv1.conv.weight = nn.Parameter(torch_model.layer2[1].conv1.weight)
    model.l3.reg_blk1.conv1.bn.weight = nn.Parameter(torch_model.layer2[1].bn1.weight)
    model.l3.reg_blk1.conv2.conv.weight = nn.Parameter(torch_model.layer2[1].conv2.weight)
    model.l3.reg_blk1.conv2.bn.weight = nn.Parameter(torch_model.layer2[1].bn2.weight)

In [107]:
with torch.no_grad():
    model.l4.conv_blk0.conv1.conv.weight = nn.Parameter(torch_model.layer3[0].conv1.weight)
    model.l4.conv_blk0.conv1.bn.weight = nn.Parameter(torch_model.layer3[0].bn1.weight)
    model.l4.conv_blk0.conv2.conv.weight = nn.Parameter(torch_model.layer3[0].conv2.weight)
    model.l4.conv_blk0.conv2.bn.weight = nn.Parameter(torch_model.layer3[0].bn2.weight)
    model.l4.conv_blk0.dsample.conv.weight = nn.Parameter(torch_model.layer3[0].downsample[0].weight)
    model.l4.conv_blk0.dsample.bn.weight = nn.Parameter(torch_model.layer3[0].downsample[1].weight)
    #
    model.l4.reg_blk1.conv1.conv.weight = nn.Parameter(torch_model.layer3[1].conv1.weight)
    model.l4.reg_blk1.conv1.bn.weight = nn.Parameter(torch_model.layer3[1].bn1.weight)
    model.l4.reg_blk1.conv2.conv.weight = nn.Parameter(torch_model.layer3[1].conv2.weight)
    model.l4.reg_blk1.conv2.bn.weight = nn.Parameter(torch_model.layer3[1].bn2.weight)

In [108]:
with torch.no_grad():
    model.l5.conv_blk0.conv1.conv.weight = nn.Parameter(torch_model.layer4[0].conv1.weight)
    model.l5.conv_blk0.conv1.bn.weight = nn.Parameter(torch_model.layer4[0].bn1.weight)
    model.l5.conv_blk0.conv2.conv.weight = nn.Parameter(torch_model.layer4[0].conv2.weight)
    model.l5.conv_blk0.conv2.bn.weight = nn.Parameter(torch_model.layer4[0].bn2.weight)
    model.l5.conv_blk0.dsample.conv.weight = nn.Parameter(torch_model.layer4[0].downsample[0].weight)
    model.l5.conv_blk0.dsample.bn.weight = nn.Parameter(torch_model.layer4[0].downsample[1].weight)
    #
    model.l5.reg_blk1.conv1.conv.weight = nn.Parameter(torch_model.layer4[1].conv1.weight)
    model.l5.reg_blk1.conv1.bn.weight = nn.Parameter(torch_model.layer4[1].bn1.weight)
    model.l5.reg_blk1.conv2.conv.weight = nn.Parameter(torch_model.layer4[1].conv2.weight)
    model.l5.reg_blk1.conv2.bn.weight = nn.Parameter(torch_model.layer4[1].bn2.weight)

# Test model

In [109]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [110]:
model = model.to(device)

In [111]:
images, labels = next(iter(train_dl))
acts = model(images.to(device))

In [112]:
def loss_func(acts, labels):
    return torch.nn.CrossEntropyLoss()(acts, labels.squeeze())

In [113]:
loss_func(acts, labels.to(device))

In [114]:
Counter(label.squeeze().tolist())

In [115]:
preds = torch.softmax(acts, -1).argmax(-1, keepdim=True)

In [116]:
Counter(preds.squeeze().tolist())

In [117]:
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report, f1_score, cohen_kappa_score

In [118]:
def get_f1score(y_true, y_pred):
    return f1_score(y_true.squeeze().detach().cpu(), y_pred.squeeze().detach().cpu(), average='micro')

In [119]:
get_f1score(labels, preds)

In [120]:
cohen_kappa_score(labels.squeeze().detach().cpu(), preds.squeeze().detach().cpu(), weights='quadratic')

In [121]:
def ckscore(y_true, y_pred):
    return cohen_kappa_score(y_true.squeeze().detach().cpu(), y_pred.squeeze().detach().cpu(), weights='quadratic')

In [122]:
loss_func(acts, labels.to(device).squeeze())

# Begin training

In [123]:
def show_images(images, labels=None, preds=None, ncols=2, nrows=3, mean=mean, std=std):
  """Show method to display images from the dataloader batch."""
  plt.figure(figsize=(8,6))
  images = images.permute(0,2,3,1).detach().cpu().numpy()
  mean=np.array(mean)
  std=np.array(std)
  for i, image in enumerate(images):
    plt.subplot(ncols, nrows, i+1, xticks=[], yticks=[])
    image = image * std + mean
    if preds is not None and labels is not None:
      col = 'green' if preds[i]==labels[i] else 'red'
      true_label = f'{labels[i].detach().cpu().numpy()}'
      pred_label = f'{preds[i].detach().cpu().numpy()}'
      plt.xlabel(true_label)
      plt.ylabel(pred_label, color=col)
    plt.imshow(image.clip(0, 255))
    
  plt.tight_layout()
  plt.show()

In [124]:
show_images(images[:6], labels=labels[:6], preds=labels[:6])

In [125]:
def show_preds(dl, limit=6, **kwargs):
  """Method to show predictions for a batch from dataloader."""
  model.eval()
  images, labels = next(iter(dl))
  images, labels = images.to(device), labels.to(device)
  acts = model(images)
  preds = torch.softmax(acts, -1).argmax(-1, keepdim=True)
  show_images(images[:limit], labels[:limit], preds[:limit], **kwargs)

In [126]:
show_preds(train_dl, 6)

# Begin Training

In [127]:
@torch.no_grad()
def validate_epoch(valid_dl, model, loss_func=loss_func, show=True):
  model.eval()
  l_valid = len(valid_dl)
  valid_loss = 0.0
  valid_ck = 0.0
  valid_f1 = 0.0
  for vb, (images, labels) in enumerate(tqdm(valid_dl)):
    images, labels = images.to(device), labels.to(device)
    acts = model(images)
    loss = loss_func(acts, labels.squeeze())
    valid_loss += loss.item()
    preds = torch.softmax(acts, -1).argmax(-1, keepdim=True)
    valid_ck += ckscore(labels, preds)
    valid_f1 += get_f1score(labels, preds)
  tqdm.write(f'valid_score:{(valid_ck/l_valid):.4f} valid_f1:{(valid_f1/l_valid):.4f} valid_loss:{(valid_loss/l_valid):.4f}')
  if show:
    show_preds(valid_dl, 3)

In [128]:
def train(train_dl, valid_dl, epochs):
  for e in range(epochs):
    print(f'epoch {e}')
    model.train()
    train_loss = 0.0
    train_ck = 0.0
    train_f1 = 0.0
    l_train = len(train_dl)
    for tb, (images, labels) in enumerate(tqdm(train_dl)):
      images, labels = images.to(device), labels.to(device)
      acts = model(images)
      # opt
      optim.zero_grad()
      # loss
      loss = loss_func(acts, labels.squeeze())
      train_loss += loss.item()
      loss.backward()
      optim.step()
      # preds train
      preds = torch.softmax(acts, -1).argmax(-1, keepdim=True)
      train_ck += ckscore(labels, preds)
      train_f1 += get_f1score(labels, preds)
    tqdm.write(f'train_score:{(train_ck/l_train):.4f} train_f1:{(train_f1/l_train):.4f} train_loss:{(train_loss/l_train):.4f}')
    if e%2==0:
      # validate every 2 epochs
      validate_epoch(valid_dl, model, loss_func=loss_func, show=True)
    

In [129]:
# freeze
for name, p in model.named_parameters():
    if 'fc' not in name: 
        p.requires_grad_(False)

In [132]:
optim = torch.optim.Adam(model.fc.parameters(), lr=1e-2) # only train the head

In [133]:
optim

In [134]:
train(train_dl, valid_dl, 5)

In [135]:
torch.save({
            'epoch': 5,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optim.state_dict()
            }, 'model_e5.pth')

In [136]:
checkpoint = torch.load('model_e5.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optim.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']

# Train body with low lr

In [139]:
optim

In [140]:
optim.add_param_group({'params': L(model.parameters())[:-2], 'lr':2e-6}) # make body trainable

In [143]:
optim.param_groups[0]['lr'] = 2e-4

In [144]:
optim

In [145]:
train(train_dl, valid_dl, 10)

In [146]:
torch.save({
            'epoch': 15,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optim.state_dict()
            }, 'model_e15.pth')

# Inference

In [None]:
checkpoint = torch.load('model_e15.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optim.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']

In [None]:
test_pth = Path(wdir/'test_images')

In [None]:
test_df = pd.read_csv(wdir/'test.csv')

In [None]:
test_data = test_df.to_records(index=False)

In [None]:
test_ds = APTOSDataset(test_data, pth=test_pth, augs=valid_augs, sz=resize_sz, test=True)

In [None]:
def collate_test(b):
  batch = list(zip(*b))
  images = torch.stack(batch[0], 0)
  labels = None
  return images, labels

In [None]:
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=bs, shuffle=False, drop_last=False, collate_fn=collate_test)

In [None]:
@torch.no_grad()
def infer(test_dl, model):
  model.eval()
  result = []
  l_valid = len(test_dl)
  for vb, (images, labels) in enumerate(tqdm(test_dl)):
    images = images.to(device)
    acts = model(images)
    preds = torch.softmax(acts, -1).argmax(-1, keepdim=True)
    result.extend(preds.squeeze().tolist())
  return result

In [None]:
result = infer(test_dl, model)

In [None]:
sub = {'id_code': test_df.id_code,
       'diagnosis': result}

In [None]:
sub = pd.DataFrame(sub)
sub.to_csv('submission.csv', index=False)