In [1]:
!pip install neptune-client
!pip install torchinfo

Collecting neptune-client
[?25l  Downloading https://files.pythonhosted.org/packages/49/1a/a3ad339640f7bd73118b9afeeb4b411ae82c149d258c013075f603e118ef/neptune-client-0.5.1.tar.gz (115kB)
[K     |██▉                             | 10kB 20.4MB/s eta 0:00:01[K     |█████▊                          | 20kB 27.1MB/s eta 0:00:01[K     |████████▌                       | 30kB 25.5MB/s eta 0:00:01[K     |███████████▍                    | 40kB 19.3MB/s eta 0:00:01[K     |██████████████▏                 | 51kB 16.3MB/s eta 0:00:01[K     |█████████████████               | 61kB 15.0MB/s eta 0:00:01[K     |████████████████████            | 71kB 13.9MB/s eta 0:00:01[K     |██████████████████████▊         | 81kB 14.5MB/s eta 0:00:01[K     |█████████████████████████▋      | 92kB 13.8MB/s eta 0:00:01[K     |████████████████████████████▍   | 102kB 13.0MB/s eta 0:00:01[K     |███████████████████████████████▎| 112kB 13.0MB/s eta 0:00:01[K     |████████████████████████████████| 122kB 

In [2]:
# -*- coding: utf-8 -*-

import os
#set workdir
os.chdir("/content/drive/MyDrive/DEM-waterlevel/ml/")

#imports
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from dataloader import DenoiseDataset
from torchinfo import summary
import time
import copy
import pdb
from tqdm import tqdm

#training parameters in neptune format
PARAMS = {
    "img_size": 256,
    "model": "vgg_unet",
    "learning_rate": 0.01,
    "batch_size": 8,
    'epochs': 1000,
    'patience': 10,
    "train_dataset_size": -1, # set train dataset subset. Useful when neet to 
                              # overtrain model with small amount of images.
                              # -1 -all images from train directories.
    "test_dataset_size": -1,  # set test dataset subset.
                              # -1 -all images from train directories.
    'image_preload': False,
}


#dataset configuration
dataset_dir = os.path.normpath("/content/drive/MyDrive/DEM-waterlevel/dataset")
train_dir = os.path.join(dataset_dir,"train")
test_dir = os.path.join(dataset_dir,"test")

train_set = DenoiseDataset(train_dir, img_size=PARAMS['img_size'], count=PARAMS["train_dataset_size"])
test_set = DenoiseDataset(test_dir, img_size=PARAMS['img_size'], count=PARAMS["test_dataset_size"])

batch_size = PARAMS['batch_size']
dataloaders = {
    'train': DataLoader(train_set, batch_size=PARAMS['batch_size'], shuffle=True, num_workers=0),
    'val': DataLoader(test_set, batch_size=PARAMS['batch_size'], shuffle=True, num_workers=0)
}

# load images - useful if you want to save some time by preloading images (very time-consuming) when 
# the model is still not fuctional and cant run standard training.
if PARAMS['image_preload']:
  for phase in dataloaders:
    for inputs, labels in tqdm(dataloaders[phase]):
      pass

#model loading
elif PARAMS['model'] == "vgg_unet":
  from models.vgg_unet import VggUnet
  model = VggUnet()

#model structure preview
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


HBox(children=(FloatProgress(value=0.0, max=553433881.0), HTML(value='')))




In [3]:
#neptune initialization
import neptune
neptune.init(project_qualified_name='radek/denoise1',
             api_token='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5haSIsImFwaV91cmwiOiJodHRwczovL3VpLm5lcHR1bmUuYWkiLCJhcGlfa2V5IjoiYmY4YjQ3YjEtNmY5My00MDc2LWI4NzAtMWE5MmUwZjQ1NDE2In0=',
             )
neptune.create_experiment(params=PARAMS)

model = model.to(device)
model_stats = summary(model, input_size=(PARAMS['batch_size'], 4, PARAMS['img_size'], PARAMS['img_size']))
for line in str(model_stats).splitlines():
  neptune.log_text('model_summary', line)

from collections import defaultdict
import torch.nn.functional as F

def calc_loss(pred, target, metrics):
    loss = F.mse_loss(pred, target)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)
    return loss

def print_metrics(metrics, epoch_samples, phase):   
    print(epoch_samples) 
    outputs = []
    for k in metrics.keys():
        outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))
        neptune.log_metric(phase+"_"+k, metrics[k] / epoch_samples) #log
    print("{}: {}".format(phase, ", ".join(outputs)))

#training loop
def train_model(model, dataloaders, optimizer, device, num_epochs=25, patience=-1):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = float('inf')
    no_improvement = 0
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        since = time.time()

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                for param_group in optimizer.param_groups:
                    print("LR", param_group['lr'])
                    
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0
            
            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)             

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = calc_loss(outputs, labels, metrics)
                    #print(model.encoder[0].weight.grad)
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        #pdb.set_trace()
                        optimizer.step()

                # statistics
                epoch_samples += inputs.size(0)

            print_metrics(metrics, epoch_samples, phase)
            epoch_loss = metrics['loss'] / epoch_samples

            # deep copy the model
            if phase == 'val':
              if epoch_loss < best_loss:
                no_improvement = 0
                print("Val loss improved by {}. Saving best model.".format(best_loss-epoch_loss))
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())
              else:
                no_improvement += 1
                print("No loss improvement since {}/{} epochs.".format(no_improvement,patience))
        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        if patience >= 0 and no_improvement > patience:
          break
    print('Best loss: {:4f}'.format(best_loss))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

#model training
optimizer_ft = optim.Adam(model.parameters(), lr=PARAMS['learning_rate'])
model = train_model(model, dataloaders, optimizer_ft, device, num_epochs=PARAMS['epochs'], patience=PARAMS['patience'])

# save weights
torch.save(model.state_dict(),"state_dict.pth")

neptune.log_artifact('state_dict.pth')



https://ui.neptune.ai/radek/denoise1/e/DEN1-38


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch 0/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  6.51it/s]
 33%|███▎      | 1/3 [00:00<00:00,  6.86it/s]

90
train: loss: 44983.087674


100%|██████████| 3/3 [00:00<00:00,  7.38it/s]
  0%|          | 0/12 [00:00<?, ?it/s]

22
val: loss: 14109.768821
Val loss improved by inf. Saving best model.
0m 2s
Epoch 1/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  6.53it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.49it/s]

90
train: loss: 40355.835764


100%|██████████| 3/3 [00:00<00:00,  7.90it/s]
  8%|▊         | 1/12 [00:00<00:01,  6.42it/s]

22
val: loss: 51661.409091
No loss improvement since 1/10 epochs.
0m 2s
Epoch 2/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  6.84it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.54it/s]

90
train: loss: 27200.053776


100%|██████████| 3/3 [00:00<00:00,  7.81it/s]
  8%|▊         | 1/12 [00:00<00:01,  6.94it/s]

22
val: loss: 1230315.840909
No loss improvement since 2/10 epochs.
0m 2s
Epoch 3/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.20it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.24it/s]

90
train: loss: 10340.071387


100%|██████████| 3/3 [00:00<00:00,  9.01it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.05it/s]

22
val: loss: 856869.909091
No loss improvement since 3/10 epochs.
0m 2s
Epoch 4/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.14it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.31it/s]

90
train: loss: 3064.687348


100%|██████████| 3/3 [00:00<00:00,  8.91it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.06it/s]

22
val: loss: 317433.113636
No loss improvement since 4/10 epochs.
0m 2s
Epoch 5/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.29it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.05it/s]

90
train: loss: 1885.111990


100%|██████████| 3/3 [00:00<00:00,  8.80it/s]
  8%|▊         | 1/12 [00:00<00:01,  6.49it/s]

22
val: loss: 21559.242188
No loss improvement since 5/10 epochs.
0m 2s
Epoch 6/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.09it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.39it/s]

90
train: loss: 1244.256795


100%|██████████| 3/3 [00:00<00:00,  8.10it/s]
  8%|▊         | 1/12 [00:00<00:01,  6.51it/s]

22
val: loss: 36312.169389
No loss improvement since 6/10 epochs.
0m 2s
Epoch 7/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.10it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.81it/s]

90
train: loss: 886.672248


100%|██████████| 3/3 [00:00<00:00,  8.18it/s]
  8%|▊         | 1/12 [00:00<00:01,  6.85it/s]

22
val: loss: 3460.790838
Val loss improved by 10648.977982954546. Saving best model.
0m 2s
Epoch 8/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.08it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.00it/s]

90
train: loss: 775.059736


100%|██████████| 3/3 [00:00<00:00,  8.54it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.58it/s]

22
val: loss: 8477.166903
No loss improvement since 1/10 epochs.
0m 2s
Epoch 9/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.43it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.00it/s]

90
train: loss: 751.669968


100%|██████████| 3/3 [00:00<00:00,  8.84it/s]
  8%|▊         | 1/12 [00:00<00:01,  6.98it/s]

22
val: loss: 5436.642312
No loss improvement since 2/10 epochs.
0m 2s
Epoch 10/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.17it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.24it/s]

90
train: loss: 740.110758


100%|██████████| 3/3 [00:00<00:00,  8.85it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.64it/s]

22
val: loss: 1503.180376
Val loss improved by 1957.6104625355115. Saving best model.
0m 2s
Epoch 11/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.29it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.38it/s]

90
train: loss: 734.285314


100%|██████████| 3/3 [00:00<00:00,  8.76it/s]
  8%|▊         | 1/12 [00:00<00:01,  6.99it/s]

22
val: loss: 3059.556396
No loss improvement since 1/10 epochs.
0m 2s
Epoch 12/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.33it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.17it/s]

90
train: loss: 732.031905


100%|██████████| 3/3 [00:00<00:00,  8.67it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.53it/s]

22
val: loss: 1218.463568
Val loss improved by 284.71680797230124. Saving best model.
0m 2s
Epoch 13/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.33it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.83it/s]

90
train: loss: 730.060582


100%|██████████| 3/3 [00:00<00:00,  8.70it/s]
  8%|▊         | 1/12 [00:00<00:01,  6.99it/s]

22
val: loss: 823.065629
Val loss improved by 395.39793812144876. Saving best model.
0m 2s
Epoch 14/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.21it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.15it/s]

90
train: loss: 728.249495


100%|██████████| 3/3 [00:00<00:00,  8.73it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.25it/s]

22
val: loss: 848.864469
No loss improvement since 1/10 epochs.
0m 2s
Epoch 15/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.40it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.69it/s]

90
train: loss: 726.433241


100%|██████████| 3/3 [00:00<00:00,  8.68it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.69it/s]

22
val: loss: 902.594532
No loss improvement since 2/10 epochs.
0m 2s
Epoch 16/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.35it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.32it/s]

90
train: loss: 724.616995


100%|██████████| 3/3 [00:00<00:00,  9.04it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.69it/s]

22
val: loss: 914.318748
No loss improvement since 3/10 epochs.
0m 2s
Epoch 17/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.40it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.46it/s]

90
train: loss: 722.796864


100%|██████████| 3/3 [00:00<00:00,  8.78it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.68it/s]

22
val: loss: 960.976174
No loss improvement since 4/10 epochs.
0m 2s
Epoch 18/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.35it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.96it/s]

90
train: loss: 721.559248


100%|██████████| 3/3 [00:00<00:00,  8.76it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.64it/s]

22
val: loss: 747.804033
Val loss improved by 75.2615966796875. Saving best model.
0m 2s
Epoch 19/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.29it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.58it/s]

90
train: loss: 720.590679


100%|██████████| 3/3 [00:00<00:00,  9.17it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.93it/s]

22
val: loss: 725.435325
Val loss improved by 22.368707830255744. Saving best model.
0m 2s
Epoch 20/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.33it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.26it/s]

90
train: loss: 719.481443


100%|██████████| 3/3 [00:00<00:00,  9.11it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.97it/s]

22
val: loss: 833.154813
No loss improvement since 1/10 epochs.
0m 2s
Epoch 21/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.37it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.94it/s]

90
train: loss: 718.936217


100%|██████████| 3/3 [00:00<00:00,  8.32it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.80it/s]

22
val: loss: 889.382247
No loss improvement since 2/10 epochs.
0m 2s
Epoch 22/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.48it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.83it/s]

90
train: loss: 718.304964


100%|██████████| 3/3 [00:00<00:00,  8.79it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.44it/s]

22
val: loss: 736.834278
No loss improvement since 3/10 epochs.
0m 2s
Epoch 23/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.28it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.56it/s]

90
train: loss: 717.384405


100%|██████████| 3/3 [00:00<00:00,  8.23it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.46it/s]

22
val: loss: 859.891629
No loss improvement since 4/10 epochs.
0m 2s
Epoch 24/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.28it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.85it/s]

90
train: loss: 716.654319


100%|██████████| 3/3 [00:00<00:00,  8.63it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.85it/s]

22
val: loss: 870.000166
No loss improvement since 5/10 epochs.
0m 2s
Epoch 25/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.31it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.10it/s]

90
train: loss: 716.120983


100%|██████████| 3/3 [00:00<00:00,  8.81it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.91it/s]

22
val: loss: 730.262140
No loss improvement since 6/10 epochs.
0m 2s
Epoch 26/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.43it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.81it/s]

90
train: loss: 715.146701


100%|██████████| 3/3 [00:00<00:00,  8.87it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.78it/s]

22
val: loss: 702.370722
Val loss improved by 23.064602938565372. Saving best model.
0m 2s
Epoch 27/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.40it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.36it/s]

90
train: loss: 537.698875


100%|██████████| 3/3 [00:00<00:00,  9.25it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.40it/s]

22
val: loss: 2653.296697
No loss improvement since 1/10 epochs.
0m 2s
Epoch 28/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.46it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.89it/s]

90
train: loss: 390.477505


100%|██████████| 3/3 [00:00<00:00,  8.95it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.82it/s]

22
val: loss: 892.066928
No loss improvement since 2/10 epochs.
0m 2s
Epoch 29/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.35it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.92it/s]

90
train: loss: 372.017027


100%|██████████| 3/3 [00:00<00:00,  8.62it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.66it/s]

22
val: loss: 648.410966
Val loss improved by 53.959755637428884. Saving best model.
0m 2s
Epoch 30/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.33it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.58it/s]

90
train: loss: 367.855530


100%|██████████| 3/3 [00:00<00:00,  8.53it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.54it/s]

22
val: loss: 444.690144
Val loss improved by 203.72082242098725. Saving best model.
0m 2s
Epoch 31/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.52it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.20it/s]

90
train: loss: 365.816588


100%|██████████| 3/3 [00:00<00:00,  8.64it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.86it/s]

22
val: loss: 411.280401
Val loss improved by 33.409742875532686. Saving best model.
0m 2s
Epoch 32/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.43it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.28it/s]

90
train: loss: 364.454742


100%|██████████| 3/3 [00:00<00:00,  8.95it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.62it/s]

22
val: loss: 369.479881
Val loss improved by 41.80052046342331. Saving best model.
0m 2s
Epoch 33/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.28it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.00it/s]

90
train: loss: 363.541522


100%|██████████| 3/3 [00:00<00:00,  8.62it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.01it/s]

22
val: loss: 369.718675
No loss improvement since 1/10 epochs.
0m 2s
Epoch 34/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.25it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.42it/s]

90
train: loss: 362.763069


100%|██████████| 3/3 [00:00<00:00,  8.45it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.20it/s]

22
val: loss: 367.456554
Val loss improved by 2.023326526988626. Saving best model.
0m 2s
Epoch 35/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.47it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.46it/s]

90
train: loss: 362.306584


100%|██████████| 3/3 [00:00<00:00,  8.86it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.83it/s]

22
val: loss: 366.982758
Val loss improved by 0.47379649769175103. Saving best model.
0m 2s
Epoch 36/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.40it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.27it/s]

90
train: loss: 361.631968


100%|██████████| 3/3 [00:00<00:00,  8.31it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.77it/s]

22
val: loss: 368.548590
No loss improvement since 1/10 epochs.
0m 2s
Epoch 37/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.39it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.81it/s]

90
train: loss: 361.179326


100%|██████████| 3/3 [00:00<00:00,  8.48it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.24it/s]

22
val: loss: 364.438113
Val loss improved by 2.544644442471565. Saving best model.
0m 2s
Epoch 38/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.42it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.30it/s]

90
train: loss: 360.767152


100%|██████████| 3/3 [00:00<00:00,  8.95it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.71it/s]

22
val: loss: 362.047255
Val loss improved by 2.390858043323874. Saving best model.
0m 2s
Epoch 39/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.46it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.12it/s]

90
train: loss: 360.484924


100%|██████████| 3/3 [00:00<00:00,  8.55it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.47it/s]

22
val: loss: 362.151312
No loss improvement since 1/10 epochs.
0m 2s
Epoch 40/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.29it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.57it/s]

90
train: loss: 360.265808


100%|██████████| 3/3 [00:00<00:00,  8.36it/s]
  8%|▊         | 1/12 [00:00<00:01,  6.96it/s]

22
val: loss: 362.597803
No loss improvement since 2/10 epochs.
0m 2s
Epoch 41/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.19it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.30it/s]

90
train: loss: 360.170034


100%|██████████| 3/3 [00:00<00:00,  8.98it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.63it/s]

22
val: loss: 369.755252
No loss improvement since 3/10 epochs.
0m 2s
Epoch 42/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.30it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.65it/s]

90
train: loss: 359.913603


100%|██████████| 3/3 [00:00<00:00,  8.69it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.57it/s]

22
val: loss: 370.838282
No loss improvement since 4/10 epochs.
0m 2s
Epoch 43/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.38it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.86it/s]

90
train: loss: 359.695148


100%|██████████| 3/3 [00:00<00:00,  8.35it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.80it/s]

22
val: loss: 360.740290
Val loss improved by 1.306965221058249. Saving best model.
0m 2s
Epoch 44/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.39it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.11it/s]

90
train: loss: 359.511902


100%|██████████| 3/3 [00:00<00:00,  8.55it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.85it/s]

22
val: loss: 361.915230
No loss improvement since 1/10 epochs.
0m 2s
Epoch 45/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.39it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.21it/s]

90
train: loss: 359.397860


100%|██████████| 3/3 [00:00<00:00,  9.03it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.77it/s]

22
val: loss: 364.833957
No loss improvement since 2/10 epochs.
0m 2s
Epoch 46/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.41it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.17it/s]

90
train: loss: 359.310509


100%|██████████| 3/3 [00:00<00:00,  8.87it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.82it/s]

22
val: loss: 360.546842
Val loss improved by 0.19344815340912191. Saving best model.
0m 2s
Epoch 47/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.32it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.67it/s]

90
train: loss: 359.177262


100%|██████████| 3/3 [00:00<00:00,  8.40it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.63it/s]

22
val: loss: 361.996582
No loss improvement since 1/10 epochs.
0m 2s
Epoch 48/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.30it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.91it/s]

90
train: loss: 359.148431


100%|██████████| 3/3 [00:00<00:00,  8.25it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.26it/s]

22
val: loss: 360.645214
No loss improvement since 2/10 epochs.
0m 2s
Epoch 49/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.40it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.90it/s]

90
train: loss: 359.163820


100%|██████████| 3/3 [00:00<00:00,  8.40it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.37it/s]

22
val: loss: 368.526711
No loss improvement since 3/10 epochs.
0m 2s
Epoch 50/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.15it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.71it/s]

90
train: loss: 359.050997


100%|██████████| 3/3 [00:00<00:00,  8.42it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.80it/s]

22
val: loss: 359.612116
Val loss improved by 0.9347256747158781. Saving best model.
0m 2s
Epoch 51/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.31it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.20it/s]

90
train: loss: 358.902973


100%|██████████| 3/3 [00:00<00:00,  8.76it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.57it/s]

22
val: loss: 361.187392
No loss improvement since 1/10 epochs.
0m 2s
Epoch 52/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.48it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.29it/s]

90
train: loss: 358.830685


100%|██████████| 3/3 [00:00<00:00,  9.00it/s]
  8%|▊         | 1/12 [00:00<00:01,  8.04it/s]

22
val: loss: 359.732255
No loss improvement since 2/10 epochs.
0m 2s
Epoch 53/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.33it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.83it/s]

90
train: loss: 359.006303


100%|██████████| 3/3 [00:00<00:00,  8.43it/s]
  8%|▊         | 1/12 [00:00<00:01,  8.08it/s]

22
val: loss: 359.711348
No loss improvement since 3/10 epochs.
0m 2s
Epoch 54/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.43it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.64it/s]

90
train: loss: 358.809108


100%|██████████| 3/3 [00:00<00:00,  8.57it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.36it/s]

22
val: loss: 360.695579
No loss improvement since 4/10 epochs.
0m 2s
Epoch 55/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.27it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.04it/s]

90
train: loss: 358.956466


100%|██████████| 3/3 [00:00<00:00,  8.54it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.35it/s]

22
val: loss: 361.641848
No loss improvement since 5/10 epochs.
0m 2s
Epoch 56/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.39it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.17it/s]

90
train: loss: 358.928018


100%|██████████| 3/3 [00:00<00:00,  9.02it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.50it/s]

22
val: loss: 370.563302
No loss improvement since 6/10 epochs.
0m 2s
Epoch 57/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.10it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.59it/s]

90
train: loss: 358.824544


100%|██████████| 3/3 [00:00<00:00,  8.30it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.51it/s]

22
val: loss: 387.949391
No loss improvement since 7/10 epochs.
0m 2s
Epoch 58/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.36it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.38it/s]

90
train: loss: 358.754915


100%|██████████| 3/3 [00:00<00:00,  8.78it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.82it/s]

22
val: loss: 359.353696
Val loss improved by 0.2584200772372469. Saving best model.
0m 2s
Epoch 59/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.25it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.86it/s]

90
train: loss: 358.737662


100%|██████████| 3/3 [00:00<00:00,  8.10it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.53it/s]

22
val: loss: 359.473880
No loss improvement since 1/10 epochs.
0m 2s
Epoch 60/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.48it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.09it/s]

90
train: loss: 358.716547


100%|██████████| 3/3 [00:00<00:00,  8.90it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.70it/s]

22
val: loss: 365.789035
No loss improvement since 2/10 epochs.
0m 2s
Epoch 61/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.40it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.60it/s]

90
train: loss: 358.706394


100%|██████████| 3/3 [00:00<00:00,  8.41it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.78it/s]

22
val: loss: 361.679657
No loss improvement since 3/10 epochs.
0m 2s
Epoch 62/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.42it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.57it/s]

90
train: loss: 358.670448


100%|██████████| 3/3 [00:00<00:00,  8.53it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.74it/s]

22
val: loss: 363.133403
No loss improvement since 4/10 epochs.
0m 2s
Epoch 63/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.37it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.27it/s]

90
train: loss: 358.638118


100%|██████████| 3/3 [00:00<00:00,  9.19it/s]
  8%|▊         | 1/12 [00:00<00:01,  8.07it/s]

22
val: loss: 362.802368
No loss improvement since 5/10 epochs.
0m 2s
Epoch 64/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.47it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.28it/s]

90
train: loss: 358.649226


100%|██████████| 3/3 [00:00<00:00,  8.84it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.90it/s]

22
val: loss: 360.854986
No loss improvement since 6/10 epochs.
0m 2s
Epoch 65/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.42it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.22it/s]

90
train: loss: 358.606782


100%|██████████| 3/3 [00:00<00:00,  8.79it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.80it/s]

22
val: loss: 359.434096
No loss improvement since 7/10 epochs.
0m 2s
Epoch 66/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.33it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.16it/s]

90
train: loss: 358.636148


100%|██████████| 3/3 [00:00<00:00,  8.58it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.64it/s]

22
val: loss: 360.792195
No loss improvement since 8/10 epochs.
0m 2s
Epoch 67/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.35it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.31it/s]

90
train: loss: 358.647880


100%|██████████| 3/3 [00:00<00:00,  8.83it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.44it/s]

22
val: loss: 360.995342
No loss improvement since 9/10 epochs.
0m 2s
Epoch 68/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.41it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.94it/s]

90
train: loss: 358.694744


100%|██████████| 3/3 [00:00<00:00,  8.45it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.89it/s]

22
val: loss: 359.000344
Val loss improved by 0.35335193980819213. Saving best model.
0m 2s
Epoch 69/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.44it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.12it/s]

90
train: loss: 358.640820


100%|██████████| 3/3 [00:00<00:00,  8.96it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.80it/s]

22
val: loss: 362.955034
No loss improvement since 1/10 epochs.
0m 2s
Epoch 70/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.29it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.04it/s]

90
train: loss: 358.703921


100%|██████████| 3/3 [00:00<00:00,  8.53it/s]
  8%|▊         | 1/12 [00:00<00:01,  6.98it/s]

22
val: loss: 365.764463
No loss improvement since 2/10 epochs.
0m 2s
Epoch 71/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.34it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.13it/s]

90
train: loss: 358.841816


100%|██████████| 3/3 [00:00<00:00,  8.65it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.68it/s]

22
val: loss: 359.868500
No loss improvement since 3/10 epochs.
0m 2s
Epoch 72/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.36it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.09it/s]

90
train: loss: 358.917834


100%|██████████| 3/3 [00:00<00:00,  8.79it/s]
  8%|▊         | 1/12 [00:00<00:01,  6.90it/s]

22
val: loss: 367.985601
No loss improvement since 4/10 epochs.
0m 2s
Epoch 73/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.38it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.77it/s]

90
train: loss: 358.691302


100%|██████████| 3/3 [00:00<00:00,  8.67it/s]
  8%|▊         | 1/12 [00:00<00:01,  8.00it/s]

22
val: loss: 359.042725
No loss improvement since 5/10 epochs.
0m 2s
Epoch 74/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.52it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.41it/s]

90
train: loss: 358.664162


100%|██████████| 3/3 [00:00<00:00,  8.94it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.85it/s]

22
val: loss: 363.684082
No loss improvement since 6/10 epochs.
0m 2s
Epoch 75/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.45it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.69it/s]

90
train: loss: 358.643336


100%|██████████| 3/3 [00:00<00:00,  8.42it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.86it/s]

22
val: loss: 359.965812
No loss improvement since 7/10 epochs.
0m 2s
Epoch 76/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.50it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.74it/s]

90
train: loss: 358.717048


100%|██████████| 3/3 [00:00<00:00,  8.15it/s]
  8%|▊         | 1/12 [00:00<00:01,  6.91it/s]

22
val: loss: 358.988309
Val loss improved by 0.012035023082432872. Saving best model.
0m 2s
Epoch 77/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.33it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.08it/s]

90
train: loss: 358.640816


100%|██████████| 3/3 [00:00<00:00,  8.54it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.48it/s]

22
val: loss: 360.177640
No loss improvement since 1/10 epochs.
0m 2s
Epoch 78/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.50it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.22it/s]

90
train: loss: 358.647691


100%|██████████| 3/3 [00:00<00:00,  8.85it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.85it/s]

22
val: loss: 359.041093
No loss improvement since 2/10 epochs.
0m 2s
Epoch 79/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.47it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.11it/s]

90
train: loss: 358.622928


100%|██████████| 3/3 [00:00<00:00,  8.74it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.97it/s]

22
val: loss: 359.064320
No loss improvement since 3/10 epochs.
0m 2s
Epoch 80/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.43it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.37it/s]

90
train: loss: 358.595161


100%|██████████| 3/3 [00:00<00:00,  8.57it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.33it/s]

22
val: loss: 359.435691
No loss improvement since 4/10 epochs.
0m 2s
Epoch 81/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.52it/s]
 33%|███▎      | 1/3 [00:00<00:00,  6.60it/s]

90
train: loss: 358.587778


100%|██████████| 3/3 [00:00<00:00,  8.09it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.74it/s]

22
val: loss: 359.096730
No loss improvement since 5/10 epochs.
0m 2s
Epoch 82/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.56it/s]
 33%|███▎      | 1/3 [00:00<00:00,  6.36it/s]

90
train: loss: 358.604793


100%|██████████| 3/3 [00:00<00:00,  7.95it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.44it/s]

22
val: loss: 359.342801
No loss improvement since 6/10 epochs.
0m 2s
Epoch 83/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.52it/s]
 33%|███▎      | 1/3 [00:00<00:00,  6.79it/s]

90
train: loss: 358.641379


100%|██████████| 3/3 [00:00<00:00,  8.34it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.87it/s]

22
val: loss: 360.906500
No loss improvement since 7/10 epochs.
0m 2s
Epoch 84/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.56it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.37it/s]

90
train: loss: 358.673061


100%|██████████| 3/3 [00:00<00:00,  8.02it/s]
  8%|▊         | 1/12 [00:00<00:01,  8.30it/s]

22
val: loss: 359.930747
No loss improvement since 8/10 epochs.
0m 2s
Epoch 85/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.44it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.64it/s]

90
train: loss: 358.609589


100%|██████████| 3/3 [00:00<00:00,  8.53it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.31it/s]

22
val: loss: 359.093167
No loss improvement since 9/10 epochs.
0m 2s
Epoch 86/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.41it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.11it/s]

90
train: loss: 358.539952


100%|██████████| 3/3 [00:00<00:00,  8.53it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.59it/s]

22
val: loss: 359.549100
No loss improvement since 10/10 epochs.
0m 2s
Epoch 87/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.46it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.79it/s]

90
train: loss: 358.562198


100%|██████████| 3/3 [00:00<00:00,  7.91it/s]


22
val: loss: 359.066465
No loss improvement since 11/10 epochs.
0m 2s
Best loss: 358.988309


In [4]:
# load weights
model.load_state_dict(torch.load("state_dict.pth", map_location="cpu"))
device = torch.device('cpu')
model = model.to(device)
# denormalization function
from torchvision import transforms
inv_normalize = transforms.Normalize(
   mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
   std=[1/0.229, 1/0.224, 1/0.225]
)

def reverse_transform(inp):
    print(inp.shape)
    inp = inv_normalize(inp)
    inp = inp.numpy()
    inp = np.swapaxes(inp, 1, 3)
    inp = np.swapaxes(inp, 1, 2)
    inp = np.clip(inp, 0, 1)
    inp = (inp * 255).astype(np.uint8)
    
    return inp
def labels2mask(labels):
    return labels[:,1,:,:]

# helper function to plot x, ground truth and predict images in grid
import matplotlib.pyplot as plt
def plot_side_by_side(x,y_dem_gt, y_dem_pr):
  assert x.shape[0] == y_dem_gt.shape[0] == y_dem_pr.shape[0]
  batch_size = x.shape[0]
  fig, axs = plt.subplots(batch_size, 4, figsize=(30,50))
  for i in range(batch_size):
    axs[i, 0].imshow(x[i,1:4].permute(1, 2, 0))
    min_val = torch.min(x[i,0])
    max_val = torch.max(x[i,0])
    axs[i, 1].imshow(x[i,0], vmin = min_val, vmax = max_val)
    axs[i, 2].imshow(np.squeeze(y_dem_gt[i]), vmin = min_val, vmax = max_val)
    axs[i, 3].imshow(np.squeeze(y_dem_pr[i]), vmin = min_val, vmax = max_val)

# visualize example segmentation
import math
model.eval()   # Set model to evaluate mode
test_dataset = DenoiseDataset(test_dir, img_size=PARAMS['img_size'], count=PARAMS["test_dataset_size"])
test_loader = DataLoader(test_dataset, batch_size=6, shuffle=True, num_workers=0)
inputs, gts = next(iter(test_loader))
inputs = inputs.to(device)
gts = gts.to(device)

gts = gts.data.cpu().numpy()
pred = model(inputs)

pred = pred.data.cpu().numpy()
inputs = inputs.data.cpu()

# use helper function to plot
plot_side_by_side(inputs, gts, pred)

#evaluate model
#test_dataset = DenoiseDataset(x_test_dir, y_test_dir, input_size=PARAMS['input_size'], output_size=PARAMS['output_size'], n_classes=PARAMS['n_classes'])
#test_loader = DataLoader(test_dataset, batch_size=PARAMS["batch_size"], shuffle=True, num_workers=0)
#intersection=0
#union=0
#for inputs, labels in tqdm(test_loader):
#  inputs = inputs.to(device)
#  labels = labels.to(device)
#  labels = labels.data.cpu().numpy()
#  pred = model(inputs)
#  pred = torch.round(pred)
#  pred = pred.data.cpu().numpy()
#  target = labels[:,1,:,:]
#  predict = pred[:,1,:,:]
#  temp = (target * predict).sum()
#  intersection+=temp
#  union+=((target + predict).sum() - temp)
#iou = intersection/union
#print("IoU: {}".format(iou))
#neptune.log_metric("total_iou",iou)


Output hidden; open in https://colab.research.google.com to view.

In [5]:

# update neptune status
neptune.stop()