In [1]:
import sys
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utility import *
from data_utils import *
from loss import *
from train import *
from deeplab_model.deeplab import *
from sync_batchnorm import convert_model
import datetime

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
USE_GPU = True
NUM_WORKERS = 12
BATCH_SIZE = 2 

dtype = torch.float32 
# define dtype, float is space efficient than double

if USE_GPU and torch.cuda.is_available():
    
    device = torch.device('cuda')
    
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    # magic flag that accelerate
    
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


In [3]:
train_dataset = pyramid_dataset(data_type = 'nii_train', 
                transform=transforms.Compose([
                random_affine(90, 15),
                random_filp(0.5)]))
# do data augumentation on train dataset

validation_dataset = pyramid_dataset(data_type = 'nii_test', 
                transform=None)
# no data augumentation on validation dataset

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS) # drop_last
# loaders come with auto batch division and multi-thread acceleration

In [4]:
'''
deeplab = DeepLab(output_stride=4)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)
deeplab = deeplab.to(device=device, dtype=dtype)
#shape_test(icnet1, True)
# create the model, by default model type is float, use model.double(), model.float() to convert
# move the model to desirable device

optimizer = optim.Adam(deeplab.parameters(), lr=1e-2)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1)
epoch = 0

# create an optimizer object
# note that only the model_2 params and model_4 params will be optimized by optimizer
'''

"\ndeeplab = DeepLab(output_stride=4)\ndeeplab = nn.DataParallel(deeplab)\ndeeplab = convert_model(deeplab)\ndeeplab = deeplab.to(device=device, dtype=dtype)\n#shape_test(icnet1, True)\n# create the model, by default model type is float, use model.double(), model.float() to convert\n# move the model to desirable device\n\noptimizer = optim.Adam(deeplab.parameters(), lr=1e-2)\nscheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1)\nepoch = 0\n\n# create an optimizer object\n# note that only the model_2 params and model_4 params will be optimized by optimizer\n"

In [5]:

deeplab = DeepLab(output_stride=4)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)

checkpoint = torch.load('../deeplab_output_4_4_save/2019-08-10 22:17:19.117394 epoch: 170.pth') # latest one

deeplab.load_state_dict(checkpoint['state_dict_1'])
deeplab = deeplab.to(device, dtype)

optimizer = optim.Adam(deeplab.parameters())
optimizer.load_state_dict(checkpoint['optimizer'])

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1)
scheduler.load_state_dict(checkpoint['scheduler'])

epoch = checkpoint['epoch']
print(epoch)

for param_group in optimizer.param_groups:
    print(param_group['lr'])
    

170
0.01


In [6]:
epochs = 5000

min_val = 1

record = open('train_deeplab_output_4_4.txt','a+')

logger = {'train':[], 'validation_1': []}

for e in tqdm(range(epoch + 1, epochs)):
# iter over epoches

    epoch_loss = 0
        
    for t, batch in enumerate(train_loader):
    # iter over the train mini batches
    
        deeplab.train()
        # Set the model flag to train
        # 1. enable dropout
        # 2. batchnorm behave differently in train and test
        
        image_1 = batch['image1_data'].to(device=device, dtype=dtype)
        label_1 = batch['image1_label'].to(device=device, dtype=dtype)
        # move data to device, convert dtype to desirable dtype
        
        out_1 = deeplab(image_1)
        # do the inference

        loss_1 = dice_loss_3(out_1, label_1)
        # calculate loss
        
        epoch_loss += loss_1.item()
        # record minibatch loss to epoch loss
        
        optimizer.zero_grad()
        # set the model parameter gradient to zero
        
        loss_1.backward()
        # calculate the gradient wrt loss
        optimizer.step()
        #scheduler.step(loss_1)
        # take a gradient descent step
        
    outstr = 'Epoch {0} finished ! Training Loss: {1:.4f}'.format(e, epoch_loss/(t+1)) + '\n'
    
    logger['train'].append(epoch_loss/(t+1))
    
    print(outstr)
    record.write(outstr)
    record.flush()

    if e%5 == 0:
    # do validation every 5 epoches
    
        deeplab.eval()
        # set model flag to eval
        # 1. disable dropout
        # 2. batchnorm behave differs

        with torch.no_grad():
        # stop taking gradient
        
            #valloss_4 = 0
            #valloss_2 = 0
            valloss_1 = 0
            
            for v, vbatch in enumerate(validation_loader):
            # iter over validation mini batches
                
                image_1_val = vbatch['image1_data'].to(device=device, dtype=dtype)
                if get_dimensions(image_1_val) == 4:
                    image_1_val.unsqueeze_(0)
                label_1_val = vbatch['image1_label'].to(device=device, dtype=dtype)
                if get_dimensions(label_1_val) == 4:
                    label_1_val.unsqueeze_(0)
                # move data to device, convert dtype to desirable dtype
                # add one dimension to labels if they are 4D tensors
                
                out_1_val = deeplab(image_1_val)
                # do the inference
                
                loss_1 = dice_loss_3(out_1_val, label_1_val)
                # calculate loss

                valloss_1 += loss_1.item()
                # record mini batch loss
            
            avg_val_loss = (valloss_1 / (v+1))
            outstr = '------- 1st valloss={0:.4f}'\
                .format(avg_val_loss) + '\n'
            
            logger['validation_1'].append(avg_val_loss)
            scheduler.step(avg_val_loss)
            
            print(outstr)
            record.write(outstr)
            record.flush()
            
            if avg_val_loss < min_val:
                print(avg_val_loss, "less than", min_val)
                min_val = avg_val_loss
            
            save_1('deeplab_output_4_4_save', deeplab, optimizer, logger, e, scheduler)

record.close()

  0%|          | 0/4829 [00:00<?, ?it/s]

Epoch 171 finished ! Training Loss: 0.1173



  0%|          | 2/4829 [36:29<1496:55:12, 1116.41s/it]

Epoch 172 finished ! Training Loss: 0.1326



  0%|          | 3/4829 [53:40<1462:13:02, 1090.75s/it]

Epoch 173 finished ! Training Loss: 0.1104



  0%|          | 4/4829 [1:11:01<1441:48:24, 1075.75s/it]

Epoch 174 finished ! Training Loss: 0.1056

Epoch 175 finished ! Training Loss: 0.1005

------- 1st valloss=0.5074

0.507421050382697 less than 1


  0%|          | 5/4829 [1:29:06<1445:12:29, 1078.51s/it]

Checkpoint 175 saved !


  0%|          | 6/4829 [1:46:26<1429:27:32, 1066.98s/it]

Epoch 176 finished ! Training Loss: 0.0957



  0%|          | 7/4829 [2:03:55<1422:02:45, 1061.67s/it]

Epoch 177 finished ! Training Loss: 0.0904



  0%|          | 8/4829 [2:21:20<1414:57:52, 1056.60s/it]

Epoch 178 finished ! Training Loss: 0.0909



  0%|          | 9/4829 [2:38:56<1414:35:49, 1056.55s/it]

Epoch 179 finished ! Training Loss: 0.0949

Epoch 180 finished ! Training Loss: 0.0872

------- 1st valloss=0.4253

0.42531127903772437 less than 0.507421050382697


  0%|          | 10/4829 [2:57:17<1431:47:14, 1069.61s/it]

Checkpoint 180 saved !


  0%|          | 11/4829 [3:14:39<1420:41:01, 1061.53s/it]

Epoch 181 finished ! Training Loss: 0.0845



  0%|          | 12/4829 [3:31:57<1410:56:12, 1054.47s/it]

Epoch 182 finished ! Training Loss: 0.0858



  0%|          | 13/4829 [3:49:20<1406:07:56, 1051.10s/it]

Epoch 183 finished ! Training Loss: 0.0856



  0%|          | 14/4829 [4:06:46<1403:36:28, 1049.43s/it]

Epoch 184 finished ! Training Loss: 0.0912

Epoch 185 finished ! Training Loss: 0.0918

------- 1st valloss=0.4401



  0%|          | 15/4829 [4:25:04<1422:51:25, 1064.04s/it]

Checkpoint 185 saved !


  0%|          | 16/4829 [4:42:24<1412:41:47, 1056.66s/it]

Epoch 186 finished ! Training Loss: 0.0940



  0%|          | 17/4829 [4:59:56<1410:43:43, 1055.41s/it]

Epoch 187 finished ! Training Loss: 0.0954



  0%|          | 18/4829 [5:17:16<1404:18:06, 1050.82s/it]

Epoch 188 finished ! Training Loss: 0.0944



  0%|          | 19/4829 [5:34:43<1402:20:08, 1049.57s/it]

Epoch 189 finished ! Training Loss: 0.0875

Epoch 190 finished ! Training Loss: 0.0844

------- 1st valloss=0.1058

0.1057870112683462 less than 0.42531127903772437


  0%|          | 20/4829 [5:52:52<1417:55:00, 1061.45s/it]

Checkpoint 190 saved !


  0%|          | 21/4829 [6:10:09<1407:57:41, 1054.21s/it]

Epoch 191 finished ! Training Loss: 0.0817



  0%|          | 22/4829 [6:27:37<1404:53:58, 1052.14s/it]

Epoch 192 finished ! Training Loss: 0.0841



  0%|          | 23/4829 [6:45:20<1409:02:31, 1055.46s/it]

Epoch 193 finished ! Training Loss: 0.0857



  0%|          | 24/4829 [7:02:44<1404:03:26, 1051.95s/it]

Epoch 194 finished ! Training Loss: 0.0880

Epoch 195 finished ! Training Loss: 0.0874

------- 1st valloss=0.2385



  1%|          | 25/4829 [7:21:22<1430:27:43, 1071.95s/it]

Checkpoint 195 saved !


  1%|          | 26/4829 [7:39:08<1427:41:53, 1070.10s/it]

Epoch 196 finished ! Training Loss: 0.0927



  1%|          | 27/4829 [7:56:27<1414:53:32, 1060.73s/it]

Epoch 197 finished ! Training Loss: 0.0811



  1%|          | 28/4829 [8:14:03<1412:47:07, 1059.37s/it]

Epoch 198 finished ! Training Loss: 0.0812



  1%|          | 29/4829 [8:31:31<1407:54:41, 1055.93s/it]

Epoch 199 finished ! Training Loss: 0.0820

Epoch 200 finished ! Training Loss: 0.0787

------- 1st valloss=0.1648



  1%|          | 30/4829 [8:49:53<1426:05:48, 1069.80s/it]

Checkpoint 200 saved !


  1%|          | 31/4829 [9:07:14<1414:09:58, 1061.07s/it]

Epoch 201 finished ! Training Loss: 0.0810



  1%|          | 32/4829 [9:24:28<1402:59:15, 1052.90s/it]

Epoch 202 finished ! Training Loss: 0.0962



  1%|          | 33/4829 [9:41:57<1401:24:26, 1051.93s/it]

Epoch 203 finished ! Training Loss: 0.0850



  1%|          | 34/4829 [9:59:25<1399:29:35, 1050.71s/it]

Epoch 204 finished ! Training Loss: 0.0800

Epoch 205 finished ! Training Loss: 0.0870

------- 1st valloss=0.1007

0.10074582598779512 less than 0.1057870112683462


  1%|          | 35/4829 [10:17:46<1419:08:40, 1065.69s/it]

Checkpoint 205 saved !


  1%|          | 36/4829 [10:35:11<1410:41:20, 1059.56s/it]

Epoch 206 finished ! Training Loss: 0.0845



  1%|          | 37/4829 [10:52:41<1406:31:29, 1056.65s/it]

Epoch 207 finished ! Training Loss: 0.0805



  1%|          | 38/4829 [11:10:19<1406:49:39, 1057.10s/it]

Epoch 208 finished ! Training Loss: 0.0811



  1%|          | 39/4829 [11:27:43<1401:24:46, 1053.25s/it]

Epoch 209 finished ! Training Loss: 0.0847

Epoch 210 finished ! Training Loss: 0.0771

------- 1st valloss=0.1452



  1%|          | 40/4829 [11:45:49<1414:14:19, 1063.12s/it]

Checkpoint 210 saved !


  1%|          | 41/4829 [12:03:21<1409:28:23, 1059.75s/it]

Epoch 211 finished ! Training Loss: 0.0795



  1%|          | 42/4829 [12:20:43<1402:01:54, 1054.38s/it]

Epoch 212 finished ! Training Loss: 0.0796



  1%|          | 43/4829 [12:37:58<1393:55:56, 1048.51s/it]

Epoch 213 finished ! Training Loss: 0.0839



  1%|          | 44/4829 [12:55:35<1397:03:48, 1051.08s/it]

Epoch 214 finished ! Training Loss: 0.0968

Epoch 215 finished ! Training Loss: 0.0855

------- 1st valloss=0.2211



  1%|          | 45/4829 [13:13:54<1415:46:41, 1065.39s/it]

Checkpoint 215 saved !


  1%|          | 46/4829 [13:31:24<1409:28:15, 1060.86s/it]

Epoch 216 finished ! Training Loss: 0.0824



  1%|          | 47/4829 [13:48:35<1397:19:19, 1051.94s/it]

Epoch 217 finished ! Training Loss: 0.0811



  1%|          | 48/4829 [14:06:00<1394:20:52, 1049.92s/it]

Epoch 218 finished ! Training Loss: 0.0801



  1%|          | 49/4829 [14:23:39<1397:24:32, 1052.44s/it]

Epoch 219 finished ! Training Loss: 0.0760

Epoch 220 finished ! Training Loss: 0.0760

------- 1st valloss=0.3448



  1%|          | 50/4829 [14:41:52<1413:09:25, 1064.53s/it]

Checkpoint 220 saved !


  1%|          | 51/4829 [14:59:07<1401:14:26, 1055.77s/it]

Epoch 221 finished ! Training Loss: 0.0771



  1%|          | 52/4829 [15:16:39<1399:22:57, 1054.59s/it]

Epoch 222 finished ! Training Loss: 0.0834



  1%|          | 53/4829 [15:33:58<1393:02:11, 1050.03s/it]

Epoch 223 finished ! Training Loss: 0.0841



  1%|          | 54/4829 [15:51:19<1389:06:06, 1047.28s/it]

Epoch 224 finished ! Training Loss: 0.0817

Epoch 225 finished ! Training Loss: 0.0797

------- 1st valloss=nan



  1%|          | 55/4829 [16:09:46<1412:45:01, 1065.33s/it]

Checkpoint 225 saved !


  1%|          | 56/4829 [16:27:02<1400:47:09, 1056.53s/it]

Epoch 226 finished ! Training Loss: 0.0769



  1%|          | 57/4829 [16:44:23<1394:12:19, 1051.79s/it]

Epoch 227 finished ! Training Loss: 0.0768



  1%|          | 58/4829 [17:01:40<1388:05:21, 1047.39s/it]

Epoch 228 finished ! Training Loss: 0.0815



  1%|          | 59/4829 [17:19:18<1391:48:36, 1050.42s/it]

Epoch 229 finished ! Training Loss: 0.0753

Epoch 230 finished ! Training Loss: 0.0730

------- 1st valloss=nan



  1%|          | 60/4829 [17:37:22<1404:50:55, 1060.49s/it]

Checkpoint 230 saved !


  1%|▏         | 61/4829 [17:54:30<1391:41:48, 1050.78s/it]

Epoch 231 finished ! Training Loss: 0.0774



  1%|▏         | 62/4829 [18:11:40<1383:00:36, 1044.44s/it]

Epoch 232 finished ! Training Loss: 0.0821



  1%|▏         | 63/4829 [18:28:59<1380:47:13, 1042.98s/it]

Epoch 233 finished ! Training Loss: 0.0770



  1%|▏         | 64/4829 [18:46:37<1386:21:17, 1047.40s/it]

Epoch 234 finished ! Training Loss: 0.0782

Epoch 235 finished ! Training Loss: 0.0821

------- 1st valloss=0.0808

0.08083511759405551 less than 0.10074582598779512


  1%|▏         | 65/4829 [19:04:39<1399:57:12, 1057.90s/it]

Checkpoint 235 saved !


  1%|▏         | 66/4829 [19:22:00<1392:56:47, 1052.83s/it]

Epoch 236 finished ! Training Loss: 0.0774



  1%|▏         | 67/4829 [19:39:28<1390:50:40, 1051.46s/it]

Epoch 237 finished ! Training Loss: 0.0773



  1%|▏         | 68/4829 [19:56:39<1382:07:00, 1045.08s/it]

Epoch 238 finished ! Training Loss: 0.0769



  1%|▏         | 69/4829 [20:14:07<1383:06:19, 1046.05s/it]

Epoch 239 finished ! Training Loss: 0.0757

Epoch 240 finished ! Training Loss: 0.0756

------- 1st valloss=nan



  1%|▏         | 70/4829 [20:32:23<1402:50:07, 1061.19s/it]

Checkpoint 240 saved !


  1%|▏         | 71/4829 [20:49:46<1395:16:52, 1055.70s/it]

Epoch 241 finished ! Training Loss: 0.0745



  1%|▏         | 72/4829 [21:07:03<1387:31:59, 1050.06s/it]

Epoch 242 finished ! Training Loss: 0.0744



  2%|▏         | 73/4829 [21:24:28<1385:02:57, 1048.40s/it]

Epoch 243 finished ! Training Loss: 0.0759



  2%|▏         | 74/4829 [21:41:35<1376:26:49, 1042.11s/it]

Epoch 244 finished ! Training Loss: 0.0763

Epoch 245 finished ! Training Loss: 0.0769

------- 1st valloss=nan



  2%|▏         | 75/4829 [21:59:54<1398:42:15, 1059.18s/it]

Checkpoint 245 saved !


  2%|▏         | 76/4829 [22:17:14<1390:36:20, 1053.27s/it]

Epoch 246 finished ! Training Loss: 0.0741



  2%|▏         | 77/4829 [22:34:37<1386:33:05, 1050.42s/it]

Epoch 247 finished ! Training Loss: 0.0715



  2%|▏         | 78/4829 [22:51:58<1382:29:35, 1047.56s/it]

Epoch 248 finished ! Training Loss: 0.0749



  2%|▏         | 79/4829 [23:09:18<1378:54:57, 1045.07s/it]

Epoch 249 finished ! Training Loss: 0.0748

Epoch 250 finished ! Training Loss: 0.1234

------- 1st valloss=0.3312



  2%|▏         | 80/4829 [23:27:17<1392:03:55, 1055.26s/it]

Checkpoint 250 saved !


  2%|▏         | 81/4829 [23:44:56<1393:32:31, 1056.60s/it]

Epoch 251 finished ! Training Loss: 0.1085



  2%|▏         | 82/4829 [24:02:32<1392:58:52, 1056.40s/it]

Epoch 252 finished ! Training Loss: 0.0960



  2%|▏         | 83/4829 [24:19:57<1387:54:05, 1052.77s/it]

Epoch 253 finished ! Training Loss: 0.0850



  2%|▏         | 84/4829 [24:37:40<1391:57:20, 1056.07s/it]

Epoch 254 finished ! Training Loss: 0.0829

Epoch 255 finished ! Training Loss: 0.0858

------- 1st valloss=0.1399



  2%|▏         | 85/4829 [24:56:04<1410:21:18, 1070.25s/it]

Checkpoint 255 saved !


KeyboardInterrupt: 

In [None]:
deeplab.eval()

with torch.no_grad():
    
    bgloss = 0
    bdloss = 0
    bvloss = 0
    
    for v, vbatch in tqdm(enumerate(validation_loader)):
            # move data to device, convert dtype to desirable dtype

        image_1 = vbatch['image1_data'].to(device=device, dtype=dtype)
        label_1 = vbatch['image1_label'].to(device=device, dtype=dtype)

        output = deeplab(image_1)
        # do the inference
        output_numpy = output.cpu().numpy()
        
        
        #out_1 = torch.round(output)
        out_1 = torch.from_numpy((output_numpy == output_numpy.max(axis=1)[:, None]).astype(int)).to(device=device, dtype=dtype)
        loss_1 = dice_loss_3(out_1, label_1)
        show_image_slice(image_1)
        show_image_slice(label_1)
        show_image_slice(out_1)
        
        bg, bd, bv = dice_loss_3_debug(out_1, label_1)
        # calculate loss
        print(bg.item(), bd.item(), bv.item(), loss_1.item())
        bgloss += bg.item()
        bdloss += bd.item()
        bvloss += bv.item()

    outstr = '------- background loss = {0:.4f}, body loss = {1:.4f}, bv loss = {2:.4f}'\
        .format(bgloss/(v+1), bdloss/(v+1), bvloss/(v+1)) + '\n'
    print(outstr)

# 