In [1]:
import sys
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utility import *
from data_utils import *
from loss import *
from train import *
from deeplab_model.deeplab import *
from sync_batchnorm import convert_model
import datetime

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
USE_GPU = True
NUM_WORKERS = 12
BATCH_SIZE = 2 

dtype = torch.float32 
# define dtype, float is space efficient than double

if USE_GPU and torch.cuda.is_available():
    
    device = torch.device('cuda')
    
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    # magic flag that accelerate
    
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


In [3]:
train_dataset = pyramid_dataset(data_type = 'nii_train', 
                transform=transforms.Compose([
                random_affine(90, 15),
                random_filp(0.5), 
                transforms.RandomApply([ElasticTransformation(256*2, 256*0.08)])
                ]))
# do data augumentation on train dataset

validation_dataset = pyramid_dataset(data_type = 'nii_test', 
                transform=None)
# no data augumentation on validation dataset

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS) # drop_last
# loaders come with auto batch division and multi-thread acceleration

In [None]:
deeplab = DeepLab(output_stride=16)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)
deeplab = deeplab.to(device=device, dtype=dtype)
#shape_test(icnet1, True)
# create the model, by default model type is float, use model.double(), model.float() to convert
# move the model to desirable device

optimizer = optim.Adam(deeplab.parameters(), lr=1e-2)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1)
epoch = 0

# create an optimizer object
# note that only the model_2 params and model_4 params will be optimized by optimizer

In [None]:
"""
deeplab = DeepLab(output_stride=8)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)

optimizer = optim.Adam(deeplab.parameters(), lr=1e-2)

#checkpoint = torch.load('../deeplab_save/2019-07-29 04:00:14.630172.pth') # second best
#checkpoint = torch.load('../deeplab_save/2019-07-28 23:47:36.279119.pth') # second best
#checkpoint = torch.load('../deeplab_save/2019-07-29 00:15:49.271222.pth') # best
#checkpoint = torch.load('../deeplab_save/2019-07-29 00:44:11.825872.pth')
checkpoint = torch.load('../deeplab_save/2019-07-31 20:34:01.096131.pth') # latest one

deeplab.load_state_dict(checkpoint['state_dict_1'])
optimizer.load_state_dict(checkpoint['optimizer'])
#scheduler.load_state_dict(checkpoint['scheduler'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
scheduler.load_state_dict(checkpoint['scheduler'])
epoch = checkpoint['epoch']
print(epoch)
"""

"\ndeeplab = DeepLab(output_stride=8)\ndeeplab = nn.DataParallel(deeplab)\ndeeplab = convert_model(deeplab)\n\noptimizer = optim.Adam(deeplab.parameters(), lr=1e-2)\n\n#checkpoint = torch.load('../deeplab_save/2019-07-29 04:00:14.630172.pth') # second best\n#checkpoint = torch.load('../deeplab_save/2019-07-28 23:47:36.279119.pth') # second best\n#checkpoint = torch.load('../deeplab_save/2019-07-29 00:15:49.271222.pth') # best\n#checkpoint = torch.load('../deeplab_save/2019-07-29 00:44:11.825872.pth')\ncheckpoint = torch.load('../deeplab_save/2019-07-31 20:34:01.096131.pth') # latest one\n\ndeeplab.load_state_dict(checkpoint['state_dict_1'])\noptimizer.load_state_dict(checkpoint['optimizer'])\n#scheduler.load_state_dict(checkpoint['scheduler'])\nscheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)\nscheduler.load_state_dict(checkpoint['scheduler'])\nepoch = checkpoint['epoch']\nprint(epoch)\n"

In [None]:
epochs = 5000

min_val = 1

record = open('train_deeplab_output_16_elastic.txt','w+')

logger = {'train':[], 'validation_1': []}

for e in tqdm(range(epoch + 1, epochs)):
# iter over epoches

    epoch_loss = 0
        
    for t, batch in enumerate(train_loader):
    # iter over the train mini batches
    
        deeplab.train()
        # Set the model flag to train
        # 1. enable dropout
        # 2. batchnorm behave differently in train and test
        
        image_1 = batch['image1_data'].to(device=device, dtype=dtype)
        label_1 = batch['image1_label'].to(device=device, dtype=dtype)
        # move data to device, convert dtype to desirable dtype
        
        out_1 = deeplab(image_1)
        # do the inference

        loss_1 = dice_loss_3(out_1, label_1)
        # calculate loss
        
        epoch_loss += loss_1.item()
        # record minibatch loss to epoch loss
        
        optimizer.zero_grad()
        # set the model parameter gradient to zero
        
        loss_1.backward()
        # calculate the gradient wrt loss
        optimizer.step()
        #scheduler.step(loss_1)
        # take a gradient descent step
        
    outstr = 'Epoch {0} finished ! Training Loss: {1:.4f}'.format(e, epoch_loss/(t+1)) + '\n'
    
    logger['train'].append(epoch_loss/(t+1))
    
    print(outstr)
    record.write(outstr)
    record.flush()

    if e%5 == 0:
    # do validation every 5 epoches
    
        deeplab.eval()
        # set model flag to eval
        # 1. disable dropout
        # 2. batchnorm behave differs

        with torch.no_grad():
        # stop taking gradient
        
            #valloss_4 = 0
            #valloss_2 = 0
            valloss_1 = 0
            
            for v, vbatch in enumerate(validation_loader):
            # iter over validation mini batches
                
                image_1_val = vbatch['image1_data'].to(device=device, dtype=dtype)
                if get_dimensions(image_1_val) == 4:
                    image_1_val.unsqueeze_(0)
                label_1_val = vbatch['image1_label'].to(device=device, dtype=dtype)
                if get_dimensions(label_1_val) == 4:
                    label_1_val.unsqueeze_(0)
                # move data to device, convert dtype to desirable dtype
                # add one dimension to labels if they are 4D tensors
                
                out_1_val = deeplab(image_1_val)
                # do the inference
                
                loss_1 = dice_loss_3(out_1_val, label_1_val)
                # calculate loss

                valloss_1 += loss_1.item()
                # record mini batch loss
            
            avg_val_loss = (valloss_1 / (v+1))
            outstr = '------- 1st valloss={0:.4f}'\
                .format(avg_val_loss) + '\n'
            
            logger['validation_1'].append(avg_val_loss)
            #scheduler.step(avg_val_loss)
            
            print(outstr)
            record.write(outstr)
            record.flush()
            
            if avg_val_loss < min_val:
                print(avg_val_loss, "less than", min_val)
                min_val = avg_val_loss
                
            save_1('deeplab_output_16_elastic_save', deeplab, optimizer, logger, e, scheduler)

record.close()

  0%|          | 1/4999 [12:06<1009:18:43, 727.00s/it]

Epoch 1 finished ! Training Loss: 0.5148



  0%|          | 2/4999 [23:05<980:26:15, 706.34s/it] 

Epoch 2 finished ! Training Loss: 0.4794



  0%|          | 3/4999 [34:15<965:16:12, 695.55s/it]

Epoch 3 finished ! Training Loss: 0.4703



  0%|          | 4/4999 [45:17<951:08:01, 685.50s/it]

Epoch 4 finished ! Training Loss: 0.4612

Epoch 5 finished ! Training Loss: 0.4468

------- 1st valloss=0.3675

0.3675405616345613 less than 1


  0%|          | 5/4999 [57:02<959:08:52, 691.42s/it]

Checkpoint 5 saved !


  0%|          | 6/4999 [1:08:00<945:00:09, 681.36s/it]

Epoch 6 finished ! Training Loss: 0.4336



  0%|          | 7/4999 [1:19:11<940:15:31, 678.07s/it]

Epoch 7 finished ! Training Loss: 0.4203



  0%|          | 8/4999 [1:30:00<928:05:17, 669.43s/it]

Epoch 8 finished ! Training Loss: 0.4168



  0%|          | 9/4999 [1:41:01<924:39:13, 667.08s/it]

Epoch 9 finished ! Training Loss: 0.4087

Epoch 10 finished ! Training Loss: 0.3987

------- 1st valloss=0.3311

0.33110871133597003 less than 0.3675405616345613


  0%|          | 10/4999 [1:52:35<935:20:51, 674.94s/it]

Checkpoint 10 saved !


  0%|          | 11/4999 [2:03:42<932:08:54, 672.76s/it]

Epoch 11 finished ! Training Loss: 0.3740



  0%|          | 12/4999 [2:14:52<930:35:11, 671.77s/it]

Epoch 12 finished ! Training Loss: 0.3815



  0%|          | 13/4999 [2:26:01<929:16:33, 670.96s/it]

Epoch 13 finished ! Training Loss: 0.3823



  0%|          | 14/4999 [2:37:02<924:49:44, 667.88s/it]

Epoch 14 finished ! Training Loss: 0.3582

Epoch 15 finished ! Training Loss: 0.3756

------- 1st valloss=0.3374



  0%|          | 15/4999 [2:48:43<938:32:17, 677.92s/it]

Checkpoint 15 saved !


  0%|          | 16/4999 [2:59:54<935:38:37, 675.96s/it]

Epoch 16 finished ! Training Loss: 0.3676



  0%|          | 17/4999 [3:10:51<927:22:53, 670.13s/it]

Epoch 17 finished ! Training Loss: 0.3556



  0%|          | 18/4999 [3:22:00<926:59:16, 669.98s/it]

Epoch 18 finished ! Training Loss: 0.3469



  0%|          | 19/4999 [3:33:00<922:18:09, 666.72s/it]

Epoch 19 finished ! Training Loss: 0.3249

Epoch 20 finished ! Training Loss: 0.3193

------- 1st valloss=0.2238

0.22379430545412976 less than 0.33110871133597003


  0%|          | 20/4999 [3:44:55<942:26:07, 681.42s/it]

Checkpoint 20 saved !


  0%|          | 21/4999 [3:55:56<933:26:51, 675.05s/it]

Epoch 21 finished ! Training Loss: 0.3247



  0%|          | 22/4999 [4:07:00<928:55:12, 671.91s/it]

Epoch 22 finished ! Training Loss: 0.3085



  0%|          | 23/4999 [4:18:05<925:50:20, 669.82s/it]

Epoch 23 finished ! Training Loss: 0.3176



  0%|          | 24/4999 [4:29:12<924:34:51, 669.04s/it]

Epoch 24 finished ! Training Loss: 0.3255

Epoch 25 finished ! Training Loss: 0.2984

------- 1st valloss=0.1761

0.17605876047974048 less than 0.22379430545412976


  1%|          | 25/4999 [4:40:59<939:54:13, 680.27s/it]

Checkpoint 25 saved !


  1%|          | 26/4999 [4:52:03<933:01:26, 675.42s/it]

Epoch 26 finished ! Training Loss: 0.3095



  1%|          | 27/4999 [5:02:51<921:33:27, 667.26s/it]

Epoch 27 finished ! Training Loss: 0.3050



  1%|          | 28/4999 [5:14:01<922:38:29, 668.18s/it]

Epoch 28 finished ! Training Loss: 0.3088



  1%|          | 29/4999 [5:24:58<917:51:40, 664.85s/it]

Epoch 29 finished ! Training Loss: 0.2987

Epoch 30 finished ! Training Loss: 0.2865

------- 1st valloss=0.1543

0.15434382823498352 less than 0.17605876047974048


  1%|          | 30/4999 [5:36:31<929:06:15, 673.13s/it]

Checkpoint 30 saved !


  1%|          | 31/4999 [5:47:37<925:52:27, 670.92s/it]

Epoch 31 finished ! Training Loss: 0.3015



  1%|          | 32/4999 [5:58:33<919:44:16, 666.61s/it]

Epoch 32 finished ! Training Loss: 0.2819



  1%|          | 33/4999 [6:09:43<920:55:45, 667.61s/it]

Epoch 33 finished ! Training Loss: 0.2808



  1%|          | 34/4999 [6:20:52<921:26:41, 668.12s/it]

Epoch 34 finished ! Training Loss: 0.2914

Epoch 35 finished ! Training Loss: 0.2772

------- 1st valloss=0.1476

0.147622001559838 less than 0.15434382823498352


  1%|          | 35/4999 [6:32:31<933:38:48, 677.10s/it]

Checkpoint 35 saved !


  1%|          | 36/4999 [6:43:29<925:56:30, 671.65s/it]

Epoch 36 finished ! Training Loss: 0.2829



  1%|          | 37/4999 [6:54:24<918:33:24, 666.43s/it]

Epoch 37 finished ! Training Loss: 0.2818



  1%|          | 38/4999 [7:05:14<911:39:28, 661.55s/it]

Epoch 38 finished ! Training Loss: 0.2965



  1%|          | 39/4999 [7:16:21<913:54:10, 663.32s/it]

Epoch 39 finished ! Training Loss: 0.2743

Epoch 40 finished ! Training Loss: 0.2486

------- 1st valloss=0.1289

0.1288576712426932 less than 0.147622001559838


  1%|          | 40/4999 [7:27:53<925:22:12, 671.78s/it]

Checkpoint 40 saved !


  1%|          | 41/4999 [7:38:49<918:36:33, 667.00s/it]

Epoch 41 finished ! Training Loss: 0.2653



  1%|          | 42/4999 [7:50:02<921:00:17, 668.88s/it]

Epoch 42 finished ! Training Loss: 0.2748



  1%|          | 43/4999 [8:00:58<915:25:19, 664.96s/it]

Epoch 43 finished ! Training Loss: 0.2722



  1%|          | 44/4999 [8:11:49<909:40:40, 660.92s/it]

Epoch 44 finished ! Training Loss: 0.2586

Epoch 45 finished ! Training Loss: 0.2828

------- 1st valloss=0.1686



  1%|          | 45/4999 [8:23:29<925:37:08, 672.63s/it]

Checkpoint 45 saved !


  1%|          | 46/4999 [8:34:23<917:30:03, 666.87s/it]

Epoch 46 finished ! Training Loss: 0.2603



  1%|          | 47/4999 [8:45:26<916:01:26, 665.93s/it]

Epoch 47 finished ! Training Loss: 0.2684



  1%|          | 48/4999 [8:56:20<910:34:49, 662.11s/it]

Epoch 48 finished ! Training Loss: 0.2670



  1%|          | 49/4999 [9:07:28<912:48:06, 663.86s/it]

Epoch 49 finished ! Training Loss: 0.2560

Epoch 50 finished ! Training Loss: 0.2626

------- 1st valloss=0.4348



  1%|          | 50/4999 [9:19:10<928:22:17, 675.32s/it]

Checkpoint 50 saved !


  1%|          | 51/4999 [9:30:17<924:43:56, 672.80s/it]

Epoch 51 finished ! Training Loss: 0.2654



  1%|          | 52/4999 [9:41:23<922:07:52, 671.05s/it]

Epoch 52 finished ! Training Loss: 0.2670



  1%|          | 53/4999 [9:52:29<919:41:18, 669.41s/it]

Epoch 53 finished ! Training Loss: 0.2688



  1%|          | 54/4999 [10:03:32<916:49:33, 667.46s/it]

Epoch 54 finished ! Training Loss: 0.2582

Epoch 55 finished ! Training Loss: 0.2753

------- 1st valloss=0.1845



  1%|          | 55/4999 [10:15:05<927:01:11, 675.01s/it]

Checkpoint 55 saved !


  1%|          | 56/4999 [10:25:59<918:28:05, 668.92s/it]

Epoch 56 finished ! Training Loss: 0.2571



  1%|          | 57/4999 [10:36:56<913:21:43, 665.34s/it]

Epoch 57 finished ! Training Loss: 0.2667



  1%|          | 58/4999 [10:48:11<917:01:52, 668.15s/it]

Epoch 58 finished ! Training Loss: 0.2543



  1%|          | 59/4999 [10:59:21<917:42:47, 668.78s/it]

Epoch 59 finished ! Training Loss: 0.2443

Epoch 60 finished ! Training Loss: 0.2547

------- 1st valloss=0.4041



  1%|          | 60/4999 [11:11:19<937:43:37, 683.50s/it]

Checkpoint 60 saved !


  1%|          | 61/4999 [11:22:26<930:42:16, 678.52s/it]

Epoch 61 finished ! Training Loss: 0.2632



  1%|          | 62/4999 [11:33:23<921:52:20, 672.22s/it]

Epoch 62 finished ! Training Loss: 0.2282



  1%|▏         | 63/4999 [11:44:20<915:02:12, 667.37s/it]

Epoch 63 finished ! Training Loss: 0.2515



  1%|▏         | 64/4999 [11:55:25<913:58:26, 666.73s/it]

Epoch 64 finished ! Training Loss: 0.2598

Epoch 65 finished ! Training Loss: 0.2482

------- 1st valloss=0.1793



  1%|▏         | 65/4999 [12:07:11<929:51:30, 678.45s/it]

Checkpoint 65 saved !


  1%|▏         | 66/4999 [12:18:02<918:42:18, 670.45s/it]

Epoch 66 finished ! Training Loss: 0.2279



  1%|▏         | 67/4999 [12:28:47<907:46:00, 662.60s/it]

Epoch 67 finished ! Training Loss: 0.2602



  1%|▏         | 68/4999 [12:39:51<908:19:20, 663.14s/it]

Epoch 68 finished ! Training Loss: 0.2445



  1%|▏         | 69/4999 [12:50:53<907:34:53, 662.74s/it]

Epoch 69 finished ! Training Loss: 0.2544

Epoch 70 finished ! Training Loss: 0.2360

------- 1st valloss=0.1601



  1%|▏         | 70/4999 [13:02:34<923:00:30, 674.14s/it]

Checkpoint 70 saved !


  1%|▏         | 71/4999 [13:13:36<918:09:49, 670.74s/it]

Epoch 71 finished ! Training Loss: 0.2523



  1%|▏         | 72/4999 [13:24:35<912:55:40, 667.05s/it]

Epoch 72 finished ! Training Loss: 0.2410



  1%|▏         | 73/4999 [13:35:41<912:30:44, 666.88s/it]

Epoch 73 finished ! Training Loss: 0.2572



  1%|▏         | 74/4999 [13:46:47<911:44:43, 666.45s/it]

Epoch 74 finished ! Training Loss: 0.2618

Epoch 75 finished ! Training Loss: 0.2360

------- 1st valloss=0.1776



  2%|▏         | 75/4999 [13:58:24<924:04:46, 675.61s/it]

Checkpoint 75 saved !


  2%|▏         | 76/4999 [14:09:20<915:47:47, 669.69s/it]

Epoch 76 finished ! Training Loss: 0.2410



  2%|▏         | 77/4999 [14:20:25<914:01:16, 668.52s/it]

Epoch 77 finished ! Training Loss: 0.2354



  2%|▏         | 78/4999 [14:31:17<906:47:05, 663.37s/it]

Epoch 78 finished ! Training Loss: 0.2499



  2%|▏         | 79/4999 [14:42:12<903:15:55, 660.93s/it]

Epoch 79 finished ! Training Loss: 0.2421

Epoch 80 finished ! Training Loss: 0.2439

------- 1st valloss=0.1152

0.11522486805915833 less than 0.1288576712426932


  2%|▏         | 80/4999 [14:53:58<921:44:06, 674.58s/it]

Checkpoint 80 saved !


  2%|▏         | 81/4999 [15:04:49<911:50:23, 667.47s/it]

Epoch 81 finished ! Training Loss: 0.2266



  2%|▏         | 82/4999 [15:15:38<904:06:24, 661.95s/it]

Epoch 82 finished ! Training Loss: 0.2408



  2%|▏         | 83/4999 [15:26:43<904:54:01, 662.66s/it]

Epoch 83 finished ! Training Loss: 0.2367



  2%|▏         | 84/4999 [15:37:41<902:48:15, 661.26s/it]

Epoch 84 finished ! Training Loss: 0.2215

Epoch 85 finished ! Training Loss: 0.2451

------- 1st valloss=0.1321



  2%|▏         | 85/4999 [15:49:37<925:12:31, 677.81s/it]

Checkpoint 85 saved !


  2%|▏         | 86/4999 [16:00:38<918:13:41, 672.83s/it]

Epoch 86 finished ! Training Loss: 0.2304



  2%|▏         | 87/4999 [16:11:44<914:59:44, 670.60s/it]

Epoch 87 finished ! Training Loss: 0.2200



  2%|▏         | 88/4999 [16:22:51<913:37:15, 669.73s/it]

Epoch 88 finished ! Training Loss: 0.2499



  2%|▏         | 89/4999 [16:33:44<906:30:01, 664.64s/it]

Epoch 89 finished ! Training Loss: 0.2338

Epoch 90 finished ! Training Loss: 0.2338

------- 1st valloss=0.2335



  2%|▏         | 90/4999 [16:50:16<1040:20:33, 762.93s/it]

Checkpoint 90 saved !


  2%|▏         | 91/4999 [17:01:02<992:13:59, 727.80s/it] 

Epoch 91 finished ! Training Loss: 0.2306



  2%|▏         | 92/4999 [17:12:09<967:04:44, 709.49s/it]

Epoch 92 finished ! Training Loss: 0.2375



  2%|▏         | 93/4999 [17:23:07<945:51:59, 694.07s/it]

Epoch 93 finished ! Training Loss: 0.2200



  2%|▏         | 94/4999 [17:34:13<933:59:52, 685.50s/it]

Epoch 94 finished ! Training Loss: 0.2510

Epoch 95 finished ! Training Loss: 0.2309

------- 1st valloss=0.2550



  2%|▏         | 95/4999 [17:46:00<942:55:47, 692.20s/it]

Checkpoint 95 saved !


  2%|▏         | 96/4999 [17:56:58<928:33:45, 681.79s/it]

Epoch 96 finished ! Training Loss: 0.2218



  2%|▏         | 97/4999 [18:08:10<924:29:50, 678.95s/it]

Epoch 97 finished ! Training Loss: 0.2336



  2%|▏         | 98/4999 [18:19:16<919:03:17, 675.09s/it]

Epoch 98 finished ! Training Loss: 0.2298



  2%|▏         | 99/4999 [18:30:21<914:34:03, 671.93s/it]

Epoch 99 finished ! Training Loss: 0.2381

Epoch 100 finished ! Training Loss: 0.2356

------- 1st valloss=0.2644



  2%|▏         | 100/4999 [18:42:09<928:58:55, 682.66s/it]

Checkpoint 100 saved !


  2%|▏         | 101/4999 [18:52:58<915:17:10, 672.73s/it]

Epoch 101 finished ! Training Loss: 0.2350



  2%|▏         | 102/4999 [19:04:00<910:31:07, 669.36s/it]

Epoch 102 finished ! Training Loss: 0.2184



  2%|▏         | 103/4999 [19:14:57<905:33:53, 665.86s/it]

Epoch 103 finished ! Training Loss: 0.2175



  2%|▏         | 104/4999 [19:25:46<898:16:14, 660.63s/it]

Epoch 104 finished ! Training Loss: 0.2319

Epoch 105 finished ! Training Loss: 0.2159

------- 1st valloss=0.5999



  2%|▏         | 105/4999 [19:37:45<921:54:32, 678.15s/it]

Checkpoint 105 saved !


  2%|▏         | 106/4999 [19:48:55<918:35:07, 675.84s/it]

Epoch 106 finished ! Training Loss: 0.2357



  2%|▏         | 107/4999 [19:59:59<913:33:42, 672.29s/it]

Epoch 107 finished ! Training Loss: 0.2306



  2%|▏         | 108/4999 [20:10:52<905:26:53, 666.45s/it]

Epoch 108 finished ! Training Loss: 0.2141



  2%|▏         | 109/4999 [20:21:51<902:00:36, 664.06s/it]

Epoch 109 finished ! Training Loss: 0.2196



In [None]:
deeplab.eval()

with torch.no_grad():
    
    bgloss = 0
    bdloss = 0
    bvloss = 0
    
    for v, vbatch in tqdm(enumerate(validation_loader)):
            # move data to device, convert dtype to desirable dtype

        image_1 = vbatch['image1_data'].to(device=device, dtype=dtype)
        label_1 = vbatch['image1_label'].to(device=device, dtype=dtype)

        output = deeplab(image_1)
        # do the inference
        output_numpy = output.cpu().numpy()
        
        
        #out_1 = torch.round(output)
        out_1 = torch.from_numpy((output_numpy == output_numpy.max(axis=1)[:, None]).astype(int)).to(device=device, dtype=dtype)
        loss_1 = dice_loss_3(out_1, label_1)

        bg, bd, bv = dice_loss_3_debug(out_1, label_1)
        # calculate loss
        print(bg.item(), bd.item(), bv.item(), loss_1.item())
        bgloss += bg.item()
        bdloss += bd.item()
        bvloss += bv.item()

    outstr = '------- background loss = {0:.4f}, body loss = {1:.4f}, bv loss = {2:.4f}'\
        .format(bgloss/(v+1), bdloss/(v+1), bvloss/(v+1)) + '\n'
    print(outstr)