In [1]:
import sys
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utility import *
from data_utils import *
from model import *
from loss import *
from train import *
from sync_batchnorm import convert_model
import datetime

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
USE_GPU = True
NUM_WORKERS = 12
BATCH_SIZE = 2 

dtype = torch.float32 
# define dtype, float is space efficient than double

if USE_GPU and torch.cuda.is_available():
    
    device = torch.device('cuda')
    
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    # magic flag that accelerate
    
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


In [3]:
train_dataset = pyramid_dataset(data_type = 'nii_train', 
                transform=transforms.Compose([
                random_affine(90, 15),
                random_filp(0.5)]))
# do data augumentation on train dataset

validation_dataset = pyramid_dataset(data_type = 'nii_test', 
                transform=None)
# no data augumentation on validation dataset

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS) # drop_last
# loaders come with auto batch division and multi-thread acceleration

In [4]:
def shape_test(model, cuda_bool):
    x = torch.zeros((1, 1, 256, 256, 256))
    x = x.to(device=device, dtype=dtype) if cuda_bool else x
    scores = model(x)
    for i in scores:
        print(i.size())

In [5]:
#from model import *

#icnet1 = ModifiedICNet(num_classes=3)
#icnet1.apply(init_weights)
#icnet1 = icnet1.to(device=device, dtype=dtype)
#shape_test(icnet1, True)
# create the model, by default model type is float, use model.double(), model.float() to convert
# move the model to desirable device

#optimizer1 = optim.Adam(icnet1.parameters(), lr=1e-2)
# create an optimizer object
# note that only the model_2 params and model_4 params will be optimized by optimizer

In [6]:
icnet1 = ModifiedICNet(num_classes=3)

checkpoint = torch.load('../half_res_save/2019-07-29 00:26:10.775920.pth')

icnet1.load_state_dict(checkpoint['state_dict_1'])
#icnet1 = nn.DataParallel(icnet1)
#icnet1 = convert_model(icnet1)
icnet1 = icnet1.to(device=device, dtype=dtype)

optimizer1 = optim.Adam(icnet1.parameters())
optimizer1.load_state_dict(checkpoint['optimizer'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer1)

epoch = checkpoint['epoch']
print(epoch)

209


In [6]:
epochs = 5000

record= open('train_half_res_no_sync.txt','a+')

logger = {'train':[], 'validation_1': []}

for e in tqdm(range(epoch, epochs)):
# iter over epoches

    epoch_loss = 0
        
    for t, batch in enumerate(train_loader):
    # iter over the train mini batches
    
        icnet1.train()
        # Set the model flag to train
        # 1. enable dropout
        # 2. batchnorm behave differently in train and test
        
        image_4 = batch['image4_data'].to(device=device, dtype=dtype)
        label_4 = batch['image4_label'].to(device=device, dtype=dtype)
        
        image_2 = batch['image2_data'].to(device=device, dtype=dtype)
        label_2 = batch['image2_label'].to(device=device, dtype=dtype)
        
        image_1 = batch['image1_data'].to(device=device, dtype=dtype)
        label_1 = batch['image1_label'].to(device=device, dtype=dtype)
        # move data to device, convert dtype to desirable dtype
        
        # Downsample labels to coincide with icnet model outputs
        label_1_resize_2 = downsample_label(label_1, 1/2)
        label_2_resize_2 = downsample_label(label_2, 1/2)
        label_4_resize_2 = downsample_label(label_4, 1/2)
        
        out_1, out_2, out_4 = icnet1(image_1)
        # do the inference

        loss_4 = dice_loss_3(out_4, label_4_resize_2)
        loss_2 = dice_loss_3(out_2, label_2_resize_2)
        loss_1 = dice_loss_3(out_1, label_1_resize_2)
        # calculate loss

        loss = loss_4 + loss_2 + loss_1 
        # add loss
        
        epoch_loss += loss.item()
        # record minibatch loss to epoch loss
        
        optimizer1.zero_grad()
        # set the model parameter gradient to zero
        
        loss.backward()
        # calculate the gradient wrt loss
        
        optimizer1.step()
        # take a gradient descent step
        
    outstr = 'Epoch {0} finished ! Training Loss: {1:.4f}'.format(e, epoch_loss/(t+1)) + '\n'
    
    logger['train'].append(epoch_loss/(t+1))
    
    print(outstr)
    record.write(outstr)
    record.flush()

    if e%5 == 4:
    # do validation every 5 epoches
    
        icnet1.eval()
        # set model flag to eval
        # 1. disable dropout
        # 2. batchnorm behave differs

        with torch.no_grad():
        # stop taking gradient
        
            #valloss_4 = 0
            #valloss_2 = 0
            valloss_1 = 0
            
            for v, vbatch in enumerate(validation_loader):
            # iter over validation mini batches
                
                image_1_val = vbatch['image1_data'].to(device=device, dtype=dtype)
                if get_dimensions(image_1_val) == 4:
                    image_1_val.unsqueeze_(0)
                label_1_val = vbatch['image1_label'].to(device=device, dtype=dtype)
                if get_dimensions(label_1_val) == 4:
                    label_1_val.unsqueeze_(0)
                # move data to device, convert dtype to desirable dtype
                # add one dimension to labels if they are 4D tensors
                
                # Downsample labels to coincide with icnet model outputs
                label_1_val_resize_2 = downsample_label(label_1_val, 1/2) 
                
                out_1_val = icnet1(image_1_val)
                # do the inference
                
                loss_1 = dice_loss_3(out_1_val, label_1_val_resize_2)
                # calculate loss

                valloss_1 += loss_1.item()
                
                
                # record mini batch loss
            
            outstr = '------- 1st valloss={0:.4f}'\
                .format(valloss_1/(v+1)) + '\n'
            
            logger['validation_1'].append(valloss_1/(v+1))
            scheduler.step(valloss_1/(v+1))
            
            print(outstr)
            record.write(outstr)
            record.flush()
            
            save_1('half_res_save', icnet1, optimizer1, logger, e)

record.close()

  0%|          | 0/4791 [00:00<?, ?it/s]

Epoch 209 finished ! Training Loss: 0.4119



  0%|          | 1/4791 [06:10<492:26:08, 370.10s/it]

------- 1st valloss=0.0591

Checkpoint 209 saved !


  0%|          | 2/4791 [11:34<474:08:22, 356.42s/it]

Epoch 210 finished ! Training Loss: 0.4116



  0%|          | 3/4791 [17:05<463:47:49, 348.72s/it]

Epoch 211 finished ! Training Loss: 0.4137



  0%|          | 4/4791 [22:40<458:26:32, 344.77s/it]

Epoch 212 finished ! Training Loss: 0.4179



  0%|          | 5/4791 [28:08<451:31:40, 339.64s/it]

Epoch 213 finished ! Training Loss: 0.4122

Epoch 214 finished ! Training Loss: 0.4062



  0%|          | 6/4791 [34:24<466:04:25, 350.65s/it]

------- 1st valloss=0.0628

Checkpoint 214 saved !


  0%|          | 7/4791 [39:44<453:39:31, 341.38s/it]

Epoch 215 finished ! Training Loss: 0.4086



  0%|          | 8/4791 [45:13<448:33:24, 337.61s/it]

Epoch 216 finished ! Training Loss: 0.4126



  0%|          | 9/4791 [50:42<444:55:35, 334.95s/it]

Epoch 217 finished ! Training Loss: 0.4097



  0%|          | 10/4791 [56:13<443:18:02, 333.80s/it]

Epoch 218 finished ! Training Loss: 0.4093

Epoch 219 finished ! Training Loss: 0.4071



  0%|          | 11/4791 [1:02:29<459:59:55, 346.44s/it]

------- 1st valloss=0.0620

Checkpoint 219 saved !


  0%|          | 12/4791 [1:08:08<457:00:21, 344.26s/it]

Epoch 220 finished ! Training Loss: 0.4194



  0%|          | 13/4791 [1:13:39<451:49:09, 340.42s/it]

Epoch 221 finished ! Training Loss: 0.4002



  0%|          | 14/4791 [1:19:14<449:13:17, 338.54s/it]

Epoch 222 finished ! Training Loss: 0.4040



  0%|          | 15/4791 [1:24:32<441:00:26, 332.42s/it]

Epoch 223 finished ! Training Loss: 0.4132

Epoch 224 finished ! Training Loss: 0.4072



  0%|          | 16/4791 [1:30:40<455:06:25, 343.12s/it]

------- 1st valloss=0.0620

Checkpoint 224 saved !


  0%|          | 17/4791 [1:36:03<447:11:51, 337.22s/it]

Epoch 225 finished ! Training Loss: 0.4224



  0%|          | 18/4791 [1:41:36<445:22:11, 335.92s/it]

Epoch 226 finished ! Training Loss: 0.4066



  0%|          | 19/4791 [1:46:55<438:34:58, 330.87s/it]

Epoch 227 finished ! Training Loss: 0.4088



  0%|          | 20/4791 [1:52:31<440:18:07, 332.23s/it]

Epoch 228 finished ! Training Loss: 0.4000

Epoch 229 finished ! Training Loss: 0.4134



  0%|          | 21/4791 [1:58:37<453:51:44, 342.54s/it]

------- 1st valloss=0.0636

Checkpoint 229 saved !


  0%|          | 22/4791 [2:04:06<448:24:05, 338.49s/it]

Epoch 230 finished ! Training Loss: 0.4184



  0%|          | 23/4791 [2:09:36<445:00:55, 336.00s/it]

Epoch 231 finished ! Training Loss: 0.4131



  1%|          | 24/4791 [2:15:16<446:22:33, 337.10s/it]

Epoch 232 finished ! Training Loss: 0.4162



  1%|          | 25/4791 [2:20:39<440:29:34, 332.73s/it]

Epoch 233 finished ! Training Loss: 0.4193

Epoch 234 finished ! Training Loss: 0.4225



  1%|          | 26/4791 [2:27:05<461:39:45, 348.79s/it]

------- 1st valloss=0.0617

Checkpoint 234 saved !


  1%|          | 27/4791 [2:32:27<450:53:12, 340.72s/it]

Epoch 235 finished ! Training Loss: 0.4284



  1%|          | 28/4791 [2:37:59<447:26:16, 338.19s/it]

Epoch 236 finished ! Training Loss: 0.4010



  1%|          | 29/4791 [2:43:22<441:05:57, 333.46s/it]

Epoch 237 finished ! Training Loss: 0.4145



  1%|          | 30/4791 [2:48:45<437:09:11, 330.55s/it]

Epoch 238 finished ! Training Loss: 0.4119

Epoch 239 finished ! Training Loss: 0.4160



  1%|          | 31/4791 [2:55:05<456:31:56, 345.28s/it]

------- 1st valloss=0.0614

Checkpoint 239 saved !


  1%|          | 32/4791 [3:00:28<447:32:48, 338.55s/it]

Epoch 240 finished ! Training Loss: 0.4146



  1%|          | 33/4791 [3:05:54<442:35:40, 334.88s/it]

Epoch 241 finished ! Training Loss: 0.4164



  1%|          | 34/4791 [3:11:21<439:14:28, 332.41s/it]

Epoch 242 finished ! Training Loss: 0.4166



  1%|          | 35/4791 [3:16:45<435:44:50, 329.83s/it]

Epoch 243 finished ! Training Loss: 0.4092

Epoch 244 finished ! Training Loss: 0.4087



  1%|          | 36/4791 [3:23:01<454:11:04, 343.86s/it]

------- 1st valloss=0.0608

Checkpoint 244 saved !


  1%|          | 37/4791 [3:28:26<446:37:12, 338.21s/it]

Epoch 245 finished ! Training Loss: 0.4204



  1%|          | 38/4791 [3:33:50<440:57:30, 333.99s/it]

Epoch 246 finished ! Training Loss: 0.4058



  1%|          | 39/4791 [3:39:15<437:19:00, 331.30s/it]

Epoch 247 finished ! Training Loss: 0.4058



  1%|          | 40/4791 [3:44:35<432:27:33, 327.69s/it]

Epoch 248 finished ! Training Loss: 0.4099

Epoch 249 finished ! Training Loss: 0.4146



  1%|          | 41/4791 [3:50:48<450:32:16, 341.46s/it]

------- 1st valloss=0.0679

Checkpoint 249 saved !


  1%|          | 42/4791 [3:56:09<442:08:13, 335.16s/it]

Epoch 250 finished ! Training Loss: 0.4181



  1%|          | 43/4791 [4:01:32<437:10:19, 331.47s/it]

Epoch 251 finished ! Training Loss: 0.3982



  1%|          | 44/4791 [4:07:14<441:21:37, 334.72s/it]

Epoch 252 finished ! Training Loss: 0.4115



  1%|          | 45/4791 [4:12:34<435:35:25, 330.41s/it]

Epoch 253 finished ! Training Loss: 0.4077

Epoch 254 finished ! Training Loss: 0.4185



  1%|          | 46/4791 [4:18:46<451:50:14, 342.81s/it]

------- 1st valloss=0.0640

Checkpoint 254 saved !


  1%|          | 47/4791 [4:24:14<446:00:25, 338.45s/it]

Epoch 255 finished ! Training Loss: 0.4164



  1%|          | 48/4791 [4:29:37<439:54:19, 333.89s/it]

Epoch 256 finished ! Training Loss: 0.3982



  1%|          | 49/4791 [4:34:55<433:28:17, 329.08s/it]

Epoch 257 finished ! Training Loss: 0.4121



  1%|          | 50/4791 [4:40:13<428:59:06, 325.74s/it]

Epoch 258 finished ! Training Loss: 0.4124

Epoch 259 finished ! Training Loss: 0.4173



  1%|          | 51/4791 [4:46:21<445:35:43, 338.43s/it]

------- 1st valloss=0.0601

Checkpoint 259 saved !


  1%|          | 52/4791 [4:51:49<441:10:46, 335.14s/it]

Epoch 260 finished ! Training Loss: 0.4182



  1%|          | 53/4791 [4:57:25<441:43:02, 335.62s/it]

Epoch 261 finished ! Training Loss: 0.4033



  1%|          | 54/4791 [5:02:51<437:43:24, 332.66s/it]

Epoch 262 finished ! Training Loss: 0.4092



  1%|          | 55/4791 [5:08:11<432:41:55, 328.91s/it]

Epoch 263 finished ! Training Loss: 0.4101

Epoch 264 finished ! Training Loss: 0.4163



  1%|          | 56/4791 [5:14:25<450:09:23, 342.25s/it]

------- 1st valloss=0.0622

Checkpoint 264 saved !


  1%|          | 57/4791 [5:19:57<446:10:29, 339.30s/it]

Epoch 265 finished ! Training Loss: 0.4179



  1%|          | 58/4791 [5:25:17<438:17:05, 333.37s/it]

Epoch 266 finished ! Training Loss: 0.4089



  1%|          | 59/4791 [5:30:38<433:22:51, 329.71s/it]

Epoch 267 finished ! Training Loss: 0.4042



  1%|▏         | 60/4791 [5:36:08<433:33:38, 329.91s/it]

Epoch 268 finished ! Training Loss: 0.4149

Epoch 269 finished ! Training Loss: 0.4097



  1%|▏         | 61/4791 [5:42:26<452:08:47, 344.13s/it]

------- 1st valloss=0.0605

Checkpoint 269 saved !


  1%|▏         | 62/4791 [5:47:49<443:51:48, 337.90s/it]

Epoch 270 finished ! Training Loss: 0.4013



  1%|▏         | 63/4791 [5:53:11<437:29:40, 333.12s/it]

Epoch 271 finished ! Training Loss: 0.4002



  1%|▏         | 64/4791 [5:58:35<433:50:29, 330.41s/it]

Epoch 272 finished ! Training Loss: 0.4152



  1%|▏         | 65/4791 [6:04:04<433:03:36, 329.88s/it]

Epoch 273 finished ! Training Loss: 0.4172

Epoch 274 finished ! Training Loss: 0.3990



  1%|▏         | 66/4791 [6:10:18<450:25:51, 343.19s/it]

------- 1st valloss=0.0621

Checkpoint 274 saved !


  1%|▏         | 67/4791 [6:15:44<443:43:01, 338.14s/it]

Epoch 275 finished ! Training Loss: 0.4118



  1%|▏         | 68/4791 [6:21:09<438:30:36, 334.24s/it]

Epoch 276 finished ! Training Loss: 0.4052



  1%|▏         | 69/4791 [6:26:29<432:36:05, 329.81s/it]

Epoch 277 finished ! Training Loss: 0.4106



  1%|▏         | 70/4791 [6:31:54<430:48:48, 328.52s/it]

Epoch 278 finished ! Training Loss: 0.4157

Epoch 279 finished ! Training Loss: 0.4133



  1%|▏         | 71/4791 [6:38:07<448:16:18, 341.90s/it]

------- 1st valloss=0.0626

Checkpoint 279 saved !


  2%|▏         | 72/4791 [6:43:41<444:56:52, 339.44s/it]

Epoch 280 finished ! Training Loss: 0.4068



  2%|▏         | 73/4791 [6:49:07<439:32:22, 335.38s/it]

Epoch 281 finished ! Training Loss: 0.4122



  2%|▏         | 74/4791 [6:54:34<436:02:44, 332.79s/it]

Epoch 282 finished ! Training Loss: 0.4118



  2%|▏         | 75/4791 [6:59:59<432:46:28, 330.36s/it]

Epoch 283 finished ! Training Loss: 0.4065

Epoch 284 finished ! Training Loss: 0.4111



  2%|▏         | 76/4791 [7:06:09<448:30:41, 342.45s/it]

------- 1st valloss=0.0652

Checkpoint 284 saved !


  2%|▏         | 77/4791 [7:11:32<440:43:32, 336.57s/it]

Epoch 285 finished ! Training Loss: 0.4203



  2%|▏         | 78/4791 [7:16:54<434:42:11, 332.05s/it]

Epoch 286 finished ! Training Loss: 0.4024



  2%|▏         | 79/4791 [7:22:15<430:33:14, 328.95s/it]

Epoch 287 finished ! Training Loss: 0.4193



  2%|▏         | 80/4791 [7:27:46<431:20:31, 329.62s/it]

Epoch 288 finished ! Training Loss: 0.4137

Epoch 289 finished ! Training Loss: 0.4091



  2%|▏         | 81/4791 [7:34:03<449:50:31, 343.83s/it]

------- 1st valloss=0.0628

Checkpoint 289 saved !


  2%|▏         | 82/4791 [7:39:28<442:02:45, 337.94s/it]

Epoch 290 finished ! Training Loss: 0.4190



  2%|▏         | 83/4791 [7:44:50<435:40:20, 333.14s/it]

Epoch 291 finished ! Training Loss: 0.4077



  2%|▏         | 84/4791 [7:50:13<431:43:40, 330.19s/it]

Epoch 292 finished ! Training Loss: 0.4156



  2%|▏         | 85/4791 [7:55:32<427:22:44, 326.94s/it]

Epoch 293 finished ! Training Loss: 0.4116

Epoch 294 finished ! Training Loss: 0.4071



  2%|▏         | 86/4791 [8:01:41<443:39:37, 339.46s/it]

------- 1st valloss=0.0593

Checkpoint 294 saved !


  2%|▏         | 87/4791 [8:07:05<437:22:06, 334.72s/it]

Epoch 295 finished ! Training Loss: 0.4165



  2%|▏         | 88/4791 [8:12:31<434:11:10, 332.36s/it]

Epoch 296 finished ! Training Loss: 0.4037



  2%|▏         | 89/4791 [8:17:50<428:47:42, 328.30s/it]

Epoch 297 finished ! Training Loss: 0.4109



  2%|▏         | 90/4791 [8:23:07<424:00:54, 324.71s/it]

Epoch 298 finished ! Training Loss: 0.4181

Epoch 299 finished ! Training Loss: 0.4035



  2%|▏         | 91/4791 [8:29:21<443:14:29, 339.50s/it]

------- 1st valloss=0.0642

Checkpoint 299 saved !


  2%|▏         | 92/4791 [8:34:54<440:51:30, 337.75s/it]

Epoch 300 finished ! Training Loss: 0.4255



  2%|▏         | 93/4791 [8:40:19<435:30:58, 333.73s/it]

Epoch 301 finished ! Training Loss: 0.4122



  2%|▏         | 94/4791 [8:45:44<432:04:18, 331.16s/it]

Epoch 302 finished ! Training Loss: 0.4273



  2%|▏         | 95/4791 [8:51:17<432:38:38, 331.67s/it]

Epoch 303 finished ! Training Loss: 0.4147

Epoch 304 finished ! Training Loss: 0.3977



  2%|▏         | 96/4791 [8:57:41<453:03:41, 347.40s/it]

------- 1st valloss=0.0616

Checkpoint 304 saved !


  2%|▏         | 97/4791 [9:03:10<446:00:39, 342.06s/it]

Epoch 305 finished ! Training Loss: 0.3883



  2%|▏         | 98/4791 [9:08:33<438:20:47, 336.26s/it]

Epoch 306 finished ! Training Loss: 0.4082



  2%|▏         | 99/4791 [9:14:02<435:22:18, 334.04s/it]

Epoch 307 finished ! Training Loss: 0.4126



  2%|▏         | 100/4791 [9:19:27<431:47:54, 331.37s/it]

Epoch 308 finished ! Training Loss: 0.4132

Epoch 309 finished ! Training Loss: 0.4178



  2%|▏         | 101/4791 [9:25:43<449:08:35, 344.76s/it]

------- 1st valloss=0.0632

Checkpoint 309 saved !


  2%|▏         | 102/4791 [9:31:11<442:24:27, 339.66s/it]

Epoch 310 finished ! Training Loss: 0.4095



  2%|▏         | 103/4791 [9:36:35<436:26:15, 335.15s/it]

Epoch 311 finished ! Training Loss: 0.4272



  2%|▏         | 104/4791 [9:42:02<432:49:19, 332.44s/it]

Epoch 312 finished ! Training Loss: 0.4126



  2%|▏         | 105/4791 [9:47:20<427:22:44, 328.33s/it]

Epoch 313 finished ! Training Loss: 0.4222

Epoch 314 finished ! Training Loss: 0.4067



  2%|▏         | 106/4791 [9:53:37<446:09:25, 342.83s/it]

------- 1st valloss=0.0621

Checkpoint 314 saved !


  2%|▏         | 107/4791 [9:59:12<443:09:14, 340.60s/it]

Epoch 315 finished ! Training Loss: 0.4172



  2%|▏         | 108/4791 [10:04:29<433:35:36, 333.32s/it]

Epoch 316 finished ! Training Loss: 0.4095



  2%|▏         | 109/4791 [10:09:53<429:50:27, 330.51s/it]

Epoch 317 finished ! Training Loss: 0.4084



  2%|▏         | 110/4791 [10:15:21<428:52:07, 329.83s/it]

Epoch 318 finished ! Training Loss: 0.4130

Epoch 319 finished ! Training Loss: 0.4238



  2%|▏         | 111/4791 [10:21:33<445:18:05, 342.54s/it]

------- 1st valloss=0.0613

Checkpoint 319 saved !


  2%|▏         | 112/4791 [10:26:59<438:45:51, 337.58s/it]

Epoch 320 finished ! Training Loss: 0.4052



  2%|▏         | 113/4791 [10:32:27<434:46:39, 334.59s/it]

Epoch 321 finished ! Training Loss: 0.4034



  2%|▏         | 114/4791 [10:37:49<429:49:49, 330.85s/it]

Epoch 322 finished ! Training Loss: 0.4129



  2%|▏         | 115/4791 [10:43:09<425:40:26, 327.72s/it]

Epoch 323 finished ! Training Loss: 0.4044

Epoch 324 finished ! Training Loss: 0.4202



  2%|▏         | 116/4791 [10:49:13<439:33:58, 338.49s/it]

------- 1st valloss=0.0600

Checkpoint 324 saved !


  2%|▏         | 117/4791 [10:54:37<433:58:33, 334.26s/it]

Epoch 325 finished ! Training Loss: 0.4025



  2%|▏         | 118/4791 [11:00:02<430:13:45, 331.44s/it]

Epoch 326 finished ! Training Loss: 0.4071



  2%|▏         | 119/4791 [11:05:27<427:36:37, 329.49s/it]

Epoch 327 finished ! Training Loss: 0.4173



  3%|▎         | 120/4791 [11:10:51<425:31:28, 327.96s/it]

Epoch 328 finished ! Training Loss: 0.4104

Epoch 329 finished ! Training Loss: 0.4138



  3%|▎         | 121/4791 [11:17:05<443:13:45, 341.68s/it]

------- 1st valloss=0.0601

Checkpoint 329 saved !


  3%|▎         | 122/4791 [11:22:29<436:09:30, 336.30s/it]

Epoch 330 finished ! Training Loss: 0.4157



KeyboardInterrupt: 

In [7]:
icnet1.eval()

with torch.no_grad():
    
    bgloss = 0
    bdloss = 0
    bvloss = 0
    
    for v, vbatch in tqdm(enumerate(validation_loader)):
            # move data to device, convert dtype to desirable dtype

        image_1 = vbatch['image1_data'].to(device=device, dtype=dtype)
        label_1 = vbatch['image1_label'].to(device=device, dtype=dtype)
        label_2 = vbatch['image2_label'].to(device=device, dtype=dtype)
        
        
        output = icnet1(image_1)
        # do the inference
        output_numpy = output.cpu().numpy()
        
        
        #out_1 = torch.round(output)
        out_1 = torch.from_numpy((output_numpy == output_numpy.max(axis=1)[:, None]).astype(int)).to(device=device, dtype=dtype)
        loss_1 = dice_loss_3(out_1, label_2)
        
        bg, bd, bv = dice_loss_3_debug(out_1, label_2)
        # calculate loss
        print(bg.item(), bd.item(), bv.item(), loss_1.item())
        bgloss += bg.item()
        bdloss += bd.item()
        bvloss += bv.item()

    outstr = '------- background loss = {0:.4f}, body loss = {1:.4f}, bv loss = {2:.4f}'\
        .format(bgloss/(v+1), bdloss/(v+1), bvloss/(v+1)) + '\n'
    print(outstr)

0it [00:00, ?it/s]

torch.Size([2, 3, 128, 128, 128])


AssertionError: 