# Version 2 of 2.5 D model

In [None]:
# !pip install -q tensorflow-io

In [None]:
import os, gc, sys, copy, pickle

from pathlib import Path
from collections import defaultdict, Counter
import glob
from tqdm.auto import tqdm
tqdm.pandas()

import math
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from joblib import Parallel, delayed
import multiprocessing as mp

import albumentations as A
import torch
import torch.nn as nn
import torch.optim as optim
import torch.cuda.amp as amp
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

from torch.utils.data import WeightedRandomSampler
from sklearn.utils.class_weight import compute_class_weight
from sklearn import model_selection

from transformers import get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup

import timm

import cv2
cv2.setNumThreads(0)
import PIL
import pydicom
from IPython import display
import warnings
warnings.filterwarnings("ignore")
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"

In [None]:
def seeding(SEED):
    np.random.seed(SEED)
    random.seed(SEED)
    os.environ['PYTHONHASHSEED'] = str(SEED)
    torch.manual_seed(SEED)
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(SEED)
        torch.cuda.manual_seed_all(SEED)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    print('seeding done!!!')

def flush():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

In [None]:
# timm.list_pretrained("densenet*")
# net = timm.create_model("resnet18d", pretrained=True, num_classes=0)
# x = torch.randn(2, 3, 224, 224)
# net.forward_features(x).size()

# Config

In [None]:
CONFIG = dict(
    load_kernel = None,
    load_last = True,
    n_folds = 5,
    backbone = "convnext_nano.in12k_ft_in1k", # , tf_efficientnet_b0.ns_jft_in1k, convnext_pico.d1_in1k, convnext_nano.in12k_ft_in1k
    img_size = 224,
    n_slice_per_c = 10, # 10, 12, 16
    axial_chans = 15,
    axial_labels = 10,
    axial_classes = 3 * 10,
    
    sagT1_chans = 15,
    sagT1_labels = 10,
    sagT1_classes = 3 * 10,
    
    sagT2_chans = 15,
    sagT2_labels = 5,
    sagT2_classes = 3 * 5,
    
    n_classes = 3 * 25,

    drop_rate = 0.,
    drop_rate_last = 0.3,
    drop_path_rate = 0.,
    p_mixup = 0.5,
    p_rand_order_v1 = 0.2,
    lr = 1e-4,

    epochs = 40,
    batch_size = 8,
    warmup = 1,
    num_cycles = 0.475,
    device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu",
    seed = 42,
    log_wandb = True,
    with_clip = False,
    use_lstm = True,
)

if CONFIG['log_wandb']:
    name = CONFIG['backbone'].split('.')[0]
    CONFIG['project_name'] = f"RSNA2024-25Dv1-{name}"
    CONFIG['artifact_name'] = "RSNA2024Model"
    import wandb
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    secret_value_0 = user_secrets.get_secret("WANDB_API_KEY")
    wandb.login(key=secret_value_0)

seeding(CONFIG['seed'])

# Load data

In [None]:
DATA_PATH = Path("/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification")

train_main = pd.read_csv(DATA_PATH/"train.csv")
train_desc = pd.read_csv(DATA_PATH/"train_series_descriptions.csv")
train_labels = pd.read_csv(DATA_PATH/"train_label_coordinates.csv")

train_main = train_main.fillna(-100)

label2id = {'Normal/Mild': 0, 'Moderate':1, 'Severe':2}
train_main = train_main.replace(label2id)

In [None]:
TARGET_COLS = train_main.columns.tolist()[1:]
AXIAL_COLS = {col:i for i, col in enumerate(train_main.columns[1:]) if 'subarticular_stenosis' in col}
SAGT1_COLS = {col:i for i, col in enumerate(train_main.columns[1:]) if 'neural_foraminal_narrowing' in col}
SAGT2_COLS = {col:i for i, col in enumerate(train_main.columns[1:]) if 'spinal_canal_stenosis' in col}

CONDITIONS = [
    'Spinal Canal Stenosis', 
    'Left Neural Foraminal Narrowing', 
    'Right Neural Foraminal Narrowing',
    'Left Subarticular Stenosis',
    'Right Subarticular Stenosis'
]

LEVELS = [
    'L1/L2',
    'L2/L3',
    'L3/L4',
    'L4/L5',
    'L5/S1',
]

In [None]:
df = pd.read_csv("/kaggle/input/rsna2024-data-split/rsna_folds.csv")

In [None]:
def load_dicom(path):
    dicom = pydicom.dcmread(path)
    data = dicom.pixel_array
    if dicom.PhotometricInterpretation == "MONOCHROME1":
        data = data - np.min(data)
        
    if np.max(data) != 0:
        data = data / (np.max(data) + 1e-4)
    data = (data * 255).astype(np.uint8)
    return data

def get_windowed_image(path, WC=100, WW=500):
    # WC - WindowCenter, WW - WindowWidth
    upper, lower = WC + WW//2, WC - WW//2
    dcm = pydicom.dcmread(path)
    pixel_array = dcm.pixel_array
    
    img = np.clip(pixel_array.copy(), lower, upper)
    img = img - np.min(img)
    img = img / np.max(img)
    img = (img * 255.0).astype(np.uint8)
    return img

# Dataset

In [None]:
class Spine25DDataset(Dataset):
    def __init__(self, data, desc, mode='train', transform=None):
        self.data = data
        self.desc = desc
        self.mode = mode
        self.transform = transform
        
    
    def __len__(self):
        return len(self.data)
    
    def get_img_paths(self, study_id, series_desc):
        pdf = self.desc[self.desc['study_id'] == study_id]
        pdf_ = pdf[pdf['series_description'] == series_desc]
        allimgs = []
        for i, row in pdf_.iterrows():
            pimgs = glob.glob(f"{str(DATA_PATH)}/train_images/{study_id}/{row['series_id']}/*.dcm")
            pimgs = sorted(pimgs, key=lambda p: int(os.path.basename(p).split('.')[0]))
            allimgs.extend(pimgs)
        return allimgs
    
    def read_dcm(self, src_path):
        img = load_dicom(src_path)
        return img
    
    def get_images(self, nslides, image_paths):
        H, W = CONFIG['img_size'], CONFIG['img_size']
        IMAGES = np.zeros((nslides, H, W, 3), dtype=np.uint8)
        for i in range(nslides):
            try:
                img = self.read_dcm(image_paths[i])
                img = cv2.resize(img, (H,W)).astype(np.uint8)
#                 img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
                img = img[..., None].repeat(3, -1)
                img = self.transform(image=img)['image']
                IMAGES[i, ...] = img
            except:
                pass
            
        return IMAGES
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        study_id = row.study_id
        
        axial_label = self.data.loc[idx, AXIAL_COLS.keys()].values.astype(int)
        coronal_label = self.data.loc[idx, SAGT1_COLS.keys()].values.astype(int)
        sagittal_label = self.data.loc[idx, SAGT2_COLS.keys()].values.astype(int)
        
        H, W = CONFIG['img_size'], CONFIG['img_size']
        n_slides_per_c = CONFIG["n_slice_per_c"]
        sagittal = np.zeros((n_slides_per_c, H, W, 3), dtype=np.uint8)
        coronal = np.zeros((n_slides_per_c, H, W, 3), dtype=np.uint8)
        axial = np.zeros((n_slides_per_c, H, W, 3), dtype=np.uint8)
        
        # Sagittal
        allimgs_sag = self.get_img_paths(study_id, 'Sagittal T2/STIR')
        
        if len(allimgs_sag)==0:
            pass
        
        else:
            sagT2_scans = len(allimgs_sag)
            sagT2_indices = np.quantile(list(range(sagT2_scans)), np.linspace(0., 1., n_slides_per_c)).round().astype(int)
            allimgs_sag = [allimgs_sag[i] for i in sagT2_indices]
            sagittal = self.get_images(nslides=n_slides_per_c, image_paths=allimgs_sag)
                
        # coronal
        allimgs_cor = self.get_img_paths(study_id, 'Sagittal T1')
    
        if len(allimgs_cor)==0:
            pass
        
        else:
            sagT1_scans = len(allimgs_cor)
            sagT1_indices = np.quantile(list(range(sagT1_scans)), np.linspace(0., 1., n_slides_per_c)).round().astype(int)
            allimgs_cor = [allimgs_cor[i] for i in sagT1_indices]
            coronal = self.get_images(nslides=n_slides_per_c, image_paths=allimgs_cor)
                
        # Axial
        allimgs_ax = self.get_img_paths(study_id, 'Axial T2')
        
        if len(allimgs_ax)==0:
            pass
        
        else:
            ax_scans = len(allimgs_ax)
            ax_indices = np.quantile(list(range(ax_scans)), np.linspace(0., 1., n_slides_per_c)).round().astype(int)
            allimgs_ax = [allimgs_ax[i] for i in ax_indices]
            axial = self.get_images(nslides=n_slides_per_c, image_paths=allimgs_ax)
        
        axial = axial.transpose(0,3,1,2).astype(np.float32) / 255.0 
        coronal = coronal.transpose(0,3,1,2).astype(np.float32) / 255.0 
        sagittal = sagittal.transpose(0,3,1,2).astype(np.float32) / 255.0 
        
        if self.mode != 'test':
            axial = torch.tensor(axial).float()
            coronal = torch.tensor(coronal).float()
            sagittal = torch.tensor(sagittal).float()
            axial_label = torch.tensor([axial_label] * CONFIG["n_slice_per_c"]).float()
            coronal_label = torch.tensor([coronal_label] * CONFIG["n_slice_per_c"]).float()
            sagittal_label = torch.tensor([sagittal_label] * CONFIG["n_slice_per_c"]).float()
            
            if self.mode == 'train' and random.random() < CONFIG['p_rand_order_v1']:
                axial_indices = torch.randperm(axial.size(0))
                coronal_indices = torch.randperm(coronal.size(0))
                sagittal_indices = torch.randperm(sagittal.size(0))
                axial = axial[axial_indices]
                coronal = coronal[coronal_indices]
                sagittal = sagittal[sagittal_indices]
            return {"axial": axial, "coronal": coronal, "sagittal": sagittal, 
                    "axial_target": axial_label, "coronal_target": coronal_label, "sagittal_target": sagittal_label}
        
        else:
            return {"axial": torch.tensor(axial).float(), 
                    "coronal": torch.tensor(coronal).float(), 
                    "sagittal": torch.tensor(sagittal).float()}

## Transformations and Dataloaders

In [None]:
def get_transforms(height, width):
    train_tsfm = A.Compose([
#         A.Resize(height=height, width=width, interpolation=cv2.INTER_AREA),
        A.Resize(height=height, width=width),
        A.Perspective(p=0.5),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
#         A.Rotate(limit=(-25, 25), p=0.5),
        A.RandomBrightnessContrast(brightness_limit=0.7, p=0.7),
        A.ShiftScaleRotate(shift_limit=0.3, scale_limit=0.3, rotate_limit=30, border_mode=4, p=0.7),
        
        A.OneOf([
            A.MotionBlur(blur_limit=3),
            A.MedianBlur(blur_limit=3),
            A.GaussianBlur(blur_limit=3),
            A.GaussNoise(var_limit=(3.0, 9.0)),
        ], p=0.5),
        
        A.OneOf([
            A.OpticalDistortion(distort_limit=1.),
            A.GridDistortion(num_steps=5, distort_limit=1.),
            A.ElasticTransform(alpha=3),
        ], p=0.5),
        A.CoarseDropout(max_holes=1, max_height=int(height * 0.375), max_width=int(width * 0.375), p=0.5),
    ])
    
    valid_tsfm = A.Compose([
#         A.Resize(height=height, width=width, interpolation=cv2.INTER_AREA),
        A.Resize(height=height, width=width),
#         A.CenterCrop(height=height, width=width, p=1.0),
    ])
    return {"train": train_tsfm, "eval": valid_tsfm}


def get_dataloaders(data, desc, cfg, split="train"):
    img_size = cfg['img_size']
    height, width = img_size, img_size
    tsfm = get_transforms(height=height, width=width)
    study_ids = data.study_id.unique().tolist()
    if split == 'train':
        tr_tsfm = tsfm['train']
        ds = Spine25DDataset(data=data, desc=desc, mode='train', transform=tr_tsfm)
        dls = DataLoader(ds, 
                         batch_size=cfg['batch_size'], 
                         shuffle=True,
                         num_workers=os.cpu_count(), 
                         drop_last=True, 
                         pin_memory=True, 
                         prefetch_factor=2)
        
    elif split == 'valid' or split == 'test':
        eval_tsfm = tsfm['eval']
        ds = Spine25DDataset(data=data, desc=desc, mode='valid', transform=eval_tsfm)
        dls = DataLoader(ds, 
                         batch_size=2*cfg['batch_size'], 
                         shuffle=False, 
                         num_workers=os.cpu_count(), 
                         drop_last=False, 
                         pin_memory=True, 
                         prefetch_factor=2)
    else:
        raise Exception("Split should be 'train' or 'valid' or 'test'!!!")
    return dls

In [None]:
# dls = get_dataloaders(data=train_main, desc=train_desc, cfg=CONFIG, split='train')
# b = next(iter(dls))

In [None]:
# gc.collect()
# k = 0
# fig, axes = plt.subplots(2, 5, figsize=(12, 12))
# axes = axes.flatten()
# sag = b['sagittal'][k].detach().cpu().numpy().transpose(0,2,3,1)

# for i in range(10):
#     axes[i].imshow(sag[i, ...])
#     axes[i].axis(False)
# plt.tight_layout()
# plt.show()
# # b['axial'].shape
# del sag

# Model

In [None]:
class BaseModel(nn.Module):
    def __init__(self, backbone, in_chans=3, pretrained=False, increase_stride=False):
        super(BaseModel, self).__init__()

        self.encoder = timm.create_model(
            backbone,
            in_chans=in_chans,
            num_classes=0,
            features_only=False,
            drop_rate=CONFIG["drop_rate"],
            drop_path_rate=CONFIG["drop_path_rate"],
            pretrained=pretrained
        )
        self.encoder.name = backbone
        self.nb_fts = self.encoder.num_features
        self.gap = nn.AdaptiveAvgPool2d(1)
#         self.gap = GeM(p_trainable=False)
        
        if increase_stride:
            self.increase_stride()
        
    def increase_stride(self):
        """
        Increase the stride of the first layer of the encoder
        """
        if "efficientnet" in self.encoder.name:
            self.encoder.conv_stem.stride = (4, 4)
        elif "nfnet" in self.encoder.name:
            self.encoder.stem.conv1.stride = (4, 4)
        else:
            raise NotImplementedError
            
    def forward(self, x):
        x = self.encoder.forward_features(x)
        x = self.gap(x)[:,:,0,0]
        return x
    
    
class Clf(nn.Module):
    def __init__(self, backbone, pretrained=False, increase_stride=False):
        super(Clf, self).__init__()
        self.axial_encoder = BaseModel(backbone=backbone, in_chans=3, 
                                       pretrained=pretrained, increase_stride=increase_stride)
        self.coronal_encoder = BaseModel(backbone=backbone, in_chans=3, 
                                         pretrained=pretrained, increase_stride=increase_stride)
        self.sagittal_encoder = BaseModel(backbone=backbone, in_chans=3, 
                                          pretrained=pretrained, increase_stride=increase_stride)
        
        self.in_chans = 3
        self.nb_fts = 3*self.axial_encoder.nb_fts
        self.lstm = nn.LSTM(self.nb_fts, 256, num_layers=2, dropout=CONFIG["drop_rate"], bidirectional=True, batch_first=True)
        self.axial_head = self.get_head(CONFIG['axial_classes'])
        self.coronal_head = self.get_head(CONFIG['sagT1_classes'])
        self.sagittal_head = self.get_head(CONFIG['sagT2_classes'])
    
    
    def get_head(self, n_classes):
        head = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.Dropout(CONFIG["drop_rate_last"]),
            nn.LeakyReLU(0.1),
            nn.Linear(256, n_classes),
        )
        return head
    
    def extract_features(self, x, view='axial'):
        bs = x.shape[0]
        x = x.view(bs * CONFIG["n_slice_per_c"], self.in_chans, CONFIG["img_size"], CONFIG["img_size"])
        if view == 'axial':
            feat = self.axial_encoder(x)
        elif view == 'coronal':
            feat = self.coronal_encoder(x)
        elif view == 'sagittal':
            feat = self.sagittal_encoder(x)
        else:
            raise NotImplementedError
            
        feat = feat.view(bs, CONFIG["n_slice_per_c"], -1)
        return feat
        
        
    def forward(self, ax, cor, sag):
        bs = ax.shape[0]
        ax_fts = self.extract_features(ax, view='axial')
        cor_fts = self.extract_features(cor, view='coronal')
        sag_fts = self.extract_features(sag, view='sagittal')
        fts = torch.concatenate([ax_fts, cor_fts, sag_fts], dim=-1)
        fts, _ = self.lstm(fts)
        fts = fts.contiguous().view(bs * CONFIG["n_slice_per_c"], -1)
        ax_fts = self.axial_head(fts)
        y_ax = ax_fts.view(bs, CONFIG["n_slice_per_c"], CONFIG["axial_classes"]).contiguous()
        cor_fts = self.coronal_head(fts)
        y_cor = cor_fts.view(bs, CONFIG["n_slice_per_c"], CONFIG["sagT1_classes"]).contiguous()
        sag_fts = self.sagittal_head(fts)
        y_sag = sag_fts.view(bs, CONFIG["n_slice_per_c"], CONFIG["sagT2_classes"]).contiguous()
        return y_ax, y_cor, y_sag

In [None]:
from collections import Counter, defaultdict

class MetricMonitor:
    def __init__(self, float_precision=4):
        self.float_precision = float_precision
        self.reset()

    def reset(self):
        self.metrics = defaultdict(lambda: {"val": 0, "count": 0, "avg": 0})

    def update(self, metric_name, val):
        metric = self.metrics[metric_name]

        metric["val"] += val
        metric["count"] += 1
        metric["avg"] = metric["val"] / metric["count"]

    def __str__(self):
        return " | ".join(
            [
                "{metric_name}: {avg:.{float_precision}f}".format(
                    metric_name=metric_name, avg=metric["avg"], float_precision=self.float_precision
                )
                for (metric_name, metric) in self.metrics.items()
            ]
        )

## Training arguments

In [None]:
gc.collect()

def get_loss(logits, labels, criterion):
    loss = criterion(logits.view(-1, 3), labels.view(-1).to(torch.int64))
    return loss

# get_loss(o3, b['sagittal_target'], criterion)
def shared_step(model, batch, criterion):
    axial = batch['axial'].to(CONFIG["device"], non_blocking=True)
    coronal = batch['coronal'].to(CONFIG["device"], non_blocking=True)
    sagittal = batch['sagittal'].to(CONFIG["device"], non_blocking=True)
    axial_target = batch['axial_target'].to(CONFIG["device"], non_blocking=True)
    coronal_target = batch['coronal_target'].to(CONFIG["device"], non_blocking=True)
    sagittal_target = batch['sagittal_target'].to(CONFIG["device"], non_blocking=True)
    
    axial_logits, coronal_logits, sagittal_logits = model(axial, coronal, sagittal)
    axial_loss = get_loss(axial_logits, axial_target, criterion=criterion)
    coronal_loss = get_loss(coronal_logits, coronal_target, criterion)
    sagittal_loss = get_loss(sagittal_logits, sagittal_target, criterion)
    loss = axial_loss + coronal_loss + sagittal_loss
    return {
        "loss": loss / 3
    }

# shared_step(net, b, criterion)
def train(train_loader, model, criterion, optimizer, epoch, scaler, scheduler=None):
    metric_monitor = MetricMonitor()
    model.train()
    stream = tqdm(train_loader)
    train_loss = 0
    for i, batch in enumerate(stream, start=1):
        optimizer.zero_grad(set_to_none=True)
        
        with torch.autocast(device_type='cuda', dtype=torch.float16):
            outputs = shared_step(model, batch, criterion)
            loss =  outputs['loss']
        
        metric_monitor.update("Loss", loss)
        train_loss += loss.detach().float()
        CONFIG['example_ct'] += len(batch["axial"])
        scaler.scale(loss).backward()
        
        # clip the gradient
        if CONFIG['with_clip']:
            scaler.unscale_(optimizer)
            nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)
        
        lr = optimizer.param_groups[0]['lr']
        scaler.step(optimizer)
        scaler.update()
        
        _train_metrics = {
            "train/step_loss": loss,
            "train/epoch": (i + 1 + CONFIG['n_steps_per_epoch'] * CONFIG['epochs']),
            "train/example_ct": CONFIG['example_ct'],
            "lr": lr,
        }
        
        if CONFIG['log_wandb'] and (i+1 < CONFIG['n_steps_per_epoch']):
            wandb.log(_train_metrics)
        
        CONFIG['step_ct'] += 1
        if scheduler is not None:
            scheduler.step()
        
        stream.set_description(
            "Epoch: {epoch}. Train.      {metric_monitor}".format(epoch=epoch, metric_monitor=metric_monitor)
        )
        
    total_train_loss = train_loss / len(train_loader)
    _train_metrics['train/epoch_loss'] = total_train_loss
    
    flush()
    return _train_metrics


def validate(val_loader, model, criterion, epoch):
    metric_monitor = MetricMonitor()
    model.eval()
    stream = tqdm(val_loader)
    valid_loss = 0
    
    with torch.no_grad():
        for i, batch in enumerate(stream, start=1):
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                outputs = shared_step(model, batch, criterion)
                loss =  outputs['loss']

            metric_monitor.update("Loss", loss)
            valid_loss += loss.detach().float()
            _valid_metrics = {
                    "valid/step_loss": loss,
                }
            
            stream.set_description(
                "Epoch: {epoch}. Validation. {metric_monitor}".format(epoch=epoch, metric_monitor=metric_monitor)
            )
            
    total_valid_loss = valid_loss / len(val_loader)
    _valid_metrics['valid/epoch_loss'] = total_valid_loss
    flush()
    return _valid_metrics

In [None]:
def train_and_validate(model, train_dataset, val_dataset, desc, fold=0):
    seeding(CONFIG['seed'])
    
    if CONFIG['log_wandb']:
        run = wandb.init(
            project=CONFIG["project_name"],
            resume="allow",
        )
        artifact = wandb.Artifact(f"{CONFIG['artifact_name']}_{fold}", type="model")
    
    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            DEVICE_IDS = list(range(torch.cuda.device_count()))
            print(f"\nUsing {len(DEVICE_IDS)} GPUs to train ...\n")
            model = nn.DataParallel(model, device_ids=DEVICE_IDS)
            
    model = model.to(CONFIG["device"])
#     model.apply(init_weights)
    train_loader = get_dataloaders(data=train_dataset, desc=desc, cfg=CONFIG, split="train")
    valid_loader = get_dataloaders(data=val_dataset, desc=desc, cfg=CONFIG, split="valid")
    
    n_steps_per_epoch = math.ceil(len(train_loader.dataset) / CONFIG['batch_size'])
    CONFIG['n_steps_per_epoch'] = n_steps_per_epoch
    CONFIG['example_ct'] = 0
    CONFIG['step_ct'] = 0
    
    # weighted cross entropy loss
    class_weights = torch.tensor([1, 2, 4], dtype=torch.float32)
    criterion = nn.CrossEntropyLoss(weight=class_weights).to(CONFIG["device"])
 
    optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["lr"])
    scaler = torch.cuda.amp.GradScaler()

    scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=CONFIG["warmup"] * CONFIG['n_steps_per_epoch'],
            num_training_steps=CONFIG["epochs"]* CONFIG['n_steps_per_epoch'],
            num_cycles = CONFIG["num_cycles"],
        )
    
#     scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
#         optimizer,
#         num_warmup_steps=CONFIG["warmup"] * CONFIG['n_steps_per_epoch'],
#         num_training_steps=CONFIG["epochs"]* CONFIG['n_steps_per_epoch'],
#         num_cycles = 2,
#     )
    
    best_metric = np.inf
    loss_min = np.inf
    es = 0
    ES_RATIO = 0.3 if CONFIG["epochs"] <= 30 else 0.25
    weights_file = "rsna_2024_lumbar_spine_fold_{fold}_epoch_{epoch}.pth"
    for epoch in range(1, CONFIG["epochs"] + 1):
        _train_metrics = train(train_loader, model, criterion, optimizer, epoch, scaler, scheduler=scheduler)
        _valid_metrics = validate(valid_loader, model, criterion, epoch)
        
        val_loss = _valid_metrics['valid/epoch_loss']
        if CONFIG['log_wandb']:
            wandb.log({**_train_metrics, **_valid_metrics})
        
        if val_loss < best_metric:
            print(f"Best metric: ({best_metric:.6f} --> {val_loss:.6f}). Saving model ...")
            if torch.cuda.device_count() > 2:
                torch.save(model.module.state_dict(), weights_file.format(fold=fold, epoch=epoch))
            else:
                torch.save(model.state_dict(), weights_file.format(fold=fold, epoch=epoch))
            best_metric = val_loss
            if CONFIG['log_wandb']:
                if epoch == 1:
                    artifact.add_file(weights_file.format(fold=fold, epoch=epoch))
                    run.log_artifact(artifact)
                else:
                    draft_artifact = wandb.Artifact(f"{CONFIG['artifact_name']}_{fold}", type="model")
                    draft_artifact.add_file(weights_file.format(fold=fold, epoch=epoch))
                    run.log_artifact(draft_artifact)
                
            es = 0
            
        else:
            es += 1
            
        if es > math.ceil(ES_RATIO*CONFIG["epochs"]):
            print(f"Early stopping on epoch {epoch} ...")
            break
    
    if CONFIG['log_wandb']:
        wandb.config = CONFIG
        wandb.finish()
        
    del model, train_loader, valid_loader
    flush()

In [None]:
def run(fold=0):
    model = Clf(backbone=CONFIG['backbone'], pretrained=True, increase_stride=False)
    train_ds = df[df['fold'] != fold].reset_index(drop=True)
    valid_ds = df[df['fold'] == fold].reset_index(drop=True)
    train_and_validate(model, train_ds, valid_ds, train_desc, fold=fold)
    gc.collect()
    flush()
    
run(fold=0)