In [1]:
from __future__ import division
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from loader import *
from model.TransMUNet import TransMUNet
import pandas as pd
import glob
import nibabel as nib
import numpy as np
import copy
import yaml

In [2]:
## Loader
## Hyper parameters
config         = yaml.load(open('./config_skin.yml'), Loader=yaml.FullLoader)
number_classes = int(config['number_classes'])
input_channels = 3
best_val_loss  = np.inf
device = 'cuda' if torch.cuda.is_available() else 'cpu'

data_path = config['path_to_data']  

train_dataset = isic_loader(path_Data = data_path, train = True)
train_loader  = DataLoader(train_dataset, batch_size = int(config['batch_size_tr']), shuffle= True)
val_dataset   = isic_loader(path_Data = data_path, train = False)
val_loader    = DataLoader(val_dataset, batch_size = int(config['batch_size_va']), shuffle= False)


In [3]:
Net = TransMUNet(n_classes = number_classes)

Net = Net.to(device)
if int(config['pretrained']):
    Net.load_state_dict(torch.load(config['saved_model'], map_location='cpu')['model_weights'])
    best_val_loss = torch.load(config['saved_model'], map_location='cpu')['val_loss']
optimizer = optim.Adam(Net.parameters(), lr= float(config['lr']))
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor = 0.5, patience = config['patience'])
criteria  = torch.nn.BCELoss()
criteria_boundary  = torch.nn.BCELoss()
criteria_region = torch.nn.MSELoss()

In [None]:
for ep in range(int(config['epochs'])):
    Net.train()
    epoch_loss = 0
    for itter, batch in enumerate(train_loader):
        img = batch['image'].to(device, dtype=torch.float)
        msk = batch['mask'].to(device)
        weak_ann = batch['weak_ann'].to(device)
        boundary = batch['boundary'].to(device)
        mask_type = torch.float32 if Net.n_classes == 1 else torch.long
        msk = msk.to(device=device, dtype=mask_type)
        weak_ann = weak_ann.to(device=device, dtype=mask_type)
        boundary = boundary.to(device=device, dtype=mask_type)
        msk_pred, B, R = Net(img, with_additional=True)
        loss          = criteria(msk_pred, msk) 
        loss_regions  = criteria_region(weak_ann[:,0], R[:,:-1,0])
        loss_boundary = criteria(msk_pred, msk)  
        tloss    = (0.7*(loss)) + (0.1* loss_regions) + (0.2*loss_boundary)
        optimizer.zero_grad()
        tloss.backward()
        epoch_loss += tloss.item()
        optimizer.step()  
        if itter%int(float(config['progress_p']) * len(train_loader))==0:
            print(f' Epoch>> {ep+1} and itteration {itter+1} Loss>> {((epoch_loss/(itter+1)))}')
    ## Validation phase
    with torch.no_grad():
        print('val_mode')
        val_loss = 0
        Net.eval()
        for itter, batch in enumerate(val_loader):
            img = batch['image'].to(device, dtype=torch.float)
            msk = batch['mask'].to(device)
            mask_type = torch.float32 if Net.n_classes == 1 else torch.long
            msk = msk.to(device=device, dtype=mask_type)
            msk_pred = Net(img)
            loss = criteria(msk_pred, msk) 
            val_loss += loss.item()
        print(f' validation on epoch>> {ep+1} dice loss>> {(abs(val_loss/(itter+1)))}')     
        mean_val_loss = (val_loss/(itter+1))
        # Check the performance and save the model
        if (mean_val_loss) < best_val_loss:
            print('New best loss, saving...')
            best_val_loss = copy.deepcopy(mean_val_loss)
            state = copy.deepcopy({'model_weights': Net.state_dict(), 'val_loss': best_val_loss})
            torch.save(state, config['saved_model'])

    scheduler.step(mean_val_loss)
    
print('Trainng phase finished')    

 Epoch>> 1 and itteration 1 Loss>> 0.4878859519958496
 Epoch>> 1 and itteration 46 Loss>> 0.436914870272512
 Epoch>> 1 and itteration 91 Loss>> 0.406768771645787
 Epoch>> 1 and itteration 136 Loss>> 0.387475683408625
 Epoch>> 1 and itteration 181 Loss>> 0.37029958166470184
 Epoch>> 1 and itteration 226 Loss>> 0.3621396487001824
 Epoch>> 1 and itteration 271 Loss>> 0.35190100401529967
 Epoch>> 1 and itteration 316 Loss>> 0.34527019302867634
 Epoch>> 1 and itteration 361 Loss>> 0.340015287214369
 Epoch>> 1 and itteration 406 Loss>> 0.3345703736975275
 Epoch>> 1 and itteration 451 Loss>> 0.3311662516017711
val_mode
 validation on epoch>> 1 dice loss>> 0.2752935446539901
New best loss, saving...
 Epoch>> 2 and itteration 1 Loss>> 0.28962594270706177
 Epoch>> 2 and itteration 46 Loss>> 0.29288280528524646
 Epoch>> 2 and itteration 91 Loss>> 0.2938939706309811
 Epoch>> 2 and itteration 136 Loss>> 0.2820981578791843
 Epoch>> 2 and itteration 181 Loss>> 0.2701566784092076
 Epoch>> 2 and ittera

 Epoch>> 12 and itteration 361 Loss>> 0.12610800418786064
 Epoch>> 12 and itteration 406 Loss>> 0.12551013803283864
 Epoch>> 12 and itteration 451 Loss>> 0.12573659851188934
val_mode
 validation on epoch>> 12 dice loss>> 0.14761795142744613
 Epoch>> 13 and itteration 1 Loss>> 0.06830105185508728
 Epoch>> 13 and itteration 46 Loss>> 0.1209399577068246
 Epoch>> 13 and itteration 91 Loss>> 0.11847573682502076
 Epoch>> 13 and itteration 136 Loss>> 0.1216593860539005
 Epoch>> 13 and itteration 181 Loss>> 0.12231733466791843
 Epoch>> 13 and itteration 226 Loss>> 0.12253736996347397
 Epoch>> 13 and itteration 271 Loss>> 0.12341048696329233
 Epoch>> 13 and itteration 316 Loss>> 0.12317021874875962
 Epoch>> 13 and itteration 361 Loss>> 0.12268629867242974
 Epoch>> 13 and itteration 406 Loss>> 0.1247509588473683
 Epoch>> 13 and itteration 451 Loss>> 0.125241905706883
val_mode
 validation on epoch>> 13 dice loss>> 0.13717417760921147
New best loss, saving...
 Epoch>> 14 and itteration 1 Loss>> 0.

 Epoch>> 24 and itteration 181 Loss>> 0.09348085230391329
 Epoch>> 24 and itteration 226 Loss>> 0.09155483698818535
 Epoch>> 24 and itteration 271 Loss>> 0.08984833234075691
 Epoch>> 24 and itteration 316 Loss>> 0.09034567181303908
 Epoch>> 24 and itteration 361 Loss>> 0.09153363255169913
 Epoch>> 24 and itteration 406 Loss>> 0.09166926263918725
 Epoch>> 24 and itteration 451 Loss>> 0.09071590649810704
val_mode
 validation on epoch>> 24 dice loss>> 0.13262791471541446
 Epoch>> 25 and itteration 1 Loss>> 0.06768117845058441
 Epoch>> 25 and itteration 46 Loss>> 0.08301182301796001
 Epoch>> 25 and itteration 91 Loss>> 0.08083239103575329
 Epoch>> 25 and itteration 136 Loss>> 0.08231433286495946
 Epoch>> 25 and itteration 181 Loss>> 0.08278180564141405
 Epoch>> 25 and itteration 226 Loss>> 0.08338782239078948
 Epoch>> 25 and itteration 271 Loss>> 0.08366489682580272
 Epoch>> 25 and itteration 316 Loss>> 0.08633823188233979
 Epoch>> 25 and itteration 361 Loss>> 0.08759486152070711
 Epoch>> 

 Epoch>> 36 and itteration 46 Loss>> 0.058433276479658874
 Epoch>> 36 and itteration 91 Loss>> 0.05948395443732267
 Epoch>> 36 and itteration 136 Loss>> 0.059972569675130004
 Epoch>> 36 and itteration 181 Loss>> 0.060879789303103206
 Epoch>> 36 and itteration 226 Loss>> 0.06133900438559003
 Epoch>> 36 and itteration 271 Loss>> 0.06139697876116226
 Epoch>> 36 and itteration 316 Loss>> 0.06330230644067066
 Epoch>> 36 and itteration 361 Loss>> 0.06435676964923451
 Epoch>> 36 and itteration 406 Loss>> 0.06426818171984015
 Epoch>> 36 and itteration 451 Loss>> 0.06382929415188053
val_mode
 validation on epoch>> 36 dice loss>> 0.1419757434857369
 Epoch>> 37 and itteration 1 Loss>> 0.06270346790552139
 Epoch>> 37 and itteration 46 Loss>> 0.05726564959015535
 Epoch>> 37 and itteration 91 Loss>> 0.06073792284907221
 Epoch>> 37 and itteration 136 Loss>> 0.062088602948386
 Epoch>> 37 and itteration 181 Loss>> 0.06590296876570467
 Epoch>> 37 and itteration 226 Loss>> 0.06604115374023671
 Epoch>> 37

 Epoch>> 47 and itteration 406 Loss>> 0.0537252500561511
