In [1]:
import sys
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utility import *
from data_utils import *
from loss import *
from train import *
from deeplab_model.deeplab import *
from sync_batchnorm import convert_model
import datetime

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
USE_GPU = True
NUM_WORKERS = 12
BATCH_SIZE = 2 

dtype = torch.float32 
# define dtype, float is space efficient than double

if USE_GPU and torch.cuda.is_available():
    
    device = torch.device('cuda')
    
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    # magic flag that accelerate
    
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


In [3]:
train_dataset = pyramid_dataset(data_type = 'nii_train', 
                transform=transforms.Compose([
                random_affine(90, 15),
                random_filp(0.5), 
                transforms.RandomApply([CropAndPad()])
                ]))
# do data augumentation on train dataset

validation_dataset = pyramid_dataset(data_type = 'nii_test', 
                transform=None)
# no data augumentation on validation dataset

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS) # drop_last
# loaders come with auto batch division and multi-thread acceleration

In [None]:
deeplab = DeepLab(output_stride=2)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)
deeplab = deeplab.to(device=device, dtype=dtype)
#shape_test(icnet1, True)
# create the model, by default model type is float, use model.double(), model.float() to convert
# move the model to desirable device

optimizer = optim.Adam(deeplab.parameters(), lr=1e-2)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1)
epoch = 0

# create an optimizer object
# note that only the model_2 params and model_4 params will be optimized by optimizer

In [None]:
"""
deeplab = DeepLab(output_stride=8)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)

optimizer = optim.Adam(deeplab.parameters(), lr=1e-2)

#checkpoint = torch.load('../deeplab_save/2019-07-29 04:00:14.630172.pth') # second best
#checkpoint = torch.load('../deeplab_save/2019-07-28 23:47:36.279119.pth') # second best
#checkpoint = torch.load('../deeplab_save/2019-07-29 00:15:49.271222.pth') # best
#checkpoint = torch.load('../deeplab_save/2019-07-29 00:44:11.825872.pth')
checkpoint = torch.load('../deeplab_save/2019-07-31 20:34:01.096131.pth') # latest one

deeplab.load_state_dict(checkpoint['state_dict_1'])
optimizer.load_state_dict(checkpoint['optimizer'])
#scheduler.load_state_dict(checkpoint['scheduler'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
scheduler.load_state_dict(checkpoint['scheduler'])
epoch = checkpoint['epoch']
print(epoch)
"""

"\ndeeplab = DeepLab(output_stride=8)\ndeeplab = nn.DataParallel(deeplab)\ndeeplab = convert_model(deeplab)\n\noptimizer = optim.Adam(deeplab.parameters(), lr=1e-2)\n\n#checkpoint = torch.load('../deeplab_save/2019-07-29 04:00:14.630172.pth') # second best\n#checkpoint = torch.load('../deeplab_save/2019-07-28 23:47:36.279119.pth') # second best\n#checkpoint = torch.load('../deeplab_save/2019-07-29 00:15:49.271222.pth') # best\n#checkpoint = torch.load('../deeplab_save/2019-07-29 00:44:11.825872.pth')\ncheckpoint = torch.load('../deeplab_save/2019-07-31 20:34:01.096131.pth') # latest one\n\ndeeplab.load_state_dict(checkpoint['state_dict_1'])\noptimizer.load_state_dict(checkpoint['optimizer'])\n#scheduler.load_state_dict(checkpoint['scheduler'])\nscheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)\nscheduler.load_state_dict(checkpoint['scheduler'])\nepoch = checkpoint['epoch']\nprint(epoch)\n"

In [None]:
epochs = 5000

min_val = 1

record = open('train_deeplab_output_4_elastic.txt','w+')

logger = {'train':[], 'validation_1': []}

for e in tqdm(range(epoch + 1, epochs)):
# iter over epoches

    epoch_loss = 0
        
    for t, batch in enumerate(train_loader):
    # iter over the train mini batches
    
        deeplab.train()
        # Set the model flag to train
        # 1. enable dropout
        # 2. batchnorm behave differently in train and test
        
        image_1 = batch['image1_data'].to(device=device, dtype=dtype)
        label_1 = batch['image1_label'].to(device=device, dtype=dtype)
        # move data to device, convert dtype to desirable dtype
        
        out_1 = deeplab(image_1)
        # do the inference

        loss_1 = dice_loss_3(out_1, label_1)
        # calculate loss
        
        epoch_loss += loss_1.item()
        # record minibatch loss to epoch loss
        
        optimizer.zero_grad()
        # set the model parameter gradient to zero
        
        loss_1.backward()
        # calculate the gradient wrt loss
        optimizer.step()
        #scheduler.step(loss_1)
        # take a gradient descent step
        
    outstr = 'Epoch {0} finished ! Training Loss: {1:.4f}'.format(e, epoch_loss/(t+1)) + '\n'
    
    logger['train'].append(epoch_loss/(t+1))
    
    print(outstr)
    record.write(outstr)
    record.flush()

    if e%5 == 0:
    # do validation every 5 epoches
    
        deeplab.eval()
        # set model flag to eval
        # 1. disable dropout
        # 2. batchnorm behave differs

        with torch.no_grad():
        # stop taking gradient
        
            #valloss_4 = 0
            #valloss_2 = 0
            valloss_1 = 0
            
            for v, vbatch in enumerate(validation_loader):
            # iter over validation mini batches
                
                image_1_val = vbatch['image1_data'].to(device=device, dtype=dtype)
                if get_dimensions(image_1_val) == 4:
                    image_1_val.unsqueeze_(0)
                label_1_val = vbatch['image1_label'].to(device=device, dtype=dtype)
                if get_dimensions(label_1_val) == 4:
                    label_1_val.unsqueeze_(0)
                # move data to device, convert dtype to desirable dtype
                # add one dimension to labels if they are 4D tensors
                
                out_1_val = deeplab(image_1_val)
                # do the inference
                
                loss_1 = dice_loss_3(out_1_val, label_1_val)
                # calculate loss

                valloss_1 += loss_1.item()
                # record mini batch loss
            
            avg_val_loss = (valloss_1 / (v+1))
            outstr = '------- 1st valloss={0:.4f}'\
                .format(avg_val_loss) + '\n'
            
            logger['validation_1'].append(avg_val_loss)
            #scheduler.step(avg_val_loss)
            
            print(outstr)
            record.write(outstr)
            record.flush()
            
            if avg_val_loss < min_val:
                print(avg_val_loss, "less than", min_val)
                min_val = avg_val_loss
                
            save_1('deeplab_output_2_elastic_save', deeplab, optimizer, logger, e, scheduler)

record.close()

  0%|          | 1/4999 [12:04<1006:10:23, 724.73s/it]

Epoch 1 finished ! Training Loss: 0.5174



  0%|          | 2/4999 [23:19<985:15:14, 709.81s/it] 

Epoch 2 finished ! Training Loss: 0.4752



  0%|          | 3/4999 [34:16<962:48:44, 693.78s/it]

Epoch 3 finished ! Training Loss: 0.4689



  0%|          | 4/4999 [45:28<953:43:24, 687.37s/it]

Epoch 4 finished ! Training Loss: 0.4556

Epoch 5 finished ! Training Loss: 0.4426

------- 1st valloss=0.6607

0.6606662921283556 less than 1


  0%|          | 5/4999 [57:18<962:54:52, 694.13s/it]

Checkpoint 5 saved !


  0%|          | 6/4999 [1:08:39<957:17:54, 690.22s/it]

Epoch 6 finished ! Training Loss: 0.4312



  0%|          | 7/4999 [1:19:45<946:59:03, 682.92s/it]

Epoch 7 finished ! Training Loss: 0.4021



  0%|          | 8/4999 [1:30:32<932:02:04, 672.28s/it]

Epoch 8 finished ! Training Loss: 0.3993



  0%|          | 9/4999 [1:41:34<927:31:44, 669.16s/it]

Epoch 9 finished ! Training Loss: 0.3776

Epoch 10 finished ! Training Loss: 0.3602

------- 1st valloss=0.3740

0.37397159312082373 less than 0.6606662921283556


  0%|          | 10/4999 [1:53:16<941:00:39, 679.02s/it]

Checkpoint 10 saved !


  0%|          | 11/4999 [2:04:31<938:54:42, 677.64s/it]

Epoch 11 finished ! Training Loss: 0.3664



  0%|          | 12/4999 [2:15:29<930:29:03, 671.70s/it]

Epoch 12 finished ! Training Loss: 0.3385



  0%|          | 13/4999 [2:26:26<924:23:22, 667.43s/it]

Epoch 13 finished ! Training Loss: 0.3446



  0%|          | 14/4999 [2:37:36<925:22:25, 668.27s/it]

Epoch 14 finished ! Training Loss: 0.3226

Epoch 15 finished ! Training Loss: 0.3181

------- 1st valloss=0.3197

0.3197280790494836 less than 0.37397159312082373


  0%|          | 15/4999 [2:49:14<937:28:04, 677.14s/it]

Checkpoint 15 saved !


  0%|          | 16/4999 [3:00:22<933:18:26, 674.27s/it]

Epoch 16 finished ! Training Loss: 0.3289



  0%|          | 17/4999 [3:11:16<924:55:37, 668.35s/it]

Epoch 17 finished ! Training Loss: 0.2930



  0%|          | 18/4999 [3:22:12<919:24:12, 664.50s/it]

Epoch 18 finished ! Training Loss: 0.2968



  0%|          | 19/4999 [3:33:17<919:41:24, 664.84s/it]

Epoch 19 finished ! Training Loss: 0.3099

Epoch 20 finished ! Training Loss: 0.2880

------- 1st valloss=0.4074



  0%|          | 20/4999 [3:44:45<929:04:58, 671.76s/it]

Checkpoint 20 saved !


  0%|          | 21/4999 [3:55:33<918:54:12, 664.53s/it]

Epoch 21 finished ! Training Loss: 0.2860



  0%|          | 22/4999 [4:06:35<917:39:00, 663.76s/it]

Epoch 22 finished ! Training Loss: 0.2898



  0%|          | 23/4999 [4:17:39<917:31:34, 663.81s/it]

Epoch 23 finished ! Training Loss: 0.2867



  0%|          | 24/4999 [4:28:40<916:08:40, 662.94s/it]

Epoch 24 finished ! Training Loss: 0.2633

Epoch 25 finished ! Training Loss: 0.2679

------- 1st valloss=0.2164

0.21637782648853635 less than 0.3197280790494836


  1%|          | 25/4999 [4:40:19<931:14:34, 674.00s/it]

Checkpoint 25 saved !


  1%|          | 26/4999 [4:51:22<926:27:40, 670.67s/it]

Epoch 26 finished ! Training Loss: 0.2565



  1%|          | 27/4999 [5:02:08<915:53:14, 663.15s/it]

Epoch 27 finished ! Training Loss: 0.2697



  1%|          | 28/4999 [5:13:03<912:13:43, 660.64s/it]

Epoch 28 finished ! Training Loss: 0.2767



  1%|          | 29/4999 [5:24:00<910:33:48, 659.56s/it]

Epoch 29 finished ! Training Loss: 0.2610

Epoch 30 finished ! Training Loss: 0.2617

------- 1st valloss=0.1713

0.17127005788295166 less than 0.21637782648853635


  1%|          | 30/4999 [5:35:51<931:52:58, 675.14s/it]

Checkpoint 30 saved !


  1%|          | 31/4999 [5:46:51<925:17:04, 670.50s/it]

Epoch 31 finished ! Training Loss: 0.2641



  1%|          | 32/4999 [5:57:49<919:50:41, 666.69s/it]

Epoch 32 finished ! Training Loss: 0.2511



  1%|          | 33/4999 [6:08:45<915:07:22, 663.40s/it]

Epoch 33 finished ! Training Loss: 0.2609



  1%|          | 34/4999 [6:19:37<910:29:29, 660.18s/it]

Epoch 34 finished ! Training Loss: 0.2556

Epoch 35 finished ! Training Loss: 0.2516

------- 1st valloss=0.6581



  1%|          | 35/4999 [6:31:13<925:08:10, 670.93s/it]

Checkpoint 35 saved !


  1%|          | 36/4999 [6:42:12<919:53:50, 667.26s/it]

Epoch 36 finished ! Training Loss: 0.2850



  1%|          | 37/4999 [6:53:14<917:27:46, 665.63s/it]

Epoch 37 finished ! Training Loss: 0.2676



  1%|          | 38/4999 [7:04:21<918:09:48, 666.27s/it]

Epoch 38 finished ! Training Loss: 0.2621



  1%|          | 39/4999 [7:15:27<917:34:38, 665.98s/it]

Epoch 39 finished ! Training Loss: 0.2792

Epoch 40 finished ! Training Loss: 0.2589

------- 1st valloss=0.2901



  1%|          | 40/4999 [7:27:15<934:42:06, 678.55s/it]

Checkpoint 40 saved !


  1%|          | 41/4999 [7:38:07<923:31:17, 670.57s/it]

Epoch 41 finished ! Training Loss: 0.2560



  1%|          | 42/4999 [7:49:06<918:46:39, 667.26s/it]

Epoch 42 finished ! Training Loss: 0.2414



  1%|          | 43/4999 [7:59:58<912:08:13, 662.57s/it]

Epoch 43 finished ! Training Loss: 0.2788



  1%|          | 44/4999 [8:10:55<909:35:00, 660.85s/it]

Epoch 44 finished ! Training Loss: 0.2692

Epoch 45 finished ! Training Loss: 0.2928

------- 1st valloss=0.3646



  1%|          | 45/4999 [8:22:31<924:08:55, 671.57s/it]

Checkpoint 45 saved !


  1%|          | 46/4999 [8:33:28<917:58:49, 667.22s/it]

Epoch 46 finished ! Training Loss: 0.2846



  1%|          | 47/4999 [8:44:33<916:46:21, 666.47s/it]

Epoch 47 finished ! Training Loss: 0.2645



  1%|          | 48/4999 [8:55:37<915:27:18, 665.65s/it]

Epoch 48 finished ! Training Loss: 0.2794



  1%|          | 49/4999 [9:06:34<912:00:24, 663.28s/it]

Epoch 49 finished ! Training Loss: 0.2732

Epoch 50 finished ! Training Loss: 0.2588

------- 1st valloss=0.1388

0.13878745099772577 less than 0.17127005788295166


  1%|          | 50/4999 [9:18:22<929:59:04, 676.49s/it]

Checkpoint 50 saved !


  1%|          | 51/4999 [9:29:33<927:27:06, 674.78s/it]

Epoch 51 finished ! Training Loss: 0.2826



  1%|          | 52/4999 [9:40:29<919:37:44, 669.23s/it]

Epoch 52 finished ! Training Loss: 0.2650



  1%|          | 53/4999 [9:51:35<918:09:31, 668.29s/it]

Epoch 53 finished ! Training Loss: 0.2681



  1%|          | 54/4999 [10:02:35<914:45:52, 665.96s/it]

Epoch 54 finished ! Training Loss: 0.2619

Epoch 55 finished ! Training Loss: 0.2706

------- 1st valloss=0.2043



  1%|          | 55/4999 [10:14:15<928:30:52, 676.10s/it]

Checkpoint 55 saved !


  1%|          | 56/4999 [10:25:14<921:15:48, 670.96s/it]

Epoch 56 finished ! Training Loss: 0.2603



  1%|          | 57/4999 [10:36:12<915:51:58, 667.16s/it]

Epoch 57 finished ! Training Loss: 0.2767



  1%|          | 58/4999 [10:47:14<913:20:51, 665.46s/it]

Epoch 58 finished ! Training Loss: 0.2563



  1%|          | 59/4999 [10:58:18<912:41:22, 665.12s/it]

Epoch 59 finished ! Training Loss: 0.2668

Epoch 60 finished ! Training Loss: 0.2728

------- 1st valloss=0.1386

0.1385621136945227 less than 0.13878745099772577


  1%|          | 60/4999 [11:10:01<928:08:24, 676.51s/it]

Checkpoint 60 saved !


  1%|          | 61/4999 [11:21:11<924:56:54, 674.32s/it]

Epoch 61 finished ! Training Loss: 0.2717



  1%|          | 62/4999 [11:32:04<916:17:58, 668.15s/it]

Epoch 62 finished ! Training Loss: 0.2696



  1%|▏         | 63/4999 [11:43:17<918:02:52, 669.57s/it]

Epoch 63 finished ! Training Loss: 0.2585



  1%|▏         | 64/4999 [11:54:12<911:42:45, 665.08s/it]

Epoch 64 finished ! Training Loss: 0.2517

Epoch 65 finished ! Training Loss: 0.2622

------- 1st valloss=0.1246

0.12462053026842035 less than 0.1385621136945227


  1%|▏         | 65/4999 [12:05:55<927:10:27, 676.50s/it]

Checkpoint 65 saved !


  1%|▏         | 66/4999 [12:17:04<924:02:43, 674.35s/it]

Epoch 66 finished ! Training Loss: 0.2534



  1%|▏         | 67/4999 [12:28:03<917:19:28, 669.58s/it]

Epoch 67 finished ! Training Loss: 0.2558



  1%|▏         | 68/4999 [12:51:57<1231:27:36, 899.06s/it]

Epoch 68 finished ! Training Loss: 0.2696



  1%|▏         | 69/4999 [13:02:56<1132:19:59, 826.86s/it]

Epoch 69 finished ! Training Loss: 0.2695

Epoch 70 finished ! Training Loss: 0.2518

------- 1st valloss=0.1255



  1%|▏         | 70/4999 [13:14:37<1080:30:44, 789.18s/it]

Checkpoint 70 saved !


  1%|▏         | 71/4999 [13:25:29<1024:01:29, 748.07s/it]

Epoch 71 finished ! Training Loss: 0.2686



  1%|▏         | 72/4999 [13:36:37<990:56:24, 724.05s/it] 

Epoch 72 finished ! Training Loss: 0.2544



  1%|▏         | 73/4999 [13:47:30<961:29:53, 702.68s/it]

Epoch 73 finished ! Training Loss: 0.2690



  1%|▏         | 74/4999 [13:58:31<944:08:18, 690.13s/it]

Epoch 74 finished ! Training Loss: 0.2578

Epoch 75 finished ! Training Loss: 0.2689

------- 1st valloss=0.5986



  2%|▏         | 75/4999 [14:10:14<949:13:25, 693.99s/it]

Checkpoint 75 saved !


  2%|▏         | 76/4999 [14:21:05<931:31:29, 681.19s/it]

Epoch 76 finished ! Training Loss: 0.2611



  2%|▏         | 77/4999 [14:31:49<916:10:43, 670.10s/it]

Epoch 77 finished ! Training Loss: 0.2737



  2%|▏         | 78/4999 [14:42:53<913:10:39, 668.04s/it]

Epoch 78 finished ! Training Loss: 0.2563



  2%|▏         | 79/4999 [14:53:47<907:22:04, 663.93s/it]

Epoch 79 finished ! Training Loss: 0.2341

Epoch 80 finished ! Training Loss: 0.2583

------- 1st valloss=0.1646



  2%|▏         | 80/4999 [15:05:16<917:39:16, 671.59s/it]

Checkpoint 80 saved !


  2%|▏         | 81/4999 [15:16:18<913:23:47, 668.61s/it]

Epoch 81 finished ! Training Loss: 0.2708



  2%|▏         | 82/4999 [15:27:25<912:37:20, 668.18s/it]

Epoch 82 finished ! Training Loss: 0.2769



  2%|▏         | 83/4999 [15:38:26<909:21:40, 665.93s/it]

Epoch 83 finished ! Training Loss: 0.2707



  2%|▏         | 84/4999 [15:49:17<903:11:20, 661.54s/it]

Epoch 84 finished ! Training Loss: 0.2398

Epoch 85 finished ! Training Loss: 0.2543

------- 1st valloss=0.2289



  2%|▏         | 85/4999 [16:01:02<920:56:31, 674.68s/it]

Checkpoint 85 saved !


  2%|▏         | 86/4999 [16:12:08<916:49:30, 671.80s/it]

Epoch 86 finished ! Training Loss: 0.2609



  2%|▏         | 87/4999 [16:23:10<912:37:35, 668.86s/it]

Epoch 87 finished ! Training Loss: 0.2760



  2%|▏         | 88/4999 [16:33:55<902:51:10, 661.83s/it]

Epoch 88 finished ! Training Loss: 0.2673



  2%|▏         | 89/4999 [16:44:53<901:15:05, 660.80s/it]

Epoch 89 finished ! Training Loss: 0.2226

Epoch 90 finished ! Training Loss: 0.2656

------- 1st valloss=0.1274



  2%|▏         | 90/4999 [16:56:38<919:08:18, 674.05s/it]

Checkpoint 90 saved !


  2%|▏         | 91/4999 [17:07:30<909:53:22, 667.40s/it]

Epoch 91 finished ! Training Loss: 0.2676



  2%|▏         | 92/4999 [17:18:19<902:08:23, 661.85s/it]

Epoch 92 finished ! Training Loss: 0.2511



  2%|▏         | 93/4999 [17:29:18<900:32:27, 660.81s/it]

Epoch 93 finished ! Training Loss: 0.2468



  2%|▏         | 94/4999 [17:40:11<897:09:33, 658.47s/it]

Epoch 94 finished ! Training Loss: 0.2721

Epoch 95 finished ! Training Loss: 0.2513

------- 1st valloss=0.1516



  2%|▏         | 95/4999 [17:51:35<907:34:55, 666.25s/it]

Checkpoint 95 saved !


  2%|▏         | 96/4999 [18:02:49<910:35:07, 668.59s/it]

Epoch 96 finished ! Training Loss: 0.2665



  2%|▏         | 97/4999 [18:13:34<900:56:11, 661.64s/it]

Epoch 97 finished ! Training Loss: 0.2514



  2%|▏         | 98/4999 [18:24:38<901:24:14, 662.12s/it]

Epoch 98 finished ! Training Loss: 0.2602



  2%|▏         | 99/4999 [18:35:39<901:05:28, 662.03s/it]

Epoch 99 finished ! Training Loss: 0.2465

Epoch 100 finished ! Training Loss: 0.2377

------- 1st valloss=0.1743



  2%|▏         | 100/4999 [18:47:21<917:11:07, 673.99s/it]

Checkpoint 100 saved !


In [None]:
deeplab.eval()

with torch.no_grad():
    
    bgloss = 0
    bdloss = 0
    bvloss = 0
    
    for v, vbatch in tqdm(enumerate(validation_loader)):
            # move data to device, convert dtype to desirable dtype

        image_1 = vbatch['image1_data'].to(device=device, dtype=dtype)
        label_1 = vbatch['image1_label'].to(device=device, dtype=dtype)

        output = deeplab(image_1)
        # do the inference
        output_numpy = output.cpu().numpy()
        
        
        #out_1 = torch.round(output)
        out_1 = torch.from_numpy((output_numpy == output_numpy.max(axis=1)[:, None]).astype(int)).to(device=device, dtype=dtype)
        loss_1 = dice_loss_3(out_1, label_1)

        bg, bd, bv = dice_loss_3_debug(out_1, label_1)
        # calculate loss
        print(bg.item(), bd.item(), bv.item(), loss_1.item())
        bgloss += bg.item()
        bdloss += bd.item()
        bvloss += bv.item()

    outstr = '------- background loss = {0:.4f}, body loss = {1:.4f}, bv loss = {2:.4f}'\
        .format(bgloss/(v+1), bdloss/(v+1), bvloss/(v+1)) + '\n'
    print(outstr)