In [1]:
import sys
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utility import *
from data_utils import *
from loss import *
from train import *
from deeplab_model.deeplab import *
from sync_batchnorm import convert_model
import datetime

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
USE_GPU = True
NUM_WORKERS = 12
BATCH_SIZE = 2 

dtype = torch.float32 
# define dtype, float is space efficient than double

if USE_GPU and torch.cuda.is_available():
    
    device = torch.device('cuda')
    
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    # magic flag that accelerate
    
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


In [3]:
train_dataset = pyramid_dataset(data_type = 'nii_train', 
                transform=transforms.Compose([
                random_affine(90, 15),
                random_filp(0.5), 
                transforms.RandomApply([CropAndPad()])
                ]))
# do data augumentation on train dataset

validation_dataset = pyramid_dataset(data_type = 'nii_test', 
                transform=None)
# no data augumentation on validation dataset

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS) # drop_last
# loaders come with auto batch division and multi-thread acceleration

In [4]:
deeplab = DeepLab(output_stride=2)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)
deeplab = deeplab.to(device=device, dtype=dtype)
#shape_test(icnet1, True)
# create the model, by default model type is float, use model.double(), model.float() to convert
# move the model to desirable device

optimizer = optim.Adam(deeplab.parameters(), lr=1e-2)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1)
epoch = 0

# create an optimizer object
# note that only the model_2 params and model_4 params will be optimized by optimizer

In [None]:
"""
deeplab = DeepLab(output_stride=8)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)

optimizer = optim.Adam(deeplab.parameters(), lr=1e-2)

#checkpoint = torch.load('../deeplab_save/2019-07-29 04:00:14.630172.pth') # second best
#checkpoint = torch.load('../deeplab_save/2019-07-28 23:47:36.279119.pth') # second best
#checkpoint = torch.load('../deeplab_save/2019-07-29 00:15:49.271222.pth') # best
#checkpoint = torch.load('../deeplab_save/2019-07-29 00:44:11.825872.pth')
checkpoint = torch.load('../deeplab_save/2019-07-31 20:34:01.096131.pth') # latest one

deeplab.load_state_dict(checkpoint['state_dict_1'])
optimizer.load_state_dict(checkpoint['optimizer'])
#scheduler.load_state_dict(checkpoint['scheduler'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
scheduler.load_state_dict(checkpoint['scheduler'])
epoch = checkpoint['epoch']
print(epoch)
"""

"\ndeeplab = DeepLab(output_stride=8)\ndeeplab = nn.DataParallel(deeplab)\ndeeplab = convert_model(deeplab)\n\noptimizer = optim.Adam(deeplab.parameters(), lr=1e-2)\n\n#checkpoint = torch.load('../deeplab_save/2019-07-29 04:00:14.630172.pth') # second best\n#checkpoint = torch.load('../deeplab_save/2019-07-28 23:47:36.279119.pth') # second best\n#checkpoint = torch.load('../deeplab_save/2019-07-29 00:15:49.271222.pth') # best\n#checkpoint = torch.load('../deeplab_save/2019-07-29 00:44:11.825872.pth')\ncheckpoint = torch.load('../deeplab_save/2019-07-31 20:34:01.096131.pth') # latest one\n\ndeeplab.load_state_dict(checkpoint['state_dict_1'])\noptimizer.load_state_dict(checkpoint['optimizer'])\n#scheduler.load_state_dict(checkpoint['scheduler'])\nscheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)\nscheduler.load_state_dict(checkpoint['scheduler'])\nepoch = checkpoint['epoch']\nprint(epoch)\n"

In [None]:
epochs = 5000

min_val = 1

record = open('train_deeplab_output_2_more_aug.txt','w+')

logger = {'train':[], 'validation_1': []}

for e in tqdm(range(epoch + 1, epochs)):
# iter over epoches

    epoch_loss = 0
        
    for t, batch in enumerate(train_loader):
    # iter over the train mini batches
    
        deeplab.train()
        # Set the model flag to train
        # 1. enable dropout
        # 2. batchnorm behave differently in train and test
        
        image_1 = batch['image1_data'].to(device=device, dtype=dtype)
        label_1 = batch['image1_label'].to(device=device, dtype=dtype)
        # move data to device, convert dtype to desirable dtype
        
        out_1 = deeplab(image_1)
        # do the inference

        loss_1 = dice_loss_3(out_1, label_1)
        # calculate loss
        
        epoch_loss += loss_1.item()
        # record minibatch loss to epoch loss
        
        optimizer.zero_grad()
        # set the model parameter gradient to zero
        
        loss_1.backward()
        # calculate the gradient wrt loss
        optimizer.step()
        #scheduler.step(loss_1)
        # take a gradient descent step
        
    outstr = 'Epoch {0} finished ! Training Loss: {1:.4f}'.format(e, epoch_loss/(t+1)) + '\n'
    
    logger['train'].append(epoch_loss/(t+1))
    
    print(outstr)
    record.write(outstr)
    record.flush()

    if e%5 == 0:
    # do validation every 5 epoches
    
        deeplab.eval()
        # set model flag to eval
        # 1. disable dropout
        # 2. batchnorm behave differs

        with torch.no_grad():
        # stop taking gradient
        
            #valloss_4 = 0
            #valloss_2 = 0
            valloss_1 = 0
            
            for v, vbatch in enumerate(validation_loader):
            # iter over validation mini batches
                
                image_1_val = vbatch['image1_data'].to(device=device, dtype=dtype)
                if get_dimensions(image_1_val) == 4:
                    image_1_val.unsqueeze_(0)
                label_1_val = vbatch['image1_label'].to(device=device, dtype=dtype)
                if get_dimensions(label_1_val) == 4:
                    label_1_val.unsqueeze_(0)
                # move data to device, convert dtype to desirable dtype
                # add one dimension to labels if they are 4D tensors
                
                out_1_val = deeplab(image_1_val)
                # do the inference
                
                loss_1 = dice_loss_3(out_1_val, label_1_val)
                # calculate loss

                valloss_1 += loss_1.item()
                # record mini batch loss
            
            avg_val_loss = (valloss_1 / (v+1))
            outstr = '------- 1st valloss={0:.4f}'\
                .format(avg_val_loss) + '\n'
            
            logger['validation_1'].append(avg_val_loss)
            #scheduler.step(avg_val_loss)
            
            print(outstr)
            record.write(outstr)
            record.flush()
            
            if avg_val_loss < min_val:
                print(avg_val_loss, "less than", min_val)
                min_val = avg_val_loss
                
            save_1('deeplab_output_2_more_aug_save', deeplab, optimizer, logger, e, scheduler)

record.close()

  0%|          | 1/4999 [13:29<1124:25:59, 809.92s/it]

Epoch 1 finished ! Training Loss: 0.5205



  0%|          | 2/4999 [25:30<1086:55:02, 783.05s/it]

Epoch 2 finished ! Training Loss: 0.4898



  0%|          | 3/4999 [37:37<1063:22:54, 766.25s/it]

Epoch 3 finished ! Training Loss: 0.4841



  0%|          | 4/4999 [49:44<1047:03:24, 754.64s/it]

Epoch 4 finished ! Training Loss: 0.4798

Epoch 5 finished ! Training Loss: 0.4736

------- 1st valloss=0.5134

0.5133609421875166 less than 1


  0%|          | 5/4999 [1:02:55<1061:55:16, 765.50s/it]

Checkpoint 5 saved !


  0%|          | 6/4999 [1:15:23<1054:29:09, 760.29s/it]

Epoch 6 finished ! Training Loss: 0.4670



  0%|          | 7/4999 [1:27:41<1044:44:40, 753.42s/it]

Epoch 7 finished ! Training Loss: 0.4666



  0%|          | 8/4999 [1:39:53<1035:52:47, 747.18s/it]

Epoch 8 finished ! Training Loss: 0.4467



  0%|          | 9/4999 [1:51:51<1023:28:19, 738.38s/it]

Epoch 9 finished ! Training Loss: 0.4312

Epoch 10 finished ! Training Loss: 0.4329

------- 1st valloss=0.5511



  0%|          | 10/4999 [2:04:50<1040:01:03, 750.46s/it]

Checkpoint 10 saved !


  0%|          | 11/4999 [2:16:56<1029:50:27, 743.27s/it]

Epoch 11 finished ! Training Loss: 0.4333



  0%|          | 12/4999 [2:29:05<1023:23:55, 738.77s/it]

Epoch 12 finished ! Training Loss: 0.4113



  0%|          | 13/4999 [2:41:23<1023:00:37, 738.64s/it]

Epoch 13 finished ! Training Loss: 0.4182



  0%|          | 14/4999 [2:53:34<1019:46:40, 736.45s/it]

Epoch 14 finished ! Training Loss: 0.3993

Epoch 15 finished ! Training Loss: 0.3934

------- 1st valloss=0.4220

0.4219803265903307 less than 0.5133609421875166


  0%|          | 15/4999 [3:06:27<1034:36:39, 747.31s/it]

Checkpoint 15 saved !


  0%|          | 16/4999 [3:18:43<1029:40:30, 743.90s/it]

Epoch 16 finished ! Training Loss: 0.3997



  0%|          | 17/4999 [3:30:55<1024:26:33, 740.26s/it]

Epoch 17 finished ! Training Loss: 0.4137



  0%|          | 18/4999 [3:43:04<1019:46:29, 737.04s/it]

Epoch 18 finished ! Training Loss: 0.3741



  0%|          | 19/4999 [3:55:13<1016:06:48, 734.54s/it]

Epoch 19 finished ! Training Loss: 0.3819

Epoch 20 finished ! Training Loss: 0.4061

------- 1st valloss=0.3101

0.3100691968980043 less than 0.4219803265903307


  0%|          | 20/4999 [4:08:18<1036:54:25, 749.72s/it]

Checkpoint 20 saved !


  0%|          | 21/4999 [4:20:37<1032:25:38, 746.63s/it]

Epoch 21 finished ! Training Loss: 0.3849



  0%|          | 22/4999 [4:32:53<1027:27:15, 743.19s/it]

Epoch 22 finished ! Training Loss: 0.3787



  0%|          | 23/4999 [4:45:01<1021:00:40, 738.67s/it]

Epoch 23 finished ! Training Loss: 0.3855



  0%|          | 24/4999 [4:57:18<1020:09:32, 738.21s/it]

Epoch 24 finished ! Training Loss: 0.3760

Epoch 25 finished ! Training Loss: 0.3670

------- 1st valloss=0.2449

0.2449473395295765 less than 0.3100691968980043


  1%|          | 25/4999 [5:10:05<1031:55:48, 746.87s/it]

Checkpoint 25 saved !


  1%|          | 26/4999 [5:22:12<1023:30:57, 740.93s/it]

Epoch 26 finished ! Training Loss: 0.3784



  1%|          | 27/4999 [5:34:24<1019:26:00, 738.13s/it]

Epoch 27 finished ! Training Loss: 0.3813



  1%|          | 28/4999 [5:46:28<1013:31:17, 733.99s/it]

Epoch 28 finished ! Training Loss: 0.3704



  1%|          | 29/4999 [5:58:35<1010:22:05, 731.86s/it]

Epoch 29 finished ! Training Loss: 0.3556

Epoch 30 finished ! Training Loss: 0.3583

------- 1st valloss=0.4064



  1%|          | 30/4999 [6:11:38<1031:23:13, 747.23s/it]

Checkpoint 30 saved !


  1%|          | 31/4999 [6:23:45<1022:47:18, 741.15s/it]

Epoch 31 finished ! Training Loss: 0.3810



  1%|          | 32/4999 [6:35:45<1013:48:46, 734.79s/it]

Epoch 32 finished ! Training Loss: 0.3650



  1%|          | 33/4999 [6:48:11<1018:20:38, 738.23s/it]

Epoch 33 finished ! Training Loss: 0.3809



  1%|          | 34/4999 [7:00:32<1019:08:01, 738.95s/it]

Epoch 34 finished ! Training Loss: 0.3688

Epoch 35 finished ! Training Loss: 0.3601

------- 1st valloss=0.4721



  1%|          | 35/4999 [7:13:29<1034:36:07, 750.32s/it]

Checkpoint 35 saved !


  1%|          | 36/4999 [7:25:36<1024:49:15, 743.37s/it]

Epoch 36 finished ! Training Loss: 0.3619



  1%|          | 37/4999 [7:37:49<1020:33:09, 740.43s/it]

Epoch 37 finished ! Training Loss: 0.3762



  1%|          | 38/4999 [7:50:07<1019:20:01, 739.69s/it]

Epoch 38 finished ! Training Loss: 0.3839



  1%|          | 39/4999 [8:02:32<1021:14:37, 741.23s/it]

Epoch 39 finished ! Training Loss: 0.3730

Epoch 40 finished ! Training Loss: 0.3766

------- 1st valloss=0.4467



  1%|          | 40/4999 [8:15:21<1032:35:58, 749.62s/it]

Checkpoint 40 saved !


  1%|          | 41/4999 [8:27:36<1026:13:39, 745.14s/it]

Epoch 41 finished ! Training Loss: 0.3795



  1%|          | 42/4999 [8:39:48<1020:34:14, 741.19s/it]

Epoch 42 finished ! Training Loss: 0.3622



  1%|          | 43/4999 [8:52:00<1016:34:24, 738.43s/it]

Epoch 43 finished ! Training Loss: 0.3567



  1%|          | 44/4999 [9:04:25<1019:17:36, 740.56s/it]

Epoch 44 finished ! Training Loss: 0.3480

Epoch 45 finished ! Training Loss: 0.3420

------- 1st valloss=0.1941

0.1941262373457784 less than 0.2449473395295765


  1%|          | 45/4999 [9:17:34<1038:56:46, 754.99s/it]

Checkpoint 45 saved !


  1%|          | 46/4999 [9:29:34<1024:22:20, 744.55s/it]

Epoch 46 finished ! Training Loss: 0.3611



  1%|          | 47/4999 [9:41:52<1021:24:59, 742.55s/it]

Epoch 47 finished ! Training Loss: 0.3439



  1%|          | 48/4999 [9:54:12<1019:54:20, 741.60s/it]

Epoch 48 finished ! Training Loss: 0.3452



  1%|          | 49/4999 [10:06:20<1014:07:17, 737.54s/it]

Epoch 49 finished ! Training Loss: 0.3667

Epoch 50 finished ! Training Loss: 0.3531

------- 1st valloss=0.4745



  1%|          | 50/4999 [10:19:12<1028:10:04, 747.91s/it]

Checkpoint 50 saved !


  1%|          | 51/4999 [10:31:10<1015:39:32, 738.96s/it]

Epoch 51 finished ! Training Loss: 0.3471



  1%|          | 52/4999 [10:43:09<1007:24:41, 733.11s/it]

Epoch 52 finished ! Training Loss: 0.3507



  1%|          | 53/4999 [10:55:32<1011:02:47, 735.90s/it]

Epoch 53 finished ! Training Loss: 0.3445



  1%|          | 54/4999 [11:07:42<1008:26:22, 734.15s/it]

Epoch 54 finished ! Training Loss: 0.3527

Epoch 55 finished ! Training Loss: 0.3276

------- 1st valloss=0.2350



  1%|          | 55/4999 [11:20:27<1021:10:35, 743.58s/it]

Checkpoint 55 saved !


  1%|          | 56/4999 [11:32:36<1014:47:50, 739.08s/it]

Epoch 56 finished ! Training Loss: 0.3403



  1%|          | 57/4999 [11:44:38<1007:36:51, 734.00s/it]

Epoch 57 finished ! Training Loss: 0.3454



  1%|          | 58/4999 [11:56:54<1008:20:36, 734.68s/it]

Epoch 58 finished ! Training Loss: 0.3629



  1%|          | 59/4999 [12:09:16<1010:56:24, 736.72s/it]

Epoch 59 finished ! Training Loss: 0.3547

Epoch 60 finished ! Training Loss: 0.3431

------- 1st valloss=0.6136



  1%|          | 60/4999 [12:22:14<1027:49:30, 749.17s/it]

Checkpoint 60 saved !


  1%|          | 61/4999 [12:34:36<1024:30:43, 746.91s/it]

Epoch 61 finished ! Training Loss: 0.3258



  1%|          | 62/4999 [12:46:51<1019:21:28, 743.30s/it]

Epoch 62 finished ! Training Loss: 0.3509



  1%|▏         | 63/4999 [12:59:00<1013:18:49, 739.05s/it]

Epoch 63 finished ! Training Loss: 0.3418



  1%|▏         | 64/4999 [13:11:10<1009:22:08, 736.32s/it]

Epoch 64 finished ! Training Loss: 0.3126

Epoch 65 finished ! Training Loss: 0.3572

------- 1st valloss=0.4824



  1%|▏         | 65/4999 [13:24:05<1025:24:47, 748.17s/it]

Checkpoint 65 saved !


  1%|▏         | 66/4999 [13:36:13<1016:54:16, 742.12s/it]

Epoch 66 finished ! Training Loss: 0.3458



  1%|▏         | 67/4999 [13:48:18<1009:29:52, 736.86s/it]

Epoch 67 finished ! Training Loss: 0.3382



In [None]:
deeplab.eval()

with torch.no_grad():
    
    bgloss = 0
    bdloss = 0
    bvloss = 0
    
    for v, vbatch in tqdm(enumerate(validation_loader)):
            # move data to device, convert dtype to desirable dtype

        image_1 = vbatch['image1_data'].to(device=device, dtype=dtype)
        label_1 = vbatch['image1_label'].to(device=device, dtype=dtype)

        output = deeplab(image_1)
        # do the inference
        output_numpy = output.cpu().numpy()
        
        
        #out_1 = torch.round(output)
        out_1 = torch.from_numpy((output_numpy == output_numpy.max(axis=1)[:, None]).astype(int)).to(device=device, dtype=dtype)
        loss_1 = dice_loss_3(out_1, label_1)

        bg, bd, bv = dice_loss_3_debug(out_1, label_1)
        # calculate loss
        print(bg.item(), bd.item(), bv.item(), loss_1.item())
        bgloss += bg.item()
        bdloss += bd.item()
        bvloss += bv.item()

    outstr = '------- background loss = {0:.4f}, body loss = {1:.4f}, bv loss = {2:.4f}'\
        .format(bgloss/(v+1), bdloss/(v+1), bvloss/(v+1)) + '\n'
    print(outstr)