In [1]:
%load_ext autoreload
%autoreload 2

### Imports

In [2]:
import os
import time
import wandb
import torch
import random
import torchvision

import numpy as np
import pandas as pd
import torchmetrics as tm 
# import plotly.express as px
import pytorch_lightning as pl
import matplotlib.pyplot as plt

from torch import nn
from pathlib import Path, PurePath
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, AdamW, RMSprop # optmizers
from sklearn import preprocessing 
# from warmup_scheduler import GradualWarmupScheduler
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau # Learning rate schedulers

import albumentations as A
# from albumentations.pytorch import ToTensorV2

import torch.nn.functional as F

import torchmetrics as tm

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import Callback, LearningRateMonitor
from torchmetrics.wrappers import ClasswiseWrapper
from torchmetrics import MetricCollection
from torchmetrics.classification import MultilabelAccuracy, MultilabelPrecision, MultilabelRecall, MultilabelF1Score
from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall, MulticlassF1Score

import timm

In [3]:
print('timm version', timm.__version__)
print('torch version', torch.__version__)

timm version 1.0.8
torch version 2.3.1


In [4]:
wandb.login(key=os.getenv('wandb_api_key'))

wandb: Currently logged in as: rosu-lucian. Use `wandb login --relogin` to force relogin
wandb: Appending key for api.wandb.ai to your netrc file: C:\Users\Asus\.netrc


True

In [5]:
# detect and define device 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(device)

cuda


In [6]:
# for reproducibility
def seed_torch(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

### Config

In [7]:
# TODO: maybe use condition and level for classes
classes = ['SCS', 'RNFN', 'LNFN', 'LSS', 'RSS'] + ['H'] # add healthy class

# classes = ['SCS', 'RNFN', 'LNFN'] + ['H'] # add healthy class

# classes = ['LSS', 'RSS'] + ['H'] # add healthy class

classes = ['SCSL1L2', 'SCSL2L3', 'SCSL3L4', 'SCSL4L5', 'SCSL5S1', 'RNFNL4L5',
       'RNFNL5S1', 'RNFNL3L4', 'RNFNL1L2', 'RNFNL2L3', 'LNFNL1L2',
       'LNFNL4L5', 'LNFNL5S1', 'LNFNL2L3', 'LNFNL3L4', 'LSSL1L2',
       'RSSL1L2', 'LSSL2L3', 'RSSL2L3', 'LSSL3L4', 'RSSL3L4', 'LSSL4L5',
       'RSSL4L5', 'LSSL5S1', 'RSSL5S1']

num_classes = len(classes)
class2id = {b: i for i, b in enumerate(classes)}

num_classes

25

In [8]:
train_dir = Path('E:\data\RSNA2024')

class CFG:

    project = 'rsna-lstm'
    comment = 'newembeds'

    ### model
    model_name = 'lstm' # 'resnet34', 'resnet200d', 'efficientnet_b1_pruned', 'efficientnetv2_m', efficientnet_b7 

    # LEARNING_RATE = 5*1e-5 # best
    LEARNING_RATE = 1e-5
    
    image_size = 256

    healthy_frac = 1

    weighted_loss = True
    class_weights = [2.06762982, 0.42942998, 5.32804575]
    # class_weights = [1, 0.2, 1.5]
    
    num_layers = 4
    dropout = 0

     ### training
    BATCH_SIZE = 16
    
    ROOT_FOLDER = train_dir
    IMAGES_DIR = ROOT_FOLDER / 'train_images'
    PNG_DIR = ROOT_FOLDER / f'pngs_{image_size}'
    FILES_CSV = ROOT_FOLDER / 'train_files.csv'
    PREDS_CSV = ROOT_FOLDER / 'predictions.csv'
    TRAIN_CSV = ROOT_FOLDER / 'train.csv'
    TRAIN_DESC_CSV = ROOT_FOLDER / 'train_series_descriptions.csv'
    COORDS_CSV = ROOT_FOLDER / 'train_label_coordinates.csv'

    # ckpt_path = Path(r"E:\data\RSNA2024\results\ckpt\eca_nfnet_l0 5e-05 10 eps all-labels\ep_03_loss_0.15231.ckpt")
    embeds_path = Path(r"E:\data\RSNA2024\embeddings")
    stacked_path = Path(r"E:\data\RSNA2024\embeddings_stacked")

    RESULTS_DIR = train_dir / 'results'
    CKPT_DIR = RESULTS_DIR / 'ckpt'

    input_dim = 128
    hidden_dim = 128
    target_size = 128

    classes = classes

    split_fraction = 0.95

    MIXUP = False
    
    ### Optimizer
    N_EPOCHS = 30
    USE_SCHD = False
    WARM_EPOCHS = 3
    COS_EPOCHS = N_EPOCHS - WARM_EPOCHS
    
    weight_decay = 1e-6 # for adamw

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    ### split train and validation sets
    num_workers = 4

    random_seed = 42

# CFG.N_LABELS = len(CFG.classes)
CFG.N_LABELS = 3

seed_torch(seed = CFG.random_seed)

In [9]:
CFG.N_LABELS 

3

### Load data

In [10]:
train_df = pd.read_csv(CFG.TRAIN_CSV)
train_desc_df = pd.read_csv(CFG.TRAIN_DESC_CSV)
coords_df = pd.read_csv(CFG.COORDS_CSV)
files_df = pd.read_csv(CFG.FILES_CSV)
preds_df = pd.read_csv(CFG.PREDS_CSV)

train_df.shape, train_desc_df.shape, coords_df.shape, files_df.shape, preds_df.shape

((1975, 26), (6294, 3), (48692, 18), (147218, 20), (147218, 16))

In [11]:
train_desc_df['ss_id'] = train_desc_df.apply(lambda row: f'{row.study_id}_{row.series_id}', axis=1)
preds_df['ss_id'] = preds_df.apply(lambda row: f'{row.study_id}_{row.series_id}', axis=1)

train_desc_df.head(2)

Unnamed: 0,study_id,series_id,series_description,ss_id
0,4003253,702807833,Sagittal T2/STIR,4003253_702807833
1,4003253,1054713880,Sagittal T1,4003253_1054713880


In [12]:
preds_df.head(2)

Unnamed: 0,ids,SCS,RNFN,LNFN,LSS,RSS,H,pred_SCS,pred_RNFN,pred_LNFN,pred_LSS,pred_RSS,pred_H,study_id,series_id,instance,ss_id
0,4290709089_4237840455_1,0.0,0.0,0.0,0.0,0.0,1.0,0.001585,0.001949,0.001369,0.002785,0.002748,0.99344,4290709089,4237840455,1,4290709089_4237840455
1,4290709089_4237840455_2,0.0,0.0,0.0,0.0,0.0,1.0,0.001521,0.001219,0.001093,0.005625,0.003293,0.996517,4290709089,4237840455,2,4290709089_4237840455


In [13]:
preds_df[preds_df.pred_RSS > 0.8].shape

(3788, 17)

In [14]:
files_df.head(2)

Unnamed: 0,study_id,series_id,image,proj,instancenumber,rows,columns,slicethickness,spacingbetweenslices,patientposition,seriesdescription,ss_id,instance_id,filename,series_description,healthy,inst_min,inst_max,inst,inst_perc
0,4290709089,4237840455,1,19,1,384,384,4.0,4.6,HFS,,4290709089_4237840455,4290709089_4237840455_1,E:\data\RSNA2024\pngs_256\4290709089_423784045...,Sagittal T1,True,1,15,0,0.0
1,4290709089,4237840455,2,14,2,384,384,4.0,4.6,HFS,,4290709089_4237840455,4290709089_4237840455_2,E:\data\RSNA2024\pngs_256\4290709089_423784045...,Sagittal T1,True,1,15,1,0.066667


In [15]:
preds_df.shape

(147218, 17)

In [16]:
preds_df = pd.merge(preds_df, train_desc_df.loc[:, ['ss_id', 'series_description']],  how='inner', left_on=['ss_id'], right_on=['ss_id'])

preds_df.sample(2)

Unnamed: 0,ids,SCS,RNFN,LNFN,LSS,RSS,H,pred_SCS,pred_RNFN,pred_LNFN,pred_LSS,pred_RSS,pred_H,study_id,series_id,instance,ss_id,series_description
27165,3489738948_3186939889_6,0.0,0.0,0.0,0.0,0.0,1.0,0.000227,0.000267,0.000136,0.000363,0.000574,0.99972,3489738948,3186939889,6,3489738948_3186939889,Axial T2
109560,1099112122_1815821295_4,0.0,0.0,0.0,0.0,0.0,1.0,0.001099,0.001675,0.001109,0.003543,0.003516,0.99723,1099112122,1815821295,4,1099112122_1815821295,Sagittal T2/STIR


In [17]:
preds_df.shape

(147218, 18)

In [18]:
preds_df = pd.merge(preds_df, files_df.loc[:, ['instance_id', 'proj']],  how='inner', left_on=['ids'], right_on=['instance_id'])

preds_df.sample(2)

Unnamed: 0,ids,SCS,RNFN,LNFN,LSS,RSS,H,pred_SCS,pred_RNFN,pred_LNFN,pred_LSS,pred_RSS,pred_H,study_id,series_id,instance,ss_id,series_description,instance_id,proj
135360,331970469_538044249_17,0.0,0.0,0.0,0.0,0.0,1.0,0.000654,0.00099,0.000915,0.002225,0.002246,0.998471,331970469,538044249,17,331970469_538044249,Sagittal T2/STIR,331970469_538044249_17,-59
144302,88465004_4105999208_4,0.0,0.0,0.0,0.0,0.0,1.0,0.001293,0.004091,0.002412,0.003399,0.003456,0.998554,88465004,4105999208,4,88465004_4105999208,Sagittal T2/STIR,88465004_4105999208_4,15


In [19]:
preds_df.shape

(147218, 20)

In [20]:
# train_desc_df[train_desc_df['series_id'] == 3909740603]

In [21]:
train_df.fillna('N', inplace=True)
train_df.head(2)

Unnamed: 0,study_id,SCSL1L2,SCSL2L3,SCSL3L4,SCSL4L5,SCSL5S1,LNFNL1L2,LNFNL2L3,LNFNL3L4,LNFNL4L5,...,LSSL1L2,LSSL2L3,LSSL3L4,LSSL4L5,LSSL5S1,RSSL1L2,RSSL2L3,RSSL3L4,RSSL4L5,RSSL5S1
0,4003253,N,N,N,N,N,N,N,N,M,...,N,N,N,M,N,N,N,N,N,N
1,4646740,N,N,M,S,N,N,N,N,M,...,N,N,N,S,N,N,M,M,M,N


In [22]:
le = preprocessing.LabelEncoder() 
le.fit(train_df.iloc[:, 1])

le.classes_
# foo = le.fit_transform(train_df.iloc[:,1])

array(['M', 'N', 'S'], dtype=object)

In [23]:
from sklearn.utils.class_weight import compute_class_weight

In [24]:
foo = train_df.iloc[:, 1:].to_numpy()

foo.shape

(1975, 25)

In [25]:
foo.flatten().shape

(49375,)

In [26]:
compute_class_weight(class_weight='balanced', classes=np.unique(foo.flatten()), y=foo.flatten())

array([2.06762982, 0.42942998, 5.32804575])

In [27]:
CFG.ignore_index = le.transform(['N'])[0]

CFG.ignore_index

1

In [28]:
train_df.iloc[:,1:] = train_df.iloc[:,1:].apply(le.fit_transform)

In [29]:
train_df.head(2)

Unnamed: 0,study_id,SCSL1L2,SCSL2L3,SCSL3L4,SCSL4L5,SCSL5S1,LNFNL1L2,LNFNL2L3,LNFNL3L4,LNFNL4L5,...,LSSL1L2,LSSL2L3,LSSL3L4,LSSL4L5,LSSL5S1,RSSL1L2,RSSL2L3,RSSL3L4,RSSL4L5,RSSL5S1
0,4003253,1,1,1,1,1,1,1,1,0,...,1,1,1,0,1,1,1,1,1,1
1,4646740,1,1,0,2,1,1,1,1,0,...,1,1,1,2,1,1,0,0,0,1


In [30]:
arr = train_df.iloc[:, 1:].to_numpy()

(arr == 0).sum(), (arr == 1).sum(), (arr == 2).sum()

(7960, 38326, 3089)

In [31]:
coords_df.sample(2)

Unnamed: 0,study_id,series_id,instance,condition,level,x,y,ss_id,instance_id,cl,series_description,rows,columns,filename,patientposition,x_perc,y_perc,inst_perc
25515,2294509371,2895790445,7,SCS,L1L2,257.327262,138.787883,2294509371_2895790445,2294509371_2895790445_7,SCSL1L2,Sagittal T2/STIR,448,448,E:\data\RSNA2024\pngs_256\2294509371_289579044...,HFS,0.574391,0.309794,0.428571
35128,3116585926,3043780434,3,LNFN,L4L5,146.454106,321.545894,3116585926_3043780434,3116585926_3043780434_3,LNFNL4L5,Sagittal T1,512,512,E:\data\RSNA2024\pngs_256\3116585926_304378043...,FFS,0.286043,0.628019,0.181818


In [32]:
coords_df.condition.unique()

array(['SCS', 'RNFN', 'LNFN', 'LSS', 'RSS'], dtype=object)

In [33]:
coords_df.cl.unique()

array(['SCSL1L2', 'SCSL2L3', 'SCSL3L4', 'SCSL4L5', 'SCSL5S1', 'RNFNL4L5',
       'RNFNL5S1', 'RNFNL3L4', 'RNFNL1L2', 'RNFNL2L3', 'LNFNL1L2',
       'LNFNL4L5', 'LNFNL5S1', 'LNFNL2L3', 'LNFNL3L4', 'LSSL1L2',
       'RSSL1L2', 'LSSL2L3', 'RSSL2L3', 'LSSL3L4', 'RSSL3L4', 'LSSL4L5',
       'RSSL4L5', 'LSSL5S1', 'RSSL5S1'], dtype=object)

In [34]:
coords_df.cl.nunique()

25

#### Prepare

In [35]:
embed_files = os.listdir(CFG.embeds_path)
stacked_files = os.listdir(CFG.stacked_path)

len(embed_files), len(stacked_files)

(147219, 6295)

In [36]:
study_id = 838134337

In [37]:
selected_files = preds_df[preds_df.study_id == study_id]

selected_files.head(2)

Unnamed: 0,ids,SCS,RNFN,LNFN,LSS,RSS,H,pred_SCS,pred_RNFN,pred_LNFN,pred_LSS,pred_RSS,pred_H,study_id,series_id,instance,ss_id,series_description,instance_id,proj
118820,838134337_3108613161_1,0.0,0.0,0.0,0.0,0.0,1.0,0.001354,0.001501,0.00132,0.002872,0.002463,0.998599,838134337,3108613161,1,838134337_3108613161,Axial T2,838134337_3108613161_1,-309
118821,838134337_3108613161_2,0.0,0.0,0.0,0.0,0.0,1.0,0.001957,0.002748,0.001661,0.003534,0.003918,0.998603,838134337,3108613161,2,838134337_3108613161,Axial T2,838134337_3108613161_2,-312


In [38]:
# selected_files.head()

In [39]:
# selected_files.groupby('series_description').sort(['proj'])

In [40]:
selected_files.sort_values(['series_description', 'proj'], ascending=[False, False])

Unnamed: 0,ids,SCS,RNFN,LNFN,LSS,RSS,H,pred_SCS,pred_RNFN,pred_LNFN,pred_LSS,pred_RSS,pred_H,study_id,series_id,instance,ss_id,series_description,instance_id,proj
118903,838134337_1285354049_1,0.0,0.0,0.0,0.0,0.0,1.0,0.000830,0.001846,0.000824,0.004176,0.003610,0.997460,838134337,1285354049,1,838134337_1285354049,Sagittal T2/STIR,838134337_1285354049_1,19
118904,838134337_1285354049_2,0.0,0.0,0.0,0.0,0.0,1.0,0.000999,0.001894,0.000911,0.004106,0.003183,0.997706,838134337,1285354049,2,838134337_1285354049,Sagittal T2/STIR,838134337_1285354049_2,15
118905,838134337_1285354049_3,0.0,0.0,0.0,0.0,0.0,1.0,0.001089,0.003119,0.000816,0.003127,0.002793,0.997842,838134337,1285354049,3,838134337_1285354049,Sagittal T2/STIR,838134337_1285354049_3,10
118906,838134337_1285354049_4,0.0,0.0,0.0,0.0,0.0,1.0,0.001067,0.002225,0.001284,0.003330,0.003203,0.997264,838134337,1285354049,4,838134337_1285354049,Sagittal T2/STIR,838134337_1285354049_4,6
118907,838134337_1285354049_5,0.0,0.0,0.0,0.0,0.0,1.0,0.000683,0.001384,0.000429,0.001838,0.001545,0.998883,838134337,1285354049,5,838134337_1285354049,Sagittal T2/STIR,838134337_1285354049_5,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
118878,838134337_3108613161_59,0.0,0.0,0.0,0.0,0.0,1.0,0.002638,0.002090,0.001770,0.492424,0.688045,0.068409,838134337,3108613161,59,838134337_3108613161,Axial T2,838134337_3108613161_59,-516
118879,838134337_3108613161_60,0.0,0.0,0.0,0.0,0.0,1.0,0.001942,0.002482,0.002314,0.115469,0.197212,0.709857,838134337,3108613161,60,838134337_3108613161,Axial T2,838134337_3108613161_60,-519
118880,838134337_3108613161_61,0.0,0.0,0.0,0.0,0.0,1.0,0.001421,0.001702,0.002013,0.004992,0.003479,0.996188,838134337,3108613161,61,838134337_3108613161,Axial T2,838134337_3108613161_61,-523
118881,838134337_3108613161_62,0.0,0.0,0.0,0.0,0.0,1.0,0.000917,0.000943,0.000874,0.003200,0.003534,0.996320,838134337,3108613161,62,838134337_3108613161,Axial T2,838134337_3108613161_62,-526


In [41]:
selected_files.study_id.unique()

array([838134337], dtype=int64)

In [42]:
idx = selected_files.sort_values(['series_description', 'proj'], ascending=[False, True]).index.to_list()
len(idx), selected_files.shape

(103, (103, 20))

In [43]:
stacked_embeds = np.load(CFG.embeds_path / 'stacked.npy')
stacked_embeds.shape

(147218, 128)

In [44]:
stacked_embeds[idx].shape

(103, 128)

In [45]:
# for name, group in selected_files.sort_values('proj').groupby('series_description'):
#     print(name)
#     print(group.image.count())

In [46]:
files = files_df[files_df.study_id == study_id].instance_id.to_list()
files = [CFG.embeds_path / f'{f}.npy' for f in files]

files[0]

WindowsPath('E:/data/RSNA2024/embeddings/838134337_3108613161_1.npy')

In [47]:
np.load(files[2])

array([-1.0995494 ,  0.49097744,  1.0215834 , -1.6787816 ,  0.29210705,
        0.23124354, -1.682824  , -1.5906051 ,  0.04192492, -0.1955464 ,
       -0.2667491 , -0.17186733,  0.95804614,  0.37150368,  1.7241437 ,
        1.1258326 ,  1.0414605 , -0.40807748,  0.4055421 , -0.0740265 ,
       -1.3100444 ,  0.29382887, -0.42474675,  0.22165483, -1.5763164 ,
        1.2931648 ,  1.5698792 , -1.6912712 , -0.15872872,  0.3404426 ,
       -0.0427428 ,  0.54748905, -0.33488297, -0.37769663, -1.5615858 ,
        1.5908116 , -1.7296395 , -0.4697454 , -0.42289543,  0.44413814,
       -0.6652043 ,  1.3982869 ,  0.3384983 ,  0.50999105, -0.26906472,
       -0.6438707 ,  0.07620332, -0.07154421,  1.8592027 , -0.12772073,
        1.609306  , -0.30250317,  1.7261735 ,  1.2107503 , -0.8999722 ,
       -0.31249663, -1.2230419 , -0.23692471, -1.1786542 ,  0.95265937,
       -0.1917044 , -0.5779623 , -0.72399706,  0.85201985], dtype=float32)

In [48]:
files = files_df[files_df.study_id == study_id].ss_id.unique().tolist()
files = [CFG.stacked_path / f'{f}.npy' for f in files]

len(files)

3

In [49]:
np.load(files[0]).shape

(63, 64)

In [50]:
train_df[train_df.study_id == study_id].values.flatten().tolist()[1:]

[1, 1, 1, 2, 1, 1, 1, 1, 2, 0, 1, 1, 0, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1]

In [51]:
train_desc_df.head(5)

Unnamed: 0,study_id,series_id,series_description,ss_id
0,4003253,702807833,Sagittal T2/STIR,4003253_702807833
1,4003253,1054713880,Sagittal T1,4003253_1054713880
2,4003253,2448190387,Axial T2,4003253_2448190387
3,4646740,3201256954,Axial T2,4646740_3201256954
4,4646740,3486248476,Sagittal T1,4646740_3486248476


In [52]:
foo = train_desc_df[train_desc_df.study_id == study_id].sort_values('series_description', ascending=False)

foo

Unnamed: 0,study_id,series_id,series_description,ss_id
1215,838134337,1285354049,Sagittal T2/STIR,838134337_1285354049
1216,838134337,1345841225,Sagittal T1,838134337_1345841225
1217,838134337,3108613161,Axial T2,838134337_3108613161


In [53]:
foo.ss_id.tolist()

['838134337_1285354049', '838134337_1345841225', '838134337_3108613161']

### Check input data

In [54]:
embeds = np.load(CFG.embeds_path / 'stacked.npy')

In [55]:
df = preds_df[preds_df.study_id == 100206310]
# df = preds_df[preds_df.study_id == 838134337]
# df = df.sort_values(['series_description', 'proj'], ascending=[False, True])

In [56]:
embeds.shape, df.shape

((147218, 128), (96, 20))

In [57]:
idx = df.index.to_list()
embeds = embeds[idx]

In [58]:
# idx

In [59]:
embeds.shape

(96, 128)

In [60]:
embeds[:,1]

array([-0.1502166 , -0.13919176, -4.1691155 , -5.535725  , -1.3638068 ,
       -0.15850405, -0.04254709, -0.1241435 , -0.16648896,  0.3508525 ,
        4.0990744 , -0.26638272, -0.31634194, -0.30690244, -0.34259567,
       -0.37317523, -0.1486419 , -0.05325514, -0.06121084, -0.20191208,
       -0.2971905 , -0.26361814, -0.07879854, -0.1257669 , -0.22480534,
       -0.16994789, -0.17953274, -0.03545944,  0.07151809, -0.11276001,
       -0.10309287, -0.2681766 , -0.18603814, -0.01431203, -0.19431952,
       -0.23498192, -0.27063888, -0.2655052 , -0.02401078, -0.58710146,
       -0.2498314 , -0.19352336, -0.2361399 , -0.23905426, -0.19944517,
       -0.12068097, -0.18416779, -0.09809046, -0.03073646, -0.04118643,
       -0.12292372, -0.10268951, -0.09341175, -0.1714987 , -0.11936589,
       -0.35471657, -0.11589582, -0.13076183, -0.45870098, -1.5385287 ,
       -3.120806  , -0.9753626 , -0.43897247, -0.22652276, -0.13100655,
       -0.21621694, -0.25103098, -0.17783034, -0.14182015,  1.76

In [61]:
df.head(10)

Unnamed: 0,ids,SCS,RNFN,LNFN,LSS,RSS,H,pred_SCS,pred_RNFN,pred_LNFN,pred_LSS,pred_RSS,pred_H,study_id,series_id,instance,ss_id,series_description,instance_id,proj
143759,100206310_2092806862_1,0.0,0.0,0.0,0.0,0.0,1.0,0.001142,0.001998,0.000932,0.006164,0.002349,0.9987,100206310,2092806862,1,100206310_2092806862,Sagittal T1,100206310_2092806862_1,20
143760,100206310_2092806862_2,0.0,0.0,0.0,0.0,0.0,1.0,0.000709,0.004118,0.000581,0.002755,0.002364,0.997486,100206310,2092806862,2,100206310_2092806862,Sagittal T1,100206310_2092806862_2,15
143761,100206310_2092806862_3,0.0,0.0,0.0,0.0,0.0,1.0,0.002421,0.033496,0.001279,0.004077,0.005075,0.984101,100206310,2092806862,3,100206310_2092806862,Sagittal T1,100206310_2092806862_3,11
143762,100206310_2092806862_4,0.0,1.0,0.0,0.0,0.0,0.0,0.000564,0.999648,0.000387,0.001096,0.000764,0.003171,100206310,2092806862,4,100206310_2092806862,Sagittal T1,100206310_2092806862_4,6
143763,100206310_2092806862_5,0.0,1.0,0.0,0.0,0.0,0.0,0.000547,0.997587,0.000175,0.001053,0.000862,0.007583,100206310,2092806862,5,100206310_2092806862,Sagittal T1,100206310_2092806862_5,2
143764,100206310_2092806862_6,0.0,1.0,0.0,0.0,0.0,0.0,0.001389,0.965823,0.000432,0.001363,0.001165,0.02497,100206310,2092806862,6,100206310_2092806862,Sagittal T1,100206310_2092806862_6,-2
143765,100206310_2092806862_7,0.0,0.0,0.0,0.0,0.0,1.0,0.001695,0.021071,0.007121,0.003482,0.002838,0.974268,100206310,2092806862,7,100206310_2092806862,Sagittal T1,100206310_2092806862_7,-6
143766,100206310_2092806862_8,0.0,0.0,0.0,0.0,0.0,1.0,0.007089,0.012795,0.013448,0.005826,0.006997,0.987549,100206310,2092806862,8,100206310_2092806862,Sagittal T1,100206310_2092806862_8,-11
143767,100206310_2092806862_9,0.0,0.0,0.0,0.0,0.0,1.0,0.009156,0.009769,0.00833,0.00801,0.007488,0.989857,100206310,2092806862,9,100206310_2092806862,Sagittal T1,100206310_2092806862_9,-15
143768,100206310_2092806862_10,0.0,0.0,0.0,0.0,0.0,1.0,0.003289,0.006161,0.007487,0.00325,0.003299,0.994555,100206310,2092806862,10,100206310_2092806862,Sagittal T1,100206310_2092806862_10,-20


In [62]:
df.tail(10)

Unnamed: 0,ids,SCS,RNFN,LNFN,LSS,RSS,H,pred_SCS,pred_RNFN,pred_LNFN,pred_LSS,pred_RSS,pred_H,study_id,series_id,instance,ss_id,series_description,instance_id,proj
143845,100206310_1012284084_51,0.0,0.0,0.0,0.0,0.0,1.0,0.001849,0.002726,0.001528,0.007656,0.011341,0.966479,100206310,1012284084,51,100206310_1012284084,Axial T2,100206310_1012284084_51,-540
143846,100206310_1012284084_52,0.0,0.0,0.0,0.0,0.0,1.0,0.001444,0.001357,0.001297,0.011961,0.011959,0.979055,100206310,1012284084,52,100206310_1012284084,Axial T2,100206310_1012284084_52,-544
143847,100206310_1012284084_53,0.0,0.0,0.0,0.0,0.0,1.0,0.003157,0.00271,0.002551,0.004657,0.004763,0.993834,100206310,1012284084,53,100206310_1012284084,Axial T2,100206310_1012284084_53,-547
143848,100206310_1012284084_54,0.0,0.0,0.0,0.0,0.0,1.0,0.001458,0.001736,0.00123,0.275448,0.261829,0.348866,100206310,1012284084,54,100206310_1012284084,Axial T2,100206310_1012284084_54,-551
143849,100206310_1012284084_55,0.0,0.0,0.0,1.0,1.0,0.0,0.003414,0.002638,0.003983,0.61213,0.736784,0.030937,100206310,1012284084,55,100206310_1012284084,Axial T2,100206310_1012284084_55,-554
143850,100206310_1012284084_56,0.0,0.0,0.0,0.0,0.0,1.0,0.005796,0.004421,0.003515,0.611376,0.815653,0.015366,100206310,1012284084,56,100206310_1012284084,Axial T2,100206310_1012284084_56,-558
143851,100206310_1012284084_57,0.0,0.0,0.0,0.0,0.0,1.0,0.00174,0.002726,0.001834,0.01604,0.013076,0.97659,100206310,1012284084,57,100206310_1012284084,Axial T2,100206310_1012284084_57,-561
143852,100206310_1012284084_58,0.0,0.0,0.0,0.0,0.0,1.0,0.000992,0.001275,0.00081,0.007748,0.010601,0.992076,100206310,1012284084,58,100206310_1012284084,Axial T2,100206310_1012284084_58,-565
143853,100206310_1012284084_59,0.0,0.0,0.0,0.0,0.0,1.0,0.001028,0.001857,0.001319,0.001814,0.00161,0.99882,100206310,1012284084,59,100206310_1012284084,Axial T2,100206310_1012284084_59,-568
143854,100206310_1012284084_60,0.0,0.0,0.0,0.0,0.0,1.0,0.000777,0.000981,0.000742,0.001911,0.002188,0.999004,100206310,1012284084,60,100206310_1012284084,Axial T2,100206310_1012284084_60,-572


In [63]:
df.tail(10).index

Index([143845, 143846, 143847, 143848, 143849, 143850, 143851, 143852, 143853,
       143854],
      dtype='int64')

In [64]:
embeds[1]

array([-0.34221628, -0.13919176, -0.78077346,  0.64501834, -0.07195064,
       -1.2440209 ,  0.9645591 , -0.8141093 ,  0.21966034, -1.099449  ,
        0.6013279 ,  1.5434505 , -0.86642814,  0.40831086, -1.1543027 ,
        0.54212123,  0.7684498 ,  0.29176962, -0.25353748, -1.0091592 ,
        0.5234485 , -0.76076007, -0.67219174,  0.11052073, -0.8537112 ,
       -0.75529987, -0.95640016, -1.0791835 , -0.0261636 , -1.0279031 ,
        1.1882536 , -0.5265772 ,  0.8112257 , -0.01480129,  0.91758794,
       -0.27568927,  0.02743484,  0.03835327, -1.430163  ,  0.09892607,
        0.2911259 , -0.21277288,  0.38242516, -0.6235944 , -1.1379712 ,
        0.25593656, -1.0292481 ,  0.04417314, -0.16738598, -0.38505363,
       -0.1594606 , -1.0211108 ,  0.4320745 ,  0.6459609 ,  1.5208894 ,
       -0.28295177,  1.1789021 , -0.5766379 , -0.9189778 ,  0.18157019,
        0.14352582, -0.03669149,  0.01666886,  1.18903   ,  1.4388409 ,
       -0.09910575,  0.6161871 ,  0.73354167, -0.31500736, -0.66

### Dataset

In [65]:
from dataset import rsna_lstm_dataset, rsna_lstm_dataset2

In [66]:
# dset = rsna_lstm_dataset(train_df, train_desc_df, CFG.stacked_path)
dset = rsna_lstm_dataset2(train_df, preds_df, CFG.stacked_path, CFG)

print(dset.__len__())

seq, target = dset.__getitem__(0)
print(seq.shape, target.shape)
print(seq.dtype, target.dtype)

1975
torch.Size([73, 128]) torch.Size([25])
torch.float32 torch.int64


In [67]:
dset.healthy_frac = 1

print(dset.__len__())

seq, target = dset.__getitem__(0)
print(seq.shape, target.shape)
print(seq.dtype, target.dtype)

1975
torch.Size([73, 128]) torch.Size([25])
torch.float32 torch.int64


In [68]:
# dset = rsna_lstm_dataset(train_df, train_desc_df, CFG.stacked_path)
# dset = rsna_lstm_dataset2(train_df, preds_df, CFG.stacked_path)
dset.healthy_frac = 0

print(dset.__len__())

seq, target = dset.__getitem__(0)
print(seq.shape, target.shape)
print(seq.dtype, target.dtype)

1975
torch.Size([22, 128]) torch.Size([25])
torch.float32 torch.int64


In [69]:
study_id = train_df.loc[0].study_id

selection = preds_df[preds_df.study_id == study_id]

selection.ids.nunique(), selection.shape

(73, (73, 20))

In [70]:
selection[selection.pred_H < 0.8].shape, selection[selection.H < 1].shape

((22, 20), (12, 20))

In [71]:
# preds_df.sample(frac=0.1)

In [72]:
# seq dim: (bs, seq_len, 1, num_features)
# target dim: (N, d1)
target

tensor([1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,
        1])

In [73]:
seq[0]

tensor([ 0.1740, -0.2892, -0.0194,  0.6008, -1.1034,  0.0763, -0.9378, -0.2737,
         0.8729, -0.1389, -1.1029,  0.2722,  0.0886, -0.2062,  0.2179, -0.3113,
        -0.1670,  0.0475,  0.9207,  0.1925, -0.0863, -0.3824, -0.2791, -0.7670,
        -0.4980,  0.1610, -0.3804, -0.2098, -0.9039, -0.3670,  0.2090,  0.9509,
         0.2215,  0.4828, -0.0169, -1.2843,  0.4423,  0.1195, -0.1747,  0.4252,
         0.3568, -1.1643, -0.0369,  0.1330, -0.0452,  0.2945,  0.0762, -0.1786,
        -0.7236, -0.0318, -0.1395, -0.2227,  0.0445, -0.9931,  0.4349, -0.0923,
        -0.0083, -0.3475,  0.0600, -0.8649,  0.8867, -1.2293,  0.9148, -0.0535,
         0.5973, -0.3420,  1.0847, -0.8047, -0.1860, -0.2820,  0.5169,  0.2863,
         0.5076, -0.4433, -0.1834, -0.2242, -0.2448, -0.1520, -0.2787, -0.4791,
         0.2262, -0.3664, -0.1167,  0.8033,  0.3123, -0.0444, -0.8730,  0.5321,
        -0.9100, -0.4221,  0.2313,  0.2102, -0.4873, -0.2858, -0.1669, -0.0357,
        -0.1119, -1.1185, -0.1582,  0.93

### Data Module

In [74]:
from dataset import rsna_lstm_dataset, rsna_lstm_dataset2, collate_fn_padd

In [75]:
train_desc_df.sample()

Unnamed: 0,study_id,series_id,series_description,ss_id
758,504362668,4031652250,Sagittal T2/STIR,504362668_4031652250


In [76]:
# from torch.nn.utils.rnn import pad_sequence

# def collate_fn_padd(data):
#     tensors, targets = zip(*data)
#     features = pad_sequence(tensors, batch_first=True)
#     targets = torch.stack(targets)
#     return features, targets
    
class lstm_datamodule(pl.LightningDataModule):
    # def __init__(self, train_df, val_df, train_desc_df, cfg):
    def __init__(self, train_df, val_df, preds_df, cfg):
        super().__init__()
        
        self.train_df = train_df
        self.val_df = val_df
        # self.train_desc_df = train_desc_df
        self.preds_df = preds_df
        
        self.train_bs = cfg.BATCH_SIZE
        self.val_bs = cfg.BATCH_SIZE

        self.cfg = cfg
        self.path = cfg.stacked_path
        
        self.num_workers = cfg.num_workers
        
    def train_dataloader(self):
        train_ds = rsna_lstm_dataset2(self.train_df, self.preds_df, self.path, self.cfg)
        
        train_loader = torch.utils.data.DataLoader(
            train_ds,
            batch_size=self.train_bs,
            collate_fn=collate_fn_padd,
            pin_memory=False,
            drop_last=False,
            shuffle=True,
            persistent_workers=True,
            num_workers=self.num_workers,
        )
        
        return train_loader
        
    def val_dataloader(self):
        val_ds = rsna_lstm_dataset2(self.val_df, self.preds_df, self.path, self.cfg)
        
        val_loader = torch.utils.data.DataLoader(
            val_ds,
            batch_size=self.val_bs,
            collate_fn=collate_fn_padd,
            pin_memory=False,
            drop_last=False,
            shuffle=False,
            persistent_workers=True,
            num_workers=2,
        )
        
        return val_loader

In [77]:
t_df = train_df[:-100]
# t_df = pd.concat([meta_df[:-100], ul_df[:-100]], ignore_index=True)
v_df = train_df[-100:]

CFG2 = CFG()
# CFG2 = copy.deepcopy(CFG)
CFG2.BATCH_SIZE = 8
CFG2.num_workers = 4

# dm = lstm_datamodule(t_df, v_df, train_desc_df, CFG2)
dm = lstm_datamodule(t_df, v_df, preds_df, CFG2)

x, y = next(iter(dm.train_dataloader()))
x.shape, y.shape, x.dtype, y.dtype

(torch.Size([8, 123, 128]), torch.Size([8, 25]), torch.float32, torch.int64)

In [78]:
# x.shape[0]

In [79]:
y[0]

tensor([1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 2, 1, 1, 1, 0, 0, 1,
        2])

In [80]:
y[1]

tensor([1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1,
        1])

In [81]:
del dm

### Loss function

In [82]:
class FocalLossBCE(torch.nn.Module):
    def __init__(
            self,
            alpha: float = 0.25,
            gamma: float = 5,
            reduction: str = "mean",
            bce_weight: float = 1.0,
            focal_weight: float = 1.0,
    ):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        # self.bce = torch.nn.BCEWithLogitsLoss(reduction=reduction, pos_weight=0.1)
        self.bce = torch.nn.BCEWithLogitsLoss(reduction=reduction)
        self.bce_weight = bce_weight
        self.focal_weight = focal_weight

    def forward(self, logits, targets):
        focall_loss = torchvision.ops.focal_loss.sigmoid_focal_loss(
            inputs=logits,
            targets=targets,
            alpha=self.alpha,
            gamma=self.gamma,
            reduction=self.reduction,
        )
        
        bce_loss = self.bce(logits, targets)
        
        return self.bce_weight * bce_loss + self.focal_weight * focall_loss

In [83]:
# logprobs = F.cross_entropy(input, target, reduction='none')
# at = at.view(-1, len(alphas))
# pt = torch.exp(-logprobs)
# focal_loss = at*(1-pt)** gamma * logprobs
# return focal_loss.mean()

### Model

#### Definition

In [84]:
CFG.N_LABELS

3

In [85]:
# torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]).shape, torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]).shape

In [86]:
CFG.ignore_index

1

In [87]:
class LSTMClassifier(pl.LightningModule):
    def __init__(self, cfg=CFG):
        super(LSTMClassifier, self).__init__()

        self.cfg = cfg
        
        self.input_dim = cfg.input_dim
        self.hidden_dim = cfg.hidden_dim

        self.num_layers = cfg.num_layers

        self.ignore_index = cfg.ignore_index
        
        self.levels = 25
        self.classes = cfg.N_LABELS

        self.hidden = None

        # https://discuss.pytorch.org/t/pytorchs-non-deterministic-cross-entropy-loss-and-the-problem-of-reproducibility/172180/9
        # reduction is set to none
        # self.criterion = torch.nn.CrossEntropyLoss(reduction='none', weight=torch.tensor([1,0.1,2]))
        weight = None
        if self.cfg.weighted_loss:
            weight = torch.tensor(cfg.class_weights)

        self.focal = FocalLossBCE(bce_weight=1)
        # self.criterion = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=cfg.ignore_index)
        self.criterion = torch.nn.CrossEntropyLoss(reduction='none', weight=weight)

        self.lstm = nn.LSTM(self.input_dim, self.hidden_dim, num_layers=self.num_layers, batch_first=True, dropout=cfg.dropout)
        self.fc = nn.Linear(self.hidden_dim, self.levels * self.classes)
        self.fc_healthy = nn.Linear(self.hidden_dim, self.levels)

        # no average maccs
        macc_serverity = ClasswiseWrapper(MulticlassAccuracy(
            num_classes=self.cfg.N_LABELS,
            average='none', 
            multidim_average='global'
            # label encoder classes
        ), labels=le.classes_.tolist(), prefix='multiacc_severity/')

        macc_levels = ClasswiseWrapper(MulticlassAccuracy(
            num_classes=self.cfg.N_LABELS,
            average='none', 
            multidim_average='global'
        ), labels=classes, prefix='multiacc_levels/')
        
        metrics = MetricCollection({
            'macc': MulticlassAccuracy(num_classes=self.cfg.N_LABELS),
            'macc_none': macc_serverity,
            'mpr': MulticlassPrecision(num_classes=self.cfg.N_LABELS),
            'mrec': MulticlassRecall(num_classes=self.cfg.N_LABELS),
            'f1': MulticlassF1Score(num_classes=self.cfg.N_LABELS)
        })

        self.train_metrics = metrics.clone(prefix='train/')
        self.valid_metrics = metrics.clone(prefix='val/')

    def init_hidden(self, bsz):
        weight = next(self.parameters())
        return (
            weight.new_zeros(self.num_layers, bsz, self.hidden_dim),
            weight.new_zeros(self.num_layers, bsz, self.hidden_dim),
        )

    def forward(self, sequence):
        # if self.hidden is None:
        #     self.hidden = self.init_hidden(sequence.shape[0])
            
        # self.hidden = (self.hidden[0].detach(), self.hidden[1].detach())
        
        #  seq: (seq_len, bs, num_features)
        lstm_out, (h, c) = self.lstm(sequence)
        # lstm_out, (h, c) = self.lstm(sequence, self.hidden)
        
        y = self.fc(h[-1])
        y2 = self.fc_healthy(h[-1])

        # (N,C,d1) -> (N,3,25)
        return y.view(-1, self.classes, self.levels), y2.view(-1, self.levels)

    def step(self, batch, batch_idx, mode='train'):
        x, y = batch

        preds, preds_healthy = self(x)

        # focal_loss = self.focal(preds[:, self.ignore_index], torch.where(y==self.ignore_index, 1., 0.))
        
        focal_loss = self.focal(preds_healthy, torch.where(y==self.ignore_index, 1., 0.))

        # focal_loss = self.focal( torch.where(y==self.ignore_index, 1., 0.), preds[:, self.ignore_index])
        # https://discuss.pytorch.org/t/pytorchs-non-deterministic-cross-entropy-loss-and-the-problem-of-reproducibility/172180/9
        loss = self.criterion(preds, y).mean()

        total_loss = loss + focal_loss

        # print(preds.shape, y.shape)

        self.hidden = None

        if mode == 'train':
            output = self.train_metrics(preds, y)
            self.log_dict(output)
        else:
            self.valid_metrics.update(preds, y)

        self.log(f'{mode}/loss', total_loss, on_step=True, on_epoch=True)
        self.log(f'{mode}/loss_focal', focal_loss, on_step=True, on_epoch=False)
        self.log(f'{mode}/loss_classif', loss, on_step=True, on_epoch=False)

        return total_loss

    def training_step(self, batch, batch_idx):
        loss = self.step(batch, batch_idx, mode='train')

        # self.hidden = None
        
        return loss
        
    def validation_step(self, batch, batch_idx):
        loss = self.step(batch, batch_idx, mode='val')
    
        return loss

    def on_train_epoch_end(self):
        self.train_metrics.reset()

    def on_validation_epoch_end(self):
        output = self.valid_metrics.compute()
        self.log_dict(output)

        self.valid_metrics.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.LEARNING_RATE, weight_decay=CFG.weight_decay)
        
        if self.cfg.USE_SCHD:
            scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.cfg.COS_EPOCHS)
            scheduler_warmup = GradualWarmupSchedulerV2(optimizer, multiplier=10, total_epoch=self.cfg.WARM_EPOCHS, after_scheduler=scheduler_cosine)

            return [optimizer], [scheduler_warmup]
        else:
            return optimizer

#### building blocks

In [88]:
seq.shape, seq.view(1, len(seq), -1).shape, target.shape

(torch.Size([22, 128]), torch.Size([1, 22, 128]), torch.Size([25]))

In [89]:
target_size = 64

lstm = nn.LSTM(target_size, target_size, num_layers=1, batch_first=True)
fc = nn.Linear(target_size, target_size)
classifiers = [nn.Linear(target_size, 3) for i in range(25)]

lstm_out, (h, c) = lstm(torch.randn(5,88,64))
print(lstm_out.shape, h.shape, c.shape)

y = fc(h[-1])
print('fc shape:', y.shape)

preds = [c(y).T for c in classifiers]
print('pred shape:', preds[0].shape)

preds = torch.stack(preds).T

print('preds shape:', preds.shape)

torch.Size([5, 88, 64]) torch.Size([1, 5, 64]) torch.Size([1, 5, 64])
fc shape: torch.Size([5, 64])
pred shape: torch.Size([3, 5])


  preds = torch.stack(preds).T


preds shape: torch.Size([5, 3, 25])


In [90]:
target_size = 64

lstm = nn.LSTM(target_size, target_size, num_layers=16, batch_first=True)
fc = nn.Linear(target_size, target_size)
classifiers = [nn.Linear(target_size, 3) for i in range(25)]

lstm_out, (h, c) = lstm(torch.randn(5,88,64))
print(lstm_out.shape, h.shape, c.shape)

y = fc(h[-1])
print('fc shape:', y.shape)

preds = [c(y).T for c in classifiers]
print('pred shape:', preds[0].shape)

preds = torch.stack(preds).T

print('preds shape:', preds.shape)

torch.Size([5, 88, 64]) torch.Size([16, 5, 64]) torch.Size([16, 5, 64])
fc shape: torch.Size([5, 64])
pred shape: torch.Size([3, 5])
preds shape: torch.Size([5, 3, 25])


In [91]:
target_size = 64

lstm = nn.LSTM(target_size, target_size, num_layers=1, batch_first=True)
fc = nn.Linear(target_size, 75)

criterion = torch.nn.CrossEntropyLoss(reduction='none')

lstm_out, (h, c) = lstm(torch.randn(5,88,64))
print(lstm_out.shape, h.shape, c.shape)

pred = fc(h[-1])
print('fc shape:', y.shape)

pred = pred.view(-1, 3, 25)
print('preds shape:', pred.shape)

targets = torch.stack(5*[target])
loss = criterion(preds, targets)

print('loss:', loss.shape)

torch.Size([5, 88, 64]) torch.Size([1, 5, 64]) torch.Size([1, 5, 64])
fc shape: torch.Size([5, 64])
preds shape: torch.Size([5, 3, 25])
loss: torch.Size([5, 25])


In [92]:
preds.argmax(dim=1).shape, targets.shape

(torch.Size([5, 25]), torch.Size([5, 25]))

In [93]:
preds.argmax(dim=1)

tensor([[1, 1, 0, 0, 0, 1, 2, 1, 0, 2, 2, 1, 2, 2, 0, 1, 1, 0, 0, 2, 2, 1, 1, 2,
         2],
        [1, 1, 0, 0, 0, 1, 2, 1, 0, 2, 2, 1, 2, 2, 0, 1, 1, 0, 0, 2, 2, 1, 1, 2,
         2],
        [1, 1, 0, 0, 0, 1, 2, 1, 0, 2, 2, 1, 2, 2, 0, 1, 1, 0, 0, 2, 2, 1, 1, 2,
         2],
        [1, 1, 0, 0, 0, 1, 2, 1, 0, 2, 2, 1, 2, 2, 0, 1, 1, 0, 0, 2, 2, 1, 1, 2,
         2],
        [1, 1, 0, 0, 0, 1, 2, 1, 0, 2, 2, 1, 2, 2, 0, 1, 1, 0, 0, 2, 2, 1, 1, 2,
         2]])

In [94]:
# preds.max(dim=1)

In [95]:
preds.max(dim=1)[1].shape, preds.argmax(1).shape

(torch.Size([5, 25]), torch.Size([5, 25]))

In [96]:
zeros = torch.zeros(preds.shape, dtype=preds.dtype)
ones = torch.ones(preds.shape, dtype=preds.dtype)

max_vals = preds.argmax(1)

foo = zeros.scatter(1, max_vals.unsqueeze(1), 1)

zeros.shape, ones.shape, foo.shape, preds.shape, foo.swapaxes(1,2).shape

(torch.Size([5, 3, 25]),
 torch.Size([5, 3, 25]),
 torch.Size([5, 3, 25]),
 torch.Size([5, 3, 25]),
 torch.Size([5, 25, 3]))

In [97]:
a = torch.arange(15*25).reshape(5,3,25)
a.shape, a.swapaxes(1,2).shape

(torch.Size([5, 3, 25]), torch.Size([5, 25, 3]))

In [98]:
zeros.scatter(1, targets.unsqueeze(1), 1).swapaxes(1,2).shape

torch.Size([5, 25, 3])

In [99]:
# acc = tm.functional.classification.multilabel_accuracy(foo, targets, 25, )
acc = tm.functional.classification.multilabel_accuracy(foo.swapaxes(1,2), zeros.scatter(1, targets.unsqueeze(1), 1).swapaxes(1,2), 25, average='none')
acc

tensor([1.0000, 1.0000, 0.3333, 0.3333, 0.3333, 1.0000, 0.3333, 1.0000, 1.0000,
        0.3333, 0.3333, 1.0000, 0.3333, 0.3333, 0.3333, 1.0000, 1.0000, 0.3333,
        1.0000, 0.3333, 0.3333, 1.0000, 1.0000, 0.3333, 0.3333])

In [100]:
torch.stack(5*[target])
preds.shape, targets.shape

(torch.Size([5, 3, 25]), torch.Size([5, 25]))

In [101]:
metric = MulticlassAccuracy(num_classes=CFG.N_LABELS, average='none', multidim_average='samplewise')
macc = MulticlassAccuracy(num_classes=CFG.N_LABELS, average='none')

m = metric(preds, targets)

m.shape, m

(torch.Size([5, 3]),
 tensor([[0.5000, 0.4286, 0.0000],
         [0.5000, 0.4286, 0.0000],
         [0.5000, 0.4286, 0.0000],
         [0.5000, 0.4286, 0.0000],
         [0.5000, 0.4286, 0.0000]]))

In [102]:
macc(preds, targets)

tensor([0.5000, 0.4286, 0.0000])

#### Test out inputs/outputs

In [103]:
model = LSTMClassifier(CFG)

In [104]:
preds, preds_healthy = model.forward((seq.view(1, len(seq), -1)))
preds.shape, preds_healthy.shape

(torch.Size([1, 3, 25]), torch.Size([1, 25]))

In [105]:
model.step((seq.view(1, len(seq), -1), target.view(1, len(target))), 0)

C:\ProgramData\anaconda3\envs\rsna\lib\site-packages\pytorch_lightning\core\module.py:445: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


tensor(1.4700, grad_fn=<AddBackward0>)

In [106]:
model = LSTMClassifier(CFG)

In [107]:
y, y2 = model(torch.randn(5,88,128))

y.shape, y.softmax(dim=0).shape

(torch.Size([5, 3, 25]), torch.Size([5, 3, 25]))

In [108]:
y.softmax(dim=1).shape

torch.Size([5, 3, 25])

In [109]:
# y.softmax(dim=1).sum(dim=1)

In [110]:
# y[0].softmax(dim=0)

### Split

In [111]:
from sklearn.model_selection import ShuffleSplit, StratifiedShuffleSplit

In [112]:
train_df.shape

(1975, 26)

In [113]:
train_df.sample(2)

Unnamed: 0,study_id,SCSL1L2,SCSL2L3,SCSL3L4,SCSL4L5,SCSL5S1,LNFNL1L2,LNFNL2L3,LNFNL3L4,LNFNL4L5,...,LSSL1L2,LSSL2L3,LSSL3L4,LSSL4L5,LSSL5S1,RSSL1L2,RSSL2L3,RSSL3L4,RSSL4L5,RSSL5S1
312,691922299,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,1,1,1,1
73,159721286,1,1,1,0,1,1,1,0,0,...,1,1,1,0,1,1,1,0,2,1


In [114]:
train_df.iloc[:,1:].shape

(1975, 25)

In [115]:
# TODO: data split

In [116]:
# sss = StratifiedShuffleSplit(n_splits=1, test_size=1-CFG.split_fraction, random_state=CFG.random_seed)
# train_idx, val_idx = next(sss.split(train_df.study_id, train_df.iloc[:,1:]))

# t_df = train_df.iloc[train_idx]
# v_df = train_df.iloc[val_idx]

# t_df.shape, v_df.shape

### Train

In [117]:
CFG.BATCH_SIZE, CFG.device

(16, 'cuda')

In [118]:
train_desc_df.head()

Unnamed: 0,study_id,series_id,series_description,ss_id
0,4003253,702807833,Sagittal T2/STIR,4003253_702807833
1,4003253,1054713880,Sagittal T1,4003253_1054713880
2,4003253,2448190387,Axial T2,4003253_2448190387
3,4646740,3201256954,Axial T2,4646740_3201256954
4,4646740,3486248476,Sagittal T1,4646740_3486248476


In [119]:
# dm = lstm_datamodule(t_df, v_df, train_desc_df, cfg=CFG)
dm = lstm_datamodule(t_df, v_df, preds_df, cfg=CFG)

len(dm.train_dataloader()), len(dm.val_dataloader())

(118, 7)

In [120]:
run_name = f'{CFG.LEARNING_RATE} {CFG.N_EPOCHS} eps {CFG.num_layers}l-{CFG.comment}'
run_name

'1e-05 30 eps 4l-newembeds'

In [121]:
wandb_logger = WandbLogger(
    name=run_name,
    project=CFG.project,
    job_type='train',
    save_dir=CFG.RESULTS_DIR,
    # config=cfg,
)

loss_ckpt = pl.callbacks.ModelCheckpoint(
    monitor='val/loss',
    auto_insert_metric_name=False,
    dirpath=CFG.CKPT_DIR / run_name,
    filename='ep_{epoch:02d}_loss_{val/loss:.5f}',
    save_top_k=2,
    mode='min',
)

# acc_ckpt = pl.callbacks.ModelCheckpoint(
#     monitor='val/acc',
#     auto_insert_metric_name=False,
#     dirpath=CFG.CKPT_DIR / run_name,
#     filename='ep_{epoch:02d}_acc_{val/acc:.5f}',
#     save_top_k=2,
#     mode='max',
# )

lr_monitor = LearningRateMonitor(logging_interval='step')

In [122]:
trainer = pl.Trainer(
    max_epochs=CFG.N_EPOCHS,
    deterministic=True,
    accelerator=CFG.device,
    default_root_dir=CFG.RESULTS_DIR,
    gradient_clip_val=0.5, 
    # gradient_clip_algorithm="value",
    logger=wandb_logger,
    callbacks=[loss_ckpt, lr_monitor],
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


#### Fit

In [123]:
model = LSTMClassifier(CFG)

In [124]:
trainer.fit(model, dm)

You are using a CUDA device ('NVIDIA GeForce RTX 4090 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | focal         | FocalLossBCE     | 0      | train
1 | criterion     | CrossEntropyLoss | 0      | train
2 | lstm          | LSTM             | 528 K  | train
3 | fc            | Linear           | 9.7 K  | train
4 | fc_healthy    | Linear           | 3.2 K  | train
5 | train_metrics | MetricCollection | 0      | train
6 | valid_metrics | MetricCollection | 0      | train
-----------------------------------------------------------
541 K     Trainable params
0         Non-trainable params
541 K     Total params
2.165     Total estimated model params size (MB)
20        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                        | 0/? [00:00<?, ?it/…

Training: |                                                                               | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…



Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

`Trainer.fit` stopped: `max_epochs=30` reached.


### Predict

In [125]:
x, y = next(iter(dm.train_dataloader()))

x.shape, y.shape

(torch.Size([16, 102, 128]), torch.Size([16, 25]))

In [126]:
torch.where(y==1, 1., 0.).shape

torch.Size([16, 25])

In [127]:
x.shape

torch.Size([16, 102, 128])

In [128]:
# pred = model(x.to(CFG.device)).detach().cpu()
pred, pred2 = model(x)
pred = pred.detach().cpu()
pred.shape

torch.Size([16, 3, 25])

In [129]:
y[:10]

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1,
         1],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0,
         1],
        [1, 1, 2, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 2, 1, 1, 1, 1, 0, 2,
         0],
        [1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 2, 1, 1, 1, 1, 0,
         1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1],
        [1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 2, 1, 0, 0, 0, 1, 1, 0, 0, 0,
         0],
        [1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 2, 1, 1, 1, 1, 1, 0, 1,
         1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 0, 2, 0, 1, 0, 0, 0,
         1]])

In [130]:
train_df.head()

Unnamed: 0,study_id,SCSL1L2,SCSL2L3,SCSL3L4,SCSL4L5,SCSL5S1,LNFNL1L2,LNFNL2L3,LNFNL3L4,LNFNL4L5,...,LSSL1L2,LSSL2L3,LSSL3L4,LSSL4L5,LSSL5S1,RSSL1L2,RSSL2L3,RSSL3L4,RSSL4L5,RSSL5S1
0,4003253,1,1,1,1,1,1,1,1,0,...,1,1,1,0,1,1,1,1,1,1
1,4646740,1,1,0,2,1,1,1,1,0,...,1,1,1,2,1,1,0,0,0,1
2,7143189,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,1,1,1,1
3,8785691,1,1,1,1,1,1,1,1,0,...,1,1,1,1,1,1,1,1,1,1
4,10728036,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,1,1,0,1


In [131]:
pred.softmax(dim=1)

tensor([[[0.1215, 0.2312, 0.2626,  ..., 0.3714, 0.3076, 0.3516],
         [0.7674, 0.5663, 0.3913,  ..., 0.2234, 0.0934, 0.2579],
         [0.1111, 0.2026, 0.3462,  ..., 0.4052, 0.5990, 0.3905]],

        [[0.1215, 0.2311, 0.2625,  ..., 0.3716, 0.3074, 0.3516],
         [0.7674, 0.5663, 0.3914,  ..., 0.2234, 0.0934, 0.2579],
         [0.1111, 0.2026, 0.3461,  ..., 0.4051, 0.5992, 0.3905]],

        [[0.1218, 0.2328, 0.2651,  ..., 0.3670, 0.3132, 0.3527],
         [0.7684, 0.5642, 0.3881,  ..., 0.2231, 0.0931, 0.2581],
         [0.1098, 0.2030, 0.3468,  ..., 0.4100, 0.5937, 0.3892]],

        ...,

        [[0.1217, 0.2326, 0.2645,  ..., 0.3683, 0.3115, 0.3523],
         [0.7681, 0.5645, 0.3886,  ..., 0.2231, 0.0934, 0.2581],
         [0.1101, 0.2029, 0.3469,  ..., 0.4086, 0.5952, 0.3896]],

        [[0.1215, 0.2313, 0.2627,  ..., 0.3711, 0.3079, 0.3516],
         [0.7675, 0.5661, 0.3910,  ..., 0.2233, 0.0934, 0.2579],
         [0.1110, 0.2026, 0.3463,  ..., 0.4055, 0.5987, 0.3905]],

 

In [132]:
pred.argmax(dim=1)[:10]

tensor([[1, 1, 1, 2, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 2, 2, 0, 1, 1, 2, 2,
         2],
        [1, 1, 1, 2, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 2, 2, 0, 1, 1, 2, 2,
         2],
        [1, 1, 1, 2, 1, 1, 1, 0, 0, 2, 1, 1, 0, 0, 0, 1, 1, 2, 2, 0, 1, 1, 2, 2,
         2],
        [1, 1, 1, 2, 1, 1, 1, 0, 0, 2, 1, 1, 0, 0, 0, 1, 1, 2, 2, 0, 1, 1, 2, 2,
         2],
        [1, 1, 1, 2, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 2, 2, 0, 1, 1, 2, 2,
         2],
        [1, 1, 1, 2, 1, 1, 1, 0, 0, 2, 1, 1, 0, 0, 0, 1, 1, 2, 2, 0, 1, 1, 2, 2,
         2],
        [1, 1, 1, 2, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 2, 2, 0, 1, 1, 2, 2,
         2],
        [1, 1, 1, 2, 1, 1, 1, 0, 0, 2, 1, 1, 0, 0, 0, 1, 1, 2, 2, 0, 1, 1, 2, 2,
         2],
        [1, 1, 1, 2, 1, 1, 1, 0, 0, 2, 1, 1, 0, 0, 0, 1, 1, 2, 2, 0, 1, 1, 2, 2,
         2],
        [1, 1, 1, 2, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 2, 2, 0, 1, 1, 2, 2,
         2]])

In [133]:
# pred.argmax(dim=1)[:10]