# GPU Status

In [1]:
# check GPU type.
!nvidia-smi

Mon Mar 11 18:25:08 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.154.05             Driver Version: 535.154.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        Off | 00000000:16:00.0 Off |                  Off |
|  0%   30C    P8              11W / 450W |  14258MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

# Python package installations

In [2]:
!pip install torchio
!pip install 'git+https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup'

Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting git+https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup
  Cloning https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup to /tmp/pip-req-build-d0p3l2n6
  Running command git clone --filter=blob:none --quiet https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup /tmp/pip-req-build-d0p3l2n6
  Resolved https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup to commit 12d03c075

In [3]:
import numpy as np
import random
import warnings
from typing import Dict
from pathlib import Path
import pandas as pd
import math

# System packages
import os
from os import listdir
from os.path import splitext, isfile, join
from pathlib import Path

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchio as tio

# Visualization packages
from PIL import Image
import matplotlib.pyplot as plt

import sklearn
from sklearn.metrics import confusion_matrix
from sklearn.feature_selection import SelectKBest, f_regression
from sklearn.svm import SVC
from sklearn.preprocessing import normalize

from argparse import ArgumentParser

from timm.models.layers import trunc_normal_
from cosine_annealing_warmup import CosineAnnealingWarmupRestarts

In [4]:
import logging
import datetime

current_datetime = datetime.datetime.now()
timestamp = current_datetime.timestamp()
filename = './log_{}.log'.format(timestamp)
print(f'loggin file name: {filename}')
logging.basicConfig(filename=filename, level=logging.DEBUG, force=True)

loggin file name: ./log_1710181517.142172.log


# Configuration

In [5]:
class cfg:
    seed = 77777777
    
    # Data
    img_size = (512, 512)
    in_channels = 1
    
    # 4 classes
    n_classes = 4
    train_img_path = './rtmets-sarcopenia/muscle_group_segment/train/xdata'
    train_mask_path = './rtmets-sarcopenia/muscle_group_segment/train/ydata'
    valid_img_path = './rtmets-sarcopenia/muscle_group_segment/validation/xdata'
    valid_mask_path = './rtmets-sarcopenia/muscle_group_segment/validation/ydata'
    all_img_path = './rtmets-sarcopenia/muscle_group_segment/all'
    img_4c_path = './rtmets-sarcopenia/4_channels_sm'
    
    # Clinical Data
    train_data_path = './rtmets-sarcopenia/RT_spine_NESMS_info/train.csv'
    valid_data_path = './rtmets-sarcopenia/RT_spine_NESMS_info/valid.csv'
    all_data_path = './rtmets-sarcopenia/RT_spine_NESMS_info/all.csv'
    train_val_test_list_path = './rtmets-sarcopenia/train_val_test_split.csv'
    pred_days = 365 # 42days, 90days, 365days
    y_data_reverse = True
    use_damper = True
    
    # Data Folding
    cross_valid = 1
    data_sampling = 'under-sampling' # 'no-sampling' | 'over-sampling' | 'under-sampling'

    val_percent: float = 0.1
    img_scale: float = 1.0
    
    # Transforms
    scale = (0.7, 1.0)
    
    # Model args
    bilinear = False
    save_checkpoint: bool = True
    amp: bool = False
    checkpoint = './rtmets-sarcopenia/checkpoints/checkpoint_epoch48_0008.pth'
    model_dir = './model/under-sampling'
    
    # Hyper parameters
    patience: int = math.inf
    epochs: int = 100
    batch_size: int = 8
    lr: float = 1e-3
    weight_decay: float = 1e-1
    momentum: float = 0.999
    gradient_clipping: float = 1.0
    optimizer = 'AdamW'
    patience = 5
    dropout = 0.4

# Utils

In [6]:
def seed_setting(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = False

def save_model(model_dir, model, args, epoch=None, k=None):
    model_dir = Path(model_dir)
    model_dir.mkdir(parents=True, exist_ok=True)
    model_dir = model_dir.joinpath('{}d_fold{}_checkpoint_epoch{}.pth'.format(cfg.pred_days, k, epoch))

    torch.save({
        'epochs': args.epochs,
        'model_state_dict': model.state_dict(),
        'batch_size': args.batch_size,
        'lr': args.lr,
        'weight_decay': args.weight_decay,
        'optimizer': args.optimizer,
        'patience': args.patience,
    }, model_dir)

def get_label_weights(args, k):
#     weights = [3, 1]
    weights = [1, 1]
    logging.warning('weights = {}'.format(weights))
    weights = torch.tensor(weights.copy(), dtype=torch.float32)
    weights = weights.to(device)
    return weights

# all_scorer = sklearn.metrics.get_scorer_names()
# print(all_scorer)

def evaluate_preds(y_test, preds, _preds):
    tn, fp, fn, tp = sklearn.metrics.confusion_matrix(y_test, _preds).ravel()
#     print(f'tn: {tn}, fp: {fp}, fn: {fn}, tp: {tp}')

    acc = (tp+tn)/(tn+fp+fn+tp)
    # acc = sklearn.metrics.accuracy_score(y_test, _preds)

    tpr = tp/(tp+fn) # Recall, tpr, sensitivity
    # sensitivity = sklearn.metrics.recall_score(y_test, _preds)

    tnr = tn/(tn+fp) # tnr, specificity
    
    ppv = tp/(tp+fp) # Precision & ppv
    # precision = sklearn.metrics.precision_score(y_test, _preds)

    npv = tn/(tn+fn)
    
#     try:
#         f1 = ((2*ppv*tpr)/(ppv+tpr))
#     except:
#         f1 = 0
    # f1 = sklearn.metrics.f1_score(y_test, _preds)

    y_test = 1 - np.array(y_test)
    auc = sklearn.metrics.roc_auc_score(y_test, preds)
    auc = 1 - auc if auc < 0.5 else auc
    
#     fpr, tpr, thresholds = sklearn.metrics.roc_curve(y_test, preds)
#     roc_auc = sklearn.metrics.auc(fpr, tpr)
#     display = sklearn.metrics.RocCurveDisplay(
#         fpr=fpr,
#         tpr=tpr,
#         roc_auc=roc_auc
#     )
#     display.plot()
#     plt.show()

    logging.info(f'acc: {acc:.4}, tpr: {tpr:.4}, tnr: {tnr:.4}, ppv: {ppv:.4}, npv: {npv:.4}, auc: {auc:.4}')

    return acc, auc

# Data Preparing

In [7]:
def data_preparing(data_path):
    res = pd.read_csv(data_path)
    data = res.values
    header = res.columns.to_numpy()
    
    x_data, y_42d_data, y_90d_data, y_365d_data = data[:, :-3], data[:, -3], data[:, -2], data[:, -1]
    
    return x_data, y_42d_data, y_90d_data, y_365d_data, header

def distribution_map_preparing(train_val_test_list_path):
    test = pd.read_csv(train_val_test_list_path, header=None)
    data = test.values.tolist()
    new_data = []
    for each in data:
        each = [i for i in each if str(i) != 'nan']
        each = each[1:]
        new_data.append(each)
    return new_data

# Datasets

In [8]:
os.environ['KMP_DUPLICATE_LIB_OK']='True'
os.environ['CUDA_LAUNCH_BLOCKING']='1'

##### Dataset #####
class L3_Dataset(Dataset):
    def __init__(self,
                 images_path=None, clinical_data_path=None, train_val_test_list_path=None,
                 mode=None, k=0, predict_mode=False, days=42, reverse=False, sampling='over-sampling'):
        # finish no-sampling
        assert sampling in ('over-sampling', 'under-sampling', 'no-sampling')
        
        self.images_path = images_path
        self.clinical_data_path = clinical_data_path
        self.mode = mode
        self.predict_mode = predict_mode
        split_map = distribution_map_preparing(train_val_test_list_path)
        
        x_data, y_42d_data, y_90d_data, y_365d_data, self.header = data_preparing(self.clinical_data_path)
        self.x_data = x_data

        if days == 42:
            self.y_data = y_42d_data
            base_idx = 0
        elif days == 90:
            self.y_data = y_90d_data
            base_idx = 15
        elif days == 365:
            self.y_data = y_365d_data
            base_idx = 30
        else:
            raise
        
        if sampling == 'under-sampling':
            base_idx += 45
        
        base_idx += k * 3
        if mode == 'train':
            self.split_map = split_map[base_idx] + split_map[base_idx + 1]
        elif mode == 'val':
            self.split_map = split_map[base_idx + 2]
        else:
            raise
            
        if reverse:
            self.y_data = 1 - self.y_data
        
    def __len__(self):
        return len(self.split_map)

    def __getitem__(self, origin_idx):
        idx = int(self.split_map[origin_idx] - 1)

        img_path = Path(self.images_path, f'{idx + 1}.npy')
#         mask_path = Path(self.images_path, 'ydata', f'{idx+1}.png')
        img_name = img_path.stem

        # label
        label = self.y_data[idx]
        label = int(label)
        label = torch.tensor(label)

        # Image
        c4_img = np.load(img_path)
        c4_img = c4_img.astype(np.float32)
#         c4_img = np.zeros((512, 512, 4), dtype=np.float32)
#         ct_img = np.array(Image.open(img_path))
#         mask_img = np.array(Image.open(mask_path))
#         for (y, rows) in enumerate(ct_img):
#             for (x, p) in enumerate(rows):
#                 [r, g, b] = mask_img[y][x]
#                 c4_img[y][x] = [r, g, b, p]
        c4_img = c4_img[:, :, :, np.newaxis].transpose((2, 3, 0, 1))

        # clinical features
        clinical_values = torch.tensor(self.x_data[idx], dtype=torch.float32)

        return img_name, c4_img, clinical_values, label

# Model Evaluation

In [9]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=0, size_average=True, weights=None):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.size_average = size_average
        self.weights = weights

    def forward(self, input, target):
        pt = F.softmax(input, dim=1)
        logpt = F.log_softmax(input, dim=1)
        weights = self.weights.index_select(0, target) # nice
        target = target.view(-1, 1) # 1維 -> 2維

        pt = pt.gather(1, target)
        logpt = logpt.gather(1, target)
        pt = pt.view(-1)
        logpt = logpt.view(-1)

        loss = -1 * weights * (1-pt)**self.gamma * logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()

# Blocks

In [10]:
class SE(nn.Module):
    def __init__(self, channel, reduction=2):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1, 1)
        return x * y.expand_as(x)

class GCT(nn.Module):
    def __init__(self, channel, epsilon=1e-5, mode='l2', after_relu=False):
        super(GCT, self).__init__()
        self.alpha = nn.Parameter(torch.ones(1, channel, 1, 1, 1))
        self.gamma = nn.Parameter(torch.zeros(1, channel, 1, 1, 1))
        self.beta = nn.Parameter(torch.zeros(1, channel, 1, 1, 1))
        self.epsilon = epsilon
        self.mode = mode
        self.after_relu = after_relu

    def forward(self, x):
        if self.mode == 'l2':
            embedding = (x.pow(2).sum((2, 3, 4), keepdim=True) + self.epsilon).pow(0.5) * self.alpha
            norm = self.gamma / (embedding.pow(2).mean(dim=1, keepdim=True) + self.epsilon).pow(0.5)
        else:
            logging.WARNING('Unknown mode!')
            sys.exit()

        # gate = 1. + torch.tanh(embedding * norm + self.beta)
        gate = 1. + torch.sigmoid(embedding * norm + self.beta)

        return x * gate

class BAM(nn.Module):
    def __init__(self, gate_channel, reduction=2, dilation_val=4, num_layers=1):
        super(BAM, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.channel_att = nn.Sequential(
            nn.Linear(gate_channel, gate_channel // reduction),
            nn.ReLU(),
            nn.Linear(gate_channel // reduction, gate_channel)
        )

        # input (_, _, D, H, W) -> output (_, _, D, H, W) for any Conv3D
        self.spatial_att = nn.Sequential(
            nn.Conv3d(gate_channel, gate_channel // reduction, kernel_size=1),
            LayerNorm(gate_channel // reduction, data_format="channels_first"),
            nn.GELU(),
            nn.Conv3d(gate_channel // reduction, gate_channel // reduction, kernel_size=3, padding=dilation_val, dilation=dilation_val),
            LayerNorm(gate_channel // reduction, data_format="channels_first"),
            nn.GELU(),
            nn.Conv3d(gate_channel // reduction, gate_channel // reduction, kernel_size=3, padding=dilation_val, dilation=dilation_val),
            LayerNorm(gate_channel // reduction, data_format="channels_first"),
            nn.GELU(),
            nn.Conv3d(gate_channel//reduction, 1, kernel_size=1)
        )

    def forward(self, x):
        b, c, _, _, _ = x.size()
        # compute channel attention
        channel_part = self.avg_pool(x).view(b, c)
        channel_part = self.channel_att(channel_part).view(b, c, 1, 1, 1).expand_as(x)
        # compute spatial attention
        spatial_part = self.spatial_att(x).expand_as(x)
        # add together
        att = 1 + F.sigmoid(channel_part + spatial_part)
        return att * x

class CBAM(nn.Module):
    def __init__(self, gate_channel, reduction=2):
        super().__init__()
        # channel attention
        self.pools = [
            nn.AdaptiveAvgPool3d(1),
            nn.AdaptiveMaxPool3d(1)
        ]
        self.mlp = nn.Sequential(
            nn.Linear(gate_channel, gate_channel // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(gate_channel // reduction, gate_channel)
        )
        self.sigmoid = nn.Sigmoid()

        # spatial attention
        kernel_size = 7
        self.conv = nn.Sequential(
            nn.Conv3d(2, 1, kernel_size=kernel_size, stride=1, padding=(kernel_size-1) // 2),
            LayerNorm(1, data_format="channels_first"),
            nn.Sigmoid()
        )

    def forward(self, x):
        res = x
        b, c, _, _, _ = x.size()

        channel_part = None
        for pool in self.pools:
            y = pool(x).view(b, c)
            y = self.mlp(y)
            channel_part = channel_part + y if channel_part is not None else y 
        channel_part = self.sigmoid(channel_part).view(b, c, 1, 1, 1)
        x = x * channel_part.expand_as(x)

        spatial_part = torch.cat( (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1 )
        spatial_part = self.conv(spatial_part)
        x = x * spatial_part.expand_as(x)

        return x + res

class SE_GCT(nn.Module):
    def __init__(self, channel, reduction=2, epsilon=1e-5, mode='l2', after_relu=False):
        super().__init__()
        # Squeeze and excitation
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel),
        )
        # GCT
        self.alpha = nn.Parameter(torch.ones(1, channel, 1, 1, 1))
        self.gamma = nn.Parameter(torch.zeros(1, channel, 1, 1, 1))
        self.beta = nn.Parameter(torch.zeros(1, channel, 1, 1, 1))
        self.epsilon = epsilon
        self.mode = mode
        self.after_relu = after_relu

    def forward(self, x):
        # compute SE attention
        b, c, _, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        se_gate = self.fc(y).view(b, c, 1, 1, 1).expand_as(x)

        # compute GCT attention
        if self.mode == 'l2':
            embedding = (x.pow(2).sum((2, 3, 4), keepdim=True) + self.epsilon).pow(0.5) * self.alpha
            norm = self.gamma / (embedding.pow(2).mean(dim=1, keepdim=True) + self.epsilon).pow(0.5)
        else:
            logging.WARNING('Unknown mode!')
            sys.exit()
        gct_gate = embedding * norm + self.beta

        # add together
        att = 1 + F.sigmoid(se_gate + gct_gate)
        return att * x

class BAM_GCT(nn.Module):
    def __init__(self, channel, reduction=2, dilation_val=4, num_layers=1, epsilon=1e-5, mode='l2', after_relu=False):
        super().__init__()
        # Squeeze and excitation
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel),
        )
        # GCT
        self.alpha = nn.Parameter(torch.ones(1, channel, 1, 1, 1))
        self.gamma = nn.Parameter(torch.zeros(1, channel, 1, 1, 1))
        self.beta = nn.Parameter(torch.zeros(1, channel, 1, 1, 1))
        self.epsilon = epsilon
        self.mode = mode
        self.after_relu = after_relu

        # input (_, _, D, H, W) -> output (_, _, D, H, W) for any Conv3D
        self.spatial_att = nn.Sequential(
            nn.Conv3d(channel, channel // reduction, kernel_size=1),
            LayerNorm(channel // reduction, data_format="channels_first"),
            nn.GELU(),
            nn.Conv3d(channel // reduction, channel // reduction, kernel_size=3, padding=dilation_val, dilation=dilation_val),
            LayerNorm(channel // reduction, data_format="channels_first"),
            nn.GELU(),
            nn.Conv3d(channel // reduction, channel // reduction, kernel_size=3, padding=dilation_val, dilation=dilation_val),
            LayerNorm(channel // reduction, data_format="channels_first"),
            nn.GELU(),
            nn.Conv3d(channel//reduction, 1, kernel_size=1)
        )

    def forward(self, x):
        b, c, _, _, _ = x.size()
        # compute channel attention

        # compute SE attention
        b, c, _, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1, 1).expand_as(x)

        # compute GCT attention
        if self.mode == 'l2':
            embedding = (x.pow(2).sum((2, 3, 4), keepdim=True) + self.epsilon).pow(0.5) * self.alpha
            norm = self.gamma / (embedding.pow(2).mean(dim=1, keepdim=True) + self.epsilon).pow(0.5)
        else:
            logging.WARNING('Unknown mode!')
            sys.exit()
        gate = embedding * norm + self.beta

        # add together
        channel_part = y + gate

        # compute spatial attention
        spatial_part = self.spatial_att(x).expand_as(x)
        # add together
        att = 1 + F.sigmoid(channel_part + spatial_part)
        return att * x

class CBAM_GCT(nn.Module):
    def __init__(self, channel, reduction=2, epsilon=1e-5, mode='l2', after_relu=False):
        super().__init__()
        # channel attention
        self.pools = [
            nn.AdaptiveAvgPool3d(1),
            nn.AdaptiveMaxPool3d(1)
        ]
        self.mlp = nn.Sequential(
            nn.Linear(channel, channel // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel)
        )
        self.sigmoid = nn.Sigmoid()

        # GCT
        self.alpha = nn.Parameter(torch.ones(1, channel, 1, 1, 1))
        self.gamma = nn.Parameter(torch.zeros(1, channel, 1, 1, 1))
        self.beta = nn.Parameter(torch.zeros(1, channel, 1, 1, 1))
        self.epsilon = epsilon
        self.mode = mode
        self.after_relu = after_relu

        # spatial attention
        kernel_size = 7
        self.conv = nn.Sequential(
            nn.Conv3d(2, 1, kernel_size=kernel_size, stride=1, padding=(kernel_size-1) // 2),
            LayerNorm(1, data_format="channels_first"),
            nn.Sigmoid()
        )

    def forward(self, x):
        res = x
        b, c, _, _, _ = x.size()

        channel_part = None
        for pool in self.pools:
            y = pool(x).view(b, c)
            y = self.mlp(y)
            channel_part = channel_part + y if channel_part is not None else y 
        channel_part = channel_part.view(b, c, 1, 1, 1)

        # compute GCT attention
        if self.mode == 'l2':
            embedding = (x.pow(2).sum((2, 3, 4), keepdim=True) + self.epsilon).pow(0.5) * self.alpha
            norm = self.gamma / (embedding.pow(2).mean(dim=1, keepdim=True) + self.epsilon).pow(0.5)
        else:
            logging.WARNING('Unknown mode!')
            sys.exit()
        gate = embedding * norm + self.beta

        # add together
        channel_part = channel_part + gate
        channel_part = F.sigmoid(channel_part)
        x = x * channel_part.expand_as(x)

        spatial_part = torch.cat( (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1 )
        spatial_part = self.conv(spatial_part)
        x = x * spatial_part.expand_as(x)

        return x + res

class LayerNorm(nn.Module):
    r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs 
    with shape (batch_size, channels, height, width).
    """
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
            return x

# Models

In [11]:
class GRN(nn.Module):
    # gamma, beta: learnable affine transform parameters
    # X: input of shape (N, D, H, W, C)
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.zeros((dim)))
        self.beta = nn.Parameter(torch.zeros((dim)))

    def forward(self, x):
        gx = torch.norm(x, p=2, dim=(1, 2, 3), keepdim=True)
        nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
        return self.gamma * (x * nx) + self.beta + x

class Block(nn.Module):
    r""" ConvNeXt Block. There are two equivalent implementations:
    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
    (2) DwConv -> Permute to (N, D, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
    We use (2) as we find it slightly faster in PyTorch
    
    Args:
        dim (int): Number of input channels.
        drop_path (float): Stochastic depth rate. Default: 0.0
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
    """
    def __init__(self, dim, depth_number: int, pos_in_depth: int, drop_path=0.5, layer_scale_init_value=1e-6):
        super().__init__()
        self.dwconv = nn.Conv3d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
        self.norm = LayerNorm(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, dim)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 
                                    requires_grad=True) if layer_scale_init_value > 0 else None
        self.depth_number = depth_number
        self.pos_in_depth = pos_in_depth
        self.se_gct = SE_GCT(channel=dim, reduction=2) # TODO:
        self.dim = dim

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 4, 1) # (N, C, D, H, W) -> (N, D, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 4, 1, 2, 3) # (N, D, H, W, C) -> (N, C, D, H, W)

        if self.pos_in_depth == 3:
            x = self.se_gct(x)

        x = input + x
        return x

class ConvNeXt(nn.Module):
    r""" ConvNeXt
        A PyTorch impl of : `A ConvNet for the 2020s`  -
          https://arxiv.org/pdf/2201.03545.pdf

    Args:
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
        dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
        drop_path_rate (float): Stochastic depth rate. Default: 0.
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
        head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
    """
    def __init__(
            self, in_chans=4, num_classes=2, 
            depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., 
            layer_scale_init_value=1e-6, head_init_scale=1., len_of_clinical_features=71,
            use_damper=False
        ):
        super().__init__()

        # depths=[1, 1, 3, 1], dims=[8, 16, 32, 64]
        self.use_damper = use_damper
        self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
        stem = nn.Sequential(
            nn.Conv3d(in_chans, dims[0], kernel_size=(1, 4, 4), stride=4),
            LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
        )
        self.downsample_layers.append(stem)
        for i in range(3):
            downsample_layer = nn.Sequential(
                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
                nn.Conv3d(dims[i], dims[i+1], kernel_size=(1, 2, 2), stride=2),
            )
            self.downsample_layers.append(downsample_layer)

        self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
        dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 
        cur = 0
        for i in range(4):
            stage = nn.Sequential(
                *[Block(dim=dims[i], drop_path=dp_rates[cur + j], layer_scale_init_value=layer_scale_init_value, depth_number=i+1, pos_in_depth=j+1) for j in range(depths[i])]
            )
            self.stages.append(stage)
            cur += depths[i]

        self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
        
        # voi size
        self.v_fc1 = nn.Linear(len_of_clinical_features, 16)
        self.v_fc2 = nn.Linear(16, dims[-1])
        self.v_fc3 = nn.Linear(dims[-1], dims[-1])
        
        # final
        len_of_features = dims[-1] + len_of_clinical_features
        self.fc1 = nn.Linear(len_of_features, len_of_features // 2) # 214 = dim[-1] + dim(clinical_info) which is 150
        self.fc2 = nn.Linear(len_of_features // 2, len_of_features // 4)
        self.fc3 = nn.Linear(len_of_features // 4, 2)
        self.bn1 = nn.BatchNorm1d(len_of_features // 2)
        self.bn2 = nn.BatchNorm1d(len_of_features // 4)
        self.softmax = nn.Softmax(dim=1)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv3d, nn.Linear)):
            trunc_normal_(m.weight, std=.02) # 正態分佈
            nn.init.constant_(m.bias, 0)

    def forward_features(self, x):
        for i in range(4):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)
        return self.norm(x.mean([-3, -2, -1])) # global average pooling, (N, C, D, H, W) -> (N, C)

    def forward(self, x, clinical_info):
        x_feature = self.forward_features(x)
        
        if self.use_damper:
            voi_feature = self.v_fc1(clinical_info)
            voi_feature = self.v_fc2(voi_feature)
            voi_feature = self.v_fc3(voi_feature)

            b, n = x_feature.size()
            _sum = (x_feature * voi_feature).sum(dim=(-1))
            x = x_feature - (_sum / n).view(b, 1).expand_as(x_feature)

        x = torch.cat((x_feature, clinical_info), 1)
        x = F.leaky_relu(self.bn1(self.fc1(x)))
        x = F.leaky_relu(self.bn2(self.fc2(x)))
        x = self.fc3(x)
        x = self.softmax(x)

        return x

def convnext_tiny(pretrained=False, in_22k=False, **kwargs):
    # original dims = [96, 192, 384, 768]
    # best dims = [8, 16, 32, 64]
    model = ConvNeXt(depths=[1, 1, 3, 1], dims=[8, 16, 32, 64], **kwargs)
    return model

# Training Functions

In [12]:
def train(args, k: int, device, model, optimizer, criterion, scheduler, trainloader, valloader, seed):
    
    the_last_loss = 100
    trigger_times = 0

    the_best_loss = 0
#     the_best_accuracy = 0
    the_best_auc = 0
    the_best_acc_auc_mul = 0

    save_path = Path(args.model_dir)
    save_tradoff_path = Path(args.model_dir, 'tradoff')
    save_auc_path = Path(args.model_dir, 'auc')
    save_acc_auc_mul_path = Path(args.model_dir, 'acc_auc_mul')

    for epoch in range(args.epochs):
#         epoch_loss, epoch_accuracy = 0, 0
        epoch_loss = 0
        train_label, train_pred = [], []
        model.train()
        for i, (_, img, clinical_info, label) in enumerate(trainloader):
            img, clinical_info, label = img.to(device), clinical_info.to(device), label.to(device)

            output = model(img, clinical_info)  #[bs, 2]

            torch.cuda.empty_cache()

            # _, pred = torch.max(output, dim=1)
            pred = torch.where(output[:, 0] > 0.475, 0, 1)
            loss = criterion(output, label)  #focal
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.detach().cpu().item()
#             epoch_accuracy += torch.sum(pred == label)

            train_label += label.detach().cpu().numpy().tolist()
            train_pred += pred.detach().cpu().numpy().tolist()

        scheduler.step()

        epoch_loss = epoch_loss / (len(trainloader) * args.batch_size)
#         epoch_accuracy = epoch_accuracy.float() / (len(trainloader) * args.batch_size)
#         tp, fn, fp, tn = confusion_matrix(train_label, train_pred).ravel()
#         tnr, fpr, fnr, tpr = tn/(fp+tn), fp/(fp+tn), fn/(tp+fn), tp/(tp+fn)
#         print(f'[{k}/{5}][{epoch}/{args.epochs}] loss: {epoch_loss:.8}, accuracy: {epoch_accuracy:.2}, tnr = {tnr:.2}, tpr = {tpr:.2}')
        print(f'[{k+1}/{5}][{epoch}/{args.epochs}] loss: {epoch_loss:.8}')
        logging.info(f'[{k+1}/{5}][{epoch}/{args.epochs}] loss: {epoch_loss:.8}')

        # Early stopping
        the_current_loss, the_current_accuracy, the_current_auc = validation(device, model, criterion, valloader)
#         print(f'[validation] The current loss: {the_current_loss:.8}, accuracy: {the_current_accuracy:.2}')
        acc_auc_mul = the_current_accuracy * the_current_auc
        logging.info(f'[validation] The current loss: {the_current_loss:.8}, acc*auc: {acc_auc_mul:.4}')

        if the_current_loss > the_last_loss:
            trigger_times += 1
            logging.info('trigger times: {}'.format(trigger_times))

            if trigger_times >= args.patience:
                logging.warning(f'Early stopping!\nEpoch = {epoch}')
                return
        else:
#             print('trigger times: 0')
            trigger_times = 0

#         if epoch == 0 or the_best_loss >= the_current_loss:
#             print(f'Recording best model. ({save_path})')
#             save_model(save_path, model, args, epoch, k)
#             the_best_loss = the_current_loss

#         if the_current_accuracy > the_best_accuracy:
#             print(f'Recording best tradeoff model. ({save_tradoff_path})')
#             save_model(save_tradoff_path, model, args, epoch, k)
#             the_best_accuracy = the_current_accuracy
        
        if the_current_auc > the_best_auc:
            logging.info(f'Recording best auc model. ({save_auc_path})')
            save_model(save_auc_path, model, args, epoch, k)
            the_best_auc = the_current_auc
        
        if acc_auc_mul > the_best_acc_auc_mul:
            logging.critical(f'Recording best acc*auc model. ({save_acc_auc_mul_path})')
            save_model(save_acc_auc_mul_path, model, args, epoch, k)
            the_best_acc_auc_mul = acc_auc_mul

        the_last_loss = the_current_loss

    logging.info(f'stopping! Epoch = {args.epochs}')
    return 

def validation(device, model, criterion, valloader):
    model.eval()
    val_loss = 0
#     val_accuracy = 0
    val_label, val_pred, val_c0 = [], [], []

    with torch.no_grad():
        for _, img, clinical_info, label in valloader:
            img, clinical_info, label = img.to(device), clinical_info.to(device), label.to(device)
            output = model(img, clinical_info)  #[bs, 2]
            
            # _, pred = torch.max(output, dim=1)
            c0 = output[:, 0]
            pred = torch.where(output[:, 0] > 0.475, 0, 1)
            loss = criterion(output, label)  #focal
            val_loss += loss.detach().cpu().item()
#             val_accuracy += torch.sum(pred == label)

            val_label += label.detach().cpu().numpy().tolist()
            val_pred += pred.detach().cpu().numpy().tolist()
            val_c0 += c0.detach().cpu().numpy().tolist()

    acc, auc = evaluate_preds(val_label, val_c0, val_pred)

    val_loss = val_loss / len(valloader)
#     val_accuracy = val_accuracy.float() / len(valloader)
            
    return val_loss, acc, auc

# Arguments Preparing

# Start Training

In [13]:
if __name__ == '__main__' :
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.warning('device = {}'.format(device))

#     for days in [42, 90, 365]:
    for k in range(4, 5):
        ##### Setting Seed #####
        seed = cfg.seed
        seed_setting(seed)

        model = convnext_tiny(len_of_clinical_features=71, use_damper=cfg.use_damper)
        model.to(device)

        weights = get_label_weights(cfg, k)
        criterion = FocalLoss(gamma=2, weights=weights).to(device)
        optimizer = getattr(optim, cfg.optimizer)(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

        scheduler = CosineAnnealingWarmupRestarts(
            optimizer,
            first_cycle_steps=25,
            cycle_mult=1.0,
            max_lr=cfg.lr / 10,
            min_lr=0.000001,
            warmup_steps=10,
            gamma=0.85,
        )
        # Data
        trainset = L3_Dataset(
            images_path=cfg.img_4c_path, 
            clinical_data_path=cfg.all_data_path,
            train_val_test_list_path=cfg.train_val_test_list_path,
            mode="train",
            k=k,
#             days=days,
            days=cfg.pred_days,
            reverse=cfg.y_data_reverse,
            sampling=cfg.data_sampling
        )
        trainloader = DataLoader(trainset, batch_size=cfg.batch_size, shuffle=True, drop_last=True)
        valset = L3_Dataset(
            images_path=cfg.img_4c_path, 
            clinical_data_path=cfg.all_data_path,
            train_val_test_list_path=cfg.train_val_test_list_path,
            mode="val",
            k=k,
#             days=days,
            days=cfg.pred_days,
            reverse=cfg.y_data_reverse,
            sampling=cfg.data_sampling
        )
        valloader = DataLoader(valset, batch_size=1, shuffle=False)

        train(cfg, k, device, model, optimizer, criterion, scheduler, trainloader, valloader, seed)

[5/5][0/100] loss: 0.021929885


  ppv = tp/(tp+fp) # Precision & ppv


[5/5][1/100] loss: 0.02193206


  ppv = tp/(tp+fp) # Precision & ppv


[5/5][2/100] loss: 0.021854491


  ppv = tp/(tp+fp) # Precision & ppv


[5/5][3/100] loss: 0.021715436
[5/5][4/100] loss: 0.02157149
[5/5][5/100] loss: 0.021488281
[5/5][6/100] loss: 0.021397201
[5/5][7/100] loss: 0.02151585
[5/5][8/100] loss: 0.021224004
[5/5][9/100] loss: 0.02123809
[5/5][10/100] loss: 0.021196503
[5/5][11/100] loss: 0.021246462
[5/5][12/100] loss: 0.021070482
[5/5][13/100] loss: 0.020863624
[5/5][14/100] loss: 0.020896857
[5/5][15/100] loss: 0.020743324
[5/5][16/100] loss: 0.020903559
[5/5][17/100] loss: 0.02074339
[5/5][18/100] loss: 0.020423974
[5/5][19/100] loss: 0.02059706
[5/5][20/100] loss: 0.02031956
[5/5][21/100] loss: 0.020526505
[5/5][22/100] loss: 0.020113529
[5/5][23/100] loss: 0.019965697
[5/5][24/100] loss: 0.020136396
[5/5][25/100] loss: 0.020163702
[5/5][26/100] loss: 0.020059818
[5/5][27/100] loss: 0.020252133
[5/5][28/100] loss: 0.019940838
[5/5][29/100] loss: 0.019889957
[5/5][30/100] loss: 0.019938863
[5/5][31/100] loss: 0.020193256
[5/5][32/100] loss: 0.019544589
[5/5][33/100] loss: 0.020215821
[5/5][34/100] loss: 0

  npv = tn/(tn+fn)


[5/5][40/100] loss: 0.01885021
[5/5][41/100] loss: 0.018370528
[5/5][42/100] loss: 0.018072394
[5/5][43/100] loss: 0.018541821
[5/5][44/100] loss: 0.018138083
[5/5][45/100] loss: 0.017632451
[5/5][46/100] loss: 0.017504911
[5/5][47/100] loss: 0.017301232
[5/5][48/100] loss: 0.016895767
[5/5][49/100] loss: 0.01709317
[5/5][50/100] loss: 0.017264307
[5/5][51/100] loss: 0.017059718
[5/5][52/100] loss: 0.016674747
[5/5][53/100] loss: 0.016908877
[5/5][54/100] loss: 0.016716931
[5/5][55/100] loss: 0.016510654
[5/5][56/100] loss: 0.017126794
[5/5][57/100] loss: 0.017402603
[5/5][58/100] loss: 0.016392786
[5/5][59/100] loss: 0.016462873
[5/5][60/100] loss: 0.016783525
[5/5][61/100] loss: 0.016499259
[5/5][62/100] loss: 0.016086412
[5/5][63/100] loss: 0.015928622
[5/5][64/100] loss: 0.015376169
[5/5][65/100] loss: 0.015470947
[5/5][66/100] loss: 0.014949909
[5/5][67/100] loss: 0.014729256
[5/5][68/100] loss: 0.014297752
[5/5][69/100] loss: 0.014482783
[5/5][70/100] loss: 0.013875433
[5/5][71/1

In [14]:
import requests

requests.get('https://hooks.zapier.com/hooks/catch/18160905/3cvu8pz/')

<Response [200]>

In [15]:
def test_model():
    model = convnext_tiny(len_of_clinical_features=71, use_damper=cfg.use_damper)
    model.to(device)
    model.load_state_dict(torch.load('./model/auc/42d_fold2_checkpoint_epoch84.pth')['model_state_dict'])
    model.eval()

    valset = L3_Dataset(
        images_path=cfg.img_4c_path, 
        clinical_data_path=cfg.all_data_path,
        train_val_test_list_path=cfg.train_val_test_list_path,
        mode="val",
        k=2,
        days=42,
        reverse=cfg.y_data_reverse
    )
    valloader = DataLoader(valset, batch_size=1, shuffle=False)

    val_label, val_pred, val_c0 = [], [], []
    with torch.no_grad():
        for _, img, clinical_info, label in valloader:
            img, clinical_info, label = img.to(device), clinical_info.to(device), label.to(device)
            output = model(img, clinical_info)  #[bs, 2]

            # _, pred = torch.max(output, dim=1)
            c0 = output[:, 0]
            pred = torch.where(output[:, 0] > 0.475, 0, 1)
            val_label += label.detach().cpu().numpy().tolist()
            val_pred += pred.detach().cpu().numpy().tolist()
            val_c0 += c0.detach().cpu().numpy().tolist()

    evaluate_preds(val_label, val_c0, val_pred)

# test_model()