In [1]:
import sys
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utility import *
from data_utils import *
from loss import *
from train import *
from deeplab_model.deeplab import *
from sync_batchnorm import convert_model
import datetime

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
USE_GPU = True
NUM_WORKERS = 12
BATCH_SIZE = 2 

dtype = torch.float32 
# define dtype, float is space efficient than double

if USE_GPU and torch.cuda.is_available():
    
    device = torch.device('cuda')
    
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    # magic flag that accelerate
    
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


In [3]:
train_dataset = pyramid_dataset(data_type = 'nii_train', 
                transform=transforms.Compose([
                random_affine(90, 15),
                random_filp(0.5)]))
# do data augumentation on train dataset

validation_dataset = pyramid_dataset(data_type = 'nii_test', 
                transform=None)
# no data augumentation on validation dataset

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS) # drop_last
# loaders come with auto batch division and multi-thread acceleration

In [4]:
'''
deeplab = DeepLab_ELU(output_stride=8)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)
deeplab = deeplab.to(device=device, dtype=dtype)
#shape_test(icnet1, True)
# create the model, by default model type is float, use model.double(), model.float() to convert
# move the model to desirable device

optimizer = optim.Adam(deeplab.parameters(), lr=1e-2)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)
epoch = 0

# create an optimizer object
# note that only the model_2 params and model_4 params will be optimized by optimizer
'''

"\ndeeplab = DeepLab_ELU(output_stride=8)\ndeeplab = nn.DataParallel(deeplab)\ndeeplab = convert_model(deeplab)\ndeeplab = deeplab.to(device=device, dtype=dtype)\n#shape_test(icnet1, True)\n# create the model, by default model type is float, use model.double(), model.float() to convert\n# move the model to desirable device\n\noptimizer = optim.Adam(deeplab.parameters(), lr=1e-2)\nscheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)\nepoch = 0\n\n# create an optimizer object\n# note that only the model_2 params and model_4 params will be optimized by optimizer\n"

In [None]:

deeplab = DeepLab_ELU(output_stride=8)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)

#checkpoint = torch.load('../deeplab_save/2019-07-29 04:00:14.630172.pth') # second best
#checkpoint = torch.load('../deeplab_save/2019-07-28 23:47:36.279119.pth') # second best
#checkpoint = torch.load('../deeplab_save/2019-07-29 00:15:49.271222.pth') # best
#checkpoint = torch.load('../deeplab_save/2019-07-29 00:44:11.825872.pth')
checkpoint = torch.load('../deeplab_output_8_elu_save/2019-08-20 21:21:15.471972 epoch: 350.pth') # latest one

deeplab.load_state_dict(checkpoint['state_dict_1'])
deeplab = deeplab.to(device=device, dtype=dtype)

optimizer = optim.Adam(deeplab.parameters(), lr=1e-2)
optimizer.load_state_dict(checkpoint['optimizer'])

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=25)
scheduler.load_state_dict(checkpoint['scheduler'])

epoch = checkpoint['epoch']
print(epoch)
for param_group in optimizer.param_groups:
    print(param_group['lr'])


350
0.01


In [None]:
epochs = 5000

min_val = .0621

record = open('train_deeplab_output_8_elu.txt','a+')

logger = {'train':[], 'validation_1': []}

for e in tqdm(range(epoch + 1, epochs)):
# iter over epoches

    epoch_loss = 0
        
    for t, batch in enumerate(train_loader):
    # iter over the train mini batches
    
        deeplab.train()
        # Set the model flag to train
        # 1. enable dropout
        # 2. batchnorm behave differently in train and test
        
        image_1 = batch['image1_data'].to(device=device, dtype=dtype)
        label_1 = batch['image1_label'].to(device=device, dtype=dtype)
        # move data to device, convert dtype to desirable dtype
        
        out_1 = deeplab(image_1)
        # do the inference

        loss_1 = dice_loss_3(out_1, label_1)
        # calculate loss
        
        epoch_loss += loss_1.item()
        # record minibatch loss to epoch loss
        
        optimizer.zero_grad()
        # set the model parameter gradient to zero
        
        loss_1.backward()
        # calculate the gradient wrt loss
        optimizer.step()
        #scheduler.step(loss_1)
        # take a gradient descent step
        
    outstr = 'Epoch {0} finished ! Training Loss: {1:.4f}'.format(e, epoch_loss/(t+1)) + '\n'
    
    logger['train'].append(epoch_loss/(t+1))
    
    print(outstr)
    record.write(outstr)
    record.flush()
    
    if (e <= 150 and e%5 == 0) or (e > 150 and e%1 == 0):
        deeplab.eval()
        # set model flag to eval
        # 1. disable dropout
        # 2. batchnorm behave differs

        with torch.no_grad():
        # stop taking gradient

            #valloss_4 = 0
            #valloss_2 = 0
            valloss_1 = 0

            for v, vbatch in enumerate(validation_loader):
            # iter over validation mini batches

                image_1_val = vbatch['image1_data'].to(device=device, dtype=dtype)
                if get_dimensions(image_1_val) == 4:
                    image_1_val.unsqueeze_(0)
                label_1_val = vbatch['image1_label'].to(device=device, dtype=dtype)
                if get_dimensions(label_1_val) == 4:
                    label_1_val.unsqueeze_(0)
                # move data to device, convert dtype to desirable dtype
                # add one dimension to labels if they are 4D tensors

                out_1_val = deeplab(image_1_val)
                # do the inference

                loss_1 = dice_loss_3(out_1_val, label_1_val)
                # calculate loss

                valloss_1 += loss_1.item()
                # record mini batch loss

            avg_val_loss = (valloss_1 / (v+1))
            outstr = '------- 1st valloss={0:.4f}'\
                .format(avg_val_loss) + '\n'

            logger['validation_1'].append(avg_val_loss)
            #scheduler.step(avg_val_loss)

            print(outstr)
            record.write(outstr)
            record.flush()

            if avg_val_loss < min_val:
                print(avg_val_loss, "less than", min_val)
                min_val = avg_val_loss
                save_1('deeplab_output_8_elu_save', deeplab, optimizer, logger, e, scheduler)
            elif e%10 == 0:
                save_1('deeplab_output_8_elu_save', deeplab, optimizer, logger, e, scheduler)

record.close()

  0%|          | 0/4649 [00:00<?, ?it/s]

Epoch 351 finished ! Training Loss: 0.0718



  0%|          | 1/4649 [13:57<1081:43:26, 837.82s/it]

------- 1st valloss=0.0765

Epoch 352 finished ! Training Loss: 0.0736



  0%|          | 2/4649 [26:33<1049:32:04, 813.07s/it]

------- 1st valloss=0.0635

Epoch 353 finished ! Training Loss: 0.0720



  0%|          | 3/4649 [39:07<1026:30:46, 795.40s/it]

------- 1st valloss=0.0727

Epoch 354 finished ! Training Loss: 0.0746



  0%|          | 4/4649 [51:37<1008:48:36, 781.85s/it]

------- 1st valloss=0.0686

Epoch 355 finished ! Training Loss: 0.0747



  0%|          | 5/4649 [1:04:39<1008:40:24, 781.92s/it]

------- 1st valloss=0.0671

Epoch 356 finished ! Training Loss: 0.0714



  0%|          | 6/4649 [1:17:29<1003:37:38, 778.17s/it]

------- 1st valloss=0.0664

Epoch 357 finished ! Training Loss: 0.0745



  0%|          | 7/4649 [1:30:07<995:49:06, 772.28s/it] 

------- 1st valloss=0.0705

Epoch 358 finished ! Training Loss: 0.0723



  0%|          | 8/4649 [1:42:52<992:41:11, 770.02s/it]

------- 1st valloss=0.0658

Epoch 359 finished ! Training Loss: 0.0722



  0%|          | 9/4649 [1:55:44<993:22:16, 770.72s/it]

------- 1st valloss=0.1516

Epoch 360 finished ! Training Loss: 0.0728

------- 1st valloss=0.0681



  0%|          | 10/4649 [2:08:24<988:56:13, 767.44s/it]

Checkpoint 360 saved !
Epoch 361 finished ! Training Loss: 0.0738



  0%|          | 11/4649 [2:20:58<983:29:59, 763.39s/it]

------- 1st valloss=0.0631

Epoch 362 finished ! Training Loss: 0.0686



  0%|          | 12/4649 [2:33:34<980:16:38, 761.05s/it]

------- 1st valloss=0.0656

Epoch 363 finished ! Training Loss: 0.0695



  0%|          | 13/4649 [2:46:21<982:36:52, 763.03s/it]

------- 1st valloss=0.0680

Epoch 364 finished ! Training Loss: 0.0703



  0%|          | 14/4649 [2:59:06<983:08:25, 763.60s/it]

------- 1st valloss=0.0715

Epoch 365 finished ! Training Loss: 0.0723



  0%|          | 15/4649 [3:11:49<982:40:54, 763.41s/it]

------- 1st valloss=0.0658

Epoch 366 finished ! Training Loss: 0.0699



  0%|          | 16/4649 [3:24:31<981:48:24, 762.90s/it]

------- 1st valloss=0.0644

Epoch 367 finished ! Training Loss: 0.0702



  0%|          | 17/4649 [3:37:07<978:57:39, 760.85s/it]

------- 1st valloss=0.0625

Epoch 368 finished ! Training Loss: 0.0702



  0%|          | 18/4649 [3:49:47<978:33:39, 760.70s/it]

------- 1st valloss=0.0664

Epoch 369 finished ! Training Loss: 0.0702



  0%|          | 19/4649 [4:02:30<979:17:38, 761.44s/it]

------- 1st valloss=0.0674

Epoch 370 finished ! Training Loss: 0.0701

------- 1st valloss=0.0671



  0%|          | 20/4649 [4:15:11<978:44:04, 761.17s/it]

Checkpoint 370 saved !
Epoch 371 finished ! Training Loss: 0.0678



  0%|          | 21/4649 [4:28:06<983:52:57, 765.34s/it]

------- 1st valloss=0.0631

Epoch 372 finished ! Training Loss: 0.0674



  0%|          | 22/4649 [4:40:40<979:21:40, 761.98s/it]

------- 1st valloss=0.0648

Epoch 373 finished ! Training Loss: 0.0671



  0%|          | 23/4649 [4:53:26<980:42:20, 763.20s/it]

------- 1st valloss=0.0647

Epoch 374 finished ! Training Loss: 0.0677



  1%|          | 24/4649 [5:06:20<984:35:17, 766.38s/it]

------- 1st valloss=0.0667

Epoch 375 finished ! Training Loss: 0.0669



  1%|          | 25/4649 [5:19:13<986:55:08, 768.36s/it]

------- 1st valloss=0.0920

Epoch 376 finished ! Training Loss: 0.0697



  1%|          | 26/4649 [5:31:55<984:24:21, 766.57s/it]

------- 1st valloss=0.0708

Epoch 377 finished ! Training Loss: 0.0683



  1%|          | 27/4649 [5:44:45<985:30:36, 767.60s/it]

------- 1st valloss=0.0628

Epoch 378 finished ! Training Loss: 0.0717



  1%|          | 28/4649 [5:57:45<989:54:34, 771.19s/it]

------- 1st valloss=0.0742

Epoch 379 finished ! Training Loss: 0.0787



  1%|          | 29/4649 [6:10:29<987:00:43, 769.10s/it]

------- 1st valloss=0.0675

Epoch 380 finished ! Training Loss: 0.0692

------- 1st valloss=0.0659



  1%|          | 30/4649 [6:23:13<984:48:21, 767.55s/it]

Checkpoint 380 saved !
Epoch 381 finished ! Training Loss: 0.0664



  1%|          | 31/4649 [6:36:06<986:36:22, 769.12s/it]

------- 1st valloss=0.0647

Epoch 382 finished ! Training Loss: 0.0671

------- 1st valloss=0.0618

0.06180673448935799 less than 0.0621


  1%|          | 32/4649 [6:48:45<982:28:14, 766.06s/it]

Checkpoint 382 saved !
Epoch 383 finished ! Training Loss: 0.0673

------- 1st valloss=0.0618

0.061755539120539375 less than 0.06180673448935799


  1%|          | 33/4649 [7:01:23<979:12:45, 763.68s/it]

Checkpoint 383 saved !
Epoch 384 finished ! Training Loss: 0.0684



  1%|          | 34/4649 [7:14:10<980:28:16, 764.83s/it]

------- 1st valloss=0.0672

Epoch 385 finished ! Training Loss: 0.0684

------- 1st valloss=0.0601

0.060146167712367096 less than 0.061755539120539375


  1%|          | 35/4649 [7:27:11<986:07:48, 769.41s/it]

Checkpoint 385 saved !
Epoch 386 finished ! Training Loss: 0.0740



  1%|          | 36/4649 [7:40:07<988:42:17, 771.59s/it]

------- 1st valloss=0.0713

Epoch 387 finished ! Training Loss: 0.0802



  1%|          | 37/4649 [7:52:56<987:20:21, 770.69s/it]

------- 1st valloss=0.0720

Epoch 388 finished ! Training Loss: 0.0868



  1%|          | 38/4649 [8:05:45<986:44:21, 770.39s/it]

------- 1st valloss=0.0872

Epoch 389 finished ! Training Loss: 0.0726



  1%|          | 39/4649 [8:18:25<982:21:09, 767.13s/it]

------- 1st valloss=0.0768

Epoch 390 finished ! Training Loss: 0.0749

------- 1st valloss=0.0683



  1%|          | 40/4649 [8:31:00<977:32:05, 763.53s/it]

Checkpoint 390 saved !
Epoch 391 finished ! Training Loss: 0.0694



  1%|          | 41/4649 [8:44:03<984:49:02, 769.39s/it]

------- 1st valloss=0.0687

Epoch 392 finished ! Training Loss: 0.0712



  1%|          | 42/4649 [8:56:40<979:46:21, 765.61s/it]

------- 1st valloss=0.0806

Epoch 393 finished ! Training Loss: 0.0697



  1%|          | 43/4649 [9:09:27<980:06:03, 766.04s/it]

------- 1st valloss=0.0793

Epoch 394 finished ! Training Loss: 0.0695



  1%|          | 44/4649 [9:22:32<987:09:47, 771.72s/it]

------- 1st valloss=0.0616

Epoch 395 finished ! Training Loss: 0.0702



  1%|          | 45/4649 [9:35:08<980:47:23, 766.91s/it]

------- 1st valloss=0.0696



In [None]:
deeplab.eval()

with torch.no_grad():
    
    bgloss = 0
    bdloss = 0
    bvloss = 0
    
    for v, vbatch in tqdm(enumerate(validation_loader)):
            # move data to device, convert dtype to desirable dtype

        image_1 = vbatch['image1_data'].to(device=device, dtype=dtype)
        label_1 = vbatch['image1_label'].to(device=device, dtype=dtype)

        output = deeplab(image_1)
        # do the inference
        output_numpy = output.cpu().numpy()
        
        
        #out_1 = torch.round(output)
        out_1 = torch.from_numpy((output_numpy == output_numpy.max(axis=1)[:, None]).astype(int)).to(device=device, dtype=dtype)
        loss_1 = dice_loss_3(out_1, label_1)

        bg, bd, bv = dice_loss_3_debug(out_1, label_1)
        # calculate loss
        print(bg.item(), bd.item(), bv.item(), loss_1.item())
        bgloss += bg.item()
        bdloss += bd.item()
        bvloss += bv.item()

    outstr = '------- background loss = {0:.4f}, body loss = {1:.4f}, bv loss = {2:.4f}'\
        .format(bgloss/(v+1), bdloss/(v+1), bvloss/(v+1)) + '\n'
    print(outstr)