In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

datadir = './preprocessed'
libdir = '.'
outputdir = './ax_t2_soft_90_256'
otherdir = '.'
train_bs_ = 4
valid_bs_ = 4
num_workers_ = 4

In [2]:
# [i for i in timm.list_models() if 'resnest' in i]

In [3]:
conditions = ['spinal_canal_stenosis', 'left_neural_foraminal_narrowing', 'right_neural_foraminal_narrowing',
          'left_subarticular_stenosis', 'right_subarticular_stenosis']
levels = ['l1_l2', 'l2_l3', 'l3_l4', 'l4_l5', 'l5_s1']
severity = ['Normal/Mild', 'Moderate', 'Severe']
class CFG:
    seed=42
    device='GPU'
    nprocs=1 # [1, 8]
    num_workers=num_workers_
    train_bs=train_bs_
    valid_bs=valid_bs_
    data_dir = './'
    INP_DIR = 'rsna/'
    target_cols=[c+'_'+l+'_'+s for c in conditions for l in levels for s in severity]
    imgs_dir = ['./input/']
    num_classes=75

    accum_iter=1#2
    max_grad_norm=1000
    print_freq=100
    normalize_mean=[0.4824, 0.4824, 0.4824] # [0.485, 0.456, 0.406] [0.4824, 0.4824, 0.4824]
    normalize_std=[0.22, 0.22, 0.22] # [0.229, 0.224, 0.225] [0.22, 0.22, 0.22]

    suffix="440"
    fold_num=4
    fold_list=[0, 1, 2, 3]
    min_epoch = -1
    epochs = 40
    shift_epoch = 1000
    model_arch="resnest50d" # tf_efficientnetv2_s, resnest50d, resnext50_32x4d, resnet200d, convnext_tiny_384_in22ft1k
    optimizer="AdamW" # Adam, SGD, AdamW
    scheduler="CosineAnnealingLR"#'ReduceLROnPlateau', 'CosineAnnealingLR', 'CosineAnnealingWarmRestarts'
    loss_fn= "Custom_loss"#'Custom_loss', "BCEWithLogitsLoss", "FocalLoss"
    scheduler_warmup="GradualWarmupSchedulerV3" 

    warmup_epo=1
    warmup_factor = 10
    T_0 = 1
    T_max= epochs-warmup_epo-2 
    
    seq_len = 90
    img_size = 256
    p_mixup = 0.5

    lr=4e-5
    min_lr=1e-7
    # lr=23e-5
    # min_lr=23e-6
    weight_decay=0.02
    dropout=0.1

    gpu_parallel=False
    n_early_stopping=8
    debug=False
    multihead=False
    plane = 'ax_t2' # 'sag_t1', 'sag_t2'

In [4]:
# !ls

In [5]:
# [i for i in timm.list_models(pretrained=True) if 'resnest50d' in i]

In [6]:
# !pip install pandas -q
# !pip install scikit-learn -q
# !pip install warmup-scheduler==0.3 -q
# !pip install timm==0.9.7 -q
# !pip install albumentations==1.3.1
# !pip install opencv-python -q
# !pip install segmentation_models_pytorch
# !apt-get update && apt-get install ffmpeg libsm6 libxext6  -y

# Import

In [7]:
import sys; 

package_paths = [f'{libdir}pytorch-image-models-master']
for pth in package_paths:
    sys.path.append(pth)

import ast
from glob import glob
import os
from datetime import datetime
import time
import random
from tqdm import tqdm
from contextlib import contextmanager
import math
import cv2

import numpy as np
import pandas as pd
import sklearn
from sklearn.metrics import roc_auc_score, log_loss
from sklearn import metrics
from sklearn.model_selection import GroupKFold, StratifiedKFold, StratifiedGroupKFold
import torch
import torchvision
from torchvision import transforms
from torchvision.transforms import v2

from torch import nn
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler
from torch.nn.modules.loss import _WeightedLoss
import torch.nn.functional as F

from torch.optim import Adam, SGD, AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
from warmup_scheduler import GradualWarmupScheduler
import timm
import warnings
import joblib
from scipy.ndimage.interpolation import zoom

import gc 
from torch.nn import DataParallel

import segmentation_models_pytorch as smp


import albumentations as A
from albumentations.pytorch import ToTensorV2



if CFG.device == 'TPU':
    !pip install -q pytorch-ignite
    import ignite.distributed as idist
elif CFG.device == 'GPU':
    from torch.cuda.amp import autocast, GradScaler


import warnings
warnings.filterwarnings('ignore')

  from scipy.ndimage.interpolation import zoom


# Loss & Metric

In [8]:
def sig(x):
    return 1/(1+np.exp(-x))

class FocalLoss(nn.Module):
    def __init__(self, gamma=2):
        super().__init__()
        self.gamma = gamma
        
    def forward(self, input, target):
        if not (target.size() == input.size()):
            raise ValueError("Target size ({}) must be the same as input size ({})"
                             .format(target.size(), input.size()))

        max_val = (-input).clamp(min=0)
        loss = input - input * target + max_val + \
            ((-max_val).exp() + (-input - max_val).exp()).log()

        invprobs = F.logsigmoid(-input * (target * 2.0 - 1.0))
        loss = (invprobs * self.gamma).exp() * loss
        
        return loss.sum(dim=1).mean()


def torch_sigmoid(x): return (1 + (-x).exp()).reciprocal()


class Custom_loss(nn.Module):


    def __init__(self, temperature=0.0):
        """
        Use max if temperature = 0
        """
        super().__init__()
        self.t = temperature
        assert self.t >= 0
    def __repr__(self):
        return 'SevereLoss(t=%.1f)' % self.t

    def forward(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        y_pred = y_pred.reshape(y_pred.shape[0], 25, 3).transpose(1, 2)

        y = y.reshape(y.shape[0], 25, 3).transpose(1, 2).argmax(1)
        """
        Args:
          y_pred (Tensor[float]): logit             (batch_size, 3, 25)
          y      (Tensor[int]):   true label index  (batch_size, 25)
        """
        assert y_pred.size(0) == y.size(0)
        assert y_pred.size(1) == 3 and y_pred.size(2) == 25
        assert y.size(1) == 25
        assert y.size(0) > 0
        slices = [slice(0, 5), slice(5, 15), slice(15, 25)] 

        loss = F.cross_entropy(y_pred, y.long(), reduction='none')  # (batch_size, 25)

        wloss_sums = []
        for k, idx in enumerate(slices):
            wloss_sums.append((loss[:, idx]).sum())
        y_spinal_prob = y_pred[:, :, :5].softmax(dim=1)             # (batch_size, 3,  5)
        y_max = torch.amax(y[:, :5] == 2, dim=1).to(y_pred.dtype)   # 0 or 1
        if self.t > 0:
            # Attention for the maximum value
            attn = F.softmax(y_spinal_prob[:, 2, :] / self.t, dim=1)  # (batch_size, 5)

            # Pick the sofmax among 5 severe=2 y_spinal_probs with attn
            y_pred_max = (attn * y_spinal_prob[:, 2, :]).sum(dim=1)   # weighted average among 5 spinal columns 
        else:
            # Exact max; this works too
            y_pred_max = y_spinal_prob[:, 2, :].amax(dim=1)
        loss_max = F.binary_cross_entropy(y_pred_max, y_max, reduction='none')
        wloss_sums.append((loss_max).sum())

        # See below about these numbers
        loss = (wloss_sums[0] / 6.084050632911392 +
                wloss_sums[1] / 12.962531645569621 + 
                wloss_sums[2] / 14.38632911392405# +
                # wloss_sums[3] / 1.729113924050633
               ) / (3 * y.size(0))

        return loss


def comp_metric_score(y_true, y_pred, w1 = 1, w2 = 1, apply_sig = True):
    y_pred = sig(y_pred)
    y_pred[:, 1::3] *= w1
    y_pred[:, 2::3] *= w2
    for i in range(0, 75, 3):
        y_pred[:, i:i+3] /= y_pred[:, i:i+3].sum(1)[:, None]
    y_pred = torch.from_numpy(y_pred)
    y_pred = (torch.log(y_pred)/(1-y_pred))


    y = torch.from_numpy(y_true.copy())
    
    y_pred = y_pred.reshape(y_pred.shape[0], 25, 3).transpose(1, 2)
    y = y.reshape(y.shape[0], 25, 3).transpose(1, 2).argmax(1)
    # 4 losses for the evaluation metric
    loss4_sum = torch.zeros(4)
    w_sum = torch.zeros(4)
    slices = [slice(0, 5), slice(5, 15), slice(15, 25)]  # spinal, foraminal, subarticular

    w = 2 ** y  # sample_weight w = (1, 2, 4) for y = 0, 1, 2 (batch_size, 25)

    batch_size = y_pred.shape[0]

    ce_loss = F.cross_entropy(y_pred, y.long(), reduction='none')  # (batch_size, 25)
    for k, idx in enumerate(slices):
        w_sum[k] += w[:, idx].sum()
        loss4_sum[k] += (w[:, idx] * ce_loss[:, idx]).sum()
    # Spinal max
    y_spinal_prob = y_pred[:, :, :5].softmax(dim=1)            # (batch_size, 3,  5)
    w_max = torch.amax(w[:, :5], dim=1)                        # (batch_size, )
    y_max = torch.amax(y[:, :5] == 2, dim=1).to(torch.float)   # 0 or 1
    y_pred_max = y_spinal_prob[:, 2, :].amax(dim=1)            # max in severe (class=2)
    loss_max = F.binary_cross_entropy(y_pred_max, y_max, reduction='none')
    loss4_sum[3] += (w_max * loss_max).sum()
    w_sum[3] += w_max.sum()
    # Average over spinal, foraminal, subarticular, and any_severe_spinal
    score = (loss4_sum / w_sum).sum().item() / 4
    return score

class Custom_loss2(nn.Module):

    def forward(self, p, t):
        p = p.sigmoid().reshape(-1, 25, 3)
        p[..., 1] *= 2
        p[..., 2] *= 4
        p /= p.sum(2)[:,:,None]
        p = p.reshape(-1, 75)
        return F.binary_cross_entropy(p, t)

def comp_metric_score2(y_true, y_pred, w1 = 1, w2 = 1, apply_sig = True):
    t = y_true.copy()
    p = y_pred.copy()
    if apply_sig:
        p = sig(p)
    p[:, 1::3] *= w1
    p[:, 2::3] *= w2
    for i in range(0, 75, 3):
        p[:,i:i+3] /= p[:,i:i+3].sum(1)[:, None]
    t_sp, p_sp = t[:, :15].reshape(-1, 3), p[:, :15].reshape(-1, 3)
    t_ne, p_ne = t[:, 15:45].reshape(-1, 3), p[:, 15:45].reshape(-1, 3)
    t_su, p_su = t[:, 45:].reshape(-1, 3), p[:, 45:].reshape(-1, 3)
    t_sp_sev, p_sp_sev = t[:, :15][:, 2::3].max(1), p[:, :15][:, 2::3].max(1)
    loss_sp = log_loss(t_sp, p_sp, sample_weight = t_sp[:,0] + 2*t_sp[:,1]+4*t_sp[:, 2])
    loss_ne = log_loss(t_ne, p_ne, sample_weight = t_ne[:,0] + 2*t_ne[:,1]+4*t_ne[:, 2])
    loss_su = log_loss(t_su, p_su, sample_weight = t_su[:,0] + 2*t_su[:,1]+4*t_su[:, 2])
    loss_sp_sev = log_loss(t_sp_sev, p_sp_sev)
    print(loss_sp, loss_ne, loss_su, loss_sp_sev)
    loss = (loss_sp + loss_ne + loss_su + loss_sp_sev)/4
    print('w1 =', w1, 'w2 =', w2, loss)
    return loss

# helper

In [10]:
ENV = 'kaggle'
if CFG.device == 'TPU':
    import os
    VERSION = "1.7"
    CP_V = "36" if ENV == "colab" else "37"
    wheel = f"torch_xla-{VERSION}-cp{CP_V}-cp{CP_V}m-linux_x86_64.whl"
    url = f"https://storage.googleapis.com/tpu-pytorch/wheels/{wheel}"
    !pip3 -q install cloud-tpu-client==0.10 $url
    os.system('export XLA_USE_BF16=1')
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    import torch_xla.distributed.xla_multiprocessing as xmp
    CFG.lr = CFG.lr * CFG.nprocs
    CFG.train_bs = CFG.train_bs // CFG.nprocs
    device = xm.xla_device()
    
elif CFG.device == "GPU":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [11]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True 

seed_everything(CFG.seed)


def get_score(y_true, y_pred):
    scores = []
    for i in range(y_true.shape[1]):
        score = roc_auc_score(y_true[:,i], y_pred[:,i])
        scores.append(score)
    avg_score = np.mean(scores)
    return avg_score, scores


@contextmanager
def timer(name):
    t0 = time.time()
    LOGGER.info(f'[{name}] start')
    yield
    LOGGER.info(f'[{name}] done in {time.time() - t0:.0f} s.')


def init_logger(log_file=outputdir+'stage2_train.log'):
    from logging import getLogger, INFO, FileHandler,  Formatter,  StreamHandler
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

if not os.path.isdir(outputdir):
    os.mkdir(outputdir)
LOGGER = init_logger(outputdir+f'/stage2_train{CFG.suffix}.log')

if CFG.device=='TPU' and CFG.nprocs==8:
    loginfo = xm.master_print
    cusprint = xm.master_print
else:
    loginfo = LOGGER.info
    cusprint = print



def get_timediff(time1,time2):
    minute_,second_ = divmod(time2-time1,60)
    return f"{int(minute_):02d}:{int(second_):02d}"  


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))


def get_img(path):
    im_bgr = cv2.imread(path)
    im_rgb = im_bgr[:, :, ::-1]
    return im_rgb

def load_dicom(path):
    """
    This supports loading both regular and compressed JPEG images. 
    See the first sell with `pip install` commands for the necessary dependencies
    """
    img = dicom.dcmread(path)
    img.PhotometricInterpretation = 'YBR_FULL'
    data = img.pixel_array
    data = data - np.min(data)
    if np.max(data) != 0:
        data = data / np.max(data)
    # data = (data * 255).astype(np.uint8)
    return data

# DataSet

In [12]:
class TrainDataset(Dataset):
    def __init__(self, csv, ser_dict_sag_t1, ser_dict_sag_t2, ser_dict_ax_t2, targets, mode, meta_features, in_chans, transform=None):

        self.csv = csv.reset_index(drop=True)
        self.ser_dict_sag_t1 = ser_dict_sag_t1
        self.ser_dict_sag_t2 = ser_dict_sag_t2
        self.ser_dict_ax_t2 = ser_dict_ax_t2
        self.targets = targets
        self.mode = mode
        self.use_meta = meta_features is not None
        self.meta_features = meta_features
        self.transform = transform
        self.in_chans = in_chans
        
        self.aug_tr = v2.RandomRotation(
                            degrees = (-45, 45),
                            interpolation = torchvision.transforms.InterpolationMode.BILINEAR,
                            expand = False,
                            center = None,
                            fill = 0
                        )
        self.aug_sj = v2.ScaleJitter(
                        target_size = [CFG.img_size, CFG.img_size],
                        scale_range = (0.8, 1.2),
                        interpolation = torchvision.transforms.InterpolationMode.BILINEAR,
                        antialias = True)
        self.aug_rc = v2.RandomCrop(
                            size = [CFG.img_size, CFG.img_size],
                            pad_if_needed = True,
                            fill = 0,
                            padding_mode = 'constant'
                        )
        self.aug_hf = v2.RandomHorizontalFlip(0.5)
        self.aug_vf = v2.RandomVerticalFlip(0.5)

    def __len__(self):
        return self.csv.shape[0]

    def __getitem__(self, index):

        row = self.csv.iloc[index]
        
        plane = CFG.plane
        

        study_id = int(row.study_id)
                
        if plane in ['sag_t1', 'sag_t2'] and study_id not in self.ser_dict_sag_t1.keys():
            if self.mode == 'train':
                serie_id = np.random.choice(self.ser_dict_sag_t2[study_id])
            else:
                serie_id = self.ser_dict_sag_t2[study_id][-1]
        elif plane in ['sag_t1', 'sag_t2'] and study_id not in self.ser_dict_sag_t2.keys():
            if self.mode == 'train':
                serie_id = np.random.choice(self.ser_dict_sag_t1[study_id])
            else:
                serie_id = self.ser_dict_sag_t1[study_id][-1]
        else:
            if plane == 'sag_t1':
                if self.mode == 'train':
                    serie_id = np.random.choice(self.ser_dict_sag_t1[study_id])
                else:
                    serie_id = self.ser_dict_sag_t1[study_id][-1]
            elif plane == 'sag_t2':
                if self.mode == 'train':
                    serie_id = np.random.choice(self.ser_dict_sag_t2[study_id])
                else:
                    serie_id = self.ser_dict_sag_t2[study_id][-1]
            else:
                if self.mode == 'train':
                    serie_id = np.random.choice(self.ser_dict_ax_t2[study_id])
                else:
                    serie_id = self.ser_dict_ax_t2[study_id][-1]
        if self.mode == 'train':
            serie_id_ax_t2 = np.random.choice(self.ser_dict_ax_t2[study_id])
        else:
            serie_id_ax_t2 = self.ser_dict_ax_t2[study_id][-1]
        fp = f'/dev/shm/preprocessed/{study_id}_{serie_id}.npy'
        image = np.load(fp)


        image = torch.from_numpy(image).float()/255
        image = F.interpolate(
                 image.unsqueeze(0).unsqueeze(0),
                 size=[CFG.seq_len, CFG.img_size, CFG.img_size],
                 mode='trilinear'
             ).squeeze(0).squeeze(0)

        if (self.mode == 'train') and np.random.rand()<0.5:
            image = self.aug_tr(image)
            image = self.aug_sj(image)
            image = self.aug_rc(image)

        l_img = image.shape[0]
        if (self.mode == 'train') and np.random.rand()<0.5:
            inds = np.random.choice(np.arange(l_img), CFG.seq_len)
            inds.sort()
            image = image[inds]

        image = image.numpy()
        tar = torch.tensor(row[self.targets]).float()
        if self.mode == 'train' and np.random.uniform()<=0.5:
            tar = torch.cat([tar[:15], tar[30:45], tar[15:30], tar[60:75], tar[45:60]])
            image = image[:, ::-1]


        # transform
        if self.transform:
            
            image = np.transpose(image, (1, 2, 0))
            augmented = self.transform(image=image)
            image = augmented['image']
            image = np.transpose(image, (2, 0, 1))
            image = torch.from_numpy(image)


        if self.mode == 'test':
            return study_id, image
        else:
            return study_id, image, tar



def get_transforms(*, data):
    if data == 'train':
        return Compose([
    albumentations.Resize(CFG.img_size, CFG.img_size),
    albumentations.Perspective(p=0.5),
    albumentations.HorizontalFlip(p=0.5),
    albumentations.VerticalFlip(p=0.5),
    albumentations.Transpose(p=0.5),
    albumentations.RandomBrightness(limit=0.1, p=0.7),
    albumentations.ShiftScaleRotate(shift_limit=0.3, scale_limit=0.3, rotate_limit=45, border_mode=4, p=0.7),

    albumentations.OneOf([
        albumentations.MotionBlur(blur_limit=3),
        albumentations.MedianBlur(blur_limit=3),
        albumentations.GaussianBlur(blur_limit=3),
        albumentations.GaussNoise(var_limit=(3.0, 9.0)),
    ], p=0.5),
    albumentations.OneOf([
        albumentations.OpticalDistortion(distort_limit=1.),
        albumentations.GridDistortion(num_steps=5, distort_limit=1.),
    ], p=0.5),

    albumentations.Cutout(max_h_size=int(CFG.img_size * 0.5), max_w_size=int(CFG.img_size * 0.5), num_holes=1, p=0.5),
            ])
    elif data == 'light_train':
        return A.Compose([
            A.Resize(CFG.img_size, CFG.img_size),
            A.Perspective(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.3, scale_limit=0.3, rotate_limit=45, p=0.75),
            A.OneOf(
                [
                    A.RandomGamma(p=1),
                    A.RandomBrightnessContrast(contrast_limit=0.2, brightness_limit=0.0, p=1),
                    A.RandomBrightnessContrast(contrast_limit=0.0, brightness_limit=0.2, p=1),
                ],
                p=0.5,
            ),
            A.OneOf([
                A.MotionBlur(p = 1),
                A.GaussianBlur(p = 1),
                A.GaussNoise(p = 1),
                ], p=0.5),
            A.GridDistortion(num_steps=5, distort_limit=0.05, p=0.5),
            ], p=1.0)
    elif data == 'valid':
        return A.Compose([
            A.Resize(CFG.img_size, CFG.img_size),

        ], p=1.0)



def get_df():

    df = pd.read_csv('train.csv')

    for c in conditions:
        for l in levels:
            for s in severity:
                df[c + '_' + l + '_' + s] =  (df[c + '_' + l] == s).astype(int)
    df = df[['study_id']+CFG.target_cols]
    tr_ser_desc = pd.read_csv('train_series_descriptions.csv')
    ser_dict_sag_t1 = tr_ser_desc[tr_ser_desc['series_description']=='Sagittal T1'][['study_id', 'series_id']].groupby('study_id').apply(lambda df:df.series_id.tolist()).to_dict()
    ser_dict_sag_t2 = tr_ser_desc[tr_ser_desc['series_description']=='Sagittal T2/STIR'][['study_id', 'series_id']].groupby('study_id').apply(lambda df:df.series_id.tolist()).to_dict()
    ser_dict_ax_t2 = tr_ser_desc[tr_ser_desc['series_description']=='Axial T2'][['study_id', 'series_id']].groupby('study_id').apply(lambda df:df.series_id.tolist()).to_dict()
    print(len(df), len(ser_dict_sag_t1), len(ser_dict_sag_t2), len(ser_dict_ax_t2))

    skf = StratifiedGroupKFold(n_splits = CFG.fold_num)
    df['fold'] = -1
    for i, (train_inds, val_inds) in enumerate(skf.split(df, df[CFG.target_cols[-1]], df['study_id'])):
        df.loc[val_inds, 'fold'] = i

    return df, ser_dict_sag_t1, ser_dict_sag_t2, ser_dict_ax_t2

In [13]:
# from pylab import rcParams
# dataset_show = TrainDataset(
#     train_df,
#     transform=get_transforms(data='light_train') # None, get_transforms(data='check')
#     )
# rcParams['figure.figsize'] = 30,20
# for i in range(2):
#     f, axarr = plt.subplots(1,5)
#     idx = np.random.randint(0, len(dataset_show))
#     img, label= dataset_show[idx]
#     # axarr[p].imshow(img) # transform=None
#     axarr[0].imshow(img[0]); plt.axis('OFF');
#     axarr[1].imshow(img[1]); plt.axis('OFF');
#     axarr[2].imshow(img[2]); plt.axis('OFF');
#     axarr[3].imshow(img[3]); plt.axis('OFF');
#     axarr[4].imshow(img[4]); plt.axis('OFF');

# Model

In [14]:
import torch.nn as nn
from itertools import repeat

class SpatialDropout(nn.Module):
    def __init__(self, drop=0.5):
        super(SpatialDropout, self).__init__()
        self.drop = drop
        
    def forward(self, inputs, noise_shape=None):
        """
        @param: inputs, tensor
        @param: noise_shape, tuple
        """
        outputs = inputs.clone()
        if noise_shape is None:
            noise_shape = (inputs.shape[0], *repeat(1, inputs.dim()-2), inputs.shape[-1]) 
        
        self.noise_shape = noise_shape
        if not self.training or self.drop == 0:
            return inputs
        else:
            noises = self._make_noises(inputs)
            if self.drop == 1:
                noises.fill_(0.0)
            else:
                noises.bernoulli_(1 - self.drop).div_(1 - self.drop)
            noises = noises.expand_as(inputs)    
            outputs.mul_(noises)
            return outputs
            
    def _make_noises(self, inputs):
        return inputs.new().resize_(self.noise_shape)


import torch
from torch import nn
import torch.nn.functional as F

from typing import Dict, Optional
 
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor



class MLPAttentionNetwork(nn.Module):
 
    def __init__(self, hidden_dim, attention_dim=None):
        super(MLPAttentionNetwork, self).__init__()
 
        self.hidden_dim = hidden_dim
        self.attention_dim = attention_dim
        if self.attention_dim is None:
            self.attention_dim = self.hidden_dim
        # W * x + b
        self.proj_w = nn.Linear(self.hidden_dim, self.attention_dim, bias=True)
        # v.T
        self.proj_v = nn.Linear(self.attention_dim, 1, bias=False)
 
    def forward(self, x):
        """
        :param x: seq_len, batch_size, hidden_dim
        :return: batch_size * seq_len, batch_size * hidden_dim
        """
        # print(f"x shape:{x.shape}")
        batch_size, seq_len, _ = x.size()
        # flat_inputs = x.reshape(-1, self.hidden_dim) # (batch_size*seq_len, hidden_dim)
        # print(f"flat_inputs shape:{flat_inputs.shape}")
        
        H = torch.tanh(self.proj_w(x)) # (batch_size, seq_len, hidden_dim)
        # print(f"H shape:{H.shape}")
        
        att_scores = torch.softmax(self.proj_v(H),axis=1) # (batch_size, seq_len)
        # print(f"att_scores shape:{att_scores.shape}")
        
        attn_x = (x * att_scores).sum(1) # (batch_size, hidden_dim)
        # print(f"attn_x shape:{attn_x.shape}")
        return attn_x

In [15]:
def gem(x, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1.0 / p)


class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6, p_trainable=False):
        super(GeM, self).__init__()
        if p_trainable:
            self.p = nn.parameter.Parameter(torch.ones(1) * p)
        else:
            self.p = p
        self.eps = eps

    def forward(self, x):
        ret = gem(x, p=self.p, eps=self.eps)
        return ret

    def __repr__(self):
        return (
            self.__class__.__name__
            + "("
            + "p="
            + "{:.4f}".format(self.p.data.tolist()[0])
            + ", "
            + "eps="
            + str(self.eps)
            + ")"
        )
    
    
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.

    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor

class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None, scale_by_keep=True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x, mask=None):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class RSNAClassifier(nn.Module):
    def __init__(self, model_arch, hidden_dim=256, seq_len=CFG.seq_len, pretrained=False, softmax = False):
        super().__init__()
        self.seq_len = seq_len
        self.model = timm.create_model(model_arch, in_chans=3, pretrained=pretrained)
        self.softmax = softmax

        if 'efficientnet' in CFG.model_arch:
            cnn_feature = self.model.classifier.in_features
            self.model.classifier = nn.Identity()
        elif "res" in CFG.model_arch or 1:
            cnn_feature = self.model.fc.in_features
            self.model.global_pool = nn.Identity()
            self.model.fc = nn.Identity()
            self.pooling = nn.AdaptiveAvgPool2d(1)#GeM(p_trainable=False)
        
        elif CFG.model_arch == "convnextv2_tiny":
            cnn_feature = 1000
            self.model.global_pool = nn.Identity()
            self.pooling = nn.AdaptiveAvgPool2d(1)#GeM(p_trainable=True)

        self.spatialdropout = SpatialDropout(CFG.dropout)
        self.gru = nn.GRU(cnn_feature, cnn_feature//2, num_layers=1, bidirectional=True, batch_first=True)
        self.mlp_attention_layers = nn.ModuleList([MLPAttentionNetwork(cnn_feature) for _ in range(5)])

        self.logits = nn.ModuleList([nn.Sequential(
                nn.Dropout(0.1),
                nn.Linear(cnn_feature, 15),
                ) for _ in range(5)])

    def forward(self, x): # (B, seq_len, H, W)
        bs = x.size(0)
        
        x = x.reshape(bs*self.seq_len//3, 3, x.size(2), x.size(3)) # (B*seq_len, 1, H, W)
        features = self.model(x)
        if "res" in CFG.model_arch:
            features = self.pooling(features).view(bs*self.seq_len//3, -1) # (B*seq_len, cnn_feature)
        features = self.spatialdropout(features)                # (B*seq_len, cnn_feature)
        features = features.reshape(bs, self.seq_len//3, -1)       # (B, seq_len, cnn_feature)
        features, _ = self.gru(features)                        # (B, seq_len, hidden_dim*2)
        preds = []
        for att_layer, logit in zip(self.mlp_attention_layers, self.logits):
            p = att_layer(features)
            p = logit(p)
            preds.append(p)
        preds = torch.cat(preds, -1)
        if self.softmax:
            preds = torch.cat([preds[:, i:i+3].softmax(1) for i in range(0, 75, 3)], -1)
        return preds

In [16]:
# model = RSNAClassifier(CFG.model_arch, hidden_dim=128, seq_len=CFG.seq_len, pretrained=True)

In [17]:
def get_activation(activ_name: str="relu"):
    """"""
    act_dict = {
        "relu": nn.ReLU(inplace=True),
        "tanh": nn.Tanh(),
        "sigmoid": nn.Sigmoid(),
        "identity": nn.Identity()}
    if activ_name in act_dict:
        return act_dict[activ_name]
    else:
        raise NotImplementedError
        

class Conv2dBNActiv(nn.Module):
    """Conv2d -> (BN ->) -> Activation"""
    
    def __init__(
        self, in_channels, out_channels,
        kernel_size, stride, padding,
        bias=False, use_bn=True, activ="relu"
    ):
        """"""
        super(Conv2dBNActiv, self).__init__()
        layers = []
        layers.append(nn.Conv2d(
            in_channels, out_channels,
            kernel_size, stride, padding, bias=bias))
        if use_bn:
            layers.append(nn.BatchNorm2d(out_channels))
            
        layers.append(get_activation(activ))
        self.layers = nn.Sequential(*layers)
        
    def forward(self, x):
        """Forward"""
        return self.layers(x)
        
    
class SpatialAttentionBlock(nn.Module):
    """Spatial Attention for (C, H, W) feature maps"""
    
    def __init__(
        self, in_channels,
        out_channels_list,
    ):
        """Initialize"""
        super(SpatialAttentionBlock, self).__init__()
        self.n_layers = len(out_channels_list)
        channels_list = [in_channels] + out_channels_list
        assert self.n_layers > 0
        assert channels_list[-1] == 1
        
        for i in range(self.n_layers - 1):
            in_chs, out_chs = channels_list[i: i + 2]
            layer = Conv2dBNActiv(in_chs, out_chs, 3, 1, 1, activ="relu")
            setattr(self, f"conv{i + 1}", layer)
            
        in_chs, out_chs = channels_list[-2:]
        layer = Conv2dBNActiv(in_chs, out_chs, 3, 1, 1, activ="sigmoid")
        setattr(self, f"conv{self.n_layers}", layer)
    
    def forward(self, x):
        """Forward"""
        h = x
        for i in range(self.n_layers):
            h = getattr(self, f"conv{i + 1}")(h)
            
        h = h * x
        return h



class MultiHeadResNet200D(nn.Module):
    def __init__(self, out_dims_head=[3, 4, 3, 1],  pretrained=False):
        self.base_name = "resnet200d_320"
        self.n_heads = len(out_dims_head)
        super(MultiHeadResNet200D, self).__init__()
        
        # # load base model
        base_model = timm.create_model(self.base_name, num_classes=sum(out_dims_head), pretrained=False)
        in_features = base_model.num_features
        
        if pretrained:
            pretrained_model_path = CFG.student
            state_dict = dict()
            for k, v in torch.load(pretrained_model_path, map_location='cpu')["model"].items():
                if k[:6] == "model.":
                    k = k.replace("model.", "")
                state_dict[k] = v
            base_model.load_state_dict(state_dict)
        
        # # remove global pooling and head classifier
        base_model.reset_classifier(0, '')
        
        # # Shared CNN Bacbone
        self.backbone = base_model
        
        # # Multi Heads.
        for i, out_dim in enumerate(out_dims_head):
            layer_name = f"head_{i}"
            layer = nn.Sequential(
                SpatialAttentionBlock(in_features, [64, 32, 16, 1]),
                nn.AdaptiveAvgPool2d(output_size=1),
                nn.Flatten(start_dim=1),
                nn.Linear(in_features, in_features),
                nn.ReLU(inplace=True),
                nn.Dropout(0.5),
                nn.Linear(in_features, out_dim))
            setattr(self, layer_name, layer)

    def forward(self, x):
        h = self.backbone(x)
        hs = [getattr(self, f"head_{i}")(h) for i in range(self.n_heads)]
        y = torch.cat(hs, axis=1)
        return None, None, y

In [18]:
def mixup(input, truth, clip=[0, 1]):
    indices = torch.randperm(input.size(0))
    shuffled_input = input[indices]
    shuffled_labels = truth[indices]

    lam = np.random.uniform(clip[0], clip[1])
    input = input * lam + shuffled_input * (1 - lam)
    return input, truth, shuffled_labels, lam



def train_one_epoch(train_loader, model, criterion, optimizer, epoch, scheduler, device):
    if CFG.device == 'GPU':
        scaler = GradScaler()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    # switch to train mode
    model.train()
    start = end = time.time()
    for step, (study_id, images, labels) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = images.to(device, dtype=torch.float)
        labels = labels.to(device, dtype=torch.float)
        batch_size = labels.size(0)
        
        do_mixup = False
        if random.random() < CFG.p_mixup:
            do_mixup = True
            images, labels, labels_mix, lam = mixup(images, labels)

        if CFG.device == 'GPU':
            # with autocast():
            y_preds = model(images)
            y_preds = y_preds.squeeze(1)
            loss = criterion(y_preds, labels)
            if do_mixup:
                loss11 = criterion(y_preds, labels_mix)
                loss = loss * lam  + loss11 * (1 - lam)
            # record loss
            losses.update(loss.item(), batch_size)
            if CFG.accum_iter > 1:
                loss = loss / CFG.accum_iter
            scaler.scale(loss).backward()
            grad_norm = 0 # torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
            if (step + 1) % CFG.accum_iter == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
        elif CFG.device == 'TPU':
            y_preds = model(images)
            loss = criterion(y_preds, labels)
            # record loss
            losses.update(loss.item(), batch_size)
            if CFG.accum_iter > 1:
                loss = loss / CFG.accum_iter
            loss.backward()
            grad_norm = 0 # torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
            if (step + 1) % CFG.accum_iter == 0:
                xm.optimizer_step(optimizer, barrier=True)
                optimizer.zero_grad()
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
            cusprint('Epoch: [{0}][{1}/{2}] '
                'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                'Elapsed {remain:s} '
                'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                'Grad: {grad_norm:.4f}  '
                'LR: {lr:.7f}  '
                .format(
                epoch, step, len(train_loader), batch_time=batch_time,
                data_time=data_time, loss=losses,
                remain=timeSince(start, float(step+1)/len(train_loader)),
                grad_norm=grad_norm,
                lr=optimizer.param_groups[0]["lr"],
                ))
        if(step == 0):print(loss)

    return losses.avg, optimizer.param_groups[0]["lr"]

In [19]:
# 

In [20]:
def tta(images, nbr):
    return torch.cat([images[:,nbr:,...], torch.cat(nbr*[images[:,-1:,...]], 1)], 1)


def evaluate(y_preds, y_true, www = 1):
    y_pred = torch.from_numpy(y_preds.copy())
    y = torch.from_numpy(y_true.copy())
    # loss0 = criterion(y_pred, y)
    
    y_pred = y_pred.reshape(y_pred.shape[0], 25, 3).transpose(1, 2).softmax(dim=1)
    if www==2:
        y_pred[:, 1] *= 2
        y_pred[:, 2] *= 4
        y_pred /= y_pred.sum(1)[:, None]
    y_pred = (torch.log(y_pred)/(1-y_pred))
    y = y.reshape(y.shape[0], 25, 3).transpose(1, 2).argmax(1)

    n_sum = 0
    # loss0_sum = 0.0  # loss for the criterion
    
    # 4 losses for the evaluation metric
    loss4_sum = torch.zeros(4)
    w_sum = torch.zeros(4)
    slices = [slice(0, 5), slice(5, 15), slice(15, 25)]  # spinal, foraminal, subarticular

    w = 2 ** y  # sample_weight w = (1, 2, 4) for y = 0, 1, 2 (batch_size, 25)

    batch_size = y_pred.shape[0]

    n_sum += batch_size
    # loss0_sum += loss0.item() * batch_size

    # Compute score
    # - weighted loss for spinal, foraminal, subarticular
    # - binary cross entropy for maximum spinal severe
    ce_loss = F.cross_entropy(y_pred, y.long(), reduction='none')  # (batch_size, 25)
    for k, idx in enumerate(slices):
        w_sum[k] += w[:, idx].sum()
        loss4_sum[k] += (w[:, idx] * ce_loss[:, idx]).sum()
    # Spinal max
    y_spinal_prob = y_pred[:, :, :5].softmax(dim=1)            # (batch_size, 3,  5)
    w_max = torch.amax(w[:, :5], dim=1)                        # (batch_size, )
    y_max = torch.amax(y[:, :5] == 2, dim=1).to(torch.float)   # 0 or 1
    y_pred_max = y_spinal_prob[:, 2, :].amax(dim=1)            # max in severe (class=2)

    loss_max = F.binary_cross_entropy(y_pred_max, y_max, reduction='none')
    loss4_sum[3] += (w_max * loss_max).sum()
    w_sum[3] += w_max.sum()
    # Average over spinal, foraminal, subarticular, and any_severe_spinal
    score = (loss4_sum / w_sum).sum().item() / 4
    print('weight =', www, 'evaluation score =', score)


    return score



def valid_one_epoch(valid_loader, model, criterion, device, softmax = False):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    # switch to evaluation mode
    model.eval()
    trues = []
    preds = []
    start = end = time.time()
    for step, (study_id, images, labels) in enumerate(valid_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = images.to(device, dtype=torch.float)
        labels = labels.to(device, dtype=torch.float)
        batch_size = labels.size(0)
        # compute loss
        with torch.no_grad():
            y_preds = model(images)#(model(images) + model(tta(images, 1)))/2
            y_preds = y_preds.squeeze(1)
        loss = criterion(y_preds, labels)
        losses.update(loss.item(), batch_size)
        # record accuracy
        # y_preds = y_preds.reshape(-1, 25, 3).softmax(dim=2).reshape(-1, 75)
        trues.append(labels.to('cpu').numpy())
        preds.append(y_preds.to('cpu').numpy())
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(valid_loader)-1):
            cusprint('EVAL: [{0}/{1}] '
                'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                'Elapsed {remain:s} '
                'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                .format(
                step, len(valid_loader), batch_time=batch_time,
                data_time=data_time, loss=losses,
                remain=timeSince(start, float(step+1)/len(valid_loader)),
                ))

    trues = np.concatenate(trues)
    predictions = np.concatenate(preds)
    sc = evaluate(predictions, trues, www = 1)
    sc = evaluate(predictions, trues, www = 2)
    print(f"trues.shape: {trues.shape}")
    print(f"predictions.shape: {predictions.shape}")
    if softmax:
        score = nn.BCELoss()(torch.from_numpy(predictions).type(torch.float32), torch.from_numpy(trues).type(torch.float32))
        score2 = comp_metric_score(trues, predictions, apply_sig = False)
        score3 = comp_metric_score(trues, predictions, w1 = 2, w2 = 4, apply_sig = False)
    else:
        score = nn.BCEWithLogitsLoss()(torch.from_numpy(predictions).type(torch.float32), torch.from_numpy(trues).type(torch.float32))
        score2 = comp_metric_score(trues, predictions, apply_sig = True)
        score3 = comp_metric_score(trues, predictions, w1 = 2, w2 = 4, apply_sig = True)
        _ = comp_metric_score2(trues, predictions, w1 = 1, w2 = 1, apply_sig = True)
        _ = comp_metric_score2(trues, predictions, w1 = 2, w2 = 4, apply_sig = True)

    
    return losses.avg, predictions, trues, score, score2, score3, sc

# loss & optimizer & scheduler

In [21]:
class GradualWarmupSchedulerV3(GradualWarmupScheduler):
    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        super(GradualWarmupSchedulerV3, self).__init__(optimizer, multiplier, total_epoch, after_scheduler)
    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]
        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]

# Training

In [22]:
def train_loop(df, ser_dict_sag_t1, ser_dict_sag_t2, ser_dict_ax_t2, fold):
    loginfo(f"========== fold: {fold} training ==========")
    if CFG.plane == 'sag_t2':
        df = df[df['study_id'].isin(ser_dict_sag_t2.keys())]
    elif CFG.plane == 'sag_t1':
        df = df[df['study_id'].isin(ser_dict_sag_t1.keys())]
    else:
        df = df[df['study_id'].isin(ser_dict_ax_t2.keys())]

    # ====================================================
    # loader
    # ====================================================
    train_folds = df[df['fold']!=fold].reset_index(drop=True)################################################"[:700]
    valid_folds = df[df['fold']==fold].reset_index(drop=True)
    
    softmax = False
    if CFG.loss_fn == "BCEWithLogitsLoss":
        criterion = nn.BCEWithLogitsLoss()
    elif CFG.loss_fn == 'Custom_loss':
        cust = Custom_loss()
        cust2 = nn.BCEWithLogitsLoss()
        def criterion(y_pred, y):
            return 0.75*cust(y_pred, y) + 0.25*cust2(y_pred, y)
        # softmax = True
    elif(CFG.loss_fn == 'FocalLoss'):
        criterion = FocalLoss()


    train_dataset = TrainDataset(train_folds, ser_dict_sag_t1, ser_dict_sag_t2, ser_dict_ax_t2, CFG.target_cols, 'train', None, CFG.seq_len, transform=get_transforms(data='light_train'))
    valid_dataset = TrainDataset(valid_folds, ser_dict_sag_t1, ser_dict_sag_t2, ser_dict_ax_t2, CFG.target_cols, 'valid', None, CFG.seq_len, transform=get_transforms(data='valid'))
    if CFG.device == 'GPU':
        # sampler_train = torch.utils.data.RandomSampler(ds_train_len)
        train_loader = DataLoader(train_dataset, batch_size=CFG.train_bs, shuffle=True, num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
        valid_loader = DataLoader(valid_dataset, batch_size=CFG.valid_bs, shuffle=False, num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
    elif CFG.device == 'TPU':
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=CFG.train_bs, sampler=train_sampler, drop_last=True, num_workers=CFG.num_workers)
        valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False)
        valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=CFG.valid_bs, sampler=valid_sampler, drop_last=False, num_workers=CFG.num_workers)

    # ====================================================
    # model & optimizer & scheduler & loss
    # ====================================================
    # not checkpoint

    if CFG.multihead:
        model = MultiHeadResNet200D([3, 4, 3, 1], True)
    else:
        model = RSNAClassifier(CFG.model_arch, hidden_dim=256, seq_len=CFG.seq_len, pretrained=True, softmax = softmax)

        
    if CFG.gpu_parallel:    
        num_gpu = torch.cuda.device_count()
        model = DataParallel(model, device_ids=range(num_gpu))
    model.to(device)
    
    
    # optimizer
    if CFG.optimizer == "AdamW":
        if CFG.scheduler_warmup in ["GradualWarmupSchedulerV2","GradualWarmupSchedulerV3"]:
            optimizer = AdamW(model.parameters(), lr=CFG.lr/CFG.warmup_factor, weight_decay=CFG.weight_decay) 
        else:
            optimizer = AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
    if CFG.optimizer == "Adam":
        if CFG.scheduler_warmup in ["GradualWarmupSchedulerV2","GradualWarmupSchedulerV3"]:
            optimizer = Adam(model.parameters(), lr=CFG.lr/CFG.warmup_factor, weight_decay=CFG.weight_decay) 
        else:
            optimizer = Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
    if CFG.optimizer == "SGD":
        if CFG.scheduler_warmup in ["GradualWarmupSchedulerV2","GradualWarmupSchedulerV3"]:
            optimizer = SGD(model.parameters(), lr=CFG.lr/CFG.warmup_factor, weight_decay=CFG.weight_decay) 
        else:
            optimizer = SGD(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
    # scheduler,
    if CFG.scheduler=='ReduceLROnPlateau':
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=CFG.factor, patience=CFG.patience, verbose=True, eps=CFG.eps)
    elif CFG.scheduler=='CosineAnnealingLR':
        scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1)
    elif CFG.scheduler=='CosineAnnealingWarmRestarts':
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, eta_min=CFG.min_lr, last_epoch=-1)

    scheduler_warmup = GradualWarmupSchedulerV3(optimizer, multiplier=CFG.warmup_factor, total_epoch=CFG.warmup_epo, after_scheduler=scheduler)



    # ====================================================
    # loop
    # ====================================================
    valid_acc_max=0; valid_loss_min=float("inf")
    valid_acc_max_cnt=0; valid_loss_min_cnt=0;
    best_acc_epoch=0;


    best_score = 10000

    shift_epoch = CFG.shift_epoch#(3*CFG.epochs)//4

    for epoch in range(CFG.epochs):
        if epoch == shift_epoch:
            model_ = RSNAClassifier(CFG.model_arch, hidden_dim=256, seq_len=CFG.seq_len, pretrained=True, softmax = True)
            model_.to(device)
            model_.load_state_dict(model.state_dict())
            model = model_
            criterion = Custom_loss()
            softmax = True
            optimizer_ = AdamW(model.parameters(), lr=CFG.lr/CFG.warmup_factor, weight_decay=CFG.weight_decay) 
            optimizer_.load_state_dict(optimizer.state_dict())
            optimizer = optimizer_

        loginfo(f"***** Epoch {epoch} *****")

        if CFG.scheduler_warmup in ["GradualWarmupSchedulerV2","GradualWarmupSchedulerV3"]:
            loginfo(f"schwarmup_last_epoch:{scheduler_warmup.last_epoch}, schwarmup_lr:{scheduler_warmup.get_last_lr()[0]}")
        if CFG.scheduler=='CosineAnnealingLR':
            loginfo(f"scheduler_last_epoch:{scheduler.last_epoch}, scheduler_lr:{scheduler.get_last_lr()[0]}")
        loginfo(f"optimizer_lr:{optimizer.param_groups[0]['lr']}")


        start_time = time.time()
        if(epoch>=CFG.min_epoch):
            avg_loss, cur_lr = train_one_epoch(train_loader, model, criterion, optimizer, epoch, scheduler, device) # train
        else:
            avg_loss, cur_lr = -1, -1
            model.load_state_dict(torch.load(outputdir+f'/{CFG.model_arch}_{CFG.suffix}_fold{fold}.pth')['model'])
        avg_val_loss, preds, trues, score, score2, score3, score4 = valid_one_epoch(valid_loader, model, criterion, device, softmax = softmax) # valid

        # scoring
        elapsed = time.time() - start_time 

        loginfo(f'Epoch {epoch} - avg_train_loss: {avg_loss:.4f}  time: {elapsed:.0f}s')
        loginfo(f'Epoch {epoch} - avg_val_loss: {avg_val_loss:.4f} - avg_val_bce: {score:.4f} - avg_val_comp_metric : {score2:.4f} - w_avg_val_comp_metric : {score3:.4f}')

        if CFG.scheduler_warmup in ["GradualWarmupSchedulerV2","GradualWarmupSchedulerV3"]:
            scheduler_warmup.step()
        elif CFG.scheduler == "ReduceLROnPlateau":
            scheduler.step(avg_val_loss)
        elif CFG.scheduler in ["CosineAnnealingLR", "CosineAnnealingWarmRestarts"]:
            scheduler.step()

        if best_score>score4:
            best_score = score4
            np.save(outputdir+f'/{CFG.model_arch}_{CFG.suffix}__fold{fold}_preds.npy', preds)
            np.save(outputdir+f'/{CFG.model_arch}_{CFG.suffix}_fold{fold}_trues.npy', trues)
            torch.save({'model': model.state_dict()}, outputdir+f'/{CFG.model_arch}_{CFG.suffix}_fold{fold}.pth')

    return preds, trues

In [23]:
def main():
    oof_df = pd.DataFrame()
    oof_list = []
    train_df, ser_dict_sag_t1, ser_dict_sag_t2, ser_dict_ax_t2 = get_df()
    for fold in CFG.fold_list:
        preds, trues = train_loop(train_df, ser_dict_sag_t1, ser_dict_sag_t2, ser_dict_ax_t2, fold)
        oof_list.append([preds, trues])
    return oof_list

# Main

In [24]:
if __name__ == '__main__':
    print(CFG.suffix)
    if CFG.device == 'TPU':
        def _mp_fn(rank, flags):
            torch.set_default_tensor_type('torch.FloatTensor')
            a = main()
        FLAGS = {}
        xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=CFG.nprocs, start_method='fork')
    elif CFG.device == 'GPU':
        oof_list = main()

440
1975 1973 1974 1975


***** Epoch 0 *****
schwarmup_last_epoch:0, schwarmup_lr:4.000000000000001e-06
scheduler_last_epoch:0, scheduler_lr:4.000000000000001e-06
optimizer_lr:4.000000000000001e-06


Epoch: [0][0/371] Data 1.577 (1.577) Elapsed 0m 6s (remain 39m 52s) Loss: 0.7837(0.7837) Grad: 0.0000  LR: 0.0000040  
tensor(0.7837, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [0][100/371] Data 0.000 (0.035) Elapsed 0m 41s (remain 1m 50s) Loss: 0.5367(0.6327) Grad: 0.0000  LR: 0.0000040  
Epoch: [0][200/371] Data 0.000 (0.040) Elapsed 1m 18s (remain 1m 6s) Loss: 0.3227(0.5501) Grad: 0.0000  LR: 0.0000040  
Epoch: [0][300/371] Data 0.000 (0.042) Elapsed 1m 56s (remain 0m 27s) Loss: 0.3276(0.5093) Grad: 0.0000  LR: 0.0000040  
Epoch: [0][370/371] Data 0.000 (0.040) Elapsed 2m 23s (remain 0m 0s) Loss: 0.2368(0.4899) Grad: 0.0000  LR: 0.0000040  
EVAL: [0/124] Data 1.122 (1.122) Elapsed 0m 1s (remain 2m 32s) Loss: 0.3727(0.3727) 
EVAL: [100/124] Data 0.000 (0.027) Elapsed 0m 13s (remain 0m 3s) Loss: 0.2698(0.4047) 
EVAL: [123/124] Data 0.000 (0.024) Elapsed 0m 16s (remain 0m 0s) Loss: 0.7324(0.4056) 
weight = 1 evaluation score = 0.8433234095573425
weight = 2 evaluation score = 0.795

Epoch 0 - avg_train_loss: 0.4899  time: 163s
Epoch 0 - avg_val_loss: 0.4056 - avg_val_bce: 0.4067 - avg_val_comp_metric : 0.8588 - w_avg_val_comp_metric : 0.9306


0.7419198439274307 0.844329339235604 0.9665810492185584 0.490820657295534
w1 = 1 w2 = 1 0.7609127224192818
0.9413512739865532 0.9728052561585461 1.032609104900716 0.631371999823924
w1 = 2 w2 = 4 0.8945344087174347


***** Epoch 1 *****
schwarmup_last_epoch:1, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:0, scheduler_lr:4.000000000000001e-06
optimizer_lr:4.000000000000001e-05


Epoch: [1][0/371] Data 2.947 (2.947) Elapsed 0m 3s (remain 20m 18s) Loss: 0.3397(0.3397) Grad: 0.0000  LR: 0.0000400  
tensor(0.3397, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [1][100/371] Data 0.000 (0.058) Elapsed 0m 39s (remain 1m 45s) Loss: 0.2763(0.4119) Grad: 0.0000  LR: 0.0000400  
Epoch: [1][200/371] Data 0.000 (0.029) Elapsed 1m 12s (remain 1m 1s) Loss: 0.3476(0.4145) Grad: 0.0000  LR: 0.0000400  
Epoch: [1][300/371] Data 0.000 (0.021) Elapsed 1m 45s (remain 0m 24s) Loss: 0.4329(0.4118) Grad: 0.0000  LR: 0.0000400  
Epoch: [1][370/371] Data 0.000 (0.021) Elapsed 2m 9s (remain 0m 0s) Loss: 0.3571(0.4106) Grad: 0.0000  LR: 0.0000400  
EVAL: [0/124] Data 1.225 (1.225) Elapsed 0m 1s (remain 2m 43s) Loss: 0.3453(0.3453) 
EVAL: [100/124] Data 0.000 (0.077) Elapsed 0m 18s (remain 0m 4s) Loss: 0.2832(0.3646) 
EVAL: [123/124] Data 0.000 (0.072) Elapsed 0m 21s (remain 0m 0s) Loss: 0.6138(0.3649) 
weight = 1 evaluation score = 0.7798622846603394
weight = 2 evaluation score = 0.7063

Epoch 1 - avg_train_loss: 0.4106  time: 153s
Epoch 1 - avg_val_loss: 0.3649 - avg_val_bce: 0.3519 - avg_val_comp_metric : 0.7835 - w_avg_val_comp_metric : 0.8019


0.6291199201516546 0.803136475912091 0.9072375737186068 0.42135652332701184
w1 = 1 w2 = 1 0.690212623277341
0.6922226124258302 0.8186479348525237 0.8870460743163279 0.47005739209367464
w1 = 2 w2 = 4 0.7169935034220891


***** Epoch 2 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:0, scheduler_lr:4.000000000000001e-06
optimizer_lr:4.000000000000001e-05


Epoch: [2][0/371] Data 1.723 (1.723) Elapsed 0m 2s (remain 12m 41s) Loss: 0.6146(0.6146) Grad: 0.0000  LR: 0.0000400  
tensor(0.6146, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [2][100/371] Data 0.000 (0.025) Elapsed 0m 35s (remain 1m 35s) Loss: 0.4120(0.3882) Grad: 0.0000  LR: 0.0000400  
Epoch: [2][200/371] Data 0.000 (0.013) Elapsed 1m 8s (remain 0m 58s) Loss: 0.3874(0.3996) Grad: 0.0000  LR: 0.0000400  
Epoch: [2][300/371] Data 0.000 (0.009) Elapsed 1m 41s (remain 0m 23s) Loss: 0.4375(0.3955) Grad: 0.0000  LR: 0.0000400  
Epoch: [2][370/371] Data 0.000 (0.007) Elapsed 2m 4s (remain 0m 0s) Loss: 1.1872(0.3890) Grad: 0.0000  LR: 0.0000400  
EVAL: [0/124] Data 0.892 (0.892) Elapsed 0m 0s (remain 2m 2s) Loss: 0.4542(0.4542) 
EVAL: [100/124] Data 0.000 (0.013) Elapsed 0m 11s (remain 0m 2s) Loss: 0.2122(0.3578) 
EVAL: [123/124] Data 0.000 (0.011) Elapsed 0m 14s (remain 0m 0s) Loss: 0.6240(0.3570) 
weight = 1 evaluation score = 0.7925848960876465
weight = 2 evaluation score = 0.66066

Epoch 2 - avg_train_loss: 0.3890  time: 140s
Epoch 2 - avg_val_loss: 0.3570 - avg_val_bce: 0.3195 - avg_val_comp_metric : 0.7381 - w_avg_val_comp_metric : 0.7092


0.574548209540958 0.781706180024465 0.9208682263617168 0.3793652181175898
w1 = 1 w2 = 1 0.6641219585111824
0.5593427754576205 0.7335978266724797 0.8203052224634687 0.34317889490306225
w1 = 2 w2 = 4 0.6141061798741578


***** Epoch 3 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:1, scheduler_lr:3.992812989642341e-05
optimizer_lr:3.992812989642341e-05


Epoch: [3][0/371] Data 2.208 (2.208) Elapsed 0m 2s (remain 15m 39s) Loss: 0.2098(0.2098) Grad: 0.0000  LR: 0.0000399  
tensor(0.2098, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [3][100/371] Data 0.000 (0.026) Elapsed 0m 35s (remain 1m 36s) Loss: 0.3000(0.3749) Grad: 0.0000  LR: 0.0000399  
Epoch: [3][200/371] Data 0.000 (0.013) Elapsed 1m 8s (remain 0m 58s) Loss: 0.4479(0.3818) Grad: 0.0000  LR: 0.0000399  
Epoch: [3][300/371] Data 0.000 (0.009) Elapsed 1m 41s (remain 0m 23s) Loss: 0.1976(0.3803) Grad: 0.0000  LR: 0.0000399  
Epoch: [3][370/371] Data 0.000 (0.007) Elapsed 2m 4s (remain 0m 0s) Loss: 0.2417(0.3790) Grad: 0.0000  LR: 0.0000399  
EVAL: [0/124] Data 1.252 (1.252) Elapsed 0m 1s (remain 2m 47s) Loss: 0.3527(0.3527) 
EVAL: [100/124] Data 0.000 (0.027) Elapsed 0m 13s (remain 0m 3s) Loss: 0.2142(0.3320) 
EVAL: [123/124] Data 0.000 (0.022) Elapsed 0m 15s (remain 0m 0s) Loss: 0.5607(0.3309) 
weight = 1 evaluation score = 0.6532269716262817
weight = 2 evaluation score = 0.5965

Epoch 3 - avg_train_loss: 0.3790  time: 141s
Epoch 3 - avg_val_loss: 0.3309 - avg_val_bce: 0.3204 - avg_val_comp_metric : 0.6835 - w_avg_val_comp_metric : 0.7002


0.5232840493932013 0.731886144536415 0.8149411062710912 0.3146891214030798
w1 = 1 w2 = 1 0.5962001054009469
0.5627235420429982 0.7467980966202015 0.8025739479529698 0.36403066373127707
w1 = 2 w2 = 4 0.6190315625868617


***** Epoch 4 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:2, scheduler_lr:3.971303741143679e-05
optimizer_lr:3.971303741143679e-05


Epoch: [4][0/371] Data 2.366 (2.366) Elapsed 0m 2s (remain 16m 38s) Loss: 0.3830(0.3830) Grad: 0.0000  LR: 0.0000397  
tensor(0.3830, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [4][100/371] Data 0.403 (0.037) Elapsed 0m 37s (remain 1m 38s) Loss: 0.3673(0.3603) Grad: 0.0000  LR: 0.0000397  
Epoch: [4][200/371] Data 0.000 (0.019) Elapsed 1m 9s (remain 0m 59s) Loss: 0.3788(0.3735) Grad: 0.0000  LR: 0.0000397  
Epoch: [4][300/371] Data 0.000 (0.013) Elapsed 1m 42s (remain 0m 23s) Loss: 0.4839(0.3794) Grad: 0.0000  LR: 0.0000397  
Epoch: [4][370/371] Data 0.000 (0.010) Elapsed 2m 5s (remain 0m 0s) Loss: 0.0640(0.3739) Grad: 0.0000  LR: 0.0000397  
EVAL: [0/124] Data 1.231 (1.231) Elapsed 0m 1s (remain 2m 44s) Loss: 0.4184(0.4184) 
EVAL: [100/124] Data 0.000 (0.017) Elapsed 0m 12s (remain 0m 2s) Loss: 0.1691(0.3702) 
EVAL: [123/124] Data 0.000 (0.014) Elapsed 0m 14s (remain 0m 0s) Loss: 0.7510(0.3726) 
weight = 1 evaluation score = 0.8110330104827881
weight = 2 evaluation score = 0.6556

Epoch 4 - avg_train_loss: 0.3739  time: 141s
Epoch 4 - avg_val_loss: 0.3726 - avg_val_bce: 0.3122 - avg_val_comp_metric : 0.7292 - w_avg_val_comp_metric : 0.6746
***** Epoch 5 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:3, scheduler_lr:3.935627229132331e-05
optimizer_lr:3.935627229132331e-05


0.588497845975669 0.8282757858311082 0.9399402745721676 0.35893896704343736
w1 = 1 w2 = 1 0.6789132183555956
0.5125645302195473 0.7216581092665411 0.7948323837758897 0.33371511809617477
w1 = 2 w2 = 4 0.5906925353395382
Epoch: [5][0/371] Data 2.083 (2.083) Elapsed 0m 2s (remain 14m 54s) Loss: 0.3826(0.3826) Grad: 0.0000  LR: 0.0000394  
tensor(0.3826, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [5][100/371] Data 0.000 (0.028) Elapsed 0m 36s (remain 1m 36s) Loss: 0.3838(0.3647) Grad: 0.0000  LR: 0.0000394  
Epoch: [5][200/371] Data 0.000 (0.014) Elapsed 1m 8s (remain 0m 58s) Loss: 0.3223(0.3670) Grad: 0.0000  LR: 0.0000394  
Epoch: [5][300/371] Data 0.000 (0.010) Elapsed 1m 41s (remain 0m 23s) Loss: 0.4357(0.3734) Grad: 0.0000  LR: 0.0000394  
Epoch: [5][370/371] Data 0.000 (0.008) Elapsed 2m 4s (remain 0m 0s) Loss: 0.1530(0.3709) Grad: 0.0000  LR: 0.0000394  
EVAL: [0/124] Data 1.508 (1.508) Elapsed 0m 1s (remain 3m 18s) Loss: 0.3108(0.3108) 
EVAL: [100/124] Data 0.000 (0.055) Elaps

Epoch 5 - avg_train_loss: 0.3709  time: 146s
Epoch 5 - avg_val_loss: 0.3295 - avg_val_bce: 0.3287 - avg_val_comp_metric : 0.6891 - w_avg_val_comp_metric : 0.7129
***** Epoch 6 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:4, scheduler_lr:3.886040503694863e-05
optimizer_lr:3.886040503694863e-05


0.5257845532326325 0.7393807701807827 0.8023380651988927 0.31510656620782723
w1 = 1 w2 = 1 0.5956524887050338
0.5762823550182056 0.7644696688543574 0.8213662650465948 0.35679152471692205
w1 = 2 w2 = 4 0.62972745340902
Epoch: [6][0/371] Data 1.759 (1.759) Elapsed 0m 2s (remain 12m 55s) Loss: 0.3744(0.3744) Grad: 0.0000  LR: 0.0000389  
tensor(0.3744, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [6][100/371] Data 0.000 (0.020) Elapsed 0m 35s (remain 1m 34s) Loss: 0.3577(0.3760) Grad: 0.0000  LR: 0.0000389  
Epoch: [6][200/371] Data 0.000 (0.010) Elapsed 1m 8s (remain 0m 57s) Loss: 0.2002(0.3737) Grad: 0.0000  LR: 0.0000389  
Epoch: [6][300/371] Data 0.000 (0.007) Elapsed 1m 41s (remain 0m 23s) Loss: 0.3942(0.3644) Grad: 0.0000  LR: 0.0000389  
Epoch: [6][370/371] Data 0.000 (0.006) Elapsed 2m 4s (remain 0m 0s) Loss: 0.1779(0.3616) Grad: 0.0000  LR: 0.0000389  
EVAL: [0/124] Data 0.877 (0.877) Elapsed 0m 0s (remain 2m 0s) Loss: 0.3298(0.3298) 
EVAL: [100/124] Data 0.000 (0.012) Elapsed

Epoch 6 - avg_train_loss: 0.3616  time: 139s
Epoch 6 - avg_val_loss: 0.3187 - avg_val_bce: 0.3120 - avg_val_comp_metric : 0.6326 - w_avg_val_comp_metric : 0.6642


0.4575715703538093 0.7164079644640079 0.7829124485367001 0.25572843629187864
w1 = 1 w2 = 1 0.553155104911599
0.5196367430637133 0.7567231656672391 0.7701882043916575 0.32923187352667926
w1 = 2 w2 = 4 0.5939449966623223


***** Epoch 7 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:5, scheduler_lr:3.8229008383243314e-05
optimizer_lr:3.8229008383243314e-05


Epoch: [7][0/371] Data 1.876 (1.876) Elapsed 0m 2s (remain 13m 36s) Loss: 0.5772(0.5772) Grad: 0.0000  LR: 0.0000382  
tensor(0.5772, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [7][100/371] Data 0.000 (0.035) Elapsed 0m 36s (remain 1m 38s) Loss: 0.2686(0.3655) Grad: 0.0000  LR: 0.0000382  
Epoch: [7][200/371] Data 0.000 (0.018) Elapsed 1m 9s (remain 0m 58s) Loss: 0.4867(0.3621) Grad: 0.0000  LR: 0.0000382  
Epoch: [7][300/371] Data 0.000 (0.012) Elapsed 1m 42s (remain 0m 23s) Loss: 0.3827(0.3629) Grad: 0.0000  LR: 0.0000382  
Epoch: [7][370/371] Data 0.000 (0.010) Elapsed 2m 5s (remain 0m 0s) Loss: 0.7311(0.3589) Grad: 0.0000  LR: 0.0000382  
EVAL: [0/124] Data 1.219 (1.219) Elapsed 0m 1s (remain 2m 42s) Loss: 0.3375(0.3375) 
EVAL: [100/124] Data 0.000 (0.033) Elapsed 0m 13s (remain 0m 3s) Loss: 0.2225(0.3122) 
EVAL: [123/124] Data 0.000 (0.033) Elapsed 0m 16s (remain 0m 0s) Loss: 0.5456(0.3126) 
weight = 1 evaluation score = 0.5883018970489502
weight = 2 evaluation score = 0.5394

Epoch 7 - avg_train_loss: 0.3589  time: 143s
Epoch 7 - avg_val_loss: 0.3126 - avg_val_bce: 0.3004 - avg_val_comp_metric : 0.6267 - w_avg_val_comp_metric : 0.6445


0.4546708077363768 0.719868699173662 0.7631989532907599 0.25986737802045634
w1 = 1 w2 = 1 0.5494014595553138
0.5007205938967226 0.7011659238495568 0.7503790893559256 0.32840748940299386
w1 = 2 w2 = 4 0.5701682741262998


***** Epoch 8 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:6, scheduler_lr:3.746663155756571e-05
optimizer_lr:3.746663155756571e-05


Epoch: [8][0/371] Data 2.333 (2.333) Elapsed 0m 2s (remain 16m 25s) Loss: 0.2220(0.2220) Grad: 0.0000  LR: 0.0000375  
tensor(0.2220, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [8][100/371] Data 0.000 (0.026) Elapsed 0m 35s (remain 1m 35s) Loss: 0.1933(0.3640) Grad: 0.0000  LR: 0.0000375  
Epoch: [8][200/371] Data 0.000 (0.015) Elapsed 1m 9s (remain 0m 58s) Loss: 0.2768(0.3481) Grad: 0.0000  LR: 0.0000375  
Epoch: [8][300/371] Data 0.000 (0.010) Elapsed 1m 41s (remain 0m 23s) Loss: 0.4499(0.3501) Grad: 0.0000  LR: 0.0000375  
Epoch: [8][370/371] Data 0.000 (0.008) Elapsed 2m 4s (remain 0m 0s) Loss: 0.3230(0.3559) Grad: 0.0000  LR: 0.0000375  
EVAL: [0/124] Data 0.929 (0.929) Elapsed 0m 1s (remain 2m 7s) Loss: 0.2748(0.2748) 
EVAL: [100/124] Data 0.000 (0.055) Elapsed 0m 16s (remain 0m 3s) Loss: 0.2247(0.3151) 
EVAL: [123/124] Data 0.000 (0.045) Elapsed 0m 18s (remain 0m 0s) Loss: 0.5519(0.3158) 
weight = 1 evaluation score = 0.6143432259559631
weight = 2 evaluation score = 0.56475

Epoch 8 - avg_train_loss: 0.3559  time: 144s
Epoch 8 - avg_val_loss: 0.3158 - avg_val_bce: 0.3109 - avg_val_comp_metric : 0.6571 - w_avg_val_comp_metric : 0.6731
***** Epoch 9 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:7, scheduler_lr:3.657876750241466e-05
optimizer_lr:3.657876750241466e-05


0.486366443707522 0.7248331422497347 0.7882382540027053 0.2833273089366249
w1 = 1 w2 = 1 0.5706912872241467
0.5444867835738466 0.7111917711197665 0.7736094558341932 0.3212174751592679
w1 = 2 w2 = 4 0.5876263714217685
Epoch: [9][0/371] Data 1.618 (1.618) Elapsed 0m 1s (remain 12m 2s) Loss: 0.3198(0.3198) Grad: 0.0000  LR: 0.0000366  
tensor(0.3198, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [9][100/371] Data 0.000 (0.033) Elapsed 0m 36s (remain 1m 37s) Loss: 0.1124(0.3452) Grad: 0.0000  LR: 0.0000366  
Epoch: [9][200/371] Data 0.000 (0.019) Elapsed 1m 9s (remain 0m 59s) Loss: 0.3433(0.3520) Grad: 0.0000  LR: 0.0000366  
Epoch: [9][300/371] Data 0.000 (0.013) Elapsed 1m 42s (remain 0m 23s) Loss: 0.1597(0.3501) Grad: 0.0000  LR: 0.0000366  
Epoch: [9][370/371] Data 0.000 (0.011) Elapsed 2m 6s (remain 0m 0s) Loss: 0.2452(0.3500) Grad: 0.0000  LR: 0.0000366  
EVAL: [0/124] Data 1.349 (1.349) Elapsed 0m 1s (remain 2m 59s) Loss: 0.3282(0.3282) 
EVAL: [100/124] Data 0.000 (0.052) Elapsed 

Epoch 9 - avg_train_loss: 0.3500  time: 146s
Epoch 9 - avg_val_loss: 0.3083 - avg_val_bce: 0.2930 - avg_val_comp_metric : 0.6230 - w_avg_val_comp_metric : 0.6251


0.4758730510702443 0.6906866767656276 0.7714362500440402 0.2571118636744634
w1 = 1 w2 = 1 0.5487769603885939
0.48977459501451126 0.6838273544807226 0.7281667345475067 0.27620842696873577
w1 = 2 w2 = 4 0.5444942777528691


***** Epoch 10 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:8, scheduler_lr:3.557181329865287e-05
optimizer_lr:3.557181329865287e-05


Epoch: [10][0/371] Data 2.162 (2.162) Elapsed 0m 2s (remain 15m 22s) Loss: 0.3140(0.3140) Grad: 0.0000  LR: 0.0000356  
tensor(0.3140, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [10][100/371] Data 0.000 (0.046) Elapsed 0m 37s (remain 1m 41s) Loss: 0.3556(0.3415) Grad: 0.0000  LR: 0.0000356  
Epoch: [10][200/371] Data 0.000 (0.023) Elapsed 1m 10s (remain 1m 0s) Loss: 0.3633(0.3442) Grad: 0.0000  LR: 0.0000356  
Epoch: [10][300/371] Data 0.000 (0.016) Elapsed 1m 43s (remain 0m 24s) Loss: 0.2755(0.3438) Grad: 0.0000  LR: 0.0000356  
Epoch: [10][370/371] Data 0.000 (0.013) Elapsed 2m 6s (remain 0m 0s) Loss: 0.1047(0.3468) Grad: 0.0000  LR: 0.0000356  
EVAL: [0/124] Data 0.980 (0.980) Elapsed 0m 1s (remain 2m 13s) Loss: 0.3630(0.3630) 
EVAL: [100/124] Data 0.000 (0.045) Elapsed 0m 15s (remain 0m 3s) Loss: 0.1686(0.3120) 
EVAL: [123/124] Data 0.000 (0.037) Elapsed 0m 17s (remain 0m 0s) Loss: 0.5477(0.3146) 
weight = 1 evaluation score = 0.6516157388687134
weight = 2 evaluation score = 0

Epoch 10 - avg_train_loss: 0.3468  time: 145s
Epoch 10 - avg_val_loss: 0.3146 - avg_val_bce: 0.2922 - avg_val_comp_metric : 0.6509 - w_avg_val_comp_metric : 0.6392
***** Epoch 11 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:9, scheduler_lr:3.445302407439277e-05
optimizer_lr:3.445302407439277e-05


0.4880394485056038 0.720777680696926 0.7978739811584754 0.29071456888345887
w1 = 1 w2 = 1 0.574351419811116
0.48492349382563565 0.6725962614939152 0.7297730834353466 0.2895126351975475
w1 = 2 w2 = 4 0.5442013684881113
Epoch: [11][0/371] Data 2.299 (2.299) Elapsed 0m 2s (remain 16m 14s) Loss: 0.2549(0.2549) Grad: 0.0000  LR: 0.0000345  
tensor(0.2549, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [11][100/371] Data 0.000 (0.029) Elapsed 0m 36s (remain 1m 36s) Loss: 0.3225(0.3479) Grad: 0.0000  LR: 0.0000345  
Epoch: [11][200/371] Data 0.000 (0.016) Elapsed 1m 9s (remain 0m 58s) Loss: 0.4182(0.3439) Grad: 0.0000  LR: 0.0000345  
Epoch: [11][300/371] Data 0.000 (0.013) Elapsed 1m 42s (remain 0m 23s) Loss: 0.4705(0.3460) Grad: 0.0000  LR: 0.0000345  
Epoch: [11][370/371] Data 0.000 (0.013) Elapsed 2m 6s (remain 0m 0s) Loss: 0.0587(0.3466) Grad: 0.0000  LR: 0.0000345  
EVAL: [0/124] Data 0.826 (0.826) Elapsed 0m 0s (remain 1m 54s) Loss: 0.2820(0.2820) 
EVAL: [100/124] Data 0.211 (0.053) E

Epoch 11 - avg_train_loss: 0.3466  time: 146s
Epoch 11 - avg_val_loss: 0.3097 - avg_val_bce: 0.3116 - avg_val_comp_metric : 0.6200 - w_avg_val_comp_metric : 0.6605
***** Epoch 12 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:10, scheduler_lr:3.3230460731632125e-05
optimizer_lr:3.3230460731632125e-05


0.44991999480669626 0.6805130239698972 0.7357822004079313 0.2563683484371655
w1 = 1 w2 = 1 0.5306458919054226
0.5325241439809516 0.7185013174543983 0.7750842015529342 0.3053140335327898
w1 = 2 w2 = 4 0.5828559241302684
Epoch: [12][0/371] Data 1.296 (1.296) Elapsed 0m 1s (remain 10m 2s) Loss: 0.2801(0.2801) Grad: 0.0000  LR: 0.0000332  
tensor(0.2801, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [12][100/371] Data 0.000 (0.041) Elapsed 0m 37s (remain 1m 39s) Loss: 0.3688(0.3493) Grad: 0.0000  LR: 0.0000332  
Epoch: [12][200/371] Data 0.000 (0.032) Elapsed 1m 12s (remain 1m 1s) Loss: 0.4242(0.3453) Grad: 0.0000  LR: 0.0000332  
Epoch: [12][300/371] Data 0.000 (0.022) Elapsed 1m 45s (remain 0m 24s) Loss: 0.4769(0.3455) Grad: 0.0000  LR: 0.0000332  
Epoch: [12][370/371] Data 0.000 (0.018) Elapsed 2m 8s (remain 0m 0s) Loss: 0.9055(0.3442) Grad: 0.0000  LR: 0.0000332  
EVAL: [0/124] Data 1.577 (1.577) Elapsed 0m 1s (remain 3m 27s) Loss: 0.3202(0.3202) 
EVAL: [100/124] Data 0.438 (0.088) E

Epoch 12 - avg_train_loss: 0.3442  time: 152s
Epoch 12 - avg_val_loss: 0.2977 - avg_val_bce: 0.2906 - avg_val_comp_metric : 0.6097 - w_avg_val_comp_metric : 0.6245


0.44147237120780786 0.6664600855592127 0.7422050880067745 0.2538106886890563
w1 = 1 w2 = 1 0.5259870583657129
0.4588434347830548 0.6813056356847625 0.726552739872618 0.28772742023283726
w1 = 2 w2 = 4 0.5386073076433182


***** Epoch 13 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:11, scheduler_lr:3.1912931867270525e-05
optimizer_lr:3.1912931867270525e-05


Epoch: [13][0/371] Data 1.605 (1.605) Elapsed 0m 1s (remain 11m 56s) Loss: 0.6996(0.6996) Grad: 0.0000  LR: 0.0000319  
tensor(0.6996, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [13][100/371] Data 0.000 (0.022) Elapsed 0m 35s (remain 1m 34s) Loss: 0.2937(0.3511) Grad: 0.0000  LR: 0.0000319  
Epoch: [13][200/371] Data 0.000 (0.011) Elapsed 1m 8s (remain 0m 57s) Loss: 0.2594(0.3459) Grad: 0.0000  LR: 0.0000319  
Epoch: [13][300/371] Data 0.000 (0.007) Elapsed 1m 41s (remain 0m 23s) Loss: 0.4481(0.3449) Grad: 0.0000  LR: 0.0000319  
Epoch: [13][370/371] Data 0.000 (0.006) Elapsed 2m 3s (remain 0m 0s) Loss: 0.2752(0.3394) Grad: 0.0000  LR: 0.0000319  
EVAL: [0/124] Data 0.974 (0.974) Elapsed 0m 1s (remain 2m 13s) Loss: 0.2963(0.2963) 
EVAL: [100/124] Data 0.000 (0.052) Elapsed 0m 15s (remain 0m 3s) Loss: 0.2450(0.3029) 
EVAL: [123/124] Data 0.000 (0.055) Elapsed 0m 19s (remain 0m 0s) Loss: 0.5268(0.3040) 
weight = 1 evaluation score = 0.5487306118011475
weight = 2 evaluation score = 0

Epoch 13 - avg_train_loss: 0.3394  time: 144s
Epoch 13 - avg_val_loss: 0.3040 - avg_val_bce: 0.3021 - avg_val_comp_metric : 0.6081 - w_avg_val_comp_metric : 0.6418
***** Epoch 14 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:12, scheduler_lr:3.0509930306966776e-05
optimizer_lr:3.0509930306966776e-05


0.44303117913179985 0.6664994870027248 0.7225995568716662 0.2446269714742354
w1 = 1 w2 = 1 0.5191892986201065
0.5104887553474988 0.7081598148189997 0.7554716126414138 0.28083960666922814
w1 = 2 w2 = 4 0.5637399473692851
Epoch: [14][0/371] Data 1.936 (1.936) Elapsed 0m 2s (remain 13m 59s) Loss: 0.2211(0.2211) Grad: 0.0000  LR: 0.0000305  
tensor(0.2211, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [14][100/371] Data 0.000 (0.019) Elapsed 0m 35s (remain 1m 34s) Loss: 0.3182(0.3532) Grad: 0.0000  LR: 0.0000305  
Epoch: [14][200/371] Data 0.000 (0.010) Elapsed 1m 8s (remain 0m 57s) Loss: 0.2724(0.3441) Grad: 0.0000  LR: 0.0000305  
Epoch: [14][300/371] Data 0.000 (0.007) Elapsed 1m 40s (remain 0m 23s) Loss: 0.4179(0.3377) Grad: 0.0000  LR: 0.0000305  
Epoch: [14][370/371] Data 0.000 (0.005) Elapsed 2m 3s (remain 0m 0s) Loss: 0.4711(0.3355) Grad: 0.0000  LR: 0.0000305  
EVAL: [0/124] Data 0.859 (0.859) Elapsed 0m 0s (remain 1m 58s) Loss: 0.2847(0.2847) 
EVAL: [100/124] Data 0.000 (0.096)

Epoch 14 - avg_train_loss: 0.3355  time: 147s
Epoch 14 - avg_val_loss: 0.2989 - avg_val_bce: 0.2969 - avg_val_comp_metric : 0.6162 - w_avg_val_comp_metric : 0.6417
***** Epoch 15 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:13, scheduler_lr:2.9031564709112585e-05
optimizer_lr:2.9031564709112585e-05


0.4485568599490395 0.6905235789650456 0.7331581652298359 0.2576312232582993
w1 = 1 w2 = 1 0.5324674568505551
0.5218942405394282 0.6781528334848552 0.7362854417904885 0.31838022566597934
w1 = 2 w2 = 4 0.5636781853701878
Epoch: [15][0/371] Data 2.313 (2.313) Elapsed 0m 2s (remain 16m 20s) Loss: 0.2398(0.2398) Grad: 0.0000  LR: 0.0000290  
tensor(0.2398, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [15][100/371] Data 0.000 (0.033) Elapsed 0m 36s (remain 1m 37s) Loss: 0.2442(0.3257) Grad: 0.0000  LR: 0.0000290  
Epoch: [15][200/371] Data 0.000 (0.021) Elapsed 1m 10s (remain 0m 59s) Loss: 0.4278(0.3373) Grad: 0.0000  LR: 0.0000290  
Epoch: [15][300/371] Data 0.000 (0.014) Elapsed 1m 43s (remain 0m 23s) Loss: 0.2687(0.3418) Grad: 0.0000  LR: 0.0000290  
Epoch: [15][370/371] Data 0.000 (0.015) Elapsed 2m 7s (remain 0m 0s) Loss: 0.0668(0.3408) Grad: 0.0000  LR: 0.0000290  
EVAL: [0/124] Data 0.971 (0.971) Elapsed 0m 1s (remain 2m 12s) Loss: 0.3100(0.3100) 
EVAL: [100/124] Data 0.000 (0.030)

Epoch 15 - avg_train_loss: 0.3408  time: 144s
Epoch 15 - avg_val_loss: 0.2988 - avg_val_bce: 0.2894 - avg_val_comp_metric : 0.6426 - w_avg_val_comp_metric : 0.6312
***** Epoch 16 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:14, scheduler_lr:2.7488486731717165e-05
optimizer_lr:2.7488486731717165e-05


0.477127341279303 0.6886673833008714 0.7694034388889889 0.29142271585760093
w1 = 1 w2 = 1 0.5566552198316911
0.47346976208850367 0.670187358809683 0.727843531528016 0.2590454422902915
w1 = 2 w2 = 4 0.5326365236791235
Epoch: [16][0/371] Data 1.610 (1.610) Elapsed 0m 1s (remain 11m 59s) Loss: 0.3067(0.3067) Grad: 0.0000  LR: 0.0000275  
tensor(0.3067, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [16][100/371] Data 0.000 (0.025) Elapsed 0m 35s (remain 1m 35s) Loss: 0.3555(0.3431) Grad: 0.0000  LR: 0.0000275  
Epoch: [16][200/371] Data 0.052 (0.017) Elapsed 1m 9s (remain 0m 58s) Loss: 0.3650(0.3409) Grad: 0.0000  LR: 0.0000275  
Epoch: [16][300/371] Data 0.000 (0.018) Elapsed 1m 44s (remain 0m 24s) Loss: 0.2760(0.3379) Grad: 0.0000  LR: 0.0000275  
Epoch: [16][370/371] Data 0.000 (0.015) Elapsed 2m 7s (remain 0m 0s) Loss: 0.4350(0.3363) Grad: 0.0000  LR: 0.0000275  
EVAL: [0/124] Data 1.191 (1.191) Elapsed 0m 1s (remain 2m 39s) Loss: 0.2726(0.2726) 
EVAL: [100/124] Data 0.000 (0.031) El

Epoch 16 - avg_train_loss: 0.3363  time: 145s
Epoch 16 - avg_val_loss: 0.2952 - avg_val_bce: 0.2850 - avg_val_comp_metric : 0.5850 - w_avg_val_comp_metric : 0.6168


0.4311995713043213 0.6646934106848551 0.7029762874159668 0.2579591163515421
w1 = 1 w2 = 1 0.5142070964391713
0.49765834331346565 0.6563620966722784 0.703217252263803 0.34963551437779955
w1 = 2 w2 = 4 0.5517183016568367


***** Epoch 17 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:15, scheduler_lr:2.5891814286967183e-05
optimizer_lr:2.5891814286967183e-05


Epoch: [17][0/371] Data 1.897 (1.897) Elapsed 0m 2s (remain 13m 44s) Loss: 0.5513(0.5513) Grad: 0.0000  LR: 0.0000259  
tensor(0.5513, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [17][100/371] Data 0.000 (0.027) Elapsed 0m 36s (remain 1m 36s) Loss: 0.1242(0.3276) Grad: 0.0000  LR: 0.0000259  
Epoch: [17][200/371] Data 0.000 (0.014) Elapsed 1m 8s (remain 0m 58s) Loss: 0.3232(0.3269) Grad: 0.0000  LR: 0.0000259  
Epoch: [17][300/371] Data 0.000 (0.009) Elapsed 1m 41s (remain 0m 23s) Loss: 0.3435(0.3329) Grad: 0.0000  LR: 0.0000259  
Epoch: [17][370/371] Data 0.000 (0.007) Elapsed 2m 4s (remain 0m 0s) Loss: 0.0385(0.3321) Grad: 0.0000  LR: 0.0000259  
EVAL: [0/124] Data 0.836 (0.836) Elapsed 0m 0s (remain 1m 55s) Loss: 0.2583(0.2583) 
EVAL: [100/124] Data 0.000 (0.064) Elapsed 0m 16s (remain 0m 3s) Loss: 0.2563(0.2938) 
EVAL: [123/124] Data 0.000 (0.067) Elapsed 0m 21s (remain 0m 0s) Loss: 0.4966(0.2943) 
weight = 1 evaluation score = 0.5419021844863892
weight = 2 evaluation score = 0

Epoch 17 - avg_train_loss: 0.3321  time: 146s
Epoch 17 - avg_val_loss: 0.2943 - avg_val_bce: 0.2983 - avg_val_comp_metric : 0.6064 - w_avg_val_comp_metric : 0.6413
***** Epoch 18 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:16, scheduler_lr:2.425305143641475e-05
optimizer_lr:2.425305143641475e-05


0.4327260378370324 0.6554664439108571 0.7129616206206089 0.25715410318744786
w1 = 1 w2 = 1 0.5145770513889866
0.49667158542134415 0.697149408347136 0.7396315157992459 0.31231094141406696
w1 = 2 w2 = 4 0.5614408627454482
Epoch: [18][0/371] Data 1.640 (1.640) Elapsed 0m 1s (remain 12m 9s) Loss: 0.2743(0.2743) Grad: 0.0000  LR: 0.0000243  
tensor(0.2743, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [18][100/371] Data 0.000 (0.048) Elapsed 0m 38s (remain 1m 42s) Loss: 0.2523(0.3197) Grad: 0.0000  LR: 0.0000243  
Epoch: [18][200/371] Data 0.181 (0.029) Elapsed 1m 12s (remain 1m 1s) Loss: 0.3564(0.3273) Grad: 0.0000  LR: 0.0000243  
Epoch: [18][300/371] Data 0.000 (0.021) Elapsed 1m 45s (remain 0m 24s) Loss: 0.3290(0.3293) Grad: 0.0000  LR: 0.0000243  
Epoch: [18][370/371] Data 0.000 (0.017) Elapsed 2m 8s (remain 0m 0s) Loss: 0.1995(0.3296) Grad: 0.0000  LR: 0.0000243  
EVAL: [0/124] Data 0.929 (0.929) Elapsed 0m 1s (remain 2m 7s) Loss: 0.3010(0.3010) 
EVAL: [100/124] Data 0.000 (0.079) E

Epoch 18 - avg_train_loss: 0.3296  time: 150s
Epoch 18 - avg_val_loss: 0.2930 - avg_val_bce: 0.2871 - avg_val_comp_metric : 0.6069 - w_avg_val_comp_metric : 0.6213
***** Epoch 19 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:17, scheduler_lr:2.2584005503950234e-05
optimizer_lr:2.2584005503950234e-05


0.44962804235522164 0.647039780103251 0.7398806341841014 0.25482112227977877
w1 = 1 w2 = 1 0.5228423947305882
0.4605028955277191 0.6791896029374144 0.7180216024692523 0.29226014684450907
w1 = 2 w2 = 4 0.5374935619447238
Epoch: [19][0/371] Data 1.708 (1.708) Elapsed 0m 2s (remain 12m 35s) Loss: 0.2844(0.2844) Grad: 0.0000  LR: 0.0000226  
tensor(0.2844, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [19][100/371] Data 0.000 (0.033) Elapsed 0m 36s (remain 1m 37s) Loss: 0.5861(0.3354) Grad: 0.0000  LR: 0.0000226  
Epoch: [19][200/371] Data 0.000 (0.020) Elapsed 1m 10s (remain 0m 59s) Loss: 0.2205(0.3241) Grad: 0.0000  LR: 0.0000226  
Epoch: [19][300/371] Data 0.000 (0.014) Elapsed 1m 43s (remain 0m 24s) Loss: 0.3309(0.3235) Grad: 0.0000  LR: 0.0000226  
Epoch: [19][370/371] Data 0.000 (0.011) Elapsed 2m 6s (remain 0m 0s) Loss: 0.5598(0.3280) Grad: 0.0000  LR: 0.0000226  
EVAL: [0/124] Data 1.421 (1.421) Elapsed 0m 1s (remain 3m 7s) Loss: 0.2847(0.2847) 
EVAL: [100/124] Data 0.000 (0.040)

Epoch 19 - avg_train_loss: 0.3280  time: 144s
Epoch 19 - avg_val_loss: 0.2933 - avg_val_bce: 0.2995 - avg_val_comp_metric : 0.6103 - w_avg_val_comp_metric : 0.6505
***** Epoch 20 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:18, scheduler_lr:2.0896702003763166e-05
optimizer_lr:2.0896702003763166e-05


0.4359242637299125 0.6657660671992655 0.7114114203636848 0.257185105162947
w1 = 1 w2 = 1 0.5175717141139524
0.5033150078345504 0.7009720815162065 0.7431028412134201 0.3413075329940943
w1 = 2 w2 = 4 0.5721743658895678
Epoch: [20][0/371] Data 2.285 (2.285) Elapsed 0m 2s (remain 16m 8s) Loss: 0.2182(0.2182) Grad: 0.0000  LR: 0.0000209  
tensor(0.2182, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [20][100/371] Data 0.000 (0.142) Elapsed 0m 47s (remain 2m 7s) Loss: 0.6258(0.3348) Grad: 0.0000  LR: 0.0000209  
Epoch: [20][200/371] Data 0.000 (0.124) Elapsed 1m 31s (remain 1m 17s) Loss: 0.2557(0.3297) Grad: 0.0000  LR: 0.0000209  
Epoch: [20][300/371] Data 0.000 (0.132) Elapsed 2m 19s (remain 0m 32s) Loss: 0.1286(0.3244) Grad: 0.0000  LR: 0.0000209  
Epoch: [20][370/371] Data 0.000 (0.119) Elapsed 2m 46s (remain 0m 0s) Loss: 0.1919(0.3244) Grad: 0.0000  LR: 0.0000209  
EVAL: [0/124] Data 0.943 (0.943) Elapsed 0m 1s (remain 2m 9s) Loss: 0.3119(0.3119) 
EVAL: [100/124] Data 0.000 (0.027) Ela

Epoch 20 - avg_train_loss: 0.3244  time: 186s
Epoch 20 - avg_val_loss: 0.2932 - avg_val_bce: 0.2800 - avg_val_comp_metric : 0.6193 - w_avg_val_comp_metric : 0.6082
***** Epoch 21 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:19, scheduler_lr:1.920329799623685e-05
optimizer_lr:1.920329799623685e-05


0.4635095209929564 0.6551026514413429 0.7517918957226359 0.2787650845272079
w1 = 1 w2 = 1 0.5372922881710358
0.44549012841287833 0.6561032481878195 0.6937193311853753 0.2655745779956612
w1 = 2 w2 = 4 0.5152218214454336
Epoch: [21][0/371] Data 1.994 (1.994) Elapsed 0m 2s (remain 14m 23s) Loss: 0.3198(0.3198) Grad: 0.0000  LR: 0.0000192  
tensor(0.3198, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [21][100/371] Data 0.000 (0.048) Elapsed 0m 38s (remain 1m 42s) Loss: 0.3408(0.3457) Grad: 0.0000  LR: 0.0000192  
Epoch: [21][200/371] Data 0.000 (0.027) Elapsed 1m 11s (remain 1m 0s) Loss: 0.4933(0.3307) Grad: 0.0000  LR: 0.0000192  
Epoch: [21][300/371] Data 0.000 (0.018) Elapsed 1m 44s (remain 0m 24s) Loss: 0.1267(0.3266) Grad: 0.0000  LR: 0.0000192  
Epoch: [21][370/371] Data 0.000 (0.015) Elapsed 2m 7s (remain 0m 0s) Loss: 0.5614(0.3256) Grad: 0.0000  LR: 0.0000192  
EVAL: [0/124] Data 1.035 (1.035) Elapsed 0m 1s (remain 2m 20s) Loss: 0.2940(0.2940) 
EVAL: [100/124] Data 0.000 (0.036) 

Epoch 21 - avg_train_loss: 0.3256  time: 146s
Epoch 21 - avg_val_loss: 0.2929 - avg_val_bce: 0.2864 - avg_val_comp_metric : 0.5978 - w_avg_val_comp_metric : 0.6120


0.42392185042312913 0.6775899336533698 0.7273244922570834 0.241831109112575
w1 = 1 w2 = 1 0.5176668463615394
0.48256734410814134 0.6507594691136132 0.7132042531708664 0.27453595592383606
w1 = 2 w2 = 4 0.5302667555791143


***** Epoch 22 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:20, scheduler_lr:1.751599449604977e-05
optimizer_lr:1.751599449604977e-05


Epoch: [22][0/371] Data 1.651 (1.651) Elapsed 0m 1s (remain 12m 13s) Loss: 0.3391(0.3391) Grad: 0.0000  LR: 0.0000175  
tensor(0.3391, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [22][100/371] Data 0.000 (0.029) Elapsed 0m 36s (remain 1m 36s) Loss: 0.3027(0.3187) Grad: 0.0000  LR: 0.0000175  
Epoch: [22][200/371] Data 0.000 (0.015) Elapsed 1m 9s (remain 0m 58s) Loss: 0.2742(0.3147) Grad: 0.0000  LR: 0.0000175  
Epoch: [22][300/371] Data 0.000 (0.010) Elapsed 1m 42s (remain 0m 23s) Loss: 0.5289(0.3173) Grad: 0.0000  LR: 0.0000175  
Epoch: [22][370/371] Data 0.000 (0.008) Elapsed 2m 4s (remain 0m 0s) Loss: 0.2449(0.3195) Grad: 0.0000  LR: 0.0000175  
EVAL: [0/124] Data 1.101 (1.101) Elapsed 0m 1s (remain 2m 28s) Loss: 0.2682(0.2682) 
EVAL: [100/124] Data 0.000 (0.042) Elapsed 0m 14s (remain 0m 3s) Loss: 0.2523(0.2915) 
EVAL: [123/124] Data 0.000 (0.039) Elapsed 0m 17s (remain 0m 0s) Loss: 0.4623(0.2914) 
weight = 1 evaluation score = 0.5231600999832153
weight = 2 evaluation score = 0

Epoch 22 - avg_train_loss: 0.3195  time: 146s
Epoch 22 - avg_val_loss: 0.2914 - avg_val_bce: 0.2970 - avg_val_comp_metric : 0.5908 - w_avg_val_comp_metric : 0.6274


0.4107591590794617 0.6620408243676731 0.7105361589532653 0.2383302038240292
w1 = 1 w2 = 1 0.5054165865561073
0.4786362702692672 0.6819193724599643 0.7449943909197823 0.30835742837673036
w1 = 2 w2 = 4 0.553476865506436


***** Epoch 23 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:21, scheduler_lr:1.5846948563585264e-05
optimizer_lr:1.5846948563585264e-05


Epoch: [23][0/371] Data 2.021 (2.021) Elapsed 0m 2s (remain 14m 32s) Loss: 0.6158(0.6158) Grad: 0.0000  LR: 0.0000158  
tensor(0.6158, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [23][100/371] Data 0.663 (0.119) Elapsed 0m 45s (remain 2m 1s) Loss: 0.2985(0.3181) Grad: 0.0000  LR: 0.0000158  
Epoch: [23][200/371] Data 0.136 (0.135) Elapsed 1m 33s (remain 1m 19s) Loss: 0.5022(0.3227) Grad: 0.0000  LR: 0.0000158  
Epoch: [23][300/371] Data 0.000 (0.116) Elapsed 2m 14s (remain 0m 31s) Loss: 0.4884(0.3216) Grad: 0.0000  LR: 0.0000158  
Epoch: [23][370/371] Data 0.000 (0.100) Elapsed 2m 39s (remain 0m 0s) Loss: 0.4500(0.3194) Grad: 0.0000  LR: 0.0000158  
EVAL: [0/124] Data 0.924 (0.924) Elapsed 0m 1s (remain 2m 6s) Loss: 0.3205(0.3205) 
EVAL: [100/124] Data 0.000 (0.074) Elapsed 0m 17s (remain 0m 4s) Loss: 0.1820(0.2890) 
EVAL: [123/124] Data 0.000 (0.064) Elapsed 0m 20s (remain 0m 0s) Loss: 0.5265(0.2903) 
weight = 1 evaluation score = 0.5872368216514587
weight = 2 evaluation score = 0

Epoch 23 - avg_train_loss: 0.3194  time: 181s
Epoch 23 - avg_val_loss: 0.2903 - avg_val_bce: 0.2735 - avg_val_comp_metric : 0.5995 - w_avg_val_comp_metric : 0.5879
***** Epoch 24 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:22, scheduler_lr:1.4208185713032831e-05
optimizer_lr:1.4208185713032831e-05


0.44615447098033845 0.660812092918453 0.7366352232323389 0.2549493845663247
w1 = 1 w2 = 1 0.5246377929243639
0.4358343803912802 0.63066053405152 0.6881555388289877 0.2502889562531168
w1 = 2 w2 = 4 0.5012348523812262
Epoch: [24][0/371] Data 2.053 (2.053) Elapsed 0m 2s (remain 14m 45s) Loss: 0.2847(0.2847) Grad: 0.0000  LR: 0.0000142  
tensor(0.2847, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [24][100/371] Data 0.000 (0.024) Elapsed 0m 35s (remain 1m 35s) Loss: 0.7135(0.3169) Grad: 0.0000  LR: 0.0000142  
Epoch: [24][200/371] Data 0.000 (0.012) Elapsed 1m 8s (remain 0m 58s) Loss: 0.3828(0.3159) Grad: 0.0000  LR: 0.0000142  
Epoch: [24][300/371] Data 0.000 (0.009) Elapsed 1m 41s (remain 0m 23s) Loss: 0.3817(0.3186) Grad: 0.0000  LR: 0.0000142  
Epoch: [24][370/371] Data 0.000 (0.007) Elapsed 2m 4s (remain 0m 0s) Loss: 0.1129(0.3191) Grad: 0.0000  LR: 0.0000142  
EVAL: [0/124] Data 1.282 (1.282) Elapsed 0m 1s (remain 2m 50s) Loss: 0.2803(0.2803) 
EVAL: [100/124] Data 0.000 (0.058) Ela

Epoch 24 - avg_train_loss: 0.3191  time: 145s
Epoch 24 - avg_val_loss: 0.2851 - avg_val_bce: 0.2813 - avg_val_comp_metric : 0.5883 - w_avg_val_comp_metric : 0.6087


0.4198465152337722 0.648150262138491 0.706940385163212 0.24729541351921677
w1 = 1 w2 = 1 0.505558144013673
0.4519370109454487 0.6640158974569664 0.6994301275029864 0.29544363115086997
w1 = 2 w2 = 4 0.5277066667640679


***** Epoch 25 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:23, scheduler_lr:1.2611513268282851e-05
optimizer_lr:1.2611513268282851e-05


Epoch: [25][0/371] Data 1.676 (1.676) Elapsed 0m 2s (remain 12m 23s) Loss: 0.2460(0.2460) Grad: 0.0000  LR: 0.0000126  
tensor(0.2460, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [25][100/371] Data 0.000 (0.036) Elapsed 0m 36s (remain 1m 38s) Loss: 0.1539(0.3238) Grad: 0.0000  LR: 0.0000126  
Epoch: [25][200/371] Data 0.000 (0.021) Elapsed 1m 10s (remain 0m 59s) Loss: 0.2917(0.3252) Grad: 0.0000  LR: 0.0000126  
Epoch: [25][300/371] Data 0.000 (0.014) Elapsed 1m 43s (remain 0m 24s) Loss: 0.3633(0.3207) Grad: 0.0000  LR: 0.0000126  
Epoch: [25][370/371] Data 0.000 (0.012) Elapsed 2m 6s (remain 0m 0s) Loss: 0.4959(0.3196) Grad: 0.0000  LR: 0.0000126  
EVAL: [0/124] Data 0.928 (0.928) Elapsed 0m 1s (remain 2m 7s) Loss: 0.2620(0.2620) 
EVAL: [100/124] Data 0.485 (0.038) Elapsed 0m 14s (remain 0m 3s) Loss: 0.2292(0.2809) 
EVAL: [123/124] Data 0.000 (0.032) Elapsed 0m 16s (remain 0m 0s) Loss: 0.4679(0.2816) 
weight = 1 evaluation score = 0.519325852394104
weight = 2 evaluation score = 0.

Epoch 25 - avg_train_loss: 0.3196  time: 144s
Epoch 25 - avg_val_loss: 0.2816 - avg_val_bce: 0.2796 - avg_val_comp_metric : 0.5733 - w_avg_val_comp_metric : 0.5959


0.39613140771900185 0.6456346636506298 0.6982312927172619 0.23191196091338223
w1 = 1 w2 = 1 0.492977331250069
0.4409083600163803 0.6481914401405531 0.698841393780732 0.2831102634231226
w1 = 2 w2 = 4 0.517762864340197


***** Epoch 26 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:24, scheduler_lr:1.106843529088742e-05
optimizer_lr:1.106843529088742e-05


Epoch: [26][0/371] Data 1.589 (1.589) Elapsed 0m 1s (remain 11m 51s) Loss: 0.1406(0.1406) Grad: 0.0000  LR: 0.0000111  
tensor(0.1406, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [26][100/371] Data 0.000 (0.032) Elapsed 0m 36s (remain 1m 37s) Loss: 0.3003(0.3262) Grad: 0.0000  LR: 0.0000111  
Epoch: [26][200/371] Data 0.000 (0.016) Elapsed 1m 9s (remain 0m 58s) Loss: 0.3469(0.3219) Grad: 0.0000  LR: 0.0000111  
Epoch: [26][300/371] Data 0.080 (0.016) Elapsed 1m 43s (remain 0m 24s) Loss: 0.2307(0.3131) Grad: 0.0000  LR: 0.0000111  
Epoch: [26][370/371] Data 0.000 (0.015) Elapsed 2m 7s (remain 0m 0s) Loss: 0.2884(0.3142) Grad: 0.0000  LR: 0.0000111  
EVAL: [0/124] Data 1.193 (1.193) Elapsed 0m 1s (remain 2m 40s) Loss: 0.2680(0.2680) 
EVAL: [100/124] Data 0.000 (0.043) Elapsed 0m 14s (remain 0m 3s) Loss: 0.2129(0.2835) 
EVAL: [123/124] Data 0.000 (0.038) Elapsed 0m 17s (remain 0m 0s) Loss: 0.4729(0.2843) 
weight = 1 evaluation score = 0.540757417678833
weight = 2 evaluation score = 0.

Epoch 26 - avg_train_loss: 0.3142  time: 146s
Epoch 26 - avg_val_loss: 0.2843 - avg_val_bce: 0.2770 - avg_val_comp_metric : 0.5815 - w_avg_val_comp_metric : 0.5952
***** Epoch 27 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:25, scheduler_lr:9.590069693033237e-06
optimizer_lr:9.590069693033237e-06


0.41297740352828777 0.657667226743254 0.7022477660203644 0.2425963675938028
w1 = 1 w2 = 1 0.5038721909714272
0.443915336266985 0.6431016214396016 0.688384688921735 0.286384481998244
w1 = 2 w2 = 4 0.5154465321566414
Epoch: [27][0/371] Data 1.455 (1.455) Elapsed 0m 1s (remain 11m 1s) Loss: 0.3245(0.3245) Grad: 0.0000  LR: 0.0000096  
tensor(0.3245, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [27][100/371] Data 0.000 (0.034) Elapsed 0m 36s (remain 1m 38s) Loss: 0.4856(0.3224) Grad: 0.0000  LR: 0.0000096  
Epoch: [27][200/371] Data 0.000 (0.052) Elapsed 1m 17s (remain 1m 5s) Loss: 0.2219(0.3264) Grad: 0.0000  LR: 0.0000096  
Epoch: [27][300/371] Data 0.000 (0.051) Elapsed 1m 54s (remain 0m 26s) Loss: 0.1309(0.3164) Grad: 0.0000  LR: 0.0000096  
Epoch: [27][370/371] Data 0.001 (0.056) Elapsed 2m 23s (remain 0m 0s) Loss: 0.3892(0.3175) Grad: 0.0000  LR: 0.0000096  
EVAL: [0/124] Data 1.010 (1.010) Elapsed 0m 1s (remain 2m 17s) Loss: 0.2706(0.2706) 
EVAL: [100/124] Data 0.000 (0.032) Elap

Epoch 27 - avg_train_loss: 0.3175  time: 163s
Epoch 27 - avg_val_loss: 0.2798 - avg_val_bce: 0.2784 - avg_val_comp_metric : 0.5671 - w_avg_val_comp_metric : 0.5873


0.3852964712800954 0.6491705429146302 0.7036127610464306 0.22393939735490526
w1 = 1 w2 = 1 0.49050479314901535
0.4215037436168295 0.6561000965279887 0.6945006610713208 0.2787467047183357
w1 = 2 w2 = 4 0.5127128014836186


***** Epoch 28 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:26, scheduler_lr:8.187068132729488e-06
optimizer_lr:8.187068132729488e-06


Epoch: [28][0/371] Data 2.340 (2.340) Elapsed 0m 2s (remain 16m 33s) Loss: 0.3171(0.3171) Grad: 0.0000  LR: 0.0000082  
tensor(0.3171, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [28][100/371] Data 0.000 (0.039) Elapsed 0m 37s (remain 1m 39s) Loss: 0.2452(0.3213) Grad: 0.0000  LR: 0.0000082  
Epoch: [28][200/371] Data 0.000 (0.020) Elapsed 1m 10s (remain 0m 59s) Loss: 0.2399(0.3145) Grad: 0.0000  LR: 0.0000082  
Epoch: [28][300/371] Data 0.000 (0.013) Elapsed 1m 43s (remain 0m 24s) Loss: 0.1866(0.3142) Grad: 0.0000  LR: 0.0000082  
Epoch: [28][370/371] Data 0.000 (0.011) Elapsed 2m 6s (remain 0m 0s) Loss: 0.0407(0.3139) Grad: 0.0000  LR: 0.0000082  
EVAL: [0/124] Data 0.924 (0.924) Elapsed 0m 1s (remain 2m 6s) Loss: 0.2930(0.2930) 
EVAL: [100/124] Data 0.001 (0.019) Elapsed 0m 12s (remain 0m 2s) Loss: 0.2090(0.2838) 
EVAL: [123/124] Data 0.000 (0.022) Elapsed 0m 15s (remain 0m 0s) Loss: 0.5068(0.2842) 
weight = 1 evaluation score = 0.5604969263076782
weight = 2 evaluation score = 0

Epoch 28 - avg_train_loss: 0.3139  time: 142s
Epoch 28 - avg_val_loss: 0.2842 - avg_val_bce: 0.2770 - avg_val_comp_metric : 0.5942 - w_avg_val_comp_metric : 0.5989
***** Epoch 29 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:27, scheduler_lr:6.869539268367882e-06
optimizer_lr:6.869539268367882e-06


0.42443524945814165 0.6511649066692532 0.7256727869577088 0.2542097806260858
w1 = 1 w2 = 1 0.5138706809277974
0.43222759161928875 0.6508813330311078 0.6912640289432158 0.278514262120595
w1 = 2 w2 = 4 0.5132218039285519
Epoch: [29][0/371] Data 1.573 (1.573) Elapsed 0m 1s (remain 11m 45s) Loss: 0.1471(0.1471) Grad: 0.0000  LR: 0.0000069  
tensor(0.1471, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [29][100/371] Data 0.139 (0.055) Elapsed 0m 38s (remain 1m 43s) Loss: 0.2516(0.3229) Grad: 0.0000  LR: 0.0000069  
Epoch: [29][200/371] Data 0.000 (0.046) Elapsed 1m 15s (remain 1m 3s) Loss: 0.2691(0.3156) Grad: 0.0000  LR: 0.0000069  
Epoch: [29][300/371] Data 0.000 (0.041) Elapsed 1m 51s (remain 0m 26s) Loss: 0.3673(0.3140) Grad: 0.0000  LR: 0.0000069  
Epoch: [29][370/371] Data 0.000 (0.035) Elapsed 2m 14s (remain 0m 0s) Loss: 0.2050(0.3166) Grad: 0.0000  LR: 0.0000069  
EVAL: [0/124] Data 0.830 (0.830) Elapsed 0m 0s (remain 1m 55s) Loss: 0.2839(0.2839) 
EVAL: [100/124] Data 0.000 (0.054)

Epoch 29 - avg_train_loss: 0.3166  time: 155s
Epoch 29 - avg_val_loss: 0.2820 - avg_val_bce: 0.2790 - avg_val_comp_metric : 0.5786 - w_avg_val_comp_metric : 0.5986
***** Epoch 30 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:28, scheduler_lr:5.646975925607241e-06
optimizer_lr:5.646975925607241e-06


0.3996651064173348 0.6494145104737052 0.7016319716131888 0.23764040823989716
w1 = 1 w2 = 1 0.49708799918603147
0.44012614438645853 0.6506401972550897 0.6964113123831269 0.28478581749463994
w1 = 2 w2 = 4 0.5179908678798287
Epoch: [30][0/371] Data 1.959 (1.959) Elapsed 0m 2s (remain 14m 7s) Loss: 0.2001(0.2001) Grad: 0.0000  LR: 0.0000056  
tensor(0.2001, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [30][100/371] Data 0.000 (0.028) Elapsed 0m 36s (remain 1m 36s) Loss: 0.4258(0.3093) Grad: 0.0000  LR: 0.0000056  
Epoch: [30][200/371] Data 0.000 (0.017) Elapsed 1m 9s (remain 0m 58s) Loss: 0.4668(0.3144) Grad: 0.0000  LR: 0.0000056  
Epoch: [30][300/371] Data 0.000 (0.011) Elapsed 1m 42s (remain 0m 23s) Loss: 0.3094(0.3082) Grad: 0.0000  LR: 0.0000056  
Epoch: [30][370/371] Data 0.000 (0.009) Elapsed 2m 5s (remain 0m 0s) Loss: 0.3871(0.3077) Grad: 0.0000  LR: 0.0000056  
EVAL: [0/124] Data 1.068 (1.068) Elapsed 0m 1s (remain 2m 24s) Loss: 0.2922(0.2922) 
EVAL: [100/124] Data 0.030 (0.066

Epoch 30 - avg_train_loss: 0.3077  time: 147s
Epoch 30 - avg_val_loss: 0.2810 - avg_val_bce: 0.2720 - avg_val_comp_metric : 0.5733 - w_avg_val_comp_metric : 0.5843
***** Epoch 31 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:29, scheduler_lr:4.528186701347132e-06
optimizer_lr:4.528186701347132e-06


0.40006592449847406 0.6486039453776526 0.70712172265937 0.2376592926392154
w1 = 1 w2 = 1 0.49836272129367803
0.42969403326600525 0.6337158848719537 0.6763469835565936 0.28455942200837253
w1 = 2 w2 = 4 0.5060790809257313
Epoch: [31][0/371] Data 2.158 (2.158) Elapsed 0m 2s (remain 15m 21s) Loss: 0.4109(0.4109) Grad: 0.0000  LR: 0.0000045  
tensor(0.4109, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [31][100/371] Data 0.000 (0.045) Elapsed 0m 37s (remain 1m 41s) Loss: 0.3004(0.3200) Grad: 0.0000  LR: 0.0000045  
Epoch: [31][200/371] Data 0.000 (0.061) Elapsed 1m 18s (remain 1m 6s) Loss: 0.3650(0.3122) Grad: 0.0000  LR: 0.0000045  
Epoch: [31][300/371] Data 0.000 (0.068) Elapsed 1m 59s (remain 0m 27s) Loss: 0.3873(0.3109) Grad: 0.0000  LR: 0.0000045  
Epoch: [31][370/371] Data 0.000 (0.055) Elapsed 2m 22s (remain 0m 0s) Loss: 0.2213(0.3141) Grad: 0.0000  LR: 0.0000045  
EVAL: [0/124] Data 1.611 (1.611) Elapsed 0m 1s (remain 3m 31s) Loss: 0.2890(0.2890) 
EVAL: [100/124] Data 0.000 (0.072

Epoch 31 - avg_train_loss: 0.3141  time: 164s
Epoch 31 - avg_val_loss: 0.2832 - avg_val_bce: 0.2721 - avg_val_comp_metric : 0.5647 - w_avg_val_comp_metric : 0.5789
***** Epoch 32 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:30, scheduler_lr:3.5212324975853526e-06
optimizer_lr:3.5212324975853526e-06


0.3905361062986696 0.6575684731972347 0.7046748964353261 0.22586287684092252
w1 = 1 w2 = 1 0.4946605881930382
0.427346054582833 0.6329515635445755 0.6784124948033929 0.2815473660085573
w1 = 2 w2 = 4 0.5050643697348397
Epoch: [32][0/371] Data 1.544 (1.544) Elapsed 0m 1s (remain 11m 33s) Loss: 0.2621(0.2621) Grad: 0.0000  LR: 0.0000035  
tensor(0.2621, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [32][100/371] Data 0.000 (0.041) Elapsed 0m 37s (remain 1m 40s) Loss: 0.4373(0.3131) Grad: 0.0000  LR: 0.0000035  
Epoch: [32][200/371] Data 0.000 (0.021) Elapsed 1m 10s (remain 0m 59s) Loss: 0.1470(0.3041) Grad: 0.0000  LR: 0.0000035  
Epoch: [32][300/371] Data 0.000 (0.017) Elapsed 1m 44s (remain 0m 24s) Loss: 0.3060(0.3031) Grad: 0.0000  LR: 0.0000035  
Epoch: [32][370/371] Data 0.000 (0.020) Elapsed 2m 9s (remain 0m 0s) Loss: 0.3737(0.3042) Grad: 0.0000  LR: 0.0000035  
EVAL: [0/124] Data 0.917 (0.917) Elapsed 0m 1s (remain 2m 6s) Loss: 0.2743(0.2743) 
EVAL: [100/124] Data 0.000 (0.030) E

Epoch 32 - avg_train_loss: 0.3042  time: 148s
Epoch 32 - avg_val_loss: 0.2807 - avg_val_bce: 0.2788 - avg_val_comp_metric : 0.5699 - w_avg_val_comp_metric : 0.5938
***** Epoch 33 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:31, scheduler_lr:2.6333684424343023e-06
optimizer_lr:2.6333684424343023e-06


0.3950737213950877 0.64081655071457 0.6913888674125511 0.23170105150343945
w1 = 1 w2 = 1 0.48974504775641203
0.4393570741429208 0.6491957798786389 0.6959744164794675 0.2882956983560064
w1 = 2 w2 = 4 0.5182057422142584
Epoch: [33][0/371] Data 2.007 (2.007) Elapsed 0m 2s (remain 14m 26s) Loss: 0.1925(0.1925) Grad: 0.0000  LR: 0.0000026  
tensor(0.1925, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [33][100/371] Data 0.000 (0.058) Elapsed 0m 39s (remain 1m 44s) Loss: 0.3715(0.2999) Grad: 0.0000  LR: 0.0000026  
Epoch: [33][200/371] Data 0.023 (0.043) Elapsed 1m 15s (remain 1m 3s) Loss: 0.3482(0.3119) Grad: 0.0000  LR: 0.0000026  
Epoch: [33][300/371] Data 0.000 (0.041) Elapsed 1m 51s (remain 0m 26s) Loss: 0.2459(0.3094) Grad: 0.0000  LR: 0.0000026  
Epoch: [33][370/371] Data 0.000 (0.040) Elapsed 2m 17s (remain 0m 0s) Loss: 0.3284(0.3080) Grad: 0.0000  LR: 0.0000026  
EVAL: [0/124] Data 1.032 (1.032) Elapsed 0m 1s (remain 2m 20s) Loss: 0.2879(0.2879) 
EVAL: [100/124] Data 0.000 (0.031) 

Epoch 33 - avg_train_loss: 0.3080  time: 154s
Epoch 33 - avg_val_loss: 0.2811 - avg_val_bce: 0.2745 - avg_val_comp_metric : 0.5784 - w_avg_val_comp_metric : 0.5902
***** Epoch 34 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:32, scheduler_lr:1.870991616756696e-06
optimizer_lr:1.870991616756696e-06


0.40665479603455174 0.6474304432621496 0.7114008911483349 0.238252375694086
w1 = 1 w2 = 1 0.5009346265347806
0.43030011389555195 0.6416656159798062 0.6827507997429452 0.2814006043489358
w1 = 2 w2 = 4 0.5090292834918098
Epoch: [34][0/371] Data 2.243 (2.243) Elapsed 0m 2s (remain 15m 53s) Loss: 0.1125(0.1125) Grad: 0.0000  LR: 0.0000019  
tensor(0.1125, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [34][100/371] Data 0.000 (0.034) Elapsed 0m 36s (remain 1m 38s) Loss: 0.2450(0.3205) Grad: 0.0000  LR: 0.0000019  
Epoch: [34][200/371] Data 0.000 (0.017) Elapsed 1m 10s (remain 0m 59s) Loss: 0.3097(0.3114) Grad: 0.0000  LR: 0.0000019  
Epoch: [34][300/371] Data 0.000 (0.012) Elapsed 1m 43s (remain 0m 23s) Loss: 0.3663(0.3128) Grad: 0.0000  LR: 0.0000019  
Epoch: [34][370/371] Data 0.000 (0.009) Elapsed 2m 5s (remain 0m 0s) Loss: 0.2522(0.3086) Grad: 0.0000  LR: 0.0000019  
EVAL: [0/124] Data 1.009 (1.009) Elapsed 0m 1s (remain 2m 17s) Loss: 0.2814(0.2814) 
EVAL: [100/124] Data 0.311 (0.062)

Epoch 34 - avg_train_loss: 0.3086  time: 150s
Epoch 34 - avg_val_loss: 0.2785 - avg_val_bce: 0.2763 - avg_val_comp_metric : 0.5693 - w_avg_val_comp_metric : 0.5899
***** Epoch 35 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:33, scheduler_lr:1.2395949630513709e-06
optimizer_lr:1.2395949630513709e-06


0.39171829686230725 0.6384557319079793 0.6969056009173846 0.23058253810729792
w1 = 1 w2 = 1 0.48941554194874226
0.42657787568946876 0.6496577678891546 0.6883295453410639 0.280850543874454
w1 = 2 w2 = 4 0.5113539331985353
Epoch: [35][0/371] Data 2.255 (2.255) Elapsed 0m 2s (remain 16m 1s) Loss: 0.3646(0.3646) Grad: 0.0000  LR: 0.0000012  
tensor(0.3646, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [35][100/371] Data 0.000 (0.078) Elapsed 0m 41s (remain 1m 50s) Loss: 0.3628(0.3021) Grad: 0.0000  LR: 0.0000012  
Epoch: [35][200/371] Data 0.000 (0.039) Elapsed 1m 14s (remain 1m 2s) Loss: 0.3242(0.2966) Grad: 0.0000  LR: 0.0000012  
Epoch: [35][300/371] Data 0.000 (0.026) Elapsed 1m 47s (remain 0m 24s) Loss: 0.1041(0.3052) Grad: 0.0000  LR: 0.0000012  
Epoch: [35][370/371] Data 0.000 (0.023) Elapsed 2m 10s (remain 0m 0s) Loss: 0.1066(0.3062) Grad: 0.0000  LR: 0.0000012  
EVAL: [0/124] Data 1.082 (1.082) Elapsed 0m 1s (remain 2m 25s) Loss: 0.2907(0.2907) 
EVAL: [100/124] Data 0.000 (0.051

Epoch 35 - avg_train_loss: 0.3062  time: 150s
Epoch 35 - avg_val_loss: 0.2797 - avg_val_bce: 0.2734 - avg_val_comp_metric : 0.5692 - w_avg_val_comp_metric : 0.5802
***** Epoch 36 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:34, scheduler_lr:7.437277086766979e-07
optimizer_lr:7.437277086766979e-07


0.3950268154792734 0.6452700067467085 0.7029201439790926 0.2293678416958564
w1 = 1 w2 = 1 0.49314620197523273
0.4191927661413205 0.6363108307205 0.683359159671422 0.2668640461912827
w1 = 2 w2 = 4 0.5014317006811313
Epoch: [36][0/371] Data 1.928 (1.928) Elapsed 0m 2s (remain 13m 56s) Loss: 0.2904(0.2904) Grad: 0.0000  LR: 0.0000007  
tensor(0.2904, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [36][100/371] Data 0.000 (0.027) Elapsed 0m 36s (remain 1m 36s) Loss: 0.6546(0.3164) Grad: 0.0000  LR: 0.0000007  
Epoch: [36][200/371] Data 0.000 (0.024) Elapsed 1m 10s (remain 0m 59s) Loss: 0.1889(0.3074) Grad: 0.0000  LR: 0.0000007  
Epoch: [36][300/371] Data 0.000 (0.018) Elapsed 1m 44s (remain 0m 24s) Loss: 0.2198(0.3143) Grad: 0.0000  LR: 0.0000007  
Epoch: [36][370/371] Data 0.000 (0.014) Elapsed 2m 6s (remain 0m 0s) Loss: 0.1646(0.3109) Grad: 0.0000  LR: 0.0000007  
EVAL: [0/124] Data 0.860 (0.860) Elapsed 0m 0s (remain 1m 58s) Loss: 0.2871(0.2871) 
EVAL: [100/124] Data 0.000 (0.014) Ela

Epoch 36 - avg_train_loss: 0.3109  time: 142s
Epoch 36 - avg_val_loss: 0.2797 - avg_val_bce: 0.2731 - avg_val_comp_metric : 0.5686 - w_avg_val_comp_metric : 0.5811
***** Epoch 37 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:35, scheduler_lr:3.8696258856321644e-07
optimizer_lr:3.8696258856321644e-07


0.39466521002516985 0.6451837911622225 0.7006143656413315 0.2290379961691642
w1 = 1 w2 = 1 0.492375340749472
0.4212305904867738 0.6361957447746177 0.681860110263621 0.268811018270642
w1 = 2 w2 = 4 0.5020243659489136
Epoch: [37][0/371] Data 2.517 (2.517) Elapsed 0m 2s (remain 17m 35s) Loss: 0.2009(0.2009) Grad: 0.0000  LR: 0.0000004  
tensor(0.2009, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [37][100/371] Data 0.000 (0.036) Elapsed 0m 36s (remain 1m 38s) Loss: 0.1890(0.3189) Grad: 0.0000  LR: 0.0000004  
Epoch: [37][200/371] Data 0.000 (0.023) Elapsed 1m 10s (remain 0m 59s) Loss: 0.1895(0.3191) Grad: 0.0000  LR: 0.0000004  
Epoch: [37][300/371] Data 0.000 (0.015) Elapsed 1m 43s (remain 0m 24s) Loss: 0.2153(0.3145) Grad: 0.0000  LR: 0.0000004  
Epoch: [37][370/371] Data 0.000 (0.012) Elapsed 2m 6s (remain 0m 0s) Loss: 0.9161(0.3091) Grad: 0.0000  LR: 0.0000004  
EVAL: [0/124] Data 1.145 (1.145) Elapsed 0m 1s (remain 2m 33s) Loss: 0.2781(0.2781) 
EVAL: [100/124] Data 0.000 (0.044) El

Epoch 37 - avg_train_loss: 0.3091  time: 145s
Epoch 37 - avg_val_loss: 0.2813 - avg_val_bce: 0.2751 - avg_val_comp_metric : 0.5711 - w_avg_val_comp_metric : 0.5865
***** Epoch 38 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:36, scheduler_lr:1.7187010357659594e-07
optimizer_lr:1.7187010357659594e-07


0.39857374059060785 0.646444783689839 0.7019417980847367 0.23188027783003992
w1 = 1 w2 = 1 0.49471015004880586
0.43296058792048486 0.6382552006325978 0.6861028620738403 0.27588090854997604
w1 = 2 w2 = 4 0.5082998897942248
Epoch: [38][0/371] Data 1.507 (1.507) Elapsed 0m 1s (remain 11m 20s) Loss: 0.3745(0.3745) Grad: 0.0000  LR: 0.0000002  
tensor(0.3745, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [38][100/371] Data 0.000 (0.029) Elapsed 0m 36s (remain 1m 36s) Loss: 0.1969(0.2978) Grad: 0.0000  LR: 0.0000002  
Epoch: [38][200/371] Data 0.000 (0.016) Elapsed 1m 9s (remain 0m 58s) Loss: 0.2532(0.2959) Grad: 0.0000  LR: 0.0000002  
Epoch: [38][300/371] Data 0.000 (0.011) Elapsed 1m 42s (remain 0m 23s) Loss: 0.2591(0.3062) Grad: 0.0000  LR: 0.0000002  
Epoch: [38][370/371] Data 0.000 (0.009) Elapsed 2m 5s (remain 0m 0s) Loss: 0.2489(0.3067) Grad: 0.0000  LR: 0.0000002  
EVAL: [0/124] Data 1.842 (1.842) Elapsed 0m 1s (remain 3m 59s) Loss: 0.2866(0.2866) 
EVAL: [100/124] Data 0.000 (0.07

Epoch 38 - avg_train_loss: 0.3067  time: 147s
Epoch 38 - avg_val_loss: 0.2815 - avg_val_bce: 0.2743 - avg_val_comp_metric : 0.5711 - w_avg_val_comp_metric : 0.5862
***** Epoch 39 *****
schwarmup_last_epoch:2, schwarmup_lr:4.000000000000001e-05
scheduler_last_epoch:37, scheduler_lr:1e-07
optimizer_lr:1e-07


0.4009922951134894 0.6433457401690706 0.6996266059238587 0.2334318637881063
w1 = 1 w2 = 1 0.49434912624863125
0.43476630975598646 0.6356988693226345 0.6841678969319956 0.2759693789750132
w1 = 2 w2 = 4 0.5076506137464074
Epoch: [39][0/371] Data 1.656 (1.656) Elapsed 0m 1s (remain 12m 17s) Loss: 0.2713(0.2713) Grad: 0.0000  LR: 0.0000001  
tensor(0.2713, device='cuda:0', grad_fn=<AddBackward0>)
Epoch: [39][100/371] Data 0.000 (0.101) Elapsed 0m 43s (remain 1m 56s) Loss: 0.1792(0.3106) Grad: 0.0000  LR: 0.0000001  
Epoch: [39][200/371] Data 0.000 (0.083) Elapsed 1m 23s (remain 1m 10s) Loss: 0.2952(0.3121) Grad: 0.0000  LR: 0.0000001  
Epoch: [39][300/371] Data 0.131 (0.075) Elapsed 2m 1s (remain 0m 28s) Loss: 0.3181(0.3116) Grad: 0.0000  LR: 0.0000001  
Epoch: [39][370/371] Data 0.000 (0.071) Elapsed 2m 28s (remain 0m 0s) Loss: 0.2526(0.3085) Grad: 0.0000  LR: 0.0000001  
EVAL: [0/124] Data 1.360 (1.360) Elapsed 0m 1s (remain 3m 0s) Loss: 0.2848(0.2848) 
EVAL: [100/124] Data 0.000 (0.038)

Epoch 39 - avg_train_loss: 0.3085  time: 168s
Epoch 39 - avg_val_loss: 0.2789 - avg_val_bce: 0.2733 - avg_val_comp_metric : 0.5642 - w_avg_val_comp_metric : 0.5788


0.38955084398967416 0.639636316942691 0.6983999890341014 0.22489254977658482
w1 = 1 w2 = 1 0.48811992493576284
0.4162167944683705 0.6400797028892962 0.6818970555329169 0.2648965948852967
w1 = 2 w2 = 4 0.5007725369439701


***** Epoch 0 *****
schwarmup_last_epoch:0, schwarmup_lr:4.000000000000001e-06
scheduler_last_epoch:0, scheduler_lr:4.000000000000001e-06
optimizer_lr:4.000000000000001e-06


Epoch: [0][0/371] Data 2.849 (2.849) Elapsed 0m 3s (remain 19m 42s) Loss: 0.8244(0.8244) Grad: 0.0000  LR: 0.0000040  
tensor(0.8244, device='cuda:0', grad_fn=<AddBackward0>)



KeyboardInterrupt

