In [1]:
import os
from pathlib import Path

import sys
sys.path.append(Path(os.getcwd()).parent.as_posix())

import warnings
warnings.filterwarnings("ignore")

In [2]:
import torch
import torch.nn as nn

import monai.transforms as mt

from utils import RANDOM_SEED
from dataset import ReorganizeTransform, get_trainval_dataloaders
from training import test_model_correctness, calculate_params, MetricLogger, default_process, run_training
from models import BaselineModel

from focal_loss import FocalLoss

In [3]:
%load_ext autoreload
%autoreload 2

# Setting up GPU

In [4]:
os.environ["CUDA_VISIBLE_DEVICES"] = "6"

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

# Datasets

In [6]:
DATA_PATH = Path(os.getcwd()) / "data" / "preprocessed_data"

### Transforms

In [7]:
baseline_image_keys = ["baseline_FLAIR", "baseline_T1", "baseline_T1CE", "baseline_T2", "baseline_seg"]
followup_image_keys = ["followup_FLAIR", "followup_T1", "followup_T1CE", "followup_T2", "followup_seg"]
keys = baseline_image_keys + followup_image_keys

required_keys = ["T1", "T1CE", "T2", "FLAIR", "seg"]
baseline_keys = [f"baseline_{key}" for key in required_keys]
followup_keys = [f"followup_{key}" for key in required_keys]
all_image_keys = baseline_keys + followup_keys

transform = mt.Compose([
    mt.LoadImaged(keys=all_image_keys),
    # mt.EnsureChannelFirstd(keys=keys),
    mt.EnsureTyped(keys=all_image_keys),
    mt.Spacingd(keys=[key for key in keys if "seg" in key], pixdim=(1.0, 1.0, 1.0), mode="nearest"),
    mt.Spacingd(keys=[key for key in keys if "seg" not in key], pixdim=(1.0, 1.0, 1.0), mode=("bilinear")),
    mt.NormalizeIntensityd(keys=all_image_keys, nonzero=True, channel_wise=True),
    ReorganizeTransform(required_keys=required_keys)
]).set_random_state(seed=RANDOM_SEED)

### Dataloaders

In [8]:
NUM_WORKERS = 4
TRAIN_BATCH_SIZE = 8
VALID_BATCH_SIZE = 32
RESAMPLE = True
RESAMPLE_TEMPERATURE = 1.0

train_loader, valid_loader = get_trainval_dataloaders(
    data_path=DATA_PATH,
    train_batch_size=TRAIN_BATCH_SIZE,
    valid_batch_size=VALID_BATCH_SIZE,
    num_workers=NUM_WORKERS,
    train_transform=transform,
    valid_transform=transform,
    resample=RESAMPLE,
    temperature=RESAMPLE_TEMPERATURE
)

In [9]:
len(train_loader), len(valid_loader)

(36, 3)

# Architecture

In [10]:
EMBED_DIM = 64
NUM_CHANNELS = 218
DROPOUT = 0

In [11]:
from torchvision.models import resnet18

def get_encoder():
    encoder = resnet18(
        pretrained=False,
        num_classes=EMBED_DIM
    )
    encoder.conv1 = nn.Conv2d(
        in_channels=NUM_CHANNELS,
        out_channels=64,
        kernel_size=7,
        stride=2,
        padding=3
    )
    return encoder

In [12]:
model = BaselineModel(
    encoder=get_encoder,
    emb_dim=EMBED_DIM,
    dropout=DROPOUT,
    logits=False,
    image_keys=["T1CE", "T2"]
)

In [13]:
calculate_params(model)

'23.78 M'

In [14]:
test_model_correctness(
    model=model,
    loader=train_loader,
    process_fn=default_process,
    device=device
)

MODEL SEEMS TO BE FINE!


# Training

In [15]:
NUM_EPOCHS = 30
WEIGHT_DECAY = 0.0
LEARNING_RATE = 3e-4

LOSS_GAMMA = 2.0
LOSS_ALPHA = None

In [16]:
model = BaselineModel(
    encoder=get_encoder,
    emb_dim=EMBED_DIM,
    dropout=DROPOUT,
    logits=False,
    image_keys=["T1CE", "T2"]
)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
loss_fn = FocalLoss(weights=LOSS_ALPHA, gamma=LOSS_GAMMA)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=NUM_EPOCHS, eta_min=1e-6)

metric_logger = MetricLogger(logits=False)
metric_logger.reset()

In [None]:
train_logs, valid_logs = run_training(
    run_name="test_3",
    epochs=NUM_EPOCHS,
    model=model,
    train_loader=train_loader,
    valid_loader=valid_loader,
    loss_fn=loss_fn,
    optimizer=optimizer,
    device=device,
    process_fn=default_process,
    metric_logger=metric_logger,
    batch_accum=1,
    wandb_logging=False,
    scheduler=scheduler,
    scheduler_step="epoch",
    make_checkpoints=True,
    checkpoint_freq=1
)

  0%|          | 0/30 [00:00<?, ?it/s]