In [1]:
import sys
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utility import *
from data_utils import *
from loss import *
from train import *
from deeplab_model.deeplab import *
from sync_batchnorm import convert_model
import adabound
import datetime

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
USE_GPU = True
NUM_WORKERS = 12
BATCH_SIZE = 2 

dtype = torch.float32 
# define dtype, float is space efficient than double

if USE_GPU and torch.cuda.is_available():
    
    device = torch.device('cuda')
    
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    # magic flag that accelerate
    
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


In [3]:
train_dataset = pyramid_dataset(data_type = 'nii_train', 
                transform=transforms.Compose([
                random_affine(90, 15),
                random_filp(0.5)]))
# do data augumentation on train dataset

validation_dataset = pyramid_dataset(data_type = 'nii_test', 
                transform=None)
# no data augumentation on validation dataset

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS) # drop_last
# loaders come with auto batch division and multi-thread acceleration

In [None]:
"""
deeplab = DeepLab(output_stride=2)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)
deeplab = deeplab.to(device=device, dtype=dtype)
#shape_test(icnet1, True)
# create the model, by default model type is float, use model.double(), model.float() to convert
# move the model to desirable device

optimizer = adabound.AdaBound(deeplab.parameters(), lr=1e-3, final_lr=0.1)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
epoch = 0

# create an optimizer object
# note that only the model_2 params and model_4 params will be optimized by optimizer
"""

"\ndeeplab = DeepLab(output_stride=2)\ndeeplab = nn.DataParallel(deeplab)\ndeeplab = convert_model(deeplab)\ndeeplab = deeplab.to(device=device, dtype=dtype)\n#shape_test(icnet1, True)\n# create the model, by default model type is float, use model.double(), model.float() to convert\n# move the model to desirable device\n\noptimizer = adabound.AdaBound(deeplab.parameters(), lr=1e-3, final_lr=0.1)\nscheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)\nepoch = 0\n\n# create an optimizer object\n# note that only the model_2 params and model_4 params will be optimized by optimizer\n"

In [None]:

deeplab = DeepLab(output_stride=2)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)

#checkpoint = torch.load('../deeplab_save/2019-07-29 00:44:11.825872.pth')
#checkpoint = torch.load('../deeplab_dilated_save/2019-08-01 08:57:17.225282.pth') # best one
checkpoint = torch.load('../deeplab_output_2_adabound_save/2019-08-06 11:02:10.001003 epoch: 85.pth') # latest one

deeplab.load_state_dict(checkpoint['state_dict_1'])
deeplab = deeplab.to(device, dtype)

optimizer = adabound.AdaBound(deeplab.parameters(), lr=1e-3, final_lr=0.1)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

optimizer.load_state_dict(checkpoint['optimizer'])
#scheduler.load_state_dict(checkpoint['scheduler'])
scheduler.load_state_dict(checkpoint['scheduler'])

epoch = checkpoint['epoch']
print(epoch)


85


In [None]:
epochs = 5000

record= open('train_deeplab_output_2_adabound.txt','a')

logger = {'train':[], 'validation_1': []}

min_val = 1

for e in tqdm(range(epoch + 1, epochs)):
# iter over epoches

    epoch_loss = 0
        
    for t, batch in enumerate(train_loader):
    # iter over the train mini batches
    
        deeplab.train()
        # Set the model flag to train
        # 1. enable dropout
        # 2. batchnorm behave differently in train and test
        
        image_1 = batch['image1_data'].to(device=device, dtype=dtype)
        label_1 = batch['image1_label'].to(device=device, dtype=dtype)
        # move data to device, convert dtype to desirable dtype
        
        out_1 = deeplab(image_1)
        # do the inference

        loss_1 = dice_loss_3(out_1, label_1)
        # calculate loss
        
        epoch_loss += loss_1.item()
        # record minibatch loss to epoch loss
        
        optimizer.zero_grad()
        # set the model parameter gradient to zero
        
        loss_1.backward()
        # calculate the gradient wrt loss
        optimizer.step()
        #scheduler.step(loss_1)
        # take a gradient descent step
        
    outstr = 'Epoch {0} finished ! Training Loss: {1:.4f}'.format(e, epoch_loss/(t+1)) + '\n'
    
    logger['train'].append(epoch_loss/(t+1))
    
    print(outstr)
    record.write(outstr)
    record.flush()

    if e%5 == 0:
    # do validation every 5 epoches
    
        deeplab.eval()
        # set model flag to eval
        # 1. disable dropout
        # 2. batchnorm behave differs

        with torch.no_grad():
        # stop taking gradient
        
            #valloss_4 = 0
            #valloss_2 = 0
            valloss_1 = 0
            
            for v, vbatch in enumerate(validation_loader):
            # iter over validation mini batches
                
                image_1_val = vbatch['image1_data'].to(device=device, dtype=dtype)
                if get_dimensions(image_1_val) == 4:
                    image_1_val.unsqueeze_(0)
                label_1_val = vbatch['image1_label'].to(device=device, dtype=dtype)
                if get_dimensions(label_1_val) == 4:
                    label_1_val.unsqueeze_(0)
                # move data to device, convert dtype to desirable dtype
                # add one dimension to labels if they are 4D tensors
                
                out_1_val = deeplab(image_1_val)
                # do the inference
                
                loss_1 = dice_loss_3(out_1_val, label_1_val)
                # calculate loss

                valloss_1 += loss_1.item()
                # record mini batch loss
            avg_val_loss = (valloss_1/(v+1))
            outstr = '------- 1st valloss={0:.4f}'\
                .format(avg_val_loss) + '\n'
            
            logger['validation_1'].append(avg_val_loss)
            #scheduler.step(avg_val_loss)
            
            print(outstr)
            record.write(outstr)
            record.flush()
            
            if avg_val_loss < min_val:
                print(avg_val_loss, ' <', min_val)
                min_val = avg_val_loss
            
            save_1('deeplab_output_2_adabound_save', deeplab, optimizer, logger, e, scheduler)

record.close()

  0%|          | 1/4914 [12:07<993:06:10, 727.70s/it]

Epoch 86 finished ! Training Loss: 0.2509



  0%|          | 2/4914 [23:09<965:48:56, 707.85s/it]

Epoch 87 finished ! Training Loss: 0.2425



  0%|          | 3/4914 [34:15<948:27:36, 695.27s/it]

Epoch 88 finished ! Training Loss: 0.2387



  0%|          | 4/4914 [45:14<933:36:36, 684.52s/it]

Epoch 89 finished ! Training Loss: 0.2405

Epoch 90 finished ! Training Loss: 0.2358

------- 1st valloss=0.2292

0.22916898066582886  < 1


  0%|          | 5/4914 [56:51<938:22:34, 688.16s/it]

Checkpoint 90 saved !


  0%|          | 6/4914 [1:07:50<926:24:50, 679.52s/it]

Epoch 91 finished ! Training Loss: 0.2339



  0%|          | 7/4914 [1:19:00<922:08:11, 676.52s/it]

Epoch 92 finished ! Training Loss: 0.2263



  0%|          | 8/4914 [1:29:52<911:55:56, 669.17s/it]

Epoch 93 finished ! Training Loss: 0.2357



  0%|          | 9/4914 [1:40:58<910:28:09, 668.23s/it]

Epoch 94 finished ! Training Loss: 0.2234

Epoch 95 finished ! Training Loss: 0.2306

------- 1st valloss=0.2195

0.2194654151149418  < 0.22916898066582886


  0%|          | 10/4914 [1:52:45<926:13:20, 679.93s/it]

Checkpoint 95 saved !


  0%|          | 11/4914 [2:03:40<915:40:56, 672.33s/it]

Epoch 96 finished ! Training Loss: 0.2272



  0%|          | 12/4914 [2:14:48<913:50:49, 671.12s/it]

Epoch 97 finished ! Training Loss: 0.2266



  0%|          | 13/4914 [2:25:48<909:17:33, 667.92s/it]

Epoch 98 finished ! Training Loss: 0.2236



  0%|          | 14/4914 [2:36:42<903:20:36, 663.68s/it]

Epoch 99 finished ! Training Loss: 0.2117

Epoch 100 finished ! Training Loss: 0.2125

------- 1st valloss=0.1648

0.16484586797330691  < 0.2194654151149418


  0%|          | 15/4914 [2:48:33<922:17:06, 677.74s/it]

Checkpoint 100 saved !


  0%|          | 16/4914 [2:59:44<919:19:42, 675.70s/it]

Epoch 101 finished ! Training Loss: 0.2328



  0%|          | 17/4914 [3:10:46<913:52:05, 671.82s/it]

Epoch 102 finished ! Training Loss: 0.2169



  0%|          | 18/4914 [3:21:55<912:17:06, 670.80s/it]

Epoch 103 finished ! Training Loss: 0.2030



  0%|          | 19/4914 [3:33:06<912:10:19, 670.85s/it]

Epoch 104 finished ! Training Loss: 0.2178

Epoch 105 finished ! Training Loss: 0.2078

------- 1st valloss=0.1709



  0%|          | 20/4914 [3:45:00<929:35:57, 683.81s/it]

Checkpoint 105 saved !


  0%|          | 21/4914 [3:56:08<923:05:48, 679.16s/it]

Epoch 106 finished ! Training Loss: 0.2127



  0%|          | 22/4914 [4:07:21<920:16:55, 677.23s/it]

Epoch 107 finished ! Training Loss: 0.2176



  0%|          | 23/4914 [4:18:20<912:53:19, 671.93s/it]

Epoch 108 finished ! Training Loss: 0.2185



  0%|          | 24/4914 [4:29:30<911:54:58, 671.35s/it]

Epoch 109 finished ! Training Loss: 0.2058

Epoch 110 finished ! Training Loss: 0.2066

------- 1st valloss=0.1765



  1%|          | 25/4914 [4:41:15<925:22:41, 681.40s/it]

Checkpoint 110 saved !


  1%|          | 26/4914 [4:52:13<915:38:31, 674.37s/it]

Epoch 111 finished ! Training Loss: 0.1966



  1%|          | 27/4914 [5:03:06<906:45:28, 667.96s/it]

Epoch 112 finished ! Training Loss: 0.1971



  1%|          | 28/4914 [5:14:06<903:23:00, 665.61s/it]

Epoch 113 finished ! Training Loss: 0.1974



  1%|          | 29/4914 [5:25:02<899:18:07, 662.74s/it]

Epoch 114 finished ! Training Loss: 0.1968

Epoch 115 finished ! Training Loss: 0.1961

------- 1st valloss=0.1382

0.13821909155534662  < 0.16484586797330691


  1%|          | 30/4914 [5:36:45<915:14:13, 674.62s/it]

Checkpoint 115 saved !


  1%|          | 31/4914 [5:47:47<909:51:39, 670.80s/it]

Epoch 116 finished ! Training Loss: 0.2156



  1%|          | 32/4914 [5:58:39<902:09:58, 665.26s/it]

Epoch 117 finished ! Training Loss: 0.2041



  1%|          | 33/4914 [6:09:38<899:24:56, 663.37s/it]

Epoch 118 finished ! Training Loss: 0.2054



  1%|          | 34/4914 [6:20:45<900:56:16, 664.63s/it]

Epoch 119 finished ! Training Loss: 0.2085

Epoch 120 finished ! Training Loss: 0.1950

------- 1st valloss=0.1437



  1%|          | 35/4914 [6:32:29<916:37:35, 676.34s/it]

Checkpoint 120 saved !


  1%|          | 36/4914 [6:43:28<909:27:51, 671.19s/it]

Epoch 121 finished ! Training Loss: 0.1985



  1%|          | 37/4914 [6:54:36<908:02:00, 670.27s/it]

Epoch 122 finished ! Training Loss: 0.1923



  1%|          | 38/4914 [7:05:29<900:49:00, 665.08s/it]

Epoch 123 finished ! Training Loss: 0.1999



  1%|          | 39/4914 [7:16:33<900:00:31, 664.62s/it]

Epoch 124 finished ! Training Loss: 0.1882

Epoch 125 finished ! Training Loss: 0.1991

------- 1st valloss=0.1323

0.13234394853529724  < 0.13821909155534662


  1%|          | 40/4914 [7:28:28<920:28:55, 679.88s/it]

Checkpoint 125 saved !


  1%|          | 41/4914 [7:39:29<912:27:46, 674.10s/it]

Epoch 126 finished ! Training Loss: 0.1942



  1%|          | 42/4914 [7:50:24<904:20:46, 668.24s/it]

Epoch 127 finished ! Training Loss: 0.1810



  1%|          | 43/4914 [8:01:34<904:53:52, 668.78s/it]

Epoch 128 finished ! Training Loss: 0.1874



  1%|          | 44/4914 [8:12:41<904:09:43, 668.37s/it]

Epoch 129 finished ! Training Loss: 0.1994

Epoch 130 finished ! Training Loss: 0.1951

------- 1st valloss=0.1547



  1%|          | 45/4914 [8:24:26<918:40:36, 679.24s/it]

Checkpoint 130 saved !


  1%|          | 46/4914 [8:35:41<916:54:50, 678.08s/it]

Epoch 131 finished ! Training Loss: 0.1850



  1%|          | 47/4914 [8:46:41<909:20:59, 672.62s/it]

Epoch 132 finished ! Training Loss: 0.1888



  1%|          | 48/4914 [8:57:41<904:12:44, 668.96s/it]

Epoch 133 finished ! Training Loss: 0.1939



  1%|          | 49/4914 [9:08:52<904:49:46, 669.56s/it]

Epoch 134 finished ! Training Loss: 0.1737

Epoch 135 finished ! Training Loss: 0.1919

------- 1st valloss=0.1576



  1%|          | 50/4914 [9:20:32<917:01:58, 678.73s/it]

Checkpoint 135 saved !


  1%|          | 51/4914 [9:31:36<910:41:11, 674.17s/it]

Epoch 136 finished ! Training Loss: 0.1955



  1%|          | 52/4914 [9:42:38<905:28:07, 670.44s/it]

Epoch 137 finished ! Training Loss: 0.1874



  1%|          | 53/4914 [9:53:49<905:27:32, 670.57s/it]

Epoch 138 finished ! Training Loss: 0.1736



  1%|          | 54/4914 [10:04:44<899:02:21, 665.95s/it]

Epoch 139 finished ! Training Loss: 0.1895

Epoch 140 finished ! Training Loss: 0.1864

------- 1st valloss=0.1397



  1%|          | 55/4914 [10:16:28<914:25:12, 677.49s/it]

Checkpoint 140 saved !


  1%|          | 56/4914 [10:27:39<911:45:03, 675.65s/it]

Epoch 141 finished ! Training Loss: 0.1803



  1%|          | 57/4914 [10:38:35<903:33:49, 669.72s/it]

Epoch 142 finished ! Training Loss: 0.1860



  1%|          | 58/4914 [10:49:35<899:12:48, 666.63s/it]

Epoch 143 finished ! Training Loss: 0.1927



  1%|          | 59/4914 [11:00:52<903:12:05, 669.73s/it]

Epoch 144 finished ! Training Loss: 0.1966

Epoch 145 finished ! Training Loss: 0.1816

------- 1st valloss=0.1333



  1%|          | 60/4914 [11:12:40<918:26:53, 681.17s/it]

Checkpoint 145 saved !


  1%|          | 61/4914 [11:23:38<909:07:59, 674.40s/it]

Epoch 146 finished ! Training Loss: 0.1760



  1%|▏         | 62/4914 [11:34:33<901:00:58, 668.52s/it]

Epoch 147 finished ! Training Loss: 0.1817



  1%|▏         | 63/4914 [11:45:34<897:55:19, 666.36s/it]

Epoch 148 finished ! Training Loss: 0.1822



  1%|▏         | 64/4914 [11:56:32<894:05:28, 663.66s/it]

Epoch 149 finished ! Training Loss: 0.1847

Epoch 150 finished ! Training Loss: 0.1781

------- 1st valloss=0.1400



  1%|▏         | 65/4914 [12:08:22<912:54:01, 677.76s/it]

Checkpoint 150 saved !


  1%|▏         | 66/4914 [12:19:40<912:49:27, 677.84s/it]

Epoch 151 finished ! Training Loss: 0.1881



  1%|▏         | 67/4914 [12:30:47<908:07:32, 674.49s/it]

Epoch 152 finished ! Training Loss: 0.1731



  1%|▏         | 68/4914 [12:41:56<905:53:06, 672.96s/it]

Epoch 153 finished ! Training Loss: 0.1762



  1%|▏         | 69/4914 [12:53:06<904:15:56, 671.90s/it]

Epoch 154 finished ! Training Loss: 0.1814

Epoch 155 finished ! Training Loss: 0.1668

------- 1st valloss=0.1184

0.11840220281611318  < 0.13234394853529724


  1%|▏         | 70/4914 [13:05:00<920:59:03, 684.46s/it]

Checkpoint 155 saved !


  1%|▏         | 71/4914 [13:16:02<911:48:01, 677.78s/it]

Epoch 156 finished ! Training Loss: 0.1781



  1%|▏         | 72/4914 [13:26:54<901:18:25, 670.12s/it]

Epoch 157 finished ! Training Loss: 0.1626



  1%|▏         | 73/4914 [13:38:07<902:12:40, 670.93s/it]

Epoch 158 finished ! Training Loss: 0.1780



  2%|▏         | 74/4914 [13:49:12<899:31:29, 669.07s/it]

Epoch 159 finished ! Training Loss: 0.1769

Epoch 160 finished ! Training Loss: 0.1630

------- 1st valloss=0.1510



  2%|▏         | 75/4914 [14:00:58<914:13:41, 680.14s/it]

Checkpoint 160 saved !


  2%|▏         | 76/4914 [14:12:06<909:19:16, 676.63s/it]

Epoch 161 finished ! Training Loss: 0.1722



  2%|▏         | 77/4914 [14:23:12<904:47:20, 673.40s/it]

Epoch 162 finished ! Training Loss: 0.1673



  2%|▏         | 78/4914 [14:34:21<902:45:17, 672.03s/it]

Epoch 163 finished ! Training Loss: 0.1760



  2%|▏         | 79/4914 [14:45:13<894:43:07, 666.18s/it]

Epoch 164 finished ! Training Loss: 0.1764

Epoch 165 finished ! Training Loss: 0.1766

------- 1st valloss=0.1394



  2%|▏         | 80/4914 [14:56:54<908:23:20, 676.50s/it]

Checkpoint 165 saved !


  2%|▏         | 81/4914 [15:08:07<906:43:47, 675.40s/it]

Epoch 166 finished ! Training Loss: 0.1707



  2%|▏         | 82/4914 [15:19:08<900:46:02, 671.10s/it]

Epoch 167 finished ! Training Loss: 0.1715



  2%|▏         | 83/4914 [15:30:19<900:50:14, 671.29s/it]

Epoch 168 finished ! Training Loss: 0.1667



  2%|▏         | 84/4914 [15:41:24<897:54:33, 669.25s/it]

Epoch 169 finished ! Training Loss: 0.1514

Epoch 170 finished ! Training Loss: 0.1585

------- 1st valloss=0.1535



  2%|▏         | 85/4914 [15:53:11<912:46:07, 680.47s/it]

Checkpoint 170 saved !


  2%|▏         | 86/4914 [16:04:10<903:54:54, 674.00s/it]

Epoch 171 finished ! Training Loss: 0.1589



  2%|▏         | 87/4914 [16:15:24<904:03:15, 674.25s/it]

Epoch 172 finished ! Training Loss: 0.1607



  2%|▏         | 88/4914 [16:26:35<902:26:16, 673.18s/it]

Epoch 173 finished ! Training Loss: 0.1674



  2%|▏         | 89/4914 [16:37:43<900:08:47, 671.61s/it]

Epoch 174 finished ! Training Loss: 0.1638

Epoch 175 finished ! Training Loss: 0.1772

------- 1st valloss=0.1447



  2%|▏         | 90/4914 [16:49:35<916:10:28, 683.71s/it]

Checkpoint 175 saved !


  2%|▏         | 91/4914 [17:00:27<903:11:57, 674.17s/it]

Epoch 176 finished ! Training Loss: 0.1749



  2%|▏         | 92/4914 [17:11:21<894:49:03, 668.05s/it]

Epoch 177 finished ! Training Loss: 0.1666



  2%|▏         | 93/4914 [17:22:35<897:16:37, 670.03s/it]

Epoch 178 finished ! Training Loss: 0.1681



  2%|▏         | 94/4914 [17:33:37<893:50:41, 667.60s/it]

Epoch 179 finished ! Training Loss: 0.1707

Epoch 180 finished ! Training Loss: 0.1685

------- 1st valloss=0.1381



  2%|▏         | 95/4914 [17:45:22<908:40:38, 678.82s/it]

Checkpoint 180 saved !


  2%|▏         | 96/4914 [17:56:15<897:56:41, 670.94s/it]

Epoch 181 finished ! Training Loss: 0.1609



  2%|▏         | 97/4914 [18:07:14<892:57:08, 667.35s/it]

Epoch 182 finished ! Training Loss: 0.1696



  2%|▏         | 98/4914 [18:18:17<891:01:20, 666.05s/it]

Epoch 183 finished ! Training Loss: 0.1559



  2%|▏         | 99/4914 [18:29:20<889:47:32, 665.27s/it]

Epoch 184 finished ! Training Loss: 0.1773

Epoch 185 finished ! Training Loss: 0.1686

------- 1st valloss=0.1367



  2%|▏         | 100/4914 [18:41:05<905:17:01, 676.99s/it]

Checkpoint 185 saved !


  2%|▏         | 101/4914 [18:52:13<901:33:46, 674.35s/it]

Epoch 186 finished ! Training Loss: 0.1621



  2%|▏         | 102/4914 [19:03:06<892:55:31, 668.02s/it]

Epoch 187 finished ! Training Loss: 0.1721



  2%|▏         | 103/4914 [19:14:06<889:32:51, 665.64s/it]

Epoch 188 finished ! Training Loss: 0.1602



  2%|▏         | 104/4914 [19:25:04<886:27:52, 663.47s/it]

Epoch 189 finished ! Training Loss: 0.1597

Epoch 190 finished ! Training Loss: 0.1719

------- 1st valloss=0.1425



  2%|▏         | 105/4914 [19:36:48<902:11:39, 675.38s/it]

Checkpoint 190 saved !


  2%|▏         | 106/4914 [19:47:57<899:43:40, 673.67s/it]

Epoch 191 finished ! Training Loss: 0.1663



  2%|▏         | 107/4914 [19:59:02<895:51:18, 670.91s/it]

Epoch 192 finished ! Training Loss: 0.1632



  2%|▏         | 108/4914 [20:09:57<889:13:06, 666.08s/it]

Epoch 193 finished ! Training Loss: 0.1595



  2%|▏         | 109/4914 [20:20:56<886:12:34, 663.97s/it]

Epoch 194 finished ! Training Loss: 0.1521

Epoch 195 finished ! Training Loss: 0.1585

------- 1st valloss=0.1643



  2%|▏         | 110/4914 [20:32:21<894:44:55, 670.50s/it]

Checkpoint 195 saved !


  2%|▏         | 111/4914 [20:43:24<891:23:44, 668.13s/it]

Epoch 196 finished ! Training Loss: 0.1609



  2%|▏         | 112/4914 [20:54:34<891:47:23, 668.56s/it]

Epoch 197 finished ! Training Loss: 0.1481



  2%|▏         | 113/4914 [21:05:44<892:17:40, 669.08s/it]

Epoch 198 finished ! Training Loss: 0.1609



  2%|▏         | 114/4914 [21:16:42<887:48:31, 665.86s/it]

Epoch 199 finished ! Training Loss: 0.1664

Epoch 200 finished ! Training Loss: 0.1581

------- 1st valloss=0.1448



  2%|▏         | 115/4914 [21:28:22<901:20:41, 676.15s/it]

Checkpoint 200 saved !


  2%|▏         | 116/4914 [21:39:23<895:01:13, 671.55s/it]

Epoch 201 finished ! Training Loss: 0.1606



  2%|▏         | 117/4914 [21:50:13<886:20:04, 665.17s/it]

Epoch 202 finished ! Training Loss: 0.1426



  2%|▏         | 118/4914 [22:01:18<886:05:56, 665.13s/it]

Epoch 203 finished ! Training Loss: 0.1559



  2%|▏         | 119/4914 [22:12:30<888:40:33, 667.20s/it]

Epoch 204 finished ! Training Loss: 0.1698

Epoch 205 finished ! Training Loss: 0.1580

------- 1st valloss=0.1405



  2%|▏         | 120/4914 [22:24:28<908:33:22, 682.27s/it]

Checkpoint 205 saved !


  2%|▏         | 121/4914 [22:35:38<903:38:49, 678.73s/it]

Epoch 206 finished ! Training Loss: 0.1545



  2%|▏         | 122/4914 [22:46:36<894:51:36, 672.27s/it]

Epoch 207 finished ! Training Loss: 0.1479



  3%|▎         | 123/4914 [22:57:42<892:24:08, 670.56s/it]

Epoch 208 finished ! Training Loss: 0.1487



  3%|▎         | 124/4914 [23:08:43<888:10:35, 667.52s/it]

Epoch 209 finished ! Training Loss: 0.1493

Epoch 210 finished ! Training Loss: 0.1543

------- 1st valloss=0.1521



  3%|▎         | 125/4914 [23:20:26<902:19:26, 678.30s/it]

Checkpoint 210 saved !


  3%|▎         | 126/4914 [23:31:31<896:43:33, 674.23s/it]

Epoch 211 finished ! Training Loss: 0.1559



  3%|▎         | 127/4914 [23:42:25<888:34:21, 668.24s/it]

Epoch 212 finished ! Training Loss: 0.1627



  3%|▎         | 128/4914 [23:53:32<887:44:31, 667.75s/it]

Epoch 213 finished ! Training Loss: 0.1464



  3%|▎         | 129/4914 [24:04:47<890:35:08, 670.03s/it]

Epoch 214 finished ! Training Loss: 0.1580

Epoch 215 finished ! Training Loss: 0.1639

------- 1st valloss=0.1556



  3%|▎         | 130/4914 [24:16:24<901:01:10, 678.02s/it]

Checkpoint 215 saved !


  3%|▎         | 131/4914 [24:27:28<895:17:41, 673.86s/it]

Epoch 216 finished ! Training Loss: 0.1654



  3%|▎         | 132/4914 [24:38:36<892:48:47, 672.13s/it]

Epoch 217 finished ! Training Loss: 0.1510



  3%|▎         | 133/4914 [24:49:44<890:52:28, 670.81s/it]

Epoch 218 finished ! Training Loss: 0.1634



  3%|▎         | 134/4914 [25:00:40<885:06:09, 666.60s/it]

Epoch 219 finished ! Training Loss: 0.1548



  3%|▎         | 136/4914 [25:23:35<896:15:03, 675.28s/it]

Epoch 221 finished ! Training Loss: 0.1534



  3%|▎         | 137/4914 [25:34:52<896:32:29, 675.64s/it]

Epoch 222 finished ! Training Loss: 0.1472



  3%|▎         | 138/4914 [25:45:57<892:00:04, 672.36s/it]

Epoch 223 finished ! Training Loss: 0.1462



  3%|▎         | 139/4914 [25:56:50<884:23:44, 666.77s/it]

Epoch 224 finished ! Training Loss: 0.1401

Epoch 225 finished ! Training Loss: 0.1530

------- 1st valloss=0.1320



  3%|▎         | 140/4914 [26:08:34<898:43:33, 677.72s/it]

Checkpoint 225 saved !


  3%|▎         | 141/4914 [26:19:42<894:42:09, 674.82s/it]

Epoch 226 finished ! Training Loss: 0.1416



  3%|▎         | 142/4914 [26:30:42<888:48:01, 670.51s/it]

Epoch 227 finished ! Training Loss: 0.1529



  3%|▎         | 143/4914 [26:41:56<889:51:22, 671.45s/it]

Epoch 228 finished ! Training Loss: 0.1459



  3%|▎         | 144/4914 [26:52:56<885:02:45, 667.96s/it]

Epoch 229 finished ! Training Loss: 0.1395

Epoch 230 finished ! Training Loss: 0.1465

------- 1st valloss=0.1365



  3%|▎         | 145/4914 [27:04:49<902:51:17, 681.54s/it]

Checkpoint 230 saved !


  3%|▎         | 146/4914 [27:15:53<895:55:45, 676.46s/it]

Epoch 231 finished ! Training Loss: 0.1516



  3%|▎         | 147/4914 [27:26:53<888:56:07, 671.32s/it]

Epoch 232 finished ! Training Loss: 0.1451



  3%|▎         | 148/4914 [27:37:50<883:19:31, 667.22s/it]

Epoch 233 finished ! Training Loss: 0.1439



  3%|▎         | 149/4914 [27:48:50<880:18:49, 665.08s/it]

Epoch 234 finished ! Training Loss: 0.1365

Epoch 235 finished ! Training Loss: 0.1435

------- 1st valloss=0.1272



  3%|▎         | 150/4914 [28:00:25<891:58:49, 674.04s/it]

Checkpoint 235 saved !


  3%|▎         | 151/4914 [28:11:27<886:49:05, 670.28s/it]

Epoch 236 finished ! Training Loss: 0.1396



  3%|▎         | 152/4914 [28:22:30<883:53:37, 668.21s/it]

Epoch 237 finished ! Training Loss: 0.1353



  3%|▎         | 153/4914 [28:33:30<880:12:04, 665.56s/it]

Epoch 238 finished ! Training Loss: 0.1458



  3%|▎         | 154/4914 [28:44:26<876:31:21, 662.92s/it]

Epoch 239 finished ! Training Loss: 0.1378

Epoch 240 finished ! Training Loss: 0.1400

------- 1st valloss=0.1335



  3%|▎         | 155/4914 [28:56:06<890:59:59, 674.01s/it]

Checkpoint 240 saved !


  3%|▎         | 156/4914 [29:07:08<886:01:23, 670.38s/it]

Epoch 241 finished ! Training Loss: 0.1427



  3%|▎         | 157/4914 [29:18:04<879:59:57, 665.97s/it]

Epoch 242 finished ! Training Loss: 0.1540



  3%|▎         | 158/4914 [29:29:07<878:38:39, 665.08s/it]

Epoch 243 finished ! Training Loss: 0.1476



  3%|▎         | 159/4914 [29:40:10<877:34:23, 664.41s/it]

Epoch 244 finished ! Training Loss: 0.1487

Epoch 245 finished ! Training Loss: 0.1444

------- 1st valloss=0.1274



  3%|▎         | 160/4914 [29:51:52<892:23:48, 675.77s/it]

Checkpoint 245 saved !


  3%|▎         | 161/4914 [30:03:04<890:44:35, 674.66s/it]

Epoch 246 finished ! Training Loss: 0.1342



  3%|▎         | 162/4914 [30:14:00<883:10:40, 669.07s/it]

Epoch 247 finished ! Training Loss: 0.1478



  3%|▎         | 163/4914 [30:25:06<881:41:12, 668.09s/it]

Epoch 248 finished ! Training Loss: 0.1447



  3%|▎         | 164/4914 [30:36:08<879:15:55, 666.39s/it]

Epoch 249 finished ! Training Loss: 0.1344

Epoch 250 finished ! Training Loss: 0.1474

------- 1st valloss=0.1160

0.115959657923035  < 0.11840220281611318


  3%|▎         | 165/4914 [30:47:48<892:13:14, 676.35s/it]

Checkpoint 250 saved !


  3%|▎         | 166/4914 [30:58:54<887:53:04, 673.21s/it]

Epoch 251 finished ! Training Loss: 0.1369



In [None]:
deeplab.eval()

with torch.no_grad():
    
    bgloss = 0
    bdloss = 0
    bvloss = 0
    
    for v, vbatch in tqdm(enumerate(validation_loader)):
            # move data to device, convert dtype to desirable dtype

        image_1 = vbatch['image1_data'].to(device=device, dtype=dtype)
        label_1 = vbatch['image1_label'].to(device=device, dtype=dtype)

        output = deeplab(image_1)
        # do the inference
        output_numpy = output.cpu().numpy()
        
        
        #out_1 = torch.round(output)
        out_1 = torch.from_numpy((output_numpy == output_numpy.max(axis=1)[:, None]).astype(int)).to(device=device, dtype=dtype)
        loss_1 = dice_loss_3(out_1, label_1)

        bg, bd, bv = dice_loss_3_debug(out_1, label_1)
        # calculate loss
        print(bg.item(), bd.item(), bv.item(), loss_1.item())
        bgloss += bg.item()
        bdloss += bd.item()
        bvloss += bv.item()

    outstr = '------- background loss = {0:.4f}, body loss = {1:.4f}, bv loss = {2:.4f}'\
        .format(bgloss/(v+1), bdloss/(v+1), bvloss/(v+1)) + '\n'
    print(outstr)