In [1]:
import sys
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utility import *
from data_utils import *
from model import *
from loss import *
from train import *
from sync_batchnorm import convert_model
import datetime

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
USE_GPU = True
NUM_WORKERS = 6
BATCH_SIZE = 2 

dtype = torch.float32 
# define dtype, float is space efficient than double

if USE_GPU and torch.cuda.is_available():
    
    device = torch.device('cuda')
    
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    # magic flag that accelerate
    
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


In [3]:
train_dataset = pyramid_dataset(data_type = 'nii_train', 
                transform=transforms.Compose([
                random_affine(90, 15),
                random_filp(0.5)]))
# do data augumentation on train dataset

validation_dataset = pyramid_dataset(data_type = 'nii_test', 
                transform=None)
# no data augumentation on validation dataset

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS) # drop_last
# loaders come with auto batch division and multi-thread acceleration

In [None]:
def shape_test(model, cuda_bool):
    x = torch.zeros((1, 1, 256, 256, 256))
    x = x.to(device=device, dtype=dtype) if cuda_bool else x
    scores = model(x)
    for i in scores:
        print(i.size())

In [None]:
from model import *

icnet1 = ModifiedICNet(num_classes=3)
icnet1.apply(init_weights)
icnet1 = icnet1.to(device=device, dtype=dtype)
shape_test(icnet1, True)
# create the model, by default model type is float, use model.double(), model.float() to convert
# move the model to desirable device

optimizer1 = optim.Adam(icnet1.parameters(), lr=1e-2)
# create an optimizer object
# note that only the model_2 params and model_4 params will be optimized by optimizer

In [9]:
icnet1 = ModifiedICNet(num_classes=3)
icnet1 = nn.DataParallel(icnet1)
#icnet1 = convert_model(icnet1)
optimizer1 = optim.Adam(icnet1.parameters(), lr=1e-2)

checkpoint = torch.load('../half_res_save/2019-07-26 13:06:55.389831.pth')

icnet1.load_state_dict(checkpoint['state_dict_1'])
icnet1 = icnet1.to(device=device, dtype=dtype)

optimizer1.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

In [8]:
epochs = 5000

record= open('train_half_res_continue.txt','a')

logger = {'train':[], 'validation_1': []}

for e in tqdm(range(epochs)):
# iter over epoches

    epoch_loss = 0
        
    for t, batch in enumerate(train_loader):
    # iter over the train mini batches
    
        icnet1.train()
        # Set the model flag to train
        # 1. enable dropout
        # 2. batchnorm behave differently in train and test
        
        image_4 = batch['image4_data'].to(device=device, dtype=dtype)
        label_4 = batch['image4_label'].to(device=device, dtype=dtype)
        
        image_2 = batch['image2_data'].to(device=device, dtype=dtype)
        label_2 = batch['image2_label'].to(device=device, dtype=dtype)
        
        image_1 = batch['image1_data'].to(device=device, dtype=dtype)
        label_1 = batch['image1_label'].to(device=device, dtype=dtype)
        # move data to device, convert dtype to desirable dtype
        
        # Downsample labels to coincide with icnet model outputs
        label_1_resize_2 = downsample_label(label_1, 1/2)
        label_2_resize_2 = downsample_label(label_2, 1/2)
        label_4_resize_2 = downsample_label(label_4, 1/2)
        
        out_1, out_2, out_4 = icnet1(image_1)
        # do the inference

        loss_4 = dice_loss_3(out_4, label_4_resize_2)
        loss_2 = dice_loss_3(out_2, label_2_resize_2)
        loss_1 = dice_loss_3(out_1, label_1_resize_2)
        # calculate loss

        loss = loss_4 + loss_2 + loss_1 
        # add loss
        
        epoch_loss += loss.item()
        # record minibatch loss to epoch loss
        
        optimizer1.zero_grad()
        # set the model parameter gradient to zero
        
        loss.backward()
        # calculate the gradient wrt loss
        
        optimizer1.step()
        # take a gradient descent step
        
    outstr = 'Epoch {0} finished ! Training Loss: {1:.4f}'.format(e, epoch_loss/(t+1)) + '\n'
    
    logger['train'].append(epoch_loss/(t+1))
    
    print(outstr)
    record.write(outstr)
    record.flush()

    if e%5 == 4:
    # do validation every 5 epoches
    
        icnet1.eval()
        # set model flag to eval
        # 1. disable dropout
        # 2. batchnorm behave differs

        with torch.no_grad():
        # stop taking gradient
        
            #valloss_4 = 0
            #valloss_2 = 0
            valloss_1 = 0
            
            for v, vbatch in enumerate(validation_loader):
            # iter over validation mini batches
                
                image_1_val = vbatch['image1_data'].to(device=device, dtype=dtype)
                if get_dimensions(image_1_val) == 4:
                    image_1_val.unsqueeze_(0)
                label_1_val = vbatch['image1_label'].to(device=device, dtype=dtype)
                if get_dimensions(label_1_val) == 4:
                    label_1_val.unsqueeze_(0)
                # move data to device, convert dtype to desirable dtype
                # add one dimension to labels if they are 4D tensors
                
                # Downsample labels to coincide with icnet model outputs
                label_1_val_resize_2 = downsample_label(label_1_val, 1/2) 
                
                out_1_val = icnet1(image_1_val)
                # do the inference
                
                loss_1 = dice_loss_3(out_1_val, label_1_val_resize_2)
                # calculate loss

                valloss_1 += loss_1.item()
                # record mini batch loss
            
            outstr = '------- 1st valloss={0:.4f}'\
                .format(valloss_1/(v+1)) + '\n'
            
            logger['validation_1'].append(valloss_1/(v+1))

            print(outstr)
            record.write(outstr)
            record.flush()
            
            save_1('half_res_save', icnet1, optimizer1, logger, e)

record.close()

  0%|          | 0/5000 [00:00<?, ?it/s]


KeyboardInterrupt: 

In [11]:
icnet1.eval()

with torch.no_grad():
    
    bgloss = 0
    bdloss = 0
    bvloss = 0
    
    for v, vbatch in tqdm(enumerate(validation_loader)):
            # move data to device, convert dtype to desirable dtype

        image_1 = vbatch['image1_data'].to(device=device, dtype=dtype)
        label_1 = vbatch['image1_label'].to(device=device, dtype=dtype)

        output = icnet1(image_1)
        # do the inference

        out_1 = torch.round(F.upsample(output, scale_factor=2))
        loss_1 = dice_loss_3(out_1, label_1)

        bg, bd, bv = dice_loss_3_debug(out_1, label_1)
        # calculate loss
        print(bg.item(), bd.item(), bv.item(), loss_1.item())
        bgloss += bg.item()
        bdloss += bd.item()
        bvloss += bv.item()

    outstr = '------- background loss = {0:.4f}, body loss = {1:.4f}, bv loss = {2:.4f}'\
        .format(bgloss/(v+1), bdloss/(v+1), bvloss/(v+1)) + '\n'
    print(outstr)

1it [00:14, 14.99s/it]

0.0052602291107177734 0.0847400426864624 0.23513567447662354 0.10837864875793457


2it [00:15, 10.58s/it]

0.004479885101318359 0.1279239058494568 0.20461124181747437 0.1123383492231369


3it [00:15,  7.56s/it]

0.007687509059906006 0.1340082883834839 0.6134685277938843 0.25172144174575806


4it [00:16,  5.38s/it]

0.004954695701599121 0.06959617137908936 0.22952210903167725 0.10135766118764877


5it [00:16,  3.92s/it]

0.005783796310424805 0.08450788259506226 0.2346789836883545 0.10832355916500092


6it [00:16,  2.83s/it]

0.004691004753112793 0.13996100425720215 0.19551748037338257 0.11338983476161957


7it [00:28,  5.60s/it]

0.0063089728355407715 0.20133793354034424 0.1454073190689087 0.1176847442984581


8it [00:29,  4.00s/it]

0.0068756937980651855 0.12412106990814209 0.21736931800842285 0.11612202972173691


9it [00:29,  2.96s/it]

0.005072355270385742 0.09443116188049316 0.14227402210235596 0.08059251308441162


10it [00:30,  2.16s/it]

0.004362940788269043 0.11617636680603027 0.1791277527809143 0.09988902509212494


11it [00:30,  1.66s/it]

0.0032694339752197266 0.08180522918701172 0.3444911241531372 0.14318859577178955


12it [00:30,  1.25s/it]

0.005502223968505859 0.17943108081817627 0.3071554899215698 0.16402959823608398


13it [00:38,  3.21s/it]

0.006381094455718994 0.1459430456161499 0.17629224061965942 0.10953879356384277


14it [00:38,  2.33s/it]

0.0061261653900146484 0.11813539266586304 0.14944398403167725 0.09123518317937851


15it [00:39,  1.77s/it]

0.00417017936706543 0.11263483762741089 0.37312567234039307 0.1633102297782898


16it [00:39,  1.33s/it]

0.004286646842956543 0.1115034818649292 0.16439056396484375 0.09339356422424316


17it [00:40,  1.09s/it]

0.004506230354309082 0.12906861305236816 0.15003246068954468 0.09453576803207397


18it [00:40,  1.18it/s]

0.007269501686096191 0.1112661361694336 0.16603732109069824 0.09485765546560287


19it [00:49,  3.20s/it]

0.007206380367279053 0.08918595314025879 0.4792160391807556 0.19186946749687195


20it [00:49,  2.33s/it]

0.005843400955200195 0.12397027015686035 0.5860595107078552 0.23862439393997192


21it [00:49,  1.77s/it]

0.005456209182739258 0.08646136522293091 0.3016459345817566 0.13118784129619598


22it [00:50,  1.33s/it]

0.004150092601776123 0.10980939865112305 0.22542667388916016 0.11312872171401978


23it [00:50,  1.07s/it]

0.009096503257751465 0.18356263637542725 0.2112862467765808 0.13464847207069397
------- background loss = 0.0056, body loss = 0.1200, bv loss = 0.2622




