## load data

In [None]:
import os
import time

In [None]:
#if not os.path.isdir('/content/train_videos'):
#    !unzip -q '/content/drive/MyDrive/image224.zip' -d '/content/train_videos'

In [None]:
if not os.path.isdir('/content/box'):
    !unzip -q '/content/drive/MyDrive/output.zip' -d 'box'

In [None]:
if not os.path.isdir('/content/boxv2'):
    !unzip -q '/content/drive/MyDrive/outputv2.zip' -d 'boxv2'

In [None]:
if not os.path.isdir('/content/mask'):
    !unzip -q '/content/drive/MyDrive/mask.zip' -d 'mask'

## library

In [None]:
!pip install pydicom
!pip install nibabel
!pip install timm
!pip install transformers
!pip install -U albumentations



In [None]:
import pandas as pd
import numpy as np

import cv2
import zipfile
import os
import gc
import glob
import shutil
import time
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image

import pydicom #as dicom
import nibabel as nib

import timm

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
from torchvision.transforms import v2
from concurrent.futures import ThreadPoolExecutor

import albumentations as A

from transformers.optimization import get_cosine_schedule_with_warmup

import numpy as np
import pandas as pd

import math
import pandas.api.types
import sklearn.metrics
from sklearn.metrics import f1_score, accuracy_score, roc_auc_score

## preprocess

In [None]:
from sklearn.model_selection import KFold

def preprocess():
    train_series_meta = pd.read_csv('/content/drive/MyDrive/Kaggle/RSNA 2023 Abdominal Trauma Detection/DataSources/train_series_meta.csv')
    train_series_meta = train_series_meta.sort_values(by='patient_id').reset_index(drop=True)
    train = pd.read_csv('/content/drive/MyDrive/Kaggle/RSNA 2023 Abdominal Trauma Detection/DataSources/train.csv')
    train = train.sort_values(by='patient_id').reset_index(drop=True)
    _train = []
    series_ids = []
    for i in range(len(train_series_meta)):
      patient_id, series_id, _, _ = train_series_meta.loc[i]
      sample = train[train['patient_id']==patient_id]
      _train.append(sample)
      series_ids.append(int(series_id))

    _train = pd.concat(_train).reset_index(drop=True)
    _train['series_id'] = series_ids

    train = _train

    injury_train = train[train['any_injury']==1].reset_index(drop=True)
    normal_train = train[train['any_injury']==0].reset_index(drop=True)

    kf = KFold(n_splits=5)
    injury_folds = []
    for i, (train_index, test_index) in enumerate(kf.split(injury_train)):
      train_df = injury_train.loc[train_index]
      val_df = injury_train.loc[test_index]
      injury_folds.append([train_df, val_df])


    kf = KFold(n_splits=5)
    normal_folds = []
    for i, (train_index, test_index) in enumerate(kf.split(normal_train)):
      train_df = normal_train.loc[train_index]
      val_df = normal_train.loc[test_index]
      normal_folds.append([train_df, val_df])
    return train, injury_folds, normal_folds

train, injury_folds, normal_folds = preprocess()

## utils

In [None]:
class ParticipantVisibleError(Exception):
    pass


def normalize_probabilities_to_one(df: pd.DataFrame, group_columns: list) -> pd.DataFrame:
    # Normalize the sum of each row's probabilities to 100%.
    # 0.75, 0.75 => 0.5, 0.5
    # 0.1, 0.1 => 0.5, 0.5
    row_totals = df[group_columns].sum(axis=1)
    if row_totals.min() == 0:
        raise ParticipantVisibleError('All rows must contain at least one non-zero prediction')
    for col in group_columns:
        df[col] /= row_totals
    return df


def score(solution_:pd.DataFrame, submission_:pd.DataFrame, row_id_column_name:str, reduction:str ='mean'):
    '''
    Pseudocode:
    1. For every label group (liver, bowel, etc):
        - Normalize the sum of each row's probabilities to 100%.
        - Calculate the sample weighted log loss.
    2. Derive a new any_injury label by taking the max of 1 - p(healthy) for each label group
    3. Calculate the sample weighted log loss for the new label group
    4. Return the average of all of the label group log losses as the final score.
    '''
    solution = solution_.copy()
    submission = submission_.copy()
    del solution[row_id_column_name]
    del submission[row_id_column_name]

    # Run basic QC checks on the inputs
    if not pandas.api.types.is_numeric_dtype(submission.values):
        raise ParticipantVisibleError('All submission values must be numeric')

    if not np.isfinite(submission.values).all():
        raise ParticipantVisibleError('All submission values must be finite')

    if solution.min().min() < 0:
        raise ParticipantVisibleError('All labels must be at least zero')
    if submission.min().min() < 0:
        raise ParticipantVisibleError('All predictions must be at least zero')

    # Calculate the label group log losses
    binary_targets = ['bowel', 'extravasation']
    triple_level_targets = ['kidney', 'liver', 'spleen']
    all_target_categories = binary_targets + triple_level_targets

    label_group_losses = []
    for category in all_target_categories:
        if category in binary_targets:
            col_group = [f'{category}_healthy', f'{category}_injury']
        else:
            col_group = [f'{category}_healthy', f'{category}_low', f'{category}_high']

        solution = normalize_probabilities_to_one(solution, col_group)

        for col in col_group:
            if col not in submission.columns:
                raise ParticipantVisibleError(f'Missing submission column {col}')
        submission = normalize_probabilities_to_one(submission, col_group)
        label_group_losses.append(
            sklearn.metrics.log_loss(
                y_true=solution[col_group].values,
                y_pred=submission[col_group].values,
                sample_weight=solution[f'{category}_weight'].values
            )
        )

    # Derive a new any_injury label by taking the max of 1 - p(healthy) for each label group
    healthy_cols = [x + '_healthy' for x in all_target_categories]
    any_injury_labels = (1 - solution[healthy_cols]).max(axis=1)
    any_injury_predictions = (1 - submission[healthy_cols]).max(axis=1)
    any_injury_loss = sklearn.metrics.log_loss(
        y_true=any_injury_labels.values,
        y_pred=any_injury_predictions.values,
        sample_weight=solution['any_injury_weight'].values
    )

    label_group_losses.append(any_injury_loss)
    if reduction == 'mean':
        return np.mean(label_group_losses)
    else:
        return label_group_losses

# Assign the appropriate weights to each category
def create_training_solution(y_train):
    sol_train = y_train.copy()

    # bowel healthy|injury sample weight = 1|2
    sol_train['bowel_weight'] = np.where(sol_train['bowel_injury'] == 1, 2, 1)

    # extravasation healthy/injury sample weight = 1|6
    sol_train['extravasation_weight'] = np.where(sol_train['extravasation_injury'] == 1, 6, 1)

    # kidney healthy|low|high sample weight = 1|2|4
    sol_train['kidney_weight'] = np.where(sol_train['kidney_low'] == 1, 2, np.where(sol_train['kidney_high'] == 1, 4, 1))

    # liver healthy|low|high sample weight = 1|2|4
    sol_train['liver_weight'] = np.where(sol_train['liver_low'] == 1, 2, np.where(sol_train['liver_high'] == 1, 4, 1))

    # spleen healthy|low|high sample weight = 1|2|4
    sol_train['spleen_weight'] = np.where(sol_train['spleen_low'] == 1, 2, np.where(sol_train['spleen_high'] == 1, 4, 1))

    # any healthy|injury sample weight = 1|6
    sol_train['any_injury_weight'] = np.where(sol_train['any_injury'] == 1, 6, 1)
    return sol_train


def get_high_aortic_hu(df):
    patient_ids = sorted(df.patient_id.unique())

    high_aortic_hu_df = []
    for i in range(len(patient_ids)):
        patient_id = int(patient_ids[i])
        sample = df.query(f'patient_id=={patient_id}').sort_values('aortic_hu', ascending=False).reset_index(drop=True)
        sample = sample.loc[0]
        high_aortic_hu_df.append(sample)

    high_aortic_hu_df = pd.concat(high_aortic_hu_df, axis=1).transpose().reset_index(drop=True)
    high_aortic_hu_df = high_aortic_hu_df.astype('int32')
    return high_aortic_hu_df

## preload

In [None]:
train, _, _ = preprocess()

if not os.path.isdir(f'/content/train_videos'):
    os.mkdir(f'/content/train_videos/')

for i in tqdm(range(len(train))):
    sample = train.loc[i]
    patient_id, series_id = int(sample['patient_id']), int(sample['series_id'])

    if not os.path.isdir(f'/content/train_videos/{patient_id}'):
        os.mkdir(f'/content/train_videos/{patient_id}')

len(glob.glob('/content/train_videos/*'))

100%|██████████| 4711/4711 [00:00<00:00, 18598.85it/s]


3147

In [None]:
class LoadDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        sample = self.df.loc[index]
        patient_id, series_id, any_injury = int(sample['patient_id']), int(sample['series_id']), int(sample['any_injury'])

        images = np.load(f'/content/drive/MyDrive/train_videos/{patient_id}/{series_id}_images.npy', mmap_mode='r')
        crop_liver = np.load(f'/content/drive/MyDrive/train_videos/{patient_id}/{series_id}_liver.npy', mmap_mode='r')
        crop_spleen = np.load(f'/content/drive/MyDrive/train_videos/{patient_id}/{series_id}_spleen.npy', mmap_mode='r')
        crop_kidney = np.load(f'/content/drive/MyDrive/train_videos/{patient_id}/{series_id}_kidney.npy', mmap_mode='r')

        np.save(f'/content/train_videos/{patient_id}/{series_id}_images.npy', images)
        np.save(f'/content/train_videos/{patient_id}/{series_id}_liver.npy', crop_liver)
        np.save(f'/content/train_videos/{patient_id}/{series_id}_spleen.npy', crop_spleen)
        np.save(f'/content/train_videos/{patient_id}/{series_id}_kidney.npy', crop_kidney)

        return torch.zeros(1)

train, _, _ = preprocess()
dataset = LoadDataset(train)
loader = torch.utils.data.DataLoader(dataset, batch_size = 2, num_workers = 12, shuffle = False, drop_last = False)

for i, _ in enumerate(tqdm(loader)):
    pass

100%|██████████| 2356/2356 [29:09<00:00,  1.35it/s]


## dataset

In [None]:
class CustomAug(nn.Module):
    def __init__(self, prob = 0.5, s = 224):
        super(CustomAug, self).__init__()
        self.prob = prob

        self.do_random_rotate = v2.RandomRotation(
            degrees = (-45, 45),
            interpolation = torchvision.transforms.InterpolationMode.BILINEAR,
            expand = False,
            center = None,
            fill = 0
        )
        self.do_random_scale = v2.ScaleJitter(
            target_size = [s, s],
            scale_range = (0.8, 1.2),
            interpolation = torchvision.transforms.InterpolationMode.BILINEAR,
            antialias = True)

        self.do_random_crop = v2.RandomCrop(
            size = [s, s],
            #padding = None,
            pad_if_needed = True,
            fill = 0,
            padding_mode = 'constant'
        )

        self.do_horizontal_flip = v2.RandomHorizontalFlip(self.prob)
        self.do_vertical_flip = v2.RandomVerticalFlip(self.prob)
    def forward(self, x):
        if np.random.rand() < self.prob:
            x = self.do_random_rotate(x)

        if np.random.rand() < self.prob:
            x = self.do_random_scale(x)
            x = self.do_random_crop(x)

        x = self.do_horizontal_flip(x)
        x = self.do_vertical_flip(x)
        return x

aug_function = CustomAug()

In [None]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, df, augmentation=False):
        self.df = df
        self.label_columns = self.df.columns[1:-2]

        #self.down_sampling = 1
        self.max_frame = 256
        self.img_size = 224

        self.augmentation = augmentation

        self.sample_weights = {
            'bowel' : {0:1, 1:2},
            'extravasation' : {0:1, 1:6},
            'kidney' : {0:1, 1:2, 2:4},
            'liver' : {0:1, 1:2, 2:4},
            'spleen' : {0:1, 1:2, 2:4},
            'any_injury' : {0:1, 1:6}
            }

        self.sample_weights = {
            'bowel' : {0:1, 1:1},
            'extravasation' : {0:1, 1:1},
            'kidney' : {0:1, 1:1, 2:1},
            'liver' : {0:1, 1:1, 2:1},
            'spleen' : {0:1, 1:1, 2:1},
            'any_injury' : {0:1, 1:1}
            }

    def __len__(self):
        return len(self.df)

    def get_stride_box(self, min_y, min_x, max_y, max_x, stride=10):
        min_y = np.clip(min_y - stride, a_min=0, a_max=512)
        min_x = np.clip(min_x - stride, a_min=0, a_max=512)
        max_y = np.clip(max_y + stride, a_min=0, a_max=512)
        max_x = np.clip(max_x + stride, a_min=0, a_max=512)
        return min_y, min_x, max_y, max_x

    def get_cropped_organs(self, video, box, ratio=(512/320)):
        organs = []
        for i in range(box.shape[0]):
          min_z, min_y, min_x, max_z, max_y, max_x = box[i]
          if 0.0 not in [max_z - min_z, max_y - min_y, max_x - min_x]:
            min_y, min_x, max_y, max_x = int(ratio*min_y), int(ratio*min_x), int(ratio*max_y), int(ratio*max_x)
            min_y, min_x, max_y, max_x = self.get_stride_box(min_y, min_x, max_y, max_x)
            print(max_z-min_z, max_y-min_y, max_x-min_x)
            organ = video[min_z:max_z, min_y:max_y, min_x:max_x]
          else:
            organ = video

          organ = F.interpolate(
              organ.unsqueeze(0).unsqueeze(0),
              size=[96, 224, 224],
              mode='trilinear'
              ).squeeze(0).squeeze(0)
          organs.append(organ)
        return organs

    def load_image(self, image_path, img_size=512):
        return cv2.resize(cv2.imread(image_path)[:,:,0], dsize=(img_size, img_size))


    def __getitem__(self, index):
        sample = self.df.loc[index]
        patient_id, series_id, any_injury = int(sample['patient_id']), int(sample['series_id']), int(sample['any_injury'])

        '''
        images_path = sorted(glob.glob( f'/content/train_images/{int(patient_id)}/{int(series_id)}' + '/*'),key = lambda x : int(x.split('/')[-1].split('.')[0]))
        with ThreadPoolExecutor(max_workers=8) as executor:
            images = list(executor.map(self.load_image, images_path))

        images = np.stack(images)
        images = torch.tensor(images, dtype=torch.float)

        images = F.interpolate(
            images.unsqueeze(0).unsqueeze(0),
            size=[256, 512, 512],
            mode='trilinear'
            ).squeeze(0).squeeze(0)

        box = np.load(f'/content/box/{int(patient_id)}/{int(series_id)}.npy', mmap_mode='r')

        organs = self.get_cropped_organs(images, box)
        crop_liver, crop_spleen, crop_kidney = organs

        images = F.interpolate(
            images.unsqueeze(0).unsqueeze(0),
            size=[128, 224, 224],
            mode='trilinear'
            ).squeeze(0).squeeze(0)
        '''

        images = torch.tensor(np.load(f'/content/train_videos/{patient_id}/{series_id}_images.npy', mmap_mode='r'), dtype=torch.float)
        crop_liver = torch.tensor(np.load(f'/content/train_videos/{patient_id}/{series_id}_liver.npy', mmap_mode='r'), dtype=torch.float)
        crop_spleen = torch.tensor(np.load(f'/content/train_videos/{patient_id}/{series_id}_spleen.npy', mmap_mode='r'), dtype=torch.float)
        crop_kidney = torch.tensor(np.load(f'/content/train_videos/{patient_id}/{series_id}_kidney.npy', mmap_mode='r'), dtype=torch.float)

        if self.augmentation:
          images = aug_function(images)
          crop_liver = aug_function(crop_liver)
          crop_spleen = aug_function(crop_spleen)
          crop_kidney = aug_function(crop_kidney)

        label = torch.tensor(sample[self.label_columns].values, dtype=torch.long)

        bowel = label[0:2].argmax()
        extravasation = label[2:4].argmax()
        kidney = label[4:7].argmax()
        liver = label[7:10].argmax()
        spleen = label[10:13].argmax()
        any_injury = torch.tensor(any_injury, dtype=torch.float)

        sample_weights = torch.tensor([
            self.sample_weights['bowel'][bowel.tolist()],
            self.sample_weights['extravasation'][extravasation.tolist()],
            self.sample_weights['kidney'][kidney.tolist()],
            self.sample_weights['liver'][liver.tolist()],
            self.sample_weights['spleen'][spleen.tolist()],
            self.sample_weights['any_injury'][any_injury.tolist()]
        ])

        images, crop_liver, crop_spleen, crop_kidney = images/255.0, crop_liver/255.0, crop_spleen/255.0, crop_kidney/255.0

        return images, crop_liver, crop_spleen, crop_kidney, label, bowel, extravasation, kidney, liver, spleen, any_injury, sample_weights

if __name__ == "__main__":
    #train, _, _ = preprocess()
    dataset = CustomDataset(train, False)
    index = np.random.randint(0, len(train)-1)
    sample = dataset[index]
    print(sample[0].shape)
    print(sample[1].shape)
    print(sample[2].shape)
    print(sample[3].shape)
    print(sample[4:])

torch.Size([128, 224, 224])
torch.Size([96, 224, 224])
torch.Size([96, 224, 224])
torch.Size([96, 224, 224])
(tensor([1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1]), tensor(0), tensor(0), tensor(0), tensor(1), tensor(0), tensor(1.), tensor([1, 1, 1, 1, 1, 1]))


In [None]:
id = 48

index = np.random.randint(0, len(train)-1)
sample = dataset[index]

fig, axs = plt.subplots(2, 2, figsize=(8, 8))
axs[0, 0].imshow(sample[0][id], cmap='gray')
axs[0, 1].imshow(sample[1][id], cmap='gray')
axs[1, 0].imshow(sample[2][id], cmap='gray')
axs[1, 1].imshow(sample[3][id], cmap='gray')

plt.show()

## model

In [None]:
from transformers import RobertaPreLayerNormConfig, RobertaPreLayerNormModel

class FeatureExtractor(nn.Module):
    def __init__(self, hidden, num_channel):
        super(FeatureExtractor, self).__init__()

        self.hidden = hidden
        self.num_channel = num_channel

        self.cnn = timm.create_model(model_name = 'regnety_002',
                                     pretrained = True,
                                     num_classes = 0,
                                     in_chans = num_channel)

        self.fc = nn.Linear(hidden, hidden//2)

    def forward(self, x):
        batch_size, num_frame, h, w = x.shape
        x = x.reshape(batch_size, num_frame//self.num_channel, self.num_channel, h, w)
        x = x.reshape(-1, self.num_channel, h, w)
        x = self.cnn(x)
        x = x.reshape(batch_size, num_frame//self.num_channel, self.hidden)

        x = self.fc(x)
        return x

class ContextProcessor(nn.Module):
    def __init__(self, hidden):
        super(ContextProcessor, self).__init__()
        self.transformer = RobertaPreLayerNormModel(
            RobertaPreLayerNormConfig(
                hidden_size = hidden//2,
                num_hidden_layers = 1,
                num_attention_heads = 4,
                intermediate_size = hidden*2,
                hidden_act = 'gelu_new',
                )
            )

        del self.transformer.embeddings.word_embeddings

        self.dense = nn.Linear(hidden, hidden)
        self.activation = nn.ReLU()


    def forward(self, x):
        x = self.transformer(inputs_embeds = x).last_hidden_state

        apool = torch.mean(x, dim = 1)
        mpool, _ = torch.max(x, dim = 1)
        x = torch.cat([mpool, apool], dim = -1)

        x = self.dense(x)
        x = self.activation(x)
        return x

class Custom3DCNN(nn.Module):
    def __init__(self, hidden = 368, num_channel = 2):
        super(Custom3DCNN, self).__init__()

        self.full_extractor = FeatureExtractor(hidden=hidden, num_channel=num_channel)
        self.kidney_extractor = FeatureExtractor(hidden=hidden, num_channel=num_channel)
        self.liver_extractor = FeatureExtractor(hidden=hidden, num_channel=num_channel)
        self.spleen_extractor = FeatureExtractor(hidden=hidden, num_channel=num_channel)

        self.full_processor = ContextProcessor(hidden=hidden)
        self.kidney_processor = ContextProcessor(hidden=hidden)
        self.liver_processor = ContextProcessor(hidden=hidden)
        self.spleen_processor = ContextProcessor(hidden=hidden)

        self.bowel = nn.Linear(hidden, 2)
        self.extravasation = nn.Linear(hidden, 2)
        self.kidney = nn.Linear(hidden, 3)
        self.liver = nn.Linear(hidden, 3)
        self.spleen = nn.Linear(hidden, 3)

        self.softmax = nn.Softmax(dim = -1)

    def forward(self, full_input, crop_liver, crop_spleen, crop_kidney, mask, mode):
        full_output = self.full_extractor(full_input)
        kidney_output = self.kidney_extractor(crop_kidney)
        liver_output = self.liver_extractor(crop_liver)
        spleen_output = self.spleen_extractor(crop_spleen)

        full_output2 = self.full_processor(torch.cat([full_output, kidney_output, liver_output, spleen_output], dim = 1))
        kidney_output2 = self.kidney_processor(torch.cat([full_output, kidney_output], dim = 1))
        liver_output2 = self.liver_processor(torch.cat([full_output, liver_output], dim = 1))
        spleen_output2 = self.spleen_processor(torch.cat([full_output, spleen_output], dim = 1))

        bowel = self.bowel(full_output2)
        extravasation = self.extravasation(full_output2)
        kidney = self.kidney(kidney_output2)
        liver = self.liver(liver_output2)
        spleen = self.spleen(spleen_output2)


        any_injury = torch.stack([
            self.softmax(bowel)[:, 0],
            self.softmax(extravasation)[:, 0],
            self.softmax(kidney)[:, 0],
            self.softmax(liver)[:, 0],
            self.softmax(spleen)[:, 0]
        ], dim = -1)
        any_injury = 1 - any_injury
        any_injury, _ = any_injury.max(1)
        return bowel, extravasation, kidney, liver, spleen, any_injury

In [None]:
device = 'cuda'
model = Custom3DCNN().to(device).float()
loader = torch.utils.data.DataLoader(dataset, batch_size = 2, num_workers = 2, shuffle = True, drop_last = True)

with torch.no_grad():
    sample = next(iter(loader))
    output = model(sample[0].to(device), sample[1].to(device), sample[2].to(device), sample[3].to(device), sample[4].to(device), mode = 'train')
    print(output)

Downloading model.safetensors:   0%|          | 0.00/67.4M [00:00<?, ?B/s]

(tensor([[0.2819, 0.2997],
        [0.3204, 0.3262]], device='cuda:0'), tensor([[-0.2597, -0.1033],
        [-0.3401, -0.1753]], device='cuda:0'), tensor([[-0.1039,  0.3544,  0.8981],
        [-0.0786,  0.2911,  0.8833]], device='cuda:0'), tensor([[ 0.5575, -0.0765, -0.3373],
        [ 0.4492, -0.0638, -0.4594]], device='cuda:0'), tensor([[ 0.0694,  0.3060, -0.4675],
        [-0.0050,  0.2950, -0.4781]], device='cuda:0'), tensor([0.5977, 0.6018], device='cuda:0'))


## train function

In [None]:
def train_function(model,
                   optimizer,
                   scheduler,
                   loss_functions,
                   scaler,
                   loader,
                   device,
                   iters_to_accumulate):
    model.train()

    total_bowel_loss = 0.0
    total_extravasation_loss = 0.0
    total_kidney_loss = 0.0
    total_liver_loss = 0.0
    total_spleen_loss = 0.0
    total_any_injury_loss = 0.0

    total_bowel_weight = 0.0
    total_extravasation_weight = 0.0
    total_kidney_weight = 0.0
    total_liver_weight = 0.0
    total_spleen_weight = 0.0
    total_any_injury_weight = 0.0
    for bi, sample in enumerate(tqdm(loader)):
        sample = [x.to(device) for x in sample]
        video, crop_liver, crop_spleen, crop_kidney, label, bowel, extravasation, kidney, liver, spleen, any_injury, sample_weights = sample


        with torch.cuda.amp.autocast():
            bowel_output, extravasation_output, kidney_output, liver_output, spleen_output, any_injury_output = model(video, crop_liver, crop_spleen, crop_kidney, label, mode = 'test')

        bowel_loss = (loss_functions[0](bowel_output, bowel) * sample_weights[:, 0]).sum()
        extravasation_loss = (loss_functions[0](extravasation_output, extravasation) * sample_weights[:, 1]).sum()
        kidney_loss = (loss_functions[0](kidney_output, kidney) * sample_weights[:, 2]).sum()
        liver_loss = (loss_functions[0](liver_output, liver) * sample_weights[:, 3]).sum()
        spleen_loss = (loss_functions[0](spleen_output, spleen) * sample_weights[:, 4]).sum()
        any_injury_loss = (loss_functions[1](any_injury_output, any_injury) * sample_weights[:, 5]).sum()

        loss = (
            bowel_loss / sample_weights[:, 0].sum()+ \
            extravasation_loss / sample_weights[:, 1].sum()+ \
            kidney_loss / sample_weights[:, 2].sum() + \
            liver_loss / sample_weights[:, 3].sum()+ \
            spleen_loss / sample_weights[:, 4].sum()#+ \
            #any_injury_loss / sample_weights[:, 5].sum()
            )
        loss = loss / 5

        loss = loss / iters_to_accumulate

        scaler.scale(loss).backward()
        if (bi + 1) % iters_to_accumulate == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

            scheduler.step()


        total_bowel_loss += bowel_loss.detach().cpu()
        total_extravasation_loss += extravasation_loss.detach().cpu()
        total_kidney_loss += kidney_loss.detach().cpu()
        total_liver_loss += liver_loss.detach().cpu()
        total_spleen_loss += spleen_loss.detach().cpu()
        total_any_injury_loss += any_injury_loss.detach().cpu()

        total_bowel_weight += sample_weights[:, 0].sum().cpu()
        total_extravasation_weight += sample_weights[:, 1].sum().cpu()
        total_kidney_weight += sample_weights[:, 2].sum().cpu()
        total_liver_weight += sample_weights[:, 3].sum().cpu()
        total_spleen_weight += sample_weights[:, 4].sum().cpu()
        total_any_injury_weight += sample_weights[:, 5].sum().cpu()

    total_bowel_loss = total_bowel_loss / total_bowel_weight
    total_extravasation_loss = total_extravasation_loss / total_extravasation_weight
    total_kidney_loss = total_kidney_loss / total_kidney_weight
    total_liver_loss = total_liver_loss / total_liver_weight
    total_spleen_loss = total_spleen_loss / total_spleen_weight
    total_any_injury_loss  = total_any_injury_loss / total_any_injury_weight

    total_loss = (total_bowel_loss + total_extravasation_loss + total_kidney_loss + total_liver_loss + total_spleen_loss + total_any_injury_loss)/6
    return total_loss


def test_function(model,
                  loader,
                  device,
                  input_df,
                  temperature=1.0):

    test_df = input_df.copy()
    true_df = input_df.copy()
    model.eval()

    # competition metric
    bowel_healthy = []
    bowel_injury = []
    extravasation_healthy = []
    extravasation_injury = []
    kidney_healthy = []
    kidney_low = []
    kidney_high = []
    liver_healthy = []
    liver_low = []
    liver_high = []
    spleen_healthy = []
    spleen_low = []
    spleen_high = []

    # auc
    bowel_preds = []
    extravasation_preds = []
    kidney_preds = []
    liver_preds = []
    spleen_preds = []
    any_injury_preds = []

    bowel_trues = []
    extravasation_trues = []
    kidney_trues = []
    liver_trues = []
    spleen_trues = []
    any_injury_trues = []

    for bi, sample in enumerate(tqdm(loader)):
        sample = [x.to(device) for x in sample]
        video, crop_liver, crop_spleen, crop_kidney, label, bowel, extravasation, kidney, liver, spleen, any_injury, _ = sample

        with torch.no_grad():
            output = model(video, crop_liver, crop_spleen, crop_kidney, label, mode = 'test')

        bowel_output = nn.Softmax(dim=-1)(output[0].cpu()/temperature)
        extravasation_output = nn.Softmax(dim=-1)(output[1].cpu()/temperature)
        kidney_output = nn.Softmax(dim=-1)(output[2].cpu()/temperature)
        liver_output = nn.Softmax(dim=-1)(output[3].cpu()/temperature)
        spleen_output = nn.Softmax(dim=-1)(output[4].cpu()/temperature)
        any_injury_output = output[5].cpu()

        bowel_healthy.extend(bowel_output[:, 0].tolist())
        bowel_injury.extend(bowel_output[:, 1].tolist())
        extravasation_healthy.extend(extravasation_output[:, 0].tolist())
        extravasation_injury.extend(extravasation_output[:, 1].tolist())
        kidney_healthy.extend(kidney_output[:, 0].tolist())
        kidney_low.extend(kidney_output[:, 1].tolist())
        kidney_high.extend(kidney_output[:, 2].tolist())
        liver_healthy.extend(liver_output[:, 0].tolist())
        liver_low.extend(liver_output[:, 1].tolist())
        liver_high.extend(liver_output[:, 2].tolist())
        spleen_healthy.extend(spleen_output[:, 0].tolist())
        spleen_low.extend(spleen_output[:, 1].tolist())
        spleen_high.extend(spleen_output[:, 2].tolist())

        bowel_preds.extend(bowel_output[:, 1].tolist())
        extravasation_preds.extend(extravasation_output[:, 1].tolist())
        kidney_preds.extend(kidney_output.tolist())
        liver_preds.extend(liver_output.tolist())
        spleen_preds.extend(spleen_output.tolist())
        any_injury_preds.extend(any_injury_output.tolist())

        bowel_trues.extend(bowel.tolist())
        extravasation_trues.extend(extravasation.tolist())
        kidney_trues.extend(kidney.tolist())
        liver_trues.extend(liver.tolist())
        spleen_trues.extend(spleen.tolist())
        any_injury_trues.extend(any_injury.tolist())

    test_df['bowel_healthy'] = bowel_healthy
    test_df['bowel_injury'] = bowel_injury
    test_df['extravasation_healthy'] = extravasation_healthy
    test_df['extravasation_injury'] = extravasation_injury
    test_df['kidney_healthy'] = kidney_healthy
    test_df['kidney_low'] = kidney_low
    test_df['kidney_high'] = kidney_high
    test_df['liver_healthy'] = liver_healthy
    test_df['liver_low'] = liver_low
    test_df['liver_high'] = liver_high
    test_df['spleen_healthy'] = spleen_healthy
    test_df['spleen_low'] = spleen_low
    test_df['spleen_high'] = spleen_high

    test_score = score(create_training_solution(true_df), test_df, 'patient_id', reduction='none')

    bowel_auc = roc_auc_score(bowel_trues, bowel_preds)
    extravasation_auc = roc_auc_score(extravasation_trues, extravasation_preds)
    kidney_auc = roc_auc_score(kidney_trues, kidney_preds, multi_class = 'ovr')
    liver_auc = roc_auc_score(liver_trues, liver_preds, multi_class = 'ovr')
    spleen_auc = roc_auc_score(spleen_trues, spleen_preds, multi_class = 'ovr')
    any_injury_auc = roc_auc_score(any_injury_trues, any_injury_preds)

    message = {
        'weighted-log-loss' : {
            'bowel' : round(test_score[0], 4),
            'extravasation' : round(test_score[1], 4),
            'kidney' : round(test_score[2], 4),
            'liver' : round(test_score[3], 4),
            'spleen' : round(test_score[4], 4),
            'any_injury' : round(test_score[5], 4),
            'score' : round(np.mean(test_score), 4)
        },

        'auc' : {
            'bowel' : round(bowel_auc, 4),
            'extravasation' : round(extravasation_auc, 4),
            'kidney' : round(kidney_auc, 4),
            'liver' : round(liver_auc, 4),
            'spleen' : round(spleen_auc, 4),
            'any_injury' : round(any_injury_auc, 4)

        }
    }

    return test_df, torch.tensor(test_score).mean(), message

## run

In [None]:
for k in range(5):
    device = 'cuda'
    epoch = 20
    batch_size = 4
    lr = 2e-4
    wd = 0.01
    warmup_ratio = 0.1
    num_workers = 12
    iters_to_accumulate = 4
    label_smoothing = 0.0
    early_stop_epoch = 15
    dir_name = 'result-channel2-512'

    train, injury_folds, normal_folds = preprocess()

    train_df = pd.concat([injury_folds[k][0]] + [normal_folds[k][0]]).reset_index(drop=True)
    val_df = pd.concat([injury_folds[k][1], normal_folds[k][1]]).reset_index(drop=True)

    train_dataset = CustomDataset(train_df, augmentation=True)
    val_dataset = CustomDataset(val_df)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, num_workers = num_workers, shuffle = True, drop_last = True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = batch_size, num_workers = num_workers, shuffle = False, drop_last = False)


    model = Custom3DCNN().to(device).float()

    loss_functions = [
        nn.CrossEntropyLoss(label_smoothing = label_smoothing, reduction='none'),
        nn.BCELoss(reduction='none')
    ]

    optimizer = torch.optim.AdamW(params = model.parameters(), lr = lr, weight_decay = wd)
    total_steps = int(len(train_df) * epoch/(batch_size * iters_to_accumulate))
    warmup_steps = int(total_steps * warmup_ratio)
    print('total_steps: ', total_steps)
    print('warmup_steps: ', warmup_steps)

    scheduler = get_cosine_schedule_with_warmup(optimizer,
                                                num_warmup_steps = warmup_steps,
                                                num_training_steps = total_steps)
    scaler = torch.cuda.amp.GradScaler()

    if not os.path.isdir(f'/content/drive/MyDrive/Kaggle/RSNA 2023 Abdominal Trauma Detection/{dir_name}/'):
      os.mkdir(f'/content/drive/MyDrive/Kaggle/RSNA 2023 Abdominal Trauma Detection/{dir_name}/')

    if not os.path.isdir(f'/content/drive/MyDrive/Kaggle/RSNA 2023 Abdominal Trauma Detection/{dir_name}/fold{k+1}/'):
      os.mkdir(f'/content/drive/MyDrive/Kaggle/RSNA 2023 Abdominal Trauma Detection/{dir_name}/fold{k+1}/')

    log_path = f'/content/drive/MyDrive/Kaggle/RSNA 2023 Abdominal Trauma Detection/{dir_name}/fold{k+1}/log.txt'


    for i in range(epoch):
        print(f'{i+1}th epoch training is start...')

        if i==early_stop_epoch:
          break

        # train
        train_loss = train_function(model,
                                    optimizer,
                                    scheduler,
                                    loss_functions,
                                    scaler,
                                    train_loader,
                                    device,
                                    iters_to_accumulate)

        # val
        _, val_loss, message = test_function(model,
                                              val_loader,
                                              device,
                                              val_loader.dataset.df.copy())


        # save
        save_path = f'/content/drive/MyDrive/Kaggle/RSNA 2023 Abdominal Trauma Detection/{dir_name}/fold{k+1}/epoch' + f'{i+1}'.zfill(3) + \
                    f'-trainloss{round(train_loss.tolist(), 4)}' + \
                    f'-valloss{round(val_loss.tolist(), 4)}' + '.bin'
        torch.save(model.state_dict(), save_path)

        _lr = optimizer.param_groups[0]['lr']
        message['log'] = f'epoch : {i+1}, lr : {_lr}, trainloss : {round(train_loss.tolist(), 4)}, valloss : {round(val_loss.tolist(), 4)}'
        print(message)
        with open(log_path, 'a+') as logger:
            logger.write(f'{message}\n')