In [1]:
import os
import glob
import cv2
import time
import copy
import pickle  # Log dictionary data
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import matplotlib.pyplot as plt
import seaborn as sn
import sklearn.metrics as metrics

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F # stateless functions
import torchvision.transforms as T
import torchvision.models as models

import multiprocessing
# We must import this explicitly, it is not imported by the top-level
# multiprocessing module.
import multiprocessing.pool

from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import cohen_kappa_score,confusion_matrix
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from datetime import datetime
from multiprocessing import Manager
from PIL import Image

In [2]:
"""Configuration in Common
    To resolve the CUDA out of memory issue, we can trade-off between number of tiles and batch_size
"""
class CFG:
    batch_size = 14
    debug = False
    device = torch.device('cuda')
    dtype = torch.float32
    epochs = 20
    lr = 1e-4
    model_name = 'resnet_mlt'
    nfolds = 4
    num_classes = 6
    nworkers = 1
    n_tile = 12
    seed = 524
    test_split = 516
    threshold = .8
    weight_decay = .25
    TRAIN = '../yi_data/panda-16x128x128-tiles-data/train/'
    LABELS = '../data/train.csv'

# Datasets and Dataloader

In [3]:
train = pd.read_csv(CFG.LABELS).set_index('image_id')
files = sorted(set([p[:32] for p in os.listdir(CFG.TRAIN)]))
train = train.loc[files].reset_index()

train, test = train_test_split(train, test_size=CFG.test_split, random_state=CFG.seed)
train = train.reset_index(drop=True)
test = test.reset_index(drop=True)

if CFG.debug:
    df = train.sample(n=100, random_state=CFG.seed).copy().reset_index(drop=True)
else:
    df = train.copy()

# Generate train/validation sets containing the same distribution of isup_grade
splits = StratifiedKFold(n_splits=CFG.nfolds, random_state=CFG.seed, shuffle=True)
splits = list(splits.split(df,df.isup_grade))
# Assign split index to training samples
folds_splits = np.zeros(len(df)).astype(np.int)
for i in range(CFG.nfolds):
    folds_splits[splits[i][1]] = i
df['split'] = folds_splits
df.head()

Unnamed: 0,image_id,data_provider,isup_grade,gleason_score,split
0,89ce8e0a494db9c7ddc68dbf58729c68,radboud,1,3+3,0
1,ef9c72d529806d9ba341476d4c159838,radboud,5,5+4,2
2,139e1e89bca7897799f80da93adcd3c7,karolinska,0,0+0,1
3,0abc61a1bc7abe47b9e44a9f69979eb0,karolinska,2,3+4,3
4,fe7812f20a38495743cf7bf7f811c108,karolinska,1,3+3,1


In [4]:
# https://www.kaggle.com/yasufuminakama/panda-se-resnext50-regression-baseline
class TrainDataset(Dataset):
    """Prostate Cancer Biopsy Dataset"""
    
    def __init__(self, df, labels, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file
            root_dir (string): Path to the directory with all images
            transform (callable, optional): Optional transform to be applied on an image sample
        """
        # Shuffle dataframes with fixed seed; otherwise, validation set only get cancerous samples
        self.df = df
        self.labels = labels
        self.transform = transform
        
        class_counts = [len(labels[labels==g]) for g in range(CFG.num_classes)]
        self.num_samples = sum(class_counts)
        class_weights = [self.num_samples/class_counts[i] for i in range(len(class_counts))]
        self.weights = [class_weights[labels[i]] for i in range(self.num_samples)]
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        # https://stackoverflow.com/questions/33369832/read-multiple-images-on-a-folder-in-opencv-python
        tile_fns = [f"{CFG.TRAIN}/{self.df['image_id'][idx]}_{sub_id}.png" for sub_id in range(CFG.n_tile)]
        # As we use cv2, the color channel is BGR. https://stackoverflow.com/questions/50963283/python-opencv-imshow-doesnt-need-convert-from-bgr-to-rgb
        img_tiles = [Image.fromarray(cv2.imread(fn)) for fn in tile_fns]
        if self.transform:
            img_tiles = list(map(self.transform, img_tiles))
        
        label = self.labels[idx]

        return img_tiles, label

In [5]:
def get_transforms(phase):
    assert phase in {'train', 'val'}
    
    if phase == 'train':
        return T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomVerticalFlip(),
            T.RandomRotation(15, fill=255),
            T.ToTensor(),
            T.Normalize(# mean and std for 12-tile
                mean=[0.8578, 0.7889, 0.8946],
                std=[0.1713, 0.2596, 0.1408]
            ),
        ])
    else:
        return T.Compose([
            T.ToTensor(),
            T.Normalize(
                mean=[0.8578, 0.7889, 0.8946],
                std=[0.1713, 0.2596, 0.1408]
            ),
        ])

In [6]:
# Use fold idx as validation set
def data_loader(fold_idx):
    train_idx = df[df['split'] != fold_idx].index
    val_idx = df[df['split'] == fold_idx].index

    train_dataset = TrainDataset(df.loc[train_idx].reset_index(drop=True),
                                 df.loc[train_idx].reset_index(drop=True)['isup_grade'],
                                 transform = get_transforms(phase='train'))
    val_dataset = TrainDataset(df.loc[val_idx].reset_index(drop=True),
                               df.loc[val_idx].reset_index(drop=True)['isup_grade'],
                               transform = get_transforms(phase='val'))
    
    train_sampler = WeightedRandomSampler(weights=train_dataset.weights, num_samples=train_dataset.num_samples)
    train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, sampler=train_sampler, num_workers=CFG.nworkers)
    val_loader = DataLoader(val_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.nworkers)
    return train_loader, val_loader

# Training

In [7]:
def train_model(model, fold, dataloaders,
                criterion_bc, criterion_mult,
                optimizer, scheduler, num_epochs=25):
    since = time.time()
    
    # Send the model to GPU/CPU
    model = model.to(device=CFG.device)
    
    mult_train_acc_history = []
    mult_val_acc_history = []
    bc_loss_history = []
    mult_loss_history = []
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()   # Set model to training phase
            else:
                # Skip evaluation for test dataset except for the last epoch
                if (len(dataloaders[phase].dataset) != CFG.test_split) or (num_epochs == epoch+1):
                    model.eval()    # Set model to evaluate phase
                else:
                    continue
            
            bc_loss = 0.0
            mult_loss = 0.0
            mult_num = 0
            mult_corrects = 0
            
            print(' ', end='', flush=True)  # To workaround tqdm issue in multiprocess
            for inputs, labels in tqdm(dataloaders[phase],
                                       desc='[{}] {}/{}({:5s})'.format(fold, epoch+1,num_epochs,phase)):
                bs, C, H, W = inputs[0].shape
                # n_tile[(bs,3,128,128)]-->(bs,n_tile,3,128,128)-->(bs*n_tile,3,128,128)
                inputs = torch.stack(inputs, 1).view(-1,C,H,W).to(device=CFG.device, dtype=CFG.dtype)
                labels_bc = torch.where(labels<=1, torch.tensor(0), torch.tensor(1)).to(device=CFG.device, dtype=CFG.dtype)
                labels_mult = (labels[labels>=2] - 2).to(device=CFG.device, dtype=torch.long)

                # Zero the parameter gradients
                optimizer.zero_grad()
                
                # Forward, track history if only in training
                with torch.set_grad_enabled(phase == 'train'):
                    outputs_bc, outputs_mult = model(inputs)
                    outputs_bc = outputs_bc.squeeze(1)
                    outputs_mult = outputs_mult[labels>=2]  # Shift (2,3,4,5) to (0,1,2,3)
                    
                    # Compute the BCE loss
                    loss_bc = criterion_bc(outputs_bc, labels_bc)
                    # Compute the cross-entropy loss
                    loss_mult = criterion_mult(outputs_mult, labels_mult)
                    #loss_mult.backward()
                    loss = .4*loss_bc + .6*loss_mult
                    # Predictions
                    #pred_bc = torch.where(outputs_bc < CFG.threshold,
                    #                       torch.tensor(0).to(device=CFG.device),
                    #                       torch.tensor(1).to(device=CFG.device)).to(device=CFG.device, dtype=torch.long)
                    pred_mult = torch.argmax(outputs_mult, 1)
                    #print(pred_mult)
                    #print(pred.shape)

                    # Backward + optimize only if in training phase
                    if phase == 'train':
                        # Backprop
                        loss.backward()
                        #loss_mult.backward()
                        optimizer.step()
                    
                # Statistics
                bc_loss += loss_bc.item()
                mult_loss += loss_mult.item()
                mult_num += len(labels_mult)
                #bc_corrects += torch.sum(pred_bc == labels_bc)
                mult_corrects += torch.sum(pred_mult == labels_mult)
                
            # End of epoch
            with torch.no_grad():
                #bc_epoch_acc = bc_corrects.double() / len(dataloaders[phase].dataset)
                mult_epoch_acc = mult_corrects.double() / mult_num
                bc_avg_loss = bc_loss / len(dataloaders[phase].dataset)
                mult_avg_loss = mult_loss / mult_num

                if phase == 'val':
                    mult_val_acc_history.append(mult_epoch_acc)
                    # deep copy the model
                    if mult_epoch_acc > best_acc:
                        best_acc = mult_epoch_acc
                        best_model_wts = copy.deepcopy(model.state_dict())
                    # Apply lr_scheduler
                    if scheduler is not None:
                        scheduler.step(mult_avg_loss)
                else:
                    mult_train_acc_history.append(mult_epoch_acc)
                    bc_loss_history.append(bc_avg_loss)
                    mult_loss_history.append(mult_avg_loss)
                print('[{}] {:5s} BcLoss: {:4f} MultLoss: {:4f} MultAcc: {:4f}'.format(
                    fold, phase, bc_avg_loss, mult_avg_loss, mult_epoch_acc))
    
    time_elapsed = time.time() - since
    print('[{}] Training complete in {:.0f}m {:0f}s'.format(fold, time_elapsed//60, time_elapsed%60))
    print('[{}] Best val MultAcc: {:4f}'.format(fold, best_acc))
    print()
    
    model.load_state_dict(best_model_wts)
    return model, bc_loss_history, mult_loss_history, mult_train_acc_history, mult_val_acc_history

## Two-Stage Classifier

In [8]:
"""Customize Resnet Header"""
class AdaptiveConcatPool2d(nn.Module):
    "Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`."
    def __init__(self, sz=1):
        "Output will be 2*sz or 2 if sz is None"
        super().__init__()
        self.output_size = sz
        self.ap = nn.AdaptiveAvgPool2d(self.output_size)
        self.mp = nn.AdaptiveMaxPool2d(self.output_size)

    def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)

# https://www.kaggle.com/nelsongriffiths/mish-activation-and-transfer-learning-pytorch
def mish(x):
    return (x*torch.tanh(F.softplus(x)))

class mish_layer(nn.Module):
    def __init__(self):
        super(mish_layer, self).__init__()
        
    def forward(self, input):
        return mish(input)
    
class CustomResnet(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc = nn.Sequential(*list(models.resnet18(pretrained=False).children())[:-2])
        num_ftrs= models.resnet18().fc.in_features
        # For binary classification
        self.head_bc = nn.Sequential(
                        nn.AdaptiveAvgPool2d((1,1)),
                        nn.Flatten(),
                        nn.Linear(num_ftrs, 1),  # One class
                        nn.Sigmoid())
        # For the multi-task
        self.head_mult = nn.Sequential(
                        AdaptiveConcatPool2d(),
                        nn.Flatten(),
                        nn.Linear(num_ftrs*2,512),
                        mish_layer(),
                        nn.BatchNorm1d(512),
                        nn.Dropout(0.5),
                        nn.Linear(512, 4) # Four classes
                    )
        
    def forward(self, x):
        n_tile = CFG.n_tile
        enc_out = self.enc(x)
        _, C, H, W = enc_out.shape
        # (bs,n_tile,512,4,4)-->(bs,512,n_tile,4,4)-->(bs,512,n_tile*4,4)
        head_in = enc_out.view(-1,n_tile,C,H,W).permute(0,2,1,3,4)\
                  .contiguous().view(-1,C,n_tile*H,W)
        head_out_bc = self.head_bc(head_in)
        head_out_mult = self.head_mult(head_in)
        
        return head_out_bc, head_out_mult

## Focal Loss

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduce = reduce

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduce:
            return torch.sum(F_loss)
        else:
            return F_loss

In [9]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False
            

def initialize_model(model_name, num_classes, feature_extract=False, use_pretrained=False):
    """
    Params:
        feature_extract
            True - fine tunning
            False - fix the model
    """
    model_ft = None
    
    if model_name == 'alexnet':
        """AlexNet
        """
        model_ft = models.alexnet(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
    elif model_name == 'resnet':
        """Resnet
        """
        model_ft = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
    elif model_name == 'resnet_mlt':
        """Restnet with customized two headers
        """
        model_ft = CustomResnet()
    return model_ft


def train_fd(fold):
    model_ft = initialize_model(CFG.model_name, CFG.num_classes, use_pretrained=False)
    optimizer = optim.Adam(model_ft.parameters(),
                           lr=CFG.lr,
                           weight_decay=CFG.weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1, verbose=True, eps=1e-06)
    loader_train, loader_val = data_loader(fold)
    return train_model(model_ft, fold,
                       {'train': loader_train, 'val': loader_val},
                       nn.BCELoss(reduction='sum'), nn.CrossEntropyLoss(reduction='sum'),
                       optimizer, scheduler, CFG.epochs)

## Multiprocessing

In [10]:
class NoDaemonProcess(multiprocessing.Process):
    # make 'daemon' attribute always return False
    def _get_daemon(self):
        return False
    def _set_daemon(self, value):
        pass
    daemon = property(_get_daemon, _set_daemon)

# We sub-class multiprocessing.pool.Pool instead of multiprocessing.Pool
# because the latter is only a wrapper function, not a proper class.
class MyPool(multiprocessing.pool.Pool):
    Process = NoDaemonProcess

def progressor(fold):
    #print(f'stage{stage} fold{fold}')
    model, bc_loss_history, mult_loss_history, mult_train_acc_history, mult_val_acc_history = train_fd(fold)
    return {
            f'best_model_{fold}': model.to('cpu'),  # Don't save model as cuda
            f'bc_loss_history_{fold}': bc_loss_history,
            f'mult_loss_history_{fold}': mult_loss_history,
            f'mult_train_acc_history_{fold}': mult_train_acc_history,
            f'mult_val_acc_history_{fold}': mult_val_acc_history
           }

## Start Training

In [None]:
log_dict = {'batch_size': CFG.batch_size,
            'epochs': CFG.epochs,
            'learning_rate': CFG.lr,
            'model': CFG.model_name,
            'nworkers': CFG.nworkers,
            'nfolds': CFG.nfolds,
            'random_seed': CFG.seed}


result_list = list(MyPool(CFG.nfolds).map(progressor, range(CFG.nfolds)))

# Accumulate result from each process
for result in result_list:
    log_dict.update(result)

    

HBox(children=(FloatProgress(value=0.0, description='[1] 1/20(train)', max=536.0, style=ProgressStyle(descript…

HBox(children=(FloatProgress(value=0.0, description='[0] 1/20(train)', max=536.0, style=ProgressStyle(descript…

HBox(children=(FloatProgress(value=0.0, description='[3] 1/20(train)', max=536.0, style=ProgressStyle(descript…

HBox(children=(FloatProgress(value=0.0, description='[2] 1/20(train)', max=536.0, style=ProgressStyle(descript…


[0] train BcLoss: 0.587028 MultLoss: 1.496089 MultAcc: 0.285167
 
[3] train BcLoss: 0.577732 MultLoss: 1.466034 MultAcc: 0.305393


HBox(children=(FloatProgress(value=0.0, description='[0] 1/20(val  )', max=179.0, style=ProgressStyle(descript…

 

HBox(children=(FloatProgress(value=0.0, description='[3] 1/20(val  )', max=179.0, style=ProgressStyle(descript…


[1] train BcLoss: 0.581112 MultLoss: 1.468593 MultAcc: 0.302666
 

HBox(children=(FloatProgress(value=0.0, description='[1] 1/20(val  )', max=179.0, style=ProgressStyle(descript…


[2] train BcLoss: 0.564781 MultLoss: 1.471399 MultAcc: 0.304959
 

HBox(children=(FloatProgress(value=0.0, description='[2] 1/20(val  )', max=179.0, style=ProgressStyle(descript…


[0] val   BcLoss: 0.698715 MultLoss: 1.440574 MultAcc: 0.317460
 

HBox(children=(FloatProgress(value=0.0, description='[0] 2/20(train)', max=536.0, style=ProgressStyle(descript…


[3] val   BcLoss: 0.585905 MultLoss: 1.412430 MultAcc: 0.308271
 

HBox(children=(FloatProgress(value=0.0, description='[3] 2/20(train)', max=536.0, style=ProgressStyle(descript…


[1] val   BcLoss: 0.708248 MultLoss: 1.426735 MultAcc: 0.331386
 

HBox(children=(FloatProgress(value=0.0, description='[1] 2/20(train)', max=536.0, style=ProgressStyle(descript…


[2] val   BcLoss: 0.601688 MultLoss: 1.352101 MultAcc: 0.360601
 

HBox(children=(FloatProgress(value=0.0, description='[2] 2/20(train)', max=536.0, style=ProgressStyle(descript…


[0] train BcLoss: 0.535171 MultLoss: 1.388347 MultAcc: 0.332343
 

HBox(children=(FloatProgress(value=0.0, description='[0] 2/20(val  )', max=179.0, style=ProgressStyle(descript…



[3] train BcLoss: 0.516466 MultLoss: 1.397001 MultAcc: 0.345302
[1] train BcLoss: 0.536791 MultLoss: 1.381356 MultAcc: 0.350530
  

HBox(children=(FloatProgress(value=0.0, description='[1] 2/20(val  )', max=179.0, style=ProgressStyle(descript…

HBox(children=(FloatProgress(value=0.0, description='[3] 2/20(val  )', max=179.0, style=ProgressStyle(descript…


[2] train BcLoss: 0.523467 MultLoss: 1.410784 MultAcc: 0.329616
 

HBox(children=(FloatProgress(value=0.0, description='[2] 2/20(val  )', max=179.0, style=ProgressStyle(descript…


[0] val   BcLoss: 0.594135 MultLoss: 1.363706 MultAcc: 0.360902
 

HBox(children=(FloatProgress(value=0.0, description='[0] 3/20(train)', max=536.0, style=ProgressStyle(descript…


[3] val   BcLoss: 0.574986 MultLoss: 1.365644 MultAcc: 0.390142
 

HBox(children=(FloatProgress(value=0.0, description='[3] 3/20(train)', max=536.0, style=ProgressStyle(descript…


[2] val   BcLoss: 0.616242 MultLoss: 1.309857 MultAcc: 0.373957
 


HBox(children=(FloatProgress(value=0.0, description='[2] 3/20(train)', max=536.0, style=ProgressStyle(descript…

[1] val   BcLoss: 0.810590 MultLoss: 1.342544 MultAcc: 0.358932
 

HBox(children=(FloatProgress(value=0.0, description='[1] 3/20(train)', max=536.0, style=ProgressStyle(descript…


[0] train BcLoss: 0.518670 MultLoss: 1.359546 MultAcc: 0.353693
 

HBox(children=(FloatProgress(value=0.0, description='[0] 3/20(val  )', max=179.0, style=ProgressStyle(descript…


[3] train BcLoss: 0.508837 MultLoss: 1.357270 MultAcc: 0.362011
 

HBox(children=(FloatProgress(value=0.0, description='[3] 3/20(val  )', max=179.0, style=ProgressStyle(descript…


[2] train BcLoss: 0.496571 MultLoss: 1.348670 MultAcc: 0.367041
 

HBox(children=(FloatProgress(value=0.0, description='[2] 3/20(val  )', max=179.0, style=ProgressStyle(descript…


[1] train BcLoss: 0.506960 MultLoss: 1.357504 MultAcc: 0.349859
 

HBox(children=(FloatProgress(value=0.0, description='[1] 3/20(val  )', max=179.0, style=ProgressStyle(descript…


[0] val   BcLoss: 0.636567 MultLoss: 1.315666 MultAcc: 0.355054
 

HBox(children=(FloatProgress(value=0.0, description='[0] 4/20(train)', max=536.0, style=ProgressStyle(descript…


[2] val   BcLoss: 0.582849 MultLoss: 1.279805 MultAcc: 0.404841
 

HBox(children=(FloatProgress(value=0.0, description='[2] 4/20(train)', max=536.0, style=ProgressStyle(descript…


[3] val   BcLoss: 0.557539 MultLoss: 1.300082 MultAcc: 0.366750
 

HBox(children=(FloatProgress(value=0.0, description='[3] 4/20(train)', max=536.0, style=ProgressStyle(descript…


[1] val   BcLoss: 0.552021 MultLoss: 1.338092 MultAcc: 0.363940
 

HBox(children=(FloatProgress(value=0.0, description='[1] 4/20(train)', max=536.0, style=ProgressStyle(descript…


[0] train BcLoss: 0.501675 MultLoss: 1.326486 MultAcc: 0.372533
 

HBox(children=(FloatProgress(value=0.0, description='[0] 4/20(val  )', max=179.0, style=ProgressStyle(descript…


[2] train BcLoss: 0.485881 MultLoss: 1.343003 MultAcc: 0.360127
 

HBox(children=(FloatProgress(value=0.0, description='[2] 4/20(val  )', max=179.0, style=ProgressStyle(descript…


[3] train BcLoss: 0.510024 MultLoss: 1.333929 MultAcc: 0.372070
 

HBox(children=(FloatProgress(value=0.0, description='[3] 4/20(val  )', max=179.0, style=ProgressStyle(descript…


[1] train BcLoss: 0.505259 MultLoss: 1.328606 MultAcc: 0.373731
 

HBox(children=(FloatProgress(value=0.0, description='[1] 4/20(val  )', max=179.0, style=ProgressStyle(descript…


[0] val   BcLoss: 0.645776 MultLoss: 1.292463 MultAcc: 0.387636
 

HBox(children=(FloatProgress(value=0.0, description='[0] 5/20(train)', max=536.0, style=ProgressStyle(descript…


[2] val   BcLoss: 0.646738 MultLoss: 1.284619 MultAcc: 0.387312

 [1] val   BcLoss: 0.541101 MultLoss: 1.362760 MultAcc: 0.348080
 

HBox(children=(FloatProgress(value=0.0, description='[2] 5/20(train)', max=536.0, style=ProgressStyle(descript…

HBox(children=(FloatProgress(value=0.0, description='[1] 5/20(train)', max=536.0, style=ProgressStyle(descript…


[3] val   BcLoss: 0.528763 MultLoss: 1.282752 MultAcc: 0.403509
 

HBox(children=(FloatProgress(value=0.0, description='[3] 5/20(train)', max=536.0, style=ProgressStyle(descript…


[0] train BcLoss: 0.494022 MultLoss: 1.321540 MultAcc: 0.374420
 

HBox(children=(FloatProgress(value=0.0, description='[0] 5/20(val  )', max=179.0, style=ProgressStyle(descript…


[1] train BcLoss: 0.498967 MultLoss: 1.313279 MultAcc: 0.379663
 

HBox(children=(FloatProgress(value=0.0, description='[1] 5/20(val  )', max=179.0, style=ProgressStyle(descript…


[2] train BcLoss: 0.501418 MultLoss: 1.332694 MultAcc: 0.367839
 

HBox(children=(FloatProgress(value=0.0, description='[2] 5/20(val  )', max=179.0, style=ProgressStyle(descript…


[3] train BcLoss: 0.488640 MultLoss: 1.317222 MultAcc: 0.394085
 

HBox(children=(FloatProgress(value=0.0, description='[3] 5/20(val  )', max=179.0, style=ProgressStyle(descript…


[0] val   BcLoss: 0.547516 MultLoss: 1.348752 MultAcc: 0.361738
 

HBox(children=(FloatProgress(value=0.0, description='[0] 6/20(train)', max=536.0, style=ProgressStyle(descript…


[1] val   BcLoss: 0.563453 MultLoss: 1.326602 MultAcc: 0.377295
 

HBox(children=(FloatProgress(value=0.0, description='[1] 6/20(train)', max=536.0, style=ProgressStyle(descript…


Epoch     5: reducing learning rate of group 0 to 5.0000e-05.
[2] val   BcLoss: 0.718849 MultLoss: 1.446567 MultAcc: 0.305509
 

HBox(children=(FloatProgress(value=0.0, description='[2] 6/20(train)', max=536.0, style=ProgressStyle(descript…


[3] val   BcLoss: 0.620051 MultLoss: 1.403695 MultAcc: 0.321637
 

HBox(children=(FloatProgress(value=0.0, description='[3] 6/20(train)', max=536.0, style=ProgressStyle(descript…


[0] train BcLoss: 0.486786 MultLoss: 1.317217 MultAcc: 0.371279
 

HBox(children=(FloatProgress(value=0.0, description='[0] 6/20(val  )', max=179.0, style=ProgressStyle(descript…


[1] train BcLoss: 0.493547 MultLoss: 1.299104 MultAcc: 0.388505
 
[2] train BcLoss: 0.477530 MultLoss: 1.301438 MultAcc: 0.393043
 

HBox(children=(FloatProgress(value=0.0, description='[1] 6/20(val  )', max=179.0, style=ProgressStyle(descript…

HBox(children=(FloatProgress(value=0.0, description='[2] 6/20(val  )', max=179.0, style=ProgressStyle(descript…


[3] train BcLoss: 0.494357 MultLoss: 1.305609 MultAcc: 0.396071
 

HBox(children=(FloatProgress(value=0.0, description='[3] 6/20(val  )', max=179.0, style=ProgressStyle(descript…


[0] val   BcLoss: 0.512904 MultLoss: 1.284558 MultAcc: 0.390977
 

HBox(children=(FloatProgress(value=0.0, description='[0] 7/20(train)', max=536.0, style=ProgressStyle(descript…


[2] val   BcLoss: 0.517304 MultLoss: 1.249478 MultAcc: 0.422371
 

HBox(children=(FloatProgress(value=0.0, description='[2] 7/20(train)', max=536.0, style=ProgressStyle(descript…


[1] val   BcLoss: 0.547371 MultLoss: 1.274981 MultAcc: 0.402337
 

HBox(children=(FloatProgress(value=0.0, description='[1] 7/20(train)', max=536.0, style=ProgressStyle(descript…


Epoch     6: reducing learning rate of group 0 to 5.0000e-05.
[3] val   BcLoss: 1.247723 MultLoss: 1.855396 MultAcc: 0.254804
 

HBox(children=(FloatProgress(value=0.0, description='[3] 7/20(train)', max=536.0, style=ProgressStyle(descript…


[0] train BcLoss: 0.496451 MultLoss: 1.286218 MultAcc: 0.400518
 

HBox(children=(FloatProgress(value=0.0, description='[0] 7/20(val  )', max=179.0, style=ProgressStyle(descript…


[2] train BcLoss: 0.481978 MultLoss: 1.291943 MultAcc: 0.397690
 

HBox(children=(FloatProgress(value=0.0, description='[2] 7/20(val  )', max=179.0, style=ProgressStyle(descript…


[1] train BcLoss: 0.495538 MultLoss: 1.302889 MultAcc: 0.386318
 

HBox(children=(FloatProgress(value=0.0, description='[1] 7/20(val  )', max=179.0, style=ProgressStyle(descript…


[3] train BcLoss: 0.479666 MultLoss: 1.292052 MultAcc: 0.393023
 

HBox(children=(FloatProgress(value=0.0, description='[3] 7/20(val  )', max=179.0, style=ProgressStyle(descript…


[0] val   BcLoss: 1.287942 MultLoss: 1.841371 MultAcc: 0.295739
 

HBox(children=(FloatProgress(value=0.0, description='[0] 8/20(train)', max=536.0, style=ProgressStyle(descript…


[2] val   BcLoss: 0.723485 MultLoss: 1.244790 MultAcc: 0.423205
 

HBox(children=(FloatProgress(value=0.0, description='[2] 8/20(train)', max=536.0, style=ProgressStyle(descript…


[3] val   BcLoss: 0.593929 MultLoss: 1.551603 MultAcc: 0.308271
 

HBox(children=(FloatProgress(value=0.0, description='[3] 8/20(train)', max=536.0, style=ProgressStyle(descript…


[1] val   BcLoss: 0.590193 MultLoss: 1.334626 MultAcc: 0.386477
 

HBox(children=(FloatProgress(value=0.0, description='[1] 8/20(train)', max=536.0, style=ProgressStyle(descript…


[0] train BcLoss: 0.494118 MultLoss: 1.297361 MultAcc: 0.389248
 

HBox(children=(FloatProgress(value=0.0, description='[0] 8/20(val  )', max=179.0, style=ProgressStyle(descript…


[2] train BcLoss: 0.481004 MultLoss: 1.288316 MultAcc: 0.401687
 

HBox(children=(FloatProgress(value=0.0, description='[2] 8/20(val  )', max=179.0, style=ProgressStyle(descript…


[3] train BcLoss: 0.473176 MultLoss: 1.278776 MultAcc: 0.400562
 

HBox(children=(FloatProgress(value=0.0, description='[3] 8/20(val  )', max=179.0, style=ProgressStyle(descript…


[1] train BcLoss: 0.495552 MultLoss: 1.288540 MultAcc: 0.406114
 

HBox(children=(FloatProgress(value=0.0, description='[1] 8/20(val  )', max=179.0, style=ProgressStyle(descript…


Epoch     8: reducing learning rate of group 0 to 5.0000e-05.
[0] val   BcLoss: 0.608050 MultLoss: 1.352286 MultAcc: 0.352548
 

HBox(children=(FloatProgress(value=0.0, description='[0] 9/20(train)', max=536.0, style=ProgressStyle(descript…


[2] val   BcLoss: 0.516820 MultLoss: 1.304700 MultAcc: 0.391486
 

HBox(children=(FloatProgress(value=0.0, description='[2] 9/20(train)', max=536.0, style=ProgressStyle(descript…



Epoch     8: reducing learning rate of group 0 to 5.0000e-05.
[1] val   BcLoss: 0.585438 MultLoss: 1.307380 MultAcc: 0.394825
 [3] val   BcLoss: 0.536200 MultLoss: 1.241269 MultAcc: 0.406015
 

HBox(children=(FloatProgress(value=0.0, description='[1] 9/20(train)', max=536.0, style=ProgressStyle(descript…

HBox(children=(FloatProgress(value=0.0, description='[3] 9/20(train)', max=536.0, style=ProgressStyle(descript…


[0] train BcLoss: 0.470352 MultLoss: 1.258545 MultAcc: 0.416404
 

HBox(children=(FloatProgress(value=0.0, description='[0] 9/20(val  )', max=179.0, style=ProgressStyle(descript…


[2] train BcLoss: 0.482818 MultLoss: 1.292781 MultAcc: 0.397802
 

HBox(children=(FloatProgress(value=0.0, description='[2] 9/20(val  )', max=179.0, style=ProgressStyle(descript…


[3] train BcLoss: 0.472388 MultLoss: 1.272184 MultAcc: 0.404264
 

HBox(children=(FloatProgress(value=0.0, description='[3] 9/20(val  )', max=179.0, style=ProgressStyle(descript…


[1] train BcLoss: 0.486522 MultLoss: 1.276754 MultAcc: 0.403281
 

HBox(children=(FloatProgress(value=0.0, description='[1] 9/20(val  )', max=179.0, style=ProgressStyle(descript…


[0] val   BcLoss: 0.528400 MultLoss: 1.289067 MultAcc: 0.390977
 

HBox(children=(FloatProgress(value=0.0, description='[0] 10/20(train)', max=536.0, style=ProgressStyle(descrip…


[3] val   BcLoss: 0.534025 MultLoss: 1.364912 MultAcc: 0.349206
 
Epoch     9: reducing learning rate of group 0 to 2.5000e-05.
[2] val   BcLoss: 0.548414 MultLoss: 1.328026 MultAcc: 0.416528
 

HBox(children=(FloatProgress(value=0.0, description='[3] 10/20(train)', max=536.0, style=ProgressStyle(descrip…

HBox(children=(FloatProgress(value=0.0, description='[2] 10/20(train)', max=536.0, style=ProgressStyle(descrip…


[1] val   BcLoss: 0.510893 MultLoss: 1.341912 MultAcc: 0.390651
 

HBox(children=(FloatProgress(value=0.0, description='[1] 10/20(train)', max=536.0, style=ProgressStyle(descrip…


[0] train BcLoss: 0.476479 MultLoss: 1.269984 MultAcc: 0.413573
 

HBox(children=(FloatProgress(value=0.0, description='[0] 10/20(val  )', max=179.0, style=ProgressStyle(descrip…


[3] train BcLoss: 0.475840 MultLoss: 1.276515 MultAcc: 0.396105
 

HBox(children=(FloatProgress(value=0.0, description='[3] 10/20(val  )', max=179.0, style=ProgressStyle(descrip…


[2] train BcLoss: 0.476998 MultLoss: 1.258726 MultAcc: 0.420322
 

HBox(children=(FloatProgress(value=0.0, description='[2] 10/20(val  )', max=179.0, style=ProgressStyle(descrip…


[1] train BcLoss: 0.480476 MultLoss: 1.264139 MultAcc: 0.419846
 

HBox(children=(FloatProgress(value=0.0, description='[1] 10/20(val  )', max=179.0, style=ProgressStyle(descrip…


[0] val   BcLoss: 0.836755 MultLoss: 1.233541 MultAcc: 0.431913
 

HBox(children=(FloatProgress(value=0.0, description='[0] 11/20(train)', max=536.0, style=ProgressStyle(descrip…


Epoch    10: reducing learning rate of group 0 to 2.5000e-05.
[3] val   BcLoss: 0.582877 MultLoss: 1.286764 MultAcc: 0.403509
 

HBox(children=(FloatProgress(value=0.0, description='[3] 11/20(train)', max=536.0, style=ProgressStyle(descrip…


[2] val   BcLoss: 0.542896 MultLoss: 1.205003 MultAcc: 0.454090
 

HBox(children=(FloatProgress(value=0.0, description='[2] 11/20(train)', max=536.0, style=ProgressStyle(descrip…


[1] val   BcLoss: 0.518718 MultLoss: 1.235953 MultAcc: 0.436561
 

HBox(children=(FloatProgress(value=0.0, description='[1] 11/20(train)', max=536.0, style=ProgressStyle(descrip…


[0] train BcLoss: 0.457514 MultLoss: 1.265399 MultAcc: 0.412114
 

HBox(children=(FloatProgress(value=0.0, description='[0] 11/20(val  )', max=179.0, style=ProgressStyle(descrip…


[3] train BcLoss: 0.472790 MultLoss: 1.241050 MultAcc: 0.423123
 

HBox(children=(FloatProgress(value=0.0, description='[3] 11/20(val  )', max=179.0, style=ProgressStyle(descrip…


[2] train BcLoss: 0.457350 MultLoss: 1.241981 MultAcc: 0.434309
 

HBox(children=(FloatProgress(value=0.0, description='[2] 11/20(val  )', max=179.0, style=ProgressStyle(descrip…


[1] train BcLoss: 0.490286 MultLoss: 1.283845 MultAcc: 0.397595
 

HBox(children=(FloatProgress(value=0.0, description='[1] 11/20(val  )', max=179.0, style=ProgressStyle(descrip…


[0] val   BcLoss: 0.741231 MultLoss: 1.351465 MultAcc: 0.365915
 

HBox(children=(FloatProgress(value=0.0, description='[0] 12/20(train)', max=536.0, style=ProgressStyle(descrip…


[3] val   BcLoss: 0.499037 MultLoss: 1.214460 MultAcc: 0.432749
 

HBox(children=(FloatProgress(value=0.0, description='[3] 12/20(train)', max=536.0, style=ProgressStyle(descrip…


[2] val   BcLoss: 0.502897 MultLoss: 1.205675 MultAcc: 0.453255
 

HBox(children=(FloatProgress(value=0.0, description='[2] 12/20(train)', max=536.0, style=ProgressStyle(descrip…


[1] val   BcLoss: 0.505698 MultLoss: 1.236424 MultAcc: 0.426544
 

HBox(children=(FloatProgress(value=0.0, description='[1] 12/20(train)', max=536.0, style=ProgressStyle(descrip…


[0] train BcLoss: 0.476234 MultLoss: 1.269758 MultAcc: 0.400714
 

HBox(children=(FloatProgress(value=0.0, description='[0] 12/20(val  )', max=179.0, style=ProgressStyle(descrip…


[3] train BcLoss: 0.458286 MultLoss: 1.217824 MultAcc: 0.452610
 

HBox(children=(FloatProgress(value=0.0, description='[3] 12/20(val  )', max=179.0, style=ProgressStyle(descrip…


[2] train BcLoss: 0.457323 MultLoss: 1.255803 MultAcc: 0.416981
 

HBox(children=(FloatProgress(value=0.0, description='[2] 12/20(val  )', max=179.0, style=ProgressStyle(descrip…


[1] train BcLoss: 0.471854 MultLoss: 1.261609 MultAcc: 0.420263
 

HBox(children=(FloatProgress(value=0.0, description='[1] 12/20(val  )', max=179.0, style=ProgressStyle(descrip…


Epoch    12: reducing learning rate of group 0 to 2.5000e-05.
[0] val   BcLoss: 0.463746 MultLoss: 1.264925 MultAcc: 0.439432
 

HBox(children=(FloatProgress(value=0.0, description='[0] 13/20(train)', max=536.0, style=ProgressStyle(descrip…


[3] val   BcLoss: 0.531263 MultLoss: 1.253814 MultAcc: 0.426901
 

HBox(children=(FloatProgress(value=0.0, description='[3] 13/20(train)', max=536.0, style=ProgressStyle(descrip…


[2] val   BcLoss: 0.498636 MultLoss: 1.186842 MultAcc: 0.458264
 

HBox(children=(FloatProgress(value=0.0, description='[2] 13/20(train)', max=536.0, style=ProgressStyle(descrip…


Epoch    12: reducing learning rate of group 0 to 2.5000e-05.
[1] val   BcLoss: 0.590394 MultLoss: 1.423249 MultAcc: 0.343072
 

HBox(children=(FloatProgress(value=0.0, description='[1] 13/20(train)', max=536.0, style=ProgressStyle(descrip…


[0] train BcLoss: 0.458263 MultLoss: 1.230982 MultAcc: 0.432833
 

HBox(children=(FloatProgress(value=0.0, description='[0] 13/20(val  )', max=179.0, style=ProgressStyle(descrip…


[3] train BcLoss: 0.460820 MultLoss: 1.231047 MultAcc: 0.424586
 

HBox(children=(FloatProgress(value=0.0, description='[3] 13/20(val  )', max=179.0, style=ProgressStyle(descrip…


[2] train BcLoss: 0.463107 MultLoss: 1.245093 MultAcc: 0.422607
 

HBox(children=(FloatProgress(value=0.0, description='[2] 13/20(val  )', max=179.0, style=ProgressStyle(descrip…


[1] train BcLoss: 0.472930 MultLoss: 1.231925 MultAcc: 0.435747
 

HBox(children=(FloatProgress(value=0.0, description='[1] 13/20(val  )', max=179.0, style=ProgressStyle(descrip…


[0] val   BcLoss: 0.534116 MultLoss: 1.212927 MultAcc: 0.455305
 

HBox(children=(FloatProgress(value=0.0, description='[0] 14/20(train)', max=536.0, style=ProgressStyle(descrip…


[2] val   BcLoss: 0.510400 MultLoss: 1.220067 MultAcc: 0.433222
 

HBox(children=(FloatProgress(value=0.0, description='[2] 14/20(train)', max=536.0, style=ProgressStyle(descrip…


Epoch    13: reducing learning rate of group 0 to 1.2500e-05.
[3] val   BcLoss: 0.487747 MultLoss: 1.237722 MultAcc: 0.451963
 

HBox(children=(FloatProgress(value=0.0, description='[3] 14/20(train)', max=536.0, style=ProgressStyle(descrip…


[1] val   BcLoss: 0.508977 MultLoss: 1.231034 MultAcc: 0.430718
 

HBox(children=(FloatProgress(value=0.0, description='[1] 14/20(train)', max=536.0, style=ProgressStyle(descrip…


[0] train BcLoss: 0.463621 MultLoss: 1.222756 MultAcc: 0.432416
 

HBox(children=(FloatProgress(value=0.0, description='[0] 14/20(val  )', max=179.0, style=ProgressStyle(descrip…


[2] train BcLoss: 0.451356 MultLoss: 1.242361 MultAcc: 0.432647
 

HBox(children=(FloatProgress(value=0.0, description='[2] 14/20(val  )', max=179.0, style=ProgressStyle(descrip…


[3] train BcLoss: 0.454638 MultLoss: 1.216664 MultAcc: 0.446086
 

HBox(children=(FloatProgress(value=0.0, description='[3] 14/20(val  )', max=179.0, style=ProgressStyle(descrip…


[1] train BcLoss: 0.464089 MultLoss: 1.236508 MultAcc: 0.429169
 

HBox(children=(FloatProgress(value=0.0, description='[1] 14/20(val  )', max=179.0, style=ProgressStyle(descrip…


[0] val   BcLoss: 0.536996 MultLoss: 1.214984 MultAcc: 0.449457
 

HBox(children=(FloatProgress(value=0.0, description='[0] 15/20(train)', max=536.0, style=ProgressStyle(descrip…


Epoch    14: reducing learning rate of group 0 to 1.2500e-05.
[2] val   BcLoss: 0.570060 MultLoss: 1.188034 MultAcc: 0.449917
 

HBox(children=(FloatProgress(value=0.0, description='[2] 15/20(train)', max=536.0, style=ProgressStyle(descrip…


[3] val   BcLoss: 0.495914 MultLoss: 1.197309 MultAcc: 0.447786
 

HBox(children=(FloatProgress(value=0.0, description='[3] 15/20(train)', max=536.0, style=ProgressStyle(descrip…


[1] val   BcLoss: 0.526615 MultLoss: 1.223351 MultAcc: 0.434891
 

HBox(children=(FloatProgress(value=0.0, description='[1] 15/20(train)', max=536.0, style=ProgressStyle(descrip…

## Log results

In [None]:
log_file = f'{CFG.model_name}_{datetime.now().strftime("%m_%d_%Y_%H_%M")}.pkl'
with open(log_file, 'wb') as pkl_file:
    pickle.dump(log_dict, pkl_file)

# Metrics
## Analyze Stage One

In [None]:
# read python dict back from the file
with open('resnet_mlt_06_04_2020_14_31.pkl', 'rb') as pfile:
    log_dict = pickle.load(pfile)

In [None]:
targets, targets_bc, scores_bc, preds_mult = [], [], [], []

for fold in range(CFG.nfolds):
    model_fd = log_dict[f'best_model_{fold}'].to(device=CFG.device, dtype=CFG.dtype)
    _, loader_val = data_loader(fold)
    for inputs, labels in tqdm(loader_val):
        _, C, H, W = inputs[0].shape
        # n_tile[(bs,3,128,128)]-->(bs,n_tile,3,128,128)-->(bs*n_tile,3,128,128)
        inputs = torch.stack(inputs, 1).view(-1,C,H,W).to(device=CFG.device, dtype=CFG.dtype)
        labels_bc = torch.where(labels<=1, torch.tensor([0]), torch.tensor([1])).to(device=CFG.device, dtype=torch.long)
        #labels_mult = (labels[labels>=2] - 2).to(device=CFG.device, dtype=torch.long)  # Shift (2,3,4,5) to (0,1,2,3)
                
        # Forward, track history if only in training
        with torch.no_grad():
            outputs_bc, outputs_mult = model_fd(inputs)
            outputs_bc = outputs_bc.squeeze(1)
            pred_mult = torch.argmax(outputs_mult, 1)
            targets.append(labels)
            targets_bc.append(labels_bc)
            scores_bc.append(outputs_bc)
            preds_mult.append(pred_mult)

t = torch.cat(targets).cpu()
t_bc = torch.cat(targets_bc).cpu()
s_bc = torch.cat(scores_bc).cpu()
p_mult = torch.cat(preds_mult).cpu() + 2  # Shift (0,1,2,3) to (2,3,4,5)

print(t)
print(t_bc)
print(s_bc)
print(p_mult)

## Locate The Threshold of The First Layer

In [None]:
# calculate the fpr and tpr for all thresholds of the classification
fpr, tpr, threshold = metrics.roc_curve(t_bc, s_bc)
roc_auc = metrics.auc(fpr, tpr)
log_dict['fpr'] = fpr
log_dict['tpr'] = tpr
log_dict['threshold'] = threshold


plt.title('Receiver Operating Characteristic')
plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
plt.legend(loc = 'lower right')
plt.plot([0, 1], [0, 1],'r--')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.show()

tn, fp, fn, tp = conf_mat.ravel()
print(f'sensitivity: {tp/(tp+fn)}')
print(f'specificity: {tn/(tn+fp)}')

print(len(log_dict['fpr']))
rate_idx = 1450
print(log_dict['fpr'][rate_idx])
print(log_dict['tpr'][rate_idx])
print(log_dict['threshold'][rate_idx])

## Analyze The Second Layer
Set the desired threshold based on AUC

In [None]:
# Infer the prediction of layer one
p_bc = torch.where(s_bc<CFG.threshold, torch.tensor(0), torch.tensor(1))
# Generate the binary mask for layer two
#p_bc = p_bc[t>=2]
#print(p_bc)
#print(t_bc)
# Filter out the targets of layer two
#t_mult =t[t>=2]
#print(t_mult)
# Filter out the predictions of layer two
p_mult[(t_bc==0) & (p_bc==1)] = 8  # Flase positive
p_mult[(t_bc==1) & (p_bc==0)] = 0  # Flase negative
#print(p_mult)
t[t==1] = 0  # Convert 1 to 0
t[(t_bc==1) & (p_bc==0)] = 8
#print(t)

In [None]:
kappa = cohen_kappa_score(t, p_mult, weights='quadratic')
print(f'Kappa: {kappa}')
conf_mat = confusion_matrix(t, p_mult)
#plt.matshow()
plt.figure(figsize=(14,7))
isup_labels = [0, 2, 3, 4, 5, 8]
sn.heatmap(conf_mat, annot=True, xticklabels=isup_labels, yticklabels=isup_labels)