In [None]:
YADSPH = False
DEBUG = True

In [None]:
import os 
import sys 
import random 
import time 
import json
from glob import glob 
import pandas as pd 
import timm
from timm.models.efficientnet import *
import numpy as np 
from tqdm.notebook import tqdm 
import gc 
import matplotlib.pyplot as plt 
import seaborn as sns 
import cv2
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import scipy 
from scipy import stats
import albumentations as A
from albumentations.pytorch import ToTensorV2
from ast import literal_eval
import sklearn 
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold 
from sklearn.metrics import accuracy_score, roc_auc_score, average_precision_score, precision_score 
import torch
import torch.nn as nn
import torch.optim as optim 
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast 
from torch.autograd import Variable
from torch.nn.modules.loss import _WeightedLoss
import typing as tp
try:
    from itertools import  ifilterfalse
except ImportError: # py3k
    from itertools import  filterfalse
import warnings 
if not DEBUG:
    warnings.filterwarnings('ignore')
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
    print('GPU is available')
else:
    DEVICE = torch.device('cpu')
    print('CPU is used')

In [None]:
VER = 'vptbin_TEST'
PARAMS = {
    'version': VER,
    'seed': 2021,
    'train_bs': 8,
    'valid_bs': 16,
    'image_size': 512,
    'amp': True,
    'accum_iter': 2,
    'init_lr': 5e-4,
    'model_arch': 'tf_efficientnetv2_m', # effb3 with aux loss 
    'n_workers': 8,
    'splits': 5,
    'epochs': 200,
    't_max': 100,
    'patience': 2 if DEBUG else 20,
    'device': ('cuda' if torch.cuda.is_available() else 'cpu'),
    'n_classes': 1,
    'n_channel': 3
}
DATA_PATH = './data'
if YADSPH:
    DATASET_PATH = f'./data2/SIIM-COVID19-Resized/img_sz_{PARAMS["img_size"]}'
    IMGS_PATH = f'{DATASET_PATH}/train'
else:
    IMGS_PATH = f'{DATA_PATH}/train_{PARAMS["image_size"]}'
    MASKS_PATH = f'{DATA_PATH}/train_{PARAMS["image_size"]}_masks'
MDLS_PATH = f'./models_{VER}'
if not os.path.exists(MDLS_PATH):
    os.mkdir(MDLS_PATH)
with open(f'{MDLS_PATH}/params.json', 'w') as file:
    json.dump(PARAMS, file)
    
def seed_all(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True 
    torch.backends.cudnn.benchmark = False

seed_all(PARAMS['seed'])
start_time = time.time()

In [None]:
def lovasz_grad(gt_sorted):
    """
    Computes gradient of the Lovasz extension w.r.t sorted errors
    See Alg. 1 in paper
    """
    p = len(gt_sorted)
    gts = gt_sorted.sum()
    intersection = gts - gt_sorted.float().cumsum(0)
    union = gts + (1 - gt_sorted).float().cumsum(0)
    jaccard = 1. - intersection / union
    if p > 1: # cover 1-pixel case
        jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
    return jaccard

def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
    """
    IoU for foreground class
    binary: 1 foreground, 0 background
    """
    if not per_image:
        preds, labels = (preds,), (labels,)
    ious = []
    for pred, label in zip(preds, labels):
        intersection = ((label == 1) & (pred == 1)).sum()
        union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
        if not union:
            iou = EMPTY
        else:
            iou = float(intersection) / union
        ious.append(iou)
    iou = mean(ious)    # mean accross images if per_image
    return 100 * iou

def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
    """
    Array of IoU for each (non ignored) class
    """
    if not per_image:
        preds, labels = (preds,), (labels,)
    ious = []
    for pred, label in zip(preds, labels):
        iou = []    
        for i in range(C):
            if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
                intersection = ((label == i) & (pred == i)).sum()
                union = ((label == i) | ((pred == i) & (label != ignore))).sum()
                if not union:
                    iou.append(EMPTY)
                else:
                    iou.append(float(intersection) / union)
        ious.append(iou)
    ious = map(mean, zip(*ious)) # mean accross images if per_image
    return 100 * np.array(ious)

In [None]:
# BINARY LOSSES

def lovasz_hinge(logits, labels, per_image=True, ignore=None):
    """
    Binary Lovasz hinge loss
      logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
      labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
      per_image: compute the loss per image instead of per batch
      ignore: void class id
    """
    if per_image:
        loss = mean(
            lovasz_hinge_flat(
                *flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)
            ) 
            for log, lab in zip(logits, labels)
        )
    else:
        loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
    return loss

def lovasz_hinge_flat(logits, labels):
    """
    Binary Lovasz hinge loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore: label to ignore
    """
    if len(labels) == 0:
        # only void pixels, the gradients should be 0
        return logits.sum() * 0.
    signs = 2. * labels.float() - 1.
    errors = (1. - logits * Variable(signs))
    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
    perm = perm.data
    gt_sorted = labels[perm]
    grad = lovasz_grad(gt_sorted)
    #loss = torch.dot(F.relu(errors_sorted), Variable(grad))
    loss = torch.dot(F.elu(errors_sorted) + 1, Variable(grad))
    return loss

def flatten_binary_scores(scores, labels, ignore=None):
    """
    Flattens predictions in the batch (binary case)
    Remove labels equal to 'ignore'
    """
    scores = scores.view(-1)
    labels = labels.view(-1)
    if ignore is None:
        return scores, labels
    valid = (labels != ignore)
    vscores = scores[valid]
    vlabels = labels[valid]
    return vscores, vlabels

class StableBCELoss(torch.nn.modules.Module):
    def __init__(self):
        super(StableBCELoss, self).__init__()
    def forward(self, input, target):
        neg_abs = - input.abs()
        loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
        return loss.mean()

def binary_xloss(logits, labels, ignore=None):
    """
    Binary Cross entropy loss
      logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
      labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
      ignore: void class id
    """
    logits, labels = flatten_binary_scores(logits, labels, ignore)
    loss = StableBCELoss()(logits, Variable(labels.float()))
    return loss

In [None]:
# MULTICLASS LOSSES

def lovasz_softmax(probas, labels, only_present=False, per_image=False, ignore=None):
    """
    Multi-class Lovasz-Softmax loss
      probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1)
      labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
      only_present: average only on classes present in ground truth
      per_image: compute the loss per image instead of per batch
      ignore: void class labels
    """
    if per_image:
        loss = mean(
            lovasz_softmax_flat(
                *flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), 
                only_present=only_present
            )
            for prob, lab in zip(probas, labels)
        )
    else:
        loss = lovasz_softmax_flat(
            *flatten_probas(probas, labels, ignore), 
            only_present=only_present
        )
    return loss

def lovasz_softmax_flat(probas, labels, only_present=False):
    """
    Multi-class Lovasz-Softmax loss
      probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
      labels: [P] Tensor, ground truth labels (between 0 and C - 1)
      only_present: average only on classes present in ground truth
    """
    C = probas.size(1)
    losses = []
    for c in range(C):
        fg = (labels == c).float() # foreground for class c
        if only_present and fg.sum() == 0:
            continue
        errors = (Variable(fg) - probas[:, c]).abs()
        errors_sorted, perm = torch.sort(errors, 0, descending=True)
        perm = perm.data
        fg_sorted = fg[perm]
        losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
    return mean(losses)

def flatten_probas(probas, labels, ignore=None):
    """
    Flattens predictions in the batch
    """
    B, C, H, W = probas.size()
    probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C)  # B * H * W, C = P, C
    labels = labels.view(-1)
    if ignore is None:
        return probas, labels
    valid = (labels != ignore)
    vprobas = probas[valid.nonzero().squeeze()]
    vlabels = labels[valid]
    return vprobas, vlabels

def xloss(logits, labels, ignore=None):
    """
    Cross entropy loss
    """
    return F.cross_entropy(logits, Variable(labels), ignore_index=255)

In [None]:
# HELPER FUNCTIONS

def mean(l, ignore_nan=False, empty=0):
    """
    nanmean compatible with generators.
    """
    l = iter(l)
    if ignore_nan:
        l = ifilterfalse(np.isnan, l)
    try:
        n = 1
        acc = next(l)
    except StopIteration:
        if empty == 'raise':
            raise ValueError('Empty mean')
        return empty
    for n, v in enumerate(l, 2):
        acc += v
    if n == 1:
        return acc
    return acc / n

In [None]:
if YADSPH:
    train_df = pd.read_csv(f'{DATASET_PATH}/meta_sz_{PARAMS["image_size"]}.csv')
else:
    train_df = pd.read_csv(f'{DATA_PATH}/train_meta_{PARAMS["image_size"]}.csv')
display(train_df.head())
if DEBUG:
    train_df = train_df.sample(100)
df_train_img = pd.read_csv(f'{DATA_PATH}/train_image_level.csv')
df_train_sty = pd.read_csv(f'{DATA_PATH}/train_study_level.csv')
if YADSPH:
    train_df['img'] = train_df['image_id'].apply(lambda x: ''.join([x, '.jpg']))
    train_df.rename(columns={'dim1': 'dim_x', 'dim0': 'dim_y'}, inplace=True)
    train_df['id'] = train_df['img'].apply(lambda x: x.split('/')[-1].replace('.jpg', '_image'))
else:
    train_df['id'] = train_df['img'].apply(lambda x: x.split('/')[-1].replace('.png', '_image'))
df_train_sty['StudyInstanceUID'] = df_train_sty['id'].apply(lambda x: x.replace('_study', ''))
del df_train_sty['id']
df_train_img = df_train_img.merge(df_train_sty, on='StudyInstanceUID')
train_df = df_train_img.merge(train_df, on='id')
train_df['None Opacity'] = train_df['boxes'].isnull()
display(train_df.head())
classes = [
    'Negative for Pneumonia',
    'Typical Appearance', 
    'Indeterminate Appearance', 
    'Atypical Appearance'
]
train_df['mc_target'] = train_df.apply(lambda x: np.argmax(x[classes]), axis=1)
print(train_df.shape)
display(train_df.head())

elapsed_time = time.time() - start_time
print(f'time elapsed: {elapsed_time // 60:.0f} min {elapsed_time % 60:.0f} sec')

In [None]:
bsize = min(6, PARAMS['train_bs'])
fig, axes = plt.subplots(figsize=(16, 4), nrows=1, ncols=bsize)
for j in range(bsize):
    img = cv2.imread(f'{IMGS_PATH}/{train_df["img"][j]}')
    axes[j].imshow(img)
    axes[j].set_title(train_df['img'][j])
    axes[j].axis('off')
plt.show()
print('min max image', np.min(img), np.max(img))
fig, axes = plt.subplots(figsize=(16, 4), nrows=1, ncols=bsize)
for j in range(bsize):
    img = cv2.imread(f'{MASKS_PATH}/{train_df["img"][j]}')
    axes[j].imshow(img)
    axes[j].set_title(train_df['img'][j])
    axes[j].axis('off')
plt.show()
print('min max mask', np.min(img), np.max(img))

In [None]:
train_aug =  A.Compose([
    A.OneOf([
        A.RandomBrightness(limit=.2, p=1), 
        A.RandomContrast(limit=.2, p=1), 
    ], p=.5),
    A.Blur(blur_limit=3, p=.25),
    A.OneOf([
        A.GaussNoise(0.002, p=.5),
        A.augmentations.geometric.transforms.Affine(p=.5),
    ], p=.25),
    A.OneOf([
        A.ElasticTransform(alpha=120, sigma=120 * .05, alpha_affine=120 * .03, p=.5),
        A.GridDistortion(p=.5),
        A.OpticalDistortion(distort_limit=2, shift_limit=.5, p=1)                  
    ], p=.25),
    A.RandomRotate90(p=.5),
    A.HorizontalFlip(p=.5),
    A.VerticalFlip(p=.5),
    A.Cutout(num_holes=10, 
             max_h_size=int(.05 * PARAMS['image_size']), 
             max_w_size=int(.05 * PARAMS['image_size']), 
             p=.25),
    A.ShiftScaleRotate(p=.25),
    ToTensorV2(p=1)
])
valid_aug = A.Compose([ToTensorV2(p=1)])

In [None]:
class StudyDataset(Dataset):
    
    def __init__(self, df, target='mc_target', tab_data=None, aug=None):
        super().__init__()
        self.df = df 
        self.aug = aug 
        self.target = target
        self.tab_data = tab_data
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, idx):
        image = cv2.imread(f'{IMGS_PATH}/{self.df["img"][idx]}')
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  
        mask = cv2.imread(f'{MASKS_PATH}/{self.df["img"][idx]}')
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) 
        if self.aug:
            augmented = self.aug(image=image, mask=mask)
            image = augmented['image'] / 255
            mask = augmented['mask'] / 255
        else:
            image = torch.from_numpy(image).float()
            image = image.permute(2, 1, 0) / 255
            mask = mask / 255
        target = self.df[self.target][idx]
        if self.tab_data:
            tab_data = self.df[['modality','sex','body_part']].values
            tab_data = tab_data[idx]
            tab_data = torch.tensor(tab_data)
            return image, target, mask, tab_data 
        else:
            return image, target, mask
    
dataset_show = StudyDataset(
    df=train_df,
    aug=train_aug
)
img_show, lbl_show, mask_show = dataset_show.__getitem__(2)
img_show = img_show.numpy().transpose([1, 2, 0])
img_show = np.clip(img_show, 0, 1)
plt.imshow(img_show)
plt.title('image, target ' + str(lbl_show))
plt.show()
plt.imshow(mask_show)
plt.title('mask, target ' + str(lbl_show))
plt.show()

In [None]:
class Net(nn.Module):
    
    def __init__(self, n_classes=4, tab_flag=False):
        super(Net, self).__init__()
        e = timm.create_model(PARAMS['model_arch'], pretrained=True)
        self.b0 = nn.Sequential(
            e.conv_stem,
            e.bn1,
            e.act1
        )
        self.b1 = e.blocks[0]
        self.b2 = e.blocks[1]
        self.b3 = e.blocks[2]
        self.b4 = e.blocks[3]
        self.b5 = e.blocks[4]
        self.b6 = e.blocks[5]
        self.b7 = e.blocks[6]
        self.b8 = nn.Sequential(
            e.conv_head, 
            e.bn2,
            e.act2
        )
        self.final_eff_layer = nn.Linear(e.classifier.in_features, 1000)
        self.mask = nn.Sequential(
            nn.Conv2d(176, 160, kernel_size=3, padding=1),
            nn.BatchNorm2d(160),
            nn.ReLU(inplace=True),
            nn.Conv2d(160, 160, kernel_size=3, padding=1),
            nn.BatchNorm2d(160),
            nn.ReLU(inplace=True),
            nn.Conv2d(160, 1, kernel_size=1, padding=0),
        )
        self.tab_flag = tab_flag
        if self.tab_flag:
            # Tab layers 
            num_features = 3
            self.hidden_size = [10, 6] 
            self.batch_norm1 = nn.BatchNorm1d(num_features)
            self.linear1 = nn.Linear(num_features, self.hidden_size[0])
            self.batch_norm2 = nn.BatchNorm1d(self.hidden_size[0])
            self.linear2 = nn.Linear(self.hidden_size[0], self.hidden_size[1])
            # FINAL LAYER 
            self.final = nn.Linear(1006, n_classes)
        else:
            self.final = nn.Linear(1000, n_classes)

    # @torch.cuda.amp.autocast()
    def forward(self, image, tab_data=None):
        # Image layers
        batch_size = len(image)
        x = 2 * image - 1    
        x = self.b0(x) 
        x = self.b1(x) 
        x = self.b2(x) 
        x = self.b3(x) 
        x = self.b4(x)
        x = self.b5(x) 
        # ============ #
        mask = self.mask(x)
        # ============ #
        x = self.b6(x) 
        x = self.b7(x) 
        x = self.b8(x)
        x = F.adaptive_avg_pool2d(x, 1).reshape(batch_size, -1)
        x = self.final_eff_layer(x)
        if tab_data:
            # TAB layers 
            y = self.batch_norm1(tab_data)
            y = F.relu(self.linear1(y))
            y = self.batch_norm2(y)
            y = F.relu(self.linear2(y))
            # Final Layers 
            x = torch.cat((x, y), dim=1)  # Concatenating image feats + tab feats 
        x = F.relu(x)
        logit = self.final(x)
        return logit, mask

In [None]:
def custom_map(target, pred):
    c1 = [0] * len(target)
    c2 = [0] * len(target)
    c3 = [0] * len(target)
    c4 = [0] * len(target)
    t1 = [0] * len(target)
    t2 = [0] * len(target)
    t3 = [0] * len(target)
    t4 = [0] * len(target)
    for i in range(len(pred)):
        c1[i] = pred[i][0]
        c2[i] = pred[i][1]
        c3[i] = pred[i][2]
        c4[i] = pred[i][3]
        if target[i] == 0:
            t1[i] = 1
        else:
            t1[i] = 0
        if target[i] == 1:
            t2[i] = 1
        else:
            t2[i] = 0
        if target[i] == 2:
            t3[i] = 1
        else:
            t3[i] = 0 
        if target[i] == 3:
            t4[i] = 1 
        else:
            t4[i] = 0
    mean_ap = (
        average_precision_score(t1, c1) + 
        average_precision_score(t2, c2) +
        average_precision_score(t3, c3) + 
        average_precision_score(t4, c4)
    ) / 4
    return mean_ap * (2 / 3)
        
def custom_map2(preds, targs):
    return np.mean([
        average_precision_score(targs==i, preds[:, i]) for i in range(4)
    ]) * 2 / 3  

class FocalCosineLoss(nn.Module):
    
    def __init__(self, alpha=1, gamma=2, xent=.1):
        super(FocalCosineLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.xent = xent
        self.y = torch.Tensor([1]).cuda()

    def forward(self, input, target, reduction="mean"):
        cosine_loss = F.cosine_embedding_loss(
            input, 
            F.one_hot(target, num_classes=input.size(-1)), 
            self.y, 
            reduction=reduction
        )
        cent_loss = F.cross_entropy(
            F.normalize(input), 
            target, 
            reduce=False
        )
        pt = torch.exp(-cent_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * cent_loss
        if reduction == 'mean':
            focal_loss = torch.mean(focal_loss)
        return cosine_loss + self.xent * focal_loss

def elo_loss(logits, y, reduction='mean'):
    """
    https://towardsdatascience.com/explicit-auc-maximization-70beef6db14e
    """
    # reduction mean, sum ...
    losses = [] 
    class_ids = y.unique()    
    for i in class_ids:
        class_logits = logits[:, i.item()]
        class_targs = (y == i).float()
        mask = (class_targs.unsqueeze(1) * (1 - class_targs.unsqueeze(0))).bool()
        class_loss = -torch.sigmoid(class_logits.unsqueeze(1) - class_logits.unsqueeze(0))[mask].mean()
        losses.append(class_loss)
    loss = torch.stack(losses)
    if reduction == 'mean': return torch.mean(loss)
    if reduction == 'sum':  return torch.sum(loss)
    
def symmetric_lovasz(outputs, targets):
    return .5 * (lovasz_hinge(outputs, targets) + 
                 lovasz_hinge(-outputs, 1 - targets)) 

In [None]:
def train_one_epoch(model, trainloader, optimizer, criterion,
                    epoch, device, scheduler=None):
    model.train()
    t = time.time()
    scaler = GradScaler()
    accum_iter = PARAMS['accum_iter']
    final_loss = 0
    all_targets = []
    all_preds = []
    all_preds_ = []
    soft = nn.Softmax()
    bce = nn.BCEWithLogitsLoss()
    pbar = tqdm(trainloader, total=len(trainloader))
    for i, (image, target, mask) in enumerate(pbar):
        image = image.to(device, dtype=torch.float32)
        mask = mask.to(device, dtype=torch.float32)
        mask = mask.unsqueeze(1)
        mask = F.interpolate(mask, size=(32, 32), mode='bilinear', align_corners=False)
        if PARAMS['amp']:
            with autocast():
                if PARAMS['n_classes'] > 1:
                    target = target.to(device, dtype=torch.long)
                    optimizer.zero_grad()
                    logits, pred_mask = model(image)
                    loss1 = criterion(logits, target)
                    loss2 = bce(pred_mask, mask)
                    loss3 = symmetric_lovasz(pred_mask, mask) 
                    probs = torch.softmax(logits, 1)
                else:
                    target = target.to(device).float()
                    target = target.unsqueeze(1)
                    optimizer.zero_grad()
                    logits, pred_mask = model(image)
                    loss1 = criterion(logits, target)
                    loss2 = bce(pred_mask, mask)
                    loss3 = symmetric_lovasz(pred_mask , mask)
                    probs = nn.Sigmoid()(logits)
                sum_loss = (loss1 + loss2 + loss3) / 3
                # normalize loss to account for batch accumulation
                sum_loss /= accum_iter
                scaler.scale(sum_loss).backward()
                # weights update
                if ((i + 1) % accum_iter == 0) or (i + 1 == len(trainloader)):
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
                    if scheduler is not None:
                        scheduler.step()
                final_loss += loss1.detach().item()
                all_targets.extend(target.cpu().detach().numpy().tolist())
                all_preds.extend(probs.cpu().detach().numpy().tolist())
                all_preds_.extend(torch.round(probs).cpu().detach().numpy().tolist())
            pbar.set_description(f'epoch: {epoch + 1}, loss: {loss1.item():.3f}')
        else:
            print('NOT IMPLEMENTED'); break
    if PARAMS['n_classes'] > 1:
        return final_loss / len(trainloader), roc_auc_score(
            all_targets, 
            all_preds,
            multi_class="ovo",
            average="macro"
        ), roc_auc_score(
            all_targets, 
            all_preds,
            multi_class="ovr",
            average="macro"
        ), custom_map(all_targets, all_preds)
    else:
        return final_loss / len(trainloader), roc_auc_score(
            all_targets , 
            all_preds
        ), average_precision_score(
            all_targets,
            all_preds
        ), accuracy_score(all_targets , all_preds_)

In [None]:
def validate_one_epoch(model, validloader, criterion, epoch, device):
    t = time.time()
    model.eval()
    final_loss = 0.0
    all_targets = []
    all_preds = []
    all_preds_ = []
    all_preds_cl = []
    soft = nn.Softmax()
    pbar = tqdm(validloader, total=len(validloader))
    for i, (image, target, mask) in enumerate(pbar):
        image = image.to(device, dtype=torch.float32)
        mask = mask.to(device, dtype=torch.float32)
        mask = mask.unsqueeze(1)
        if PARAMS['n_classes'] > 1:
            target = target.to(device, dtype=torch.long)
            logits, pred_mask = model(image)
            loss = criterion(logits, target)
            probs = torch.softmax(logits, 1)
            final_loss += loss.detach().item()
            all_targets.extend(target.cpu().detach().numpy().tolist())
            all_preds.extend(probs.cpu().detach().numpy().tolist())
            all_preds_cl.extend(torch.argmax(logits, 1).cpu().detach().numpy().tolist())
        else:
            target = target.to(device).float()
            target = target.unsqueeze(1)
            logits, pred_mask = model(image)
            loss = criterion(logits, target)
            probs = nn.Sigmoid()(logits)
            final_loss += loss.detach().item()
            all_targets.extend(target.cpu().detach().numpy().tolist())
            all_preds.extend(probs.cpu().detach().numpy().tolist())
            all_preds_.extend(torch.round(probs).cpu().detach().numpy().tolist())
        pbar.set_description(f'epoch: {epoch + 1}, loss: {loss.item():.3f}')
    if PARAMS['n_classes'] > 1:
        return final_loss / len(validloader), roc_auc_score(
            all_targets, 
            all_preds,
            multi_class="ovo",
            average="macro"
        ), roc_auc_score(
            all_targets, 
            all_preds, 
            multi_class="ovr", 
            average="macro"
        ), custom_map(all_targets, all_preds), accuracy_score(
            all_targets,
            all_preds_cl
        ), precision_score(
            all_targets,
            all_preds_cl,
            average='macro'
        )
    else:
        return final_loss  /len(validloader), roc_auc_score(
            all_targets, 
            all_preds), average_precision_score(
            all_targets,
            all_preds
        ), accuracy_score(all_targets, all_preds_)

In [None]:
def engine(splits, target='Negative for Pneumonia'):
    mean_ap = 0.0
    cnt = 0
    skf  = StratifiedKFold(n_splits=splits)
    for fold, (train_idx , valid_idx) in enumerate(skf.split(
        train_df, 
        y=train_df[target].tolist()
    )):
        print('=' * 48)
        print('=' * 20, f'FOLD {fold}', '=' * 20)
        print('=' * 48)
        print(f'training for "{target}" class',
              f'\ntrain len: {len(train_idx)} | val len: {len(valid_idx)}')
        print('=' * 48)
        model = Net(n_classes=PARAMS['n_classes'])
        train_data = train_df.iloc[train_idx, :].reset_index(drop=True)
        valid_data = train_df.iloc[valid_idx, :].reset_index(drop=True)
        train_dataset = StudyDataset(train_data, target, aug=train_aug)
        valid_dataset = StudyDataset(valid_data, target, aug=valid_aug)
        trainloader = DataLoader(
            train_dataset,
            shuffle=True,
            pin_memory=True,
            num_workers=PARAMS['n_workers'],
            batch_size=PARAMS['train_bs'],
            drop_last=False
        )
        validloader = DataLoader(
            valid_dataset,
            shuffle=False,
            pin_memory=True,
            num_workers=PARAMS['n_workers'],
            batch_size=PARAMS['valid_bs'],
            drop_last=False
        )
        model.to(PARAMS['device'])
        optimizer = optim.AdamW(model.parameters(), PARAMS['init_lr']) 
        if PARAMS['n_classes'] > 1:
            criterion = nn.CrossEntropyLoss()
        else:
            criterion = nn.BCEWithLogitsLoss()
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, PARAMS['t_max'])
        validation_ap = 0
        validation_roc_auc = 0
        best_ep = 0
        epochs_no_improve = 0
        for epoch in range(PARAMS['epochs']):
            if PARAMS['n_classes'] > 1:
                train_loss, train_roc_auc_macro_ovo, \
                train_roc_auc_macro_ovr, train_map = train_one_epoch(
                    model,
                    trainloader,
                    optimizer,
                    criterion,
                    epoch,
                    PARAMS['device'],
                    scheduler
                )
                with torch.no_grad():
                    valid_loss, valid_roc_auc_macro_ovo, valid_roc_auc_macro_ovr, \
                    val_map, val_acc, val_pre = validate_one_epoch(
                        model,
                        validloader,
                        criterion, 
                        epoch,
                        PARAMS['device']
                    )
                content = ''.join([
                    f'{time.ctime()} epoch {epoch} | ',
                    f'train loss {train_loss:.4f}, val loss {valid_loss:.4f} | ',
                    f'train roc auc macro ovo {train_roc_auc_macro_ovo:.4f}, '
                    f'val roc auc macro ovo {valid_roc_auc_macro_ovo:.4f} | ',
                    f'train_roc_auc_macro_ovr {train_roc_auc_macro_ovr:.4f}, ',
                    f'valid roc auc macro ovr {valid_roc_auc_macro_ovr:.4f} | '
                    f'train_map {train_map:.4f}, val map {val_map:.4f} | ',
                    f'val acc {val_acc:.4f}, val precision {val_pre:.4f}'
                ])
                print(content)
            else:
                train_loss, train_roc_auc, train_ap, train_acc = train_one_epoch(
                    model,
                    trainloader,
                    optimizer,
                    criterion,
                    epoch,
                    PARAMS['device'],
                    scheduler
                )
                with torch.no_grad():
                    valid_loss, valid_roc_auc, val_ap, valid_acc = validate_one_epoch(
                        model,
                        validloader,
                        criterion,
                        epoch,
                        PARAMS['device']
                    )
                content = ''.join([
                    f'{time.ctime()} epoch {epoch} | ',
                    f'train loss {train_loss:.4f}, val loss {valid_loss:.4f} | ',
                    f'train roc auc {train_roc_auc:.4f}, '
                    f'val roc auc {valid_roc_auc:.4f} | ',
                    f'train AP {train_ap:.4f}, valid AP {val_ap:.4f} | '
                    f'train accuracy {train_acc:.4f}, val accuracy {valid_acc:.4f}'
                ])
            with open('{}/log_{}.txt'.format(MDLS_PATH, fold), 'a') as appender:
                appender.write(content + '\n')
            if PARAMS['n_classes'] > 1:              
                if val_map > validation_ap:
                    print(f'mAP val improves: {validation_ap} -> {val_map}')
                    validation_ap = val_map
                    best_ep = epoch
                    torch.save(
                        model.state_dict(),
                        f"{MDLS_PATH}/model_best_fold_{fold}.pth"
                    )
                    epochs_no_improve = 0
                else:
                    epochs_no_improve += 1
            else:              
                if valid_roc_auc > validation_roc_auc:
                    print(f'val roc auc improves: {validation_roc_auc} -> {valid_roc_auc}')
                    validation_roc_auc = valid_roc_auc
                    best_ep = epoch
                    torch.save(
                        model.state_dict(),
                        f"{MDLS_PATH}/model_best_fold_{fold}.pth"
                    )
                    epochs_no_improve = 0
                else:
                    epochs_no_improve += 1
            if epochs_no_improve >= PARAMS['patience']:
                print('no improve for', epochs_no_improve, 'epochs | early stopping')
                break
        torch.save(
            model.state_dict(),
            f"{MDLS_PATH}/model_final_fold_{fold}.pth"
        )
        if PARAMS['n_classes'] > 1:
            with open('{}/log_{}.txt'.format(MDLS_PATH, fold), 'a') as appender:
                appender.write(f'\nbest val mAP: {validation_ap:.3f} | best epoch {best_ep}')
        else:
            with open('{}/log_{}.txt'.format(MDLS_PATH, fold), 'a') as appender:
                appender.write(f'\nbest val roc auc: {validation_roc_auc:.3f} | best epoch {best_ep}')            
        del model, train_data, valid_data, train_dataset, valid_dataset, trainloader, validloader
        gc.collect()

In [None]:
if __name__ == '__main__':
    if PARAMS['n_classes'] > 1:
        engine(PARAMS['splits'], 'mc_target')
    else:
        engine(PARAMS['splits'], 'None Opacity')