# Notatnik wzorowany na https://github.com/usuyama/pytorch-unet/blob/master/pytorch_unet.ipynb

In [1]:
!pip install torchinfo



In [2]:
import os
#set workdir
os.chdir("/content/drive/MyDrive/RiverSemanticSegmentation/")

In [3]:
#imports
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from dataloader import Dataset
from torchinfo import summary
import time
import copy
import pdb
from tqdm import tqdm


In [4]:
#training parameters in neptune format
PARAMS = {
    "input_size": 416,
    "output_size": 416,
    "model": "vgg_unet",
    "learning_rate": 0.0001,
    "batch_size": 8,
    'epochs': 1000,
    'patience': 10,
    "train_dataset_size": -1, # set train dataset subset. Useful when neet to 
                              # overtrain model with small amount of images.
                              # -1 -all images from train directories.
    "test_dataset_size": -1,  # set test dataset subset.
                              # -1 -all images from train directories.
    "n_classes": 2,
    'image_preload': False,
}

In [5]:
#neptune installation and initialization
!pip install neptune-client
import neptune
neptune.init(project_qualified_name='radek/pth1',
             api_token='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5haSIsImFwaV91cmwiOiJodHRwczovL3VpLm5lcHR1bmUuYWkiLCJhcGlfa2V5IjoiYmY4YjQ3YjEtNmY5My00MDc2LWI4NzAtMWE5MmUwZjQ1NDE2In0=',
             )
neptune.create_experiment(params=PARAMS)


https://ui.neptune.ai/radek/pth1/e/PTH1-153


Experiment(PTH1-153)

In [6]:
#dataset configuration
dataset_dir = os.path.normpath("/content/drive/MyDrive/SemanticSegmentationV2/dataset/")
x_train_dir = os.path.join(dataset_dir,"x_train")
y_train_dir = os.path.join(dataset_dir,"y_train")
x_test_dir = os.path.join(dataset_dir,"x_test")
y_test_dir = os.path.join(dataset_dir,"y_test")

train_set = Dataset(x_train_dir, y_train_dir, input_size=PARAMS['input_size'], output_size=PARAMS['output_size'], n_classes=PARAMS["n_classes"], count=PARAMS["train_dataset_size"])
test_set = Dataset(x_test_dir, y_test_dir, input_size=PARAMS['input_size'], output_size=PARAMS['output_size'], n_classes=PARAMS["n_classes"], count=PARAMS["test_dataset_size"])

batch_size = PARAMS['batch_size']
dataloaders = {
    'train': DataLoader(train_set, batch_size=PARAMS['batch_size'], shuffle=True, num_workers=0),
    'val': DataLoader(test_set, batch_size=PARAMS['batch_size'], shuffle=True, num_workers=0)
}

In [7]:
# load images - useful if you want to save some time by preloading images (very time-consuming) when 
# the model is still not fuctional and cant run standard training.
if PARAMS['image_preload']:
  for phase in dataloaders:
    for inputs, labels in tqdm(dataloaders[phase]):
      pass


In [8]:
#model loading
if PARAMS['model'] == "simple":
  from models.simple import Simple
  model = Simple()
elif PARAMS['model'] == "vgg_unet":
  from models.vgg_unet import VggUnet
  model = VggUnet()
elif PARAMS['model'] == "vgg_unet_ks":
  from models.vgg_unet_ks import VggUnetKs
  model = VggUnetKs()
elif PARAMS['model'] == "unet":
  from models.unet import UNet
  model = UNet()
elif PARAMS['model'] == "vgg_deconvnet":
  from models.vgg16_deconvnet import VggDeconvNet
  model = VggDeconvNet()

In [9]:
#model structure preview
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model_stats = summary(model, input_size=(PARAMS['batch_size'], 3, PARAMS['input_size'], PARAMS['input_size']))
for line in str(model_stats).splitlines():
  neptune.log_text('model_summary', line)

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1                          []                        --
|    └─Conv2d: 2-1                       [8, 64, 416, 416]         1,792
|    └─ReLU: 2-2                         [8, 64, 416, 416]         --
|    └─Conv2d: 2-3                       [8, 64, 416, 416]         36,928
|    └─ReLU: 2-4                         [8, 64, 416, 416]         --
|    └─MaxPool2d: 2-5                    [8, 64, 208, 208]         --
|    └─Conv2d: 2-6                       [8, 128, 208, 208]        73,856
|    └─ReLU: 2-7                         [8, 128, 208, 208]        --
|    └─Conv2d: 2-8                       [8, 128, 208, 208]        147,584
|    └─ReLU: 2-9                         [8, 128, 208, 208]        --
|    └─MaxPool2d: 2-10                   [8, 128, 104, 104]        --
|    └─Conv2d: 2-11                      [8, 256, 104, 104]        295,168
|    └─ReLU: 2-12                        [8, 256, 104, 104]     

In [10]:
from collections import defaultdict
import torch.nn.functional as F
SMOOTH = 1e-6
def iou_metric(outputs: torch.Tensor, labels: torch.Tensor):
    outputs = outputs[:,1,:,:]  # BATCH x 1 x H x W => BATCH x H x W
    labels = labels[:,1,:,:]
    intersection = (outputs * labels).sum(2).sum(1)  # Will be zero if Truth=0 or Prediction=0
    union = (outputs + labels).sum(2).sum(1) - intersection  # Will be zzero if both are 0
    iou = (intersection + SMOOTH) / (union + SMOOTH)  # We smooth our devision to avoid 0/0
    return iou.mean()
    

def dice_loss(pred, target, smooth = 1.):
    pred = pred.contiguous()
    target = target.contiguous()    
    intersection = (pred * target).sum(dim=2).sum(dim=2)
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
    return loss.mean()

def calc_loss(pred, target, metrics, bce_weight=0.5):
    bce = F.binary_cross_entropy(pred, target)
    pred = torch.round(pred)
    dice = dice_loss(pred, target)
    loss = bce# * bce_weight + dice * (1 - bce_weight)
    iou = iou_metric(pred, target)
    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)
    metrics['iou'] += iou.data.cpu().numpy() * target.size(0)
    return loss

def print_metrics(metrics, epoch_samples, phase):   
    print(epoch_samples) 
    outputs = []
    for k in metrics.keys():
        outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))
        neptune.log_metric(phase+"_"+k, metrics[k] / epoch_samples) #log
    print("{}: {}".format(phase, ", ".join(outputs)))    



In [11]:
#training loop
def train_model(model, dataloaders, optimizer, device, num_epochs=25, patience=-1):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_accuracy = 0
    no_improvement = 0
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        since = time.time()

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                for param_group in optimizer.param_groups:
                    print("LR", param_group['lr'])
                    
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0
            
            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)             

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = calc_loss(outputs, labels, metrics)
                    #print(model.encoder[0].weight.grad)
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        #pdb.set_trace()
                        optimizer.step()

                # statistics
                epoch_samples += inputs.size(0)

            print_metrics(metrics, epoch_samples, phase)
            epoch_loss = metrics['loss'] / epoch_samples
            epoch_accuracy = metrics['iou'] / epoch_samples

            # deep copy the model
            if phase == 'val':
              if epoch_accuracy > best_accuracy:
                no_improvement = 0
                print("Val IoU improved by {}. Saving best model.".format(epoch_accuracy-best_accuracy))
                best_accuracy = epoch_accuracy
                best_model_wts = copy.deepcopy(model.state_dict())
              else:
                no_improvement += 1
                print("No accuracy improvement since {}/{} epochs.".format(no_improvement,patience))
        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        if patience >= 0 and no_improvement > patience:
          break
    print('Best accuracy: {:4f}'.format(best_accuracy))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [12]:
#model training
optimizer_ft = optim.Adam(model.parameters(), lr=PARAMS['learning_rate'])
model = train_model(model, dataloaders, optimizer_ft, device, num_epochs=PARAMS['epochs'], patience=PARAMS['patience'])

  0%|          | 0/302 [00:00<?, ?it/s]

Epoch 0/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]


2412


  0%|          | 0/72 [00:00<?, ?it/s]

train: bce: 0.223076, dice: 0.141114, loss: 0.223076, iou: 0.643406


100%|██████████| 72/72 [00:28<00:00,  2.49it/s]


570


  0%|          | 0/302 [00:00<?, ?it/s]

val: bce: 0.119871, dice: 0.114095, loss: 0.119871, iou: 0.681044
Val IoU improved by 0.6810440316534879. Saving best model.
3m 56s
Epoch 1/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.088848, dice: 0.099962, loss: 0.088848, iou: 0.721370


100%|██████████| 72/72 [00:29<00:00,  2.44it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.066423, dice: 0.083912, loss: 0.066423, iou: 0.759506
Val IoU improved by 0.07846201637334993. Saving best model.
3m 55s
Epoch 2/999
----------
LR 0.0001


100%|██████████| 302/302 [03:25<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.057832, dice: 0.082585, loss: 0.057832, iou: 0.763538


100%|██████████| 72/72 [00:29<00:00,  2.42it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.051249, dice: 0.070248, loss: 0.051249, iou: 0.793479
Val IoU improved by 0.03397314276611585. Saving best model.
3m 55s
Epoch 3/999
----------
LR 0.0001


100%|██████████| 302/302 [03:25<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.044603, dice: 0.073558, loss: 0.044603, iou: 0.786105


100%|██████████| 72/72 [00:29<00:00,  2.43it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.040683, dice: 0.066990, loss: 0.040683, iou: 0.802534
Val IoU improved by 0.00905519389269649. Saving best model.
3m 55s
Epoch 4/999
----------
LR 0.0001


100%|██████████| 302/302 [03:25<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.037242, dice: 0.066062, loss: 0.037242, iou: 0.805306


100%|██████████| 72/72 [00:29<00:00,  2.46it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.038989, dice: 0.065299, loss: 0.038989, iou: 0.804703
Val IoU improved by 0.002168249456506066. Saving best model.
3m 54s
Epoch 5/999
----------
LR 0.0001


100%|██████████| 302/302 [03:25<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.048485, dice: 0.076435, loss: 0.048485, iou: 0.779850


100%|██████████| 72/72 [00:29<00:00,  2.46it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.050235, dice: 0.075451, loss: 0.050235, iou: 0.780702
No accuracy improvement since 1/10 epochs.
3m 54s
Epoch 6/999
----------
LR 0.0001


100%|██████████| 302/302 [03:25<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.039620, dice: 0.069956, loss: 0.039620, iou: 0.795339


100%|██████████| 72/72 [00:29<00:00,  2.44it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.039169, dice: 0.068431, loss: 0.039169, iou: 0.799634
No accuracy improvement since 2/10 epochs.
3m 55s
Epoch 7/999
----------
LR 0.0001


100%|██████████| 302/302 [03:25<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.032381, dice: 0.061496, loss: 0.032381, iou: 0.816658


100%|██████████| 72/72 [00:29<00:00,  2.47it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.032486, dice: 0.058409, loss: 0.032486, iou: 0.824677
Val IoU improved by 0.01997387743832768. Saving best model.
3m 54s
Epoch 8/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.029834, dice: 0.057676, loss: 0.029834, iou: 0.825912


100%|██████████| 72/72 [00:28<00:00,  2.49it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.031334, dice: 0.056153, loss: 0.031334, iou: 0.829621
Val IoU improved by 0.004944970942380111. Saving best model.
3m 54s
Epoch 9/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.027999, dice: 0.055383, loss: 0.027999, iou: 0.832051


100%|██████████| 72/72 [00:29<00:00,  2.48it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.030127, dice: 0.054060, loss: 0.030127, iou: 0.835106
Val IoU improved by 0.005484218137306152. Saving best model.
3m 54s
Epoch 10/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.027457, dice: 0.054596, loss: 0.027457, iou: 0.834167


100%|██████████| 72/72 [00:28<00:00,  2.48it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.034677, dice: 0.053593, loss: 0.034677, iou: 0.833821
No accuracy improvement since 1/10 epochs.
3m 54s
Epoch 11/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.026637, dice: 0.052843, loss: 0.026637, iou: 0.838550


100%|██████████| 72/72 [00:28<00:00,  2.49it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.030239, dice: 0.051925, loss: 0.030239, iou: 0.839746
Val IoU improved by 0.004640308806770688. Saving best model.
3m 54s
Epoch 12/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.025441, dice: 0.050985, loss: 0.025441, iou: 0.842892


100%|██████████| 72/72 [00:28<00:00,  2.50it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.028771, dice: 0.051527, loss: 0.028771, iou: 0.842318
Val IoU improved by 0.0025721056419506505. Saving best model.
3m 54s
Epoch 13/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.024580, dice: 0.049191, loss: 0.024580, iou: 0.848119


100%|██████████| 72/72 [00:28<00:00,  2.48it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.031414, dice: 0.052139, loss: 0.031414, iou: 0.838121
No accuracy improvement since 1/10 epochs.
3m 54s
Epoch 14/999
----------
LR 0.0001


100%|██████████| 302/302 [03:25<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.023286, dice: 0.047644, loss: 0.023286, iou: 0.852445


100%|██████████| 72/72 [00:29<00:00,  2.41it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.029778, dice: 0.052031, loss: 0.029778, iou: 0.840106
No accuracy improvement since 2/10 epochs.
3m 55s
Epoch 15/999
----------
LR 0.0001


100%|██████████| 302/302 [03:25<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.027861, dice: 0.056083, loss: 0.027861, iou: 0.830276


100%|██████████| 72/72 [00:29<00:00,  2.41it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.036215, dice: 0.054453, loss: 0.036215, iou: 0.834669
No accuracy improvement since 3/10 epochs.
3m 55s
Epoch 16/999
----------
LR 0.0001


100%|██████████| 302/302 [03:25<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.034079, dice: 0.062938, loss: 0.034079, iou: 0.812989


100%|██████████| 72/72 [00:29<00:00,  2.44it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.033796, dice: 0.055780, loss: 0.033796, iou: 0.829602
No accuracy improvement since 4/10 epochs.
3m 55s
Epoch 17/999
----------
LR 0.0001


100%|██████████| 302/302 [03:25<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.024322, dice: 0.049614, loss: 0.024322, iou: 0.846876


100%|██████████| 72/72 [00:29<00:00,  2.46it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.027733, dice: 0.047366, loss: 0.027733, iou: 0.851761
Val IoU improved by 0.009443356279741244. Saving best model.
3m 54s
Epoch 18/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.021796, dice: 0.045399, loss: 0.021796, iou: 0.858559


100%|██████████| 72/72 [00:29<00:00,  2.47it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.028224, dice: 0.047534, loss: 0.028224, iou: 0.850698
No accuracy improvement since 1/10 epochs.
3m 54s
Epoch 19/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.021651, dice: 0.044721, loss: 0.021651, iou: 0.859828


100%|██████████| 72/72 [00:29<00:00,  2.48it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.026707, dice: 0.045829, loss: 0.026707, iou: 0.856075
Val IoU improved by 0.0043139673115913935. Saving best model.
3m 54s
Epoch 20/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.020775, dice: 0.043060, loss: 0.020775, iou: 0.864498


100%|██████████| 72/72 [00:28<00:00,  2.49it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.027736, dice: 0.046597, loss: 0.027736, iou: 0.854073
No accuracy improvement since 1/10 epochs.
3m 54s
Epoch 21/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.020833, dice: 0.043004, loss: 0.020833, iou: 0.863999


100%|██████████| 72/72 [00:28<00:00,  2.50it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.028707, dice: 0.047151, loss: 0.028707, iou: 0.852023
No accuracy improvement since 2/10 epochs.
3m 54s
Epoch 22/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.020856, dice: 0.043459, loss: 0.020856, iou: 0.863509


100%|██████████| 72/72 [00:28<00:00,  2.49it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.027700, dice: 0.046413, loss: 0.027700, iou: 0.854342
No accuracy improvement since 3/10 epochs.
3m 54s
Epoch 23/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.019994, dice: 0.041703, loss: 0.019994, iou: 0.867679


100%|██████████| 72/72 [00:28<00:00,  2.49it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.027233, dice: 0.047134, loss: 0.027233, iou: 0.852005
No accuracy improvement since 4/10 epochs.
3m 54s
Epoch 24/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.019613, dice: 0.041079, loss: 0.019613, iou: 0.869276


100%|██████████| 72/72 [00:28<00:00,  2.50it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.027918, dice: 0.046342, loss: 0.027918, iou: 0.855266
No accuracy improvement since 5/10 epochs.
3m 54s
Epoch 25/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.019613, dice: 0.041274, loss: 0.019613, iou: 0.869068


100%|██████████| 72/72 [00:29<00:00,  2.43it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.029038, dice: 0.048118, loss: 0.029038, iou: 0.849943
No accuracy improvement since 6/10 epochs.
3m 55s
Epoch 26/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.019899, dice: 0.041041, loss: 0.019899, iou: 0.869047


100%|██████████| 72/72 [00:29<00:00,  2.48it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.027739, dice: 0.044741, loss: 0.027739, iou: 0.858935
Val IoU improved by 0.002860029747611681. Saving best model.
3m 54s
Epoch 27/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.018647, dice: 0.038685, loss: 0.018647, iou: 0.875717


100%|██████████| 72/72 [00:28<00:00,  2.50it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.027707, dice: 0.046401, loss: 0.027707, iou: 0.854880
No accuracy improvement since 1/10 epochs.
3m 54s
Epoch 28/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.018159, dice: 0.038093, loss: 0.018159, iou: 0.877386


100%|██████████| 72/72 [00:28<00:00,  2.49it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.026996, dice: 0.044803, loss: 0.026996, iou: 0.858503
No accuracy improvement since 2/10 epochs.
3m 54s
Epoch 29/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.018431, dice: 0.038489, loss: 0.018431, iou: 0.876154


100%|██████████| 72/72 [00:28<00:00,  2.48it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.027569, dice: 0.044240, loss: 0.027569, iou: 0.860406
Val IoU improved by 0.0014701502364977692. Saving best model.
3m 54s
Epoch 30/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.034325, dice: 0.058827, loss: 0.034325, iou: 0.826081


100%|██████████| 72/72 [00:28<00:00,  2.49it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.031362, dice: 0.055975, loss: 0.031362, iou: 0.829951
No accuracy improvement since 1/10 epochs.
3m 54s
Epoch 31/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.028835, dice: 0.055819, loss: 0.028835, iou: 0.830856


100%|██████████| 72/72 [00:28<00:00,  2.49it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.026992, dice: 0.046702, loss: 0.026992, iou: 0.853133
No accuracy improvement since 2/10 epochs.
3m 54s
Epoch 32/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.020588, dice: 0.042653, loss: 0.020588, iou: 0.865219


100%|██████████| 72/72 [00:28<00:00,  2.50it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.027189, dice: 0.047527, loss: 0.027189, iou: 0.851422
No accuracy improvement since 3/10 epochs.
3m 54s
Epoch 33/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.018803, dice: 0.039296, loss: 0.018803, iou: 0.874559


100%|██████████| 72/72 [00:28<00:00,  2.50it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.026760, dice: 0.044183, loss: 0.026760, iou: 0.860444
Val IoU improved by 3.832495003419645e-05. Saving best model.
3m 54s
Epoch 34/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.017903, dice: 0.037199, loss: 0.017903, iou: 0.879753


100%|██████████| 72/72 [00:29<00:00,  2.44it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.026167, dice: 0.044346, loss: 0.026167, iou: 0.859886
No accuracy improvement since 1/10 epochs.
3m 54s
Epoch 35/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.017691, dice: 0.037076, loss: 0.017691, iou: 0.880250


100%|██████████| 72/72 [00:28<00:00,  2.49it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.027022, dice: 0.046587, loss: 0.027022, iou: 0.852758
No accuracy improvement since 2/10 epochs.
3m 54s
Epoch 36/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.017477, dice: 0.036346, loss: 0.017477, iou: 0.882024


100%|██████████| 72/72 [00:29<00:00,  2.43it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.027980, dice: 0.045524, loss: 0.027980, iou: 0.856707
No accuracy improvement since 3/10 epochs.
3m 55s
Epoch 37/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.017057, dice: 0.035646, loss: 0.017057, iou: 0.884389


100%|██████████| 72/72 [00:29<00:00,  2.43it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.030436, dice: 0.047121, loss: 0.030436, iou: 0.851813
No accuracy improvement since 4/10 epochs.
3m 55s
Epoch 38/999
----------
LR 0.0001


100%|██████████| 302/302 [03:25<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.017195, dice: 0.035685, loss: 0.017195, iou: 0.883791


100%|██████████| 72/72 [00:29<00:00,  2.45it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.026807, dice: 0.045222, loss: 0.026807, iou: 0.856824
No accuracy improvement since 5/10 epochs.
3m 54s
Epoch 39/999
----------
LR 0.0001


100%|██████████| 302/302 [03:25<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.016920, dice: 0.035297, loss: 0.016920, iou: 0.884982


100%|██████████| 72/72 [00:30<00:00,  2.39it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.026826, dice: 0.043606, loss: 0.026826, iou: 0.861690
Val IoU improved by 0.0012461367406343404. Saving best model.
3m 55s
Epoch 40/999
----------
LR 0.0001


100%|██████████| 302/302 [03:25<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.016450, dice: 0.034607, loss: 0.016450, iou: 0.887132


100%|██████████| 72/72 [00:30<00:00,  2.38it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.027790, dice: 0.043585, loss: 0.027790, iou: 0.861131
No accuracy improvement since 1/10 epochs.
3m 55s
Epoch 41/999
----------
LR 0.0001


100%|██████████| 302/302 [03:25<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.017977, dice: 0.037360, loss: 0.017977, iou: 0.880086


100%|██████████| 72/72 [00:30<00:00,  2.38it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.033516, dice: 0.052915, loss: 0.033516, iou: 0.836250
No accuracy improvement since 2/10 epochs.
3m 55s
Epoch 42/999
----------
LR 0.0001


100%|██████████| 302/302 [03:25<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.036270, dice: 0.067185, loss: 0.036270, iou: 0.802705


100%|██████████| 72/72 [00:29<00:00,  2.41it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.028066, dice: 0.047790, loss: 0.028066, iou: 0.849551
No accuracy improvement since 3/10 epochs.
3m 55s
Epoch 43/999
----------
LR 0.0001


100%|██████████| 302/302 [03:25<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.019822, dice: 0.041481, loss: 0.019822, iou: 0.868382


100%|██████████| 72/72 [00:29<00:00,  2.43it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.026640, dice: 0.044736, loss: 0.026640, iou: 0.858269
No accuracy improvement since 4/10 epochs.
3m 55s
Epoch 44/999
----------
LR 0.0001


100%|██████████| 302/302 [03:25<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.017264, dice: 0.036222, loss: 0.017264, iou: 0.882619


100%|██████████| 72/72 [00:29<00:00,  2.45it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.026474, dice: 0.043232, loss: 0.026474, iou: 0.862758
Val IoU improved by 0.0010681486966317433. Saving best model.
3m 54s
Epoch 45/999
----------
LR 0.0001


100%|██████████| 302/302 [03:25<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.016521, dice: 0.034785, loss: 0.016521, iou: 0.886686


100%|██████████| 72/72 [00:29<00:00,  2.44it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.027655, dice: 0.043714, loss: 0.027655, iou: 0.861684
No accuracy improvement since 1/10 epochs.
3m 55s
Epoch 46/999
----------
LR 0.0001


100%|██████████| 302/302 [03:25<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.016093, dice: 0.034048, loss: 0.016093, iou: 0.888998


100%|██████████| 72/72 [00:29<00:00,  2.41it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.028164, dice: 0.043410, loss: 0.028164, iou: 0.862267
No accuracy improvement since 2/10 epochs.
3m 55s
Epoch 47/999
----------
LR 0.0001


100%|██████████| 302/302 [03:25<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.015933, dice: 0.033277, loss: 0.015933, iou: 0.890735


100%|██████████| 72/72 [00:29<00:00,  2.45it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.028746, dice: 0.044815, loss: 0.028746, iou: 0.858362
No accuracy improvement since 3/10 epochs.
3m 55s
Epoch 48/999
----------
LR 0.0001


100%|██████████| 302/302 [03:25<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.015787, dice: 0.033177, loss: 0.015787, iou: 0.891076


100%|██████████| 72/72 [00:29<00:00,  2.46it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.028993, dice: 0.044369, loss: 0.028993, iou: 0.859514
No accuracy improvement since 4/10 epochs.
3m 54s
Epoch 49/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.015724, dice: 0.032981, loss: 0.015724, iou: 0.891536


100%|██████████| 72/72 [00:29<00:00,  2.46it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.027710, dice: 0.044520, loss: 0.027710, iou: 0.859268
No accuracy improvement since 5/10 epochs.
3m 54s
Epoch 50/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.015482, dice: 0.032483, loss: 0.015482, iou: 0.893043


100%|██████████| 72/72 [00:29<00:00,  2.44it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.029925, dice: 0.044311, loss: 0.029925, iou: 0.859685
No accuracy improvement since 6/10 epochs.
3m 55s
Epoch 51/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.015185, dice: 0.031815, loss: 0.015185, iou: 0.894867


100%|██████████| 72/72 [00:29<00:00,  2.47it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.029409, dice: 0.043945, loss: 0.029409, iou: 0.860829
No accuracy improvement since 7/10 epochs.
3m 54s
Epoch 52/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.015419, dice: 0.032803, loss: 0.015419, iou: 0.892388


100%|██████████| 72/72 [00:29<00:00,  2.46it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.028907, dice: 0.044500, loss: 0.028907, iou: 0.859692
No accuracy improvement since 8/10 epochs.
3m 54s
Epoch 53/999
----------
LR 0.0001


100%|██████████| 302/302 [03:25<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.015233, dice: 0.031759, loss: 0.015233, iou: 0.895065


100%|██████████| 72/72 [00:29<00:00,  2.48it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.027912, dice: 0.043255, loss: 0.027912, iou: 0.861970
No accuracy improvement since 9/10 epochs.
3m 54s
Epoch 54/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.015237, dice: 0.031984, loss: 0.015237, iou: 0.894438


100%|██████████| 72/72 [00:29<00:00,  2.45it/s]
  0%|          | 0/302 [00:00<?, ?it/s]

570
val: bce: 0.029380, dice: 0.043919, loss: 0.029380, iou: 0.860732
No accuracy improvement since 10/10 epochs.
3m 54s
Epoch 55/999
----------
LR 0.0001


100%|██████████| 302/302 [03:24<00:00,  1.47it/s]
  0%|          | 0/72 [00:00<?, ?it/s]

2412
train: bce: 0.015257, dice: 0.031904, loss: 0.015257, iou: 0.894753


100%|██████████| 72/72 [00:29<00:00,  2.48it/s]

570
val: bce: 0.030171, dice: 0.044514, loss: 0.030171, iou: 0.858877
No accuracy improvement since 11/10 epochs.
3m 54s
Best accuracy: 0.862758





In [13]:
# save weights
torch.save(model.state_dict(),"state_dict.pth")

In [23]:
neptune.log_artifact('state_dict.pth')

NeptuneNoExperimentContextException: ignored

In [14]:
# load weights
model.load_state_dict(torch.load("state_dict.pth", map_location="cpu"))

<All keys matched successfully>

In [15]:
# denormalization function
from torchvision import transforms
inv_normalize = transforms.Normalize(
   mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
   std=[1/0.229, 1/0.224, 1/0.225]
)

def reverse_transform(inp):
    print(inp.shape)
    inp = inv_normalize(inp)
    inp = inp.numpy()
    inp = np.swapaxes(inp, 1, 3)
    inp = np.swapaxes(inp, 1, 2)
    inp = np.clip(inp, 0, 1)
    inp = (inp * 255).astype(np.uint8)
    
    return inp
def labels2mask(labels):
    return labels[:,1,:,:]

In [16]:
# helper function to plot input, ground truth and predict images in grid
import matplotlib.pyplot as plt
def plot_side_by_side(rgb,ground_truth,predict):
  assert rgb.shape[0] == ground_truth.shape[0] == predict.shape[0]
  batch_size = rgb.shape[0]
  fig, axs = plt.subplots(batch_size, 3, figsize=(30,50))
  for i in range(batch_size):
    axs[i, 0].imshow(rgb[i])
    axs[i, 1].imshow(ground_truth[i])
    axs[i, 2].imshow(predict[i])

In [25]:
# visualize example segmentation
import math
model.eval()   # Set model to evaluate mode
test_dataset = Dataset(x_test_dir, y_test_dir, input_size=PARAMS['input_size'], output_size=PARAMS['output_size'], n_classes=PARAMS['n_classes'])
test_loader = DataLoader(test_dataset, batch_size=6, shuffle=True, num_workers=0)
inputs, labels = next(iter(test_loader))
inputs = inputs.to(device)
labels = labels.to(device)

labels = labels.data.cpu().numpy()
pred = model(inputs)

#print(pred)
pred = torch.round(pred)
#print(pred.size())

pred = pred.data.cpu().numpy()
inputs = inputs.data.cpu()

# dataloader return normalized input image, so we have to denormalize before viewing
input_images = reverse_transform(inputs)
# target and predict mask are single channel, so squeeze
target_masks = labels2mask(labels)

pred = labels2mask(pred)

# use helper function to plot
plot_side_by_side(input_images, target_masks, pred)

Output hidden; open in https://colab.research.google.com to view.

In [18]:
#evaluate model
test_dataset = Dataset(x_test_dir, y_test_dir, input_size=PARAMS['input_size'], output_size=PARAMS['output_size'], n_classes=PARAMS['n_classes'])
test_loader = DataLoader(test_dataset, batch_size=PARAMS["batch_size"], shuffle=True, num_workers=0)
intersection=0
union=0
for inputs, labels in tqdm(test_loader):
  inputs = inputs.to(device)
  labels = labels.to(device)
  labels = labels.data.cpu().numpy()
  pred = model(inputs)
  pred = torch.round(pred)
  pred = pred.data.cpu().numpy()
  target = labels[:,1,:,:]
  predict = pred[:,1,:,:]
  temp = (target * predict).sum()
  intersection+=temp
  union+=((target + predict).sum() - temp)
iou = intersection/union
print("IoU: {}".format(iou))
neptune.log_metric("total_iou",iou)

100%|██████████| 72/72 [00:30<00:00,  2.40it/s]


IoU: 0.901741448727608


In [19]:
# update neptune status
neptune.stop()