# Histopathologic Cancer Detection Competition

This notebook is a PyTorch baseline for iterative testing.

In [0]:
# Comment out if in Colab
# %reload_ext autoreload
# %autoreload 2
# %matplotlib inline

# Colab drive mount
# from google.colab import drive
# drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
import torch.nn as nn
import numpy as np
import os
import pandas as pd
import torch
from pathlib import Path
from torch.optim import Adam
from torchvision.models.resnet import BasicBlock
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from PIL import Image
from matplotlib import pyplot as plt
from torchvision.models.resnet import ResNet
from sklearn.metrics import roc_auc_score
from torch import Tensor
from torchvision import transforms
from torch.autograd import Variable
from pathlib import Path
from sklearn.metrics import roc_auc_score
from tqdm import tqdm_notebook
import random
from typing import Optional, Tuple, List, Union
import logging
from functools import partial

## Data Prep

Our data consists of a subset *histopathological images of lymph nodes stained with hematoxylin and eosin (H&E)* from the Patch Camelyon (PCAM) 2016 dataset:  https://github.com/basveeling/pcam
- 220k training images and 57k evaluation 96x96 images, the only difference from PCAM being that duplicates are removed
- Images from Camylon16 Challenge '16 https://camelyon16.grand-challenge.org/Data/ are digitized at 2 different centers at 40x
    - PCAM uses 10x under sampling to increase FOV--resultant pixel resolution is 2.43 microns
- Training set is ~0.595 positive where positive denotes at least 1 pixel of tumor tissue in the 32x32 center region of the image--**tumor tissue outside the 32x32 region does not influence the label**
    - Thus we should crop images to the center region, but not too close as to lose features
    

#### Downloading Data 

Dataset is available at https://www.kaggle.com/c/histopathologic-cancer-detection/data

In [0]:
# Replace '/content/drive/My Drive/histopath' with your own respective working directory
# %cd /content/drive/My Drive/histopath 
path = Path('/content/drive/My Drive/histopath')

/content/drive/My Drive/histopath


In [0]:
# unzipping WSI patch id's
# !gunzip 'patch_ids.csv.gz'

In [0]:
# more unzipping
# !unzip -q train_labels.csv.zip
# !unzip -q test.zip -d test 
# !unzip -q sample_submission.csv.zip

# permissions
# !chmod 644 train_labels.csv
# !chmod 644 sample_submission.csv

chmod: cannot access 'sample_submission.csv': No such file or directory


In [0]:
TRN_PATH = Path(path/'train')
TEST_PATH = Path(path/'test')
LABELS = Path(path/'train_labels.csv')
WSI_LABELS = Path(path/'patch_id_wsi.csv')
USE_GPU = torch.cuda.is_available()

fixedSeed = None
def fixSeed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  if USE_GPU:
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


if fixedSeed is None:
  fixedSeed = 42  
fixSeed(fixedSeed)

### Train/Val Split

- Here we split based on the file id of the Whole Slide Image (WSI), as to not overfit to the way they are named.


In [0]:
# Note the label imbalance
trn_labels = pd.read_csv(LABELS)
trn_labels['label'].value_counts() 

0    130908
1     89117
Name: label, dtype: int64

In [0]:
# adapted from https://www.kaggle.com/c/histopathologic-cancer-detection/discussion/84132
def return_tumor_or_not(dic, one_id):
  return dic[one_id]

def create_dict():
  df = pd.read_csv(LABELS)
  result_dict = {}
  for index in range(df.shape[0]):
    one_id = df.iloc[index,0]
    tumor_or_not = df.iloc[index,1]
    result_dict[one_id] = int(tumor_or_not)
  return result_dict

def find_missing(train_ids, cv_ids):
  all_ids = set(pd.read_csv(LABELS)['id'].values)
  wsi_ids = set(train_ids + cv_ids)
  missing_ids = list(all_ids-wsi_ids)
  return missing_ids


def generate_split():
  ids = pd.read_csv(WSI_LABELS)
  wsi_dict = {}
  for i in range(ids.shape[0]):
    wsi = ids.iloc[i,1]
    train_id = ids.iloc[i,0]
    wsi_array = wsi.split('_')
    number = int(wsi_array[3])
    if wsi_dict.get(number) is None:
      wsi_dict[number] = [train_id]
    else:
      wsi_dict[number].append(train_id)

  wsi_keys = list(wsi_dict.keys())
  np.random.shuffle(wsi_keys)
  amount_of_keys = len(wsi_keys)

  keys_for_train = wsi_keys[0:int(amount_of_keys*0.9)]
  keys_for_cv = wsi_keys[int(amount_of_keys*0.9):]
  train_ids = []
  cv_ids = []

  for key in keys_for_train:
    train_ids += wsi_dict[key]

  for key in keys_for_cv:
    cv_ids += wsi_dict[key]

  dic = create_dict()

  missing_ids = find_missing(train_ids, cv_ids)
  missing_ids_total = len(missing_ids)
  np.random.shuffle(missing_ids)

  train_missing_ids = missing_ids[0:int(missing_ids_total*0.9)]
  cv_missing_ids = missing_ids[int(missing_ids_total*0.9):]

  train_ids += train_missing_ids
  cv_ids += cv_missing_ids

  train_labels = []
  cv_labels = []

  train_tumor = 0
  for one_id in train_ids:
    temp = return_tumor_or_not(dic, one_id)
    train_tumor += temp
    train_labels.append(temp)

  cv_tumor = 0
  for one_id in cv_ids:
    temp = return_tumor_or_not(dic, one_id)
    cv_tumor += temp
    cv_labels.append(temp)
  total = len(train_ids) + len(cv_ids)

  print("Amount of train labels: {}, {}/{}".format(len(train_ids), train_tumor, len(train_ids)-train_tumor))
  print("Amount of cv labels: {}, {}/{}".format(len(cv_ids), cv_tumor, len(cv_ids) - cv_tumor))
  print("Percentage of cv labels: {}".format(len(cv_ids)/total))

  return train_ids, cv_ids, train_labels, cv_labels



### DataLoader & Transforms 

H&E produces blue and different shades of pink colors:
    - Hematoxylin is dark blue and binds to negatively charged compounds i.e. nucleic acids
    - Eosin is pink and binds to positively charged compounds i.e. amino-acid side chains in most proteins, cytoplasm, extracellular features

 We experimented with stain normalization but this did not improve performance. Normalizing over the image values proved to be more effective than stain normalization. 

In [0]:
class BaseDataset(Dataset):
  def __init__(self,
                x_ds: Dataset,
                y_ds: Dataset,
                x_tfms: Optional = None):
    self.x_ds = x_ds
    self.y_ds = y_ds
    self.x_tfms = x_tfms

  def __len__(self) -> int:
    return self.x_ds.__len__()

  def __getitem__(self, index: int) -> Tuple:
    x = self.x_ds[index]
    y = self.y_ds[index]
    if self.x_tfms is not None:
      x = self.x_tfms(x)
      return x, y
      
class WSILabelDataset(Dataset):
  def __init__(self, labels: List):
    self.labels = labels

  def __len__(self) -> int:
    return len(self.labels)

  def __getitem__(self, index: int) -> int:
    return self.labels[index]

class WSIDataset(Dataset):
  def __init__(self, img_paths: List):
    self.img_paths = img_paths

  def __len__(self) -> int:
    return len(self.img_paths)

  def __getitem__(self, index: int) -> Image.Image:
    img = Image.open(self.img_paths[index])
    return img

In [0]:
bs = 256
n_workers = 0
kaggle_stats = [[0.485, 0.456, 0.406],
        [0.229, 0.224, 0.225]]
size = 96 # initial size
lr = 4e-2 # determined via fast.ai lr_find()
# more intensive transforms seemed to help validation score
tfms = transforms.Compose([
    transforms.Resize((size, size)),
    transforms.RandomChoice([
      transforms.ColorJitter(brightness=0.5),
      transforms.ColorJitter(contrast=0.5), 
      transforms.ColorJitter(saturation=0.5),
      transforms.ColorJitter(hue=0.5),
      transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), 
      transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3), 
      transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5), 
    ]),
    transforms.RandomChoice([
      transforms.RandomRotation((0,0)),
      transforms.RandomHorizontalFlip(p=1),
      transforms.RandomVerticalFlip(p=1),
      transforms.RandomRotation((90,90)),
      transforms.RandomRotation((180,180)),
      transforms.RandomRotation((270,270)),
      transforms.Compose([
          transforms.RandomHorizontalFlip(p=1),
          transforms.RandomRotation((90,90)),
      ]),
      transforms.Compose([
          transforms.RandomHorizontalFlip(p=1),
          transforms.RandomRotation((270,270)),
      ]) 
    ]),
    transforms.ToTensor(),
    transforms.Normalize(
      mean=kaggle_stats[0],
      std=kaggle_stats[1]
    )
])


In [0]:
train_ids, val_ids, train_labels, val_labels = generate_split()
fn_paths = [TRN_PATH/(name+'.tif') for name in train_ids]
valid_paths = [id+'.tif' for id in val_ids]

trn_img_ds = WSIDataset(fn_paths)
val_img_ds = WSIDataset(valid_paths)
trn_lbl_ds = WSILabelDataset(train_labels)
val_lbl_ds = WSILabelDataset(val_labels)

trn_ds = BaseDataset(trn_img_ds,trn_lbl_ds, x_tfms=tfms)
val_ds = BaseDataset(val_img_ds,trn_lbl_ds, x_tfms=tfms)
trn_dl = DataLoader(trn_ds,batch_size=bs,shuffle=True,num_workers=n_workers)
val_dl = DataLoader(val_ds,batch_size=bs,shuffle=False,num_workers=n_workers)

Amount of train labels: 197458, 81040/116418
Amount of cv labels: 22567, 8077/14490
Percentage of cv labels: 0.10256561754346097


In [0]:
def to_t(tensor):
  if not torch.is_tensor(tensor):
    tensor = torch.FloatTensor(tensor)
  else:
    tensor = tensor.type(torch.FloatTensor)
  if USE_GPU:
    tensor = to_gpu(tensor)
  return tensor


def to_numpy(tensor: Union[Tensor, Image.Image, np.array]) -> np.ndarray:
  if type(tensor) == np.array or type(tensor) == np.ndarray:
    return np.array(tensor)
  elif type(tensor) == Image.Image:
    return np.array(tensor)
  elif type(tensor) == Tensor:
    return tensor.cpu().detach().numpy()
  else:
    raise ValueError(msg)

## Model & Training Loop


Heavily influenced by this series of posts on training optimization:
https://myrtle.ai/how-to-train-your-resnet-8-bag-of-tricks/

Here we use base PyTorch but in practice, the One Cycle Scheduler available from fast.ai and progressive resizing (96, 128, 192) was added to the mix.

In [0]:
def to_gpu(tensor):
  return tensor.cuda() if USE_GPU else tensor
def resnet9(pretrained=True,output_dim: int = 1, **kwargs):
  """Constructs a ResNet-9 model.
  Args:
      pretrained (bool): If True, returns a model pre-trained on ImageNet
  """
  model = ResNet(BasicBlock, [1,1,1,1], **kwargs)
  in_features = model.fc.in_features
  model.avgpool = nn.AdaptiveAvgPool2d(1)
  model.fc = nn.Linear(in_features, output_dim)
  model = to_gpu(model)
  return model
m = resnet9(output_dim=1)

In [0]:
def one_epoch(
    model,
    trn_dl,
    val_dl,
    loss_func,
    opt_func,
    trn_loss_writer,
    val_loss_writer,
    do_step_trig,
    trn_loss_trig,
    val_loss_trig):
  model.train()
  y_targ_trn, y_pred_trn = [], []
  for i, (x,y) in enumerate(trn_dl):
    x = Variable(to_t(x), requires_grad=True)
    y = Variable(to_t(y), requires_grad=True)
    out = model(x)
    y_targ_trn.append(to_numpy(y))
    y_pred_train.append(to_numpy(out))
    losses = loss(out, y)
    losses.backward()
    if do_step_trig(i):
      opt.step()
      opt.zero_grad()
    if trn_loss_trig(i):
      auc_metric(y_targ_trn, y_pred_trn, i)
      y_targ_trn, y_pred_trn = [], []
    if val_loss_trig(i):
      y_targ, y_pred = predict(model, val_dl)
      auc_metric(y_targ, y_pred, i)
  return model

def predict(model, dl):
  model.eval()
  y_targ, y_pred = [], []
  with torch.no_grad():
    for x, y in dataloader:
      x = Variable(to_t(x))
      y = Variable(to_t(y))
      out = model(x)
      y_targ.append(to_numpy(y))
      y_pred.append(to_numpy(out))
  return y_targ, y_pred

def iter_trig(iter_num, step_size):
  if step_size == 1:
    return True
  elif iter_num > 0 and iter_num % step_size == 0:
    return True
  else:
    return False

# Competition uses AUC as ranking metric
def auc_metric(y_targ, y_pred, iter_num):
  try:
    auc = roc_auc_score(np.vstack(y_targ), np.vstack(y_pred)) 
  except:
    auc = -1
  logger.info(f'iter #: {iter_num}, auc: {auc}')

def init_trigs(step_size=1, val=10, trn=10):
  do_step_trig = partial(iter_trig, step_size=step_size)
  trn_loss_trig = partial(iter_trig, step_size=trn)
  val_loss_trig = partial(iter_trig, step_size=val)
  return do_step_trig, trn_loss_trig, val_loss_trig

In [0]:
m = resnet9(output_dim=1)
optimizer = Adam(m.parameters(),lr=lr)
loss = nn.BCEWithLogitsLoss()
loss_writer_trn = auc_metric
loss_writer_val = auc_metric
do_step_trig, trn_loss_trig, val_loss_trig = init_trigs(1,10,10)


## Training

Our single best model was progressively resized from 96 -> 128 -> 192.
On each resizing, lr was divided by a factor of 10.
If the validation loss did not decrease after ~5 cycles, lr was divided by 2. Then if loss still did not improve, previous best model was loaded.

In [0]:
learn = one_epoch(m,
                  trn_dl,
                  val_dl,
                  loss,
                  optimizer,
                  loss_writer_trn,
                  loss_writer_val,
                  do_step_trig,
                  trn_loss_trig,
                  val_loss_trig)

## Classification Interpretation

In [0]:
y_targ, y_pred = predict(m, val_dl)
preds = pd.DataFrame(list(zip(val_lbls.reshape(-1),np.vstack(y_pred).reshape(-1),valid_paths)),columns=['Label','Pred','fn'])
plot_limit = 20
def show_preds(fns, w=2, h=2, n_cols=5):
  n_rows = len(fns) / n_cols + 1
  imgs = [Image.open(f) for f in fns]
  plt.figure(figsize = (n_cols * w, n_rows * h))
  plt.tight_layout()
  for chart, img in enumerate(imgs, 1):
    ax = plt.subplot(n_rows, n_cols, chart)
    ax.imshow(np.array(img))
    ax.axis('off')


In [0]:
# False negative imgs
false_neg = preds[preds['Label']==1].sort_values('Pred', ascending=True)['fn'].values[:plot_limit]
implot(false_neg)

In [0]:
# False positive imgs
false_pos = preds[preds['Label']==0].sort_values('Pred', ascending=False)['fn'].values[:plot_limit]
implot(false_pos)

In [0]:
# True positive imgs
true_pos = preds[preds['Label']==1].sort_values('Pred', ascending=False)['fn'].values[:plot_limit]
implot(true_pos)

In [0]:
# True negative imgs
true_neg = preds[preds['Label']==0].sort_values('Pred', ascending=True)['fn'].values[:plot_limit]
implot(true_neg)