In [1]:
import sys
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utility import *
from data_utils import *
from loss import *
from train import *
from deeplab_model.deeplab import *
from sync_batchnorm import convert_model
import datetime

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
USE_GPU = True
NUM_WORKERS = 12
BATCH_SIZE = 2 

dtype = torch.float32 
# define dtype, float is space efficient than double

if USE_GPU and torch.cuda.is_available():
    
    device = torch.device('cuda')
    
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    # magic flag that accelerate
    
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


In [3]:
train_dataset = pyramid_dataset(data_type = 'nii_train', 
                transform=transforms.Compose([
                random_affine(90, 15),
                random_filp(0.5)]))
# do data augumentation on train dataset

validation_dataset = pyramid_dataset(data_type = 'nii_test', 
                transform=None)
# no data augumentation on validation dataset

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS) # drop_last
# loaders come with auto batch division and multi-thread acceleration

In [None]:

deeplab = DeepLabModified(output_stride=8)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)
deeplab = deeplab.to(device=device, dtype=dtype)
#shape_test(icnet1, True)
# create the model, by default model type is float, use model.double(), model.float() to convert
# move the model to desirable device

optimizer = optim.Adam(deeplab.parameters(), lr=1e-2)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=7)
epoch = 0

# create an optimizer object
# note that only the model_2 params and model_4 params will be optimized by optimizer

In [None]:
'''
deeplab = DeepLabModified(output_stride=16)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)

checkpoint = torch.load('../deeplab_modified_save/2019-08-12 17:24:36.436884 epoch: 290.pth') # latest one

deeplab.load_state_dict(checkpoint['state_dict_1'])
deeplab = deeplab.to(device, dtype)

optimizer = optim.Adam(deeplab.parameters(), lr=1e-3)
optimizer.load_state_dict(checkpoint['optimizer'])

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=8)
scheduler.load_state_dict(checkpoint['scheduler'])

epoch = checkpoint['epoch']
print(epoch)
print(sum(checkpoint['logger']['validation_1']) / len(checkpoint['logger']['validation_1']))
for param_group in optimizer.param_groups:
    print(param_group['lr'])
'''

"\ndeeplab = DeepLabModified(output_stride=16)\ndeeplab = nn.DataParallel(deeplab)\ndeeplab = convert_model(deeplab)\n\ncheckpoint = torch.load('../deeplab_modified_save/2019-08-12 17:24:36.436884 epoch: 290.pth') # latest one\n\ndeeplab.load_state_dict(checkpoint['state_dict_1'])\ndeeplab = deeplab.to(device, dtype)\n\noptimizer = optim.Adam(deeplab.parameters(), lr=1e-3)\noptimizer.load_state_dict(checkpoint['optimizer'])\n\nscheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=8)\nscheduler.load_state_dict(checkpoint['scheduler'])\n\nepoch = checkpoint['epoch']\nprint(epoch)\nprint(sum(checkpoint['logger']['validation_1']) / len(checkpoint['logger']['validation_1']))\nfor param_group in optimizer.param_groups:\n    print(param_group['lr'])\n"

In [None]:
epochs = 5000

min_val = 1

record = open('train_deeplab_output_8_modified.txt','a+')

logger = {'train':[], 'validation_1': []}

for e in tqdm(range(epoch + 1, epochs)):
# iter over epoches

    epoch_loss = 0
        
    for t, batch in enumerate(train_loader):
    # iter over the train mini batches
    
        deeplab.train()
        # Set the model flag to train
        # 1. enable dropout
        # 2. batchnorm behave differently in train and test
        
        image_1 = batch['image1_data'].to(device=device, dtype=dtype)
        label_1 = batch['image1_label'].to(device=device, dtype=dtype)
        # move data to device, convert dtype to desirable dtype
        
        out_1 = deeplab(image_1)
        # do the inference

        loss_1 = dice_loss_3(out_1, label_1)
        # calculate loss
        
        epoch_loss += loss_1.item()
        # record minibatch loss to epoch loss
        
        optimizer.zero_grad()
        # set the model parameter gradient to zero
        
        loss_1.backward()
        # calculate the gradient wrt loss
        optimizer.step()
        #scheduler.step(loss_1)
        # take a gradient descent step
        
    outstr = 'Epoch {0} finished ! Training Loss: {1:.4f}'.format(e, epoch_loss/(t+1)) + '\n'
    
    logger['train'].append(epoch_loss/(t+1))
    
    print(outstr)
    record.write(outstr)
    record.flush()

    if e%3 == 0:
    # do validation every 5 epoches
    
        deeplab.eval()
        # set model flag to eval
        # 1. disable dropout
        # 2. batchnorm behave differs

        with torch.no_grad():
        # stop taking gradient
        
            #valloss_4 = 0
            #valloss_2 = 0
            valloss_1 = 0
            
            for v, vbatch in enumerate(validation_loader):
            # iter over validation mini batches
                
                image_1_val = vbatch['image1_data'].to(device=device, dtype=dtype)
                if get_dimensions(image_1_val) == 4:
                    image_1_val.unsqueeze_(0)
                label_1_val = vbatch['image1_label'].to(device=device, dtype=dtype)
                if get_dimensions(label_1_val) == 4:
                    label_1_val.unsqueeze_(0)
                # move data to device, convert dtype to desirable dtype
                # add one dimension to labels if they are 4D tensors
                
                out_1_val = deeplab(image_1_val)
                # do the inference
                
                loss_1 = dice_loss_3(out_1_val, label_1_val)
                # calculate loss

                valloss_1 += loss_1.item()
                # record mini batch loss
            
            avg_val_loss = (valloss_1 / (v+1))
            outstr = '------- 1st valloss={0:.4f}'\
                .format(avg_val_loss) + '\n'
            
            logger['validation_1'].append(avg_val_loss)
            scheduler.step(avg_val_loss)
            
            print(outstr)
            record.write(outstr)
            record.flush()
            
            if avg_val_loss < min_val:
                print(avg_val_loss, "less than", min_val)
                min_val = avg_val_loss
                save_1('deeplab_output_8_modified_save', deeplab, optimizer, logger, e, scheduler)
            elif e % 15 == 0:
                save_1('deeplab_output_8_modified_save', deeplab, optimizer, logger, e, scheduler)

record.close()

  0%|          | 1/4999 [16:44<1395:10:40, 1004.93s/it]

Epoch 1 finished ! Training Loss: 0.4848



  0%|          | 2/4999 [31:02<1333:42:22, 960.85s/it] 

Epoch 2 finished ! Training Loss: 0.4539

Epoch 3 finished ! Training Loss: 0.4466

------- 1st valloss=0.4800

0.47998546517413593 less than 1


  0%|          | 3/4999 [46:07<1309:55:05, 943.90s/it]

Checkpoint 3 saved !


  0%|          | 4/4999 [1:00:29<1275:29:14, 919.27s/it]

Epoch 4 finished ! Training Loss: 0.4364



  0%|          | 5/4999 [1:14:53<1252:31:01, 902.90s/it]

Epoch 5 finished ! Training Loss: 0.4334

Epoch 6 finished ! Training Loss: 0.4027

------- 1st valloss=0.3473

0.3472966668398484 less than 0.47998546517413593


  0%|          | 6/4999 [1:29:55<1251:40:32, 902.47s/it]

Checkpoint 6 saved !


  0%|          | 7/4999 [1:44:10<1231:49:30, 888.34s/it]

Epoch 7 finished ! Training Loss: 0.3740



  0%|          | 8/4999 [1:58:22<1216:29:01, 877.45s/it]

Epoch 8 finished ! Training Loss: 0.3500

Epoch 9 finished ! Training Loss: 0.3153



  0%|          | 9/4999 [2:13:29<1228:40:11, 886.42s/it]

------- 1st valloss=0.3773



  0%|          | 10/4999 [2:27:42<1214:24:50, 876.31s/it]

Epoch 10 finished ! Training Loss: 0.3009



  0%|          | 11/4999 [2:42:02<1207:22:19, 871.40s/it]

Epoch 11 finished ! Training Loss: 0.2861

Epoch 12 finished ! Training Loss: 0.2736



  0%|          | 12/4999 [2:56:59<1217:49:01, 879.11s/it]

------- 1st valloss=0.9444



  0%|          | 13/4999 [3:11:13<1206:49:48, 871.36s/it]

Epoch 13 finished ! Training Loss: 0.2809



  0%|          | 14/4999 [3:25:38<1204:01:48, 869.51s/it]

Epoch 14 finished ! Training Loss: 0.2701

Epoch 15 finished ! Training Loss: 0.2478

------- 1st valloss=0.6459



  0%|          | 15/4999 [3:40:22<1210:01:23, 874.01s/it]

Checkpoint 15 saved !


  0%|          | 16/4999 [3:54:40<1202:57:39, 869.09s/it]

Epoch 16 finished ! Training Loss: 0.2555



  0%|          | 17/4999 [4:08:58<1198:15:37, 865.86s/it]

Epoch 17 finished ! Training Loss: 0.2464

Epoch 18 finished ! Training Loss: 0.2353

------- 1st valloss=0.1801

0.18014607546122177 less than 0.3472966668398484


  0%|          | 18/4999 [4:23:51<1209:21:17, 874.06s/it]

Checkpoint 18 saved !


  0%|          | 19/4999 [4:38:13<1203:46:21, 870.20s/it]

Epoch 19 finished ! Training Loss: 0.2422



  0%|          | 20/4999 [4:52:22<1195:02:41, 864.06s/it]

Epoch 20 finished ! Training Loss: 0.2386

Epoch 21 finished ! Training Loss: 0.2191



  0%|          | 21/4999 [5:07:19<1208:08:47, 873.71s/it]

------- 1st valloss=0.2379



  0%|          | 22/4999 [5:21:33<1199:53:54, 867.92s/it]

Epoch 22 finished ! Training Loss: 0.1975



  0%|          | 23/4999 [5:35:46<1193:20:33, 863.35s/it]

Epoch 23 finished ! Training Loss: 0.2159

Epoch 24 finished ! Training Loss: 0.2151



  0%|          | 24/4999 [5:50:41<1206:26:16, 873.00s/it]

------- 1st valloss=0.4779



  1%|          | 25/4999 [6:04:54<1197:39:55, 866.83s/it]

Epoch 25 finished ! Training Loss: 0.2047



  1%|          | 26/4999 [6:19:11<1193:41:54, 864.13s/it]

Epoch 26 finished ! Training Loss: 0.2025

Epoch 27 finished ! Training Loss: 0.1929

------- 1st valloss=0.1761

0.17606516765511554 less than 0.18014607546122177


  1%|          | 27/4999 [6:34:11<1208:07:14, 874.75s/it]

Checkpoint 27 saved !


  1%|          | 28/4999 [6:48:24<1198:53:20, 868.24s/it]

Epoch 28 finished ! Training Loss: 0.1949



  1%|          | 29/4999 [7:02:38<1192:45:44, 863.97s/it]

Epoch 29 finished ! Training Loss: 0.1914

Epoch 30 finished ! Training Loss: 0.1943

------- 1st valloss=0.1189

0.11885526031255722 less than 0.17606516765511554


  1%|          | 30/4999 [7:17:35<1206:15:45, 873.93s/it]

Checkpoint 30 saved !


  1%|          | 31/4999 [7:31:50<1198:11:47, 868.26s/it]

Epoch 31 finished ! Training Loss: 0.1905



  1%|          | 32/4999 [7:46:04<1191:58:12, 863.92s/it]

Epoch 32 finished ! Training Loss: 0.1828

Epoch 33 finished ! Training Loss: 0.1827



  1%|          | 33/4999 [8:00:49<1200:24:53, 870.22s/it]

------- 1st valloss=0.3338



  1%|          | 34/4999 [8:15:03<1193:19:57, 865.26s/it]

Epoch 34 finished ! Training Loss: 0.1919



  1%|          | 35/4999 [8:29:15<1187:48:04, 861.42s/it]

Epoch 35 finished ! Training Loss: 0.1736

Epoch 36 finished ! Training Loss: 0.1700



  1%|          | 36/4999 [8:44:03<1198:33:55, 869.40s/it]

------- 1st valloss=0.1488



  1%|          | 37/4999 [8:58:21<1193:25:05, 865.84s/it]

Epoch 37 finished ! Training Loss: 0.1695



  1%|          | 38/4999 [9:12:34<1187:51:05, 861.98s/it]

Epoch 38 finished ! Training Loss: 0.1566

Epoch 39 finished ! Training Loss: 0.1797



  1%|          | 39/4999 [9:27:36<1204:28:07, 874.21s/it]

------- 1st valloss=0.6760



  1%|          | 40/4999 [9:41:43<1192:51:32, 865.96s/it]

Epoch 40 finished ! Training Loss: 0.2100



  1%|          | 41/4999 [9:55:53<1186:02:09, 861.18s/it]

Epoch 41 finished ! Training Loss: 0.1827

Epoch 42 finished ! Training Loss: 0.1835



  1%|          | 42/4999 [10:10:42<1197:16:07, 869.51s/it]

------- 1st valloss=0.1463



  1%|          | 43/4999 [10:24:46<1186:27:18, 861.83s/it]

Epoch 43 finished ! Training Loss: 0.1841



  1%|          | 44/4999 [10:38:57<1181:45:48, 858.60s/it]

Epoch 44 finished ! Training Loss: 0.1614

Epoch 45 finished ! Training Loss: 0.1594

------- 1st valloss=0.1649



  1%|          | 45/4999 [10:54:07<1202:45:32, 874.03s/it]

Checkpoint 45 saved !


  1%|          | 46/4999 [11:08:16<1192:19:46, 866.62s/it]

Epoch 46 finished ! Training Loss: 0.1574



  1%|          | 47/4999 [11:22:26<1185:04:22, 861.52s/it]

Epoch 47 finished ! Training Loss: 0.1602

Epoch 48 finished ! Training Loss: 0.1533



  1%|          | 48/4999 [11:37:27<1201:19:09, 873.51s/it]

------- 1st valloss=0.6669



  1%|          | 49/4999 [11:51:41<1192:48:41, 867.50s/it]

Epoch 49 finished ! Training Loss: 0.1416



  1%|          | 50/4999 [12:06:05<1191:03:09, 866.40s/it]

Epoch 50 finished ! Training Loss: 0.1617

Epoch 51 finished ! Training Loss: 0.1453

------- 1st valloss=0.1083

0.10827822788901952 less than 0.11885526031255722


  1%|          | 51/4999 [12:20:59<1202:13:21, 874.70s/it]

Checkpoint 51 saved !


  1%|          | 52/4999 [12:35:19<1195:58:13, 870.32s/it]

Epoch 52 finished ! Training Loss: 0.1381



  1%|          | 53/4999 [12:49:34<1189:23:15, 865.71s/it]

Epoch 53 finished ! Training Loss: 0.1441

Epoch 54 finished ! Training Loss: 0.1328



  1%|          | 54/4999 [13:04:19<1197:11:42, 871.57s/it]

------- 1st valloss=0.1411



  1%|          | 55/4999 [13:18:38<1191:44:39, 867.77s/it]

Epoch 55 finished ! Training Loss: 0.1324



  1%|          | 56/4999 [13:33:00<1189:18:11, 866.17s/it]

Epoch 56 finished ! Training Loss: 0.1343

Epoch 57 finished ! Training Loss: 0.1399



  1%|          | 57/4999 [13:48:01<1203:10:00, 876.45s/it]

------- 1st valloss=0.4368



  1%|          | 58/4999 [14:02:14<1193:30:33, 869.59s/it]

Epoch 58 finished ! Training Loss: 0.1245



  1%|          | 59/4999 [14:16:40<1191:41:01, 868.43s/it]

Epoch 59 finished ! Training Loss: 0.1290

Epoch 60 finished ! Training Loss: 0.1210

------- 1st valloss=0.1658



  1%|          | 60/4999 [14:31:45<1206:25:20, 879.35s/it]

Checkpoint 60 saved !


  1%|          | 61/4999 [14:46:12<1201:15:56, 875.77s/it]

Epoch 61 finished ! Training Loss: 0.1138



  1%|          | 62/4999 [15:00:34<1195:05:23, 871.44s/it]

Epoch 62 finished ! Training Loss: 0.1282

Epoch 63 finished ! Training Loss: 0.2361



  1%|▏         | 63/4999 [15:15:41<1209:28:28, 882.11s/it]

------- 1st valloss=0.1594



  1%|▏         | 64/4999 [15:29:58<1199:09:24, 874.76s/it]

Epoch 64 finished ! Training Loss: 0.2006



  1%|▏         | 65/4999 [15:44:22<1194:25:30, 871.49s/it]

Epoch 65 finished ! Training Loss: 0.1782

Epoch 66 finished ! Training Loss: 0.1698



  1%|▏         | 66/4999 [15:59:26<1207:25:41, 881.16s/it]

------- 1st valloss=0.3343



  1%|▏         | 67/4999 [16:13:39<1195:28:18, 872.61s/it]

Epoch 67 finished ! Training Loss: 0.1495



  1%|▏         | 68/4999 [16:28:00<1190:42:23, 869.31s/it]

Epoch 68 finished ! Training Loss: 0.1295

Epoch 69 finished ! Training Loss: 0.1368



  1%|▏         | 69/4999 [16:42:53<1199:56:45, 876.23s/it]

------- 1st valloss=0.4143



  1%|▏         | 70/4999 [16:57:17<1195:01:31, 872.81s/it]

Epoch 70 finished ! Training Loss: 0.1283



  1%|▏         | 71/4999 [17:11:32<1187:15:53, 867.32s/it]

Epoch 71 finished ! Training Loss: 0.1225

Epoch 72 finished ! Training Loss: 0.1272



  1%|▏         | 72/4999 [17:26:36<1202:04:03, 878.31s/it]

------- 1st valloss=0.2221



  1%|▏         | 73/4999 [17:40:53<1193:14:55, 872.05s/it]

Epoch 73 finished ! Training Loss: 0.1122



  1%|▏         | 74/4999 [17:55:22<1191:37:14, 871.03s/it]

Epoch 74 finished ! Training Loss: 0.1263

Epoch 75 finished ! Training Loss: 0.1189

------- 1st valloss=0.2422



  2%|▏         | 75/4999 [18:10:22<1203:09:29, 879.64s/it]

Checkpoint 75 saved !


  2%|▏         | 76/4999 [18:24:27<1188:59:32, 869.46s/it]

Epoch 76 finished ! Training Loss: 0.1144



  2%|▏         | 77/4999 [18:38:39<1181:21:14, 864.05s/it]

Epoch 77 finished ! Training Loss: 0.0987

Epoch 78 finished ! Training Loss: 0.0970



  2%|▏         | 78/4999 [18:53:55<1202:21:10, 879.59s/it]

------- 1st valloss=0.1109



  2%|▏         | 79/4999 [19:08:02<1188:57:31, 869.97s/it]

Epoch 79 finished ! Training Loss: 0.0939



  2%|▏         | 80/4999 [19:22:26<1186:06:48, 868.06s/it]

Epoch 80 finished ! Training Loss: 0.0918

Epoch 81 finished ! Training Loss: 0.0922

------- 1st valloss=0.0928

0.09282480245051176 less than 0.10827822788901952


  2%|▏         | 81/4999 [19:37:36<1203:02:18, 880.63s/it]

Checkpoint 81 saved !


  2%|▏         | 82/4999 [19:51:45<1189:53:08, 871.18s/it]

Epoch 82 finished ! Training Loss: 0.0937



  2%|▏         | 83/4999 [20:05:58<1182:07:19, 865.67s/it]

Epoch 83 finished ! Training Loss: 0.0923

Epoch 84 finished ! Training Loss: 0.0917

------- 1st valloss=0.0910

0.09103151864331702 less than 0.09282480245051176


  2%|▏         | 84/4999 [20:20:56<1195:09:36, 875.40s/it]

Checkpoint 84 saved !


  2%|▏         | 85/4999 [20:35:19<1189:59:29, 871.79s/it]

Epoch 85 finished ! Training Loss: 0.0908



  2%|▏         | 86/4999 [20:49:43<1186:23:13, 869.32s/it]

Epoch 86 finished ! Training Loss: 0.0902

Epoch 87 finished ! Training Loss: 0.0895



  2%|▏         | 87/4999 [21:04:31<1193:50:19, 874.96s/it]

------- 1st valloss=0.0946



  2%|▏         | 88/4999 [21:18:55<1189:00:49, 871.60s/it]

Epoch 88 finished ! Training Loss: 0.0890



  2%|▏         | 89/4999 [21:33:14<1183:53:57, 868.03s/it]

Epoch 89 finished ! Training Loss: 0.0916

Epoch 90 finished ! Training Loss: 0.0872

------- 1st valloss=0.1051



  2%|▏         | 90/4999 [21:48:12<1195:38:43, 876.82s/it]

Checkpoint 90 saved !


  2%|▏         | 91/4999 [22:02:35<1190:01:39, 872.88s/it]

Epoch 91 finished ! Training Loss: 0.0895



  2%|▏         | 92/4999 [22:16:50<1182:26:37, 867.49s/it]

Epoch 92 finished ! Training Loss: 0.0865

Epoch 93 finished ! Training Loss: 0.0888

------- 1st valloss=0.0834

0.08344225326310033 less than 0.09103151864331702


  2%|▏         | 93/4999 [22:31:47<1194:19:06, 876.39s/it]

Checkpoint 93 saved !


  2%|▏         | 94/4999 [22:46:00<1184:21:53, 869.26s/it]

Epoch 94 finished ! Training Loss: 0.0874



  2%|▏         | 95/4999 [23:00:25<1182:31:57, 868.09s/it]

Epoch 95 finished ! Training Loss: 0.0851

Epoch 96 finished ! Training Loss: 0.0869



  2%|▏         | 96/4999 [23:15:31<1197:37:33, 879.35s/it]

------- 1st valloss=0.0934



  2%|▏         | 97/4999 [23:29:44<1186:26:15, 871.31s/it]

Epoch 97 finished ! Training Loss: 0.0864



  2%|▏         | 98/4999 [23:44:09<1183:49:29, 869.57s/it]

Epoch 98 finished ! Training Loss: 0.0842

Epoch 99 finished ! Training Loss: 0.0851



  2%|▏         | 99/4999 [23:59:06<1194:50:22, 877.84s/it]

------- 1st valloss=0.0925



  2%|▏         | 100/4999 [24:13:26<1187:11:39, 872.40s/it]

Epoch 100 finished ! Training Loss: 0.0868



  2%|▏         | 101/4999 [24:27:28<1174:41:08, 863.39s/it]

Epoch 101 finished ! Training Loss: 0.0833

Epoch 102 finished ! Training Loss: 0.0861



  2%|▏         | 102/4999 [24:42:19<1185:36:35, 871.59s/it]

------- 1st valloss=0.0961



  2%|▏         | 103/4999 [24:56:34<1178:31:58, 866.57s/it]

Epoch 103 finished ! Training Loss: 0.0857



  2%|▏         | 104/4999 [25:10:51<1174:19:51, 863.66s/it]

Epoch 104 finished ! Training Loss: 0.0811

Epoch 105 finished ! Training Loss: 0.0817

------- 1st valloss=0.1153



  2%|▏         | 105/4999 [25:25:48<1187:54:35, 873.82s/it]

Checkpoint 105 saved !


  2%|▏         | 106/4999 [25:40:04<1180:11:50, 868.32s/it]

Epoch 106 finished ! Training Loss: 0.0834



  2%|▏         | 107/4999 [25:54:17<1173:38:48, 863.68s/it]

Epoch 107 finished ! Training Loss: 0.0817

Epoch 108 finished ! Training Loss: 0.0802



  2%|▏         | 108/4999 [26:09:11<1186:05:56, 873.02s/it]

------- 1st valloss=0.0900



  2%|▏         | 109/4999 [26:23:35<1181:59:26, 870.18s/it]

Epoch 109 finished ! Training Loss: 0.0803



In [None]:
deeplab.eval()

with torch.no_grad():
    
    bgloss = 0
    bdloss = 0
    bvloss = 0
    
    for v, vbatch in tqdm(enumerate(validation_loader)):
            # move data to device, convert dtype to desirable dtype

        image_1 = vbatch['image1_data'].to(device=device, dtype=dtype)
        label_1 = vbatch['image1_label'].to(device=device, dtype=dtype)

        output = deeplab(image_1)
        # do the inference
        output_numpy = output.cpu().numpy()
        
        
        #out_1 = torch.round(output)
        out_1 = torch.from_numpy((output_numpy == output_numpy.max(axis=1)[:, None]).astype(int)).to(device=device, dtype=dtype)
        loss_1 = dice_loss_3(out_1, label_1)

        bg, bd, bv = dice_loss_3_debug(out_1, label_1)
        # calculate loss
        print(bg.item(), bd.item(), bv.item(), loss_1.item())
        bgloss += bg.item()
        bdloss += bd.item()
        bvloss += bv.item()
        
        if bv.item() >= 0.2 or bd.item() >= 0.1:
            show_image_slice(image_1)
            show_image_slice(label_1)
            show_image_slice(output)

    outstr = '------- background loss = {0:.4f}, body loss = {1:.4f}, bv loss = {2:.4f}'\
        .format(bgloss/(v+1), bdloss/(v+1), bvloss/(v+1)) + '\n'
    print(outstr)