In [1]:
import os
import sys
import json
import random

import numpy as np
import pandas as pd
import nibabel as nib

import seaborn as sns
import matplotlib.pyplot as plt

from pathlib import Path
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.functional as F

from torch import GradScaler, autocast
from torch.utils.data import DataLoader

from tqdm import tqdm
# from utils import DataPathWrapper

import time
import utils

import monai
import monai.transforms as mt

import pypickle

import medim
import nibabel as nib
import matplotlib.pyplot as plt

from dataset import get_dataset

import wandb
wandb.login()

import warnings
warnings.simplefilter('ignore')

[34m[1mwandb[0m: Currently logged in as: [33mdteakhperky[0m ([33mdteakhperky-higher-school-of-economics[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
CWD_PATH = Path(os.getcwd())
BMMAE_PATH = CWD_PATH / "BM-MAE"

sys.path.append(str(BMMAE_PATH))

In [3]:
from bmmae.model import ViTEncoder
from bmmae.tokenizers import MRITokenizer

# Setup

In [4]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
def seed_everything(seed=0xBAD5EED):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

    generator = torch.Generator()
    generator.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

In [6]:
%load_ext autoreload
%autoreload 2

# Dataset

In [7]:
CWD_PATH = Path(os.getcwd())
DATA_PATH = CWD_PATH / "data"

ATLAS_DATA_PATH = DATA_PATH / "Lumiere" / "atlas_mapping"
SPLIT_PATH = CWD_PATH / "splits" / "split.pkl"

split = pypickle.load(SPLIT_PATH, verbose="silent")

In [8]:
import monai.transforms as mt
import numpy as np

class CropAroundMaskd(mt.MapTransform):
    """
    Crop a region around the non-zero area of a mask with a specified margin.
    If the mask is empty (no foreground), returns the original images without cropping.
    """

    def __init__(self, keys, mask_key, margin=10):
        super().__init__(keys)
        self.mask_key = mask_key
        self.margin = margin

    def __call__(self, data):
        d = dict(data)
        mask = d[self.mask_key][0]  # assuming mask shape (C=1, H, W, D)

        nonzero = np.nonzero(mask)
        if len(nonzero[0]) == 0:
            # Mask is empty: skip cropping and return original images
            return d

        minz, maxz = nonzero[0].min(), nonzero[0].max()
        miny, maxy = nonzero[1].min(), nonzero[1].max()
        minx, maxx = nonzero[2].min(), nonzero[2].max()

        shape = mask.shape
        minz = max(minz - self.margin, 0)
        maxz = min(maxz + self.margin + 1, shape[0])
        miny = max(miny - self.margin, 0)
        maxy = min(maxy + self.margin + 1, shape[1])
        minx = max(minx - self.margin, 0)
        maxx = min(maxx + self.margin + 1, shape[2])

        for key in self.keys:
            img = d[key]
            # img shape assumed (C, H, W, D)
            d[key] = img[:, minz:maxz, miny:maxy, minx:maxx]

        return d


In [9]:
baseline_image_keys = ['baseline_FLAIR', 'baseline_T1', 'baseline_T1CE', 'baseline_T2', 'baseline_seg']
followup_image_keys = ['followup_FLAIR', 'followup_T1', 'followup_T1CE', 'followup_T2', 'followup_seg']
keys = baseline_image_keys + followup_image_keys

train_transform = mt.Compose([
    mt.LoadImaged(keys=keys),
    mt.EnsureChannelFirstd(keys=keys),
    mt.EnsureTyped(keys=keys),
    mt.Spacingd(keys=keys, pixdim=(1.0, 1.0, 1.0), mode=("bilinear")),
    # Crop around tumor mask with margin 10 voxels
    CropAroundMaskd(keys=[key for key in baseline_image_keys if "seg" not in key], mask_key="baseline_seg", margin=10),
    CropAroundMaskd(keys=[key for key in followup_image_keys if "seg" not in key], mask_key="followup_seg", margin=10),
    # Now you can apply random crops or other augmentations on the cropped ROI
    # mt.RandCropByPosNegLabeld(
    #     keys=keys,
    #     label_key="baseline_seg",
    #     spatial_size=(128, 128, 128),
    #     pos=1,
    #     num_samples=1,
    #     image_key=baseline_image_keys[0]
    # ),
    mt.RandFlipd(keys=keys, prob=0.5, spatial_axis=0),
    mt.RandFlipd(keys=keys, prob=0.5, spatial_axis=1),
    mt.RandFlipd(keys=keys, prob=0.5, spatial_axis=2),
    mt.NormalizeIntensityd(keys=keys, nonzero=True, channel_wise=True),
    mt.RandScaleIntensityd(keys=keys, factors=0.1, prob=1.0),
    mt.RandShiftIntensityd(keys=keys, offsets=0.1, prob=1.0),
    mt.DeleteItemsd(keys=["baseline_seg", "followup_seg"]),
    mt.ResizeD(keys=[img for img in keys if "seg" not in img], spatial_size=(128, 128, 128)),
]).set_random_state(seed=0xBAD5EED)

# train_transform = mt.Compose([
#     mt.LoadImaged(keys=keys),
#     mt.EnsureChannelFirstd(keys=keys),
#     mt.EnsureTyped(keys=keys),
#     # mt.Orientationd(keys=keys, axcodes="RAS"),
#     # Spatial augmentations
#     # mt.RandRotated(keys=keys, range_x=0.3, prob=0.5),
#     # mt.RandZoomd(keys=keys, min_zoom=0.9, max_zoom=1.1, prob=0.5),
#     # Intensity augmentations
#     # mt.RandGaussianNoised(keys=keys, std=0.01, prob=0.2),
#     # mt.RandAdjustContrastd(keys=keys, gamma=(0.7, 1.3), prob=0.3),
#     # Normalization
#     # mt.ScaleIntensityd(keys=keys),
#     # mt.ScaleIntensityRangePercentilesd(keys=keys, lower=5, upper=95, b_min=0, b_max=1),
#     mt.Spacingd(keys=keys, pixdim=(1.0, 1.0, 1.0), mode=("bilinear")),
#     # mt.CenterSpatialCropd(keys=keys, roi_size=[128, 128, 128]),
#     mt.RandCropByPosNegLabeld(
#         keys=keys,
#         label_key="baseline_seg",  # Use your segmentation mask
#         spatial_size=(128,128,128),
#         pos=1,  # Force tumor presence
#         num_samples=1,
#         image_key=baseline_image_keys[0]
#     ),
#     mt.RandFlipd(keys=keys, prob=0.5, spatial_axis=0),
#     mt.RandFlipd(keys=keys, prob=0.5, spatial_axis=1),
#     mt.RandFlipd(keys=keys, prob=0.5, spatial_axis=2),
#     mt.NormalizeIntensityd(keys=keys, nonzero=True, channel_wise=True),
#     mt.RandScaleIntensityd(keys=keys, factors=0.1, prob=1.0),
#     mt.RandShiftIntensityd(keys=keys, offsets=0.1, prob=1.0),
#     # mt.RandGaussianNoised(keys=[key for key in keys if key not in ["baseline_seg", "followup_seg"]], std=0.01, prob=0.3),  # Exclude segs
#     # mt.ConcatItemsd(keys=keys, name="images"),
#     # mt.DeleteItemsd(keys=keys),
#     # mt.RandFlipd(keys=["images"]),
#     # mt.RandBiasFieldd(keys=["images"], prob=0.5),
#     mt.DeleteItemsd(keys=["baseline_seg", "followup_seg"]),
#     mt.ResizeD(keys=[img for img in keys if "seg" not in img], spatial_size=(128, 128, 128)),
#     # mt.RandCropByPosNegLabeld(
#     #     keys=["images"],
#     #     label_key="baseline_seg",
#     #     spatial_size=(128, 128, 64),
#     #     pos=1, neg=1, num_samples=4,
#     #     image_key="images", prob=0.5
#     # ),
#     # mt.ConcatItemsd(keys=followup_image_keys, name="followup"),
#     # mt.ToTensord(keys=keys)
# ]).set_random_state(seed=0xBAD5EED)

valid_transform = mt.Compose([
    mt.LoadImaged(keys=keys),
    mt.EnsureChannelFirstd(keys=keys),
    mt.EnsureTyped(keys=keys),
    # mt.Orientationd(keys=keys, axcodes="RAS"),
    # mt.ConcatItemsd(keys=baseline_image_keys, name="baseline", dim=0),
    # mt.ConcatItemsd(keys=followup_image_keys, name="followup", dim=0),
    # mt.ScaleIntensityd(keys=keys),
    mt.Spacingd(keys=keys, pixdim=(1.0, 1.0, 1.0), mode=("bilinear")),
    # mt.CenterSpatialCropd(keys=keys, roi_size=[128, 128, 128]),
    CropAroundMaskd(keys=[key for key in baseline_image_keys if "seg" not in key], mask_key="baseline_seg", margin=10),
    CropAroundMaskd(keys=[key for key in followup_image_keys if "seg" not in key], mask_key="followup_seg", margin=10),
    # mt.RandCropByPosNegLabeld(
    #     keys=keys,
    #     label_key="baseline_seg",  # Use your segmentation mask
    #     spatial_size=(128,128,128),
    #     pos=1,  # Force tumor presence
    #     num_samples=1,
    #     image_key=baseline_image_keys[0]
    # ),
    mt.NormalizeIntensityd(keys=keys, nonzero=True, channel_wise=True),
    # mt.ConcatItemsD(keys=keys, name="images", dim=0),
    mt.DeleteItemsd(keys=["baseline_seg", "followup_seg"]),
    mt.ResizeD(keys=[img for img in keys if "seg" not in img], spatial_size=(128, 128, 128)),
    # mt.ToTensord(keys=["baseline", "followup"]),
    # mt.ToTensord(keys=keys)
]).set_random_state(seed=0xBAD5EED)

In [10]:
train_dataset = get_dataset(
    data_path=ATLAS_DATA_PATH,
    indices=split["train_patient_ids"],
    transform=train_transform
)

valid_dataset = get_dataset(
    data_path=ATLAS_DATA_PATH,
    indices=split["valid_patient_ids"],
    transform=valid_transform
)

len(train_dataset), len(valid_dataset)

(238, 84)

In [18]:
if os.path.exists("./labels.pkl"):
    labels = pypickle.load("./labels.pkl", verbose="silent")
else:
    labels = [sample["label"] for sample in train_dataset]
pypickle.save("./labels.pkl", labels)



False

In [12]:
from torch.utils.data import WeightedRandomSampler

class_counts = torch.bincount(torch.tensor(labels))
temperature = 1.0
weights = (1 / class_counts.float()) ** temperature
weights /= weights.sum()
samples_weights = weights[labels]

sampler = WeightedRandomSampler(
    weights=samples_weights,
    num_samples=len(samples_weights),
    replacement=True
)

In [13]:
BATCH_SIZE = 2
NUM_WORKERS = 4

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    # shuffle=True,
    sampler=sampler,
    num_workers=NUM_WORKERS
)

valid_loader = DataLoader(
    dataset=valid_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS
)

In [14]:
# next(iter(train_loader))["images"].shape, next(iter(valid_loader))["images"].shape

In [15]:
# next(iter(train_loader))

# Training Utilities

In [16]:
@torch.inference_mode()
def test_model_correctness(model, loader):
    images = next(iter(loader))
    baseline_images = {
        key.split("_")[-1]: images[key].to(device) for key in baseline_image_keys if "seg" not in key
    }
    followup_images = {
        key.split("_")[-1]: images[key].to(device) for key in followup_image_keys if "seg" not in key
    }
    model = model.to(device)
    outputs = model(baseline_images, followup_images).cpu().numpy()
    assert outputs.shape[-1] == 4, f'Wrong last output dimension. Expected {4}; Got {outputs.shape[-1]}'
    print('All gucci')

def calculate_receptive_field(params: list[tuple[int, int]]) -> int:
    r, s = 1, 1
    for kernel_size, stride in params:
        r += (kernel_size - 1) * stride
        s *= stride
    
    return r

def calculate_params(model: torch.nn.Module) -> int:
    return f'{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f} M'

In [17]:
def train(
    model, 
    loader, 
    loss_fn, 
    optimizer, 
    device, 
    scheduler=None, 
    wandb_logging=False,
    accumulation_steps=4
):
    model.train()
    scaler = GradScaler()
    total_loss = 0
    total_acc = 0
    step_count = 0
    
    optimizer.zero_grad()
    for i, batch in enumerate(tqdm(loader, leave=False)):
        # batch = batch[0]
        batch["label"] = batch["label"].to(device)
        baseline_images = {
            key.split("_")[-1]: batch[key].to(device) for key in baseline_image_keys if "seg" not in key
        }
        followup_images = {
            key.split("_")[-1]: batch[key].to(device) for key in followup_image_keys if "seg" not in key
        }

        with autocast(device_type="cuda"):
            outputs = model(baseline_images, followup_images)
            loss = loss_fn(outputs, batch["label"]) / accumulation_steps  # Scale loss

        scaler.scale(loss).backward()
        total_loss += loss.item() * accumulation_steps

        preds = outputs.argmax(dim=1).cpu().numpy()
        accuracy = np.mean(preds == batch["label"].cpu().numpy())
        total_acc += accuracy
        
        step_count += 1

        if (i + 1) % accumulation_steps == 0 or (i + 1) == len(loader):
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            
            if wandb_logging:
                wandb.log({
                    'train_loss': total_loss / step_count,
                    'train_accuracy': total_acc / step_count,
                })
            
            total_loss = 0
            total_acc = 0
            step_count = 0

    if scheduler is not None:
        scheduler.step()

    return {
        'train_loss': total_loss / max(1, step_count),
        'train_accuracy': total_acc / max(1, step_count)
    }
        
@torch.inference_mode()
def evaluate(model, loader, loss_fn, device, scheduler=None, wandb_logging=False):
    model.eval()
    batches_cnt = 0
    accuracy_acum = 0
    accuracy_per_class = np.zeros(NUM_CLASSES, dtype=np.float32)
    samples_cnt_per_class = np.zeros(NUM_CLASSES, dtype=np.float32)

    for batch in tqdm(loader, leave=False):
        # batch = batch[0]
        batch["label"] = batch["label"].to(device)
        baseline_images = {
            key.split("_")[-1]: batch[key].to(device) for key in baseline_image_keys if "seg" not in key
        }
        followup_images = {
            key.split("_")[-1]: batch[key].to(device) for key in followup_image_keys if "seg" not in key
        }

        outputs = model(baseline_images, followup_images)
        loss = loss_fn(outputs, batch["label"])

        # ===================
        # METRICS CALCULATION
        preds = outputs.argmax(dim=1).cpu().numpy()
        accuracy = np.mean(preds == batch["label"].cpu().numpy())
        for i in range(NUM_CLASSES):
            accuracy_per_class[i] += np.sum(preds[batch["label"].cpu().numpy() == i] == i) if np.any(batch["label"].cpu().numpy() == i) else 0
            samples_cnt_per_class[i] += np.sum(batch["label"].cpu().numpy() == i)
        accuracy_acum += accuracy
        batches_cnt += 1
        # ===================
    print(accuracy_per_class, samples_cnt_per_class)
    valid_logs = {
        'valid_loss': loss.item(),
        'valid_accuracy': accuracy_acum / batches_cnt,
        'valid_accuracy_0': accuracy_per_class[0] / samples_cnt_per_class[0],
        'valid_accuracy_1': accuracy_per_class[1] / samples_cnt_per_class[1],
        'valid_accuracy_2': accuracy_per_class[2] / samples_cnt_per_class[2],
        'valid_accuracy_3': accuracy_per_class[3] / samples_cnt_per_class[3],
    }

    if scheduler is not None:
        scheduler.step(valid_logs['valid_accuracy'])
        
    if wandb_logging:
        wandb.log(valid_logs)
    
    return valid_logs

# Model Architecture

In [18]:
NUM_CLASSES = 4

In [19]:
# class BaselineModel(nn.Module):
#     def __init__(self, encoder: nn.Module, emb_dim: int, dropout: float):
#         super(BaselineModel, self).__init__()
#         self.encoder = encoder
#         self.head = nn.Sequential(
#             nn.Linear(emb_dim, 4 * emb_dim),
#             nn.LayerNorm(4 * emb_dim),
#             nn.ReLU(inplace=True),
#             nn.Dropout(dropout),
#             nn.Linear(4 * emb_dim, NUM_CLASSES),
#             nn.Softmax(dim=-1)
#         )

#     def forward(self, baseline, followup):
#         baseline_embed = self.encoder(baseline.permute(0, 3, 1, 2))
#         followup_embed = self.encoder(followup.permute(0, 3, 1, 2))
#         # growing_embed = self.encoder(followup.permute(0, 3, 1, 2) - baseline.permute(0, 3, 1, 2))

#         # baseline_segm_mask = (baseline[:, 4] > 0)
#         # followup_segm_mask = (followup[:, 4] > 0)

#         # baseline[:, :4] = baseline[:, :4] * baseline_segm_mask.unsqueeze(1)
#         # followup[:, :4] = followup[:, :4] * followup_segm_mask.unsqueeze(1)

#         # baseline_embed = self.encoder(baseline)
#         # followup_embed = self.encoder(followup)
#         # joint_embed = self.encoder(followup - baseline)
#         # joint_embed = torch.cat([baseline_embed, followup_embed], dim=1)
#         joint_embed = followup_embed - baseline_embed
#         outputs = self.head(joint_embed)
#         return outputs

In [20]:
CLS_TOKEN_IDX = 0

class BaselineModel(nn.Module):
    def __init__(self, encoder: nn.Module, dropout: float):
        super(BaselineModel, self).__init__()
        self.encoder = encoder
        self.dif = nn.Sequential(
            nn.Linear(encoder.hidden_size, 2 * encoder.hidden_size),
            nn.LayerNorm(2 * encoder.hidden_size),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        self.clf = nn.Sequential(
            nn.Linear(4 * encoder.hidden_size, encoder.hidden_size),
            nn.LayerNorm(encoder.hidden_size),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(encoder.hidden_size, NUM_CLASSES),
            nn.Softmax(dim=-1) # FOCAL LOSS ONLY
        )

    def forward(self, baseline, followup):
        baseline_embed = self.encoder(baseline)[:, CLS_TOKEN_IDX]
        followup_embed = self.encoder(followup)[:, CLS_TOKEN_IDX]
        # growing_embed = self.encoder(followup.permute(0, 3, 1, 2) - baseline.permute(0, 3, 1, 2))

        # baseline_segm_mask = (baseline[:, 4] > 0)
        # followup_segm_mask = (followup[:, 4] > 0)

        # baseline[:, :4] = baseline[:, :4] * baseline_segm_mask.unsqueeze(1)
        # followup[:, :4] = followup[:, :4] * followup_segm_mask.unsqueeze(1)

        # baseline_embed = self.encoder(baseline)
        # followup_embed = self.encoder(followup)
        # joint_embed = self.encoder(followup - baseline)
        # joint_embed = torch.cat([baseline_embed, followup_embed], dim=1)
        delta_embed = self.dif(followup_embed - baseline_embed)
        # joint_embed = followup_embed - baseline_embed
        outputs = self.clf(torch.cat([baseline_embed, followup_embed, delta_embed], dim=1))
        return outputs

In [21]:
# from monai.networks.nets import DenseNet121

NUM_CHANNELS = 10
# EMB_DIM = 128
DROPOUT = 0.0
ENCODER_DROPOUT = 0.00

def get_model() -> nn.Module:
    modalities = ["T1", "T1CE", "T2", "FLAIR"]
    tokenizers = {
        modality: MRITokenizer(
            patch_size=(16, 16, 16),
            img_size=(128, 128, 128),
            hidden_size=768,
        )
        for modality in modalities
    }
    encoder = ViTEncoder(
        modalities=modalities,
        tokenizers=tokenizers,
        cls_token=True,
        dropout_rate=ENCODER_DROPOUT
    )

    state_dict = torch.load(BMMAE_PATH / 'pretrained_models/bmmae.pth')
    encoder.load_state_dict(state_dict, strict=False)

    for name, param in encoder.named_parameters():
        # if all(x not in name for x in ["blocks.11", "cls_token"]):
        if all(x not in name for x in ["blocks.11"]):
            print(f"Freezing {name}")
            param.requires_grad = False
        else:
            print(f"Unfreezing {name}")
            param.requires_grad = True
        
    model = BaselineModel(
        encoder=encoder,
        dropout=DROPOUT
    )

    return model

In [22]:
batch = next(iter(train_loader))

In [23]:
batch["baseline_FLAIR"].shape, batch["label"].shape

(torch.Size([2, 1, 128, 128, 128]), torch.Size([2]))

In [24]:
baseline_model = get_model()
test_model_correctness(baseline_model, train_loader)

Freezing cls_token
Freezing tokenizers.T1.patch_embedding.position_embeddings
Freezing tokenizers.T1.patch_embedding.patch_embeddings.weight
Freezing tokenizers.T1.patch_embedding.patch_embeddings.bias
Freezing tokenizers.T1CE.patch_embedding.position_embeddings
Freezing tokenizers.T1CE.patch_embedding.patch_embeddings.weight
Freezing tokenizers.T1CE.patch_embedding.patch_embeddings.bias
Freezing tokenizers.T2.patch_embedding.position_embeddings
Freezing tokenizers.T2.patch_embedding.patch_embeddings.weight
Freezing tokenizers.T2.patch_embedding.patch_embeddings.bias
Freezing tokenizers.FLAIR.patch_embedding.position_embeddings
Freezing tokenizers.FLAIR.patch_embedding.patch_embeddings.weight
Freezing tokenizers.FLAIR.patch_embedding.patch_embeddings.bias
Freezing norm.weight
Freezing norm.bias
Freezing blocks.0.mlp.linear1.weight
Freezing blocks.0.mlp.linear1.bias
Freezing blocks.0.mlp.linear2.weight
Freezing blocks.0.mlp.linear2.bias
Freezing blocks.0.norm1.weight
Freezing blocks.0.n

In [25]:
calculate_params(baseline_model)

'10.64 M'

# Training

In [26]:
def run_training_pipeline(
    epochs, model, train_loader, valid_loader, loss_fn, optimizer, device,
    scheduler=None, wandb_logging=False,
    checkpoint_freq: int = None, checkpoint_name: str = None
):
    seed_everything()
    model = model.to(device)
    train_logs, valid_logs = [], []

    if wandb_logging:
        wandb.init(
            project="MRI Classification",
            name="FFA",
            config={"desc": "Doing something"}
        )

    if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
        train_scheduler = None
        valid_scheduler = scheduler
    else:
        train_scheduler = scheduler
        valid_scheduler = None
    
    for epoch in tqdm(range(epochs)):
        train_log = train(
            model=model,
            loader=train_loader,
            loss_fn=loss_fn,
            optimizer=optimizer,
            device=device,
            scheduler=train_scheduler,
            wandb_logging=wandb_logging
        )
        train_logs.append(train_log)

        valid_log = evaluate(
            model=model,
            loader=valid_loader,
            loss_fn=loss_fn,
            device=device,
            scheduler=valid_scheduler,
            wandb_logging=wandb_logging
        )
        valid_logs.append(valid_log)

        if checkpoint_freq is not None:
            checkpoint = {
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "train_scheduler_state": None if train_scheduler is None else train_scheduler.state_dict(),
                "valid_scheduler_state": None if valid_scheduler is None else valid_scheduler.state_dict()
            }
            torch.save(checkpoint, PATH_TO_MODELS / f"{checkpoint_name}_{epoch + 1}.pth")
            # REMOVE PREVIOUS CHECKPOINT
            if epoch + 1 != checkpoint_freq:
                os.remove(PATH_TO_MODELS / f"{checkpoint_name}_{epoch + 1 - checkpoint_freq}.pth")

        if wandb_logging:
            wandb.log({'learning_rate': optimizer.param_groups[0]['lr']})
    
    if wandb_logging:
        wandb.finish()
    
    return train_logs, valid_logs

In [27]:
from focal_loss.focal_loss import FocalLoss

NUM_EPOCHS = 100
WEIGHT_DECAY = 1e-4
MOMENTUM = 0.90
LEARNING_RATE = 1e-4

model = get_model()
# optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
# optimizer = torch.optim.AdamW([
#     {'params': model.encoder.parameters(), 'lr': LEARNING_RATE / 3},
#     {'params': model.dif.parameters(), 'lr': LEARNING_RATE},
#     {'params': model.clf.parameters(), 'lr': LEARNING_RATE}
# ], weight_decay=WEIGHT_DECAY)
# loss_fn = nn.CrossEntropyLoss(weight=CLASS_WEIGHTS.to(device))
# loss_fn = nn.CrossEntropyLoss()
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=NUM_EPOCHS, eta_min=1e-5)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(
#     optimizer,
#     max_lr=LEARNING_RATE,
#     total_steps=NUM_EPOCHS,
#     pct_start=0.25
# )
NUM_WARMUP_EPOCHS = int(0.08 * NUM_EPOCHS)
# scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, [
#     torch.optim.lr_scheduler.LinearLR(optimizer, 1e-2, 1.0, total_iters=NUM_WARMUP_EPOCHS),  # Warmup
#     torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS - NUM_WARMUP_EPOCHS, eta_min=1e-6)  # Main
# ], [NUM_WARMUP_EPOCHS])
loss_fn = FocalLoss(gamma=2.0)
# weights = weights ** 0.5
# loss_fn = FocalLoss(weights=weights.to(device), gamma=2.0)
scheduler = None
calculate_params(model)

Freezing cls_token
Freezing tokenizers.T1.patch_embedding.position_embeddings
Freezing tokenizers.T1.patch_embedding.patch_embeddings.weight
Freezing tokenizers.T1.patch_embedding.patch_embeddings.bias
Freezing tokenizers.T1CE.patch_embedding.position_embeddings
Freezing tokenizers.T1CE.patch_embedding.patch_embeddings.weight
Freezing tokenizers.T1CE.patch_embedding.patch_embeddings.bias
Freezing tokenizers.T2.patch_embedding.position_embeddings
Freezing tokenizers.T2.patch_embedding.patch_embeddings.weight
Freezing tokenizers.T2.patch_embedding.patch_embeddings.bias
Freezing tokenizers.FLAIR.patch_embedding.position_embeddings
Freezing tokenizers.FLAIR.patch_embedding.patch_embeddings.weight
Freezing tokenizers.FLAIR.patch_embedding.patch_embeddings.bias
Freezing norm.weight
Freezing norm.bias
Freezing blocks.0.mlp.linear1.weight
Freezing blocks.0.mlp.linear1.bias
Freezing blocks.0.mlp.linear2.weight
Freezing blocks.0.mlp.linear2.bias
Freezing blocks.0.norm1.weight
Freezing blocks.0.n

'10.64 M'

In [28]:
import gc

gc.collect()
torch.cuda.empty_cache()

In [29]:
train_logs, val_logs = run_training_pipeline(
    epochs=NUM_EPOCHS,
    model=model,
    train_loader=train_loader,
    valid_loader=valid_loader,
    loss_fn=loss_fn,
    optimizer=optimizer,
    device=device,
    scheduler=scheduler,
    wandb_logging=True,
    checkpoint_freq=None,
    checkpoint_name=None,
)

  1%|          | 1/100 [03:44<6:09:48, 224.13s/it]

[ 0.  2.  9. 21.] [ 0. 14. 16. 54.]


  2%|▏         | 2/100 [07:25<6:03:11, 222.36s/it]

[ 0.  1.  3. 12.] [ 0. 14. 16. 54.]


  3%|▎         | 3/100 [11:05<5:58:16, 221.61s/it]

[ 0.  1. 13.  6.] [ 0. 14. 16. 54.]


  4%|▍         | 4/100 [14:47<5:54:18, 221.44s/it]

[ 0.  0.  6. 29.] [ 0. 14. 16. 54.]


  5%|▌         | 5/100 [18:23<5:47:29, 219.47s/it]

[ 0.  0. 11.  8.] [ 0. 14. 16. 54.]


  6%|▌         | 6/100 [21:59<5:42:26, 218.58s/it]

[ 0.  0. 11. 30.] [ 0. 14. 16. 54.]


  7%|▋         | 7/100 [25:36<5:37:37, 217.82s/it]

[ 0.  0. 11. 19.] [ 0. 14. 16. 54.]


  7%|▋         | 7/100 [27:59<6:11:56, 239.96s/it]


KeyboardInterrupt: 