# GPU Status

In [1]:
# check GPU type.
!nvidia-smi

Mon Mar  4 13:56:08 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla P100-PCIE-16GB           Off | 00000000:00:04.0 Off |                    0 |
| N/A   34C    P0              25W / 250W |      0MiB / 16384MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

# Python package installations

In [2]:
!pip install torchio
!pip install 'git+https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup'

Collecting torchio
  Downloading torchio-0.19.6-py2.py3-none-any.whl.metadata (48 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m48.9/48.9 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
Downloading torchio-0.19.6-py2.py3-none-any.whl (173 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m173.4/173.4 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torchio
Successfully installed torchio-0.19.6
Collecting git+https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup
  Cloning https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup to /tmp/pip-req-build-_34bihvz
  Running command git clone --filter=blob:none --quiet https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup /tmp/pip-req-build-_34bihvz
  Resolved https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup to commit 12d03c07553aedd3d9e9155e2b3e31ce8c64081a
  Preparing metadata (setup.py) ... [?25ldone
[?25hBuilding whe

In [3]:
import numpy as np
import random
import warnings
from typing import Dict
from pathlib import Path
import pandas as pd
import math

# System packages
import os
from os import listdir
from os.path import splitext, isfile, join
from pathlib import Path

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchio as tio

# Visualization packages
from PIL import Image
import matplotlib.pyplot as plt

import sklearn
from sklearn.metrics import confusion_matrix
from sklearn.feature_selection import SelectKBest, f_regression
from sklearn.svm import SVC
from sklearn.preprocessing import normalize

from argparse import ArgumentParser

from timm.models.layers import trunc_normal_
from cosine_annealing_warmup import CosineAnnealingWarmupRestarts

# Configuration

In [4]:
class cfg:
    seed = 77777777
    
    # Data
    img_size = (512, 512)
    in_channels = 1
    
    # 4 classes
    n_classes = 4
    train_img_path = '/kaggle/input/rtmets-sarcopenia/muscle_group_segment/train/xdata'
    train_mask_path = '/kaggle/input/rtmets-sarcopenia/muscle_group_segment/train/ydata'
    valid_img_path = '/kaggle/input/rtmets-sarcopenia/muscle_group_segment/validation/xdata'
    valid_mask_path = '/kaggle/input/rtmets-sarcopenia/muscle_group_segment/validation/ydata'
    all_img_path = '/kaggle/input/rtmets-sarcopenia/muscle_group_segment/all'
    img_4c_path = '/kaggle/input/rtmets-sarcopenia/4_channels'
    
    # Clinical Data
    train_data_path = '/kaggle/input/rtmets-sarcopenia/RT_spine_NESMS_info/train.csv'
    valid_data_path = '/kaggle/input/rtmets-sarcopenia/RT_spine_NESMS_info/valid.csv'
    all_data_path = '/kaggle/input/rtmets-sarcopenia/RT_spine_NESMS_info/all.csv'
    pred_days = 42 # 42days, 90days, 365days
    y_data_reverse = True
    
    # Data Folding
    cross_valid = 1

    val_percent: float = 0.1
    img_scale: float = 1.0
    
    # Transforms
    scale = (0.7, 1.0)
    
    # Model args
    optimizer = 'AdamW'
    bilinear = False
    save_checkpoint: bool = True
    amp: bool = False
    checkpoint = '/kaggle/input/rtmets-sarcopenia/checkpoints/checkpoint_epoch48_0008.pth'
    model_dir = './model'
    
    # Hyper parameters
    patience: int = math.inf
    epochs: int = 100
    batch_size: int = 2
    lr: float = 1e-3
    weight_decay: float = 1e-1
    momentum: float = 0.999
    gradient_clipping: float = 1.0
    optimizer = 'AdamW'
    patience = 5
    dropout = 0.4

# Utils

In [5]:
def seed_setting(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = False

def save_model(model_dir, model, args, epoch=None):
    model_dir = Path(model_dir)
    model_dir.mkdir(parents=True, exist_ok=True)

    torch.save({
        'epochs': args.epochs,
        'model_state_dict': model.state_dict(),
        'batch_size': args.batch_size,
        'lr': args.lr,
        'weight_decay': args.weight_decay,
        'optimizer': args.optimizer,
        'patience': args.patience,
    }, str(model_dir) + 'checkpoint_epoch{}.pth'.format(epoch))

def get_label_weights(args, k):
    weights = [3, 1]
    print('weights = ', weights)
    weights = torch.tensor(weights.copy(), dtype=torch.float32)
    weights = weights.to(device)
    return weights

# all_scorer = sklearn.metrics.get_scorer_names()
# print(all_scorer)

def evaluate_preds(y_test, preds, threshold=0.5):    
    _preds = []
    for i, v in enumerate(preds):
        _preds.append(1 if (v > threshold) else 0)

    tn, fp, fn, tp = sklearn.metrics.confusion_matrix(y_test, _preds).ravel()
    print(tn, fp, fn, tp)

    acc = (tp+tn)/(tn+fp+fn+tp)
    # acc = sklearn.metrics.accuracy_score(y_test, _preds)

    sensitivity = tp/(tp+fn) # Recall
    # sensitivity = sklearn.metrics.recall_score(y_test, _preds)

    precision = tp/(tp+fp) # Precision & ppv
    # precision = sklearn.metrics.precision_score(y_test, _preds)

    npv = tn/(tn+fn)

    specificity = tn/(tn+fp)

    try:
        f1 = ((2*precision*sensitivity)/(precision+sensitivity))
    except:
        f1 = 0
    # f1 = sklearn.metrics.f1_score(y_test, _preds)

    auc = sklearn.metrics.roc_auc_score(y_test, preds)

    fpr, tpr, thresholds = sklearn.metrics.roc_curve(y_test, preds)
#     print(fpr)
#     print(tpr)
#     print(thresholds)
    roc_auc = sklearn.metrics.auc(fpr, tpr)
    display = sklearn.metrics.RocCurveDisplay(
        fpr=fpr,
        tpr=tpr,
        roc_auc=roc_auc
    )
    display.plot()
    plt.show()

#     print(f'model_name: {cfg.model_name}')
    print(f'acc: {acc}\nsensitivity: {sensitivity}\nspecificity: {specificity}\nppv(precision): {precision}\nnpv: {npv}\nF1-score: {f1}\nAUC: {auc}')

# Data Preparing

In [6]:
def data_preparing(data_path):
    res = pd.read_csv(data_path)
    data = res.values
    header = res.columns.to_numpy()
    
    x_data, y_42d_data, y_90d_data, y_365d_data = data[:, :-3], data[:, -3], data[:, -2], data[:, -1]
    
    return x_data, y_42d_data, y_90d_data, y_365d_data, header

# Datasets

In [7]:
os.environ['KMP_DUPLICATE_LIB_OK']='True'
os.environ['CUDA_LAUNCH_BLOCKING']='1'

##### Dataset #####
class L3_Dataset(Dataset):
    def __init__(self, images_path=None, clinical_data_path=None, train_val_test_list_path=None, mode=None, k=0, predict_mode=False, days=42):
        self.images_path = images_path
        self.clinical_data_path = clinical_data_path
        self.mode = mode
        self.predict_mode = predict_mode  
        
        self.x_data, y_42d_data, y_90d_data, y_365d_data, self.header = data_preparing(self.clinical_data_path)
        if days == 42:
            self.y_data = y_42d_data
        elif days == 90:
            self.y_data = y_90d_data
        elif days == 365:
            self.y_data = y_365d_data
        
    def __len__(self):
        return len(self.x_data)

    def __getitem__(self, idx):
        img_path = Path(self.images_path, f'{idx+1}.npy')
#         mask_path = Path(self.images_path, 'ydata', f'{idx+1}.png')
        img_name = img_path.stem

        # label
        label = self.y_data[idx]
        label = int(label)
        label = torch.tensor(label)

        # Image
        c4_img = np.load(img_path)
        c4_img = c4_img.astype(np.float32)
#         c4_img = np.zeros((512, 512, 4), dtype=np.float32)
#         ct_img = np.array(Image.open(img_path))
#         mask_img = np.array(Image.open(mask_path))
#         for (y, rows) in enumerate(ct_img):
#             for (x, p) in enumerate(rows):
#                 [r, g, b] = mask_img[y][x]
#                 c4_img[y][x] = [r, g, b, p]
        c4_img = c4_img[:, :, :, np.newaxis].transpose((2, 3, 0, 1))

        # clinical features
        clinical_values = torch.tensor(self.x_data[idx], dtype=torch.float32)

        return img_name, c4_img, clinical_values, label

# Model Evaluation

In [8]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=0, size_average=True, weights=None):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.size_average = size_average
        self.weights = weights

    def forward(self, input, target):
        pt = F.softmax(input, dim=1)
        logpt = F.log_softmax(input, dim=1)
        weights = self.weights.index_select(0, target) # nice
        target = target.view(-1, 1) # 1維 -> 2維

        pt = pt.gather(1, target)
        logpt = logpt.gather(1, target)
        pt = pt.view(-1)
        logpt = logpt.view(-1)

        loss = -1 * weights * (1-pt)**self.gamma * logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()

# Blocks

In [9]:
class SE(nn.Module):
    def __init__(self, channel, reduction=2):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1, 1)
        return x * y.expand_as(x)

class GCT(nn.Module):
    def __init__(self, channel, epsilon=1e-5, mode='l2', after_relu=False):
        super(GCT, self).__init__()
        self.alpha = nn.Parameter(torch.ones(1, channel, 1, 1, 1))
        self.gamma = nn.Parameter(torch.zeros(1, channel, 1, 1, 1))
        self.beta = nn.Parameter(torch.zeros(1, channel, 1, 1, 1))
        self.epsilon = epsilon
        self.mode = mode
        self.after_relu = after_relu

    def forward(self, x):
        if self.mode == 'l2':
            embedding = (x.pow(2).sum((2, 3, 4), keepdim=True) + self.epsilon).pow(0.5) * self.alpha
            norm = self.gamma / (embedding.pow(2).mean(dim=1, keepdim=True) + self.epsilon).pow(0.5)
        else:
            print('Unknown mode!')
            sys.exit()

        # gate = 1. + torch.tanh(embedding * norm + self.beta)
        gate = 1. + torch.sigmoid(embedding * norm + self.beta)

        return x * gate

class BAM(nn.Module):
    def __init__(self, gate_channel, reduction=2, dilation_val=4, num_layers=1):
        super(BAM, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.channel_att = nn.Sequential(
            nn.Linear(gate_channel, gate_channel // reduction),
            nn.ReLU(),
            nn.Linear(gate_channel // reduction, gate_channel)
        )

        # input (_, _, D, H, W) -> output (_, _, D, H, W) for any Conv3D
        self.spatial_att = nn.Sequential(
            nn.Conv3d(gate_channel, gate_channel // reduction, kernel_size=1),
            LayerNorm(gate_channel // reduction, data_format="channels_first"),
            nn.GELU(),
            nn.Conv3d(gate_channel // reduction, gate_channel // reduction, kernel_size=3, padding=dilation_val, dilation=dilation_val),
            LayerNorm(gate_channel // reduction, data_format="channels_first"),
            nn.GELU(),
            nn.Conv3d(gate_channel // reduction, gate_channel // reduction, kernel_size=3, padding=dilation_val, dilation=dilation_val),
            LayerNorm(gate_channel // reduction, data_format="channels_first"),
            nn.GELU(),
            nn.Conv3d(gate_channel//reduction, 1, kernel_size=1)
        )

    def forward(self, x):
        b, c, _, _, _ = x.size()
        # compute channel attention
        channel_part = self.avg_pool(x).view(b, c)
        channel_part = self.channel_att(channel_part).view(b, c, 1, 1, 1).expand_as(x)
        # compute spatial attention
        spatial_part = self.spatial_att(x).expand_as(x)
        # add together
        att = 1 + F.sigmoid(channel_part + spatial_part)
        return att * x

class CBAM(nn.Module):
    def __init__(self, gate_channel, reduction=2):
        super().__init__()
        # channel attention
        self.pools = [
            nn.AdaptiveAvgPool3d(1),
            nn.AdaptiveMaxPool3d(1)
        ]
        self.mlp = nn.Sequential(
            nn.Linear(gate_channel, gate_channel // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(gate_channel // reduction, gate_channel)
        )
        self.sigmoid = nn.Sigmoid()

        # spatial attention
        kernel_size = 7
        self.conv = nn.Sequential(
            nn.Conv3d(2, 1, kernel_size=kernel_size, stride=1, padding=(kernel_size-1) // 2),
            LayerNorm(1, data_format="channels_first"),
            nn.Sigmoid()
        )

    def forward(self, x):
        res = x
        b, c, _, _, _ = x.size()

        channel_part = None
        for pool in self.pools:
            y = pool(x).view(b, c)
            y = self.mlp(y)
            channel_part = channel_part + y if channel_part is not None else y 
        channel_part = self.sigmoid(channel_part).view(b, c, 1, 1, 1)
        x = x * channel_part.expand_as(x)

        spatial_part = torch.cat( (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1 )
        spatial_part = self.conv(spatial_part)
        x = x * spatial_part.expand_as(x)

        return x + res

class SE_GCT(nn.Module):
    def __init__(self, channel, reduction=2, epsilon=1e-5, mode='l2', after_relu=False):
        super().__init__()
        # Squeeze and excitation
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel),
        )
        # GCT
        self.alpha = nn.Parameter(torch.ones(1, channel, 1, 1, 1))
        self.gamma = nn.Parameter(torch.zeros(1, channel, 1, 1, 1))
        self.beta = nn.Parameter(torch.zeros(1, channel, 1, 1, 1))
        self.epsilon = epsilon
        self.mode = mode
        self.after_relu = after_relu

    def forward(self, x):
        # compute SE attention
        b, c, _, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        se_gate = self.fc(y).view(b, c, 1, 1, 1).expand_as(x)

        # compute GCT attention
        if self.mode == 'l2':
            embedding = (x.pow(2).sum((2, 3, 4), keepdim=True) + self.epsilon).pow(0.5) * self.alpha
            norm = self.gamma / (embedding.pow(2).mean(dim=1, keepdim=True) + self.epsilon).pow(0.5)
        else:
            print('Unknown mode!')
            sys.exit()
        gct_gate = embedding * norm + self.beta

        # add together
        att = 1 + F.sigmoid(se_gate + gct_gate)
        return att * x

class BAM_GCT(nn.Module):
    def __init__(self, channel, reduction=2, dilation_val=4, num_layers=1, epsilon=1e-5, mode='l2', after_relu=False):
        super().__init__()
        # Squeeze and excitation
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel),
        )
        # GCT
        self.alpha = nn.Parameter(torch.ones(1, channel, 1, 1, 1))
        self.gamma = nn.Parameter(torch.zeros(1, channel, 1, 1, 1))
        self.beta = nn.Parameter(torch.zeros(1, channel, 1, 1, 1))
        self.epsilon = epsilon
        self.mode = mode
        self.after_relu = after_relu

        # input (_, _, D, H, W) -> output (_, _, D, H, W) for any Conv3D
        self.spatial_att = nn.Sequential(
            nn.Conv3d(channel, channel // reduction, kernel_size=1),
            LayerNorm(channel // reduction, data_format="channels_first"),
            nn.GELU(),
            nn.Conv3d(channel // reduction, channel // reduction, kernel_size=3, padding=dilation_val, dilation=dilation_val),
            LayerNorm(channel // reduction, data_format="channels_first"),
            nn.GELU(),
            nn.Conv3d(channel // reduction, channel // reduction, kernel_size=3, padding=dilation_val, dilation=dilation_val),
            LayerNorm(channel // reduction, data_format="channels_first"),
            nn.GELU(),
            nn.Conv3d(channel//reduction, 1, kernel_size=1)
        )

    def forward(self, x):
        b, c, _, _, _ = x.size()
        # compute channel attention

        # compute SE attention
        b, c, _, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1, 1).expand_as(x)

        # compute GCT attention
        if self.mode == 'l2':
            embedding = (x.pow(2).sum((2, 3, 4), keepdim=True) + self.epsilon).pow(0.5) * self.alpha
            norm = self.gamma / (embedding.pow(2).mean(dim=1, keepdim=True) + self.epsilon).pow(0.5)
        else:
            print('Unknown mode!')
            sys.exit()
        gate = embedding * norm + self.beta

        # add together
        channel_part = y + gate

        # compute spatial attention
        spatial_part = self.spatial_att(x).expand_as(x)
        # add together
        att = 1 + F.sigmoid(channel_part + spatial_part)
        return att * x

class CBAM_GCT(nn.Module):
    def __init__(self, channel, reduction=2, epsilon=1e-5, mode='l2', after_relu=False):
        super().__init__()
        # channel attention
        self.pools = [
            nn.AdaptiveAvgPool3d(1),
            nn.AdaptiveMaxPool3d(1)
        ]
        self.mlp = nn.Sequential(
            nn.Linear(channel, channel // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel)
        )
        self.sigmoid = nn.Sigmoid()

        # GCT
        self.alpha = nn.Parameter(torch.ones(1, channel, 1, 1, 1))
        self.gamma = nn.Parameter(torch.zeros(1, channel, 1, 1, 1))
        self.beta = nn.Parameter(torch.zeros(1, channel, 1, 1, 1))
        self.epsilon = epsilon
        self.mode = mode
        self.after_relu = after_relu

        # spatial attention
        kernel_size = 7
        self.conv = nn.Sequential(
            nn.Conv3d(2, 1, kernel_size=kernel_size, stride=1, padding=(kernel_size-1) // 2),
            LayerNorm(1, data_format="channels_first"),
            nn.Sigmoid()
        )

    def forward(self, x):
        res = x
        b, c, _, _, _ = x.size()

        channel_part = None
        for pool in self.pools:
            y = pool(x).view(b, c)
            y = self.mlp(y)
            channel_part = channel_part + y if channel_part is not None else y 
        channel_part = channel_part.view(b, c, 1, 1, 1)

        # compute GCT attention
        if self.mode == 'l2':
            embedding = (x.pow(2).sum((2, 3, 4), keepdim=True) + self.epsilon).pow(0.5) * self.alpha
            norm = self.gamma / (embedding.pow(2).mean(dim=1, keepdim=True) + self.epsilon).pow(0.5)
        else:
            print('Unknown mode!')
            sys.exit()
        gate = embedding * norm + self.beta

        # add together
        channel_part = channel_part + gate
        channel_part = F.sigmoid(channel_part)
        x = x * channel_part.expand_as(x)

        spatial_part = torch.cat( (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1 )
        spatial_part = self.conv(spatial_part)
        x = x * spatial_part.expand_as(x)

        return x + res

class LayerNorm(nn.Module):
    r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs 
    with shape (batch_size, channels, height, width).
    """
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
            return x

# Models

In [10]:
class GRN(nn.Module):
    # gamma, beta: learnable affine transform parameters
    # X: input of shape (N, D, H, W, C)
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.zeros((dim)))
        self.beta = nn.Parameter(torch.zeros((dim)))

    def forward(self, x):
        gx = torch.norm(x, p=2, dim=(1, 2, 3), keepdim=True)
        nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
        return self.gamma * (x * nx) + self.beta + x

class Block(nn.Module):
    r""" ConvNeXt Block. There are two equivalent implementations:
    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
    (2) DwConv -> Permute to (N, D, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
    We use (2) as we find it slightly faster in PyTorch
    
    Args:
        dim (int): Number of input channels.
        drop_path (float): Stochastic depth rate. Default: 0.0
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
    """
    def __init__(self, dim, depth_number: int, pos_in_depth: int, drop_path=0.5, layer_scale_init_value=1e-6):
        super().__init__()
        self.dwconv = nn.Conv3d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
        self.norm = LayerNorm(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, dim)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 
                                    requires_grad=True) if layer_scale_init_value > 0 else None
        self.depth_number = depth_number
        self.pos_in_depth = pos_in_depth
        self.se_gct = SE_GCT(channel=dim, reduction=2)
        self.dim = dim

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 4, 1) # (N, C, D, H, W) -> (N, D, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 4, 1, 2, 3) # (N, D, H, W, C) -> (N, C, D, H, W)

        if self.pos_in_depth == 3:
            x = self.se_gct(x)

        x = input + x
        return x

class ConvNeXt(nn.Module):
    r""" ConvNeXt
        A PyTorch impl of : `A ConvNet for the 2020s`  -
          https://arxiv.org/pdf/2201.03545.pdf

    Args:
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
        dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
        drop_path_rate (float): Stochastic depth rate. Default: 0.
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
        head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
    """
    def __init__(
            self, in_chans=4, num_classes=2, 
            depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., 
            layer_scale_init_value=1e-6, head_init_scale=1., len_of_clinical_features=71
        ):
        super().__init__()

        # depths=[1, 1, 3, 1], dims=[8, 16, 32, 64]
        self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
        stem = nn.Sequential(
            nn.Conv3d(in_chans, dims[0], kernel_size=(1, 4, 4), stride=4),
            LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
        )
        self.downsample_layers.append(stem)
        for i in range(3):
            downsample_layer = nn.Sequential(
                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
                nn.Conv3d(dims[i], dims[i+1], kernel_size=(1, 2, 2), stride=2),
            )
            self.downsample_layers.append(downsample_layer)

        self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
        dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 
        cur = 0
        for i in range(4):
            stage = nn.Sequential(
                *[Block(dim=dims[i], drop_path=dp_rates[cur + j], layer_scale_init_value=layer_scale_init_value, depth_number=i+1, pos_in_depth=j+1) for j in range(depths[i])]
            )
            self.stages.append(stage)
            cur += depths[i]

        self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer

        # voi size
#         self.v_fc1 = nn.Linear(3, 16)
#         self.v_fc2 = nn.Linear(16, dims[-1])
#         self.v_fc3 = nn.Linear(dims[-1], dims[-1])

        # final
        len_of_features = dims[-1] + len_of_clinical_features
        self.fc1 = nn.Linear(len_of_features, len_of_features // 2) # 214 = dim[-1] + dim(clinical_info) which is 150
        self.fc2 = nn.Linear(len_of_features // 2, 2)
        self.bn1 = nn.BatchNorm1d(len_of_features // 2)
        self.softmax = nn.Softmax(dim=1)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv3d, nn.Linear)):
            trunc_normal_(m.weight, std=.02) # 正態分佈
            nn.init.constant_(m.bias, 0)

    def forward_features(self, x):
        for i in range(4):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)
        return self.norm(x.mean([-3, -2, -1])) # global average pooling, (N, C, D, H, W) -> (N, C)

    def forward(self, x, clinical_info):
        x_feature = self.forward_features(x)
#         voi_feature = self.v_fc1(voi_size)
#         voi_feature = self.v_fc2(voi_feature)
#         voi_feature = self.v_fc3(voi_feature)

#         b, n = x_feature.size()
#         _sum = (x_feature * voi_feature).sum(dim=(-1))
#         x = x_feature - (_sum / n).view(b, 1).expand_as(x_feature)

        x = torch.cat((x_feature, clinical_info), 1)
        x = F.leaky_relu(self.bn1(self.fc1(x)))
        x = self.fc2(x)
        x = self.softmax(x)

        return x

def convnext_tiny(pretrained=False, in_22k=False, **kwargs):
    # original dims = [96, 192, 384, 768]
    # best dims = [8, 16, 32, 64]
    model = ConvNeXt(depths=[1, 1, 3, 1], dims=[8, 16, 32, 64], **kwargs)
    return model

# Training Functions

In [11]:
def train(args, k: int, device, model, optimizer, criterion, scheduler, trainloader, valloader, seed):
    
    the_last_loss = 100
    trigger_times = 0

    the_best_loss = 0
    the_best_accuracy = 0

    save_path = Path(args.model_dir)
    save_tradoff_path = Path(args.model_dir, 'tradoff')

    for epoch in range(args.epochs):
        epoch_loss, epoch_accuracy = 0, 0
        train_label, train_pred = [], []
        model.train()
        for i, (_, img, clinical_info, label) in enumerate(trainloader):
            img, clinical_info, label = img.to(device), clinical_info.to(device), label.to(device)

            output = model(img, clinical_info)  #[bs, 2]

            torch.cuda.empty_cache()

            # _, pred = torch.max(output, dim=1)
            pred = torch.where(output[:, 0] > 0.475, 0, 1)
            # TODO: 不確定 label 有沒有二極化
            loss = criterion(output, label)  #focal
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.detach().cpu().item()
            epoch_accuracy += torch.sum(pred == label)

            train_label += label.detach().cpu().numpy().tolist()
            train_pred += pred.detach().cpu().numpy().tolist()

        scheduler.step()

        epoch_loss = epoch_loss / (len(trainloader) * args.batch_size)
        epoch_accuracy = epoch_accuracy.float() / (len(trainloader) * args.batch_size)
        tp, fn, fp, tn = confusion_matrix(train_label, train_pred).ravel()
        tnr, fpr, fnr, tpr = tn/(fp+tn), fp/(fp+tn), fn/(tp+fn), tp/(tp+fn)
        print(f'[{k}/{5}][{epoch}/{args.epochs}] loss: {epoch_loss:.8}, accuracy: {epoch_accuracy:.2}, tnr = {tnr:.2}, tpr = {tpr:.2}')

        # Early stopping
        the_current_loss, the_current_accuracy, the_current_tnr, the_current_tpr = validation(device, model, criterion, valloader)
        print(f'[validation] The current loss: {the_current_loss:.8}, accuracy: {the_current_accuracy:.2}, tnr = {the_current_tnr:.2}, tpr = {the_current_tpr:.2}')

        if the_current_loss > the_last_loss:
            trigger_times += 1
            print('trigger times:', trigger_times)

            if trigger_times >= args.patience:
                print(f'Early stopping!\nEpoch = {epoch}')
                return
        else:
            print('trigger times: 0')
            trigger_times = 0

        if epoch == 0 or the_best_loss >= the_current_loss:
            print(f'Recording best model. ({save_path})')
            save_model(save_path, model, args, epoch)
            the_best_loss = the_current_loss

        if epoch == 0 or (the_best_accuracy <= the_current_accuracy and abs(the_current_tnr-the_current_tpr) <= 0.15):
            print(f'Recording best tradeoff model. ({save_tradoff_path})')
            save_model(save_tradoff_path, model, args, epoch)
            the_best_accuracy = the_current_accuracy

        the_last_loss = the_current_loss

    print(f'stopping! Epoch = {args.epochs}')
    return 

def validation(device, model, criterion, valloader):
    model.eval()
    val_loss = 0
    val_accuracy = 0
    val_label, val_pred, val_c0 = [], [], []

    with torch.no_grad():
        for _, img, clinical_info, label in valloader:
            img, clinical_info, label = img.to(device), clinical_info.to(device), label.to(device)
            output = model(img, clinical_info)  #[bs, 2]
            
            # _, pred = torch.max(output, dim=1)
            c0 = output[:, 0]
            pred = torch.where(output[:, 0] > 0.475, 0, 1)
            loss = criterion(output, label)  #focal
            val_loss += loss.detach().cpu().item()
            val_accuracy += torch.sum(pred == label)

            val_label += label.detach().cpu().numpy().tolist()
            val_pred += pred.detach().cpu().numpy().tolist()
            val_c0 += c0.detach().cpu().numpy().tolist()

    evaluate_preds(val_label, val_c0, 0.5)

    val_loss = val_loss / len(valloader)
    val_accuracy = val_accuracy.float() / len(valloader)
    tp, fn, fp, tn = confusion_matrix(val_label, val_pred).ravel()
    tnr, fpr, fnr, tpr = tn/(fp+tn), fp/(fp+tn), fn/(tp+fn), tp/(tp+fn)
            
    return val_loss, val_accuracy, tnr, tpr

# Arguments Preparing

# Start Training

In [None]:
if __name__ == '__main__' :
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("device = ", device)

    for k in range(1, 6):
        ##### Setting Seed #####
        seed = cfg.seed
        seed_setting(seed)

        model = convnext_tiny(len_of_clinical_features=71)
        model.to(device)

        weights = get_label_weights(cfg, k)
        criterion = FocalLoss(gamma=2, weights=weights).to(device)
        optimizer = getattr(optim, cfg.optimizer)(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

        scheduler = CosineAnnealingWarmupRestarts(
            optimizer,
            first_cycle_steps=20,
            cycle_mult=1.0,
            max_lr=cfg.lr,
            min_lr=0.000001,
            warmup_steps=10,
            gamma=1.0,
        )
        # Data
        trainset = L3_Dataset(
            images_path=cfg.img_4c_path, 
            clinical_data_path=cfg.all_data_path,
            mode="train",
            k=k,
            days=cfg.pred_days
        )
        trainloader = DataLoader(trainset, batch_size=cfg.batch_size, shuffle=True, drop_last=True)
        valset = L3_Dataset(
            images_path=cfg.img_4c_path, 
            clinical_data_path=cfg.all_data_path,
            mode="val",
            k=k,
            days=cfg.pred_days
        ) 
        valloader = DataLoader(valset, batch_size=1, shuffle=False)
        
        train(cfg, k, device, model, optimizer, criterion, scheduler, trainloader, valloader, seed)

device =  cuda
weights =  [3, 1]
