In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
import pytorch_lightning as pl

import matplotlib.pyplot as plt

%matplotlib inline

import numpy as np
import os
import h5py

In [2]:
pl.seed_everything(42)

Global seed set to 42


42

In [3]:
PATH_TO_TRAIN = "/deep/group/aihc-bootcamp-fall2021/lymphoma/processed/tma_patches/png/"
PATH_TO_VAL = "/deep/group/aihc-bootcamp-fall2021/lymphoma/processed/tma_patches/val.hdf5"
PATH_TO_TEST = "/deep/group/aihc-bootcamp-fall2021/lymphoma/processed/tma_patches/test.hdf5"

In [4]:
class HDF5Dataset(Dataset):
    
    def __init__(self, hdf5_path: str):
        
        self.hdf5_path = hdf5_path
        
        self.h5data = h5py.File(self.hdf5_path, "r")
        
        self.cores = list(self.h5data.keys())
        
    def __len__(self):
        return len(self.cores)
    
    def __getitem__(self, idx):
        
        patient_id = self.cores[idx]
        
        patches = self.h5data[patient_id][()]
        label = self.h5data[patient_id].attrs["label"]    
        
        return [transforms.ToTensor()(im) for im in patches], torch.tensor(label)

In [5]:
# Datasets

train_dataset = HDF5Dataset(PATH_TO_TEST)
val_dataset = HDF5Dataset(PATH_TO_VAL)

train_loader = DataLoader(train_dataset, batch_size=1, num_workers=1)
val_loader = DataLoader(val_dataset, batch_size=1, num_workers=1)

In [6]:
class SupervisedBaseline(pl.LightningModule):
    def __init__(self):
        super().__init__()

        # init a pretrained resnet
        backbone = models.resnet18(pretrained=True)
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = nn.Sequential(*layers)

        # use the pretrained model to classify cifar-10 (10 image classes)
        num_target_classes = 10
        self.classifier = nn.Linear(num_filters, num_target_classes)
        
    def forward(self, x):
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        x = self.classifier(representations)
        
        return x
    
    def aggregate(self, y_hats):
        return torch.max(y_hats, dim=0)[0].unsqueeze(0)
        
    
    def infer(self, bag, y):
        y_hats = []
        for x in bag:
            y_hats.append(self(x).squeeze())
            
        y_hat = self.aggregate(torch.stack(y_hats, dim=0))
        
        loss = nn.CrossEntropyLoss()(y_hat, y)
        
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    def training_step(self, batch, batch_idx):
        bag, y = batch
        
        return self.infer(bag, y)
    
    def validation_step(self, batch, batch_idx):
        bag, y = batch
        return self.infer(bag, y)
    
    def test_step(self, batch, batch_idx):
        bag, y = batch
        return self.infer(bag, y)
    

In [7]:
# model
model = SupervisedBaseline()

In [None]:
# training
trainer = pl.Trainer(
    gpus=2, num_nodes=1, num_processes=8,
    precision=16, limit_train_batches=0.5, accelerator="dp"
)
trainer.fit(model, train_loader, val_loader)

Using native 16bit precision.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Set SLURM handle signals.

  | Name              | Type       | Params
-------------------------------------------------
0 | feature_extractor | Sequential | 11.2 M
1 | classifier        | Linear     | 5.1 K 
-------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
Global seed set to 42
  rank_zero_warn(


Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]