In [141]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Imports

In [3]:
import os
import time
import wandb
import torch
import random
import torchvision

import numpy as np
import pandas as pd
import torchmetrics as tm 
# import plotly.express as px
import pytorch_lightning as pl
import matplotlib.pyplot as plt

from torch import nn
from pathlib import Path, PurePath
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, AdamW, RMSprop # optmizers
from sklearn import preprocessing 
# from warmup_scheduler import GradualWarmupScheduler
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau # Learning rate schedulers

import albumentations as A
# from albumentations.pytorch import ToTensorV2

import torch.nn.functional as F

import torchmetrics as tm

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import Callback, LearningRateMonitor
from torchmetrics.wrappers import ClasswiseWrapper
from torchmetrics import MetricCollection
from torchmetrics.classification import MultilabelAccuracy, MultilabelPrecision, MultilabelRecall, MultilabelF1Score
from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall, MulticlassF1Score

import timm

In [4]:
print('timm version', timm.__version__)
print('torch version', torch.__version__)

timm version 1.0.8
torch version 2.3.1


In [5]:
wandb.login(key=os.getenv('wandb_api_key'))

wandb: Currently logged in as: rosu-lucian. Use `wandb login --relogin` to force relogin
wandb: Appending key for api.wandb.ai to your netrc file: C:\Users\Asus\.netrc


True

In [6]:
# detect and define device 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(device)

cuda


In [7]:
# for reproducibility
def seed_torch(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

### Config

In [8]:
# TODO: maybe use condition and level for classes
classes = ['SCS', 'RNFN', 'LNFN', 'LSS', 'RSS'] + ['H'] # add healthy class

# classes = ['SCS', 'RNFN', 'LNFN'] + ['H'] # add healthy class

# classes = ['LSS', 'RSS'] + ['H'] # add healthy class

classes = ['SCSL1L2', 'SCSL2L3', 'SCSL3L4', 'SCSL4L5', 'SCSL5S1', 'RNFNL4L5',
       'RNFNL5S1', 'RNFNL3L4', 'RNFNL1L2', 'RNFNL2L3', 'LNFNL1L2',
       'LNFNL4L5', 'LNFNL5S1', 'LNFNL2L3', 'LNFNL3L4', 'LSSL1L2',
       'RSSL1L2', 'LSSL2L3', 'RSSL2L3', 'LSSL3L4', 'RSSL3L4', 'LSSL4L5',
       'RSSL4L5', 'LSSL5S1', 'RSSL5S1']

num_classes = len(classes)
class2id = {b: i for i, b in enumerate(classes)}

num_classes

25

In [9]:
train_dir = Path('E:\data\RSNA2024')

class CFG:

    project = 'lstm_new'
    comment = 'undersample_healthy'

    ### model
    model_name = 'lstm' # 'resnet34', 'resnet200d', 'efficientnet_b1_pruned', 'efficientnetv2_m', efficientnet_b7 

    image_size = 256
    
    ROOT_FOLDER = train_dir
    IMAGES_DIR = ROOT_FOLDER / 'train_images'
    PNG_DIR = ROOT_FOLDER / f'pngs_{image_size}'
    FILES_CSV = ROOT_FOLDER / 'train_files.csv'
    PREDS_CSV = ROOT_FOLDER / 'predictions.csv'
    TRAIN_CSV = ROOT_FOLDER / 'train.csv'
    TRAIN_DESC_CSV = ROOT_FOLDER / 'train_series_descriptions.csv'
    COORDS_CSV = ROOT_FOLDER / 'train_label_coordinates.csv'

    # ckpt_path = Path(r"E:\data\RSNA2024\results\ckpt\eca_nfnet_l0 5e-05 10 eps all-labels\ep_03_loss_0.15231.ckpt")
    embeds_path = Path(r"E:\data\RSNA2024\embeddings")
    stacked_path = Path(r"E:\data\RSNA2024\embeddings_stacked")

    RESULTS_DIR = train_dir / 'results'
    CKPT_DIR = RESULTS_DIR / 'ckpt'

    weighted_loss = True
    class_weights = [2.06762982, 0.42942998, 5.32804575]
    # class_weights = [1, 0.2, 1.5]
    
    num_layers=8

    input_dim = 128
    hidden_dim = 128
    target_size = 128

    classes = classes

    split_fraction = 0.95

    MIXUP = False

    ### training
    BATCH_SIZE = 16
    
    ### Optimizer
    N_EPOCHS = 30
    USE_SCHD = False
    WARM_EPOCHS = 3
    COS_EPOCHS = N_EPOCHS - WARM_EPOCHS

    # LEARNING_RATE = 5*1e-5 # best
    LEARNING_RATE = 5e-5
    
    weight_decay = 1e-6 # for adamw

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    ### split train and validation sets
    num_workers = 4

    random_seed = 42

# CFG.N_LABELS = len(CFG.classes)
CFG.N_LABELS = 3

seed_torch(seed = CFG.random_seed)

In [10]:
CFG.N_LABELS 

3

### Load data

In [11]:
train_df = pd.read_csv(CFG.TRAIN_CSV)
train_desc_df = pd.read_csv(CFG.TRAIN_DESC_CSV)
coords_df = pd.read_csv(CFG.COORDS_CSV)
files_df = pd.read_csv(CFG.FILES_CSV)
preds_df = pd.read_csv(CFG.PREDS_CSV)

train_df.shape, train_desc_df.shape, coords_df.shape, files_df.shape, preds_df.shape

((1975, 26), (6294, 3), (48692, 18), (147218, 21), (147218, 16))

In [12]:
train_desc_df['ss_id'] = train_desc_df.apply(lambda row: f'{row.study_id}_{row.series_id}', axis=1)
preds_df['ss_id'] = preds_df.apply(lambda row: f'{row.study_id}_{row.series_id}', axis=1)

train_desc_df.head(2)

Unnamed: 0,study_id,series_id,series_description,ss_id
0,4003253,702807833,Sagittal T2/STIR,4003253_702807833
1,4003253,1054713880,Sagittal T1,4003253_1054713880


In [159]:
coords_df.condition.unique()

array(['SCS', 'RNFN', 'LNFN', 'LSS', 'RSS'], dtype=object)

In [13]:
preds_df.head(2)

Unnamed: 0,ids,SCS,RNFN,LNFN,LSS,RSS,H,pred_SCS,pred_RNFN,pred_LNFN,pred_LSS,pred_RSS,pred_H,study_id,series_id,instance,ss_id
0,100206310_1012284084_1,0.0,0.0,0.0,0.0,0.0,1.0,0.008687,0.005704,0.003396,0.008272,0.010076,0.996422,100206310,1012284084,1,100206310_1012284084
1,100206310_1012284084_10,0.0,0.0,0.0,0.0,0.0,1.0,0.009786,0.005203,0.004467,0.010467,0.013595,0.994842,100206310,1012284084,10,100206310_1012284084


In [14]:
files_df.head(2)

Unnamed: 0,study_id,series_id,image,proj,instancenumber,rows,columns,slicethickness,spacingbetweenslices,patientposition,...,ss_id,instance_id,filename,series_description,cl,condition,inst_min,inst_max,inst,inst_perc
0,100206310,1012284084,1,-394,1,320,320,3.5,3.5,HFS,...,100206310_1012284084,100206310_1012284084_1,E:\data\RSNA2024\pngs_256\100206310_1012284084...,Axial T2,H,H,1,60,0,0.0
1,100206310,1012284084,10,-427,10,320,320,3.5,3.5,HFS,...,100206310_1012284084,100206310_1012284084_10,E:\data\RSNA2024\pngs_256\100206310_1012284084...,Axial T2,H,H,1,60,9,0.15


In [15]:
preds_df = pd.merge(preds_df, train_desc_df.loc[:, ['ss_id', 'series_description']],  how='inner', left_on=['ss_id'], right_on=['ss_id'])

preds_df.sample(2)

Unnamed: 0,ids,SCS,RNFN,LNFN,LSS,RSS,H,pred_SCS,pred_RNFN,pred_LNFN,pred_LSS,pred_RSS,pred_H,study_id,series_id,instance,ss_id,series_description
27165,1724850177_3909740603_2,0.0,0.0,0.0,0.0,0.0,1.0,0.009437,0.004596,0.004386,0.00776,0.00909,0.996385,1724850177,3909740603,2,1724850177_3909740603,Axial T2
109560,3871485386_2006212660_8,0.0,0.0,0.0,0.0,0.0,1.0,0.00679,0.003646,0.002439,0.013407,0.013598,0.997847,3871485386,2006212660,8,3871485386_2006212660,Sagittal T1


In [16]:
preds_df = pd.merge(preds_df, files_df.loc[:, ['instance_id', 'proj']],  how='inner', left_on=['ids'], right_on=['instance_id'])

preds_df.sample(2)

Unnamed: 0,ids,SCS,RNFN,LNFN,LSS,RSS,H,pred_SCS,pred_RNFN,pred_LNFN,pred_LSS,pred_RSS,pred_H,study_id,series_id,instance,ss_id,series_description,instance_id,proj
135360,696494906_2429429854_11,0.0,0.0,0.0,0.0,0.0,1.0,0.010126,0.004597,0.00354,0.014541,0.013506,0.976901,696494906,2429429854,11,696494906_2429429854,Axial T2,696494906_2429429854_11,-418
144302,937772835_1503533501_34,0.0,0.0,0.0,0.0,0.0,1.0,0.010032,0.005043,0.005893,0.011268,0.014073,0.989045,937772835,1503533501,34,937772835_1503533501,Axial T2,937772835_1503533501_34,-77


In [17]:
# train_desc_df[train_desc_df['series_id'] == 3909740603]

In [18]:
train_df.fillna('N', inplace=True)
train_df.head(2)

Unnamed: 0,study_id,SCSL1L2,SCSL2L3,SCSL3L4,SCSL4L5,SCSL5S1,LNFNL1L2,LNFNL2L3,LNFNL3L4,LNFNL4L5,...,LSSL1L2,LSSL2L3,LSSL3L4,LSSL4L5,LSSL5S1,RSSL1L2,RSSL2L3,RSSL3L4,RSSL4L5,RSSL5S1
0,4003253,N,N,N,N,N,N,N,N,M,...,N,N,N,M,N,N,N,N,N,N
1,4646740,N,N,M,S,N,N,N,N,M,...,N,N,N,S,N,N,M,M,M,N


In [19]:
le = preprocessing.LabelEncoder() 
le.fit(train_df.iloc[:, 1])

le.classes_
# foo = le.fit_transform(train_df.iloc[:,1])

array(['M', 'N', 'S'], dtype=object)

In [20]:
from sklearn.utils.class_weight import compute_class_weight

In [21]:
foo = train_df.iloc[:, 1:].to_numpy()

foo.shape

(1975, 25)

In [22]:
foo.flatten().shape

(49375,)

In [23]:
compute_class_weight(class_weight='balanced', classes=np.unique(foo.flatten()), y=foo.flatten())

array([2.06762982, 0.42942998, 5.32804575])

In [24]:
train_df.iloc[:,1:] = train_df.iloc[:,1:].apply(le.fit_transform)

In [25]:
train_df.head(2)

Unnamed: 0,study_id,SCSL1L2,SCSL2L3,SCSL3L4,SCSL4L5,SCSL5S1,LNFNL1L2,LNFNL2L3,LNFNL3L4,LNFNL4L5,...,LSSL1L2,LSSL2L3,LSSL3L4,LSSL4L5,LSSL5S1,RSSL1L2,RSSL2L3,RSSL3L4,RSSL4L5,RSSL5S1
0,4003253,1,1,1,1,1,1,1,1,0,...,1,1,1,0,1,1,1,1,1,1
1,4646740,1,1,0,2,1,1,1,1,0,...,1,1,1,2,1,1,0,0,0,1


In [26]:
coords_df.sample(2)

Unnamed: 0,study_id,series_id,instance,condition,level,x,y,ss_id,instance_id,cl,series_description,rows,columns,filename,patientposition,x_perc,y_perc,inst_perc
25515,2294509371,2895790445,7,SCS,L1L2,257.327262,138.787883,2294509371_2895790445,2294509371_2895790445_7,SCSL1L2,Sagittal T2/STIR,448,448,E:\data\RSNA2024\pngs_256\2294509371_289579044...,HFS,0.574391,0.309794,0.428571
35128,3116585926,3043780434,3,LNFN,L4L5,146.454106,321.545894,3116585926_3043780434,3116585926_3043780434_3,LNFNL4L5,Sagittal T1,512,512,E:\data\RSNA2024\pngs_256\3116585926_304378043...,FFS,0.286043,0.628019,0.181818


In [27]:
coords_df.condition.unique()

array(['SCS', 'RNFN', 'LNFN', 'LSS', 'RSS'], dtype=object)

In [28]:
coords_df.cl.unique()

array(['SCSL1L2', 'SCSL2L3', 'SCSL3L4', 'SCSL4L5', 'SCSL5S1', 'RNFNL4L5',
       'RNFNL5S1', 'RNFNL3L4', 'RNFNL1L2', 'RNFNL2L3', 'LNFNL1L2',
       'LNFNL4L5', 'LNFNL5S1', 'LNFNL2L3', 'LNFNL3L4', 'LSSL1L2',
       'RSSL1L2', 'LSSL2L3', 'RSSL2L3', 'LSSL3L4', 'RSSL3L4', 'LSSL4L5',
       'RSSL4L5', 'LSSL5S1', 'RSSL5S1'], dtype=object)

In [29]:
coords_df.cl.nunique()

25

#### Prepare

In [30]:
embed_files = os.listdir(CFG.embeds_path)
stacked_files = os.listdir(CFG.stacked_path)

len(embed_files), len(stacked_files)

(147219, 6295)

In [31]:
study_id = 838134337

In [32]:
selected_files = preds_df[preds_df.study_id == study_id]

selected_files.head(2)

Unnamed: 0,ids,SCS,RNFN,LNFN,LSS,RSS,H,pred_SCS,pred_RNFN,pred_LNFN,pred_LSS,pred_RSS,pred_H,study_id,series_id,instance,ss_id,series_description,instance_id,proj
140320,838134337_1285354049_1,0.0,0.0,0.0,0.0,0.0,1.0,0.008669,0.004251,0.002214,0.010968,0.010222,0.996511,838134337,1285354049,1,838134337_1285354049,Sagittal T2/STIR,838134337_1285354049_1,19
140321,838134337_1285354049_10,0.0,0.0,0.0,0.0,0.0,1.0,0.857816,0.003316,0.00282,0.012196,0.014888,0.031567,838134337,1285354049,10,838134337_1285354049,Sagittal T2/STIR,838134337_1285354049_10,-20


In [33]:
# selected_files.head()

In [34]:
# selected_files.groupby('series_description').sort(['proj'])

In [35]:
selected_files.sort_values(['series_description', 'proj'], ascending=[False, False])

Unnamed: 0,ids,SCS,RNFN,LNFN,LSS,RSS,H,pred_SCS,pred_RNFN,pred_LNFN,pred_LSS,pred_RSS,pred_H,study_id,series_id,instance,ss_id,series_description,instance_id,proj
140320,838134337_1285354049_1,0.0,0.0,0.0,0.0,0.0,1.0,0.008669,0.004251,0.002214,0.010968,0.010222,0.996511,838134337,1285354049,1,838134337_1285354049,Sagittal T2/STIR,838134337_1285354049_1,19
140331,838134337_1285354049_2,0.0,0.0,0.0,0.0,0.0,1.0,0.007301,0.003968,0.001861,0.006535,0.005958,0.997781,838134337,1285354049,2,838134337_1285354049,Sagittal T2/STIR,838134337_1285354049_2,15
140333,838134337_1285354049_3,0.0,0.0,0.0,0.0,0.0,1.0,0.006322,0.003869,0.001825,0.008667,0.007997,0.997514,838134337,1285354049,3,838134337_1285354049,Sagittal T2/STIR,838134337_1285354049_3,10
140334,838134337_1285354049_4,0.0,0.0,0.0,0.0,0.0,1.0,0.007862,0.004880,0.004075,0.008933,0.009013,0.997641,838134337,1285354049,4,838134337_1285354049,Sagittal T2/STIR,838134337_1285354049_4,6
140335,838134337_1285354049_5,0.0,0.0,0.0,0.0,0.0,1.0,0.007605,0.005034,0.002569,0.006341,0.007329,0.994792,838134337,1285354049,5,838134337_1285354049,Sagittal T2/STIR,838134337_1285354049_5,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
140414,838134337_3108613161_59,0.0,0.0,0.0,0.0,0.0,1.0,0.013585,0.007014,0.006603,0.545113,0.552875,0.110206,838134337,3108613161,59,838134337_3108613161,Axial T2,838134337_3108613161_59,-516
140416,838134337_3108613161_60,0.0,0.0,0.0,0.0,0.0,1.0,0.008732,0.005328,0.005256,0.088571,0.118986,0.698470,838134337,3108613161,60,838134337_3108613161,Axial T2,838134337_3108613161_60,-519
140417,838134337_3108613161_61,0.0,0.0,0.0,0.0,0.0,1.0,0.010470,0.006225,0.007601,0.019147,0.021621,0.960496,838134337,3108613161,61,838134337_3108613161,Axial T2,838134337_3108613161_61,-523
140418,838134337_3108613161_62,0.0,0.0,0.0,0.0,0.0,1.0,0.013279,0.007912,0.011143,0.015453,0.025385,0.979502,838134337,3108613161,62,838134337_3108613161,Axial T2,838134337_3108613161_62,-526


In [36]:
selected_files.study_id.unique()

array([838134337], dtype=int64)

In [37]:
idx = selected_files.sort_values(['series_description', 'proj'], ascending=[False, True]).index.to_list()
len(idx), selected_files.shape

(103, (103, 20))

In [38]:
stacked_embeds = np.load(CFG.embeds_path / 'stacked.npy')
stacked_embeds.shape

(147218, 128)

In [39]:
stacked_embeds[idx].shape

(103, 128)

In [40]:
# for name, group in selected_files.sort_values('proj').groupby('series_description'):
#     print(name)
#     print(group.image.count())

In [41]:
files = files_df[files_df.study_id == study_id].instance_id.to_list()
files = [CFG.embeds_path / f'{f}.npy' for f in files]

files[0]

WindowsPath('E:/data/RSNA2024/embeddings/838134337_1285354049_1.npy')

In [42]:
np.load(files[2])

array([-1.3868768 , -1.8670154 , -2.4197767 ,  1.0718939 , -2.0131953 ,
        0.9038663 ,  0.96438974, -1.655277  , -1.1635075 ,  1.0065799 ,
        1.302022  ,  3.1330764 , -1.9478861 ,  0.8630925 ,  0.16886447,
        1.846276  ,  1.3884737 ,  1.4937617 , -0.19075578,  2.0443947 ,
        0.73868847, -0.64445627,  1.8864077 ,  2.6812937 , -1.2419662 ,
        2.8782215 , -0.31735116,  0.47601956, -1.0237284 , -1.7396214 ,
        0.8056327 ,  1.6725999 , -0.3930759 ,  0.6275044 ,  0.34967494,
       -1.0178531 ,  0.95754117, -0.05976506,  2.0251436 , -1.3062149 ,
        0.00745753, -0.91142875,  0.2940572 ,  0.74012256,  2.7130435 ,
       -0.6819354 ,  1.111156  ,  1.5737361 , -0.36712924, -0.91991526,
       -0.9702089 , -0.7707571 , -0.4642533 , -1.2098924 ,  2.7476518 ,
       -1.8290665 ,  1.1101327 ,  0.10819634, -0.8503396 ,  2.5317743 ,
        2.0561385 ,  2.9731913 ,  0.89999783,  2.0628805 ], dtype=float32)

In [43]:
files = files_df[files_df.study_id == study_id].ss_id.unique().tolist()
files = [CFG.stacked_path / f'{f}.npy' for f in files]

len(files)

3

In [44]:
np.load(files[0]).shape

(20, 64)

In [45]:
train_df[train_df.study_id == study_id].values.flatten().tolist()[1:]

[1, 1, 1, 2, 1, 1, 1, 1, 2, 0, 1, 1, 0, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1]

In [46]:
train_desc_df.head(5)

Unnamed: 0,study_id,series_id,series_description,ss_id
0,4003253,702807833,Sagittal T2/STIR,4003253_702807833
1,4003253,1054713880,Sagittal T1,4003253_1054713880
2,4003253,2448190387,Axial T2,4003253_2448190387
3,4646740,3201256954,Axial T2,4646740_3201256954
4,4646740,3486248476,Sagittal T1,4646740_3486248476


In [47]:
foo = train_desc_df[train_desc_df.study_id == study_id].sort_values('series_description', ascending=False)

foo

Unnamed: 0,study_id,series_id,series_description,ss_id
1215,838134337,1285354049,Sagittal T2/STIR,838134337_1285354049
1216,838134337,1345841225,Sagittal T1,838134337_1345841225
1217,838134337,3108613161,Axial T2,838134337_3108613161


In [48]:
foo.ss_id.tolist()

['838134337_1285354049', '838134337_1345841225', '838134337_3108613161']

### Dataset

In [123]:
from dataset import rsna_lstm_dataset, rsna_lstm_dataset2

In [155]:
preds_df[preds_df.RSS > 0].head(15)

Unnamed: 0,ids,SCS,RNFN,LNFN,LSS,RSS,H,pred_SCS,pred_RNFN,pred_LNFN,pred_LSS,pred_RSS,pred_H,study_id,series_id,instance,ss_id,series_description,instance_id,proj


In [157]:
preds_df['RSS'].unique()

array([0.])

In [150]:
preds_df.sample(10)

Unnamed: 0,ids,SCS,RNFN,LNFN,LSS,RSS,H,pred_SCS,pred_RNFN,pred_LNFN,pred_LSS,pred_RSS,pred_H,study_id,series_id,instance,ss_id,series_description,instance_id,proj
9633,1246802434_2530289455_5,0.0,0.0,0.0,0.0,0.0,1.0,0.014651,0.005698,0.007709,0.360556,0.368441,0.381716,1246802434,2530289455,5,1246802434_2530289455,Axial T2,1246802434_2530289455_5,-432
99853,3612685875_2943022937_2,0.0,0.0,0.0,0.0,0.0,1.0,0.009616,0.004076,0.007683,0.012329,0.015736,0.994216,3612685875,2943022937,2,3612685875_2943022937,Sagittal T1,3612685875_2943022937_2,-9
54587,2455924281_1645353593_18,0.0,0.0,0.0,0.0,0.0,1.0,0.010334,0.005059,0.005969,0.01896,0.024023,0.985709,2455924281,1645353593,18,2455924281_1645353593,Sagittal T1,2455924281_1645353593_18,-39
33948,1897045431_220072654_5,0.0,0.0,0.0,0.0,0.0,1.0,0.012641,0.00655,0.008909,0.018371,0.01766,0.980799,1897045431,220072654,5,1897045431_220072654,Axial T2,1897045431_220072654_5,-5
5542,113121178_2624376155_11,0.0,0.0,0.0,0.0,0.0,1.0,0.006214,0.003586,0.004145,0.00724,0.008384,0.994022,113121178,2624376155,11,113121178_2624376155,Axial T2,113121178_2624376155_11,-468
132350,625376596_338048129_11,0.0,0.0,0.0,0.0,0.0,1.0,0.007917,0.003528,0.003439,0.111915,0.083025,0.736104,625376596,338048129,11,625376596_338048129,Axial T2,625376596_338048129_11,-382
1442,1039182563_814821691_16,0.0,0.0,0.0,0.0,0.0,1.0,0.006231,0.003703,0.004742,0.006956,0.008886,0.994398,1039182563,814821691,16,1039182563_814821691,Axial T2,1039182563_814821691_16,-3
10100,1258848546_3003067430_22,0.0,0.0,0.0,0.0,0.0,1.0,0.008835,0.003534,0.004563,0.014118,0.011669,0.994157,1258848546,3003067430,22,1258848546_3003067430,Sagittal T2/STIR,1258848546_3003067430_22,-50
8008,11943292_1212326388_7,0.0,0.0,0.0,0.0,0.0,1.0,0.004609,0.003296,0.002554,0.009371,0.009944,0.998347,11943292,1212326388,7,11943292_1212326388,Sagittal T1,11943292_1212326388_7,-7
61784,2626030939_956892659_2,0.0,0.0,0.0,0.0,0.0,1.0,0.008317,0.005151,0.008881,0.009005,0.010837,0.994311,2626030939,956892659,2,2626030939_956892659,Sagittal T2/STIR,2626030939_956892659_2,-18


In [122]:
preds_df.instance_id.nunique()

147218

In [129]:
dset = rsna_lstm_dataset(train_df, train_desc_df, CFG.stacked_path)
dset = rsna_lstm_dataset2(train_df, preds_df, CFG.stacked_path)

print(dset.__len__())

seq, target = dset.__getitem__(10)
print(seq.shape, target.shape)
print(seq.dtype, target.dtype)

1975
116
29931867
torch.Size([116, 128]) torch.Size([25])
torch.float32 torch.int64


In [51]:
# preds_df.sample(frac=0.1)

In [52]:
# seq dim: (bs, seq_len, 1, num_features)
# target dim: (N, d1)
target

tensor([1, 1, 0, 2, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 2, 1, 1, 0, 0, 0,
        1])

In [53]:
seq[0]

tensor([-0.5228, -0.1782, -0.6903,  0.5997, -0.4489, -1.0280,  0.2119, -0.6594,
         0.6898, -1.1903,  0.3165,  1.5068, -0.7865,  0.5029, -1.0719,  0.4189,
         0.5564,  0.4282, -0.2935, -0.9852,  0.3526, -0.4669, -0.5027,  0.3444,
        -0.7786, -0.7690, -0.7657, -1.0102, -0.3332, -0.8924,  1.2580, -0.1102,
         0.7869,  0.1274,  1.0775, -0.6019,  0.0730, -0.1483, -1.4893,  0.1967,
         0.3407, -0.0978,  0.4087, -0.7306, -1.0403,  0.1851, -0.8510,  0.3468,
        -0.2782, -0.1983,  0.0975, -0.9890,  0.3682,  0.4134,  1.1431, -0.3665,
         0.8837, -0.5044, -0.7994,  0.0136,  0.3304, -0.3338, -0.0366,  1.3782,
         1.3190, -0.3707,  0.5600,  0.5615, -0.3195, -0.6933,  0.1179,  0.1510,
         0.1354, -0.2475,  0.8667, -0.2318, -1.0075,  0.5286, -0.1825, -1.2873,
         0.5826, -0.1120, -0.3456,  0.4346, -0.1036, -0.9096, -0.8098, -0.1004,
         0.2509, -0.3196, -0.0296,  1.3002,  0.2478,  0.4836, -1.7138,  0.6634,
        -0.4320,  0.2279, -0.1096, -0.63

### Data Module

In [54]:
from dataset import rsna_lstm_dataset, rsna_lstm_dataset2, collate_fn_padd

In [55]:
train_desc_df.sample()

Unnamed: 0,study_id,series_id,series_description,ss_id
800,532925408,476707229,Sagittal T1,532925408_476707229


In [56]:
# from torch.nn.utils.rnn import pad_sequence

# def collate_fn_padd(data):
#     tensors, targets = zip(*data)
#     features = pad_sequence(tensors, batch_first=True)
#     targets = torch.stack(targets)
#     return features, targets
    
class lstm_datamodule(pl.LightningDataModule):
    # def __init__(self, train_df, val_df, train_desc_df, cfg):
    def __init__(self, train_df, val_df, preds_df, cfg):
        super().__init__()
        
        self.train_df = train_df
        self.val_df = val_df
        # self.train_desc_df = train_desc_df
        self.preds_df = preds_df
        
        self.train_bs = cfg.BATCH_SIZE
        self.val_bs = cfg.BATCH_SIZE

        self.cfg = cfg
        self.path = cfg.stacked_path
        
        self.num_workers = cfg.num_workers
        
    def train_dataloader(self):
        train_ds = rsna_lstm_dataset2(self.train_df, self.preds_df, self.path)
        
        train_loader = torch.utils.data.DataLoader(
            train_ds,
            batch_size=self.train_bs,
            collate_fn=collate_fn_padd,
            pin_memory=False,
            drop_last=False,
            shuffle=True,
            persistent_workers=True,
            num_workers=self.num_workers,
        )
        
        return train_loader
        
    def val_dataloader(self):
        val_ds = rsna_lstm_dataset2(self.val_df, self.preds_df, self.path)
        
        val_loader = torch.utils.data.DataLoader(
            val_ds,
            batch_size=self.val_bs,
            collate_fn=collate_fn_padd,
            pin_memory=False,
            drop_last=False,
            shuffle=False,
            persistent_workers=True,
            num_workers=2,
        )
        
        return val_loader

In [57]:
t_df = train_df[:-100]
# t_df = pd.concat([meta_df[:-100], ul_df[:-100]], ignore_index=True)
v_df = train_df[-100:]

CFG2 = CFG()
# CFG2 = copy.deepcopy(CFG)
CFG2.BATCH_SIZE = 8
CFG2.num_workers = 4

# dm = lstm_datamodule(t_df, v_df, train_desc_df, CFG2)
dm = lstm_datamodule(t_df, v_df, preds_df, CFG2)

x, y = next(iter(dm.train_dataloader()))
x.shape, y.shape, x.dtype, y.dtype

(torch.Size([8, 87, 128]), torch.Size([8, 25]), torch.float32, torch.int64)

In [130]:
x[0]

tensor([[-0.2320,  0.0143, -0.9939,  ...,  0.2312,  0.5752,  1.1324],
        [-0.3192, -0.1989, -1.1888,  ...,  0.2762,  0.4888,  1.5444],
        [-0.2673, -0.1179, -0.7755,  ...,  0.2152,  0.3759,  1.3764],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]])

In [138]:
y[0]

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1])

In [140]:
y[2]

tensor([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 2, 0,
        1])

In [58]:
del dm

### Loss function

In [59]:
class FocalLossBCE(torch.nn.Module):
    def __init__(
            self,
            alpha: float = 0.25,
            gamma: float = 2,
            reduction: str = "mean",
            bce_weight: float = 1.0,
            focal_weight: float = 1.0,
    ):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.bce = torch.nn.BCEWithLogitsLoss(reduction=reduction)
        self.bce_weight = bce_weight
        self.focal_weight = focal_weight

    def forward(self, logits, targets):
        focall_loss = torchvision.ops.focal_loss.sigmoid_focal_loss(
            inputs=logits,
            targets=targets,
            alpha=self.alpha,
            gamma=self.gamma,
            reduction=self.reduction,
        )
        bce_loss = self.bce(logits, targets)
        return self.bce_weight * bce_loss + self.focal_weight * focall_loss

In [60]:
# logprobs = F.cross_entropy(input, target, reduction='none')
# at = at.view(-1, len(alphas))
# pt = torch.exp(-logprobs)
# focal_loss = at*(1-pt)** gamma * logprobs
# return focal_loss.mean()

### Model

#### Definition

In [61]:
CFG.N_LABELS

3

In [62]:
# torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]).shape, torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]).shape

In [63]:
class LSTMClassifier(pl.LightningModule):
    def __init__(self, cfg=CFG):
        super(LSTMClassifier, self).__init__()

        self.cfg = cfg
        
        self.input_dim = cfg.input_dim
        self.hidden_dim = cfg.hidden_dim
        
        self.levels = 25
        self.classes = cfg.N_LABELS

        # https://discuss.pytorch.org/t/pytorchs-non-deterministic-cross-entropy-loss-and-the-problem-of-reproducibility/172180/9
        # reduction is set to none
        # self.criterion = torch.nn.CrossEntropyLoss(reduction='none', weight=torch.tensor([1,0.1,2]))
        weight = None
        if self.cfg.weighted_loss:
            weight = torch.tensor(cfg.class_weights)
            
        self.criterion = torch.nn.CrossEntropyLoss(reduction='none', weight=weight)

        self.lstm = nn.LSTM(self.input_dim, self.hidden_dim, num_layers=cfg.num_layers, batch_first=True)
        self.fc = nn.Linear(self.hidden_dim, self.levels * self.classes)
        # self.classifiers = torch.nn.ModuleList([nn.Linear(self.target_size, cfg.N_LABELS) for i in range(25)])

        # no average maccs
        macc_serverity = ClasswiseWrapper(MulticlassAccuracy(
            num_classes=self.cfg.N_LABELS,
            average='none', 
            multidim_average='global'
            # label encoder classes
        ), labels=le.classes_.tolist(), prefix='multiacc_severity/')

        macc_levels = ClasswiseWrapper(MulticlassAccuracy(
            num_classes=self.cfg.N_LABELS,
            average='none', 
            multidim_average='global'
        ), labels=classes, prefix='multiacc_levels/')
        
        metrics = MetricCollection({
            'macc': MulticlassAccuracy(num_classes=self.cfg.N_LABELS),
            'macc_none': macc_serverity,
            'mpr': MulticlassPrecision(num_classes=self.cfg.N_LABELS),
            'mrec': MulticlassRecall(num_classes=self.cfg.N_LABELS),
            'f1': MulticlassF1Score(num_classes=self.cfg.N_LABELS)
        })

        self.train_metrics = metrics.clone(prefix='train/')
        self.valid_metrics = metrics.clone(prefix='val/')

    def forward(self, sequence):
        #  seq: (seq_len, bs, num_features)
        lstm_out, (h, c) = self.lstm(sequence)
        
        y = self.fc(h[-1])

        # (N,C,d1) -> (N,3,25)
        return y.view(-1, self.classes, self.levels)

    def step(self, batch, batch_idx, mode='train'):
        x, y = batch

        preds = self(x)

        loss = self.criterion(preds, y)

        # https://discuss.pytorch.org/t/pytorchs-non-deterministic-cross-entropy-loss-and-the-problem-of-reproducibility/172180/9
        loss = loss.mean()

        # print(preds.shape, y.shape)

        if mode == 'train':
            output = self.train_metrics(preds, y)
            self.log_dict(output)
        else:
            self.valid_metrics.update(preds, y)

        self.log(f'{mode}/loss', loss, on_step=True, on_epoch=True)

        return loss

    def training_step(self, batch, batch_idx):
        loss = self.step(batch, batch_idx, mode='train')
        
        return loss
        
    def validation_step(self, batch, batch_idx):
        loss = self.step(batch, batch_idx, mode='val')
    
        return loss

    def on_train_epoch_end(self):
        self.train_metrics.reset()

    def on_validation_epoch_end(self):
        output = self.valid_metrics.compute()
        self.log_dict(output)

        self.valid_metrics.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.LEARNING_RATE, weight_decay=CFG.weight_decay)
        
        if self.cfg.USE_SCHD:
            scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.cfg.COS_EPOCHS)
            scheduler_warmup = GradualWarmupSchedulerV2(optimizer, multiplier=10, total_epoch=self.cfg.WARM_EPOCHS, after_scheduler=scheduler_cosine)

            return [optimizer], [scheduler_warmup]
        else:
            return optimizer

#### building blocks

In [64]:
seq.shape, seq.view(1, len(seq), -1).shape, target.shape

(torch.Size([70, 128]), torch.Size([1, 70, 128]), torch.Size([25]))

In [163]:
target_size = 64

lstm = nn.LSTM(target_size, target_size, num_layers=1, batch_first=True)
fc = nn.Linear(target_size, target_size)
classifiers = [nn.Linear(target_size, 3) for i in range(25)]

lstm_out, (h, c) = lstm(torch.randn(5,88,64))
print(lstm_out.shape, h.shape, c.shape)

y = fc(h[-1])
print('fc shape:', y.shape)

preds = [c(y).T for c in classifiers]
print('pred shape:', preds[0].shape)

preds = torch.stack(preds).T

print('preds shape:', preds.shape)

torch.Size([5, 88, 64]) torch.Size([1, 5, 64]) torch.Size([1, 5, 64])
fc shape: torch.Size([5, 64])
pred shape: torch.Size([3, 5])
preds shape: torch.Size([5, 3, 25])


In [164]:
h.shape, c.shape

(torch.Size([1, 5, 64]), torch.Size([1, 5, 64]))

In [165]:
target_size = 64

lstm = nn.LSTM(target_size, target_size, num_layers=16, batch_first=True)
fc = nn.Linear(target_size, target_size)
classifiers = [nn.Linear(target_size, 3) for i in range(25)]

lstm_out, (h, c) = lstm(torch.randn(5,88,64))
print(lstm_out.shape, h.shape, c.shape)

y = fc(h[-1])
print('fc shape:', y.shape)

preds = [c(y).T for c in classifiers]
print('pred shape:', preds[0].shape)

preds = torch.stack(preds).T

print('preds shape:', preds.shape)

torch.Size([5, 88, 64]) torch.Size([16, 5, 64]) torch.Size([16, 5, 64])
fc shape: torch.Size([5, 64])
pred shape: torch.Size([3, 5])
preds shape: torch.Size([5, 3, 25])


In [166]:
h.shape, c.shape

(torch.Size([16, 5, 64]), torch.Size([16, 5, 64]))

In [67]:
target_size = 64

lstm = nn.LSTM(target_size, target_size, num_layers=1, batch_first=True)
fc = nn.Linear(target_size, 75)

criterion = torch.nn.CrossEntropyLoss(reduction='none')

lstm_out, (h, c) = lstm(torch.randn(5,88,64))
print(lstm_out.shape, h.shape, c.shape)

pred = fc(h[-1])
print('fc shape:', y.shape)

pred = pred.view(-1, 3, 25)
print('preds shape:', pred.shape)

targets = torch.stack(5*[target])
loss = criterion(preds, targets)

print('loss:', loss.shape)

torch.Size([5, 88, 64]) torch.Size([1, 5, 64]) torch.Size([1, 5, 64])
fc shape: torch.Size([5, 64])
preds shape: torch.Size([5, 3, 25])
loss: torch.Size([5, 25])


In [68]:
preds.argmax(dim=1).shape, targets.shape

(torch.Size([5, 25]), torch.Size([5, 25]))

In [69]:
preds.argmax(dim=1)

tensor([[1, 1, 0, 0, 0, 1, 2, 1, 0, 2, 2, 1, 2, 2, 0, 1, 1, 0, 0, 2, 2, 1, 1, 2,
         2],
        [1, 1, 0, 0, 0, 1, 2, 1, 0, 2, 2, 1, 2, 2, 0, 1, 1, 0, 0, 2, 2, 1, 1, 2,
         2],
        [1, 1, 0, 0, 0, 1, 2, 1, 0, 2, 2, 1, 2, 2, 0, 1, 1, 0, 0, 2, 2, 1, 1, 2,
         2],
        [1, 1, 0, 0, 0, 1, 2, 1, 0, 2, 2, 1, 2, 2, 0, 1, 1, 0, 0, 2, 2, 1, 1, 2,
         2],
        [1, 1, 0, 0, 0, 1, 2, 1, 0, 2, 2, 1, 2, 2, 0, 1, 1, 0, 0, 2, 2, 1, 1, 2,
         2]])

In [70]:
# preds.max(dim=1)

In [71]:
preds.max(dim=1)[1].shape, preds.argmax(1).shape

(torch.Size([5, 25]), torch.Size([5, 25]))

In [72]:
zeros = torch.zeros(preds.shape, dtype=preds.dtype)
ones = torch.ones(preds.shape, dtype=preds.dtype)

max_vals = preds.argmax(1)

foo = zeros.scatter(1, max_vals.unsqueeze(1), 1)

zeros.shape, ones.shape, foo.shape, preds.shape, foo.swapaxes(1,2).shape

(torch.Size([5, 3, 25]),
 torch.Size([5, 3, 25]),
 torch.Size([5, 3, 25]),
 torch.Size([5, 3, 25]),
 torch.Size([5, 25, 3]))

In [73]:
a = torch.arange(15*25).reshape(5,3,25)
a.shape, a.swapaxes(1,2).shape

(torch.Size([5, 3, 25]), torch.Size([5, 25, 3]))

In [74]:
zeros.scatter(1, targets.unsqueeze(1), 1).swapaxes(1,2).shape

torch.Size([5, 25, 3])

In [75]:
# acc = tm.functional.classification.multilabel_accuracy(foo, targets, 25, )
acc = tm.functional.classification.multilabel_accuracy(foo.swapaxes(1,2), zeros.scatter(1, targets.unsqueeze(1), 1).swapaxes(1,2), 25, average='none')
acc

tensor([1.0000, 1.0000, 1.0000, 0.3333, 0.3333, 1.0000, 0.3333, 1.0000, 1.0000,
        0.3333, 0.3333, 1.0000, 0.3333, 0.3333, 0.3333, 1.0000, 1.0000, 0.3333,
        0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333])

In [76]:
torch.stack(5*[target])
preds.shape, targets.shape

(torch.Size([5, 3, 25]), torch.Size([5, 25]))

In [77]:
metric = MulticlassAccuracy(num_classes=CFG.N_LABELS, average='none', multidim_average='samplewise')
macc = MulticlassAccuracy(num_classes=CFG.N_LABELS, average='none')

m = metric(preds, targets)

m.shape, m

(torch.Size([5, 3]),
 tensor([[0.2500, 0.4667, 0.0000],
         [0.2500, 0.4667, 0.0000],
         [0.2500, 0.4667, 0.0000],
         [0.2500, 0.4667, 0.0000],
         [0.2500, 0.4667, 0.0000]]))

In [78]:
macc(preds, targets)

tensor([0.2500, 0.4667, 0.0000])

#### Test out inputs/outputs

In [79]:
model = LSTMClassifier(CFG)

In [80]:
model.forward((seq.view(1, len(seq), -1))).shape

torch.Size([1, 3, 25])

In [81]:
model.step((seq.view(1, len(seq), -1), target.view(1, len(target))), 0)

C:\ProgramData\anaconda3\envs\rsna\lib\site-packages\pytorch_lightning\core\module.py:445: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


tensor(1.4628, grad_fn=<MeanBackward0>)

In [82]:
y = model(torch.randn(5,88,128))

y.shape, y.softmax(dim=0).shape

(torch.Size([5, 3, 25]), torch.Size([5, 3, 25]))

In [83]:
y.softmax(dim=1).shape

torch.Size([5, 3, 25])

In [84]:
# y.softmax(dim=1).sum(dim=1)

In [85]:
# y[0].softmax(dim=0)

### Split

In [86]:
from sklearn.model_selection import ShuffleSplit, StratifiedShuffleSplit

In [87]:
train_df.shape

(1975, 26)

In [88]:
train_df.sample(2)

Unnamed: 0,study_id,SCSL1L2,SCSL2L3,SCSL3L4,SCSL4L5,SCSL5S1,LNFNL1L2,LNFNL2L3,LNFNL3L4,LNFNL4L5,...,LSSL1L2,LSSL2L3,LSSL3L4,LSSL4L5,LSSL5S1,RSSL1L2,RSSL2L3,RSSL3L4,RSSL4L5,RSSL5S1
125,269259654,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,1,1,1,1
1928,4179331034,1,1,1,1,1,1,1,1,1,...,1,1,1,2,1,1,1,1,0,1


In [89]:
train_df.iloc[:,1:].shape

(1975, 25)

In [90]:
# TODO: data split

In [91]:
# sss = StratifiedShuffleSplit(n_splits=1, test_size=1-CFG.split_fraction, random_state=CFG.random_seed)
# train_idx, val_idx = next(sss.split(train_df.study_id, train_df.iloc[:,1:]))

# t_df = train_df.iloc[train_idx]
# v_df = train_df.iloc[val_idx]

# t_df.shape, v_df.shape

### Train

In [92]:
CFG.BATCH_SIZE, CFG.device

(16, 'cuda')

In [93]:
train_desc_df.head()

Unnamed: 0,study_id,series_id,series_description,ss_id
0,4003253,702807833,Sagittal T2/STIR,4003253_702807833
1,4003253,1054713880,Sagittal T1,4003253_1054713880
2,4003253,2448190387,Axial T2,4003253_2448190387
3,4646740,3201256954,Axial T2,4646740_3201256954
4,4646740,3486248476,Sagittal T1,4646740_3486248476


In [94]:
# dm = lstm_datamodule(t_df, v_df, train_desc_df, cfg=CFG)
dm = lstm_datamodule(t_df, v_df, preds_df, cfg=CFG)

len(dm.train_dataloader()), len(dm.val_dataloader())

(118, 7)

In [95]:
run_name = f'{CFG.model_name} {CFG.LEARNING_RATE} {CFG.N_EPOCHS} eps {CFG.num_layers}l-{CFG.comment}'
run_name

'lstm 5e-05 30 eps 8l-undersample_healthy'

In [96]:
wandb_logger = WandbLogger(
    name=run_name,
    project=CFG.project,
    job_type='train',
    save_dir=CFG.RESULTS_DIR,
    # config=cfg,
)

loss_ckpt = pl.callbacks.ModelCheckpoint(
    monitor='val/loss',
    auto_insert_metric_name=False,
    dirpath=CFG.CKPT_DIR / run_name,
    filename='ep_{epoch:02d}_loss_{val/loss:.5f}',
    save_top_k=2,
    mode='min',
)

# acc_ckpt = pl.callbacks.ModelCheckpoint(
#     monitor='val/acc',
#     auto_insert_metric_name=False,
#     dirpath=CFG.CKPT_DIR / run_name,
#     filename='ep_{epoch:02d}_acc_{val/acc:.5f}',
#     save_top_k=2,
#     mode='max',
# )

lr_monitor = LearningRateMonitor(logging_interval='step')

In [97]:
trainer = pl.Trainer(
    max_epochs=CFG.N_EPOCHS,
    deterministic=True,
    accelerator=CFG.device,
    default_root_dir=CFG.RESULTS_DIR,
    gradient_clip_val=0.5, 
    # gradient_clip_algorithm="value",
    logger=wandb_logger,
    callbacks=[loss_ckpt, lr_monitor],
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


#### Fit

In [98]:
model = LSTMClassifier(CFG)

In [99]:
trainer.fit(model, dm)

You are using a CUDA device ('NVIDIA GeForce RTX 4090 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01127777777777131, max=1.0)…

C:\ProgramData\anaconda3\envs\rsna\lib\site-packages\pytorch_lightning\callbacks\model_checkpoint.py:654: Checkpoint directory E:\data\RSNA2024\results\ckpt\lstm 5e-05 30 eps 8l-undersample_healthy exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | criterion     | CrossEntropyLoss | 0      | train
1 | lstm          | LSTM             | 1.1 M  | train
2 | fc            | Linear           | 9.7 K  | train
3 | train_metrics | MetricCollection | 0      | train
4 | valid_metrics | MetricCollection | 0      | train
-----------------------------------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.266     Total estimated model params size (MB)
17        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                        | 0/? [00:00<?, ?it/…

Training: |                                                                               | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…



Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

`Trainer.fit` stopped: `max_epochs=30` reached.


### Predict

In [100]:
x, y = next(iter(dm.train_dataloader()))

x.shape, y.shape

(torch.Size([16, 76, 128]), torch.Size([16, 25]))

In [102]:
# pred = model(x.to(CFG.device)).detach().cpu()
pred = model(x).detach().cpu()
pred.shape

torch.Size([16, 3, 25])

In [116]:
torch.where(y==1, 1, 0)[:3]

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,
         0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0,
         1]])

In [103]:
y

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1,
         0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 2, 0,
         1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1],
        [1, 2, 2, 0, 1, 1, 2, 1, 0, 2, 1, 0, 0, 0, 0, 1, 0, 2, 0, 1, 0, 2, 2, 0,
         1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1],
        [1, 0, 2, 2, 1, 1, 1, 1, 0, 2, 1, 0, 0, 2, 0, 1, 2, 2, 2, 2, 1, 0, 2, 2,
         0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0,
         0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1],
        [1, 1, 1, 1, 1, 1, 0, 0, 2, 1, 1, 1, 0, 2, 0, 1, 1, 

In [104]:
train_df.head()

Unnamed: 0,study_id,SCSL1L2,SCSL2L3,SCSL3L4,SCSL4L5,SCSL5S1,LNFNL1L2,LNFNL2L3,LNFNL3L4,LNFNL4L5,...,LSSL1L2,LSSL2L3,LSSL3L4,LSSL4L5,LSSL5S1,RSSL1L2,RSSL2L3,RSSL3L4,RSSL4L5,RSSL5S1
0,4003253,1,1,1,1,1,1,1,1,0,...,1,1,1,0,1,1,1,1,1,1
1,4646740,1,1,0,2,1,1,1,1,0,...,1,1,1,2,1,1,0,0,0,1
2,7143189,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,1,1,1,1
3,8785691,1,1,1,1,1,1,1,1,0,...,1,1,1,1,1,1,1,1,1,1
4,10728036,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,1,1,0,1


In [105]:
pred.softmax(dim=1)

tensor([[[0.1311, 0.2310, 0.2676,  ..., 0.3676, 0.3069, 0.3503],
         [0.7625, 0.5590, 0.3816,  ..., 0.2235, 0.0949, 0.2537],
         [0.1065, 0.2101, 0.3508,  ..., 0.4090, 0.5982, 0.3959]],

        [[0.1311, 0.2310, 0.2676,  ..., 0.3676, 0.3069, 0.3503],
         [0.7624, 0.5589, 0.3816,  ..., 0.2235, 0.0950, 0.2537],
         [0.1065, 0.2101, 0.3508,  ..., 0.4089, 0.5981, 0.3959]],

        [[0.1311, 0.2310, 0.2676,  ..., 0.3676, 0.3069, 0.3503],
         [0.7625, 0.5590, 0.3816,  ..., 0.2235, 0.0949, 0.2537],
         [0.1065, 0.2101, 0.3508,  ..., 0.4090, 0.5982, 0.3959]],

        ...,

        [[0.1311, 0.2310, 0.2676,  ..., 0.3676, 0.3069, 0.3503],
         [0.7623, 0.5588, 0.3816,  ..., 0.2235, 0.0950, 0.2538],
         [0.1066, 0.2101, 0.3508,  ..., 0.4089, 0.5980, 0.3959]],

        [[0.1311, 0.2310, 0.2676,  ..., 0.3676, 0.3069, 0.3503],
         [0.7624, 0.5589, 0.3816,  ..., 0.2235, 0.0950, 0.2538],
         [0.1065, 0.2101, 0.3508,  ..., 0.4089, 0.5981, 0.3959]],

 

In [106]:
pred.argmax(dim=1)

tensor([[1, 1, 1, 2, 1, 1, 1, 0, 0, 2, 1, 1, 0, 0, 0, 1, 1, 2, 2, 0, 1, 1, 2, 2,
         2],
        [1, 1, 1, 2, 1, 1, 1, 0, 0, 2, 1, 1, 0, 0, 0, 1, 1, 2, 2, 0, 1, 1, 2, 2,
         2],
        [1, 1, 1, 2, 1, 1, 1, 0, 0, 2, 1, 1, 0, 0, 0, 1, 1, 2, 2, 0, 1, 1, 2, 2,
         2],
        [1, 1, 1, 2, 1, 1, 1, 0, 0, 2, 1, 1, 0, 0, 0, 1, 1, 2, 2, 0, 1, 1, 2, 2,
         2],
        [1, 1, 1, 2, 1, 1, 1, 0, 0, 2, 1, 1, 0, 0, 0, 1, 1, 2, 2, 0, 1, 1, 2, 2,
         2],
        [1, 1, 1, 2, 1, 1, 1, 0, 0, 2, 1, 1, 0, 0, 0, 1, 1, 2, 2, 0, 1, 1, 2, 2,
         2],
        [1, 1, 1, 2, 1, 1, 1, 0, 0, 2, 1, 1, 0, 0, 0, 1, 1, 2, 2, 0, 1, 1, 2, 2,
         2],
        [1, 1, 1, 2, 1, 1, 1, 0, 0, 2, 1, 1, 0, 0, 0, 1, 1, 2, 2, 0, 1, 1, 2, 2,
         2],
        [1, 1, 1, 2, 1, 1, 1, 0, 0, 2, 1, 1, 0, 0, 0, 1, 1, 2, 2, 0, 1, 1, 2, 2,
         2],
        [1, 1, 1, 2, 1, 1, 1, 0, 0, 2, 1, 1, 0, 0, 0, 1, 1, 2, 2, 0, 1, 1, 2, 2,
         2],
        [1, 1, 1, 2, 1, 1, 1, 0, 0, 2, 1, 1, 0, 0, 0, 1, 1, 