In [1]:
%load_ext autoreload
%autoreload 2

### Imports

In [2]:
import os
import time
import wandb
import torch
import random
import torchvision

import numpy as np
import pandas as pd
import torchmetrics as tm 
# import plotly.express as px
import pytorch_lightning as pl
import matplotlib.pyplot as plt

from torch import nn
from pathlib import Path, PurePath
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, AdamW, RMSprop # optmizers
from sklearn import preprocessing 
# from warmup_scheduler import GradualWarmupScheduler
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau # Learning rate schedulers

import albumentations as A
# from albumentations.pytorch import ToTensorV2

import torch.nn.functional as F

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import Callback, LearningRateMonitor
from torchmetrics.wrappers import ClasswiseWrapper
from torchmetrics import MetricCollection
from torchmetrics.classification import MultilabelAccuracy, MultilabelPrecision, MultilabelRecall, MultilabelF1Score

import timm

In [3]:
print('timm version', timm.__version__)
print('torch version', torch.__version__)

timm version 1.0.8
torch version 2.3.1


In [4]:
wandb.login(key=os.getenv('wandb_api_key'))

wandb: Currently logged in as: rosu-lucian. Use `wandb login --relogin` to force relogin
wandb: Appending key for api.wandb.ai to your netrc file: C:\Users\Asus\.netrc


True

In [5]:
# detect and define device 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(device)

cuda


In [6]:
# for reproducibility
def seed_torch(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

### Config

In [7]:
# TODO: maybe use condition and level for classes
classes = ['SCS', 'RNFN', 'LNFN', 'LSS', 'RSS'] + ['H'] # add healthy class

# classes = ['SCS', 'RNFN', 'LNFN'] + ['H'] # add healthy class

# classes = ['LSS', 'RSS'] + ['H'] # add healthy class

# classes = ['SCSL1L2', 'SCSL2L3', 'SCSL3L4', 'SCSL4L5', 'SCSL5S1', 'RNFNL4L5',
#        'RNFNL5S1', 'RNFNL3L4', 'RNFNL1L2', 'RNFNL2L3', 'LNFNL1L2',
#        'LNFNL4L5', 'LNFNL5S1', 'LNFNL2L3', 'LNFNL3L4', 'LSSL1L2',
#        'RSSL1L2', 'LSSL2L3', 'RSSL2L3', 'LSSL3L4', 'RSSL3L4', 'LSSL4L5',
#        'RSSL4L5', 'LSSL5S1', 'RSSL5S1'] + ['H']

num_classes = len(classes)
class2id = {b: i for i, b in enumerate(classes)}

In [8]:
train_dir = Path('E:\data\RSNA2024')

class CFG:

    project = 'rsna-2'
    comment = 'bottleneck'

    ### model
    model_name = 'eca_nfnet_l0' # 'resnet34', 'resnet200d', 'efficientnet_b1_pruned', 'efficientnetv2_m', efficientnet_b7 

    image_size = 256
    
    ROOT_FOLDER = train_dir
    IMAGES_DIR = ROOT_FOLDER / 'train_images'
    PNG_DIR = ROOT_FOLDER / f'pngs_{image_size}'
    FILES_CSV = ROOT_FOLDER / 'train_files.csv'
    TRAIN_CSV = ROOT_FOLDER / 'train.csv'
    TRAIN_DESC_CSV = ROOT_FOLDER / 'train_series_descriptions.csv'
    COORDS_CSV = ROOT_FOLDER / 'train_label_coordinates.csv'

    # ckpt_path = Path(r"E:\data\RSNA2024\results\ckpt\eca_nfnet_l0 5e-05 10 eps all-labels\ep_03_loss_0.15231.ckpt")
    embeds_path = Path(r"E:\data\RSNA2024\embeddings")

    RESULTS_DIR = train_dir / 'results'
    CKPT_DIR = RESULTS_DIR / 'ckpt'

    input_dim = 64
    hidden_dim = 64
    target_size = 64

    classes = classes

    split_fraction = 0.95

    MIXUP = False

    ### training
    BATCH_SIZE = 1
    
    ### Optimizer
    N_EPOCHS = 10
    USE_SCHD = False
    WARM_EPOCHS = 3
    COS_EPOCHS = N_EPOCHS - WARM_EPOCHS

    # LEARNING_RATE = 5*1e-5 # best
    LEARNING_RATE = 5e-5
    
    weight_decay = 1e-6 # for adamw

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    ### split train and validation sets
    num_workers = 16

    random_seed = 42

CFG.N_LABELS = len(CFG.classes)

seed_torch(seed = CFG.random_seed)

In [9]:
CFG.N_LABELS 

6

### Load data

In [10]:
train_df = pd.read_csv(CFG.TRAIN_CSV)
train_desc_df = pd.read_csv(CFG.TRAIN_DESC_CSV)
coords_df = pd.read_csv(CFG.COORDS_CSV)
files_df = pd.read_csv(CFG.FILES_CSV)

train_df.shape, train_desc_df.shape, coords_df.shape, files_df.shape

((1975, 26), (6294, 3), (48692, 18), (147218, 21))

In [11]:
train_df.fillna('N', inplace=True)

In [12]:
train_df.head(2)

Unnamed: 0,study_id,SCSL1L2,SCSL2L3,SCSL3L4,SCSL4L5,SCSL5S1,LNFNL1L2,LNFNL2L3,LNFNL3L4,LNFNL4L5,...,LSSL1L2,LSSL2L3,LSSL3L4,LSSL4L5,LSSL5S1,RSSL1L2,RSSL2L3,RSSL3L4,RSSL4L5,RSSL5S1
0,4003253,N,N,N,N,N,N,N,N,M,...,N,N,N,M,N,N,N,N,N,N
1,4646740,N,N,M,S,N,N,N,N,M,...,N,N,N,S,N,N,M,M,M,N


In [13]:
le = preprocessing.LabelEncoder() 
le.fit(train_df.iloc[:, 1])

le.classes_
# foo = le.fit_transform(train_df.iloc[:,1])

array(['M', 'N', 'S'], dtype=object)

In [14]:
train_df.iloc[:,1:] = train_df.iloc[:,1:].apply(le.fit_transform)

In [15]:
train_df.head(2)

Unnamed: 0,study_id,SCSL1L2,SCSL2L3,SCSL3L4,SCSL4L5,SCSL5S1,LNFNL1L2,LNFNL2L3,LNFNL3L4,LNFNL4L5,...,LSSL1L2,LSSL2L3,LSSL3L4,LSSL4L5,LSSL5S1,RSSL1L2,RSSL2L3,RSSL3L4,RSSL4L5,RSSL5S1
0,4003253,1,1,1,1,1,1,1,1,0,...,1,1,1,0,1,1,1,1,1,1
1,4646740,1,1,0,2,1,1,1,1,0,...,1,1,1,2,1,1,0,0,0,1


In [16]:
coords_df.sample(2)

Unnamed: 0,study_id,series_id,instance,condition,level,x,y,ss_id,instance_id,cl,series_description,rows,columns,filename,patientposition,x_perc,y_perc,inst_perc
8470,765688458,1017097760,18,RSS,L2L3,113.866365,131.044168,765688458_1017097760,765688458_1017097760_18,RSSL2L3,Axial T2,256,256,E:\data\RSNA2024\pngs_256\765688458_1017097760...,HFS,0.44479,0.511891,0.425
37793,3329250043,134630734,5,RNFN,L5S1,378.126195,537.246654,3329250043_134630734,3329250043_134630734_5,RNFNL5S1,Sagittal T1,760,640,E:\data\RSNA2024\pngs_256\3329250043_134630734...,HFS,0.590822,0.706903,0.190476


In [17]:
coords_df.condition.unique()

array(['SCS', 'RNFN', 'LNFN', 'LSS', 'RSS'], dtype=object)

In [18]:
coords_df.cl.unique()

array(['SCSL1L2', 'SCSL2L3', 'SCSL3L4', 'SCSL4L5', 'SCSL5S1', 'RNFNL4L5',
       'RNFNL5S1', 'RNFNL3L4', 'RNFNL1L2', 'RNFNL2L3', 'LNFNL1L2',
       'LNFNL4L5', 'LNFNL5S1', 'LNFNL2L3', 'LNFNL3L4', 'LSSL1L2',
       'RSSL1L2', 'LSSL2L3', 'RSSL2L3', 'LSSL3L4', 'RSSL3L4', 'LSSL4L5',
       'RSSL4L5', 'LSSL5S1', 'RSSL5S1'], dtype=object)

In [19]:
coords_df.cl.nunique()

25

In [20]:
embed_files = os.listdir(CFG.embeds_path)

len(embed_files)

147218

In [21]:
study_id = 838134337

In [22]:
selected_files = files_df[files_df.study_id == study_id]

selected_files.head(2)

Unnamed: 0,study_id,series_id,image,proj,instancenumber,rows,columns,slicethickness,spacingbetweenslices,patientposition,...,ss_id,instance_id,filename,series_description,cl,condition,inst_min,inst_max,inst,inst_perc
140320,838134337,1285354049,1,19,1,448,448,4.0,4.48,HFS,...,838134337_1285354049,838134337_1285354049_1,E:\data\RSNA2024\pngs_256\838134337_1285354049...,Sagittal T2/STIR,H,H,1,20,0,0.0
140321,838134337,1285354049,10,-20,10,448,448,4.0,4.48,HFS,...,838134337_1285354049,838134337_1285354049_10,E:\data\RSNA2024\pngs_256\838134337_1285354049...,Sagittal T2/STIR,H,H,1,20,9,0.45


In [23]:
# selected_files.groupby('series_description').sort(['proj'])

In [24]:
for name, group in selected_files.sort_values('proj').groupby('series_description'):
    print(name)
    print(group.image.count())

Axial T2
63
Sagittal T1
20
Sagittal T2/STIR
20


In [25]:
files = files_df[files_df.study_id == study_id].instance_id.to_list()
files = [CFG.embeds_path / f'{f}.npy' for f in files]

files[0]

WindowsPath('E:/data/RSNA2024/embeddings/838134337_1285354049_1.npy')

In [26]:
np.load(files[2])

array([0.19990681, 0.13388744, 0.081677  , 0.7449569 , 0.11782445,
       0.71174335, 0.7239998 , 0.16039702, 0.23803055, 0.7323503 ,
       0.78617513, 0.9582367 , 0.12478404, 0.7033064 , 0.5421161 ,
       0.8636893 , 0.80034846, 0.81664217, 0.45245513, 0.88537997,
       0.676709  , 0.3442399 , 0.8683454 , 0.93591374, 0.22409391,
       0.9467593 , 0.42132142, 0.6168075 , 0.2643018 , 0.14936103,
       0.6911781 , 0.8419221 , 0.40297708, 0.65192336, 0.58653873,
       0.2654458 , 0.7226292 , 0.4850632 , 0.88341177, 0.21312091,
       0.5018644 , 0.28670755, 0.5729891 , 0.67702264, 0.93779194,
       0.33582947, 0.75234455, 0.8283155 , 0.40923488, 0.28497517,
       0.27483886, 0.31631535, 0.3859773 , 0.2297201 , 0.93978065,
       0.1383495 , 0.7521539 , 0.5270227 , 0.29936162, 0.9263395 ,
       0.8865664 , 0.9513482 , 0.71094906, 0.8872427 ], dtype=float32)

In [27]:
train_df[train_df.study_id == study_id].values.flatten().tolist()[1:]

[1, 1, 1, 2, 1, 1, 1, 1, 2, 0, 1, 1, 0, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1]

### Dataset

In [28]:
from dataset import rsna_lstm_dataset

In [29]:
dset = rsna_lstm_dataset(train_df, files_df, CFG)

print(dset.__len__())

seq, target = dset.__getitem__(1)
print(seq.shape, target.shape)
print(seq.dtype, target.dtype)

1975
torch.Size([88, 64]) torch.Size([25])
torch.float32 torch.int64


In [30]:
# seq dim: (bs, seq_len, 1, num_features)
# target dim: (N, d1)
target

tensor([1, 1, 0, 2, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 2, 1, 1, 0, 0, 0,
        1])

In [31]:
seq[0]

tensor([0.4239, 0.5589, 0.5724, 0.3233, 0.5413, 0.4105, 0.2450, 0.2933, 0.5783,
        0.3665, 0.2999, 0.3192, 0.6060, 0.4276, 0.7149, 0.6555, 0.4987, 0.2461,
        0.5538, 0.3502, 0.3036, 0.5713, 0.4331, 0.4855, 0.2895, 0.6407, 0.5820,
        0.2301, 0.5089, 0.4737, 0.4656, 0.6288, 0.4501, 0.5102, 0.2545, 0.6618,
        0.2734, 0.5166, 0.4853, 0.6639, 0.4464, 0.7453, 0.3942, 0.4599, 0.4279,
        0.5163, 0.4440, 0.4644, 0.7640, 0.5849, 0.7543, 0.5833, 0.7475, 0.7489,
        0.3839, 0.4694, 0.2550, 0.4448, 0.2281, 0.6770, 0.6206, 0.5454, 0.2657,
        0.5847])

### Data Module

In [32]:
from dataset import rsna_lstm_dataset

In [33]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn_padd(data):
    tensors, targets = zip(*data)
    features = pad_sequence(tensors, batch_first=True)
    targets = torch.stack(targets)
    return features, targets

In [34]:
class lstm_datamodule(pl.LightningDataModule):
    def __init__(self, train_df, val_df, files_df, cfg=CFG):
        super().__init__()
        
        self.train_df = train_df
        self.val_df = val_df
        self.files_df = files_df
        
        self.train_bs = cfg.BATCH_SIZE
        self.val_bs = cfg.BATCH_SIZE

        self.cfg = cfg
        
        self.num_workers = cfg.num_workers
        
    def train_dataloader(self):
        train_ds = rsna_lstm_dataset(self.train_df, self.files_df, self.cfg, mode='train')
        
        train_loader = torch.utils.data.DataLoader(
            train_ds,
            batch_size=self.train_bs,
            collate_fn=collate_fn_padd,
            pin_memory=False,
            drop_last=False,
            shuffle=True,
            # persistent_workers=True,
            num_workers=self.num_workers,
        )
        
        return train_loader
        
    def val_dataloader(self):
        val_ds = rsna_lstm_dataset(self.val_df, self.files_df, self.cfg, mode='val')
        
        val_loader = torch.utils.data.DataLoader(
            val_ds,
            batch_size=self.val_bs,
            collate_fn=collate_fn_padd,
            pin_memory=False,
            drop_last=False,
            shuffle=False,
            persistent_workers=True,
            num_workers=2,
        )
        
        return val_loader

In [None]:
t_df = train_df[:-100]
# t_df = pd.concat([meta_df[:-100], ul_df[:-100]], ignore_index=True)
v_df = train_df[-100:]

CFG2 = CFG()
# CFG2 = copy.deepcopy(CFG)
CFG2.BATCH_SIZE = 2
CFG2.num_workers = 2

dm = lstm_datamodule(t_df, v_df, files_df, cfg=CFG2)

x, y = next(iter(dm.train_dataloader()))
x.shape, y.shape, x.dtype, y.dtype

### Loss function

In [None]:
class FocalLossBCE(torch.nn.Module):
    def __init__(
            self,
            alpha: float = 0.25,
            gamma: float = 2,
            reduction: str = "mean",
            bce_weight: float = 1.0,
            focal_weight: float = 1.0,
    ):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.bce = torch.nn.BCEWithLogitsLoss(reduction=reduction)
        self.bce_weight = bce_weight
        self.focal_weight = focal_weight

    def forward(self, logits, targets):
        focall_loss = torchvision.ops.focal_loss.sigmoid_focal_loss(
            inputs=logits,
            targets=targets,
            alpha=self.alpha,
            gamma=self.gamma,
            reduction=self.reduction,
        )
        bce_loss = self.bce(logits, targets)
        return self.bce_weight * bce_loss + self.focal_weight * focall_loss

class GeM(torch.nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = torch.nn.Parameter(torch.ones(1) * p)
        self.eps = eps

    def forward(self, x):
        bs, ch, h, w = x.shape
        x = torch.nn.functional.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(
            1.0 / self.p)
        x = x.view(bs, ch)
        return x

### Model

In [None]:
class LSTMClassifier(pl.LightningModule):
    def __init__(self, cfg=CFG):
        super(LSTMClassifier, self).__init__()
        
        self.input_dim = cfg.input_dim
        self.hidden_dim = cfg.hidden_dim
        self.target_size = cfg.target_size

        self.criterion = torch.nn.CrossEntropyLoss()

        self.lstm = nn.LSTM(self.input_dim, self.hidden_dim, num_layers=1, batch_first=True)

        self.fc = nn.Linear(self.hidden_dim, self.target_size)


        self.classifiers = [nn.Linear(self.target_size, 3) for i in range(25)]

    def forward(self, sequence):

        #  seq: (seq_len, bs, num_features)
        lstm_out, (h, c) = self.lstm(sequence)
        
        y = self.fc(h[-1])

        preds = [c(y).T for c in self.classifiers]
        preds = torch.stack(preds).T
        
        return preds

    def step(self, batch, batch_idx, mode='train'):
        x, y = batch

        preds = self(x)

        loss = self.criterion(preds, y)

        return loss

#### building blocks

In [380]:
seq.shape, seq.view(1, len(seq), -1).shape, target.shape

(torch.Size([88, 64]), torch.Size([1, 88, 64]), torch.Size([25]))

In [385]:
target_size = 64

lstm = nn.LSTM(target_size, target_size, num_layers=1, batch_first=True)
fc = nn.Linear(target_size, target_size)
classifiers = [nn.Linear(target_size, 3) for i in range(25)]

lstm_out, (h, c) = lstm(torch.randn(5,88,64))
print(lstm_out.shape, h.shape, c.shape)

y = fc(h[-1])
print(y.shape)

preds = [c(y) for c in classifiers]
print(preds[0].shape)
preds = torch.stack(preds)

preds.shape

torch.Size([5, 88, 64]) torch.Size([1, 5, 64]) torch.Size([1, 5, 64])
torch.Size([5, 64])
torch.Size([5, 3])


torch.Size([25, 5, 3])

In [382]:
preds.T.shape

torch.Size([3, 2, 25])

#### Test out inputs/outputs

In [393]:
model = LSTMClassifier(CFG)

In [394]:
model.step((seq.view(1, len(seq), -1), target.view(1, len(target))), 0)

tensor(1.1086, grad_fn=<NllLoss2DBackward0>)

In [398]:
y = model(torch.randn(5,88,64))

y.shape, y.softmax(dim=0).shape

(torch.Size([5, 3, 25]), torch.Size([5, 3, 25]))

In [408]:
# y[0].softmax(dim=0)

In [406]:
y.softmax(dim=1)

tensor([[[0.3204, 0.3345, 0.2845, 0.2863, 0.3278, 0.2983, 0.2880, 0.3148,
          0.3732, 0.3555, 0.3143, 0.3387, 0.3531, 0.3479, 0.3106, 0.2951,
          0.3115, 0.2947, 0.3487, 0.3700, 0.3649, 0.3012, 0.3156, 0.3536,
          0.3207],
         [0.3293, 0.3398, 0.3507, 0.3339, 0.3022, 0.3600, 0.3437, 0.3333,
          0.3135, 0.3304, 0.2938, 0.3442, 0.3015, 0.2952, 0.3794, 0.3297,
          0.3115, 0.3461, 0.3402, 0.3220, 0.3145, 0.3895, 0.3419, 0.3348,
          0.3150],
         [0.3503, 0.3257, 0.3648, 0.3798, 0.3700, 0.3417, 0.3682, 0.3519,
          0.3133, 0.3142, 0.3918, 0.3171, 0.3454, 0.3570, 0.3100, 0.3753,
          0.3770, 0.3592, 0.3111, 0.3079, 0.3206, 0.3093, 0.3426, 0.3116,
          0.3644]],

        [[0.3135, 0.3460, 0.2832, 0.3341, 0.3206, 0.3214, 0.2901, 0.3127,
          0.3835, 0.3244, 0.3252, 0.3304, 0.3395, 0.3882, 0.2964, 0.2933,
          0.2921, 0.2930, 0.3521, 0.4045, 0.3615, 0.2916, 0.3266, 0.3539,
          0.3062],
         [0.3414, 0.3445, 0.3357, 

In [35]:
# y.softmax(dim=1).sum(dim=1)

### Split

### Train