In [1]:
import sys
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utility import *
from data_utils import *
from loss import *
from train import *
from deeplab_model.deeplab import *
from sync_batchnorm import convert_model
import datetime

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
USE_GPU = True
NUM_WORKERS = 12
BATCH_SIZE = 2 

dtype = torch.float32 
# define dtype, float is space efficient than double

if USE_GPU and torch.cuda.is_available():
    
    device = torch.device('cuda')
    
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    # magic flag that accelerate
    
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


In [3]:
train_dataset = pyramid_dataset(data_type = 'nii_train', 
                transform=transforms.Compose([
                random_affine(90, 15),
                random_filp(0.5)
                ]))
# do data augumentation on train dataset

validation_dataset = pyramid_dataset(data_type = 'nii_test', 
                transform=None)
# no data augumentation on validation dataset

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS) # drop_last
# loaders come with auto batch division and multi-thread acceleration

In [None]:
'''
deeplab = DeepLab_ELU(output_stride=2)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)
deeplab = deeplab.to(device=device, dtype=dtype)
#shape_test(icnet1, True)
# create the model, by default model type is float, use model.double(), model.float() to convert
# move the model to desirable device

optimizer = optim.Adam(deeplab.parameters(), lr=1e-2)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=25)
epoch = 0

# create an optimizer object
# note that only the model_2 params and model_4 params will be optimized by optimizer
'''

In [None]:

deeplab = DeepLab_ELU(output_stride=2)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)

#checkpoint = torch.load('../deeplab_save/2019-07-29 04:00:14.630172.pth') # second best
#checkpoint = torch.load('../deeplab_save/2019-07-28 23:47:36.279119.pth') # second best
#checkpoint = torch.load('../deeplab_save/2019-07-29 00:15:49.271222.pth') # best
#checkpoint = torch.load('../deeplab_save/2019-07-29 00:44:11.825872.pth')
checkpoint = torch.load('../deeplab_output_2_elu_save/2019-08-21 20:02:58.787116 epoch: 450.pth') # latest one

deeplab.load_state_dict(checkpoint['state_dict_1'])
deeplab = deeplab.to(device, dtype)

optimizer = optim.Adam(deeplab.parameters(), lr=1e-2)
optimizer.load_state_dict(checkpoint['optimizer'])

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
scheduler.load_state_dict(checkpoint['scheduler'])

epoch = checkpoint['epoch']
print(epoch)
for param_group in optimizer.param_groups:
    print(param_group['lr'])


"\ndeeplab = DeepLab_ELU(output_stride=2)\ndeeplab = nn.DataParallel(deeplab)\ndeeplab = convert_model(deeplab)\n\n#checkpoint = torch.load('../deeplab_save/2019-07-29 04:00:14.630172.pth') # second best\n#checkpoint = torch.load('../deeplab_save/2019-07-28 23:47:36.279119.pth') # second best\n#checkpoint = torch.load('../deeplab_save/2019-07-29 00:15:49.271222.pth') # best\n#checkpoint = torch.load('../deeplab_save/2019-07-29 00:44:11.825872.pth')\ncheckpoint = torch.load('../deeplab_output_2_2_save/2019-08-11 15:21:04.416434 epoch: 425.pth') # latest one\n\ndeeplab.load_state_dict(checkpoint['state_dict_1'])\ndeeplab = deeplab.to(device, dtype)\n\noptimizer = optim.Adam(deeplab.parameters(), lr=1e-4)\n#optimizer.load_state_dict(checkpoint['optimizer'])\n#scheduler.load_state_dict(checkpoint['scheduler'])\nscheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)\nscheduler.load_state_dict(checkpoint['scheduler'])\n\nepoch = checkpoint['epoch']\np

In [None]:
epochs = 5000

min_val = .07

record = open('train_deeplab_output_2_elu.txt','a+')

logger = {'train':[], 'validation_1': []}

for e in tqdm(range(epoch + 1, epochs)):
# iter over epoches

    epoch_loss = 0
        
    for t, batch in enumerate(train_loader):
    # iter over the train mini batches
    
        deeplab.train()
        # Set the model flag to train
        # 1. enable dropout
        # 2. batchnorm behave differently in train and test
        
        image_1 = batch['image1_data'].to(device=device, dtype=dtype)
        label_1 = batch['image1_label'].to(device=device, dtype=dtype)
        # move data to device, convert dtype to desirable dtype
        
        out_1 = deeplab(image_1)
        # do the inference

        loss_1 = dice_loss_3(out_1, label_1)
        # calculate loss
        
        epoch_loss += loss_1.item()
        # record minibatch loss to epoch loss
        
        optimizer.zero_grad()
        # set the model parameter gradient to zero
        
        loss_1.backward()
        # calculate the gradient wrt loss
        optimizer.step()
        #scheduler.step(loss_1)
        # take a gradient descent step
        
    outstr = 'Epoch {0} finished ! Training Loss: {1:.4f}'.format(e, epoch_loss/(t+1)) + '\n'
    
    logger['train'].append(epoch_loss/(t+1))
    
    print(outstr)
    record.write(outstr)
    record.flush()
    
    if (e <= 150 and e % 5 == 0) or (e > 150 and e % 1 == 0):
        # do validation every 5 epoches

        deeplab.eval()
        # set model flag to eval
        # 1. disable dropout
        # 2. batchnorm behave differs

        with torch.no_grad():
        # stop taking gradient

            #valloss_4 = 0
            #valloss_2 = 0
            valloss_1 = 0

            for v, vbatch in enumerate(validation_loader):
            # iter over validation mini batches

                image_1_val = vbatch['image1_data'].to(device=device, dtype=dtype)
                if get_dimensions(image_1_val) == 4:
                    image_1_val.unsqueeze_(0)
                label_1_val = vbatch['image1_label'].to(device=device, dtype=dtype)
                if get_dimensions(label_1_val) == 4:
                    label_1_val.unsqueeze_(0)
                # move data to device, convert dtype to desirable dtype
                # add one dimension to labels if they are 4D tensors

                out_1_val = deeplab(image_1_val)
                # do the inference

                loss_1 = dice_loss_3(out_1_val, label_1_val)
                # calculate loss

                valloss_1 += loss_1.item()
                # record mini batch loss

            avg_val_loss = (valloss_1 / (v+1))
            outstr = '------- 1st valloss={0:.4f}'\
                .format(avg_val_loss) + '\n'

            logger['validation_1'].append(avg_val_loss)
            #scheduler.step(avg_val_loss)

            print(outstr)
            record.write(outstr)
            record.flush()

            if avg_val_loss < min_val:
                print(avg_val_loss, "less than", min_val)
                min_val = avg_val_loss
                save_1('deeplab_output_2_elu_save', deeplab, optimizer, logger, e, scheduler)
            elif e%10 == 0:
                save_1('deeplab_output_2_elu_save', deeplab, optimizer, logger, e, scheduler)

record.close()

  0%|          | 1/4999 [13:16<1106:26:06, 796.95s/it]

Epoch 1 finished ! Training Loss: 0.5091



  0%|          | 2/4999 [25:20<1075:30:28, 774.83s/it]

Epoch 2 finished ! Training Loss: 0.4702



  0%|          | 3/4999 [37:32<1057:40:01, 762.13s/it]

Epoch 3 finished ! Training Loss: 0.4685



  0%|          | 4/4999 [49:29<1038:24:08, 748.40s/it]

Epoch 4 finished ! Training Loss: 0.4465

Epoch 5 finished ! Training Loss: 0.4335

------- 1st valloss=0.4162

0.4161752488302148 less than 1


  0%|          | 5/4999 [1:02:09<1043:23:26, 752.14s/it]

Checkpoint 5 saved !


  0%|          | 6/4999 [1:14:10<1030:14:16, 742.81s/it]

Epoch 6 finished ! Training Loss: 0.4024



  0%|          | 7/4999 [1:26:27<1027:26:01, 740.94s/it]

Epoch 7 finished ! Training Loss: 0.3736



  0%|          | 8/4999 [1:38:36<1022:07:25, 737.26s/it]

Epoch 8 finished ! Training Loss: 0.3502



  0%|          | 9/4999 [1:50:41<1016:49:12, 733.58s/it]

Epoch 9 finished ! Training Loss: 0.3453

Epoch 10 finished ! Training Loss: 0.3331

------- 1st valloss=0.2949

0.2949115452559098 less than 0.4161752488302148


  0%|          | 10/4999 [2:03:31<1031:47:50, 744.53s/it]

Checkpoint 10 saved !


  0%|          | 11/4999 [2:15:34<1022:38:29, 738.07s/it]

Epoch 11 finished ! Training Loss: 0.3034



  0%|          | 12/4999 [2:27:48<1020:45:55, 736.87s/it]

Epoch 12 finished ! Training Loss: 0.2821



  0%|          | 13/4999 [2:39:53<1015:34:45, 733.27s/it]

Epoch 13 finished ! Training Loss: 0.2846



  0%|          | 14/4999 [2:52:06<1015:14:37, 733.18s/it]

Epoch 14 finished ! Training Loss: 0.2613

Epoch 15 finished ! Training Loss: 0.2595



  0%|          | 15/4999 [3:05:03<1033:30:09, 746.51s/it]

------- 1st valloss=0.3447



  0%|          | 16/4999 [3:17:04<1022:34:11, 738.76s/it]

Epoch 16 finished ! Training Loss: 0.2583



  0%|          | 17/4999 [3:28:58<1012:05:18, 731.34s/it]

Epoch 17 finished ! Training Loss: 0.2540



  0%|          | 18/4999 [3:41:05<1010:03:37, 730.02s/it]

Epoch 18 finished ! Training Loss: 0.2326



  0%|          | 19/4999 [3:53:20<1011:53:54, 731.49s/it]

Epoch 19 finished ! Training Loss: 0.2360

Epoch 20 finished ! Training Loss: 0.2277

------- 1st valloss=0.1931

0.1931494662295217 less than 0.2949115452559098


  0%|          | 20/4999 [4:06:06<1026:02:38, 741.87s/it]

Checkpoint 20 saved !


  0%|          | 21/4999 [4:18:18<1021:41:41, 738.87s/it]

Epoch 21 finished ! Training Loss: 0.2187



  0%|          | 22/4999 [4:30:31<1019:11:44, 737.21s/it]

Epoch 22 finished ! Training Loss: 0.2286



  0%|          | 23/4999 [4:42:29<1010:58:21, 731.41s/it]

Epoch 23 finished ! Training Loss: 0.2339



  0%|          | 24/4999 [4:54:41<1011:10:06, 731.70s/it]

Epoch 24 finished ! Training Loss: 0.2206

Epoch 25 finished ! Training Loss: 0.2048

------- 1st valloss=0.1482

0.14815867821807446 less than 0.1931494662295217


  1%|          | 25/4999 [5:07:32<1027:07:20, 743.39s/it]

Checkpoint 25 saved !


  1%|          | 26/4999 [5:19:41<1020:53:13, 739.03s/it]

Epoch 26 finished ! Training Loss: 0.1975



  1%|          | 27/4999 [5:31:55<1018:47:16, 737.66s/it]

Epoch 27 finished ! Training Loss: 0.2027



  1%|          | 28/4999 [5:43:58<1012:09:11, 733.00s/it]

Epoch 28 finished ! Training Loss: 0.2149



  1%|          | 29/4999 [5:56:13<1013:07:58, 733.86s/it]

Epoch 29 finished ! Training Loss: 0.2160

Epoch 30 finished ! Training Loss: 0.2063

------- 1st valloss=0.1977



  1%|          | 30/4999 [6:09:03<1027:32:20, 744.44s/it]

Checkpoint 30 saved !


  1%|          | 31/4999 [6:21:08<1019:19:43, 738.64s/it]

Epoch 31 finished ! Training Loss: 0.1924



  1%|          | 32/4999 [6:33:17<1015:18:45, 735.88s/it]

Epoch 32 finished ! Training Loss: 0.2009



  1%|          | 33/4999 [6:45:30<1013:50:05, 734.96s/it]

Epoch 33 finished ! Training Loss: 0.1998



  1%|          | 34/4999 [6:57:34<1008:59:27, 731.59s/it]

Epoch 34 finished ! Training Loss: 0.1809

Epoch 35 finished ! Training Loss: 0.1956

------- 1st valloss=0.1208

0.12078349130309146 less than 0.14815867821807446


  1%|          | 35/4999 [7:10:21<1023:36:05, 742.34s/it]

Checkpoint 35 saved !


  1%|          | 36/4999 [7:22:15<1011:45:34, 733.90s/it]

Epoch 36 finished ! Training Loss: 0.1803



  1%|          | 37/4999 [7:34:43<1017:16:08, 738.04s/it]

Epoch 37 finished ! Training Loss: 0.1864



  1%|          | 38/4999 [7:46:50<1012:24:19, 734.66s/it]

Epoch 38 finished ! Training Loss: 0.1794



  1%|          | 39/4999 [7:59:08<1013:38:26, 735.71s/it]

Epoch 39 finished ! Training Loss: 0.1710

Epoch 40 finished ! Training Loss: 0.1825

------- 1st valloss=0.1143

0.1142981311549311 less than 0.12078349130309146


  1%|          | 40/4999 [8:12:05<1030:41:38, 748.24s/it]

Checkpoint 40 saved !


  1%|          | 41/4999 [8:24:09<1020:24:20, 740.92s/it]

Epoch 41 finished ! Training Loss: 0.1708



  1%|          | 42/4999 [8:36:29<1019:34:59, 740.47s/it]

Epoch 42 finished ! Training Loss: 0.1867



  1%|          | 43/4999 [8:48:39<1015:02:07, 737.31s/it]

Epoch 43 finished ! Training Loss: 0.1742



  1%|          | 44/4999 [9:00:52<1013:17:33, 736.20s/it]

Epoch 44 finished ! Training Loss: 0.1759

Epoch 45 finished ! Training Loss: 0.1708



  1%|          | 45/4999 [9:13:40<1026:10:27, 745.71s/it]

------- 1st valloss=0.2242



  1%|          | 46/4999 [9:25:40<1015:10:01, 737.86s/it]

Epoch 46 finished ! Training Loss: 0.1650



  1%|          | 47/4999 [9:37:48<1011:04:00, 735.02s/it]

Epoch 47 finished ! Training Loss: 0.1689



  1%|          | 48/4999 [9:49:54<1007:14:12, 732.39s/it]

Epoch 48 finished ! Training Loss: 0.1594



  1%|          | 49/4999 [10:01:51<1000:36:13, 727.71s/it]

Epoch 49 finished ! Training Loss: 0.1551

Epoch 50 finished ! Training Loss: 0.1662

------- 1st valloss=0.1132

0.11321386921664943 less than 0.1142981311549311


  1%|          | 50/4999 [10:15:01<1026:00:45, 746.34s/it]

Checkpoint 50 saved !


  1%|          | 51/4999 [10:27:11<1019:00:13, 741.39s/it]

Epoch 51 finished ! Training Loss: 0.1529



  1%|          | 52/4999 [10:39:14<1011:29:14, 736.07s/it]

Epoch 52 finished ! Training Loss: 0.1646



  1%|          | 53/4999 [10:51:26<1009:16:33, 734.61s/it]

Epoch 53 finished ! Training Loss: 0.1592



  1%|          | 54/4999 [11:03:34<1006:27:41, 732.71s/it]

Epoch 54 finished ! Training Loss: 0.1476

Epoch 55 finished ! Training Loss: 0.1437



  1%|          | 55/4999 [11:16:25<1022:11:31, 744.31s/it]

------- 1st valloss=0.1783



  1%|          | 56/4999 [11:28:34<1015:31:29, 739.61s/it]

Epoch 56 finished ! Training Loss: 0.1472



  1%|          | 57/4999 [11:40:39<1009:25:49, 735.32s/it]

Epoch 57 finished ! Training Loss: 0.1397



  1%|          | 58/4999 [11:52:40<1003:23:25, 731.07s/it]

Epoch 58 finished ! Training Loss: 0.1357



  1%|          | 59/4999 [12:04:50<1002:45:23, 730.75s/it]

Epoch 59 finished ! Training Loss: 0.1388

Epoch 60 finished ! Training Loss: 0.1327

------- 1st valloss=0.1144



  1%|          | 60/4999 [12:17:45<1020:37:13, 743.92s/it]

Checkpoint 60 saved !


  1%|          | 61/4999 [12:29:43<1009:39:08, 736.08s/it]

Epoch 61 finished ! Training Loss: 0.1322



  1%|          | 62/4999 [12:41:43<1003:08:19, 731.48s/it]

Epoch 62 finished ! Training Loss: 0.1321



  1%|▏         | 63/4999 [12:53:39<996:27:21, 726.75s/it] 

Epoch 63 finished ! Training Loss: 0.1316



  1%|▏         | 64/4999 [13:05:38<992:56:32, 724.33s/it]

Epoch 64 finished ! Training Loss: 0.1279

Epoch 65 finished ! Training Loss: 0.1233



  1%|▏         | 65/4999 [13:18:34<1014:05:22, 739.91s/it]

------- 1st valloss=0.2276



  1%|▏         | 66/4999 [13:30:52<1013:04:10, 739.32s/it]

Epoch 66 finished ! Training Loss: 0.1228



  1%|▏         | 67/4999 [13:43:00<1008:08:34, 735.87s/it]

Epoch 67 finished ! Training Loss: 0.1162



  1%|▏         | 68/4999 [13:55:04<1003:18:18, 732.49s/it]

Epoch 68 finished ! Training Loss: 0.1108



  1%|▏         | 69/4999 [14:07:16<1002:42:19, 732.20s/it]

Epoch 69 finished ! Training Loss: 0.1262

Epoch 70 finished ! Training Loss: 0.1339

------- 1st valloss=0.1665



  1%|▏         | 70/4999 [14:20:09<1019:17:27, 744.46s/it]

Checkpoint 70 saved !


  1%|▏         | 71/4999 [14:32:16<1011:41:20, 739.06s/it]

Epoch 71 finished ! Training Loss: 0.1339



  1%|▏         | 72/4999 [14:44:12<1002:16:14, 732.33s/it]

Epoch 72 finished ! Training Loss: 0.1219



  1%|▏         | 73/4999 [14:56:23<1001:20:31, 731.80s/it]

Epoch 73 finished ! Training Loss: 0.1174



  1%|▏         | 74/4999 [15:08:42<1004:24:45, 734.19s/it]

Epoch 74 finished ! Training Loss: 0.1137

Epoch 75 finished ! Training Loss: 0.1202



  2%|▏         | 75/4999 [15:21:26<1016:11:30, 742.95s/it]

------- 1st valloss=0.1690



  2%|▏         | 76/4999 [15:33:29<1007:51:49, 737.01s/it]

Epoch 76 finished ! Training Loss: 0.1128



  2%|▏         | 77/4999 [15:45:24<998:43:08, 730.47s/it] 

Epoch 77 finished ! Training Loss: 0.1130



  2%|▏         | 78/4999 [15:57:21<992:52:15, 726.34s/it]

Epoch 78 finished ! Training Loss: 0.1149



  2%|▏         | 79/4999 [16:09:25<991:52:25, 725.76s/it]

Epoch 79 finished ! Training Loss: 0.1149

Epoch 80 finished ! Training Loss: 0.1132

------- 1st valloss=0.1947



  2%|▏         | 80/4999 [16:22:08<1006:43:21, 736.78s/it]

Checkpoint 80 saved !


  2%|▏         | 81/4999 [16:34:29<1008:24:45, 738.16s/it]

Epoch 81 finished ! Training Loss: 0.1029



  2%|▏         | 82/4999 [16:46:30<1001:15:57, 733.08s/it]

Epoch 82 finished ! Training Loss: 0.1066



  2%|▏         | 83/4999 [16:58:43<1000:55:32, 732.98s/it]

Epoch 83 finished ! Training Loss: 0.1125



  2%|▏         | 84/4999 [17:10:58<1001:16:55, 733.39s/it]

Epoch 84 finished ! Training Loss: 0.1018

Epoch 85 finished ! Training Loss: 0.1110

------- 1st valloss=0.1042

0.10415547195336093 less than 0.11321386921664943


  2%|▏         | 85/4999 [17:23:39<1012:40:19, 741.88s/it]

Checkpoint 85 saved !


  2%|▏         | 86/4999 [17:35:45<1005:50:34, 737.03s/it]

Epoch 86 finished ! Training Loss: 0.1062



  2%|▏         | 87/4999 [17:47:52<1001:35:18, 734.06s/it]

Epoch 87 finished ! Training Loss: 0.1079



  2%|▏         | 88/4999 [17:59:57<997:41:37, 731.36s/it] 

Epoch 88 finished ! Training Loss: 0.1009



  2%|▏         | 89/4999 [18:12:05<996:02:42, 730.30s/it]

Epoch 89 finished ! Training Loss: 0.0980

Epoch 90 finished ! Training Loss: 0.1021

------- 1st valloss=0.0935

0.09353989897214848 less than 0.10415547195336093


  2%|▏         | 90/4999 [18:24:51<1010:26:39, 741.01s/it]

Checkpoint 90 saved !


  2%|▏         | 91/4999 [18:36:57<1004:10:21, 736.56s/it]

Epoch 91 finished ! Training Loss: 0.1075



  2%|▏         | 92/4999 [18:48:56<996:37:45, 731.17s/it] 

Epoch 92 finished ! Training Loss: 0.0971



  2%|▏         | 93/4999 [19:01:03<994:47:03, 729.97s/it]

Epoch 93 finished ! Training Loss: 0.1048



  2%|▏         | 94/4999 [19:13:09<992:51:32, 728.70s/it]

Epoch 94 finished ! Training Loss: 0.1012

Epoch 95 finished ! Training Loss: 0.1004



  2%|▏         | 95/4999 [19:26:01<1010:25:51, 741.75s/it]

------- 1st valloss=0.1760



  2%|▏         | 96/4999 [19:38:06<1003:24:42, 736.75s/it]

Epoch 96 finished ! Training Loss: 0.1092



In [None]:
deeplab.eval()

with torch.no_grad():
    
    bgloss = 0
    bdloss = 0
    bvloss = 0
    
    for v, vbatch in tqdm(enumerate(validation_loader)):
            # move data to device, convert dtype to desirable dtype

        image_1 = vbatch['image1_data'].to(device=device, dtype=dtype)
        label_1 = vbatch['image1_label'].to(device=device, dtype=dtype)

        output = deeplab(image_1)
        # do the inference
        output_numpy = output.cpu().numpy()
        
        
        #out_1 = torch.round(output)
        out_1 = torch.from_numpy((output_numpy == output_numpy.max(axis=1)[:, None]).astype(int)).to(device=device, dtype=dtype)
        loss_1 = dice_loss_3(out_1, label_1)

        bg, bd, bv = dice_loss_3_debug(out_1, label_1)
        # calculate loss
        print(bg.item(), bd.item(), bv.item(), loss_1.item())
        bgloss += bg.item()
        bdloss += bd.item()
        bvloss += bv.item()

    outstr = '------- background loss = {0:.4f}, body loss = {1:.4f}, bv loss = {2:.4f}'\
        .format(bgloss/(v+1), bdloss/(v+1), bvloss/(v+1)) + '\n'
    print(outstr)