In [1]:
import sys
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_utility import *
from data_utils import *
from loss import *
from train import *
from deeplab_model.deeplab import *
from sync_batchnorm import convert_model
import datetime

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
USE_GPU = True
NUM_WORKERS = 12
BATCH_SIZE = 2 

dtype = torch.float32 
# define dtype, float is space efficient than double

if USE_GPU and torch.cuda.is_available():
    
    device = torch.device('cuda')
    
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    # magic flag that accelerate
    
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


In [3]:
train_dataset = pyramid_dataset(data_type = 'nii_train', 
                transform=transforms.Compose([
                random_affine(90, 15),
                random_filp(0.5)]))
# do data augumentation on train dataset

validation_dataset = pyramid_dataset(data_type = 'nii_test', 
                transform=None)
# no data augumentation on validation dataset

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=NUM_WORKERS) # drop_last
# loaders come with auto batch division and multi-thread acceleration

In [4]:
'''
deeplab = DeepLabModified(output_stride=16)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)
deeplab = deeplab.to(device=device, dtype=dtype)
#shape_test(icnet1, True)
# create the model, by default model type is float, use model.double(), model.float() to convert
# move the model to desirable device

optimizer = optim.Adam(deeplab.parameters(), lr=1e-2)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=7)
epoch = 0
'''
# create an optimizer object
# note that only the model_2 params and model_4 params will be optimized by optimizer

"\ndeeplab = DeepLabModified(output_stride=16)\ndeeplab = nn.DataParallel(deeplab)\ndeeplab = convert_model(deeplab)\ndeeplab = deeplab.to(device=device, dtype=dtype)\n#shape_test(icnet1, True)\n# create the model, by default model type is float, use model.double(), model.float() to convert\n# move the model to desirable device\n\noptimizer = optim.Adam(deeplab.parameters(), lr=1e-2)\nscheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=7)\nepoch = 0\n"

In [None]:

deeplab = DeepLabModified(output_stride=16)
deeplab = nn.DataParallel(deeplab)
deeplab = convert_model(deeplab)

checkpoint = torch.load('../deeplab_modified_save/2019-08-16 06:20:16.596756 epoch: 270.pth') # latest one

deeplab.load_state_dict(checkpoint['state_dict_1'])
deeplab = deeplab.to(device, dtype)

optimizer = optim.Adam(deeplab.parameters(), lr=1e-3)
#optimizer.load_state_dict(checkpoint['optimizer'])

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=15)
scheduler.load_state_dict(checkpoint['scheduler'])

epoch = checkpoint['epoch']
print(epoch)
print(sum(checkpoint['logger']['validation_1']) / len(checkpoint['logger']['validation_1']))
for param_group in optimizer.param_groups:
    print(param_group['lr'])


270
0.16324048634924454
0.001


In [None]:
epochs = 5000

min_val = 1

record = open('train_deeplab_modified.txt','a+')

logger = {'train':[], 'validation_1': []}

for e in tqdm(range(epoch + 1, epochs)):
# iter over epoches

    epoch_loss = 0
        
    for t, batch in enumerate(train_loader):
    # iter over the train mini batches
    
        deeplab.train()
        # Set the model flag to train
        # 1. enable dropout
        # 2. batchnorm behave differently in train and test
        
        image_1 = batch['image1_data'].to(device=device, dtype=dtype)
        label_1 = batch['image1_label'].to(device=device, dtype=dtype)
        # move data to device, convert dtype to desirable dtype
        
        out_1 = deeplab(image_1)
        # do the inference

        loss_1 = dice_loss_3(out_1, label_1)
        # calculate loss
        
        epoch_loss += loss_1.item()
        # record minibatch loss to epoch loss
        
        optimizer.zero_grad()
        # set the model parameter gradient to zero
        
        loss_1.backward()
        # calculate the gradient wrt loss
        optimizer.step()
        #scheduler.step(loss_1)
        # take a gradient descent step
        
    outstr = 'Epoch {0} finished ! Training Loss: {1:.4f}'.format(e, epoch_loss/(t+1)) + '\n'
    
    logger['train'].append(epoch_loss/(t+1))
    
    print(outstr)
    record.write(outstr)
    record.flush()

    if e%1 == 0:
    # do validation every 5 epoches
    
        deeplab.eval()
        # set model flag to eval
        # 1. disable dropout
        # 2. batchnorm behave differs

        with torch.no_grad():
        # stop taking gradient
        
            #valloss_4 = 0
            #valloss_2 = 0
            valloss_1 = 0
            
            for v, vbatch in enumerate(validation_loader):
            # iter over validation mini batches
                
                image_1_val = vbatch['image1_data'].to(device=device, dtype=dtype)
                if get_dimensions(image_1_val) == 4:
                    image_1_val.unsqueeze_(0)
                label_1_val = vbatch['image1_label'].to(device=device, dtype=dtype)
                if get_dimensions(label_1_val) == 4:
                    label_1_val.unsqueeze_(0)
                # move data to device, convert dtype to desirable dtype
                # add one dimension to labels if they are 4D tensors
                
                out_1_val = deeplab(image_1_val)
                # do the inference
                
                loss_1 = dice_loss_3(out_1_val, label_1_val)
                # calculate loss

                valloss_1 += loss_1.item()
                # record mini batch loss
            
            avg_val_loss = (valloss_1 / (v+1))
            outstr = '------- 1st valloss={0:.4f}'\
                .format(avg_val_loss) + '\n'
            
            logger['validation_1'].append(avg_val_loss)
            scheduler.step(avg_val_loss)
            
            print(outstr)
            record.write(outstr)
            record.flush()
            
            if avg_val_loss < min_val:
                print(avg_val_loss, "less than", min_val)
                min_val = avg_val_loss
                save_1('deeplab_modified_save', deeplab, optimizer, logger, e, scheduler)
            elif e % 10 == 0:
                save_1('deeplab_modified_save', deeplab, optimizer, logger, e, scheduler)

record.close()

  0%|          | 0/4729 [00:00<?, ?it/s]

Epoch 271 finished ! Training Loss: 0.1479

------- 1st valloss=0.0657

0.06566553935408592 less than 1


  0%|          | 1/4729 [16:35<1307:09:30, 995.30s/it]

Checkpoint 271 saved !
Epoch 272 finished ! Training Loss: 0.1412



  0%|          | 2/4729 [31:09<1259:03:08, 958.87s/it]

------- 1st valloss=0.0684

Epoch 273 finished ! Training Loss: 0.1391



  0%|          | 3/4729 [45:37<1223:00:38, 931.62s/it]

------- 1st valloss=0.0675

Epoch 274 finished ! Training Loss: 0.1289



  0%|          | 4/4729 [1:00:01<1196:10:46, 911.37s/it]

------- 1st valloss=0.0670

Epoch 275 finished ! Training Loss: 0.1412

------- 1st valloss=0.0642

0.06421103620010873 less than 0.06566553935408592


  0%|          | 5/4729 [1:14:39<1182:47:55, 901.37s/it]

Checkpoint 275 saved !
Epoch 276 finished ! Training Loss: 0.1239

------- 1st valloss=0.0632

0.06324755934917409 less than 0.06421103620010873


  0%|          | 6/4729 [1:29:15<1172:47:12, 893.93s/it]

Checkpoint 276 saved !
Epoch 277 finished ! Training Loss: 0.1339



  0%|          | 7/4729 [1:43:46<1163:32:15, 887.07s/it]

------- 1st valloss=0.0639

Epoch 278 finished ! Training Loss: 0.1232



  0%|          | 8/4729 [1:58:09<1153:36:52, 879.69s/it]

------- 1st valloss=0.0652

Epoch 279 finished ! Training Loss: 0.1289



  0%|          | 9/4729 [2:12:44<1151:23:12, 878.18s/it]

------- 1st valloss=0.0934

Epoch 280 finished ! Training Loss: 0.1331

------- 1st valloss=0.0696



  0%|          | 10/4729 [2:27:07<1145:20:20, 873.75s/it]

Checkpoint 280 saved !
Epoch 281 finished ! Training Loss: 0.1385



  0%|          | 11/4729 [2:41:42<1145:29:31, 874.05s/it]

------- 1st valloss=0.0641

Epoch 282 finished ! Training Loss: 0.1267



  0%|          | 12/4729 [2:56:10<1142:54:16, 872.26s/it]

------- 1st valloss=0.0878

Epoch 284 finished ! Training Loss: 0.1269



  0%|          | 14/4729 [3:25:01<1137:48:45, 868.74s/it]

------- 1st valloss=0.0657

Epoch 285 finished ! Training Loss: 0.1274



  0%|          | 15/4729 [3:39:28<1136:44:38, 868.11s/it]

------- 1st valloss=0.0644

Epoch 286 finished ! Training Loss: 0.1158



  0%|          | 16/4729 [3:54:00<1138:11:46, 869.41s/it]

------- 1st valloss=0.0674

Epoch 287 finished ! Training Loss: 0.1156



  0%|          | 17/4729 [4:08:35<1140:06:46, 871.05s/it]

------- 1st valloss=0.0638

Epoch 288 finished ! Training Loss: 0.1259

------- 1st valloss=0.0626

0.06259030087486557 less than 0.06324755934917409


  0%|          | 18/4729 [4:23:03<1138:35:04, 870.07s/it]

Checkpoint 288 saved !
Epoch 289 finished ! Training Loss: 0.1274



  0%|          | 19/4729 [4:37:25<1135:16:52, 867.73s/it]

------- 1st valloss=0.0631

Epoch 290 finished ! Training Loss: 0.1375

------- 1st valloss=0.0637



  0%|          | 20/4729 [4:51:52<1134:49:28, 867.57s/it]

Checkpoint 290 saved !
Epoch 291 finished ! Training Loss: 0.1198



  0%|          | 21/4729 [5:06:19<1134:10:56, 867.26s/it]

------- 1st valloss=0.0643

Epoch 292 finished ! Training Loss: 0.1330



  0%|          | 22/4729 [5:21:02<1140:18:53, 872.13s/it]

------- 1st valloss=0.0632

Epoch 293 finished ! Training Loss: 0.1158



  0%|          | 23/4729 [5:35:23<1135:25:17, 868.58s/it]

------- 1st valloss=0.0645

Epoch 294 finished ! Training Loss: 0.1286



  1%|          | 24/4729 [5:49:54<1136:11:37, 869.35s/it]

------- 1st valloss=0.0682

Epoch 295 finished ! Training Loss: 0.1330



  1%|          | 25/4729 [6:04:26<1137:07:03, 870.24s/it]

------- 1st valloss=0.0643

Epoch 296 finished ! Training Loss: 0.1275



  1%|          | 26/4729 [6:18:55<1136:10:11, 869.70s/it]

------- 1st valloss=0.0637

Epoch 297 finished ! Training Loss: 0.1272



  1%|          | 27/4729 [6:33:21<1134:39:10, 868.73s/it]

------- 1st valloss=0.0696

Epoch 298 finished ! Training Loss: 0.1058



  1%|          | 28/4729 [6:47:41<1131:08:53, 866.23s/it]

------- 1st valloss=0.0634

Epoch 299 finished ! Training Loss: 0.1384



  1%|          | 29/4729 [7:02:13<1133:03:03, 867.87s/it]

------- 1st valloss=0.0642

Epoch 300 finished ! Training Loss: 0.1270

------- 1st valloss=0.0638



  1%|          | 30/4729 [7:16:47<1134:59:52, 869.55s/it]

Checkpoint 300 saved !
Epoch 301 finished ! Training Loss: 0.1263



  1%|          | 31/4729 [7:31:08<1131:33:17, 867.09s/it]

------- 1st valloss=0.0638

Epoch 302 finished ! Training Loss: 0.1312



  1%|          | 32/4729 [7:45:37<1132:06:08, 867.70s/it]

------- 1st valloss=0.0654

Epoch 303 finished ! Training Loss: 0.1270



  1%|          | 33/4729 [8:00:01<1130:36:07, 866.73s/it]

------- 1st valloss=0.0988

Epoch 304 finished ! Training Loss: 0.1498



  1%|          | 34/4729 [8:14:29<1130:35:55, 866.91s/it]

------- 1st valloss=0.0626

Epoch 305 finished ! Training Loss: 0.1386



  1%|          | 35/4729 [8:29:00<1132:07:08, 868.26s/it]

------- 1st valloss=0.0651

Epoch 306 finished ! Training Loss: 0.1102



  1%|          | 36/4729 [8:43:41<1136:40:03, 871.94s/it]

------- 1st valloss=0.0632

Epoch 307 finished ! Training Loss: 0.1203



  1%|          | 37/4729 [8:58:06<1133:52:28, 869.98s/it]

------- 1st valloss=0.0652

Epoch 308 finished ! Training Loss: 0.1264



  1%|          | 38/4729 [9:12:48<1138:16:12, 873.54s/it]

------- 1st valloss=0.0629

Epoch 309 finished ! Training Loss: 0.1374



  1%|          | 39/4729 [9:27:22<1138:17:19, 873.74s/it]

------- 1st valloss=0.0635

Epoch 310 finished ! Training Loss: 0.1321

------- 1st valloss=0.0631



  1%|          | 40/4729 [9:41:41<1132:14:19, 869.28s/it]

Checkpoint 310 saved !
Epoch 311 finished ! Training Loss: 0.1379



  1%|          | 41/4729 [9:56:01<1128:27:35, 866.56s/it]

------- 1st valloss=0.0628

Epoch 312 finished ! Training Loss: 0.1322



  1%|          | 42/4729 [10:10:39<1132:29:12, 869.84s/it]

------- 1st valloss=0.0673

Epoch 313 finished ! Training Loss: 0.1329



  1%|          | 43/4729 [10:24:54<1126:38:58, 865.54s/it]

------- 1st valloss=0.0632

Epoch 314 finished ! Training Loss: 0.1214



  1%|          | 44/4729 [10:39:11<1122:57:43, 862.90s/it]

------- 1st valloss=0.0753

Epoch 315 finished ! Training Loss: 0.1359



  1%|          | 45/4729 [10:53:44<1126:42:45, 865.96s/it]

------- 1st valloss=0.0722

Epoch 316 finished ! Training Loss: 0.1311



  1%|          | 46/4729 [11:08:14<1128:03:44, 867.18s/it]

------- 1st valloss=0.0645

Epoch 317 finished ! Training Loss: 0.1343



  1%|          | 47/4729 [11:22:39<1127:02:13, 866.58s/it]

------- 1st valloss=0.0658

Epoch 318 finished ! Training Loss: 0.1253



  1%|          | 48/4729 [11:37:04<1125:56:17, 865.92s/it]

------- 1st valloss=0.0651



In [None]:
deeplab.eval()

with torch.no_grad():
    
    bgloss = 0
    bdloss = 0
    bvloss = 0
    
    for v, vbatch in tqdm(enumerate(validation_loader)):
            # move data to device, convert dtype to desirable dtype

        image_1 = vbatch['image1_data'].to(device=device, dtype=dtype)
        label_1 = vbatch['image1_label'].to(device=device, dtype=dtype)

        output = deeplab(image_1)
        # do the inference
        output_numpy = output.cpu().numpy()
        
        
        #out_1 = torch.round(output)
        out_1 = torch.from_numpy((output_numpy == output_numpy.max(axis=1)[:, None]).astype(int)).to(device=device, dtype=dtype)
        loss_1 = dice_loss_3(out_1, label_1)

        bg, bd, bv = dice_loss_3_debug(out_1, label_1)
        # calculate loss
        print(bg.item(), bd.item(), bv.item(), loss_1.item())
        bgloss += bg.item()
        bdloss += bd.item()
        bvloss += bv.item()
        
        if bv.item() >= 0.2 or bd.item() >= 0.1:
            show_image_slice(image_1)
            show_image_slice(label_1)
            show_image_slice(output)

    outstr = '------- background loss = {0:.4f}, body loss = {1:.4f}, bv loss = {2:.4f}'\
        .format(bgloss/(v+1), bdloss/(v+1), bvloss/(v+1)) + '\n'
    print(outstr)