In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from torch.optim.lr_scheduler import MultiStepLR
import pytorch_lightning as pl
from pytorch_lightning import LightningDataModule
from pytorch_lightning.utilities import rank_zero_info
from torchmetrics.classification.accuracy import Accuracy

from collections import OrderedDict

import numpy as np
import random
import os
import h5py

In [2]:
pl.seed_everything(42)

Global seed set to 42


42

In [3]:
PATH_TO_TRAIN = "/deep/group/aihc-bootcamp-fall2021/lymphoma/processed/data_splits/train.hdf5"
PATH_TO_VAL = "/deep/group/aihc-bootcamp-fall2021/lymphoma/processed/data_splits/val.hdf5"
PATH_TO_TEST = "/deep/group/aihc-bootcamp-fall2021/lymphoma/processed/data_splits/test.hdf5"

In [12]:
class HDF5Dataset(Dataset):
    
    def __init__(self, hdf5_path: str):
        
        self.hdf5_path = hdf5_path
        
        self.h5data = h5py.File(self.hdf5_path, "r")
        
        self.cores = list(self.h5data.keys())
        
    def __len__(self):
        return len(self.cores)
    
    def __getitem__(self, idx):
        
        patient_id = self.cores[idx]
        
        patches = self.h5data[patient_id][()]
        label = self.h5data[patient_id].attrs["label"]    
        
        return random.sample([transforms.ToTensor()(im) for im in patches], 8), torch.tensor(label)

In [13]:
# Datasets
train_dataset = HDF5Dataset(PATH_TO_TRAIN)
val_dataset = HDF5Dataset(PATH_TO_VAL)
test_dataset = HDF5Dataset(PATH_TO_TEST)

train_loader = DataLoader(train_dataset, batch_size=1, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=1, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=1, num_workers=4)


In [14]:
train_dataset[0]

TypeError: 'NoneType' object is not subscriptable

In [7]:
######### Fine-tune Model ##########
# Model adapted from net.py as the template model to load pretrained weights
class TripletNet_Finetune(pl.LightningModule):

    def __init__(self):
        super(TripletNet_Finetune, self).__init__()

        # set the model
        model = models.resnet18(pretrained=False)
        model.fc = torch.nn.Sequential()
        self.model = model
        self.fc = nn.Sequential(nn.Linear(512*2, 512),
                                 nn.ReLU(True), nn.Linear(512, 256))

    def forward(self, i):

        E1 = self.model(i)
        E2 = self.model(i)
        E3 = self.model(i)

        # Pairwise concatenation of features
        E12 = torch.cat((E1, E2), dim=1)
        E23 = torch.cat((E2, E3), dim=1)
        E13 = torch.cat((E1, E3), dim=1)

        f12 = self.fc(E12)
        f23 = self.fc(E23)
        f13 = self.fc(E13)

        features = torch.cat((f12, f23, f13), dim=1)

        return features


In [8]:
PATH_TO_PRETRAINED = '/deep/group/aihc-bootcamp-fall2021/lymphoma/models/Camelyon16_pretrained_model.pt'

In [9]:
class SupervisedBaseline(pl.LightningModule):
    def __init__(self, 
                 num_classes, 
                 batch_size: int, 
                 lr: float, 
                 num_workers: int, 
                 finetune: bool = False):
        
        super().__init__()
        self.batch_size = batch_size
        self.lr = lr
        self.num_workers = num_workers
        
        # Load pre-trained network:
        model = TripletNet_Finetune()

        state_dict = torch.load(PATH_TO_PRETRAINED) ## TODO: change this to pytorch lightning

        # create new OrderedDict that does not contain `module`
        new_state_dict = OrderedDict()
        for k, v in state_dict['model'].items():
            name = k[7:]  # remove `module.`
            v.requires_grad = not finetune
            new_state_dict[name] = v
            
        # load pretrained weights onto TripletNet_Finetune model
        model.load_state_dict(new_state_dict)

        
        # if we finetune - only train the classifier, as opposed to e2e - freeze the network

        self.feature_extractor = model
        
        # set the linear classifier
        # use the classifier setup in the paper
        self.classifier = nn.Sequential(nn.Linear(256*3, num_classes))
        
        self.criterion = nn.CrossEntropyLoss()
        
        # use separate metric instance for train, val and test step
        # to ensure a proper reduction over the epoch
        self.train_accuracy = Accuracy()
        self.val_accuracy = Accuracy()
        self.test_accuracy = Accuracy()
        
        # ensures params passed to LightningModule will be saved to ckpt
        # allows to access params with 'self.hparams' attribute
        # self.save_hyperparameters()
        
    def forward(self, x):
        # Forward step
        x = self.feature_extractor(x).flatten(1)   # representations
        x = self.classifier(x)                     # classifications
        return x
    
    def aggregate(self, y_hats):
        # TODO: confirm argmax/max
        return torch.max(y_hats, dim=0)[0].unsqueeze(0)
        
    def infer(self, bag, y):
        y_hats = []
        for x in bag:
            y_hats.append(self(x).squeeze())
            
        y_hat = self.aggregate(torch.stack(y_hats, dim=0))
        
        return y_hat

    def configure_optimizers(self):
        # only train parameters that are not frozen
        parameters = self.parameters()
        trainable_parameters = list(filter(lambda p: p.requires_grad, parameters))
        
        optimizer = torch.optim.Adam(trainable_parameters, lr=self.lr)
        return optimizer
    
    def training_step(self, batch, batch_idx):
        bag, y = batch
        y_hat = self.infer(bag, y)
        loss = self.criterion(y_hat, y)
        acc = self.train_accuracy(y_hat, y)
    
        # log train metrics
        self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
        return {"loss": loss, "preds": y_hat, "targets": y}
    
    def validation_step(self, batch, batch_idx):
        bag, y = batch
        y_hat = self.infer(bag, y)
        loss = self.criterion(y_hat, y)
        acc = self.val_accuracy(y_hat, y)
    
        # log validation metrics
        self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
        return {"loss": loss, "preds": y_hat, "targets": y}
    
    def test_step(self, batch, batch_idx):
        bag, y = batch
        y_hat = self.infer(bag, y)
        loss = self.criterion(y_hat, y)
        acc = self.test_accuracy(y_hat, y)
        
        # log test metrics
        self.log("test/loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test/acc", acc, on_step=False, on_epoch=True, prog_bar=True)

        return {"loss": loss, "preds": y_hat, "targets": y}

In [10]:
model = SupervisedBaseline(num_classes=9, batch_size=1, lr=1e-3, num_workers=1, finetune=True)

In [11]:
# training
trainer = pl.Trainer(
    gpus=-1, num_nodes=1, num_processes=8,
    precision=16, limit_train_batches=0.5, accelerator="dp",
    max_epochs=100
)
trainer.fit(model, train_loader, val_loader)

Using native 16bit precision.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Set SLURM handle signals.

  | Name              | Type                | Params
----------------------------------------------------------
0 | feature_extractor | TripletNet_Finetune | 11.8 M
1 | classifier        | Sequential          | 6.9 K 
2 | criterion         | CrossEntropyLoss    | 0     
3 | train_accuracy    | Accuracy            | 0     
4 | val_accuracy      | Accuracy            | 0     
5 | test_accuracy     | Accuracy            | 0     
----------------------------------------------------------
11.8 M    Trainable params
0         Non-trainable params
11.8 M    Total params
47.358    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 42


Training: -1it [00:00, ?it/s]



ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/sailhome/vrishk/miniconda3/envs/aihc/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/sailhome/vrishk/miniconda3/envs/aihc/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/sailhome/vrishk/miniconda3/envs/aihc/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/tmp/ipykernel_785648/254545935.py", line 21, in __getitem__
    return random.sample([transforms.ToTensor()(im) for im in patches], 8), torch.tensor(label)
  File "/sailhome/vrishk/miniconda3/envs/aihc/lib/python3.8/random.py", line 363, in sample
    raise ValueError("Sample larger than population or is negative")
ValueError: Sample larger than population or is negative


In [None]:
trainer.test(model, test_loader)