In [1]:
import sys
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utility import *
from data_utils import *
from loss import *
from train import *
from deeplab_model.deeplab import *
from sync_batchnorm import convert_model
import datetime

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
USE_GPU = True
NUM_WORKERS = 12
BATCH_SIZE = 2 

dtype = torch.float32 
# define dtype, float is space efficient than double

if USE_GPU and torch.cuda.is_available():
    
    device = torch.device('cuda')
    
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    # magic flag that accelerate
    
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


In [3]:
train_dataset = pyramid_dataset(data_type = 'nii_train', 
                transform=transforms.Compose([
                random_affine(90, 15),
                random_filp(0.5)]))
# do data augumentation on train dataset

validation_dataset = pyramid_dataset(data_type = 'nii_test', 
                transform=None)
# no data augumentation on validation dataset

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS) # drop_last
# loaders come with auto batch division and multi-thread acceleration

In [4]:
'''
deeplab = DeepLab(output_stride=8)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)
deeplab = deeplab.to(device=device, dtype=dtype)
#shape_test(icnet1, True)
# create the model, by default model type is float, use model.double(), model.float() to convert
# move the model to desirable device

optimizer = optim.Adam(deeplab.parameters(), lr=1e-2)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
epoch = 0

# create an optimizer object
# note that only the model_2 params and model_4 params will be optimized by optimizer
'''

"\ndeeplab = DeepLab(output_stride=8)\ndeeplab = nn.DataParallel(deeplab)\ndeeplab = convert_model(deeplab)\ndeeplab = deeplab.to(device=device, dtype=dtype)\n#shape_test(icnet1, True)\n# create the model, by default model type is float, use model.double(), model.float() to convert\n# move the model to desirable device\n\noptimizer = optim.Adam(deeplab.parameters(), lr=1e-2)\nscheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)\nepoch = 0\n\n# create an optimizer object\n# note that only the model_2 params and model_4 params will be optimized by optimizer\n"

In [None]:

deeplab = DeepLab(output_stride=8)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)

#checkpoint = torch.load('../deeplab_save/2019-07-29 04:00:14.630172.pth') # second best
#checkpoint = torch.load('../deeplab_save/2019-07-28 23:47:36.279119.pth') # second best
#checkpoint = torch.load('../deeplab_save/2019-07-29 00:15:49.271222.pth') # best
#checkpoint = torch.load('../deeplab_save/2019-07-29 00:44:11.825872.pth')
checkpoint = torch.load('../deeplab_output_8_3_save/2019-08-11 15:25:43.880400 epoch: 320.pth') # latest one

deeplab.load_state_dict(checkpoint['state_dict_1'])
deeplab = deeplab.to(device, dtype)

optimizer = optim.Adam(deeplab.parameters(), lr=1e-3)
optimizer.load_state_dict(checkpoint['optimizer'])

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
scheduler.load_state_dict(checkpoint['scheduler'])
epoch = checkpoint['epoch']
print(epoch)
for param_group in optimizer.param_groups:
    print(param_group['lr'])


320
0.0001


In [None]:
epochs = 5000

min_val = 0.0778

record = open('train_deeplab_output_8_3.txt','a+')

logger = {'train':[], 'validation_1': []}

for e in tqdm(range(epoch + 1, epochs)):
# iter over epoches

    epoch_loss = 0
        
    for t, batch in enumerate(train_loader):
    # iter over the train mini batches
    
        deeplab.train()
        # Set the model flag to train
        # 1. enable dropout
        # 2. batchnorm behave differently in train and test
        
        image_1 = batch['image1_data'].to(device=device, dtype=dtype)
        label_1 = batch['image1_label'].to(device=device, dtype=dtype)
        # move data to device, convert dtype to desirable dtype
        
        out_1 = deeplab(image_1)
        # do the inference

        loss_1 = dice_loss_3(out_1, label_1)
        # calculate loss
        
        epoch_loss += loss_1.item()
        # record minibatch loss to epoch loss
        
        optimizer.zero_grad()
        # set the model parameter gradient to zero
        
        loss_1.backward()
        # calculate the gradient wrt loss
        optimizer.step()
        #scheduler.step(loss_1)
        # take a gradient descent step
        
    outstr = 'Epoch {0} finished ! Training Loss: {1:.4f}'.format(e, epoch_loss/(t+1)) + '\n'
    
    logger['train'].append(epoch_loss/(t+1))
    
    print(outstr)
    record.write(outstr)
    record.flush()

    if e%1 == 0:
    # do validation every 5 epoches
    
        deeplab.eval()
        # set model flag to eval
        # 1. disable dropout
        # 2. batchnorm behave differs

        with torch.no_grad():
        # stop taking gradient
        
            #valloss_4 = 0
            #valloss_2 = 0
            valloss_1 = 0
            
            for v, vbatch in enumerate(validation_loader):
            # iter over validation mini batches
                
                image_1_val = vbatch['image1_data'].to(device=device, dtype=dtype)
                if get_dimensions(image_1_val) == 4:
                    image_1_val.unsqueeze_(0)
                label_1_val = vbatch['image1_label'].to(device=device, dtype=dtype)
                if get_dimensions(label_1_val) == 4:
                    label_1_val.unsqueeze_(0)
                # move data to device, convert dtype to desirable dtype
                # add one dimension to labels if they are 4D tensors
                
                out_1_val = deeplab(image_1_val)
                # do the inference
                
                loss_1 = dice_loss_3(out_1_val, label_1_val)
                # calculate loss

                valloss_1 += loss_1.item()
                # record mini batch loss
            
            avg_val_loss = (valloss_1 / (v+1))
            outstr = '------- 1st valloss={0:.4f}'\
                .format(avg_val_loss) + '\n'
            
            logger['validation_1'].append(avg_val_loss)
            scheduler.step(avg_val_loss)
            
            print(outstr)
            record.write(outstr)
            record.flush()
            
            if avg_val_loss < min_val:
                print(avg_val_loss, "less than", min_val)
                min_val = avg_val_loss
                
            save_1('deeplab_output_8_3_save', deeplab, optimizer, logger, e, scheduler)

record.close()

  0%|          | 0/4679 [00:00<?, ?it/s]

Epoch 321 finished ! Training Loss: 0.0699

------- 1st valloss=nan



  0%|          | 1/4679 [14:06<1099:34:39, 846.19s/it]

Checkpoint 321 saved !
Epoch 322 finished ! Training Loss: 0.0714

------- 1st valloss=nan



  0%|          | 2/4679 [26:57<1070:11:30, 823.75s/it]

Checkpoint 322 saved !
Epoch 323 finished ! Training Loss: 0.0704

------- 1st valloss=0.0920



  0%|          | 3/4679 [39:46<1048:33:05, 807.27s/it]

Checkpoint 323 saved !
Epoch 324 finished ! Training Loss: 0.0703

------- 1st valloss=nan



  0%|          | 4/4679 [52:39<1034:54:15, 796.93s/it]

Checkpoint 324 saved !
Epoch 325 finished ! Training Loss: 0.0698

------- 1st valloss=nan



  0%|          | 5/4679 [1:05:37<1027:30:56, 791.41s/it]

Checkpoint 325 saved !
Epoch 326 finished ! Training Loss: 0.0699

------- 1st valloss=0.0911



  0%|          | 6/4679 [1:18:38<1023:10:23, 788.24s/it]

Checkpoint 326 saved !
Epoch 327 finished ! Training Loss: 0.0693

------- 1st valloss=nan



  0%|          | 7/4679 [1:31:26<1015:06:10, 782.19s/it]

Checkpoint 327 saved !
Epoch 328 finished ! Training Loss: 0.0695

------- 1st valloss=0.0952



  0%|          | 8/4679 [1:44:18<1010:48:53, 779.05s/it]

Checkpoint 328 saved !
Epoch 329 finished ! Training Loss: 0.0701

------- 1st valloss=nan



  0%|          | 9/4679 [1:57:09<1007:33:37, 776.71s/it]

Checkpoint 329 saved !
Epoch 330 finished ! Training Loss: 0.0696

------- 1st valloss=nan



  0%|          | 10/4679 [2:09:58<1004:13:14, 774.30s/it]

Checkpoint 330 saved !
Epoch 331 finished ! Training Loss: 0.0700

------- 1st valloss=nan



  0%|          | 11/4679 [2:23:00<1006:56:19, 776.56s/it]

Checkpoint 331 saved !
Epoch 332 finished ! Training Loss: 0.0702

------- 1st valloss=nan



  0%|          | 12/4679 [2:35:59<1007:46:59, 777.38s/it]

Checkpoint 332 saved !
Epoch 333 finished ! Training Loss: 0.0705

------- 1st valloss=nan



  0%|          | 13/4679 [2:48:49<1004:47:07, 775.23s/it]

Checkpoint 333 saved !
Epoch 334 finished ! Training Loss: 0.0711

------- 1st valloss=nan



  0%|          | 14/4679 [3:01:50<1006:55:53, 777.05s/it]

Checkpoint 334 saved !
Epoch 335 finished ! Training Loss: 0.0704

------- 1st valloss=nan



  0%|          | 15/4679 [3:14:41<1004:04:49, 775.02s/it]

Checkpoint 335 saved !
Epoch 336 finished ! Training Loss: 0.0707

------- 1st valloss=0.0930



  0%|          | 16/4679 [3:27:28<1001:01:16, 772.82s/it]

Checkpoint 336 saved !
Epoch 337 finished ! Training Loss: 0.0706

------- 1st valloss=nan



  0%|          | 17/4679 [3:40:19<1000:08:03, 772.30s/it]

Checkpoint 337 saved !
Epoch 338 finished ! Training Loss: 0.0699

------- 1st valloss=nan



  0%|          | 18/4679 [3:53:14<1000:45:26, 772.95s/it]

Checkpoint 338 saved !
Epoch 339 finished ! Training Loss: 0.0697

------- 1st valloss=nan



  0%|          | 19/4679 [4:06:02<998:43:55, 771.55s/it] 

Checkpoint 339 saved !
Epoch 340 finished ! Training Loss: 0.0703

------- 1st valloss=nan



  0%|          | 20/4679 [4:18:52<997:43:36, 770.94s/it]

Checkpoint 340 saved !
Epoch 341 finished ! Training Loss: 0.0696

------- 1st valloss=0.0923



  0%|          | 21/4679 [4:31:41<996:53:18, 770.46s/it]

Checkpoint 341 saved !
Epoch 342 finished ! Training Loss: 0.0707

------- 1st valloss=nan



  0%|          | 22/4679 [4:44:30<995:57:16, 769.90s/it]

Checkpoint 342 saved !
Epoch 343 finished ! Training Loss: 0.0697

------- 1st valloss=nan



  0%|          | 23/4679 [4:57:16<994:11:18, 768.70s/it]

Checkpoint 343 saved !
Epoch 344 finished ! Training Loss: 0.0702

------- 1st valloss=nan



  1%|          | 24/4679 [5:10:03<993:35:06, 768.40s/it]

Checkpoint 344 saved !
Epoch 345 finished ! Training Loss: 0.0713

------- 1st valloss=nan



  1%|          | 25/4679 [5:22:53<994:00:18, 768.89s/it]

Checkpoint 345 saved !
Epoch 346 finished ! Training Loss: 0.0696

------- 1st valloss=nan



  1%|          | 26/4679 [5:35:45<994:52:37, 769.73s/it]

Checkpoint 346 saved !
Epoch 347 finished ! Training Loss: 0.0702

------- 1st valloss=nan



  1%|          | 27/4679 [5:48:47<999:28:50, 773.46s/it]

Checkpoint 347 saved !
Epoch 348 finished ! Training Loss: 0.0705

------- 1st valloss=0.0650

0.06503271230536958 less than 0.0778


  1%|          | 28/4679 [6:01:51<1003:14:21, 776.53s/it]

Checkpoint 348 saved !
Epoch 349 finished ! Training Loss: 0.0710

------- 1st valloss=0.0934



  1%|          | 29/4679 [6:14:49<1003:44:34, 777.09s/it]

Checkpoint 349 saved !
Epoch 350 finished ! Training Loss: 0.0703

------- 1st valloss=nan



  1%|          | 30/4679 [6:27:42<1001:59:00, 775.90s/it]

Checkpoint 350 saved !
Epoch 351 finished ! Training Loss: 0.0708

------- 1st valloss=nan



  1%|          | 31/4679 [6:40:39<1001:55:08, 776.01s/it]

Checkpoint 351 saved !
Epoch 352 finished ! Training Loss: 0.0697

------- 1st valloss=nan



  1%|          | 32/4679 [6:53:33<1001:01:47, 775.49s/it]

Checkpoint 352 saved !
Epoch 353 finished ! Training Loss: 0.0714

------- 1st valloss=nan



  1%|          | 33/4679 [7:06:32<1002:16:23, 776.62s/it]

Checkpoint 353 saved !
Epoch 354 finished ! Training Loss: 0.0692

------- 1st valloss=nan



  1%|          | 34/4679 [7:19:22<999:34:32, 774.70s/it] 

Checkpoint 354 saved !
Epoch 355 finished ! Training Loss: 0.0697

------- 1st valloss=nan



  1%|          | 35/4679 [7:32:13<997:56:59, 773.60s/it]

Checkpoint 355 saved !
Epoch 356 finished ! Training Loss: 0.0694

------- 1st valloss=nan



  1%|          | 36/4679 [7:45:14<1000:16:29, 775.57s/it]

Checkpoint 356 saved !
Epoch 357 finished ! Training Loss: 0.0694

------- 1st valloss=nan



  1%|          | 37/4679 [7:58:08<999:27:02, 775.10s/it] 

Checkpoint 357 saved !
Epoch 358 finished ! Training Loss: 0.0699

------- 1st valloss=nan



  1%|          | 38/4679 [8:10:53<995:27:41, 772.17s/it]

Checkpoint 358 saved !
Epoch 359 finished ! Training Loss: 0.0708

------- 1st valloss=nan



  1%|          | 39/4679 [8:23:52<997:47:04, 774.14s/it]

Checkpoint 359 saved !
Epoch 360 finished ! Training Loss: 0.0698

------- 1st valloss=nan



  1%|          | 40/4679 [8:36:53<1000:16:14, 776.24s/it]

Checkpoint 360 saved !
Epoch 361 finished ! Training Loss: 0.0703

------- 1st valloss=nan



  1%|          | 41/4679 [8:49:46<999:02:27, 775.45s/it] 

Checkpoint 361 saved !
Epoch 362 finished ! Training Loss: 0.0703

------- 1st valloss=nan



  1%|          | 42/4679 [9:02:36<996:38:05, 773.75s/it]

Checkpoint 362 saved !
Epoch 363 finished ! Training Loss: 0.0696

------- 1st valloss=nan



  1%|          | 43/4679 [9:15:39<1000:01:17, 776.55s/it]

Checkpoint 363 saved !
Epoch 364 finished ! Training Loss: 0.0700

------- 1st valloss=nan



  1%|          | 44/4679 [9:28:43<1002:43:04, 778.81s/it]

Checkpoint 364 saved !
Epoch 365 finished ! Training Loss: 0.0703

------- 1st valloss=nan



  1%|          | 45/4679 [9:41:28<997:10:29, 774.67s/it] 

Checkpoint 365 saved !
Epoch 366 finished ! Training Loss: 0.0712

------- 1st valloss=0.0908



  1%|          | 46/4679 [9:54:18<995:01:10, 773.16s/it]

Checkpoint 366 saved !


In [None]:
deeplab.eval()

with torch.no_grad():
    
    bgloss = 0
    bdloss = 0
    bvloss = 0
    
    for v, vbatch in tqdm(enumerate(validation_loader)):
            # move data to device, convert dtype to desirable dtype

        image_1 = vbatch['image1_data'].to(device=device, dtype=dtype)
        label_1 = vbatch['image1_label'].to(device=device, dtype=dtype)

        output = deeplab(image_1)
        # do the inference
        output_numpy = output.cpu().numpy()
        
        
        #out_1 = torch.round(output)
        out_1 = torch.from_numpy((output_numpy == output_numpy.max(axis=1)[:, None]).astype(int)).to(device=device, dtype=dtype)
        loss_1 = dice_loss_3(out_1, label_1)

        bg, bd, bv = dice_loss_3_debug(out_1, label_1)
        # calculate loss
        print(bg.item(), bd.item(), bv.item(), loss_1.item())
        bgloss += bg.item()
        bdloss += bd.item()
        bvloss += bv.item()

    outstr = '------- background loss = {0:.4f}, body loss = {1:.4f}, bv loss = {2:.4f}'\
        .format(bgloss/(v+1), bdloss/(v+1), bvloss/(v+1)) + '\n'
    print(outstr)