## Import modules

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from model import ChannelEffFormer

from utils import DiceLossV2, ISICLoader

import pandas as pd
import glob
import argparse

import numpy as np
import copy
import yaml
from tqdm import tqdm

from sklearn.metrics import confusion_matrix
from sklearn.metrics import f1_score
import matplotlib.pyplot as plt
%config InlineBackend.figure_format="svg"
%matplotlib inline

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

data_path = './isic18_224/'

## Loader Hyper parameters
config         = yaml.load(open('./config_skin.yml'), Loader=yaml.FullLoader)
number_classes = int(config['number_classes'])
input_channels = 3
best_val_loss  = np.inf
# data_path = config['path_to_data']

In [3]:
train_dataset = ISICLoader(path_Data = data_path, train = True)
train_loader  = DataLoader(train_dataset, batch_size = 24, shuffle= True)
val_dataset   = ISICLoader(path_Data = data_path, train = False, Test = True)
val_loader    = DataLoader(val_dataset, batch_size = 1, shuffle= False)

## Create Model

In [4]:
class skin_net(torch.nn.Module):
    def __init__(self, classes = 1):
        super().__init__()
        self.net = ChannelEffFormer(num_classes=1, head_count=8, token_mlp_mode="mix_skip")
        
    def forward(self, x):
        x = self.net(x)
        return x 

net = skin_net(classes = 1)
net = net.to(device)

## Define optimizer, loss functions, etc.

In [5]:
# optimizer = optim.SGD(Net.parameters(), lr=float(config['lr']), momentum=0.9, weight_decay=0.0001)
optimizer = optim.Adam(net.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor = 0.5, patience = config['patience'])
criteria  = torch.nn.BCEWithLogitsLoss()
# criteria_dice  = DiceLossV2()

In [6]:
max_epochs = 100
save_name = './model_results/' + 'ISIC_V1/'
os.makedirs(save_name, exist_ok=True)
best_val_loss  = np.inf
eval_interval = 1
Start_epoch = 0

## Define train-validation loop

In [7]:
best_F1_score = 0.0
for ep in range(Start_epoch, max_epochs):
    net.train()
    epoch_loss = 0
    for itter, batch in enumerate(train_loader):
        img = batch['image'].to(device, dtype=torch.float)
        msk = batch['mask'].to(device)
        mask_type = torch.float32
        msk = msk.to(device=device, dtype=mask_type)
        msk_pred = net(img)

        loss          = criteria(msk_pred, msk)
#         loss_dice     = criteria_dice(msk_pred, msk)
        
        optimizer.zero_grad()
        loss_total = loss
        loss_total.backward()
        epoch_loss += loss_total.item()
        optimizer.step()  
        if itter%int(0.1 * len(train_loader))==0:
            print(f' Epoch: {ep+1}, itteration: {itter+1}, Loss: {((epoch_loss/(itter+1)))}, CE loss: {loss.item()}')
    predictions = []
    gt = []

    if (ep+1)% eval_interval==0:
        with torch.no_grad():
            print('val_mode')
            val_loss = 0
            net.eval()
            for itter, batch in enumerate(val_loader):
                img = batch['image'].to(device, dtype=torch.float)
                msk = batch['mask']
                msk_pred = net(img)
                gt.append(msk.numpy()[0, 0])
                msk_pred = msk_pred.cpu().detach().numpy()[0, 0]
                msk_pred  = np.where(msk_pred>=0.41, 1, 0)

                predictions.append(msk_pred)        

        predictions = np.array(predictions)
        gt = np.array(gt)

        y_scores = predictions.reshape(-1)
        y_true   = gt.reshape(-1)

        y_scores2 = np.where(y_scores>0.45, 1, 0) #0.47
        y_true2   = np.where(y_true>0.5, 1, 0)


        #F1 score
        F1_score = f1_score(y_true2, y_scores2, labels=None, average='binary', sample_weight=None)
        print (f"\nF1 score (F-measure) or DSC in epoch {ep+1}:  {F1_score}")
            
        if ((F1_score) > best_F1_score) or F1_score >= 0.92:
            print('New best loss, saving...')
            best_F1_score = copy.deepcopy(F1_score)
            state = copy.deepcopy({'model_weights': net.state_dict(), 'test_F1_score': F1_score})
            torch.save(state, save_name + f"ISCF_{F1_score}_ep_{ep}.model")
            print(save_name + f"ISCF_{F1_score}_ep_{ep}.model")

 Epoch: 1, itteration: 1, Loss: 0.9320162534713745, CE loss: 0.9320162534713745
 Epoch: 1, itteration: 8, Loss: 0.5585475899279118, CE loss: 0.45749005675315857
 Epoch: 1, itteration: 15, Loss: 0.5159610688686371, CE loss: 0.47572678327560425
 Epoch: 1, itteration: 22, Loss: 0.4971086802807721, CE loss: 0.47728121280670166
 Epoch: 1, itteration: 29, Loss: 0.4672179674280101, CE loss: 0.334820955991745
 Epoch: 1, itteration: 36, Loss: 0.4446494181950887, CE loss: 0.2703462541103363
 Epoch: 1, itteration: 43, Loss: 0.4309470431749211, CE loss: 0.3736322522163391
 Epoch: 1, itteration: 50, Loss: 0.42118457973003387, CE loss: 0.3723260164260864
 Epoch: 1, itteration: 57, Loss: 0.4091652678815942, CE loss: 0.3114868402481079
 Epoch: 1, itteration: 64, Loss: 0.3970083447638899, CE loss: 0.36305922269821167
 Epoch: 1, itteration: 71, Loss: 0.3830763859228349, CE loss: 0.27616918087005615
val_mode

F1 score (F-measure) or DSC in epoch 1:  0.730172629285498
New best loss, saving...
./model_resu

 Epoch: 9, itteration: 15, Loss: 0.15203378448883692, CE loss: 0.13646146655082703
 Epoch: 9, itteration: 22, Loss: 0.160563989457759, CE loss: 0.23834046721458435
 Epoch: 9, itteration: 29, Loss: 0.15360839215332064, CE loss: 0.05656455084681511
 Epoch: 9, itteration: 36, Loss: 0.15392783584280145, CE loss: 0.15183508396148682
 Epoch: 9, itteration: 43, Loss: 0.15393957452372062, CE loss: 0.12629903852939606
 Epoch: 9, itteration: 50, Loss: 0.15690314941108227, CE loss: 0.1320725828409195
 Epoch: 9, itteration: 57, Loss: 0.15594329874505075, CE loss: 0.11165323853492737
 Epoch: 9, itteration: 64, Loss: 0.151330093445722, CE loss: 0.11712788790464401
 Epoch: 9, itteration: 71, Loss: 0.14994900972700456, CE loss: 0.10138718783855438
val_mode

F1 score (F-measure) or DSC in epoch 9:  0.8300211242989919
 Epoch: 10, itteration: 1, Loss: 0.1421116143465042, CE loss: 0.1421116143465042
 Epoch: 10, itteration: 8, Loss: 0.14847531728446484, CE loss: 0.20300279557704926
 Epoch: 10, itteration: 

 Epoch: 17, itteration: 8, Loss: 0.1295870542526245, CE loss: 0.14059501886367798
 Epoch: 17, itteration: 15, Loss: 0.12221386830012003, CE loss: 0.10013030469417572
 Epoch: 17, itteration: 22, Loss: 0.11720189689235254, CE loss: 0.0928124189376831
 Epoch: 17, itteration: 29, Loss: 0.1139083160408612, CE loss: 0.15620169043540955
 Epoch: 17, itteration: 36, Loss: 0.11606223591499859, CE loss: 0.11187649518251419
 Epoch: 17, itteration: 43, Loss: 0.11627637750880662, CE loss: 0.12044590711593628
 Epoch: 17, itteration: 50, Loss: 0.11580699436366558, CE loss: 0.1271122395992279
 Epoch: 17, itteration: 57, Loss: 0.11529051819652841, CE loss: 0.11160171777009964
 Epoch: 17, itteration: 64, Loss: 0.11316956934751943, CE loss: 0.07526498287916183
 Epoch: 17, itteration: 71, Loss: 0.11212976429034287, CE loss: 0.09604363143444061
val_mode

F1 score (F-measure) or DSC in epoch 17:  0.8482612540966624
 Epoch: 18, itteration: 1, Loss: 0.05777815729379654, CE loss: 0.05777815729379654
 Epoch: 18,

 Epoch: 25, itteration: 36, Loss: 0.1098471697833803, CE loss: 0.08904760330915451
 Epoch: 25, itteration: 43, Loss: 0.10949487585661023, CE loss: 0.12053444981575012
 Epoch: 25, itteration: 50, Loss: 0.10991407394409179, CE loss: 0.0702512189745903
 Epoch: 25, itteration: 57, Loss: 0.10862982612952851, CE loss: 0.1294032484292984
 Epoch: 25, itteration: 64, Loss: 0.10754192317835987, CE loss: 0.10420827567577362
 Epoch: 25, itteration: 71, Loss: 0.1059797095056151, CE loss: 0.059812337160110474
val_mode

F1 score (F-measure) or DSC in epoch 25:  0.867382519218311
 Epoch: 26, itteration: 1, Loss: 0.11262916773557663, CE loss: 0.11262916773557663
 Epoch: 26, itteration: 8, Loss: 0.10020467080175877, CE loss: 0.06869470328092575
 Epoch: 26, itteration: 15, Loss: 0.09323146988948186, CE loss: 0.07032924890518188
 Epoch: 26, itteration: 22, Loss: 0.10278869826685298, CE loss: 0.1406545490026474
 Epoch: 26, itteration: 29, Loss: 0.10250098232565255, CE loss: 0.11131960898637772
 Epoch: 26, 

 Epoch: 33, itteration: 57, Loss: 0.09313499090964333, CE loss: 0.06464356184005737
 Epoch: 33, itteration: 64, Loss: 0.0931987859075889, CE loss: 0.1130862683057785
 Epoch: 33, itteration: 71, Loss: 0.09422046428834888, CE loss: 0.0770091637969017
val_mode

F1 score (F-measure) or DSC in epoch 33:  0.863040464205306
 Epoch: 34, itteration: 1, Loss: 0.09638034552335739, CE loss: 0.09638034552335739
 Epoch: 34, itteration: 8, Loss: 0.10070277377963066, CE loss: 0.11798358708620071
 Epoch: 34, itteration: 15, Loss: 0.09242260605096816, CE loss: 0.06115434318780899
 Epoch: 34, itteration: 22, Loss: 0.0988886071877046, CE loss: 0.12444943189620972
 Epoch: 34, itteration: 29, Loss: 0.10106865290937753, CE loss: 0.19085443019866943
 Epoch: 34, itteration: 36, Loss: 0.10208183930565913, CE loss: 0.09378132224082947
 Epoch: 34, itteration: 43, Loss: 0.10594290226351383, CE loss: 0.12161732465028763
 Epoch: 34, itteration: 50, Loss: 0.10758961908519268, CE loss: 0.11769814789295197
 Epoch: 34, 

 Epoch: 41, itteration: 71, Loss: 0.09060659820974713, CE loss: 0.08929755538702011
val_mode

F1 score (F-measure) or DSC in epoch 41:  0.8779422018082098
 Epoch: 42, itteration: 1, Loss: 0.06027240306138992, CE loss: 0.06027240306138992
 Epoch: 42, itteration: 8, Loss: 0.07492474559694529, CE loss: 0.06488291919231415
 Epoch: 42, itteration: 15, Loss: 0.07603812118371328, CE loss: 0.05881473422050476
 Epoch: 42, itteration: 22, Loss: 0.0778275855224241, CE loss: 0.07652700692415237
 Epoch: 42, itteration: 29, Loss: 0.07682159144816728, CE loss: 0.06739351153373718
 Epoch: 42, itteration: 36, Loss: 0.0783109079218573, CE loss: 0.0980987623333931
 Epoch: 42, itteration: 43, Loss: 0.08060038280348446, CE loss: 0.11819491535425186
 Epoch: 42, itteration: 50, Loss: 0.0818361745774746, CE loss: 0.12313083559274673
 Epoch: 42, itteration: 57, Loss: 0.08137069433404688, CE loss: 0.060806453227996826
 Epoch: 42, itteration: 64, Loss: 0.08142428001156077, CE loss: 0.06203674152493477
 Epoch: 42

 Epoch: 50, itteration: 1, Loss: 0.07445535808801651, CE loss: 0.07445535808801651
 Epoch: 50, itteration: 8, Loss: 0.07241826830431819, CE loss: 0.059135422110557556
 Epoch: 50, itteration: 15, Loss: 0.07031754578153292, CE loss: 0.07572156935930252
 Epoch: 50, itteration: 22, Loss: 0.07685899277302352, CE loss: 0.05795037001371384
 Epoch: 50, itteration: 29, Loss: 0.07985376553802655, CE loss: 0.10830766707658768
 Epoch: 50, itteration: 36, Loss: 0.0792450699955225, CE loss: 0.0815947875380516
 Epoch: 50, itteration: 43, Loss: 0.07867744798923648, CE loss: 0.05597880110144615
 Epoch: 50, itteration: 50, Loss: 0.07827681422233582, CE loss: 0.07560042291879654
 Epoch: 50, itteration: 57, Loss: 0.07756092503928301, CE loss: 0.07640589028596878
 Epoch: 50, itteration: 64, Loss: 0.07793939980911091, CE loss: 0.0744161382317543
 Epoch: 50, itteration: 71, Loss: 0.0772814424314969, CE loss: 0.0805392637848854
val_mode

F1 score (F-measure) or DSC in epoch 50:  0.8855179064545303
 Epoch: 51,

 Epoch: 58, itteration: 29, Loss: 0.06829003244638443, CE loss: 0.07124023884534836
 Epoch: 58, itteration: 36, Loss: 0.06644325330853462, CE loss: 0.047809794545173645
 Epoch: 58, itteration: 43, Loss: 0.0670342875947786, CE loss: 0.07272704690694809
 Epoch: 58, itteration: 50, Loss: 0.06632805831730365, CE loss: 0.062900610268116
 Epoch: 58, itteration: 57, Loss: 0.06673532645953328, CE loss: 0.062112338840961456
 Epoch: 58, itteration: 64, Loss: 0.0694422543165274, CE loss: 0.12360340356826782
 Epoch: 58, itteration: 71, Loss: 0.07182859466739104, CE loss: 0.05976495146751404
val_mode

F1 score (F-measure) or DSC in epoch 58:  0.8825924575696655
 Epoch: 59, itteration: 1, Loss: 0.06269074231386185, CE loss: 0.06269074231386185
 Epoch: 59, itteration: 8, Loss: 0.08069103956222534, CE loss: 0.07852267473936081
 Epoch: 59, itteration: 15, Loss: 0.08026069874564806, CE loss: 0.05260134115815163
 Epoch: 59, itteration: 22, Loss: 0.07723127915100618, CE loss: 0.051328860223293304
 Epoch: 

 Epoch: 66, itteration: 57, Loss: 0.08052473417238186, CE loss: 0.04910847172141075
 Epoch: 66, itteration: 64, Loss: 0.07982642145361751, CE loss: 0.06105448305606842
 Epoch: 66, itteration: 71, Loss: 0.07820620524211669, CE loss: 0.06668329983949661
val_mode

F1 score (F-measure) or DSC in epoch 66:  0.8751011967240722
 Epoch: 67, itteration: 1, Loss: 0.07958836853504181, CE loss: 0.07958836853504181
 Epoch: 67, itteration: 8, Loss: 0.07170257484540343, CE loss: 0.08894000947475433
 Epoch: 67, itteration: 15, Loss: 0.06660994440317154, CE loss: 0.06470876932144165
 Epoch: 67, itteration: 22, Loss: 0.0658702262761918, CE loss: 0.06682153046131134
 Epoch: 67, itteration: 29, Loss: 0.06521946602854235, CE loss: 0.07293301820755005
 Epoch: 67, itteration: 36, Loss: 0.0651558522755901, CE loss: 0.07372914254665375
 Epoch: 67, itteration: 43, Loss: 0.06313921355230864, CE loss: 0.05651867017149925
 Epoch: 67, itteration: 50, Loss: 0.0637630382925272, CE loss: 0.07320854067802429
 Epoch: 67

val_mode

F1 score (F-measure) or DSC in epoch 74:  0.8809148919331226
 Epoch: 75, itteration: 1, Loss: 0.04729951173067093, CE loss: 0.04729951173067093
 Epoch: 75, itteration: 8, Loss: 0.05639808997511864, CE loss: 0.04330228269100189
 Epoch: 75, itteration: 15, Loss: 0.0583259051044782, CE loss: 0.06004280224442482
 Epoch: 75, itteration: 22, Loss: 0.05755676329135895, CE loss: 0.05142191797494888
 Epoch: 75, itteration: 29, Loss: 0.05745754596488229, CE loss: 0.04136926680803299
 Epoch: 75, itteration: 36, Loss: 0.056178838221563235, CE loss: 0.0323573499917984
 Epoch: 75, itteration: 43, Loss: 0.05619013569382734, CE loss: 0.057589542120695114
 Epoch: 75, itteration: 50, Loss: 0.057376488372683526, CE loss: 0.06403981894254684
 Epoch: 75, itteration: 57, Loss: 0.057724512982786746, CE loss: 0.07262049615383148
 Epoch: 75, itteration: 64, Loss: 0.056893475470133126, CE loss: 0.04666344076395035
 Epoch: 75, itteration: 71, Loss: 0.05728962183208533, CE loss: 0.0599198117852211
val_m

 Epoch: 83, itteration: 15, Loss: 0.05192873775959015, CE loss: 0.05132221058011055
 Epoch: 83, itteration: 22, Loss: 0.05123907649381594, CE loss: 0.042608074843883514
 Epoch: 83, itteration: 29, Loss: 0.05324153589277432, CE loss: 0.045917704701423645
 Epoch: 83, itteration: 36, Loss: 0.05215149890217516, CE loss: 0.04343942925333977
 Epoch: 83, itteration: 43, Loss: 0.0522069135724112, CE loss: 0.05785563215613365
 Epoch: 83, itteration: 50, Loss: 0.052277185693383216, CE loss: 0.04187930375337601
 Epoch: 83, itteration: 57, Loss: 0.05220464891509006, CE loss: 0.04271477833390236
 Epoch: 83, itteration: 64, Loss: 0.052879819297231734, CE loss: 0.06650989502668381
 Epoch: 83, itteration: 71, Loss: 0.05245793007419143, CE loss: 0.03387034684419632
val_mode

F1 score (F-measure) or DSC in epoch 83:  0.8845692209017504
 Epoch: 84, itteration: 1, Loss: 0.06867147982120514, CE loss: 0.06867147982120514
 Epoch: 84, itteration: 8, Loss: 0.058308208361268044, CE loss: 0.05975813791155815
 Ep

 Epoch: 91, itteration: 36, Loss: 0.05845504812896252, CE loss: 0.05976170301437378
 Epoch: 91, itteration: 43, Loss: 0.05724472032729969, CE loss: 0.04385696351528168
 Epoch: 91, itteration: 50, Loss: 0.056082597970962524, CE loss: 0.04439936578273773
 Epoch: 91, itteration: 57, Loss: 0.05568668409659151, CE loss: 0.04852697625756264
 Epoch: 91, itteration: 64, Loss: 0.05558636784553528, CE loss: 0.05306275933980942
 Epoch: 91, itteration: 71, Loss: 0.05499397076561417, CE loss: 0.05364568158984184
val_mode

F1 score (F-measure) or DSC in epoch 91:  0.8847118864773423
 Epoch: 92, itteration: 1, Loss: 0.03676231950521469, CE loss: 0.03676231950521469
 Epoch: 92, itteration: 8, Loss: 0.048788501881062984, CE loss: 0.06216714158654213
 Epoch: 92, itteration: 15, Loss: 0.05550383850932121, CE loss: 0.06552360206842422
 Epoch: 92, itteration: 22, Loss: 0.05187302116643299, CE loss: 0.04375477880239487
 Epoch: 92, itteration: 29, Loss: 0.052222574707762946, CE loss: 0.05229493975639343
 Epo

 Epoch: 99, itteration: 57, Loss: 0.07981367337337711, CE loss: 0.15638186037540436
 Epoch: 99, itteration: 64, Loss: 0.08056282921461388, CE loss: 0.11715954542160034
 Epoch: 99, itteration: 71, Loss: 0.08374377964457995, CE loss: 0.20568588376045227
val_mode

F1 score (F-measure) or DSC in epoch 99:  0.862531402878219
 Epoch: 100, itteration: 1, Loss: 0.08799128234386444, CE loss: 0.08799128234386444
 Epoch: 100, itteration: 8, Loss: 0.07696897769346833, CE loss: 0.09219802916049957
 Epoch: 100, itteration: 15, Loss: 0.0813539244234562, CE loss: 0.1242685467004776
 Epoch: 100, itteration: 22, Loss: 0.0778660076585683, CE loss: 0.06647955626249313
 Epoch: 100, itteration: 29, Loss: 0.07311194557054289, CE loss: 0.06138480454683304
 Epoch: 100, itteration: 36, Loss: 0.06927791010174486, CE loss: 0.05671593174338341
 Epoch: 100, itteration: 43, Loss: 0.06754006289465483, CE loss: 0.04931176081299782
 Epoch: 100, itteration: 50, Loss: 0.06538884095847607, CE loss: 0.043981995433568954
 E