In [1]:
from __future__ import division
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from loader import *
import pandas as pd
import glob
import argparse
import nibabel as nib
import numpy as np
import copy
import yaml
from model.vit_seg_modeling import VisionTransformer as ViT_seg
from model.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str,
                    default='', help='root dir for data')
parser.add_argument('--dataset', type=str,
                    default='isic2018', help='experiment_name')
parser.add_argument('--list_dir', type=str,
                    default='./lists/lists_Synapse', help='list dir')
parser.add_argument('--num_classes', type=int,
                    default=9, help='output channel of network')
parser.add_argument('--max_iterations', type=int,
                    default=30000, help='maximum epoch number to train')
parser.add_argument('--max_epochs', type=int,
                    default=150, help='maximum epoch number to train')
parser.add_argument('--batch_size', type=int,
                    default=24, help='batch_size per gpu')
parser.add_argument('--n_gpu', type=int, default=1, help='total gpu')
parser.add_argument('--deterministic', type=int,  default=1,
                    help='whether use deterministic training')
parser.add_argument('--base_lr', type=float,  default=0.01,
                    help='segmentation network learning rate')
parser.add_argument('--img_size', type=int,
                    default=256, help='input patch size of network input')
parser.add_argument('--seed', type=int,
                    default=1234, help='random seed')
parser.add_argument('--n_skip', type=int,
                    default=3, help='using number of skip-connect, default is num')
parser.add_argument('--vit_name', type=str,
                    default='R50-ViT-B_16', help='select one vit model')
parser.add_argument('--vit_patches_size', type=int,
                    default=16, help='vit_patches_size, default is 16')
# args = parser.parse_args()
args = parser.parse_args(args=[])

In [3]:
## Loader
## Hyper parameters
config         = yaml.load(open('./config_skin.yml'), Loader=yaml.FullLoader)
number_classes = int(config['number_classes'])
input_channels = 3
best_val_loss  = np.inf
device = 'cuda' if torch.cuda.is_available() else 'cpu'

data_path = config['path_to_data']  

train_dataset = isic_loader(path_Data = data_path, train = True)
train_loader  = DataLoader(train_dataset, batch_size = int(config['batch_size_tr']), shuffle= True)
val_dataset   = isic_loader(path_Data = data_path, train = False)
val_loader    = DataLoader(val_dataset, batch_size = int(config['batch_size_va']), shuffle= False)


In [4]:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
dataset_name = args.dataset
dataset_config = {
    'Synapse': {
        'root_path': '',
        'list_dir': '',
        'num_classes': 1,
    },
}
args.num_classes = 1
args.root_path = ''
args.list_dir = ''
args.is_pretrain = True
args.exp = 'TU_' + dataset_name + str(args.img_size)
snapshot_path = "./model_results/{}/{}".format(args.exp, 'TU')
snapshot_path = snapshot_path + '_pretrain' if args.is_pretrain else snapshot_path
snapshot_path += '_' + args.vit_name
snapshot_path = snapshot_path + '_skip' + str(args.n_skip)
snapshot_path = snapshot_path + '_vitpatch' + str(args.vit_patches_size) if args.vit_patches_size!=16 else snapshot_path
snapshot_path = snapshot_path+'_'+str(args.max_iterations)[0:2]+'k' if args.max_iterations != 30000 else snapshot_path
snapshot_path = snapshot_path + '_epo' +str(args.max_epochs) if args.max_epochs != 30 else snapshot_path
snapshot_path = snapshot_path+'_bs'+str(args.batch_size)
snapshot_path = snapshot_path + '_lr' + str(args.base_lr) if args.base_lr != 0.01 else snapshot_path
snapshot_path = snapshot_path + '_'+str(args.img_size)
snapshot_path = snapshot_path + '_s'+str(args.seed) if args.seed!=1234 else snapshot_path

if not os.path.exists(snapshot_path):
    os.makedirs(snapshot_path)
config_vit = CONFIGS_ViT_seg[args.vit_name]
config_vit.n_classes = args.num_classes
config_vit.n_skip = args.n_skip
if args.vit_name.find('R50') != -1:
    config_vit.patches.grid = (int(args.img_size / args.vit_patches_size), int(args.img_size / args.vit_patches_size))
Net = ViT_seg(config_vit, img_size=args.img_size, num_classes=config_vit.n_classes).to(device)

In [5]:
optimizer = optim.SGD(Net.parameters(), lr=args.base_lr, momentum=0.9, weight_decay=0.0001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor = 0.5, patience = config['patience'])
criteria  = torch.nn.BCELoss()


In [6]:
for ep in range(int(config['epochs'])):
    Net.train()
    epoch_loss = 0
    for itter, batch in enumerate(train_loader):
        img = batch['image'].to(device, dtype=torch.float)
        msk = batch['mask'].to(device)
        mask_type = torch.float32
        msk = msk.to(device=device, dtype=mask_type)
        msk_pred = Net(img)
        loss          = criteria(msk_pred, msk) 
        optimizer.zero_grad()
        loss.backward()
        epoch_loss += loss.item()
        optimizer.step()  
        if itter%int(float(config['progress_p']) * len(train_loader))==0:
            print(f' Epoch>> {ep+1} and itteration {itter+1} Loss>> {((epoch_loss/(itter+1)))}')
    ## Validation phase
    with torch.no_grad():
        print('val_mode')
        val_loss = 0
        Net.eval()
        for itter, batch in enumerate(val_loader):
            img = batch['image'].to(device, dtype=torch.float)
            msk = batch['mask'].to(device)
            mask_type = torch.float32
            msk = msk.to(device=device, dtype=mask_type)
            msk_pred = Net(img)
            loss = criteria(msk_pred, msk) 
            val_loss += loss.item()
        print(f' validation on epoch>> {ep+1} dice loss>> {(abs(val_loss/(itter+1)))}')     
        mean_val_loss = (val_loss/(itter+1))
        # Check the performance and save the model
        if (mean_val_loss) < best_val_loss:
            print('New best loss, saving...')
            best_val_loss = copy.deepcopy(mean_val_loss)
            state = copy.deepcopy({'model_weights': Net.state_dict(), 'val_loss': best_val_loss})
            torch.save(state, config['saved_model'])

    scheduler.step(mean_val_loss)
    
print('Trainng phase finished')    

 Epoch>> 1 and itteration 1 Loss>> 0.6713525056838989
 Epoch>> 1 and itteration 91 Loss>> 0.4028219772236688
 Epoch>> 1 and itteration 181 Loss>> 0.38357710797154443
 Epoch>> 1 and itteration 271 Loss>> 0.3657187595688549
 Epoch>> 1 and itteration 361 Loss>> 0.35696848865773867
 Epoch>> 1 and itteration 451 Loss>> 0.33765244406303124
 Epoch>> 1 and itteration 541 Loss>> 0.3199430010152392
 Epoch>> 1 and itteration 631 Loss>> 0.3071564765980618
 Epoch>> 1 and itteration 721 Loss>> 0.29575338292655107
 Epoch>> 1 and itteration 811 Loss>> 0.28973091795636313
 Epoch>> 1 and itteration 901 Loss>> 0.2829993542502908
val_mode
 validation on epoch>> 1 dice loss>> 0.19406575661756703
New best loss, saving...
 Epoch>> 2 and itteration 1 Loss>> 0.13811182975769043
 Epoch>> 2 and itteration 91 Loss>> 0.18968174541062052
 Epoch>> 2 and itteration 181 Loss>> 0.23027639584976006
 Epoch>> 2 and itteration 271 Loss>> 0.21821508584097302
 Epoch>> 2 and itteration 361 Loss>> 0.21367615019680722
 Epoch>> 

 Epoch>> 12 and itteration 811 Loss>> 0.11547165150086955
 Epoch>> 12 and itteration 901 Loss>> 0.11584056242184092
val_mode
 validation on epoch>> 12 dice loss>> 0.15677177614953125
 Epoch>> 13 and itteration 1 Loss>> 0.03681149333715439
 Epoch>> 13 and itteration 91 Loss>> 0.12272640525435026
 Epoch>> 13 and itteration 181 Loss>> 0.1060704005295565
 Epoch>> 13 and itteration 271 Loss>> 0.1027420899827322
 Epoch>> 13 and itteration 361 Loss>> 0.10210066516767578
 Epoch>> 13 and itteration 451 Loss>> 0.10166595992683836
 Epoch>> 13 and itteration 541 Loss>> 0.10280110451862795
 Epoch>> 13 and itteration 631 Loss>> 0.10400714289851593
 Epoch>> 13 and itteration 721 Loss>> 0.10462889788819583
 Epoch>> 13 and itteration 811 Loss>> 0.10559485861905259
 Epoch>> 13 and itteration 901 Loss>> 0.11066437584681893
val_mode
 validation on epoch>> 13 dice loss>> 0.13818974642291607
New best loss, saving...
 Epoch>> 14 and itteration 1 Loss>> 0.08348086476325989
 Epoch>> 14 and itteration 91 Loss>>

 Epoch>> 24 and itteration 361 Loss>> 0.09099230538135256
 Epoch>> 24 and itteration 451 Loss>> 0.09119318590443705
 Epoch>> 24 and itteration 541 Loss>> 0.0907432349009677
 Epoch>> 24 and itteration 631 Loss>> 0.08810986321944915
 Epoch>> 24 and itteration 721 Loss>> 0.0908886490779801
 Epoch>> 24 and itteration 811 Loss>> 0.0892214288719173
 Epoch>> 24 and itteration 901 Loss>> 0.08979873841506461
val_mode
 validation on epoch>> 24 dice loss>> 0.13329794751228513
 Epoch>> 25 and itteration 1 Loss>> 0.039466045796871185
 Epoch>> 25 and itteration 91 Loss>> 0.08355986046000988
 Epoch>> 25 and itteration 181 Loss>> 0.09869541953580238
 Epoch>> 25 and itteration 271 Loss>> 0.10130206137428205
 Epoch>> 25 and itteration 361 Loss>> 0.0961293187637367
 Epoch>> 25 and itteration 451 Loss>> 0.09342845629264669
 Epoch>> 25 and itteration 541 Loss>> 0.09363412459392975
 Epoch>> 25 and itteration 631 Loss>> 0.09756277826881825
 Epoch>> 25 and itteration 721 Loss>> 0.09313040495373505
 Epoch>> 25

val_mode
 validation on epoch>> 35 dice loss>> 0.13571898135424926
 Epoch>> 36 and itteration 1 Loss>> 0.1970347762107849
 Epoch>> 36 and itteration 91 Loss>> 0.0699596681253432
 Epoch>> 36 and itteration 181 Loss>> 0.07400471992399646
 Epoch>> 36 and itteration 271 Loss>> 0.07284666310198523
 Epoch>> 36 and itteration 361 Loss>> 0.074494489762736
 Epoch>> 36 and itteration 451 Loss>> 0.0726635915179739
 Epoch>> 36 and itteration 541 Loss>> 0.07224928387724786
 Epoch>> 36 and itteration 631 Loss>> 0.0757657031133098
 Epoch>> 36 and itteration 721 Loss>> 0.07678460796401111
 Epoch>> 36 and itteration 811 Loss>> 0.0762702076473077
 Epoch>> 36 and itteration 901 Loss>> 0.07875575491647907
val_mode
 validation on epoch>> 36 dice loss>> 0.13737063328142216
 Epoch>> 37 and itteration 1 Loss>> 0.11091698706150055
 Epoch>> 37 and itteration 91 Loss>> 0.07021547916376002
 Epoch>> 37 and itteration 181 Loss>> 0.0689175667688868
 Epoch>> 37 and itteration 271 Loss>> 0.07221505959514816
 Epoch>> 3

 Epoch>> 47 and itteration 631 Loss>> 0.06427540589735006
 Epoch>> 47 and itteration 721 Loss>> 0.06502274173310597
 Epoch>> 47 and itteration 811 Loss>> 0.06834454713940492
 Epoch>> 47 and itteration 901 Loss>> 0.06931787715567385
val_mode
 validation on epoch>> 47 dice loss>> 0.13686747658408058
 Epoch>> 48 and itteration 1 Loss>> 0.15090720355510712
 Epoch>> 48 and itteration 91 Loss>> 0.07771419795168626
 Epoch>> 48 and itteration 181 Loss>> 0.07226358300321915
 Epoch>> 48 and itteration 271 Loss>> 0.07236285348048234
 Epoch>> 48 and itteration 361 Loss>> 0.06785390009443144
 Epoch>> 48 and itteration 451 Loss>> 0.06655057260411443
 Epoch>> 48 and itteration 541 Loss>> 0.06613974446988563
 Epoch>> 48 and itteration 631 Loss>> 0.06632902610843348
 Epoch>> 48 and itteration 721 Loss>> 0.06567612064656143
 Epoch>> 48 and itteration 811 Loss>> 0.06451988075463659
 Epoch>> 48 and itteration 901 Loss>> 0.06475601580301993
val_mode
 validation on epoch>> 48 dice loss>> 0.1415470859025062
