# Import packages

In [None]:
import os
os.environ['CUDA_MODULE_LOADING']='LAZY'

from types import SimpleNamespace
import pandas as pd
from sklearn.model_selection import KFold

from timm.utils import NativeScaler
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, SequentialSampler

from rsna_utils import RsnaDataset
from rsna_nets import Q1Net, Q2Net, load_q1_pretrained
from rsna_engine import train_one_epoch, evaluate

print(torch.__version__)

# Load train.csv file
If you are running on Kolab or Kaggle it is important to change the path to this file.

In [None]:
DF_PATH = "./dataset/train_with_box.csv"
IMG_PATH = "./dataset/positive_images"
train_df = pd.read_csv(DF_PATH)
train_df["id"] = train_df.apply(lambda x: str(x.patient_id) + "_" + str(x.laterality), axis=1)

# Define global variables and models

In [None]:
device = "cuda"
lr = 5e-5
weight_decay = 0.05
momentum = 0.9
batch_size = 32
EPOCHES = 10

q1_model = Q1Net(stem_chs=[64, 32, 64], depths=[3, 4, 20, 3], path_dropout=0.2)
q1_model.to(device)
load_q1_pretrained("./dataset/nextvit_base_in1k6m_384.pth", q1_model)

q2_model = Q2Net()
q2_model.to(device)

args = SimpleNamespace()
args.opt = "adamw"
args.lr = lr
args.weight_decay = weight_decay
args.batch_size = batch_size
args.decay_rate = 0.1
args.decay_epochs = 30
args.warmup_epochs = 5
args.cooldown_epochs = 10
args.epochs = EPOCHES
args.min_lr = 1e-5
args.warmup_lr = 1e-6
args.sched = "cosine"
args.momentum = momentum

q1_optimizer = create_optimizer(args, q1_model)
q1_lr_scheduler, _ = create_scheduler(args, q1_optimizer)

q2_optimizer = create_optimizer(args, q2_model)
q2_lr_scheduler, _ = create_scheduler(args, q2_optimizer)

criterion = nn.CrossEntropyLoss()

inner_iterations = 4
patch_size = 384
patches_per_in_iter = 12
grad_acc_steps = 2
max_num_patches = 32
img_mean = 0
img_std = 1

# Split data into folds

In [None]:
FOLD_NUM = 1
all_patient_ids = train_df.id.unique()
patient_ids_w_cancer = train_df[train_df.cancer == 1].id.unique()
patient_ids_wo_cancer = np.array([p for p in all_patient_ids if p not in patient_ids_w_cancer])
kf = KFold(n_splits=5)
kf.get_n_splits(patient_ids_w_cancer)
patients_w_iter = kf.split(patient_ids_w_cancer)
kf.get_n_splits(patient_ids_wo_cancer)
patients_wo_iter = kf.split(patient_ids_wo_cancer)
for _ in range(FOLD_NUM):
    first_fold_positive_train, first_fold_positive_valid = next(patients_w_iter)
    first_fold_negative_train, first_fold_negative_valid = next(patients_wo_iter)
    
train_patient_ids = np.concatenate(
    (
        patient_ids_w_cancer[first_fold_positive_train],
        patient_ids_wo_cancer[first_fold_negative_train]
    )
)
train_labels = np.concatenate((
    np.ones((len(first_fold_positive_train), ), dtype=np.int64),
    np.zeros((len(first_fold_negative_train),), dtype=np.int64)
))
valid_patient_ids = np.concatenate(
    (
        patient_ids_w_cancer[first_fold_positive_valid],
        patient_ids_wo_cancer[first_fold_negative_valid]
    )
)
valid_labels = np.concatenate((
    np.ones((len(first_fold_positive_valid), ), dtype=np.int64),
    np.zeros((len(first_fold_negative_valid),), dtype=np.int64)
))

# Create PyTorch Loaders

In [None]:
train_dataset = RsnaDataset(train_patient_ids, train_labels, is_train=True)
train_loader = DataLoader(
    train_dataset,
    sampler=SequentialSampler(train_dataset),
    batch_size=batch_size,
    drop_last=False,
    num_workers=2,
    pin_memory=True,
)

valid_dataset = RsnaDataset(valid_patient_ids, valid_labels, is_train=False)
valid_loader = DataLoader(
    valid_dataset,
    sampler=SequentialSampler(valid_dataset),
    batch_size=batch_size,
    drop_last=False,
    num_workers=2,
    pin_memory=True,
)

# Run training

In [None]:
for epoch in range(EPOCHES):
    train_one_epoch(
        train_df, IMG_PATH, patch_size, patches_per_in_iter, grad_acc_steps,
        q1_model, q2_model, criterion, train_loader,
        q1_optimizer, q2_optimizer, inner_iterations, 
        device, epoch, max_num_patches=max_num_patches,
        img_mean=img_mean, img_std=img_std
    )
    q1_lr_scheduler.step(epoch)
    q2_lr_scheduler.step(epoch)
    
    evaluate(
        train_df, IMG_PATH, valid_loader,
        q1_model, q2_model, patch_size, device,
        max_num_patches=max_num_patches, img_mean=img_mean, img_std=img_std
    )