In [1]:
import sys
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utility import *
from data_utils import *
from loss import *
from train import *
from deeplab_model.deeplab import *
from sync_batchnorm import convert_model
import adabound
import datetime

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
USE_GPU = True
NUM_WORKERS = 12
BATCH_SIZE = 2 

dtype = torch.float32 
# define dtype, float is space efficient than double

if USE_GPU and torch.cuda.is_available():
    
    device = torch.device('cuda')
    
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    # magic flag that accelerate
    
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


In [3]:
train_dataset = pyramid_dataset(data_type = 'nii_train', 
                transform=transforms.Compose([
                random_affine(90, 15),
                random_filp(0.5)]))
# do data augumentation on train dataset

validation_dataset = pyramid_dataset(data_type = 'nii_test', 
                transform=None)
# no data augumentation on validation dataset

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS) # drop_last
# loaders come with auto batch division and multi-thread acceleration

In [4]:
"""
deeplab = DeepLab(output_stride=16)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)
deeplab = deeplab.to(device=device, dtype=dtype)
#shape_test(icnet1, True)
# create the model, by default model type is float, use model.double(), model.float() to convert
# move the model to desirable device

optimizer = adabound.AdaBound(deeplab.parameters(), lr=1e-3, final_lr=0.1)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
epoch = 0

# create an optimizer object
# note that only the model_2 params and model_4 params will be optimized by optimizer
"""

"\ndeeplab = DeepLab(output_stride=16)\ndeeplab = nn.DataParallel(deeplab)\ndeeplab = convert_model(deeplab)\ndeeplab = deeplab.to(device=device, dtype=dtype)\n#shape_test(icnet1, True)\n# create the model, by default model type is float, use model.double(), model.float() to convert\n# move the model to desirable device\n\noptimizer = adabound.AdaBound(deeplab.parameters(), lr=1e-3, final_lr=0.1)\nscheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)\nepoch = 0\n\n# create an optimizer object\n# note that only the model_2 params and model_4 params will be optimized by optimizer\n"

In [None]:

deeplab = DeepLab(output_stride=16)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)

optimizer = adabound.AdaBound(deeplab.parameters(), lr=1e-3, final_lr=0.1)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1)

#checkpoint = torch.load('../deeplab_save/2019-07-29 00:44:11.825872.pth')
#checkpoint = torch.load('../deeplab_dilated_save/2019-08-01 08:57:17.225282.pth') # best one
checkpoint = torch.load('../deeplab_output_16_adabound_save/2019-08-06 11:05:17.805214 epoch: 85.pth') # latest one

deeplab.load_state_dict(checkpoint['state_dict_1'])
deeplab = deeplab.to(device, dtype)

optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])

epoch = checkpoint['epoch']
print(epoch)


85


In [None]:
epochs = 5000

record= open('train_deeplab_output_16_adabound.txt','a+')

logger = {'train':[], 'validation_1': []}

min_val = 1

for e in tqdm(range(epoch + 1, epochs)):
# iter over epoches

    epoch_loss = 0
        
    for t, batch in enumerate(train_loader):
    # iter over the train mini batches
    
        deeplab.train()
        # Set the model flag to train
        # 1. enable dropout
        # 2. batchnorm behave differently in train and test
        
        image_1 = batch['image1_data'].to(device=device, dtype=dtype)
        label_1 = batch['image1_label'].to(device=device, dtype=dtype)
        # move data to device, convert dtype to desirable dtype
        
        out_1 = deeplab(image_1)
        # do the inference

        loss_1 = dice_loss_3(out_1, label_1)
        # calculate loss
        
        epoch_loss += loss_1.item()
        # record minibatch loss to epoch loss
        
        optimizer.zero_grad()
        # set the model parameter gradient to zero
        
        loss_1.backward()
        # calculate the gradient wrt loss
        optimizer.step()
        #scheduler.step(loss_1)
        # take a gradient descent step
        
    outstr = 'Epoch {0} finished ! Training Loss: {1:.4f}'.format(e, epoch_loss/(t+1)) + '\n'
    
    logger['train'].append(epoch_loss/(t+1))
    
    print(outstr)
    record.write(outstr)
    record.flush()

    if e%5 == 0:
    # do validation every 5 epoches
    
        deeplab.eval()
        # set model flag to eval
        # 1. disable dropout
        # 2. batchnorm behave differs

        with torch.no_grad():
        # stop taking gradient
        
            #valloss_4 = 0
            #valloss_2 = 0
            valloss_1 = 0
            
            for v, vbatch in enumerate(validation_loader):
            # iter over validation mini batches
                
                image_1_val = vbatch['image1_data'].to(device=device, dtype=dtype)
                if get_dimensions(image_1_val) == 4:
                    image_1_val.unsqueeze_(0)
                label_1_val = vbatch['image1_label'].to(device=device, dtype=dtype)
                if get_dimensions(label_1_val) == 4:
                    label_1_val.unsqueeze_(0)
                # move data to device, convert dtype to desirable dtype
                # add one dimension to labels if they are 4D tensors
                
                out_1_val = deeplab(image_1_val)
                # do the inference
                
                loss_1 = dice_loss_3(out_1_val, label_1_val)
                # calculate loss

                valloss_1 += loss_1.item()
                # record mini batch loss
            avg_val_loss = (valloss_1/(v+1))
            outstr = '------- 1st valloss={0:.4f}'\
                .format(avg_val_loss) + '\n'
            
            logger['validation_1'].append(avg_val_loss)
            #scheduler.step(avg_val_loss)
            
            print(outstr)
            record.write(outstr)
            record.flush()
            
            if avg_val_loss < min_val:
                print(avg_val_loss, ' <', min_val)
                min_val = avg_val_loss
            
            save_1('deeplab_output_16_adabound_save', deeplab, optimizer, logger, e, scheduler)

record.close()

  0%|          | 1/4914 [12:10<996:33:58, 730.23s/it]

Epoch 86 finished ! Training Loss: 0.2440



  0%|          | 2/4914 [23:26<974:10:16, 713.97s/it]

Epoch 87 finished ! Training Loss: 0.2443



  0%|          | 3/4914 [34:38<956:59:03, 701.52s/it]

Epoch 88 finished ! Training Loss: 0.2354



  0%|          | 4/4914 [45:33<937:47:38, 687.59s/it]

Epoch 89 finished ! Training Loss: 0.2329

Epoch 90 finished ! Training Loss: 0.2238

------- 1st valloss=0.1897

0.18965560521768488  < 1


  0%|          | 5/4914 [57:25<947:18:42, 694.71s/it]

Checkpoint 90 saved !


  0%|          | 6/4914 [1:08:22<931:55:06, 683.56s/it]

Epoch 91 finished ! Training Loss: 0.2274



  0%|          | 7/4914 [1:19:16<919:26:20, 674.54s/it]

Epoch 92 finished ! Training Loss: 0.2261



  0%|          | 8/4914 [1:30:22<915:56:36, 672.12s/it]

Epoch 93 finished ! Training Loss: 0.2178



  0%|          | 9/4914 [1:41:23<911:17:32, 668.84s/it]

Epoch 94 finished ! Training Loss: 0.2290

Epoch 95 finished ! Training Loss: 0.2210

------- 1st valloss=0.2411



  0%|          | 10/4914 [1:53:05<924:27:17, 678.64s/it]

Checkpoint 95 saved !


  0%|          | 11/4914 [2:04:05<916:53:49, 673.23s/it]

Epoch 96 finished ! Training Loss: 0.2174



  0%|          | 12/4914 [2:15:15<915:08:13, 672.07s/it]

Epoch 97 finished ! Training Loss: 0.2188



  0%|          | 13/4914 [2:26:19<911:39:53, 669.66s/it]

Epoch 98 finished ! Training Loss: 0.2200



  0%|          | 14/4914 [2:37:33<913:18:59, 671.01s/it]

Epoch 99 finished ! Training Loss: 0.2144

Epoch 100 finished ! Training Loss: 0.2128

------- 1st valloss=0.1768

0.17684343262859012  < 0.18965560521768488


  0%|          | 15/4914 [2:49:15<925:40:09, 680.22s/it]

Checkpoint 100 saved !


  0%|          | 16/4914 [3:00:07<914:05:06, 671.85s/it]

Epoch 101 finished ! Training Loss: 0.2119



  0%|          | 17/4914 [3:11:17<913:08:18, 671.29s/it]

Epoch 102 finished ! Training Loss: 0.2136



  0%|          | 18/4914 [3:22:27<912:25:24, 670.90s/it]

Epoch 103 finished ! Training Loss: 0.2073



  0%|          | 19/4914 [3:33:44<914:39:30, 672.68s/it]

Epoch 104 finished ! Training Loss: 0.2181

Epoch 105 finished ! Training Loss: 0.2135

------- 1st valloss=0.1757

0.1756950195716775  < 0.17684343262859012


  0%|          | 20/4914 [3:45:27<926:44:48, 681.71s/it]

Checkpoint 105 saved !


  0%|          | 21/4914 [3:56:31<919:28:50, 676.50s/it]

Epoch 106 finished ! Training Loss: 0.2066



  0%|          | 22/4914 [4:07:25<910:16:55, 669.87s/it]

Epoch 107 finished ! Training Loss: 0.2093



  0%|          | 23/4914 [4:18:43<913:19:46, 672.25s/it]

Epoch 108 finished ! Training Loss: 0.2044



  0%|          | 24/4914 [4:29:43<908:10:56, 668.60s/it]

Epoch 109 finished ! Training Loss: 0.2040

Epoch 110 finished ! Training Loss: 0.2049

------- 1st valloss=0.1918



  1%|          | 25/4914 [4:41:34<925:21:53, 681.39s/it]

Checkpoint 110 saved !
Epoch 111 finished ! Training Loss: 0.2026



  1%|          | 27/4914 [5:04:00<919:55:10, 677.66s/it]

Epoch 112 finished ! Training Loss: 0.1996



  1%|          | 28/4914 [5:15:14<918:16:09, 676.58s/it]

Epoch 113 finished ! Training Loss: 0.1978



  1%|          | 29/4914 [5:26:19<913:35:25, 673.27s/it]

Epoch 114 finished ! Training Loss: 0.1966

Epoch 115 finished ! Training Loss: 0.1937

------- 1st valloss=0.1942



  1%|          | 30/4914 [5:38:06<927:11:32, 683.43s/it]

Checkpoint 115 saved !


  1%|          | 31/4914 [5:49:14<920:32:12, 678.67s/it]

Epoch 116 finished ! Training Loss: 0.1896



  1%|          | 32/4914 [6:00:09<910:52:07, 671.68s/it]

Epoch 117 finished ! Training Loss: 0.1988



  1%|          | 33/4914 [6:11:17<909:11:10, 670.57s/it]

Epoch 118 finished ! Training Loss: 0.1940



  1%|          | 34/4914 [6:22:11<902:00:12, 665.41s/it]

Epoch 119 finished ! Training Loss: 0.1999

Epoch 120 finished ! Training Loss: 0.2024

------- 1st valloss=0.1385

0.13849218837592914  < 0.1756950195716775


  1%|          | 35/4914 [6:34:00<919:37:21, 678.55s/it]

Checkpoint 120 saved !


  1%|          | 36/4914 [6:45:04<913:34:03, 674.22s/it]

Epoch 121 finished ! Training Loss: 0.1892



  1%|          | 37/4914 [6:56:09<909:42:56, 671.51s/it]

Epoch 122 finished ! Training Loss: 0.1902



  1%|          | 38/4914 [7:07:30<913:23:38, 674.37s/it]

Epoch 123 finished ! Training Loss: 0.1945



  1%|          | 39/4914 [7:18:38<910:26:12, 672.32s/it]

Epoch 124 finished ! Training Loss: 0.1889

Epoch 125 finished ! Training Loss: 0.1950

------- 1st valloss=0.1784



  1%|          | 40/4914 [7:30:19<921:50:49, 680.89s/it]

Checkpoint 125 saved !


  1%|          | 41/4914 [7:41:20<913:42:09, 675.01s/it]

Epoch 126 finished ! Training Loss: 0.1831



  1%|          | 42/4914 [7:52:42<916:15:09, 677.03s/it]

Epoch 127 finished ! Training Loss: 0.1839



  1%|          | 43/4914 [8:03:53<913:41:36, 675.28s/it]

Epoch 128 finished ! Training Loss: 0.1966



  1%|          | 44/4914 [8:14:53<907:11:35, 670.62s/it]

Epoch 129 finished ! Training Loss: 0.1957

Epoch 130 finished ! Training Loss: 0.1842

------- 1st valloss=0.3107



  1%|          | 45/4914 [8:26:38<921:10:20, 681.09s/it]

Checkpoint 130 saved !


  1%|          | 46/4914 [8:37:59<920:47:14, 680.94s/it]

Epoch 131 finished ! Training Loss: 0.1916



  1%|          | 47/4914 [8:48:59<912:16:07, 674.78s/it]

Epoch 132 finished ! Training Loss: 0.1870



  1%|          | 48/4914 [9:00:01<906:51:44, 670.92s/it]

Epoch 133 finished ! Training Loss: 0.1873



  1%|          | 49/4914 [9:11:03<902:59:51, 668.20s/it]

Epoch 134 finished ! Training Loss: 0.1860

Epoch 135 finished ! Training Loss: 0.1802

------- 1st valloss=0.2137



  1%|          | 50/4914 [9:22:50<918:31:56, 679.83s/it]

Checkpoint 135 saved !


  1%|          | 51/4914 [9:33:46<908:34:48, 672.61s/it]

Epoch 136 finished ! Training Loss: 0.1873



  1%|          | 52/4914 [9:44:48<904:17:22, 669.57s/it]

Epoch 137 finished ! Training Loss: 0.1861



  1%|          | 53/4914 [9:55:58<904:06:59, 669.58s/it]

Epoch 138 finished ! Training Loss: 0.1916



  1%|          | 54/4914 [10:06:57<899:54:45, 666.60s/it]

Epoch 139 finished ! Training Loss: 0.1803

Epoch 140 finished ! Training Loss: 0.1895

------- 1st valloss=0.1450



  1%|          | 55/4914 [10:18:40<914:06:46, 677.26s/it]

Checkpoint 140 saved !


  1%|          | 56/4914 [10:29:51<911:25:41, 675.41s/it]

Epoch 141 finished ! Training Loss: 0.1854



  1%|          | 57/4914 [10:40:51<905:20:26, 671.04s/it]

Epoch 142 finished ! Training Loss: 0.1821



  1%|          | 58/4914 [10:52:00<904:16:55, 670.39s/it]

Epoch 143 finished ! Training Loss: 0.1799



  1%|          | 59/4914 [11:03:03<900:52:29, 668.00s/it]

Epoch 144 finished ! Training Loss: 0.1759

Epoch 145 finished ! Training Loss: 0.1825

------- 1st valloss=0.1539



  1%|          | 60/4914 [11:14:53<917:40:21, 680.60s/it]

Checkpoint 145 saved !


  1%|          | 61/4914 [11:25:45<905:55:24, 672.02s/it]

Epoch 146 finished ! Training Loss: 0.1756



  1%|▏         | 62/4914 [11:36:58<906:13:38, 672.39s/it]

Epoch 147 finished ! Training Loss: 0.1787



  1%|▏         | 63/4914 [11:48:09<905:23:11, 671.90s/it]

Epoch 148 finished ! Training Loss: 0.1736



  1%|▏         | 64/4914 [11:59:01<897:08:03, 665.91s/it]

Epoch 149 finished ! Training Loss: 0.1741

Epoch 150 finished ! Training Loss: 0.1769

------- 1st valloss=0.1350

0.13495922023835388  < 0.13849218837592914


  1%|▏         | 65/4914 [12:10:45<912:24:25, 677.39s/it]

Checkpoint 150 saved !


  1%|▏         | 66/4914 [12:21:53<908:19:31, 674.50s/it]

Epoch 151 finished ! Training Loss: 0.1740



  1%|▏         | 67/4914 [12:33:03<906:35:49, 673.35s/it]

Epoch 152 finished ! Training Loss: 0.1732



  1%|▏         | 68/4914 [12:44:11<904:16:19, 671.77s/it]

Epoch 153 finished ! Training Loss: 0.1756



  1%|▏         | 69/4914 [12:55:17<901:31:30, 669.86s/it]

Epoch 154 finished ! Training Loss: 0.1775

Epoch 155 finished ! Training Loss: 0.1744

------- 1st valloss=0.1826



  1%|▏         | 70/4914 [13:07:00<914:56:28, 679.97s/it]

Checkpoint 155 saved !


  1%|▏         | 71/4914 [13:17:58<905:35:05, 673.16s/it]

Epoch 156 finished ! Training Loss: 0.1721



  1%|▏         | 72/4914 [13:29:09<904:34:22, 672.54s/it]

Epoch 157 finished ! Training Loss: 0.1772



  1%|▏         | 73/4914 [13:40:14<901:37:26, 670.49s/it]

Epoch 158 finished ! Training Loss: 0.1742



  2%|▏         | 74/4914 [13:51:06<893:48:51, 664.82s/it]

Epoch 159 finished ! Training Loss: 0.1699

Epoch 160 finished ! Training Loss: 0.1700

------- 1st valloss=0.1309

0.1309332698583603  < 0.13495922023835388


  2%|▏         | 75/4914 [14:02:57<912:14:57, 678.67s/it]

Checkpoint 160 saved !


  2%|▏         | 76/4914 [14:14:10<909:34:08, 676.82s/it]

Epoch 161 finished ! Training Loss: 0.1715



  2%|▏         | 77/4914 [14:25:09<902:34:41, 671.76s/it]

Epoch 162 finished ! Training Loss: 0.1644



  2%|▏         | 78/4914 [14:36:15<900:03:44, 670.02s/it]

Epoch 163 finished ! Training Loss: 0.1668



  2%|▏         | 79/4914 [14:47:20<897:49:51, 668.50s/it]

Epoch 164 finished ! Training Loss: 0.1644

Epoch 165 finished ! Training Loss: 0.1636

------- 1st valloss=0.1658



  2%|▏         | 80/4914 [14:58:58<909:13:30, 677.12s/it]

Checkpoint 165 saved !


  2%|▏         | 81/4914 [15:10:10<907:18:19, 675.83s/it]

Epoch 166 finished ! Training Loss: 0.1733



  2%|▏         | 82/4914 [15:21:19<904:13:43, 673.68s/it]

Epoch 167 finished ! Training Loss: 0.1610



  2%|▏         | 83/4914 [15:32:31<903:11:42, 673.05s/it]

Epoch 168 finished ! Training Loss: 0.1635



  2%|▏         | 84/4914 [15:43:45<903:42:01, 673.57s/it]

Epoch 169 finished ! Training Loss: 0.1697

Epoch 170 finished ! Training Loss: 0.1678

------- 1st valloss=0.1522



  2%|▏         | 85/4914 [15:55:37<918:47:47, 684.96s/it]

Checkpoint 170 saved !


  2%|▏         | 86/4914 [16:06:32<906:34:25, 675.99s/it]

Epoch 171 finished ! Training Loss: 0.1664



  2%|▏         | 87/4914 [16:17:49<906:49:16, 676.31s/it]

Epoch 172 finished ! Training Loss: 0.1632



  2%|▏         | 88/4914 [16:28:46<898:44:07, 670.42s/it]

Epoch 173 finished ! Training Loss: 0.1596



  2%|▏         | 89/4914 [16:39:46<894:25:25, 667.34s/it]

Epoch 174 finished ! Training Loss: 0.1722

Epoch 175 finished ! Training Loss: 0.1605

------- 1st valloss=0.1258

0.12582957291084787  < 0.1309332698583603


  2%|▏         | 90/4914 [16:51:31<909:26:33, 678.69s/it]

Checkpoint 175 saved !


  2%|▏         | 91/4914 [17:02:28<900:22:40, 672.06s/it]

Epoch 176 finished ! Training Loss: 0.1596



  2%|▏         | 92/4914 [17:13:37<898:58:13, 671.15s/it]

Epoch 177 finished ! Training Loss: 0.1605



  2%|▏         | 93/4914 [17:24:58<903:01:57, 674.32s/it]

Epoch 178 finished ! Training Loss: 0.1646



  2%|▏         | 94/4914 [17:36:17<904:43:49, 675.73s/it]

Epoch 179 finished ! Training Loss: 0.1570

Epoch 180 finished ! Training Loss: 0.1645

------- 1st valloss=0.1314



  2%|▏         | 95/4914 [17:48:10<919:08:37, 686.64s/it]

Checkpoint 180 saved !


  2%|▏         | 96/4914 [17:59:18<911:29:47, 681.07s/it]

Epoch 181 finished ! Training Loss: 0.1557



  2%|▏         | 97/4914 [18:10:26<906:01:29, 677.12s/it]

Epoch 182 finished ! Training Loss: 0.1599



  2%|▏         | 98/4914 [18:21:16<894:56:19, 668.97s/it]

Epoch 183 finished ! Training Loss: 0.1617



  2%|▏         | 99/4914 [18:32:28<896:11:11, 670.05s/it]

Epoch 184 finished ! Training Loss: 0.1592

Epoch 185 finished ! Training Loss: 0.1619

------- 1st valloss=0.1642



  2%|▏         | 100/4914 [18:44:14<910:30:01, 680.89s/it]

Checkpoint 185 saved !


  2%|▏         | 101/4914 [18:55:28<907:15:22, 678.60s/it]

Epoch 186 finished ! Training Loss: 0.1557



  2%|▏         | 102/4914 [19:06:29<900:22:23, 673.60s/it]

Epoch 187 finished ! Training Loss: 0.1592



  2%|▏         | 103/4914 [19:17:40<899:00:32, 672.72s/it]

Epoch 188 finished ! Training Loss: 0.1552



  2%|▏         | 104/4914 [19:28:58<901:03:51, 674.39s/it]

Epoch 189 finished ! Training Loss: 0.1518

Epoch 190 finished ! Training Loss: 0.1489

------- 1st valloss=0.1660



  2%|▏         | 105/4914 [19:40:50<915:58:32, 685.70s/it]

Checkpoint 190 saved !


  2%|▏         | 106/4914 [19:52:02<909:59:48, 681.36s/it]

Epoch 191 finished ! Training Loss: 0.1538



  2%|▏         | 107/4914 [20:03:08<903:45:19, 676.83s/it]

Epoch 192 finished ! Training Loss: 0.1501



  2%|▏         | 108/4914 [20:14:05<895:46:35, 670.99s/it]

Epoch 193 finished ! Training Loss: 0.1506



  2%|▏         | 109/4914 [20:25:38<904:13:40, 677.47s/it]

Epoch 194 finished ! Training Loss: 0.1485

Epoch 195 finished ! Training Loss: 0.1434

------- 1st valloss=0.1687



  2%|▏         | 110/4914 [20:37:14<911:17:44, 682.90s/it]

Checkpoint 195 saved !


  2%|▏         | 111/4914 [20:48:23<905:45:48, 678.90s/it]

Epoch 196 finished ! Training Loss: 0.1572



  2%|▏         | 112/4914 [20:59:35<902:54:41, 676.90s/it]

Epoch 197 finished ! Training Loss: 0.1495



  2%|▏         | 113/4914 [21:10:36<896:21:15, 672.13s/it]

Epoch 198 finished ! Training Loss: 0.1504



  2%|▏         | 114/4914 [21:21:30<888:38:51, 666.49s/it]

Epoch 199 finished ! Training Loss: 0.1516

Epoch 200 finished ! Training Loss: 0.1532

------- 1st valloss=0.1371



  2%|▏         | 115/4914 [21:33:10<901:59:06, 676.63s/it]

Checkpoint 200 saved !


  2%|▏         | 116/4914 [21:44:21<899:46:12, 675.11s/it]

Epoch 201 finished ! Training Loss: 0.1488



  2%|▏         | 117/4914 [21:55:15<890:48:36, 668.53s/it]

Epoch 202 finished ! Training Loss: 0.1469



  2%|▏         | 118/4914 [22:06:24<890:46:18, 668.64s/it]

Epoch 203 finished ! Training Loss: 0.1505



  2%|▏         | 119/4914 [22:17:27<888:40:37, 667.20s/it]

Epoch 204 finished ! Training Loss: 0.1470

Epoch 205 finished ! Training Loss: 0.1444

------- 1st valloss=0.1575



  2%|▏         | 120/4914 [22:29:08<901:49:41, 677.22s/it]

Checkpoint 205 saved !


  2%|▏         | 121/4914 [22:40:23<900:42:22, 676.52s/it]

Epoch 206 finished ! Training Loss: 0.1495



  2%|▏         | 122/4914 [22:51:31<897:20:11, 674.13s/it]

Epoch 207 finished ! Training Loss: 0.1461



  3%|▎         | 123/4914 [23:02:47<897:43:37, 674.56s/it]

Epoch 208 finished ! Training Loss: 0.1513



  3%|▎         | 124/4914 [23:13:46<891:26:04, 669.97s/it]

Epoch 209 finished ! Training Loss: 0.1457

Epoch 210 finished ! Training Loss: 0.1477

------- 1st valloss=0.1393



  3%|▎         | 125/4914 [23:25:38<907:49:31, 682.43s/it]

Checkpoint 210 saved !


  3%|▎         | 126/4914 [23:36:51<903:52:59, 679.61s/it]

Epoch 211 finished ! Training Loss: 0.1392



  3%|▎         | 127/4914 [23:47:59<899:18:12, 676.31s/it]

Epoch 212 finished ! Training Loss: 0.1454



  3%|▎         | 128/4914 [23:59:01<893:08:53, 671.82s/it]

Epoch 213 finished ! Training Loss: 0.1428



  3%|▎         | 129/4914 [24:10:11<892:14:30, 671.28s/it]

Epoch 214 finished ! Training Loss: 0.1419

Epoch 215 finished ! Training Loss: 0.1417

------- 1st valloss=0.1351



  3%|▎         | 130/4914 [24:21:58<906:15:45, 681.97s/it]

Checkpoint 215 saved !


  3%|▎         | 131/4914 [24:32:58<897:27:06, 675.48s/it]

Epoch 216 finished ! Training Loss: 0.1421



  3%|▎         | 132/4914 [24:43:59<891:36:36, 671.22s/it]

Epoch 217 finished ! Training Loss: 0.1381



  3%|▎         | 133/4914 [24:55:02<888:06:30, 668.73s/it]

Epoch 218 finished ! Training Loss: 0.1361



  3%|▎         | 134/4914 [25:05:56<881:53:18, 664.18s/it]

Epoch 219 finished ! Training Loss: 0.1373

Epoch 220 finished ! Training Loss: 0.1380

------- 1st valloss=0.1423



  3%|▎         | 135/4914 [25:17:35<895:42:21, 674.73s/it]

Checkpoint 220 saved !


  3%|▎         | 136/4914 [25:28:35<889:37:59, 670.30s/it]

Epoch 221 finished ! Training Loss: 0.1401



  3%|▎         | 137/4914 [25:39:39<886:48:54, 668.31s/it]

Epoch 222 finished ! Training Loss: 0.1384



  3%|▎         | 138/4914 [25:50:40<884:00:53, 666.34s/it]

Epoch 223 finished ! Training Loss: 0.1417



  3%|▎         | 139/4914 [26:01:53<886:06:10, 668.06s/it]

Epoch 224 finished ! Training Loss: 0.1352

Epoch 225 finished ! Training Loss: 0.1401

------- 1st valloss=0.1381



  3%|▎         | 140/4914 [26:13:47<904:18:04, 681.92s/it]

Checkpoint 225 saved !


  3%|▎         | 141/4914 [26:24:52<897:26:08, 676.88s/it]

Epoch 226 finished ! Training Loss: 0.1416



  3%|▎         | 142/4914 [26:35:59<893:28:05, 674.03s/it]

Epoch 227 finished ! Training Loss: 0.1337



  3%|▎         | 143/4914 [26:47:01<888:27:19, 670.39s/it]

Epoch 228 finished ! Training Loss: 0.1352



  3%|▎         | 144/4914 [26:58:13<889:00:00, 670.94s/it]

Epoch 229 finished ! Training Loss: 0.1363

Epoch 230 finished ! Training Loss: 0.1355

------- 1st valloss=0.1848



  3%|▎         | 145/4914 [27:09:46<897:35:39, 677.57s/it]

Checkpoint 230 saved !


  3%|▎         | 146/4914 [27:20:53<893:11:44, 674.39s/it]

Epoch 231 finished ! Training Loss: 0.1352



  3%|▎         | 147/4914 [27:31:55<887:46:12, 670.44s/it]

Epoch 232 finished ! Training Loss: 0.1401



  3%|▎         | 148/4914 [27:43:08<888:49:45, 671.38s/it]

Epoch 233 finished ! Training Loss: 0.1314



  3%|▎         | 149/4914 [27:54:09<884:19:56, 668.12s/it]

Epoch 234 finished ! Training Loss: 0.1392

Epoch 235 finished ! Training Loss: 0.1306

------- 1st valloss=0.1556



  3%|▎         | 150/4914 [28:06:04<902:48:25, 682.22s/it]

Checkpoint 235 saved !


  3%|▎         | 151/4914 [28:17:13<897:12:43, 678.14s/it]

Epoch 236 finished ! Training Loss: 0.1308



  3%|▎         | 152/4914 [28:28:15<890:48:48, 673.44s/it]

Epoch 237 finished ! Training Loss: 0.1330



  3%|▎         | 153/4914 [28:39:20<887:16:55, 670.91s/it]

Epoch 238 finished ! Training Loss: 0.1348



  3%|▎         | 154/4914 [28:50:20<882:54:12, 667.74s/it]

Epoch 239 finished ! Training Loss: 0.1257

Epoch 240 finished ! Training Loss: 0.1335

------- 1st valloss=0.1358



  3%|▎         | 155/4914 [29:02:14<901:06:38, 681.66s/it]

Checkpoint 240 saved !


  3%|▎         | 156/4914 [29:13:22<895:16:03, 677.38s/it]

Epoch 241 finished ! Training Loss: 0.1294



  3%|▎         | 157/4914 [29:24:16<885:58:31, 670.49s/it]

Epoch 242 finished ! Training Loss: 0.1292



  3%|▎         | 158/4914 [29:35:22<883:58:13, 669.11s/it]

Epoch 243 finished ! Training Loss: 0.1250



  3%|▎         | 159/4914 [29:46:28<882:38:35, 668.25s/it]

Epoch 244 finished ! Training Loss: 0.1282

Epoch 245 finished ! Training Loss: 0.1258

------- 1st valloss=0.1429



  3%|▎         | 160/4914 [29:58:24<901:01:19, 682.31s/it]

Checkpoint 245 saved !


  3%|▎         | 161/4914 [30:09:43<899:43:35, 681.47s/it]

Epoch 246 finished ! Training Loss: 0.1264



  3%|▎         | 162/4914 [30:20:44<891:18:56, 675.24s/it]

Epoch 247 finished ! Training Loss: 0.1286



  3%|▎         | 163/4914 [30:31:51<887:52:28, 672.77s/it]

Epoch 248 finished ! Training Loss: 0.1293



  3%|▎         | 164/4914 [30:42:57<885:11:50, 670.89s/it]

Epoch 249 finished ! Training Loss: 0.1270

Epoch 250 finished ! Training Loss: 0.1252

------- 1st valloss=0.1285



  3%|▎         | 165/4914 [30:54:39<897:18:30, 680.21s/it]

Checkpoint 250 saved !


  3%|▎         | 166/4914 [31:05:43<890:29:36, 675.18s/it]

Epoch 251 finished ! Training Loss: 0.1191



In [None]:
deeplab.eval()

with torch.no_grad():
    
    bgloss = 0
    bdloss = 0
    bvloss = 0
    
    for v, vbatch in tqdm(enumerate(validation_loader)):
            # move data to device, convert dtype to desirable dtype

        image_1 = vbatch['image1_data'].to(device=device, dtype=dtype)
        label_1 = vbatch['image1_label'].to(device=device, dtype=dtype)

        output = deeplab(image_1)
        # do the inference
        output_numpy = output.cpu().numpy()
        
        
        #out_1 = torch.round(output)
        out_1 = torch.from_numpy((output_numpy == output_numpy.max(axis=1)[:, None]).astype(int)).to(device=device, dtype=dtype)
        loss_1 = dice_loss_3(out_1, label_1)

        bg, bd, bv = dice_loss_3_debug(out_1, label_1)
        # calculate loss
        print(bg.item(), bd.item(), bv.item(), loss_1.item())
        bgloss += bg.item()
        bdloss += bd.item()
        bvloss += bv.item()

    outstr = '------- background loss = {0:.4f}, body loss = {1:.4f}, bv loss = {2:.4f}'\
        .format(bgloss/(v+1), bdloss/(v+1), bvloss/(v+1)) + '\n'
    print(outstr)