In [1]:
import sys
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utility import *
from data_utils import *
from loss import *
from train import *
from deeplab_model.deeplab import *
from sync_batchnorm import convert_model
import datetime

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
USE_GPU = True
NUM_WORKERS = 12
BATCH_SIZE = 2 

dtype = torch.float32 
# define dtype, float is space efficient than double

if USE_GPU and torch.cuda.is_available():
    
    device = torch.device('cuda')
    
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    # magic flag that accelerate
    
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


In [3]:
train_dataset = pyramid_dataset(data_type = 'nii_train', 
                transform=transforms.Compose([
                random_affine(90, 15),
                random_filp(0.5)]))
# do data augumentation on train dataset

validation_dataset = pyramid_dataset(data_type = 'nii_test', 
                transform=None)
# no data augumentation on validation dataset

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS) # drop_last
# loaders come with auto batch division and multi-thread acceleration

In [None]:
'''
deeplab = DeepLab(output_stride=4)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)
deeplab = deeplab.to(device=device, dtype=dtype)
#shape_test(icnet1, True)
# create the model, by default model type is float, use model.double(), model.float() to convert
# move the model to desirable device

optimizer = optim.Adam(deeplab.parameters(), lr=1e-2)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1)
epoch = 0

# create an optimizer object
# note that only the model_2 params and model_4 params will be optimized by optimizer
'''

In [None]:

deeplab = DeepLab(output_stride=4)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)

checkpoint = torch.load('../deeplab_output_4_4_save/2019-08-10 22:17:19.117394 epoch: 170.pth') # latest one

deeplab.load_state_dict(checkpoint['state_dict_1'])
deeplab = deeplab.to(device, dtype)

optimizer = optim.Adam(deeplab.parameters())
optimizer.load_state_dict(checkpoint['optimizer'])

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1)
scheduler.load_state_dict(checkpoint['scheduler'])

epoch = checkpoint['epoch']
print(epoch)

for param_group in optimizer.param_groups:
    print(param_group['lr'])
    

"\ndeeplab = DeepLab(output_stride=4)\ndeeplab = nn.DataParallel(deeplab)\ndeeplab = convert_model(deeplab)\n\noptimizer = optim.Adam(deeplab.parameters())\n\ncheckpoint = torch.load('../deeplab_output_4_save/2019-08-05 12:43:35.164261 epoch: 165.pth') # latest one\n\ndeeplab.load_state_dict(checkpoint['state_dict_1'])\noptimizer.load_state_dict(checkpoint['optimizer'])\nscheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1)\n#scheduler.load_state_dict(checkpoint['scheduler'])\nepoch = checkpoint['epoch']\nprint(epoch)\n#print(checkpoint['logger']['validation_1'])\n"

In [None]:
epochs = 5000

min_val = 1

record = open('train_deeplab_output_4_4.txt','a+')

logger = {'train':[], 'validation_1': []}

for e in tqdm(range(epoch + 1, epochs)):
# iter over epoches

    epoch_loss = 0
        
    for t, batch in enumerate(train_loader):
    # iter over the train mini batches
    
        deeplab.train()
        # Set the model flag to train
        # 1. enable dropout
        # 2. batchnorm behave differently in train and test
        
        image_1 = batch['image1_data'].to(device=device, dtype=dtype)
        label_1 = batch['image1_label'].to(device=device, dtype=dtype)
        # move data to device, convert dtype to desirable dtype
        
        out_1 = deeplab(image_1)
        # do the inference

        loss_1 = dice_loss_3(out_1, label_1)
        # calculate loss
        
        epoch_loss += loss_1.item()
        # record minibatch loss to epoch loss
        
        optimizer.zero_grad()
        # set the model parameter gradient to zero
        
        loss_1.backward()
        # calculate the gradient wrt loss
        optimizer.step()
        #scheduler.step(loss_1)
        # take a gradient descent step
        
    outstr = 'Epoch {0} finished ! Training Loss: {1:.4f}'.format(e, epoch_loss/(t+1)) + '\n'
    
    logger['train'].append(epoch_loss/(t+1))
    
    print(outstr)
    record.write(outstr)
    record.flush()

    if e%5 == 0:
    # do validation every 5 epoches
    
        deeplab.eval()
        # set model flag to eval
        # 1. disable dropout
        # 2. batchnorm behave differs

        with torch.no_grad():
        # stop taking gradient
        
            #valloss_4 = 0
            #valloss_2 = 0
            valloss_1 = 0
            
            for v, vbatch in enumerate(validation_loader):
            # iter over validation mini batches
                
                image_1_val = vbatch['image1_data'].to(device=device, dtype=dtype)
                if get_dimensions(image_1_val) == 4:
                    image_1_val.unsqueeze_(0)
                label_1_val = vbatch['image1_label'].to(device=device, dtype=dtype)
                if get_dimensions(label_1_val) == 4:
                    label_1_val.unsqueeze_(0)
                # move data to device, convert dtype to desirable dtype
                # add one dimension to labels if they are 4D tensors
                
                out_1_val = deeplab(image_1_val)
                # do the inference
                
                loss_1 = dice_loss_3(out_1_val, label_1_val)
                # calculate loss

                valloss_1 += loss_1.item()
                # record mini batch loss
            
            avg_val_loss = (valloss_1 / (v+1))
            outstr = '------- 1st valloss={0:.4f}'\
                .format(avg_val_loss) + '\n'
            
            logger['validation_1'].append(avg_val_loss)
            scheduler.step(avg_val_loss)
            
            print(outstr)
            record.write(outstr)
            record.flush()
            
            if avg_val_loss < min_val:
                print(avg_val_loss, "less than", min_val)
                min_val = avg_val_loss
            
            save_1('deeplab_output_4_4_save', deeplab, optimizer, logger, e, scheduler)

record.close()

  0%|          | 1/4999 [19:22<1614:30:35, 1162.91s/it]

Epoch 1 finished ! Training Loss: 0.5005



  0%|          | 2/4999 [37:11<1574:53:35, 1134.60s/it]

Epoch 2 finished ! Training Loss: 0.4424



  0%|          | 3/4999 [55:06<1549:56:02, 1116.85s/it]

Epoch 3 finished ! Training Loss: 0.4160



  0%|          | 4/4999 [1:12:43<1524:33:52, 1098.79s/it]

Epoch 4 finished ! Training Loss: 0.3793

------- 1st valloss=0.5548

0.5548094161178755 less than 1


  0%|          | 5/4999 [1:31:38<1539:07:35, 1109.50s/it]

Checkpoint 5 saved !


  0%|          | 6/4999 [1:49:25<1521:30:36, 1097.02s/it]

Epoch 6 finished ! Training Loss: 0.3317



  0%|          | 7/4999 [2:06:58<1502:32:55, 1083.57s/it]

Epoch 7 finished ! Training Loss: 0.3138



  0%|          | 8/4999 [2:24:36<1491:45:55, 1076.01s/it]

Epoch 8 finished ! Training Loss: 0.2871



  0%|          | 9/4999 [2:42:08<1481:36:21, 1068.89s/it]

Epoch 9 finished ! Training Loss: 0.2868

Epoch 10 finished ! Training Loss: 0.2665

------- 1st valloss=0.2157

0.21569309027298636 less than 0.5548094161178755


  0%|          | 10/4999 [3:00:32<1495:51:21, 1079.39s/it]

Checkpoint 10 saved !


  0%|          | 11/4999 [3:17:51<1478:38:09, 1067.18s/it]

Epoch 11 finished ! Training Loss: 0.2700



  0%|          | 12/4999 [3:35:08<1465:50:59, 1058.16s/it]

Epoch 12 finished ! Training Loss: 0.2581



  0%|          | 13/4999 [3:52:39<1462:37:17, 1056.04s/it]

Epoch 13 finished ! Training Loss: 0.2459



  0%|          | 14/4999 [4:10:00<1456:07:07, 1051.56s/it]

Epoch 14 finished ! Training Loss: 0.2348

Epoch 15 finished ! Training Loss: 0.2448

------- 1st valloss=0.2446



  0%|          | 15/4999 [4:28:05<1469:27:21, 1061.40s/it]

Checkpoint 15 saved !


  0%|          | 16/4999 [4:45:23<1459:37:58, 1054.52s/it]

Epoch 16 finished ! Training Loss: 0.2323



  0%|          | 17/4999 [5:03:08<1463:44:27, 1057.70s/it]

Epoch 17 finished ! Training Loss: 0.2328



  0%|          | 18/4999 [5:20:58<1468:32:36, 1061.38s/it]

Epoch 18 finished ! Training Loss: 0.2118



  0%|          | 19/4999 [5:38:38<1467:49:41, 1061.08s/it]

Epoch 19 finished ! Training Loss: 0.2186

Epoch 20 finished ! Training Loss: 0.2184

------- 1st valloss=0.1657

0.16568971295719562 less than 0.21569309027298636


  0%|          | 20/4999 [5:56:33<1472:56:41, 1064.99s/it]

Checkpoint 20 saved !


  0%|          | 21/4999 [6:13:55<1463:21:54, 1058.28s/it]

Epoch 21 finished ! Training Loss: 0.2109



  0%|          | 22/4999 [6:31:44<1467:18:33, 1061.34s/it]

Epoch 22 finished ! Training Loss: 0.2086



  0%|          | 23/4999 [6:49:34<1470:38:12, 1063.97s/it]

Epoch 23 finished ! Training Loss: 0.2049



  0%|          | 24/4999 [7:07:19<1470:49:42, 1064.32s/it]

Epoch 24 finished ! Training Loss: 0.2057

Epoch 25 finished ! Training Loss: 0.1957

------- 1st valloss=0.1537

0.1536602096065231 less than 0.16568971295719562


  1%|          | 25/4999 [7:25:21<1477:48:47, 1069.59s/it]

Checkpoint 25 saved !


  1%|          | 26/4999 [7:43:02<1473:53:37, 1066.97s/it]

Epoch 26 finished ! Training Loss: 0.1952



  1%|          | 27/4999 [8:00:13<1459:01:35, 1056.42s/it]

Epoch 27 finished ! Training Loss: 0.1954



  1%|          | 28/4999 [8:17:29<1450:17:37, 1050.30s/it]

Epoch 28 finished ! Training Loss: 0.2169



  1%|          | 29/4999 [8:34:42<1442:35:42, 1044.94s/it]

Epoch 29 finished ! Training Loss: 0.1888

Epoch 30 finished ! Training Loss: 0.1949

------- 1st valloss=0.1340

0.13402820475723431 less than 0.1536602096065231


  1%|          | 30/4999 [8:52:49<1459:47:30, 1057.61s/it]

Checkpoint 30 saved !


  1%|          | 31/4999 [9:10:08<1451:36:45, 1051.89s/it]

Epoch 31 finished ! Training Loss: 0.1782



  1%|          | 32/4999 [9:27:58<1459:09:22, 1057.57s/it]

Epoch 32 finished ! Training Loss: 0.1978



  1%|          | 33/4999 [9:45:38<1459:39:58, 1058.16s/it]

Epoch 33 finished ! Training Loss: 0.2001



  1%|          | 34/4999 [10:03:27<1463:59:57, 1061.51s/it]

Epoch 34 finished ! Training Loss: 0.1911

Epoch 35 finished ! Training Loss: 0.1735

------- 1st valloss=0.1984



  1%|          | 35/4999 [10:21:34<1474:00:59, 1068.99s/it]

Checkpoint 35 saved !


  1%|          | 36/4999 [10:39:01<1464:32:14, 1062.33s/it]

Epoch 36 finished ! Training Loss: 0.1730



  1%|          | 37/4999 [10:56:13<1451:58:44, 1053.43s/it]

Epoch 37 finished ! Training Loss: 0.1895



  1%|          | 38/4999 [11:13:33<1445:55:26, 1049.25s/it]

Epoch 38 finished ! Training Loss: 0.1762



  1%|          | 39/4999 [11:30:42<1437:33:10, 1043.39s/it]

Epoch 39 finished ! Training Loss: 0.1875

Epoch 40 finished ! Training Loss: 0.1876

------- 1st valloss=0.2254



  1%|          | 40/4999 [11:49:07<1462:31:48, 1061.73s/it]

Checkpoint 40 saved !


  1%|          | 41/4999 [12:06:52<1463:34:23, 1062.70s/it]

Epoch 41 finished ! Training Loss: 0.1848



  1%|          | 42/4999 [12:24:04<1450:35:54, 1053.49s/it]

Epoch 42 finished ! Training Loss: 0.1769



  1%|          | 43/4999 [12:41:24<1444:49:14, 1049.51s/it]

Epoch 43 finished ! Training Loss: 0.1747



  1%|          | 44/4999 [12:59:11<1451:46:38, 1054.77s/it]

Epoch 44 finished ! Training Loss: 0.1601

Epoch 45 finished ! Training Loss: 0.1816

------- 1st valloss=0.6033



  1%|          | 45/4999 [13:17:43<1474:52:17, 1071.77s/it]

Checkpoint 45 saved !


  1%|          | 46/4999 [13:35:12<1465:29:12, 1065.16s/it]

Epoch 46 finished ! Training Loss: 0.1795



  1%|          | 47/4999 [13:52:50<1462:00:25, 1062.85s/it]

Epoch 47 finished ! Training Loss: 0.1586



  1%|          | 48/4999 [14:10:13<1453:44:31, 1057.05s/it]

Epoch 48 finished ! Training Loss: 0.1639



  1%|          | 49/4999 [14:27:41<1449:23:39, 1054.10s/it]

Epoch 49 finished ! Training Loss: 0.1697

Epoch 50 finished ! Training Loss: 0.1527

------- 1st valloss=0.1317

0.13167721810548202 less than 0.13402820475723431


  1%|          | 50/4999 [14:46:06<1470:28:30, 1069.65s/it]

Checkpoint 50 saved !


  1%|          | 51/4999 [15:03:40<1463:33:07, 1064.83s/it]

Epoch 51 finished ! Training Loss: 0.1675



  1%|          | 52/4999 [15:21:29<1465:02:44, 1066.13s/it]

Epoch 52 finished ! Training Loss: 0.1742



In [None]:
deeplab.eval()

with torch.no_grad():
    
    bgloss = 0
    bdloss = 0
    bvloss = 0
    
    for v, vbatch in tqdm(enumerate(validation_loader)):
            # move data to device, convert dtype to desirable dtype

        image_1 = vbatch['image1_data'].to(device=device, dtype=dtype)
        label_1 = vbatch['image1_label'].to(device=device, dtype=dtype)

        output = deeplab(image_1)
        # do the inference
        output_numpy = output.cpu().numpy()
        
        
        #out_1 = torch.round(output)
        out_1 = torch.from_numpy((output_numpy == output_numpy.max(axis=1)[:, None]).astype(int)).to(device=device, dtype=dtype)
        loss_1 = dice_loss_3(out_1, label_1)
        show_image_slice(image_1)
        show_image_slice(label_1)
        show_image_slice(out_1)
        
        bg, bd, bv = dice_loss_3_debug(out_1, label_1)
        # calculate loss
        print(bg.item(), bd.item(), bv.item(), loss_1.item())
        bgloss += bg.item()
        bdloss += bd.item()
        bvloss += bv.item()

    outstr = '------- background loss = {0:.4f}, body loss = {1:.4f}, bv loss = {2:.4f}'\
        .format(bgloss/(v+1), bdloss/(v+1), bvloss/(v+1)) + '\n'
    print(outstr)

# 