In [1]:
import numpy as np
import pandas as pd
import polars as pl
import os
import random
import time
import datetime
import warnings
import yaml

from pathlib import Path
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler

import albumentations as A
from albumentations.pytorch import ToTensorV2

from sklearn.model_selection import KFold, StratifiedGroupKFold, GroupKFold
from sklearn.metrics import log_loss

import pandas.api.types
import sklearn.metrics

import timm
import transformers
import pydicom
import cv2

warnings.simplefilter("ignore")

# Seeding

In [2]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Splits

In [3]:
# def split_fold(conf, df_):
#     print('split fold..')
#     df = df_.clone()
    
#     sgkf = StratifiedGroupKFold(n_splits=conf.fold_num, shuffle=True, random_state=conf.seed)
#     splitter = np.zeros(df.height)
    
#     for fold, (_, valid_idx) in enumerate(sgkf.split(X=df['study_id'], y=df['level'], groups=df['study_id'])):
#         splitter[valid_idx] = fold
    
#     df = df.with_columns(fold=pl.Series(splitter).cast(pl.Int8))
    
#     return df

# Dataset

In [4]:
class LSDCTrainDataset(Dataset):
    def __init__(self, conf, df, study_id_level_df, patient_coords, axial_IPP, transforms):
        super().__init__()
        self.conf = conf
        self.df = df
        self.study_id_level_df = study_id_level_df
        self.patient_coords = patient_coords
        self.axial_IPP = axial_IPP
        self.transforms = transforms

        self.label_names_cond = [
            # sagittal_t1
            'left_neural_foraminal_narrowing_normal_mild',
            'left_neural_foraminal_narrowing_moderate',
            'left_neural_foraminal_narrowing_severe',

            'right_neural_foraminal_narrowing_normal_mild',
            'right_neural_foraminal_narrowing_moderate',
            'right_neural_foraminal_narrowing_severe',

            # sagittal_t2
            'spinal_canal_stenosis_normal_mild',
            'spinal_canal_stenosis_moderate',
            'spinal_canal_stenosis_severe',

            # axial
            'left_subarticular_stenosis_normal_mild',
            'left_subarticular_stenosis_moderate',
            'left_subarticular_stenosis_severe',

            'right_subarticular_stenosis_normal_mild',
            'right_subarticular_stenosis_moderate',
            'right_subarticular_stenosis_severe',
        ]

    def __len__(self):
        return len(self.study_id_level_df)

    def __getitem__(self, idx):

        image_stack = np.zeros((self.conf.image_size, self.conf.image_size, self.conf.n_slices))
        chan = 0

        study_id_level_samples = self.study_id_level_df[idx]
        this_study_lvl = self.df.filter(pl.col('study_id_level') == study_id_level_samples)
        this_study = this_study_lvl['study_id'].item(0)
        sagittal_t1 = this_study_lvl.filter(pl.col('series_description') == 'Sagittal T1')
        sagittal_t2 = this_study_lvl.filter(pl.col('series_description') == 'Sagittal T2/STIR')
        axial_t2 = this_study_lvl.filter(pl.col('series_description') == 'Axial T2')
        labels = this_study_lvl.select(self.label_names_cond).to_numpy()[0]

        if not sagittal_t1.is_empty():
            sagittal_t1_path = sagittal_t1['series_path'].item(0)
            sagittal_t1_slices = sagittal_t1['slices'][0].to_numpy()
            sagittal_t1_coords = sagittal_t1.select(['x', 'y']).to_numpy().mean(axis=0)

            for ins_num in sagittal_t1_slices:

                dicom = np.load(f'{sagittal_t1_path}/{ins_num}.npy')
                window_half_size = (np.mean(dicom.shape) * self.conf.sagittal_window_ratio) // 2
                cropped = dicom[
                    round(sagittal_t1_coords[1] - window_half_size): round(sagittal_t1_coords[1] + window_half_size),
                    round(sagittal_t1_coords[0] - window_half_size): round(sagittal_t1_coords[0] + window_half_size),
                ]
                image_stack[:, :, chan] = cv2.resize(cropped, (self.conf.image_size, self.conf.image_size))
                chan += 1
        else:
            chan += 11 # skip if study does not have sagittal t1

        if not sagittal_t2.is_empty():

            sagittal_t2_path = sagittal_t2['series_path'].item(0)
            sagittal_t2_slices = sagittal_t2['slices'][0].to_numpy()
            sagittal_t2_coords = sagittal_t2.select(['x', 'y']).to_numpy().mean(axis=0)

            for ins_num in sagittal_t2_slices:

                dicom = np.load(f'{sagittal_t2_path}/{ins_num}.npy')
                window_half_size = (np.mean(dicom.shape) * self.conf.sagittal_window_ratio) // 2
                cropped = dicom[
                    round(sagittal_t2_coords[1] - window_half_size): round(sagittal_t2_coords[1] + window_half_size),
                    round(sagittal_t2_coords[0] - window_half_size): round(sagittal_t2_coords[0] + window_half_size),
                ]
                image_stack[:, :, chan] = cv2.resize(cropped, (self.conf.image_size, self.conf.image_size))
                chan += 1

            # can not cross-ref if sagittal t2 does not exist
            if not axial_t2.is_empty():
                axial_t2_path = axial_t2['series_path'].item(0)
                axial_t2_slices = axial_t2['slices'][0].to_numpy()
                axial_t2_coords = axial_t2.select(['x', 'y']).to_numpy().mean(axis=0)

                axial_z_axis = self.axial_IPP.filter(pl.col('study_id') == this_study)[['instance_number', 'SliceLocation']].to_numpy()
                sag_coord = self.patient_coords.filter(
                    pl.col('study_id_level') == study_id_level_samples,
                    pl.col('series_description') == 'Sagittal T2/STIR'
                )['patient_coords'].to_numpy()[0]

                level_idx = self._find_axial(axial_z_axis, sag_coord, self.conf.axial_min_dist_threshold)
                if level_idx is not None:
                    level_pack_slices = axial_t2_slices[
                        level_idx - 2: level_idx + 1
                    ]

                    for ins_num in level_pack_slices:
                        dicom = np.load(f'{axial_t2_path}/{ins_num}.npy')
                        window_half_size = (np.mean(dicom.shape) * self.conf.axial_window_ratio) // 2
                        cropped = dicom[
                            round(axial_t2_coords[1] - window_half_size): round(axial_t2_coords[1] + window_half_size),
                            round(axial_t2_coords[0] - window_half_size): round(axial_t2_coords[0] + window_half_size),
                        ]
#                         cropped = self._center_crop(dicom, (self.conf.axial_center_crop_size, self.conf.axial_center_crop_size))
                        image_stack[:, :, chan] = cv2.resize(cropped, (self.conf.image_size, self.conf.image_size))
                        chan += 1

        image_stack = (image_stack - image_stack.min()) / (image_stack.max() - image_stack.min())
        image_stack = image_stack.astype(np.float32)

        new_image_stack = np.concatenate([
            # left
            image_stack[:, :, 0: 3], # T1
            image_stack[:, :, 11: 14], # T2
            # middle
            image_stack[:, :, 3: 8], # T1
            image_stack[:, :, 14: 19], # T2
            # right
            image_stack[:, :, 8: 11], # T1
            image_stack[:, :, 19: 22], # T2
            # axial
            image_stack[:, :, 22: 25],
        ], axis=-1)

        if self.transforms is not None:
            new_image_stack = self.transforms(image=new_image_stack)['image']

        batch = {}
        batch['study_id_level'] = study_id_level_samples
        batch['images'] = new_image_stack.to(dtype=torch.float)
        batch['labels'] = torch.as_tensor(labels, dtype=torch.float)

        return batch

    def _read_dcm(self, path):
        data = pydicom.dcmread(path).pixel_array
        return data

    def _center_crop(self, img, dim):
        # https://gist.github.com/Nannigalaxy/35dd1d0722f29672e68b700bc5d44767
        """Returns center cropped image
        Args:
        img: image to be center cropped
        dim: dimensions (width, height) to be cropped
        """
        width, height = img.shape[1], img.shape[0]

        # process crop width and height for max available dimension
        crop_width = dim[0] if dim[0] < img.shape[1] else img.shape[1]
        crop_height = dim[1] if dim[1] < img.shape[0] else img.shape[0]
        mid_x, mid_y = int(width / 2), int(height / 2)
        cw2, ch2 = int(crop_width / 2), int(crop_height / 2)
        crop_img = img[mid_y - ch2: mid_y + ch2, mid_x - cw2: mid_x + cw2]
        return crop_img

    def _find_axial(self, axial_array, sag_coords, threshold):
        # axial_array (instance, IPP-z)
        slice_position = axial_array[:, 1] # axial z-axis
        distance = np.abs(slice_position - sag_coords[2])
        slice_pos = np.argmin(distance)

        if threshold != None:
            if axial_array[slice_pos][1] > threshold:
                return None

        return int(axial_array[slice_pos][0])

In [5]:
# sample = train_df.sample().item(0, 1)

In [6]:
# df = data_preprocess(CONF, train_df.filter(pl.col('study_id') == sample))

In [7]:
# conf = CONF()
# dataset = LSDCTrainDataset(conf, df, df['study_id_level'].unique(maintain_order=True), patient_coords, axial_IPP, get_transforms(CONF, types='valid'))

In [8]:
# data = dataset[0]

In [9]:
# print(data['study_id_level'])
# print(data['labels'].view(5, 3))
# fig = plt.figure(figsize=(20, 20))
# for i, img in enumerate(data['images']):
#     num = i + 1
#     ax = fig.add_subplot(5, 5, num, frameon=False)
#     ax.title.set_text(num)
#     plt.imshow(img, cmap='gray')
#     plt.axis('off')
# plt.show()

# Transforms

In [10]:
def get_transforms(conf, types):
    tranforms_dict = {
        'train': A.Compose([
#             A.VerticalFlip(p=0.5),
#             A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.25, rotate_limit=25, border_mode=4, p=0.7),
            A.RandomBrightnessContrast(p=0.7),
            A.OneOf([
                A.MotionBlur(blur_limit=3),
                A.MedianBlur(blur_limit=3),
                A.GaussianBlur(blur_limit=3),
                A.GaussNoise(var_limit=(3.0, 9.0)),
            ], p=0.5),
            A.OneOf([
                A.OpticalDistortion(distort_limit=1.),
                A.GridDistortion(num_steps=5, distort_limit=1.),
            ], p=0.5),
            A.CoarseDropout(max_height=int(conf.image_size * 0.1), max_width=int(conf.image_size * 0.1), max_holes=4, fill_value=0., p=0.5),
            ToTensorV2(),
        ]),

        'valid': A.Compose([
            ToTensorV2(),
        ]),
        'check': A.Compose([
            A.GridDistortion(num_steps=5, distort_limit=0.05, always_apply=True),
            ToTensorV2(),
        ]),
    }
    return tranforms_dict[types]

# Model

In [11]:
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = nn.Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        return self.gem(x, p=self.p, eps=self.eps)
        
    def gem(self, x, p=3, eps=1e-6):
        return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
        
    def __repr__(self):
        return self.__class__.__name__ + \
                '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
                ', ' + 'eps=' + str(self.eps) + ')'

In [12]:
# https://www.kaggle.com/bminixhofer/a-validation-framework-impact-of-the-random-seed
class Attention(nn.Module):
    def __init__(self, feature_dim, step_dim, bias=True, **kwargs):
        super(Attention, self).__init__(**kwargs)
        
        self.supports_masking = True

        self.bias = bias
        self.feature_dim = feature_dim
        self.step_dim = step_dim
        self.features_dim = 0
        
        weight = torch.zeros(feature_dim, 1)
        nn.init.xavier_uniform_(weight)
        self.weight = nn.Parameter(weight)
        
        if bias:
            self.b = nn.Parameter(torch.zeros(step_dim))
        
    def forward(self, x, mask=None):
        feature_dim = self.feature_dim
        step_dim = self.step_dim

        eij = torch.mm(
            x.contiguous().view(-1, feature_dim), 
            self.weight
        ).view(-1, step_dim)
        
        if self.bias:
            eij = eij + self.b
            
        eij = torch.tanh(eij)
        a = torch.exp(eij)
        
        if mask is not None:
            a = a * mask

        a = a / torch.sum(a, 1, keepdim=True) + 1e-10

        weighted_input = x * torch.unsqueeze(a, -1)
        return torch.sum(weighted_input, 1)

In [13]:
class LSDCModel(nn.Module):
    def __init__(self, conf, pretrained=False):
        super().__init__()
        self.head_type = conf.head_type
        if self.head_type == 'lstm_attn':
            self.head_feats = 512
        elif self.head_type == 'lstm_mean_max':
            self.head_feats = 512 * 2
        elif self.head_type == 'avg':
            self.head_feats = 512
        
        if 'vit' in conf.backbone or 'coat' in conf.backbone: 
            self.backbone = timm.create_model(
                conf.backbone,
                pretrained=pretrained,
                features_only=False,
                in_chans=conf.in_chans,
                img_size=conf.image_size
            )
        else:
            self.backbone = timm.create_model(
                conf.backbone,
                pretrained=pretrained,
                features_only=False,
                in_chans=conf.in_chans,
            )
        
        if 'efficientnet' in conf.backbone:
            in_features = self.backbone.classifier.in_features
            self.backbone.global_pool = nn.Identity()
            self.backbone.classifier = nn.Identity()
        elif 'convnext' in conf.backbone:
            in_features = self.backbone.head.fc.in_features
            self.backbone.head = nn.Identity()
        elif 'maxvit' in conf.backbone:
            in_features = self.backbone.head.fc.in_features
            self.backbone.head = nn.Identity()
        elif 'coat' in conf.backbone:
            in_features = self.backbone.head.fc.in_features
            self.backbone.head = nn.Identity()
        else:
            raise NotImplementedError

        self.pooler = self._get_pooling(conf.pooling)

        self.lstm_lnfn = nn.LSTM(in_features, 256, num_layers=1, dropout=0., bidirectional=True, batch_first=True)
        self.lstm_rnfn = nn.LSTM(in_features, 256, num_layers=1, dropout=0., bidirectional=True, batch_first=True)
        self.lstm_scs = nn.LSTM(in_features, 256, num_layers=1, dropout=0., bidirectional=True, batch_first=True)
        self.lstm_lss = nn.LSTM(in_features, 256, num_layers=1, dropout=0., bidirectional=True, batch_first=True)
        self.lstm_rss = nn.LSTM(in_features, 256, num_layers=1, dropout=0., bidirectional=True, batch_first=True)

        if 'attn' in self.head_type:
            self.attn_lnfn = Attention(512, conf.n_slices)
            self.attn_rnfn = Attention(512, conf.n_slices)
            self.attn_scs = Attention(512, conf.n_slices)
            self.attn_lss = Attention(512, conf.n_slices)
            self.attn_rss = Attention(512, conf.n_slices)

        self.head_lnfn = nn.Sequential(
            nn.Linear(self.head_feats, 256),
            nn.BatchNorm1d(256),
            nn.Dropout(0.1),
            nn.LeakyReLU(0.1),
            nn.Linear(256, conf.num_class // 5)
        )
        self.head_rnfn = nn.Sequential(
            nn.Linear(self.head_feats, 256),
            nn.BatchNorm1d(256),
            nn.Dropout(0.1),
            nn.LeakyReLU(0.1),
            nn.Linear(256, conf.num_class // 5)
        )
        self.head_scs = nn.Sequential(
            nn.Linear(self.head_feats, 256),
            nn.BatchNorm1d(256),
            nn.Dropout(0.1),
            nn.LeakyReLU(0.1),
            nn.Linear(256, conf.num_class // 5)
        )
        self.head_lss = nn.Sequential(
            nn.Linear(self.head_feats, 256),
            nn.BatchNorm1d(256),
            nn.Dropout(0.1),
            nn.LeakyReLU(0.1),
            nn.Linear(256, conf.num_class // 5)
        )
        self.head_rss = nn.Sequential(
            nn.Linear(self.head_feats, 256),
            nn.BatchNorm1d(256),
            nn.Dropout(0.1),
            nn.LeakyReLU(0.1),
            nn.Linear(256, conf.num_class // 5)
        )

    def _get_pooling(self, pooling_name):
        if pooling_name == 'avg':
            return nn.AdaptiveAvgPool2d((1))
        elif pooling_name == 'max':
            return nn.AdaptiveMaxPool2d((1))
        elif pooling_name == 'gem':
            return GeM()
        else:
            raise NotImplementedError

    def forward(self, inputs):
        
        bs, slices, img_h, img_w = inputs.shape

        inputs = inputs.view(bs * slices, 1, img_h, img_w)

        outputs = self.backbone(inputs)
        outputs = self.pooler(outputs)
        outputs = outputs.view(bs, slices, -1)
        
        outputs1, _ = self.lstm_lnfn(outputs)
        outputs2, _ = self.lstm_rnfn(outputs)
        outputs3, _ = self.lstm_scs(outputs)
        outputs4, _ = self.lstm_lss(outputs)
        outputs5, _ = self.lstm_rss(outputs)

        if self.head_type == 'lstm_attn':
            outputs1 = self.attn_lnfn(outputs1)
            outputs2 = self.attn_rnfn(outputs2)
            outputs3 = self.attn_scs(outputs3)
            outputs4 = self.attn_lss(outputs4)
            outputs5 = self.attn_rss(outputs5)
            
        elif self.head_type == 'lstm_mean_max':
            outputs1 = torch.concat([outputs1.mean(dim=1), outputs1.amax(dim=1)], dim=-1)
            outputs2 = torch.concat([outputs2.mean(dim=1), outputs2.amax(dim=1)], dim=-1)
            outputs3 = torch.concat([outputs3.mean(dim=1), outputs3.amax(dim=1)], dim=-1)
            outputs4 = torch.concat([outputs4.mean(dim=1), outputs4.amax(dim=1)], dim=-1)
            outputs5 = torch.concat([outputs5.mean(dim=1), outputs5.amax(dim=1)], dim=-1)
            
        elif self.head_type == 'avg':
            outputs1 = outputs1.contiguous().view(bs * slices, -1)
            outputs2 = outputs2.contiguous().view(bs * slices, -1)
            outputs3 = outputs3.contiguous().view(bs * slices, -1)
            outputs4 = outputs4.contiguous().view(bs * slices, -1)
            outputs5 = outputs5.contiguous().view(bs * slices, -1)
        
        outputs1 = self.head_lnfn(outputs1)
        outputs2 = self.head_rnfn(outputs2)
        outputs3 = self.head_scs(outputs3)
        outputs4 = self.head_lss(outputs4)
        outputs5 = self.head_rss(outputs5)
        
        if self.head_type == 'avg':
            outputs1 = outputs1.view(bs, slices, -1).mean(dim=1)
            outputs2 = outputs2.view(bs, slices, -1).mean(dim=1)
            outputs3 = outputs3.view(bs, slices, -1).mean(dim=1)
            outputs4 = outputs4.view(bs, slices, -1).mean(dim=1)
            outputs5 = outputs5.view(bs, slices, -1).mean(dim=1)

        return torch.concat([outputs1, outputs2, outputs3, outputs4, outputs5], dim=-1)

In [14]:
# conf = CONF()
# model = LSDCModel(conf)

In [15]:
# img = torch.rand(2, 25, 128, 128, dtype=torch.float)
# outputs = model(img)

In [16]:
# outputs.shape

# Official Metric Calculator

In [17]:
class ParticipantVisibleError(Exception):
    pass


def get_condition(full_location: str) -> str:
    # Given an input like spinal_canal_stenosis_l1_l2 extracts 'spinal'
    for injury_condition in ['spinal', 'foraminal', 'subarticular']:
        if injury_condition in full_location:
            return injury_condition
    raise ValueError(f'condition not found in {full_location}')


# def score(
def calculate_score(
        solution: pd.DataFrame,
        submission: pd.DataFrame,
        evaluate_condition_list: list = ['spinal', 'foraminal', 'subarticular'],
        row_id_column_name: str = 'row_id', ### modified: default
        any_severe_scalar: float = 1.0, ### modified: added default
    ) -> float:
    '''
    Pseudocode:
    1. Calculate the sample weighted log loss for each medical condition:
    2. Derive a new any_severe label.
    3. Calculate the sample weighted log loss for the new any_severe label.
    4. Return the average of all of the label group log losses as the final score, normalized for the number of columns in each group.
       This mitigates the impact of spinal stenosis having only half as many columns as the other two conditions.
    '''

    target_levels = ['normal_mild', 'moderate', 'severe']

    # Run basic QC checks on the inputs
    if not pandas.api.types.is_numeric_dtype(submission[target_levels].values):
        raise ParticipantVisibleError('All submission values must be numeric')

    if not np.isfinite(submission[target_levels].values).all():
        raise ParticipantVisibleError('All submission values must be finite')

    if solution[target_levels].min().min() < 0:
        raise ParticipantVisibleError('All labels must be at least zero')
    if submission[target_levels].min().min() < 0:
        raise ParticipantVisibleError('All predictions must be at least zero')

    solution['study_id'] = solution['row_id'].apply(lambda x: x.split('_')[0])
    solution['location'] = solution['row_id'].apply(lambda x: '_'.join(x.split('_')[1:]))
    solution['condition'] = solution['row_id'].apply(get_condition)

    del solution[row_id_column_name]
    del submission[row_id_column_name]
    assert sorted(submission.columns) == sorted(target_levels)

    submission['study_id'] = solution['study_id']
    submission['location'] = solution['location']
    submission['condition'] = solution['condition']

    condition_losses = []
    condition_weights = []

#     for condition in ['spinal', 'foraminal', 'subarticular']: ## modified
    for condition in evaluate_condition_list:
        condition_indices = solution.loc[solution['condition'] == condition].index.values
        condition_loss = sklearn.metrics.log_loss(
            y_true=solution.loc[condition_indices, target_levels].values,
            y_pred=submission.loc[condition_indices, target_levels].values,
            sample_weight=solution.loc[condition_indices, 'sample_weight'].values
        )
        condition_losses.append(condition_loss)
        condition_weights.append(1)

    any_severe_spinal_labels = pd.Series(solution.loc[solution['condition'] == 'spinal'].groupby('study_id')['severe'].max())
    any_severe_spinal_weights = pd.Series(solution.loc[solution['condition'] == 'spinal'].groupby('study_id')['sample_weight'].max())
    any_severe_spinal_predictions = pd.Series(submission.loc[submission['condition'] == 'spinal'].groupby('study_id')['severe'].max())

    any_severe_spinal_loss = sklearn.metrics.log_loss(
        y_true=any_severe_spinal_labels,
        y_pred=any_severe_spinal_predictions,
        sample_weight=any_severe_spinal_weights,
        labels=[0., 1.]  ### modified: make it also run in debug run
    )

    condition_losses.append(any_severe_spinal_loss)
    condition_weights.append(any_severe_scalar)
    return np.average(condition_losses, weights=condition_weights), any_severe_spinal_loss

# Utils

In [18]:
# https://www.kaggle.com/code/yasufuminakama/fb3-deberta-v3-base-baseline-train/notebook
class Averager:
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
    def get_average(self):
        return self.avg
    
    def get_value(self):
        return self.val

In [19]:
class TimerError(Exception):
    """A custom exception used to report errors in use of Timer class"""

class Timer:
    def __init__(self):
        self.split_time = []
        self._start_time = None

    def start(self):
        """Start a new timer"""
        if self._start_time is not None:
            raise TimerError(f"Timer is running. Use .stop() to stop it")

        self._start_time = time.perf_counter()

    def stop(self):
        """Stop the timer, and report the elapsed time"""
        if self._start_time is None:
            raise TimerError(f"Timer is not running. Use .start() to start it")
            
        self._start_time = None
    
    def get_time(self):
        if self._start_time is None:
            raise TimerError(f"Timer is not running. Use .start() to start it")
            
        return time.perf_counter() - self._start_time
    
    def split(self):
        if self._start_time is None:
            raise TimerError(f"Timer is not running. Use .start() to start it")
            
        self.split_time.append(time.perf_counter() - self._start_time)
    
    def get_split_time(self, idx):
        return self.split_time[idx]
    
    @staticmethod
    def formatting(second):
        return str(datetime.timedelta(seconds=round(second)))

## Data preprocess


In [20]:
def data_preprocess(conf, df_):
    print('data preprocessing..')
    df = df_.clone()

    df = df.with_columns(pl.concat_str([
        pl.lit(conf.data_path),
        pl.col('study_id'),
        pl.col('series_id'),
    ], separator='/').alias('series_path'))

    slice_samples_stack = {
        'series_id': [],
        'slices': [],
    }

    for name, data in df.group_by(['series_path'], maintain_order=True):
        plane = data['series_description'][0]
        series_path = Path(name[0])
        series_id = series_path.name

#         dcm_array = np.sort([int(path.stem) for path in series_path.glob('*.dcm')])
        dcm_array = np.sort([int(path.stem) for path in series_path.glob('*.npy')])
        if plane != 'Axial T2':

            series_info = sagittal_IPP.filter(pl.col('series_id') == int(series_id))
            mid_idx = dcm_array[len(dcm_array) // 2] # should be fine in most case but may be inaccurate, check later

            first = series_info.filter(pl.col('instance_number') == series_info['instance_number'].min())
            first_IPP_x = first["ImagePositionPatient_x"].item(0)
            first_sl = first["SliceLocation"].item(0)

            last = series_info.filter(pl.col('instance_number') == series_info['instance_number'].max())
            last_IPP_x = last["ImagePositionPatient_x"].item(0)
            last_sl = last["SliceLocation"].item(0)

            mid_info = series_info.filter(pl.col('instance_number') == mid_idx)

            neutral = series_info['SliceLocation'].to_numpy() - mid_info['SliceLocation'].to_numpy()
            side1_idx = np.argmin(np.abs(neutral - conf.spine_side_dist_mm)) + 1
            side2_idx = np.argmin(np.abs(neutral + conf.spine_side_dist_mm)) + 1

            side1_pack = dcm_array[side1_idx - 2: side1_idx + 1]
            mid_pack = dcm_array[mid_idx - 3: mid_idx + 2]
            side2_pack = dcm_array[side2_idx - 2: side2_idx + 1]

            cond1 = (first_IPP_x > last_IPP_x) and (first_sl > last_sl)
            cond2 = (first_IPP_x < last_IPP_x) and (first_sl < last_sl)

            if cond1 or cond2:
                pack = np.concatenate([side1_pack, mid_pack, side2_pack]) # this pack: less is right, greater is left
            else:
                pack = np.concatenate([side2_pack, mid_pack, side1_pack]) # this pack: lesser is left, greater is right
        else:
            pack = dcm_array # only axial t2 take all file name to process later in dataset

        slice_samples_stack['series_id'].append(series_id)
        slice_samples_stack['slices'].append(pack)

    slice_samples_stack = pl.DataFrame(slice_samples_stack).with_columns(pl.col('series_id').cast(pl.Int64))

    return df.join(slice_samples_stack, on='series_id')

In [21]:
class FocalLoss(nn.modules.loss._WeightedLoss):
    # https://github.com/gokulprasadthekkel/pytorch-multi-class-focal-loss/tree/master
    def __init__(self, weight=None, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__(weight, reduction=reduction)
        self.gamma = gamma
        self.weight = weight #weight parameter will act as the alpha parameter to balance class weights

    def forward(self, input, target):

        ce_loss = F.cross_entropy(input, target, reduction=self.reduction, weight=self.weight)
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
        return focal_loss

In [22]:
def read_dcm(path):
    dicom = pydicom.dcmread(path)
    data = dicom.pixel_array
    return data

def extract_config(conf_):
    config_dict = {}
    for k, v in vars(conf_).items():
        if not k.startswith('_'):
            config_dict[k] = v

    with open(Path(conf_.save_path, 'config.yaml'), 'w+') as file:
        yaml.dump(config_dict, file)

    print('Extracted config')

def get_dataloader(conf, df, train_dataset, valid_dataset):

    train_dataset = LSDCTrainDataset(conf, df, train_dataset, patient_coords, axial_IPP, get_transforms(conf, types='train'))
    valid_dataset = LSDCTrainDataset(conf, df, valid_dataset, patient_coords, axial_IPP, get_transforms(conf, types='valid'))
    
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=conf.batch_size,
        shuffle=True,
        num_workers=conf.num_workers,
        pin_memory=True,
        drop_last=True,
    )
    valid_dataloader = DataLoader(
        valid_dataset,
        batch_size=conf.batch_size,
        shuffle=False,
        num_workers=conf.num_workers,
        pin_memory=True,
        drop_last=False,
    )
    return train_dataloader, valid_dataloader

def get_model(conf):
    model = LSDCModel(conf, pretrained=True)
    return model

def get_criterion(conf_criterion, class_weight=None):
    criterion_dict = {
        'ce': nn.CrossEntropyLoss(weight=class_weight),
        'bce': nn.BCEWithLogitsLoss(weight=class_weight),
        'focal': FocalLoss(weight=class_weight),
    }
    return criterion_dict[conf_criterion]

def get_scheduler(conf, samples_per_epoch):
    scheduler_dict = {
        'cosine_warmup': {
            'scheduler': transformers.get_cosine_schedule_with_warmup,
            'hparams': {
                'num_warmup_steps': int(samples_per_epoch * conf.num_epochs * conf.warmup_ratios),
                'num_training_steps': samples_per_epoch * conf.num_epochs,
            }
        },
        'linear_warmup': {
            'scheduler':transformers.get_linear_schedule_with_warmup,
            'hparams': {
                'num_warmup_steps': int(samples_per_epoch * conf.num_epochs * conf.warmup_ratios),
                'num_training_steps': samples_per_epoch * conf.num_epochs,
            }
        },
        'constant_warmup': {
            'scheduler':transformers.get_constant_schedule_with_warmup,
            'hparams': {
                'num_warmup_steps': int(samples_per_epoch * conf.num_epochs * conf.warmup_ratios),
            }
        },
    }
    return scheduler_dict[conf.scheduler]['scheduler'], scheduler_dict[conf.scheduler]['hparams']

def get_optimizer(conf):
    optim_dict = {
        'adamw': optim.AdamW,
    }
    return optim_dict[conf.optimizer]

def formatting_predictions(oof_: pd.DataFrame):
    oof = pl.from_pandas(oof_.copy())
    formatted_df = pl.DataFrame()
    
    oof = (
        oof
        .melt(id_vars='study_id_level')
        .with_columns(pl.col('study_id_level').str.head(-6).alias('study_id'))
        .with_columns(pl.col('study_id_level').str.slice(-5).alias('level'))
    )

    for cond in ['normal_mild', 'moderate', 'severe']:
        tmp = (
            oof
            .filter(pl.col('variable').str.contains(f'{cond}'))
            .with_columns(pl.col('variable').str.head(-len(f'_{cond}')).alias('condition'))
            .with_columns(pl.concat_str([
                pl.col('study_id'),
                pl.col('condition'),
                pl.col('level'),
            ], separator='_').alias('row_id'))
            .rename({'value': f'{cond}'})
            .select(['row_id', f'{cond}']) # different from others
        )
        if formatted_df.is_empty():
            formatted_df = tmp
        else:
            formatted_df = formatted_df.join(tmp, on='row_id')
    
    # for filtering
    formatted_df = formatted_df.with_columns(pl.col('row_id').str.split('_').list.first().alias('study_id'))
    
    return formatted_df.sort('row_id').to_pandas()

def mixup(x, y, alpha=1.0):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    idx = torch.randperm(x.size(0))

    mixed_x = lam * x + (1 - lam) * x[idx]
    y_a, y_b = y, y[idx]
    return mixed_x, y_a, y_b, lam

def cutmix(data, targets, alpha=1.0):
# https://github.com/hysts/pytorch_cutmix/blob/master/cutmix.py
    idx = torch.randperm(data.size(0))
    image_h, image_w = data.size(-2), data.size(-1)
    shuffled_data = data[idx]
    shuffled_targets = targets[idx]

    lam = np.random.beta(alpha, alpha)

    cx = np.random.uniform(0, image_w)
    cy = np.random.uniform(0, image_h)
    w = image_w * np.sqrt(1 - lam)
    h = image_h * np.sqrt(1 - lam)
    x0 = int(np.round(max(cx - w / 2, 0)))
    x1 = int(np.round(min(cx + w / 2, image_w)))
    y0 = int(np.round(max(cy - h / 2, 0)))
    y1 = int(np.round(min(cy + h / 2, image_h)))

    data[:, :, y0:y1, x0:x1] = shuffled_data[:, :, y0:y1, x0:x1]

    return data, targets, shuffled_targets, lam

# Trainer

In [23]:
class Trainer:
    def __init__(
        self, debug_run, fold, conf, device, model, optimizer, scheduler, scheduler_hparams, criterion
    ):
        self.debug_run = debug_run
        self.current_fold = fold
        self.device = device
        self.model = model
        self.optimizer = optimizer(model.parameters(), lr=conf.lr, eps=conf.optim_eps, betas=(conf.optim_betas1, conf.optim_betas2))
        self.scheduler = scheduler(self.optimizer, **scheduler_hparams)
        self.criterion = criterion

        self.apex = conf.apex
        self.scaler = GradScaler(enabled=self.apex)
        self.exp_num = conf.exp
        self.num_class = conf.num_class
        self.num_epochs = conf.num_epochs
        self.verbose_step = conf.verbose_step
        self.save_path = conf.save_path
        self.use_mix = conf.use_mix
        self.mix_type = conf.mix_type
        self.mix_p = conf.mix_p
        self.clip_grad_norm = conf.clip_grad_norm
        self.max_grad_norm = conf.max_grad_norm

        self.best_score = 100
        self.best_valid_loss = 100

        self.oof_cols = [ # following alphabet sorting
            'study_id_level',
            # sagittal_t1
            'left_neural_foraminal_narrowing_normal_mild',
            'left_neural_foraminal_narrowing_moderate',
            'left_neural_foraminal_narrowing_severe',

            'right_neural_foraminal_narrowing_normal_mild',
            'right_neural_foraminal_narrowing_moderate',
            'right_neural_foraminal_narrowing_severe',

            # sagittal_t2
            'spinal_canal_stenosis_normal_mild',
            'spinal_canal_stenosis_moderate',
            'spinal_canal_stenosis_severe',

            # axial
            'left_subarticular_stenosis_normal_mild',
            'left_subarticular_stenosis_moderate',
            'left_subarticular_stenosis_severe',

            'right_subarticular_stenosis_normal_mild',
            'right_subarticular_stenosis_moderate',
            'right_subarticular_stenosis_severe',
        ]

        self.oof_df = pd.DataFrame()
        self.oof_valid_loss_df = pd.DataFrame()
        
        self.raw_oof_df = pd.DataFrame()
        self.raw_oof_valid_loss_df = pd.DataFrame()

        self.record_cols = ['fold', 'epoch', 'train_loss', 'valid_loss', 'score']
        self.record = pd.DataFrame(columns=self.record_cols)

    def fit(self, train_loader, valid_loader):
        self.model.to(self.device)

        self.log(f'exp: {self.exp_num}')
        self.log(f'--- FOLD {self.current_fold} ---')

        for epoch in range(self.num_epochs):
            self.current_epoch = epoch

            train_loss = self._train_fn(train_loader)

            valid_loss, ids_list, labels_list, outputs_list = self._valid_fn(valid_loader)

            this_epoch_score = self._eval_fn(ids_list, labels_list, outputs_list, valid_loss)

            self.record = pd.concat([
                self.record,
                pd.DataFrame(dict(zip(self.record_cols, np.array([
                    self.current_fold, self.current_epoch, train_loss, valid_loss, this_epoch_score
                ]).reshape(-1, 1))))
            ], axis=0)

            self.log(f'-- [Fold: {self.current_fold}, Epoch: {self.current_epoch + 1}] DONE --\n')

            if self.debug_run: break

        return self.record, self.oof_df, self.oof_valid_loss_df, self.raw_oof_df, self.raw_oof_valid_loss_df

    def _train_fn(self, train_loader):
        self.log('TRAINL_LOOP')
        self.model.train()
        total_loss = Averager()
        current_lr = self.scheduler.get_lr()[0]
        timer = Timer()
        timer.start()

        for step, batch in enumerate(train_loader):
            
            inputs = batch['images'].to(self.device, dtype=torch.float)
            labels = batch['labels'].to(self.device, dtype=torch.float)
            batchsize = labels.shape[0]

            if self.use_mix:
                is_mix_applied = False
                if random.random() > self.mix_p:
                    is_mix_applied = True
                    if self.mix_type == 'mixup':
                        inputs, labels, labels_mixed, lam = mixup(inputs, labels)
                    elif self.mix_type == 'cutmix':
                        inputs, labels, labels_mixed, lam = cutmix(inputs, labels)
                    else:
                        raise NotImplementedError

            with autocast(enabled=self.apex):
                outputs = self.model(inputs)

                loss = 0
                for c in range(0, self.num_class, 3):
                    loss += self.criterion(outputs[:, c: c + 3], labels[:, c: c + 3])
                loss /= (self.num_class // 3)
                
                if self.use_mix and is_mix_applied:
                    loss_mixed = 0
                    for c in range(0, self.num_class, 3):
                        loss_mixed += self.criterion(outputs[:, c: c + 3], labels_mixed[:, c: c + 3])
                    loss_mixed /= (self.num_class // 3)
                    loss = (loss * lam)  + (loss_mixed * (1 - lam))

            total_loss.update(loss.item(), batchsize)

            current_lr = self.scheduler.get_lr()[0]

            self.scaler.scale(loss).backward()
            if self.clip_grad_norm:
                self.scaler.unscale_(self.optimizer)
                grad_norm = nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
            self.scaler.step(self.optimizer)
            self.scaler.update()

            self.scheduler.step()
            self.optimizer.zero_grad()

            if step % self.verbose_step == 0 or step == (len(train_loader) - 1):
                self.log(
                    f'[TRAIN_F{self.current_fold}], ' + \
                    f'E: {self.current_epoch + 1}/{self.num_epochs}, ' + \
                    f'S: {str(step).zfill(len(str(len(train_loader))))}/{len(train_loader)}, ' + \
                    f'L: {total_loss.get_average():.5f}, ' + \
                    f'LR: {current_lr:.8f}, ' + \
                    f'T: {Timer.formatting(timer.get_time())}'
                )

            if self.debug_run: break
            # end of train loop
        timer.stop()

        return total_loss.get_average()

    def _valid_fn(self, valid_loader):
        self.log("\nVALID_LOOP")
        self.model.eval()

        ids_list = []
        outputs_list = []
        labels_list = []

        total_loss = Averager()
        timer = Timer()
        timer.start()

        for step, batch in enumerate(valid_loader):

            ids = batch['study_id_level']
            inputs = batch['images'].to(self.device, dtype=torch.float)
            labels = batch['labels'].to(self.device, dtype=torch.float)
            batchsize = labels.shape[0]

            with torch.no_grad():
                outputs = self.model(inputs)

                loss = 0
                for c in range(0, self.num_class, 3):
                    loss += self.criterion(outputs[:, c: c + 3], labels[:, c: c + 3])
                loss /= (self.num_class // 3)

            total_loss.update(loss.item(), batchsize)

            ids_list.extend(ids)
            labels_list.append(labels)
            outputs_list.append(outputs)

            if step % self.verbose_step == 0 or step == (len(valid_loader) - 1):
                self.log(
                    f'[VALID_F{self.current_fold}], ' + \
                    f'E: {self.current_epoch + 1}/{self.num_epochs}, ' + \
                    f'S: {str(step).zfill(len(str(len(valid_loader))))}/{len(valid_loader)}, ' + \
                    f'L: {total_loss.get_average():.5f}, ' + \
                    f'T: {Timer.formatting(timer.get_time())}'
                )
            if self.debug_run: break
            # end of the valid loop

        labels_list = torch.concat(labels_list).cpu()
        outputs_list = torch.concat(outputs_list).cpu()

        return total_loss.get_average(), ids_list, labels_list, outputs_list

    def _eval_fn(self, ids_list, y_trues, y_preds, valid_loss):
        batchsize = len(y_trues)
        score = 0
        
        this_raw_oof_df = pd.DataFrame(
                {self.oof_cols[0]: ids_list} | dict(zip(self.oof_cols[1:], y_preds.T))
            )

        for c in range(0, self.num_class, 3):
            y_preds[:, c: c + 3] = y_preds[:, c: c + 3].softmax(1)

        this_oof_df = pd.DataFrame(
                {self.oof_cols[0]: ids_list} | dict(zip(self.oof_cols[1:], y_preds.T))
            )
        this_oof_df_formatted = formatting_predictions(this_oof_df).drop('study_id', axis=1)
        this_solution_df = solution_df[solution_df['row_id'].isin(this_oof_df_formatted['row_id'])].reset_index(drop=True)

        check_condition_list = ['spinal', 'foraminal', 'subarticular']

        score, any_severe = calculate_score(
            this_solution_df.copy(),
            this_oof_df_formatted.copy(),
            check_condition_list,
        )

        self.log(f'\nscore: {score}')
        self.log(f'any_severe: {any_severe}')
        for indiv_cond in check_condition_list:
            indiv_score = calculate_score(
                this_solution_df.copy(),
                this_oof_df_formatted.copy(),
                [indiv_cond],
            )[0]
            self.log(f'{indiv_cond}: {indiv_score}')

        if self.best_score > score :
            self.best_score = score
            
            self.raw_oof_df = this_raw_oof_df
            self.oof_df = this_oof_df

            file_name = f'best_score_fold{self.current_fold}.pt'

            self.model.eval()
            torch.save({
                'model_state_dict': self.model.state_dict(),
                'exp': self.exp_num,
                'fold': self.current_fold,
                'epoch': self.current_epoch,
            }, Path(self.save_path, file_name))

            self.log(f'\n-> [SAVED] Fold: {self.current_fold}, Epoch: {self.current_epoch + 1}, score: {self.best_score}\n')

        if self.best_valid_loss > valid_loss:
            self.best_valid_loss = valid_loss

            self.raw_oof_valid_loss_df = this_raw_oof_df
            self.oof_valid_loss_df = this_oof_df

            file_name = f'best_loss_fold{self.current_fold}.pt'

            self.model.eval()
            torch.save({
                'model_state_dict': self.model.state_dict(),
                'exp': self.exp_num,
                'fold': self.current_fold,
                'epoch': self.current_epoch,
            }, Path(self.save_path, file_name))

            self.log(f'\n-> [SAVED] Fold: {self.current_fold}, Epoch: {self.current_epoch + 1}, valid_loss: {self.best_valid_loss}\n')

        return score

    def log(self, msg):
        print(msg)
        if not self.debug_run:
            with open(Path(self.save_path, 'train.log'), mode='a+', encoding='utf-8') as log:
                log.write(f'{msg}\n')

# Config

In [24]:
class CONF:
    exp = 'exp164'

    data_path = '/kaggle/input/rsna24-lsdc-npy/train_images'
    save_path = '/kaggle/working/'
    seed = 42

    backbone = 'tf_efficientnetv2_s.in21k_ft_in1k'
    # 'tf_efficientnetv2_s.in21k_ft_in1k' *

    # 'convnext_tiny.in12k_ft_in1k'
    # 'convnextv2_tiny.fcmae_ft_in22k_in1k'
    # 'convnextv2_nano.fcmae_ft_in22k_in1k'
    # 'convnextv2_pico.fcmae_ft_in1k'
    
    # 'coatnet_1_rw_224.sw_in1k'
    # 'coatnet_0_rw_224.sw_in1k'
    # 'coatnet_nano_rw_224.sw_in1k'

    # 'maxvit_nano_rw_256.sw_in1k'
    # 'maxvit_tiny_tf_224.in1k'
    # 'maxvit_tiny_tf_384.in1k'
    # 'maxvit_tiny_tf_512.in1k'


    pooling = 'avg' # ['avg', 'max', 'gem']
    head_type = 'lstm_attn' #['lstm_attn', 'lstm_mean_max', 'avg']
    
    fold_num = 4
    train_fold_list = [0, 1, 2, 3]

    sagittal_window_ratio = 0.12
    
    spine_side_dist_mm = 17

#     axial_center_crop_size = 224
    axial_window_ratio = 0.35
    axial_min_dist_threshold = None

    num_class = 15
    in_chans = 1
    n_slices = 25 # (11 + 11) + 3
    image_size = 128
    
    use_mix = False
    mix_type = 'mixup' # ['mixup', 'cutmix']
    mix_p = 0.5

    apex = True
    
    clip_grad_norm  = False
    max_grad_norm = 1.0

    criterion = 'ce'

    optimizer = 'adamw'
    lr = 4e-4
    optim_eps = 1e-6
    optim_betas1 = 0.9
    optim_betas2 = 0.999
    scheduler = 'cosine_warmup'
    num_epochs = 10
    warmup_ratios = 0.2

    batch_size = 8
    num_workers = 0
    verbose_step = 200

# Load data

In [25]:
path = '/kaggle/input/rsna24-lsdc-create-dataset/'

train_df = pl.read_csv(path + 'merged_train.csv')
labels_df = pl.read_csv(path + 'study_id_level_labels.csv')

solution_df = pd.read_csv(path + 'solution_df.csv')

patient_coords = pl.read_parquet(path + 'patient_coords.parquet')
axial_IPP = pl.read_parquet(path + 'axial_img_pos.parquet')
sagittal_IPP = pl.read_parquet(path + 'sagittal_img_pos.parquet')

train_df = train_df.join(labels_df, on='study_id_level')

display(train_df)

row_id,study_id,series_id,instance_number,series_description,condition,labels,level,x,y,name,study_id_level,fold,left_neural_foraminal_narrowing_normal_mild,left_neural_foraminal_narrowing_moderate,left_neural_foraminal_narrowing_severe,right_neural_foraminal_narrowing_normal_mild,right_neural_foraminal_narrowing_moderate,right_neural_foraminal_narrowing_severe,spinal_canal_stenosis_normal_mild,spinal_canal_stenosis_moderate,spinal_canal_stenosis_severe,left_subarticular_stenosis_normal_mild,left_subarticular_stenosis_moderate,left_subarticular_stenosis_severe,right_subarticular_stenosis_normal_mild,right_subarticular_stenosis_moderate,right_subarticular_stenosis_severe
str,i64,i64,i64,str,str,i64,str,f64,f64,str,str,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64
"""100206310_left_neural_foramina…",100206310,2092806862,13,"""Sagittal T1""","""left_neural_foraminal_narrowin…",0,"""l1_l2""",270.34225,148.221459,"""100206310_2092806862_0013""","""100206310_l1_l2""",1,1,0,0,0,1,0,1,0,0,1,0,0,1,0,0
"""100206310_left_neural_foramina…",100206310,2092806862,12,"""Sagittal T1""","""left_neural_foraminal_narrowin…",1,"""l2_l3""",260.177602,191.705532,"""100206310_2092806862_0012""","""100206310_l2_l3""",1,0,1,0,0,1,0,1,0,0,1,0,0,1,0,0
"""100206310_left_neural_foramina…",100206310,2092806862,13,"""Sagittal T1""","""left_neural_foraminal_narrowin…",1,"""l3_l4""",250.176889,234.398551,"""100206310_2092806862_0013""","""100206310_l3_l4""",1,0,1,0,0,0,1,0,1,0,0,1,0,0,1,0
"""100206310_left_neural_foramina…",100206310,2092806862,12,"""Sagittal T1""","""left_neural_foraminal_narrowin…",2,"""l4_l5""",249.241774,274.786914,"""100206310_2092806862_0012""","""100206310_l4_l5""",1,0,0,1,0,1,0,0,0,1,0,0,1,0,1,0
"""100206310_left_neural_foramina…",100206310,2092806862,12,"""Sagittal T1""","""left_neural_foraminal_narrowin…",1,"""l5_s1""",258.80649,319.853318,"""100206310_2092806862_0012""","""100206310_l5_s1""",1,0,1,0,1,0,0,1,0,0,0,1,0,0,1,0
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""992674144_spinal_canal_stenosi…",992674144,1576603050,9,"""Sagittal T2/STIR""","""spinal_canal_stenosis""",6,"""l1_l2""",190.682111,92.252252,"""992674144_1576603050_0009""","""992674144_l1_l2""",2,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0
"""992674144_spinal_canal_stenosi…",992674144,1576603050,9,"""Sagittal T2/STIR""","""spinal_canal_stenosis""",6,"""l2_l3""",182.033462,123.552124,"""992674144_1576603050_0009""","""992674144_l2_l3""",2,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0
"""992674144_spinal_canal_stenosi…",992674144,1576603050,9,"""Sagittal T2/STIR""","""spinal_canal_stenosis""",6,"""l3_l4""",175.855856,162.265122,"""992674144_1576603050_0009""","""992674144_l3_l4""",2,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0
"""992674144_spinal_canal_stenosi…",992674144,1576603050,9,"""Sagittal T2/STIR""","""spinal_canal_stenosis""",6,"""l4_l5""",175.032175,193.976834,"""992674144_1576603050_0009""","""992674144_l4_l5""",2,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0


# Training

In [26]:
def run_training(conf, df, debug_run=True):
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    class_weight = None # for balancing imbalance dataset
    seed_everything(seed=conf.seed)

    cv_record = pd.DataFrame()
    oof_df = pd.DataFrame()
    oof_valid_loss_df = pd.DataFrame()
    
    raw_oof_df = pd.DataFrame()
    raw_oof_valid_loss_df = pd.DataFrame()

    extract_config(conf)

    if debug_run:
        print('DEBUG_RUN')
        debug_samples = df.sample(n=8, seed=0).select('study_id')
        df = data_preprocess(conf, df.filter(pl.col('study_id').is_in(debug_samples)))
        conf.batch_size = 8
        conf.train_fold_list = [0]
    else:
        df = data_preprocess(conf, df)
        
#     df = split_fold(conf, df) # already split

    for fold in conf.train_fold_list:
        seed_everything(conf.seed)
        model = get_model(conf)

        train_dataset = df.filter(pl.col('fold') != fold)['study_id_level'].unique(maintain_order=True)
        valid_dataset = df.filter(pl.col('fold') == fold)['study_id_level'].unique(maintain_order=True)

        train_loader, valid_loader = get_dataloader(conf, df, train_dataset, valid_dataset)
        optimizer = get_optimizer(conf)
        scheduler, scheduler_hparams = get_scheduler(conf, len(train_loader))
        criterion = get_criterion(conf.criterion, class_weight)

        trainer = Trainer(
            debug_run, fold, conf, device, model, optimizer, scheduler, scheduler_hparams, criterion,
        )
        fold_record, fold_oof, fold_oof_valid_loss, fold_raw_oof, fold_raw_oof_valid_loss = trainer.fit(train_loader, valid_loader)

        oof_df = pd.concat([oof_df, fold_oof], axis=0)
        oof_valid_loss_df = pd.concat([oof_valid_loss_df, fold_oof_valid_loss], axis=0)

        raw_oof_df = pd.concat([raw_oof_df, fold_raw_oof], axis=0)
        raw_oof_valid_loss_df = pd.concat([raw_oof_valid_loss_df, fold_raw_oof_valid_loss], axis=0)
        
        cv_record = pd.concat([cv_record, fold_record], axis=0).reset_index(drop=True)

        if debug_run: break

    best_epoch_idx = [cv_record[cv_record['fold'] == i]['score'].idxmin() for i in conf.train_fold_list]
    best_epoch_record  = cv_record[cv_record.index.isin(best_epoch_idx)].reset_index(drop=True)

    cv_record[['fold', 'epoch']] = cv_record[['fold', 'epoch']].astype(int)
    best_epoch_record[['fold', 'epoch']] = best_epoch_record[['fold', 'epoch']].astype(int)

    display(cv_record)
    display(best_epoch_record)

    oof_df.to_csv(Path(conf.save_path, 'oof_df.csv'), index=False)
    oof_valid_loss_df.to_csv(Path(conf.save_path, 'oof_valid_loss_df.csv'), index=False)
    raw_oof_df.to_csv(Path(conf.save_path, 'raw_oof_df.csv'), index=False)
    raw_oof_valid_loss_df.to_csv(Path(conf.save_path, 'raw_oof_valid_loss_df.csv'), index=False)
    cv_record.to_csv(Path(conf.save_path, 'cv_record.csv'), index=False)
    best_epoch_record.to_csv(Path(conf.save_path, 'best_epoch_record.csv'), index=False)

In [27]:
run_training(CONF, train_df, debug_run=False)

Extracted config
data preprocessing..


model.safetensors:   0%|          | 0.00/86.5M [00:00<?, ?B/s]

exp: exp164
--- FOLD 0 ---
TRAINL_LOOP
[TRAIN_F0], E: 1/10, S: 000/840, L: 1.14990, LR: 0.00000000, T: 0:00:04
[TRAIN_F0], E: 1/10, S: 200/840, L: 1.04115, LR: 0.00004762, T: 0:06:34
[TRAIN_F0], E: 1/10, S: 400/840, L: 0.87740, LR: 0.00009524, T: 0:10:50
[TRAIN_F0], E: 1/10, S: 600/840, L: 0.77560, LR: 0.00014286, T: 0:14:22
[TRAIN_F0], E: 1/10, S: 800/840, L: 0.71871, LR: 0.00019048, T: 0:17:54
[TRAIN_F0], E: 1/10, S: 839/840, L: 0.70780, LR: 0.00019976, T: 0:18:35

VALID_LOOP
[VALID_F0], E: 1/10, S: 000/279, L: 0.44782, T: 0:00:01
[VALID_F0], E: 1/10, S: 200/279, L: 0.41154, T: 0:02:58
[VALID_F0], E: 1/10, S: 278/279, L: 0.41790, T: 0:04:09

score: 0.5957185865725956
any_severe: 0.5506159415968659
spinal: 0.4826597098266733
foraminal: 0.6250869118275185
subarticular: 0.6343064930878652

-> [SAVED] Fold: 0, Epoch: 1, score: 0.5957185865725956


-> [SAVED] Fold: 0, Epoch: 1, valid_loss: 0.417896091897926

-- [Fold: 0, Epoch: 1] DONE --

TRAINL_LOOP
[TRAIN_F0], E: 2/10, S: 000/840, L: 0

Unnamed: 0,fold,epoch,train_loss,valid_loss,score
0,0,0,0.707798,0.417896,0.595719
1,0,1,0.497835,0.386945,0.469837
2,0,2,0.489835,0.380442,0.469765
3,0,3,0.459858,0.362411,0.459455
4,0,4,0.444063,0.373686,0.512106
5,0,5,0.433518,0.344301,0.430301
6,0,6,0.405309,0.346397,0.443741
7,0,7,0.385494,0.33622,0.417567
8,0,8,0.363452,0.334638,0.419807
9,0,9,0.352443,0.334051,0.414386


Unnamed: 0,fold,epoch,train_loss,valid_loss,score
0,0,9,0.352443,0.334051,0.414386
1,1,9,0.353751,0.348767,0.415112
2,2,8,0.356195,0.367278,0.471654
3,3,9,0.348863,0.360099,0.460119
