In [1]:
import os
import json
import random

import numpy as np
import pandas as pd
import nibabel as nib

import seaborn as sns
import matplotlib.pyplot as plt

from pathlib import Path
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.functional as F

from torch import GradScaler, autocast
from torch.utils.data import DataLoader

from tqdm import tqdm

import time
import utils

import monai
import monai.transforms as mt

import pypickle

import medim
import nibabel as nib
import matplotlib.pyplot as plt

from dataset import get_dataset

import wandb
wandb.login()

import warnings
warnings.simplefilter('ignore')

[34m[1mwandb[0m: Currently logged in as: [33mdteakhperky[0m ([33mdteakhperky-higher-school-of-economics[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


# Setup

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
def seed_everything(seed=0xBAD5EED):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

    generator = torch.Generator()
    generator.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# Dataset

In [4]:
CWD_PATH = Path(os.getcwd())
DATA_PATH = CWD_PATH / "data"

ATLAS_DATA_PATH = DATA_PATH / "Lumiere" / "atlas_mapping"
SPLIT_PATH = CWD_PATH / "splits" / "split.pkl"

split = pypickle.load(SPLIT_PATH, verbose="silent")

In [5]:
baseline_image_keys = ['baseline_FLAIR', 'baseline_T1', 'baseline_T1CE', 'baseline_T2', 'baseline_seg']
followup_image_keys = ['followup_FLAIR', 'followup_T1', 'followup_T1CE', 'followup_T2', 'followup_seg']
keys = baseline_image_keys + followup_image_keys

train_transform = mt.Compose([
    mt.LoadImaged(keys=keys),
    mt.EnsureChannelFirstd(keys=keys),
    # mt.Orientationd(keys=keys, axcodes="RAS"),
    # Spatial augmentations
    # mt.RandRotated(keys=keys, range_x=0.3, prob=0.5),
    # mt.RandFlipd(keys=keys, prob=0.5),
    # mt.RandZoomd(keys=keys, min_zoom=0.9, max_zoom=1.1, prob=0.5),
    # Intensity augmentations
    # mt.RandGaussianNoised(keys=keys, std=0.01, prob=0.2),
    # mt.RandAdjustContrastd(keys=keys, gamma=(0.7, 1.3), prob=0.3),
    # Normalization
    mt.ScaleIntensityd(keys=keys),
    mt.ConcatItemsd(keys=keys, name="images"),
    mt.DeleteItemsd(keys=keys),
    mt.RandFlipd(keys=["images"]),
    mt.ResizeD(keys=["images"], spatial_size=(121, 128, 121)),
    # mt.RandCropByPosNegLabeld(
    #     keys=["images"],
    #     label_key="baseline_seg",
    #     spatial_size=(128, 128, 64),
    #     pos=1, neg=1, num_samples=4,
    #     image_key="images", prob=0.5
    # ),
    # mt.ConcatItemsd(keys=followup_image_keys, name="followup"),
    mt.ToTensord(keys=["images"])
])

valid_transform = mt.Compose([
    mt.LoadImaged(keys=keys),
    mt.EnsureChannelFirstd(keys=keys),
    # mt.Orientationd(keys=keys, axcodes="RAS"),
    # mt.ConcatItemsd(keys=baseline_image_keys, name="baseline", dim=0),
    # mt.ConcatItemsd(keys=followup_image_keys, name="followup", dim=0),
    mt.ConcatItemsD(keys=keys, name="images", dim=0),
    mt.DeleteItemsd(keys=keys),
    mt.ResizeD(keys=["images"], spatial_size=(121, 128, 121)),
    # mt.ToTensord(keys=["baseline", "followup"]),
    mt.ToTensord(keys=["images"])
])

In [6]:
train_dataset = get_dataset(
    data_path=ATLAS_DATA_PATH,
    indices=split["train_patient_ids"],
    transform=train_transform
)

valid_dataset = get_dataset(
    data_path=ATLAS_DATA_PATH,
    indices=split["valid_patient_ids"],
    transform=valid_transform
)

len(train_dataset), len(valid_dataset)

(238, 84)

In [7]:
# train_dataset[0]["images"].shape, valid_dataset[0]["images"].shape

In [8]:
BATCH_SIZE = 4
NUM_WORKERS = 4

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS
)

valid_loader = DataLoader(
    dataset=valid_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS
)

In [9]:
# next(iter(train_loader))["images"].shape, next(iter(valid_loader))["images"].shape

In [10]:
# next(iter(train_loader))

# Training Utilities

In [46]:
@torch.inference_mode()
def test_model_correctness(model, loader):
    images = next(iter(loader))["images"]
    images = images.to(device)
    model = model.to(device)
    outputs = model(images[:, :5], images[:, 5:]).cpu().numpy()
    assert outputs.shape[-1] == 4, f'Wrong last output dimension. Expected {4}; Got {outputs.shape[-1]}'
    print('All gucci')

def calculate_receptive_field(params: list[tuple[int, int]]) -> int:
    r, s = 1, 1
    for kernel_size, stride in params:
        r += (kernel_size - 1) * stride
        s *= stride
    
    return r

def calculate_params(model: torch.nn.Module) -> int:
    return f'{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f} M'

In [47]:
def train(
    model, 
    loader, 
    loss_fn, 
    optimizer, 
    device, 
    scheduler=None, 
    wandb_logging=False,
    accumulation_steps=1
):
    model.train()
    scaler = GradScaler()
    total_loss = 0
    total_acc = 0
    step_count = 0
    
    optimizer.zero_grad()
    for i, batch in enumerate(tqdm(loader, leave=False)):
        batch["images"] = batch["images"].to(device)
        batch["label"] = batch["label"].to(device)

        with autocast(device_type="cuda"):
            outputs = model(batch["images"][:, :5], batch["images"][:, 5:])
            loss = loss_fn(outputs, batch["label"]) / accumulation_steps  # Scale loss

        scaler.scale(loss).backward()
        total_loss += loss.item() * accumulation_steps
        
        preds = outputs.argmax(dim=1).cpu().numpy()
        accuracy = np.mean(preds == batch["label"].cpu().numpy())
        total_acc += accuracy
        
        step_count += 1

        if (i + 1) % accumulation_steps == 0 or (i + 1) == len(loader):
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            
            if wandb_logging:
                wandb.log({
                    'train_loss': total_loss / step_count,
                    'train_accuracy': total_acc / step_count,
                    'learning_rate': optimizer.param_groups[0]['lr']
                })
            
            total_loss = 0
            total_acc = 0
            step_count = 0

    if scheduler is not None:
        scheduler.step()

    return {
        'train_loss': total_loss / max(1, step_count),
        'train_accuracy': total_acc / max(1, step_count)
    }
        
@torch.inference_mode()
def evaluate(model, loader, loss_fn, device, scheduler=None, wandb_logging=False):
    model.eval()
    batches_cnt = 0
    accuracy_acum = 0
    accuracy_per_class = np.zeros(NUM_CLASSES, dtype=np.float32)
    samples_cnt_per_class = np.zeros(NUM_CLASSES, dtype=np.float32)

    for batch in tqdm(loader, leave=False):
        batch["images"] = batch["images"].to(device)
        batch["label"] = batch["label"].to(device)

        outputs = model(batch["images"][:, :5], batch["images"][:, 5:])
        loss = loss_fn(outputs, batch["label"])

        # ===================
        # METRICS CALCULATION
        preds = outputs.argmax(dim=1).cpu().numpy()
        accuracy = np.mean(preds == batch["label"].cpu().numpy())
        for i in range(NUM_CLASSES):
            accuracy_per_class[i] += np.sum(preds[batch["label"].cpu().numpy() == i] == i) if np.any(batch["label"].cpu().numpy() == i) else 0
            samples_cnt_per_class[i] += np.sum(batch["label"].cpu().numpy() == i)
        accuracy_acum += accuracy
        batches_cnt += 1
        # ===================
    
    valid_logs = {
        'valid_loss': loss.item(),
        'valid_accuracy': accuracy_acum / batches_cnt,
        'valid_accuracy_0': accuracy_per_class[0] / samples_cnt_per_class[0],
        'valid_accuracy_1': accuracy_per_class[1] / samples_cnt_per_class[1],
        'valid_accuracy_2': accuracy_per_class[2] / samples_cnt_per_class[2],
        'valid_accuracy_3': accuracy_per_class[3] / samples_cnt_per_class[3],
    }

    if scheduler is not None:
        scheduler.step()
        
    if wandb_logging:
        wandb.log(valid_logs)
    
    return valid_logs

# Model Architecture

In [48]:
NUM_CLASSES = 4

In [49]:
class BaselineModel(nn.Module):
    def __init__(self, encoder: nn.Module, emb_dim: int, dropout: float):
        super(BaselineModel, self).__init__()
        self.encoder = encoder
        self.head = nn.Sequential(
            nn.Linear(2 * emb_dim, 4 * emb_dim),
            nn.LayerNorm(4 * emb_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(4 * emb_dim, NUM_CLASSES)
        )

    def forward(self, baseline, followup):
        baseline_embed = self.encoder(baseline)
        followup_embed = self.encoder(followup)
        # joint_embed = self.encoder(followup - baseline)
        joint_embed = torch.cat([baseline_embed, followup_embed], dim=1)
        outputs = self.head(joint_embed)
        return outputs

In [None]:
NUM_CHANNELS = 5
EMB_DIM = 128
DROPOUT = 0.50

def get_model() -> nn.Module:
    encoder = monai.networks.nets.DenseNet169(
        spatial_dims=2,
        in_channels=5,
        out_channels=EMB_DIM,
        # att_dropout=0.1
        pretrained=True,
    )
    baseline_model = BaselineModel(
        encoder=encoder,
        emb_dim=EMB_DIM,
        dropout=DROPOUT
    )
    return baseline_model

In [51]:
baseline_model = get_model()
test_model_correctness(baseline_model, train_loader)

NotImplementedError: Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does notprovide pretrained models for more than two spatial dimensions.

In [52]:
calculate_params(baseline_model)

NameError: name 'baseline_model' is not defined

# Training

In [18]:
def run_training_pipeline(
    epochs, model, train_loader, valid_loader, loss_fn, optimizer, device,
    scheduler=None, wandb_logging=False,
    checkpoint_freq: int = None, checkpoint_name: str = None
):
    seed_everything()
    model = model.to(device)
    train_logs, valid_logs = [], []

    if wandb_logging:
        wandb.init(
            project="MRI Classification",
            name="FFA",
            config={"desc": "Doing something"}
        )

    if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
        train_scheduler = None
        valid_scheduler = scheduler
    else:
        train_scheduler = scheduler
        valid_scheduler = None
    
    for epoch in tqdm(range(epochs)):
        train_log = train(
            model=model,
            loader=train_loader,
            loss_fn=loss_fn,
            optimizer=optimizer,
            device=device,
            scheduler=train_scheduler,
            wandb_logging=wandb_logging
        )
        train_logs.append(train_log)

        valid_log = evaluate(
            model=model,
            loader=valid_loader,
            loss_fn=loss_fn,
            device=device,
            scheduler=valid_scheduler,
            wandb_logging=wandb_logging
        )
        valid_logs.append(valid_log)

        if checkpoint_freq is not None:
            checkpoint = {
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "train_scheduler_state": None if train_scheduler is None else train_scheduler.state_dict(),
                "valid_scheduler_state": None if valid_scheduler is None else valid_scheduler.state_dict()
            }
            torch.save(checkpoint, PATH_TO_MODELS / f"{checkpoint_name}_{epoch + 1}.pth")
            # REMOVE PREVIOUS CHECKPOINT
            if epoch + 1 != checkpoint_freq:
                os.remove(PATH_TO_MODELS / f"{checkpoint_name}_{epoch + 1 - checkpoint_freq}.pth")

        if wandb_logging:
            wandb.log({'learning_rate': optimizer.param_groups[0]['lr']})
    
    if wandb_logging:
        wandb.finish()
    
    return train_logs, valid_logs

In [19]:
NUM_EPOCHS = 25
WEIGHT_DECAY = 1e-5
MOMENTUM = 0.90
LEARNING_RATE = 1e-4

TEMP = 0.80
CLASS_WEIGHTS = 1 / (torch.tensor([0.04, 0.08, 0.23, 0.65])) ** TEMP

model = get_model()
# optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
# loss_fn = nn.CrossEntropyLoss(weight=CLASS_WEIGHTS.to(device))
loss_fn = nn.CrossEntropyLoss()
scheduler = None

In [20]:
train_logs, val_logs = run_training_pipeline(
    epochs=NUM_EPOCHS,
    model=model,
    train_loader=train_loader,
    valid_loader=valid_loader,
    loss_fn=loss_fn,
    optimizer=optimizer,
    device=device,
    scheduler=scheduler,
    wandb_logging=True,
    checkpoint_freq=None,
    checkpoint_name=None,
)

 88%|████████▊ | 22/25 [2:09:59<15:56, 318.68s/it]  IOStream.flush timed out
100%|██████████| 25/25 [2:28:06<00:00, 355.45s/it]
[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▅▄▁▅▄▄▅▅▆▄▅▄▃▃▅▃█▅█▄▅▄▄▃▂▅▄▆▇▅▇▅▄▆▆▅▆▄▆▅
train_loss,█▄▅▃▃▃▃▄▄▄▄▄▃▃▃▅▃▃▃▄▄▂▂▄▂▃▅▅▅▃▂▂▃▃▅▂▂▃▄▁
valid_accuracy,▅█▇███▆▇▇▄▇▇▆▆▅█▇█▆▄▆▆▆▄▁
valid_accuracy_1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid_accuracy_2,▅▁▂▁▁▂▅▂▃▆▄▁▆▁▅▂▄▄▅▄▄▄█▆█
valid_accuracy_3,▅█▇███▆▇▇▄▆▇▅▇▅█▇▇▆▅▆▆▄▄▁
valid_loss,▆▆▅▅▄▃▄▆▄▆▇▆▆█▅▅▅▅▁▄▄▂▅▁▄

0,1
learning_rate,0.0001
train_accuracy,0.6875
train_loss,1.10683
valid_accuracy,0.44048
valid_accuracy_1,0.0
valid_accuracy_2,0.4375
valid_accuracy_3,0.55556
valid_loss,1.27537
