In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

datadir = './preprocessed'
libdir = '.'
outputdir = './output_all'
otherdir = '.'
train_bs_ = 4
valid_bs_ = 4
num_workers_ = 8

In [2]:
# [i for i in timm.list_models() if 'resnest' in i]

In [3]:
conditions = ['spinal_canal_stenosis', 'left_neural_foraminal_narrowing', 'right_neural_foraminal_narrowing',
          'left_subarticular_stenosis', 'right_subarticular_stenosis']
levels = ['l1_l2', 'l2_l3', 'l3_l4', 'l4_l5', 'l5_s1']
severity = ['Normal/Mild', 'Moderate', 'Severe']
class CFG:
    seed=42
    device='GPU'
    nprocs=1 # [1, 8]
    num_workers=num_workers_
    train_bs=train_bs_
    valid_bs=valid_bs_
    data_dir = './'
    INP_DIR = 'rsna/'
    target_cols=[c+'_'+l+'_'+s for c in conditions for l in levels for s in severity]
    imgs_dir = ['./input/']
    num_classes=75

    accum_iter=1#2
    max_grad_norm=1000
    print_freq=100
    normalize_mean=[0.4824, 0.4824, 0.4824] # [0.485, 0.456, 0.406] [0.4824, 0.4824, 0.4824]
    normalize_std=[0.22, 0.22, 0.22] # [0.229, 0.224, 0.225] [0.22, 0.22, 0.22]

    suffix="440"
    fold_num=4
    fold_list=[2, 3]
    min_epoch = -1
    epochs = 40
    shift_epoch = 1000
    model_arch="resnest50d" # tf_efficientnetv2_s, resnest50d, resnext50_32x4d, resnet200d, convnext_tiny_384_in22ft1k
    optimizer="AdamW" # Adam, SGD, AdamW
    scheduler="CosineAnnealingLR"#'ReduceLROnPlateau', 'CosineAnnealingLR', 'CosineAnnealingWarmRestarts'
    loss_fn= "Custom_loss"#'Custom_loss', "BCEWithLogitsLoss", "FocalLoss"
    scheduler_warmup="GradualWarmupSchedulerV3" 

    warmup_epo=1
    warmup_factor = 10
    T_0 = 1
    T_max= epochs-warmup_epo-2 
    
    seq_len = 90
    img_size = 256
    p_mixup = 0.5

    lr=4e-5
    min_lr=1e-7
    # lr=23e-5
    # min_lr=23e-6
    weight_decay=0.02
    dropout=0.1

    gpu_parallel=False
    n_early_stopping=8
    debug=False
    multihead=False
    plane = 'sag_t1'

In [4]:
# !ls

In [5]:
# [i for i in timm.list_models(pretrained=True) if 'resnest50d' in i]

In [6]:
# !pip install pandas -q
# !pip install scikit-learn -q
# !pip install warmup-scheduler==0.3 -q
# !pip install timm==0.9.7 -q
# !pip install albumentations==1.3.1
# !pip install opencv-python -q
# !pip install segmentation_models_pytorch
# !apt-get update && apt-get install ffmpeg libsm6 libxext6  -y

# Import

In [7]:
import sys; 

package_paths = [f'{libdir}pytorch-image-models-master']
for pth in package_paths:
    sys.path.append(pth)

import ast
from glob import glob
# import cv2
# from skimage import io
import os
from datetime import datetime
import time
import random
from tqdm import tqdm
from contextlib import contextmanager
import math
import cv2

import numpy as np
import pandas as pd
import sklearn
from sklearn.metrics import roc_auc_score, log_loss
from sklearn import metrics
from sklearn.model_selection import GroupKFold, StratifiedKFold, StratifiedGroupKFold
import torch
import torchvision
from torchvision import transforms
from torchvision.transforms import v2

from torch import nn
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler
from torch.nn.modules.loss import _WeightedLoss
import torch.nn.functional as F
# import matplotlib.pyplot as plt

from torch.optim import Adam, SGD, AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
from warmup_scheduler import GradualWarmupScheduler
import timm
import warnings
import joblib
from scipy.ndimage.interpolation import zoom
# import nibabel as nib
# import pydicom as dicom
import gc 
from torch.nn import DataParallel

import segmentation_models_pytorch as smp


import albumentations as A
from albumentations.pytorch import ToTensorV2



if CFG.device == 'TPU':
    !pip install -q pytorch-ignite
    import ignite.distributed as idist
elif CFG.device == 'GPU':
    from torch.cuda.amp import autocast, GradScaler


import warnings
warnings.filterwarnings('ignore')

  from scipy.ndimage.interpolation import zoom


# Loss & Metric

In [8]:
def sig(x):
    return 1/(1+np.exp(-x))

class FocalLoss(nn.Module):
    def __init__(self, gamma=2):
        super().__init__()
        self.gamma = gamma
        
    def forward(self, input, target):
        if not (target.size() == input.size()):
            raise ValueError("Target size ({}) must be the same as input size ({})"
                             .format(target.size(), input.size()))

        max_val = (-input).clamp(min=0)
        loss = input - input * target + max_val + \
            ((-max_val).exp() + (-input - max_val).exp()).log()

        invprobs = F.logsigmoid(-input * (target * 2.0 - 1.0))
        loss = (invprobs * self.gamma).exp() * loss
        
        return loss.sum(dim=1).mean()

def torch_sigmoid(x): return (1 + (-x).exp()).reciprocal()


class Custom_loss(nn.Module):
    


    def __init__(self, temperature=0.0):
        """
        Use max if temperature = 0
        """
        super().__init__()
        self.t = temperature
        assert self.t >= 0
    def __repr__(self):
        return 'SevereLoss(t=%.1f)' % self.t

    def forward(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        y_pred = y_pred.reshape(y_pred.shape[0], 25, 3).transpose(1, 2)

        y = y.reshape(y.shape[0], 25, 3).transpose(1, 2).argmax(1)
        """
        Args:
          y_pred (Tensor[float]): logit             (batch_size, 3, 25)
          y      (Tensor[int]):   true label index  (batch_size, 25)
        """
        assert y_pred.size(0) == y.size(0)
        assert y_pred.size(1) == 3 and y_pred.size(2) == 25
        assert y.size(1) == 25
        assert y.size(0) > 0
        slices = [slice(0, 5), slice(5, 15), slice(15, 25)] 

        loss = F.cross_entropy(y_pred, y.long(), reduction='none')  # (batch_size, 25)

        wloss_sums = []
        for k, idx in enumerate(slices):
            wloss_sums.append((loss[:, idx]).sum())
        y_spinal_prob = y_pred[:, :, :5].softmax(dim=1)             # (batch_size, 3,  5)

        y_max = torch.amax(y[:, :5] == 2, dim=1).to(y_pred.dtype)   # 0 or 1
        if self.t > 0:
            # Attention for the maximum value
            attn = F.softmax(y_spinal_prob[:, 2, :] / self.t, dim=1)  # (batch_size, 5)

            # Pick the sofmax among 5 severe=2 y_spinal_probs with attn
            y_pred_max = (attn * y_spinal_prob[:, 2, :]).sum(dim=1)   # weighted average among 5 spinal columns 
        else:
            # Exact max; this works too
            y_pred_max = y_spinal_prob[:, 2, :].amax(dim=1)
        loss_max = F.binary_cross_entropy(y_pred_max, y_max, reduction='none')
        wloss_sums.append((loss_max).sum())

        # See below about these numbers
        loss = (wloss_sums[0] / 6.084050632911392 +
                wloss_sums[1] / 12.962531645569621 + 
                wloss_sums[2] / 14.38632911392405# +
                # wloss_sums[3] / 1.729113924050633
               ) / (3 * y.size(0))

        return loss


def comp_metric_score(y_true, y_pred, w1 = 1, w2 = 1, apply_sig = True):
    y_pred = sig(y_pred)
    y_pred[:, 1::3] *= w1
    y_pred[:, 2::3] *= w2
    for i in range(0, 75, 3):
        y_pred[:, i:i+3] /= y_pred[:, i:i+3].sum(1)[:, None]
    y_pred = torch.from_numpy(y_pred)
    y_pred = (torch.log(y_pred)/(1-y_pred))


    y = torch.from_numpy(y_true.copy())
    
    y_pred = y_pred.reshape(y_pred.shape[0], 25, 3).transpose(1, 2)
    y = y.reshape(y.shape[0], 25, 3).transpose(1, 2).argmax(1)
    # 4 losses for the evaluation metric
    loss4_sum = torch.zeros(4)
    w_sum = torch.zeros(4)
    slices = [slice(0, 5), slice(5, 15), slice(15, 25)]  # spinal, foraminal, subarticular

    w = 2 ** y  # sample_weight w = (1, 2, 4) for y = 0, 1, 2 (batch_size, 25)

    batch_size = y_pred.shape[0]

    ce_loss = F.cross_entropy(y_pred, y.long(), reduction='none')  # (batch_size, 25)
    for k, idx in enumerate(slices):
        w_sum[k] += w[:, idx].sum()
        loss4_sum[k] += (w[:, idx] * ce_loss[:, idx]).sum()
    # Spinal max
    y_spinal_prob = y_pred[:, :, :5].softmax(dim=1)            # (batch_size, 3,  5)
    w_max = torch.amax(w[:, :5], dim=1)                        # (batch_size, )
    y_max = torch.amax(y[:, :5] == 2, dim=1).to(torch.float)   # 0 or 1
    y_pred_max = y_spinal_prob[:, 2, :].amax(dim=1)            # max in severe (class=2)
    loss_max = F.binary_cross_entropy(y_pred_max, y_max, reduction='none')
    loss4_sum[3] += (w_max * loss_max).sum()
    w_sum[3] += w_max.sum()
    # Average over spinal, foraminal, subarticular, and any_severe_spinal
    score = (loss4_sum / w_sum).sum().item() / 4
    return score

class Custom_loss2(nn.Module):

    def forward(self, p, t):
        p = p.sigmoid().reshape(-1, 25, 3)
        p[..., 1] *= 2
        p[..., 2] *= 4
        p /= p.sum(2)[:,:,None]
        p = p.reshape(-1, 75)
        return F.binary_cross_entropy(p, t)

def comp_metric_score2(y_true, y_pred, w1 = 1, w2 = 1, apply_sig = True):
    t = y_true.copy()
    p = y_pred.copy()
    if apply_sig:
        p = sig(p)
    p[:, 1::3] *= w1
    p[:, 2::3] *= w2
    for i in range(0, 75, 3):
        p[:,i:i+3] /= p[:,i:i+3].sum(1)[:, None]
    t_sp, p_sp = t[:, :15].reshape(-1, 3), p[:, :15].reshape(-1, 3)
    t_ne, p_ne = t[:, 15:45].reshape(-1, 3), p[:, 15:45].reshape(-1, 3)
    t_su, p_su = t[:, 45:].reshape(-1, 3), p[:, 45:].reshape(-1, 3)
    t_sp_sev, p_sp_sev = t[:, :15][:, 2::3].max(1), p[:, :15][:, 2::3].max(1)
    loss_sp = log_loss(t_sp, p_sp, sample_weight = t_sp[:,0] + 2*t_sp[:,1]+4*t_sp[:, 2])
    loss_ne = log_loss(t_ne, p_ne, sample_weight = t_ne[:,0] + 2*t_ne[:,1]+4*t_ne[:, 2])
    loss_su = log_loss(t_su, p_su, sample_weight = t_su[:,0] + 2*t_su[:,1]+4*t_su[:, 2])
    loss_sp_sev = log_loss(t_sp_sev, p_sp_sev)
    print(loss_sp, loss_ne, loss_su, loss_sp_sev)
    loss = (loss_sp + loss_ne + loss_su + loss_sp_sev)/4
    print('w1 =', w1, 'w2 =', w2, loss)
    return loss

# helper

In [10]:
ENV = 'kaggle'
if CFG.device == 'TPU':
    import os
    VERSION = "1.7"
    CP_V = "36" if ENV == "colab" else "37"
    wheel = f"torch_xla-{VERSION}-cp{CP_V}-cp{CP_V}m-linux_x86_64.whl"
    url = f"https://storage.googleapis.com/tpu-pytorch/wheels/{wheel}"
    !pip3 -q install cloud-tpu-client==0.10 $url
    os.system('export XLA_USE_BF16=1')
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    import torch_xla.distributed.xla_multiprocessing as xmp
    CFG.lr = CFG.lr * CFG.nprocs
    CFG.train_bs = CFG.train_bs // CFG.nprocs
    device = xm.xla_device()
    
elif CFG.device == "GPU":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [11]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True 

seed_everything(CFG.seed)


def get_score(y_true, y_pred):
    scores = []
    for i in range(y_true.shape[1]):
        score = roc_auc_score(y_true[:,i], y_pred[:,i])
        scores.append(score)
    avg_score = np.mean(scores)
    return avg_score, scores


@contextmanager
def timer(name):
    t0 = time.time()
    LOGGER.info(f'[{name}] start')
    yield
    LOGGER.info(f'[{name}] done in {time.time() - t0:.0f} s.')


def init_logger(log_file=outputdir+'stage2_train.log'):
    from logging import getLogger, INFO, FileHandler,  Formatter,  StreamHandler
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

if not os.path.isdir(outputdir):
    os.mkdir(outputdir)
LOGGER = init_logger(outputdir+f'/stage2_train{CFG.suffix}.log')

if CFG.device=='TPU' and CFG.nprocs==8:
    loginfo = xm.master_print
    cusprint = xm.master_print
else:
    loginfo = LOGGER.info
    cusprint = print



def get_timediff(time1,time2):
    minute_,second_ = divmod(time2-time1,60)
    return f"{int(minute_):02d}:{int(second_):02d}"  


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))


def get_img(path):
    im_bgr = cv2.imread(path)
    im_rgb = im_bgr[:, :, ::-1]
    return im_rgb

def load_dicom(path):
    """
    This supports loading both regular and compressed JPEG images. 
    See the first sell with `pip install` commands for the necessary dependencies
    """
    img = dicom.dcmread(path)
    img.PhotometricInterpretation = 'YBR_FULL'
    data = img.pixel_array
    data = data - np.min(data)
    if np.max(data) != 0:
        data = data / np.max(data)
    # data = (data * 255).astype(np.uint8)
    return data

# DataSet

In [12]:
class TrainDataset(Dataset):
    def __init__(self, csv, ser_dict_sag_t1, ser_dict_sag_t2, ser_dict_ax_t2, targets, mode, meta_features, in_chans, transform=None):

        self.csv = csv.reset_index(drop=True)
        self.ser_dict_sag_t1 = ser_dict_sag_t1
        self.ser_dict_sag_t2 = ser_dict_sag_t2
        self.ser_dict_ax_t2 = ser_dict_ax_t2
        self.targets = targets
        self.mode = mode
        self.use_meta = meta_features is not None
        self.meta_features = meta_features
        self.transform = transform
        self.in_chans = in_chans

        self.aug_tr = v2.RandomRotation(
                            degrees = (-45, 45),
                            interpolation = torchvision.transforms.InterpolationMode.BILINEAR,
                            expand = False,
                            center = None,
                            fill = 0
                        )
        self.aug_sj = v2.ScaleJitter(
                        target_size = [CFG.img_size, CFG.img_size],
                        scale_range = (0.8, 1.2),
                        interpolation = torchvision.transforms.InterpolationMode.BILINEAR,
                        antialias = True)
        self.aug_rc = v2.RandomCrop(
                            size = [CFG.img_size, CFG.img_size],
                            #padding = None,
                            pad_if_needed = True,
                            fill = 0,
                            padding_mode = 'constant'
                        )
        self.aug_hf = v2.RandomHorizontalFlip(0.5)
        self.aug_vf = v2.RandomVerticalFlip(0.5)

    def __len__(self):
        return self.csv.shape[0]

    def __getitem__(self, index):

        row = self.csv.iloc[index]


        all_image = []
        for plane in ['ax_t2', 'sag_t2']:#, 'sag_t1'

            study_id = int(row.study_id)

            if plane in ['sag_t1', 'sag_t2'] and study_id not in self.ser_dict_sag_t1.keys():
                if self.mode == 'train':
                    serie_id = np.random.choice(self.ser_dict_sag_t2[study_id])
                else:
                    serie_id = self.ser_dict_sag_t2[study_id][-1]
            elif plane in ['sag_t1', 'sag_t2'] and study_id not in self.ser_dict_sag_t2.keys():
                if self.mode == 'train':
                    serie_id = np.random.choice(self.ser_dict_sag_t1[study_id])
                else:
                    serie_id = self.ser_dict_sag_t1[study_id][-1]
            else:
                if plane == 'sag_t1':
                    if self.mode == 'train':
                        serie_id = np.random.choice(self.ser_dict_sag_t1[study_id])
                    else:
                        serie_id = self.ser_dict_sag_t1[study_id][-1]
                elif plane == 'sag_t2':
                    if self.mode == 'train':
                        serie_id = np.random.choice(self.ser_dict_sag_t2[study_id])
                    else:
                        serie_id = self.ser_dict_sag_t2[study_id][-1]
                else:
                    if self.mode == 'train':
                        serie_id = np.random.choice(self.ser_dict_ax_t2[study_id])
                    else:
                        serie_id = self.ser_dict_ax_t2[study_id][-1]
            if self.mode == 'train':
                serie_id_ax_t2 = np.random.choice(self.ser_dict_ax_t2[study_id])
            else:
                serie_id_ax_t2 = self.ser_dict_ax_t2[study_id][-1]

            fp = f'/dev/shm/preprocessed/{study_id}_{serie_id}.npy'
            image = np.load(fp)
    
            image = torch.from_numpy(image).float()/255
            image = F.interpolate(
                     image.unsqueeze(0).unsqueeze(0),
                     size=[CFG.seq_len, CFG.img_size, CFG.img_size],
                     mode='trilinear'
                 ).squeeze(0).squeeze(0)
    
            l_img = image.shape[0]
            if (self.mode == 'train') and np.random.rand()<0.5:
                image = self.aug_tr(image)
                image = self.aug_sj(image)
                image = self.aug_rc(image)
    
            l_img = image.shape[0]
            if (self.mode == 'train') and np.random.rand()<0.5:
                inds = np.random.choice(np.arange(l_img), CFG.seq_len)
                inds.sort()
                image = image[inds]
            image = image.numpy()
            tar = torch.tensor(row[self.targets]).float()

            # transform
            if self.transform:
                
                image = np.transpose(image, (1, 2, 0))
                augmented = self.transform(image=image)
                image = augmented['image']
                image = np.transpose(image, (2, 0, 1))
                image = torch.from_numpy(image)
    

            all_image.append(image)

        image = torch.stack(all_image, -1)


        if self.mode == 'test':
            return study_id, image
        else:
            return study_id, image, torch.tensor(row[self.targets]).float()



def get_transforms(*, data):
    if data == 'train':
        return Compose([
    albumentations.Resize(CFG.img_size, CFG.img_size),
    albumentations.Perspective(p=0.5),
    albumentations.HorizontalFlip(p=0.5),
    albumentations.VerticalFlip(p=0.5),
    albumentations.Transpose(p=0.5),
    albumentations.RandomBrightness(limit=0.1, p=0.7),
    albumentations.ShiftScaleRotate(shift_limit=0.3, scale_limit=0.3, rotate_limit=45, border_mode=4, p=0.7),

    albumentations.OneOf([
        albumentations.MotionBlur(blur_limit=3),
        albumentations.MedianBlur(blur_limit=3),
        albumentations.GaussianBlur(blur_limit=3),
        albumentations.GaussNoise(var_limit=(3.0, 9.0)),
    ], p=0.5),
    albumentations.OneOf([
        albumentations.OpticalDistortion(distort_limit=1.),
        albumentations.GridDistortion(num_steps=5, distort_limit=1.),
    ], p=0.5),

    albumentations.Cutout(max_h_size=int(CFG.img_size * 0.5), max_w_size=int(CFG.img_size * 0.5), num_holes=1, p=0.5),
            ])
    elif data == 'light_train':
        return A.Compose([
            A.Resize(CFG.img_size, CFG.img_size),
            A.Perspective(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.3, scale_limit=0.3, rotate_limit=45, p=0.75),
            A.OneOf(
                [
                    A.RandomGamma(p=1),
                    A.RandomBrightnessContrast(contrast_limit=0.2, brightness_limit=0.0, p=1),
                    A.RandomBrightnessContrast(contrast_limit=0.0, brightness_limit=0.2, p=1),
                ],
                p=0.5,
            ),
            A.OneOf([
                A.MotionBlur(p = 1),
                A.GaussianBlur(p = 1),
                A.GaussNoise(p = 1),
                ], p=0.5),
            A.GridDistortion(num_steps=5, distort_limit=0.05, p=0.5),
            ], p=1.0)
    elif data == 'valid':
        return A.Compose([
            A.Resize(CFG.img_size, CFG.img_size),
        ], p=1.0)

def get_df():

    df = pd.read_csv('train.csv')

    for c in conditions:
        for l in levels:
            for s in severity:
                df[c + '_' + l + '_' + s] =  (df[c + '_' + l] == s).astype(int)
    df = df[['study_id']+CFG.target_cols]
    tr_ser_desc = pd.read_csv('train_series_descriptions.csv')
    ser_dict_sag_t1 = tr_ser_desc[tr_ser_desc['series_description']=='Sagittal T1'][['study_id', 'series_id']].groupby('study_id').apply(lambda df:df.series_id.tolist()).to_dict()
    ser_dict_sag_t2 = tr_ser_desc[tr_ser_desc['series_description']=='Sagittal T2/STIR'][['study_id', 'series_id']].groupby('study_id').apply(lambda df:df.series_id.tolist()).to_dict()
    ser_dict_ax_t2 = tr_ser_desc[tr_ser_desc['series_description']=='Axial T2'][['study_id', 'series_id']].groupby('study_id').apply(lambda df:df.series_id.tolist()).to_dict()
    print(len(df), len(ser_dict_sag_t1), len(ser_dict_sag_t2), len(ser_dict_ax_t2))
    skf = StratifiedGroupKFold(n_splits = CFG.fold_num)
    df['fold'] = -1
    for i, (train_inds, val_inds) in enumerate(skf.split(df, df[CFG.target_cols[-1]], df['study_id'])):
        df.loc[val_inds, 'fold'] = i
    return df, ser_dict_sag_t1, ser_dict_sag_t2, ser_dict_ax_t2

# Model

In [14]:
import torch.nn as nn
from itertools import repeat

class SpatialDropout(nn.Module):
    def __init__(self, drop=0.5):
        super(SpatialDropout, self).__init__()
        self.drop = drop
        
    def forward(self, inputs, noise_shape=None):
        """
        @param: inputs, tensor
        @param: noise_shape, tuple
        """
        outputs = inputs.clone()
        if noise_shape is None:
            noise_shape = (inputs.shape[0], *repeat(1, inputs.dim()-2), inputs.shape[-1]) 
        
        self.noise_shape = noise_shape
        if not self.training or self.drop == 0:
            return inputs
        else:
            noises = self._make_noises(inputs)
            if self.drop == 1:
                noises.fill_(0.0)
            else:
                noises.bernoulli_(1 - self.drop).div_(1 - self.drop)
            noises = noises.expand_as(inputs)    
            outputs.mul_(noises)
            return outputs
            
    def _make_noises(self, inputs):
        return inputs.new().resize_(self.noise_shape)


import torch
from torch import nn
import torch.nn.functional as F

from typing import Dict, Optional
 
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor


    
class MLPAttentionNetwork(nn.Module):
 
    def __init__(self, hidden_dim, attention_dim=None):
        super(MLPAttentionNetwork, self).__init__()
 
        self.hidden_dim = hidden_dim
        self.attention_dim = attention_dim
        if self.attention_dim is None:
            self.attention_dim = self.hidden_dim
        # W * x + b
        self.proj_w = nn.Linear(self.hidden_dim, self.attention_dim, bias=True)
        # v.T
        self.proj_v = nn.Linear(self.attention_dim, 1, bias=False)
 
    def forward(self, x):
        """
        :param x: seq_len, batch_size, hidden_dim
        :return: batch_size * seq_len, batch_size * hidden_dim
        """
        # print(f"x shape:{x.shape}")
        batch_size, seq_len, _ = x.size()
        # flat_inputs = x.reshape(-1, self.hidden_dim) # (batch_size*seq_len, hidden_dim)
        # print(f"flat_inputs shape:{flat_inputs.shape}")
        
        H = torch.tanh(self.proj_w(x)) # (batch_size, seq_len, hidden_dim)
        # print(f"H shape:{H.shape}")
        
        att_scores = torch.softmax(self.proj_v(H),axis=1) # (batch_size, seq_len)
        # print(f"att_scores shape:{att_scores.shape}")
        
        attn_x = (x * att_scores).sum(1) # (batch_size, hidden_dim)
        # print(f"attn_x shape:{attn_x.shape}")
        return attn_x

In [15]:
def gem(x, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1.0 / p)


class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6, p_trainable=False):
        super(GeM, self).__init__()
        if p_trainable:
            self.p = nn.parameter.Parameter(torch.ones(1) * p)
        else:
            self.p = p
        self.eps = eps

    def forward(self, x):
        ret = gem(x, p=self.p, eps=self.eps)
        return ret

    def __repr__(self):
        return (
            self.__class__.__name__
            + "("
            + "p="
            + "{:.4f}".format(self.p.data.tolist()[0])
            + ", "
            + "eps="
            + str(self.eps)
            + ")"
        )
    
    
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.

    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor

class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None, scale_by_keep=True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x, mask=None):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class RSNAClassifier(nn.Module):
    def __init__(self, model_arch, hidden_dim=256, seq_len=CFG.seq_len, pretrained=False, softmax = False, freeze_layers = True, fold = -1):
        super().__init__()
        self.seq_len = seq_len
        self.fold = fold
        self.model = timm.create_model(model_arch, in_chans=3, pretrained=pretrained)
        self.model2 = timm.create_model(model_arch, in_chans=3, pretrained=pretrained)
        self.softmax = softmax
        self.freeze_layers = freeze_layers

        cnn_feature = self.model.fc.in_features
        self.model.global_pool = nn.Identity()
        self.model.fc = nn.Identity()
        self.pooling = nn.AdaptiveAvgPool2d(1)#GeM(p_trainable=False)

        self.model2.global_pool = nn.Identity()
        self.model2.fc = nn.Identity()

        self.spatialdropout = SpatialDropout(CFG.dropout)
        self.gru = nn.GRU(cnn_feature, cnn_feature//2, num_layers=1, bidirectional=True, batch_first=True)
        self.gru2 = nn.GRU(cnn_feature, cnn_feature//2, num_layers=1, bidirectional=True, batch_first=True)
        self.mlp_attention_layers = nn.ModuleList([MLPAttentionNetwork(cnn_feature) for _ in range(5)])
        self.mlp_attention_layers2 = nn.ModuleList([MLPAttentionNetwork(cnn_feature) for _ in range(5)])
        self.logits = nn.ModuleList([nn.Sequential(
                nn.Linear(2*cnn_feature, 512),
                nn.Mish(),
                nn.Dropout(0.1),
                nn.Linear(512, 15),
                ) for _ in range(5)])
        self.logits2 = nn.ModuleList([nn.Sequential(
                nn.Dropout(0.1),
                nn.Linear(cnn_feature, 15),
                ) for _ in range(5)])
        self.logits3 = nn.ModuleList([nn.Sequential(
                nn.Dropout(0.1),
                nn.Linear(cnn_feature, 15),
                ) for _ in range(5)])

        if self.freeze_layers:
            ws1 = torch.load(f'/dev/shm/rsna-v5/ax-t2/resnest50d_440_fold{self.fold}.pth')['model']
            ws2 = torch.load(f'/dev/shm/rsna-v5/sag-t2/resnest50d_440_fold{self.fold}.pth')['model']
            keys = list(ws1.keys())
            for k in keys:
                if not k.startswith('model.'):
                    ws1.pop(k)
                    ws2.pop(k)
                else:
                    ws1[k[6:]] = ws1.pop(k)
                    ws2[k[6:]] = ws2.pop(k)
            self.model.load_state_dict(ws1)
            self.model2.load_state_dict(ws2)
            for param in self.model.parameters():
                param.requires_grad = False
            for param in self.model2.parameters():
                param.requires_grad = False

            ws1 = torch.load(f'/dev/shm/rsna-v5/ax-t2/resnest50d_440_fold{self.fold}.pth')['model']
            ws2 = torch.load(f'/dev/shm/rsna-v5/sag-t2/resnest50d_440_fold{self.fold}.pth')['model']
            keys = list(ws1.keys())
            for k in keys:
                if not k.startswith('gru.'):
                    ws1.pop(k)
                    ws2.pop(k)
                else:
                    ws1[k[4:]] = ws1.pop(k)
                    ws2[k[4:]] = ws2.pop(k)
            self.gru.load_state_dict(ws1)
            self.gru2.load_state_dict(ws2)
            for param in self.gru.parameters():
                param.requires_grad = False
            for param in self.gru2.parameters():
                param.requires_grad = False

            ws1 = torch.load(f'/dev/shm/rsna-v5/ax-t2/resnest50d_440_fold{self.fold}.pth')['model']
            ws2 = torch.load(f'/dev/shm/rsna-v5/sag-t2/resnest50d_440_fold{self.fold}.pth')['model']
            keys = list(ws1.keys())
            for k in keys:
                if not k.startswith('logits.'):
                    ws1.pop(k)
                    ws2.pop(k)
                else:
                    ws1[k[7:]] = ws1.pop(k)
                    ws2[k[7:]] = ws2.pop(k)
            self.logits2.load_state_dict(ws1)
            self.logits3.load_state_dict(ws2)
            for param in self.gru.parameters():
                param.requires_grad = False
            for param in self.gru2.parameters():
                param.requires_grad = False
            
            ws1 = torch.load(f'/dev/shm/rsna-v5/ax-t2/resnest50d_440_fold{self.fold}.pth')['model']
            ws2 = torch.load(f'/dev/shm/rsna-v5/sag-t2/resnest50d_440_fold{self.fold}.pth')['model']
            keys = list(ws1.keys())
            for k in keys:
                if not k.startswith('mlp_attention_layers.'):
                    ws1.pop(k)
                    ws2.pop(k)
                else:
                    ws1[k[21:]] = ws1.pop(k)
                    ws2[k[21:]] = ws2.pop(k)
            self.mlp_attention_layers.load_state_dict(ws1)
            self.mlp_attention_layers2.load_state_dict(ws2)
            for param in self.mlp_attention_layers.parameters():
                param.requires_grad = False
            for param in self.mlp_attention_layers2.parameters():
                param.requires_grad = False

    def forward(self, x): # (B, seq_len, H, W)
        bs = x.size(0)

        x = x.reshape(bs*self.seq_len//3, 3, x.size(2), x.size(3), x.size(4)) # (B*seq_len, 1, H, W)
        features = self.model(x[..., 0])
        features2 = self.model2(x[..., 1])
        if "res" in CFG.model_arch:                             
            features = self.pooling(features).view(bs*self.seq_len//3, -1) # (B*seq_len, cnn_feature)
            features2 = self.pooling(features2).view(bs*self.seq_len//3, -1) 
        features = features.reshape(bs, self.seq_len//3, -1)       # (B, seq_len, cnn_feature)
        features2 = features2.reshape(bs, self.seq_len//3, -1)
        features, _ = self.gru(features)                        # (B, seq_len, hidden_dim*2)
        features2, _ = self.gru2(features2)
        preds = []
        for att_layer, att_layer2, logit, logit2, logit3 in zip(self.mlp_attention_layers, 
                                                            self.mlp_attention_layers2, 
                                                            self.logits, self.logits2, self.logits3):
            p = att_layer(features)
            p2 = att_layer2(features2)
            p = logit(torch.cat([p, p2], -1))
            preds.append(p)
        preds = torch.cat(preds, -1)
        if self.softmax:
            preds = torch.cat([preds[:, i:i+3].softmax(1) for i in range(0, 75, 3)], -1)
        return preds

In [16]:
# model = RSNAClassifier(CFG.model_arch, hidden_dim=128, seq_len=CFG.seq_len, pretrained=True)

In [17]:
def get_activation(activ_name: str="relu"):
    """"""
    act_dict = {
        "relu": nn.ReLU(inplace=True),
        "tanh": nn.Tanh(),
        "sigmoid": nn.Sigmoid(),
        "identity": nn.Identity()}
    if activ_name in act_dict:
        return act_dict[activ_name]
    else:
        raise NotImplementedError
        

class Conv2dBNActiv(nn.Module):
    """Conv2d -> (BN ->) -> Activation"""
    
    def __init__(
        self, in_channels, out_channels,
        kernel_size, stride, padding,
        bias=False, use_bn=True, activ="relu"
    ):
        """"""
        super(Conv2dBNActiv, self).__init__()
        layers = []
        layers.append(nn.Conv2d(
            in_channels, out_channels,
            kernel_size, stride, padding, bias=bias))
        if use_bn:
            layers.append(nn.BatchNorm2d(out_channels))
            
        layers.append(get_activation(activ))
        self.layers = nn.Sequential(*layers)
        
    def forward(self, x):
        """Forward"""
        return self.layers(x)
        
    
class SpatialAttentionBlock(nn.Module):
    """Spatial Attention for (C, H, W) feature maps"""
    
    def __init__(
        self, in_channels,
        out_channels_list,
    ):
        """Initialize"""
        super(SpatialAttentionBlock, self).__init__()
        self.n_layers = len(out_channels_list)
        channels_list = [in_channels] + out_channels_list
        assert self.n_layers > 0
        assert channels_list[-1] == 1
        
        for i in range(self.n_layers - 1):
            in_chs, out_chs = channels_list[i: i + 2]
            layer = Conv2dBNActiv(in_chs, out_chs, 3, 1, 1, activ="relu")
            setattr(self, f"conv{i + 1}", layer)
            
        in_chs, out_chs = channels_list[-2:]
        layer = Conv2dBNActiv(in_chs, out_chs, 3, 1, 1, activ="sigmoid")
        setattr(self, f"conv{self.n_layers}", layer)
    
    def forward(self, x):
        """Forward"""
        h = x
        for i in range(self.n_layers):
            h = getattr(self, f"conv{i + 1}")(h)
            
        h = h * x
        return h



class MultiHeadResNet200D(nn.Module):
    def __init__(self, out_dims_head=[3, 4, 3, 1],  pretrained=False):
        self.base_name = "resnet200d_320"
        self.n_heads = len(out_dims_head)
        super(MultiHeadResNet200D, self).__init__()
        
        # # load base model
        base_model = timm.create_model(self.base_name, num_classes=sum(out_dims_head), pretrained=False)
        in_features = base_model.num_features
        
        if pretrained:
            pretrained_model_path = CFG.student
            state_dict = dict()
            for k, v in torch.load(pretrained_model_path, map_location='cpu')["model"].items():
                if k[:6] == "model.":
                    k = k.replace("model.", "")
                state_dict[k] = v
            base_model.load_state_dict(state_dict)
        
        # # remove global pooling and head classifier
        base_model.reset_classifier(0, '')
        
        # # Shared CNN Bacbone
        self.backbone = base_model
        
        # # Multi Heads.
        for i, out_dim in enumerate(out_dims_head):
            layer_name = f"head_{i}"
            layer = nn.Sequential(
                SpatialAttentionBlock(in_features, [64, 32, 16, 1]),
                nn.AdaptiveAvgPool2d(output_size=1),
                nn.Flatten(start_dim=1),
                nn.Linear(in_features, in_features),
                nn.ReLU(inplace=True),
                nn.Dropout(0.5),
                nn.Linear(in_features, out_dim))
            setattr(self, layer_name, layer)

    def forward(self, x):
        h = self.backbone(x)
        hs = [getattr(self, f"head_{i}")(h) for i in range(self.n_heads)]
        y = torch.cat(hs, axis=1)
        return None, None, y

In [18]:
def mixup(input, truth, clip=[0, 1]):
    indices = torch.randperm(input.size(0))
    shuffled_input = input[indices]
    shuffled_labels = truth[indices]

    lam = np.random.uniform(clip[0], clip[1])
    input = input * lam + shuffled_input * (1 - lam)
    return input, truth, shuffled_labels, lam



def train_one_epoch(train_loader, model, criterion, optimizer, epoch, scheduler, device):
    if CFG.device == 'GPU':
        scaler = GradScaler()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    # switch to train mode
    model.train()
    start = end = time.time()
    for step, (study_id, images, labels) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = images.to(device, dtype=torch.float)
        labels = labels.to(device, dtype=torch.float)
        batch_size = labels.size(0)
        
        do_mixup = False
        if random.random() < CFG.p_mixup:
            do_mixup = True
            images, labels, labels_mix, lam = mixup(images, labels)

        if CFG.device == 'GPU':
            # with autocast():
            y_preds = model(images)
            y_preds = y_preds.squeeze(1)
            loss = criterion(y_preds, labels)
            if do_mixup:
                loss11 = criterion(y_preds, labels_mix)
                loss = loss * lam  + loss11 * (1 - lam)
            # record loss
            losses.update(loss.item(), batch_size)
            if CFG.accum_iter > 1:
                loss = loss / CFG.accum_iter
            scaler.scale(loss).backward()
            grad_norm = 0 # torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
            if (step + 1) % CFG.accum_iter == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
        elif CFG.device == 'TPU':
            y_preds = model(images)
            loss = criterion(y_preds, labels)
            # record loss
            losses.update(loss.item(), batch_size)
            if CFG.accum_iter > 1:
                loss = loss / CFG.accum_iter
            loss.backward()
            grad_norm = 0 # torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
            if (step + 1) % CFG.accum_iter == 0:
                xm.optimizer_step(optimizer, barrier=True)
                optimizer.zero_grad()
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
            cusprint('Epoch: [{0}][{1}/{2}] '
                'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                'Elapsed {remain:s} '
                'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                'Grad: {grad_norm:.4f}  '
                'LR: {lr:.7f}  '
                .format(
                epoch, step, len(train_loader), batch_time=batch_time,
                data_time=data_time, loss=losses,
                remain=timeSince(start, float(step+1)/len(train_loader)),
                grad_norm=grad_norm,
                lr=optimizer.param_groups[0]["lr"],
                ))
        if(step == 0):print(loss)

    return losses.avg, optimizer.param_groups[0]["lr"]

In [19]:
# 

In [20]:
def tta(images, nbr):
    return torch.cat([images[:,nbr:,...], torch.cat(nbr*[images[:,-1:,...]], 1)], 1)


def evaluate(y_preds, y_true, www = 1):
    y_pred = torch.from_numpy(y_preds.copy())
    y = torch.from_numpy(y_true.copy())
    # loss0 = criterion(y_pred, y)
    
    y_pred = y_pred.reshape(y_pred.shape[0], 25, 3).transpose(1, 2).softmax(dim=1)
    if www==2:
        y_pred[:, 1] *= 2
        y_pred[:, 2] *= 4
        y_pred /= y_pred.sum(1)[:, None]
    y_pred = (torch.log(y_pred)/(1-y_pred))
    y = y.reshape(y.shape[0], 25, 3).transpose(1, 2).argmax(1)

    n_sum = 0
    # loss0_sum = 0.0  # loss for the criterion
    
    # 4 losses for the evaluation metric
    loss4_sum = torch.zeros(4)
    w_sum = torch.zeros(4)
    slices = [slice(0, 5), slice(5, 15), slice(15, 25)]  # spinal, foraminal, subarticular

    w = 2 ** y  # sample_weight w = (1, 2, 4) for y = 0, 1, 2 (batch_size, 25)

    batch_size = y_pred.shape[0]

    n_sum += batch_size
    # loss0_sum += loss0.item() * batch_size

    # Compute score
    # - weighted loss for spinal, foraminal, subarticular
    # - binary cross entropy for maximum spinal severe
    ce_loss = F.cross_entropy(y_pred, y.long(), reduction='none')  # (batch_size, 25)
    for k, idx in enumerate(slices):
        w_sum[k] += w[:, idx].sum()
        loss4_sum[k] += (w[:, idx] * ce_loss[:, idx]).sum()
    # Spinal max
    y_spinal_prob = y_pred[:, :, :5].softmax(dim=1)            # (batch_size, 3,  5)
    w_max = torch.amax(w[:, :5], dim=1)                        # (batch_size, )
    y_max = torch.amax(y[:, :5] == 2, dim=1).to(torch.float)   # 0 or 1
    y_pred_max = y_spinal_prob[:, 2, :].amax(dim=1)            # max in severe (class=2)

    loss_max = F.binary_cross_entropy(y_pred_max, y_max, reduction='none')
    loss4_sum[3] += (w_max * loss_max).sum()
    w_sum[3] += w_max.sum()
    # Average over spinal, foraminal, subarticular, and any_severe_spinal
    score = (loss4_sum / w_sum).sum().item() / 4
    print('weight =', www, 'evaluation score =', score)


    return score



def valid_one_epoch(valid_loader, model, criterion, device, softmax = False):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    # switch to evaluation mode
    model.eval()
    trues = []
    preds = []
    start = end = time.time()
    for step, (study_id, images, labels) in enumerate(valid_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = images.to(device, dtype=torch.float)
        labels = labels.to(device, dtype=torch.float)
        batch_size = labels.size(0)
        # compute loss
        with torch.no_grad():
            y_preds = model(images)#(model(images) + model(tta(images, 1)))/2
            y_preds = y_preds.squeeze(1)
        loss = criterion(y_preds, labels)
        losses.update(loss.item(), batch_size)
        # record accuracy
        # y_preds = y_preds.reshape(-1, 25, 3).softmax(dim=2).reshape(-1, 75)
        trues.append(labels.to('cpu').numpy())
        preds.append(y_preds.to('cpu').numpy())
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(valid_loader)-1):
            cusprint('EVAL: [{0}/{1}] '
                'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                'Elapsed {remain:s} '
                'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                .format(
                step, len(valid_loader), batch_time=batch_time,
                data_time=data_time, loss=losses,
                remain=timeSince(start, float(step+1)/len(valid_loader)),
                ))

    trues = np.concatenate(trues)
    predictions = np.concatenate(preds)
    sc = evaluate(predictions, trues, www = 1)
    sc = evaluate(predictions, trues, www = 2)
    print(f"trues.shape: {trues.shape}")
    print(f"predictions.shape: {predictions.shape}")
    if softmax:
        score = nn.BCELoss()(torch.from_numpy(predictions).type(torch.float32), torch.from_numpy(trues).type(torch.float32))
        score2 = comp_metric_score(trues, predictions, apply_sig = False)
        score3 = comp_metric_score(trues, predictions, w1 = 2, w2 = 4, apply_sig = False)
    else:
        score = nn.BCEWithLogitsLoss()(torch.from_numpy(predictions).type(torch.float32), torch.from_numpy(trues).type(torch.float32))
        score2 = comp_metric_score(trues, predictions, apply_sig = True)
        score3 = comp_metric_score(trues, predictions, w1 = 2, w2 = 4, apply_sig = True)
        _ = comp_metric_score2(trues, predictions, w1 = 1, w2 = 1, apply_sig = True)
        _ = comp_metric_score2(trues, predictions, w1 = 2, w2 = 4, apply_sig = True)

    
    return losses.avg, predictions, trues, score, score2, score3, sc

# loss & optimizer & scheduler

In [21]:
class GradualWarmupSchedulerV3(GradualWarmupScheduler):
    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        super(GradualWarmupSchedulerV3, self).__init__(optimizer, multiplier, total_epoch, after_scheduler)
    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]
        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]

# Training

In [22]:
def train_loop(df, ser_dict_sag_t1, ser_dict_sag_t2, ser_dict_ax_t2, fold):
    loginfo(f"========== fold: {fold} training ==========")
    df = df[df['study_id'].isin(ser_dict_sag_t2.keys())]

    # ====================================================
    # loader
    # ====================================================
    train_folds = df[df['fold']!=fold].reset_index(drop=True)################################################"[:700]
    valid_folds = df[df['fold']==fold].reset_index(drop=True)
    
    softmax = False
    if CFG.loss_fn == "BCEWithLogitsLoss":
        criterion = nn.BCEWithLogitsLoss()
    elif CFG.loss_fn == 'Custom_loss':
        cust = Custom_loss()
        cust2 = nn.BCEWithLogitsLoss()
        def criterion(y_pred, y):
            return cust(y_pred, y)#0.75*cust(y_pred, y) + 0.25*cust2(y_pred, y)
        # softmax = True
    elif(CFG.loss_fn == 'FocalLoss'):
        criterion = FocalLoss()


    train_dataset = TrainDataset(train_folds, ser_dict_sag_t1, ser_dict_sag_t2, ser_dict_ax_t2, CFG.target_cols, 'train', None, CFG.seq_len, transform=get_transforms(data='light_train'))
    valid_dataset = TrainDataset(valid_folds, ser_dict_sag_t1, ser_dict_sag_t2, ser_dict_ax_t2, CFG.target_cols, 'valid', None, CFG.seq_len, transform=get_transforms(data='valid'))
    if CFG.device == 'GPU':
        # sampler_train = torch.utils.data.RandomSampler(ds_train_len)
        train_loader = DataLoader(train_dataset, batch_size=CFG.train_bs, shuffle=True, num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
        valid_loader = DataLoader(valid_dataset, batch_size=CFG.valid_bs, shuffle=False, num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
    elif CFG.device == 'TPU':
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=CFG.train_bs, sampler=train_sampler, drop_last=True, num_workers=CFG.num_workers)
        valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False)
        valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=CFG.valid_bs, sampler=valid_sampler, drop_last=False, num_workers=CFG.num_workers)

    # ====================================================
    # model & optimizer & scheduler & loss
    # ====================================================
    # not checkpoint

    model = RSNAClassifier(CFG.model_arch, hidden_dim=256, seq_len=CFG.seq_len, pretrained=True, softmax = softmax, fold = fold)

        
    if CFG.gpu_parallel:    
        num_gpu = torch.cuda.device_count()
        model = DataParallel(model, device_ids=range(num_gpu))
    model.to(device)
    
    
    # optimizer
    if CFG.optimizer == "AdamW":
        if CFG.scheduler_warmup in ["GradualWarmupSchedulerV2","GradualWarmupSchedulerV3"]:
            optimizer = AdamW(model.parameters(), lr=CFG.lr/CFG.warmup_factor, weight_decay=CFG.weight_decay) 
        else:
            optimizer = AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
    if CFG.optimizer == "Adam":
        if CFG.scheduler_warmup in ["GradualWarmupSchedulerV2","GradualWarmupSchedulerV3"]:
            optimizer = Adam(model.parameters(), lr=CFG.lr/CFG.warmup_factor, weight_decay=CFG.weight_decay) 
        else:
            optimizer = Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
    if CFG.optimizer == "SGD":
        if CFG.scheduler_warmup in ["GradualWarmupSchedulerV2","GradualWarmupSchedulerV3"]:
            optimizer = SGD(model.parameters(), lr=CFG.lr/CFG.warmup_factor, weight_decay=CFG.weight_decay) 
        else:
            optimizer = SGD(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
    # scheduler,
    if CFG.scheduler=='ReduceLROnPlateau':
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=CFG.factor, patience=CFG.patience, verbose=True, eps=CFG.eps)
    elif CFG.scheduler=='CosineAnnealingLR':
        scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1)
    elif CFG.scheduler=='CosineAnnealingWarmRestarts':
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, eta_min=CFG.min_lr, last_epoch=-1)

    scheduler_warmup = GradualWarmupSchedulerV3(optimizer, multiplier=CFG.warmup_factor, total_epoch=CFG.warmup_epo, after_scheduler=scheduler)



    # ====================================================
    # loop
    # ====================================================
    valid_acc_max=0; valid_loss_min=float("inf")
    valid_acc_max_cnt=0; valid_loss_min_cnt=0;
    best_acc_epoch=0;


    best_score = 10000

    shift_epoch = CFG.shift_epoch#(3*CFG.epochs)//4

    for epoch in range(CFG.epochs):
        if epoch == shift_epoch:
            model_ = RSNAClassifier(CFG.model_arch, hidden_dim=256, seq_len=CFG.seq_len, pretrained=True, softmax = True)
            model_.to(device)
            model_.load_state_dict(model.state_dict())
            model = model_
            criterion = Custom_loss()
            softmax = True
            optimizer_ = AdamW(model.parameters(), lr=CFG.lr/CFG.warmup_factor, weight_decay=CFG.weight_decay) 
            optimizer_.load_state_dict(optimizer.state_dict())
            optimizer = optimizer_

        loginfo(f"***** Epoch {epoch} *****")

        if CFG.scheduler_warmup in ["GradualWarmupSchedulerV2","GradualWarmupSchedulerV3"]:
            loginfo(f"schwarmup_last_epoch:{scheduler_warmup.last_epoch}, schwarmup_lr:{scheduler_warmup.get_last_lr()[0]}")
        if CFG.scheduler=='CosineAnnealingLR':
            loginfo(f"scheduler_last_epoch:{scheduler.last_epoch}, scheduler_lr:{scheduler.get_last_lr()[0]}")
        loginfo(f"optimizer_lr:{optimizer.param_groups[0]['lr']}")


        start_time = time.time()
        if(epoch>=CFG.min_epoch):
            avg_loss, cur_lr = train_one_epoch(train_loader, model, criterion, optimizer, epoch, scheduler, device) # train
        else:
            avg_loss, cur_lr = -1, -1
            # model.load_state_dict(torch.load(outputdir+f'/{CFG.model_arch}_{CFG.suffix}_fold{fold}.pth')['model'])
            ws = torch.load(outputdir+f'/{CFG.model_arch}_{CFG.suffix}_fold{fold}.pth')['model']
            ws2 = model.state_dict()
            ws2.update(ws)
            model.load_state_dict(ws2)
        avg_val_loss, preds, trues, score, score2, score3, score4 = valid_one_epoch(valid_loader, model, criterion, device, softmax = softmax) # valid

        # scoring
        elapsed = time.time() - start_time 

        loginfo(f'Epoch {epoch} - avg_train_loss: {avg_loss:.4f}  time: {elapsed:.0f}s')
        loginfo(f'Epoch {epoch} - avg_val_loss: {avg_val_loss:.4f} - avg_val_bce: {score:.4f} - avg_val_comp_metric : {score2:.4f} - w_avg_val_comp_metric : {score3:.4f}')

        if CFG.scheduler_warmup in ["GradualWarmupSchedulerV2","GradualWarmupSchedulerV3"]:
            scheduler_warmup.step()
        elif CFG.scheduler == "ReduceLROnPlateau":
            scheduler.step(avg_val_loss)
        elif CFG.scheduler in ["CosineAnnealingLR", "CosineAnnealingWarmRestarts"]:
            scheduler.step()


        if best_score>score4:
            best_score = score4
            ws = model.state_dict()
            keys = list(ws.keys())
            for k in keys:
                if not k.startswith('logits.'):
                    ws.pop(k)   

            np.save(outputdir+f'/{CFG.model_arch}_{CFG.suffix}__fold{fold}_preds.npy', preds)
            np.save(outputdir+f'/{CFG.model_arch}_{CFG.suffix}_fold{fold}_trues.npy', trues)
            torch.save({'model': model.state_dict()}, outputdir+f'/{CFG.model_arch}_{CFG.suffix}_fold{fold}.pth')

    return preds, trues

In [23]:
def main():
    oof_df = pd.DataFrame()
    oof_list = []
    train_df, ser_dict_sag_t1, ser_dict_sag_t2, ser_dict_ax_t2 = get_df()
    for fold in CFG.fold_list:
        preds, trues = train_loop(train_df, ser_dict_sag_t1, ser_dict_sag_t2, ser_dict_ax_t2, fold)
        oof_list.append([preds, trues])
    return oof_list

# Main

In [26]:
CFG.epochs = 12
CFG.min_epoch = -1

# CFG.epochs = 1
# CFG.min_epoch = 1

CFG.plane = 'sag-t2'
CFG.fold_list = [0, 1, 2, 3]
outputdir = f'output_all'
seed_everything(42)
if not os.path.isdir(outputdir):
    os.mkdir(outputdir)

if __name__ == '__main__':
    print(CFG.suffix)
    if CFG.device == 'TPU':
        def _mp_fn(rank, flags):
            torch.set_default_tensor_type('torch.FloatTensor')
            a = main()
        FLAGS = {}
        xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=CFG.nprocs, start_method='fork')
    elif CFG.device == 'GPU':
        oof_list = main()

440
1975 1973 1974 1975


***** Epoch 0 *****
schwarmup_last_epoch:0, schwarmup_lr:4.000000000000001e-06
scheduler_last_epoch:0, scheduler_lr:4.000000000000001e-06
optimizer_lr:4.000000000000001e-06


EVAL: [0/124] Data 2.650 (2.650) Elapsed 0m 4s (remain 9m 29s) Loss: 0.3225(0.3225) 
EVAL: [100/124] Data 0.000 (0.027) Elapsed 0m 27s (remain 0m 6s) Loss: 0.1677(0.2664) 


Epoch 0 - avg_train_loss: -1.0000  time: 34s
Epoch 0 - avg_val_loss: 0.2659 - avg_val_bce: 0.2866 - avg_val_comp_metric : 0.5717 - w_avg_val_comp_metric : 0.6060


EVAL: [123/124] Data 0.000 (0.022) Elapsed 0m 33s (remain 0m 0s) Loss: 0.4243(0.2659) 
weight = 1 evaluation score = 0.4868314862251282
weight = 2 evaluation score = 0.4496035873889923
trues.shape: (494, 75)
predictions.shape: (494, 75)
0.37545499555141904 0.6329692830344926 0.6920404889397052 0.2238256192243744
w1 = 1 w2 = 1 0.4810725966874978
0.42799708826052957 0.677373465639109 0.7139865400216576 0.2715328903303016
w1 = 2 w2 = 4 0.5227224960628994


***** Epoch 0 *****
schwarmup_last_epoch:0, schwarmup_lr:4.000000000000001e-06
scheduler_last_epoch:0, scheduler_lr:4.000000000000001e-06
optimizer_lr:4.000000000000001e-06


EVAL: [0/124] Data 1.780 (1.780) Elapsed 0m 2s (remain 4m 8s) Loss: 0.2583(0.2583) 
EVAL: [100/124] Data 0.000 (0.046) Elapsed 0m 28s (remain 0m 6s) Loss: 0.3685(0.2667) 


Epoch 0 - avg_train_loss: -1.0000  time: 35s
Epoch 0 - avg_val_loss: 0.2672 - avg_val_bce: 0.2846 - avg_val_comp_metric : 0.5752 - w_avg_val_comp_metric : 0.6149


EVAL: [123/124] Data 0.000 (0.054) Elapsed 0m 35s (remain 0m 0s) Loss: 0.1548(0.2672) 
weight = 1 evaluation score = 0.489918977022171
weight = 2 evaluation score = 0.4580690562725067
trues.shape: (494, 75)
predictions.shape: (494, 75)
0.3949338888394965 0.6322290339972138 0.6847042906480152 0.23156905676606712
w1 = 1 w2 = 1 0.48585906756269814
0.4504792091523381 0.6850357642090823 0.6970852225038795 0.3118136058382566
w1 = 2 w2 = 4 0.5361034504258891


***** Epoch 0 *****
schwarmup_last_epoch:0, schwarmup_lr:4.000000000000001e-06
scheduler_last_epoch:0, scheduler_lr:4.000000000000001e-06
optimizer_lr:4.000000000000001e-06


EVAL: [0/124] Data 1.711 (1.711) Elapsed 0m 1s (remain 3m 59s) Loss: 0.2738(0.2738) 
EVAL: [100/124] Data 0.000 (0.017) Elapsed 0m 25s (remain 0m 5s) Loss: 0.2395(0.2772) 


Epoch 0 - avg_train_loss: -1.0000  time: 31s
Epoch 0 - avg_val_loss: 0.2714 - avg_val_bce: 0.2937 - avg_val_comp_metric : 0.5786 - w_avg_val_comp_metric : 0.6153


EVAL: [123/124] Data 0.000 (0.014) Elapsed 0m 30s (remain 0m 0s) Loss: 0.0173(0.2714) 
weight = 1 evaluation score = 0.48960524797439575
weight = 2 evaluation score = 0.45944106578826904
trues.shape: (493, 75)
predictions.shape: (493, 75)
0.38750533387626196 0.6462525656347464 0.6952088928470364 0.2363541466187582
w1 = 1 w2 = 1 0.49133023474420073
0.4410903447950007 0.6991291426936865 0.7036998545783477 0.3146752227070914
w1 = 2 w2 = 4 0.5396486411935316


***** Epoch 0 *****
schwarmup_last_epoch:0, schwarmup_lr:4.000000000000001e-06
scheduler_last_epoch:0, scheduler_lr:4.000000000000001e-06
optimizer_lr:4.000000000000001e-06


EVAL: [0/124] Data 1.296 (1.296) Elapsed 0m 1s (remain 3m 9s) Loss: 0.2563(0.2563) 
EVAL: [100/124] Data 0.000 (0.014) Elapsed 0m 24s (remain 0m 5s) Loss: 0.1479(0.2671) 


Epoch 0 - avg_train_loss: -1.0000  time: 30s
Epoch 0 - avg_val_loss: 0.2667 - avg_val_bce: 0.2928 - avg_val_comp_metric : 0.5673 - w_avg_val_comp_metric : 0.6155


EVAL: [123/124] Data 0.000 (0.011) Elapsed 0m 30s (remain 0m 0s) Loss: 0.2137(0.2667) 
weight = 1 evaluation score = 0.4696274399757385
weight = 2 evaluation score = 0.44606131315231323
trues.shape: (493, 75)
predictions.shape: (493, 75)
0.38481526732235677 0.6307844676861583 0.6958750236136952 0.20944457217865944
w1 = 1 w2 = 1 0.48022983270021746
0.4630939858948766 0.7022556809934897 0.6936760374492577 0.29453082265925085
w1 = 2 w2 = 4 0.5383891317492187


In [27]:
weight = 2 evaluation score = 0.44960105419158936
weight = 2 evaluation score = 0.45807141065597534
weight = 2 evaluation score = 0.4594390392303467
weight = 2 evaluation score = 0.4460652470588684

SyntaxError: invalid syntax (2261054437.py, line 1)