In [1]:
!pip install neptune-client
!pip install torchinfo



In [2]:
# -*- coding: utf-8 -*-

import os
#set workdir
os.chdir("/content/drive/MyDrive/DEM-waterlevel/ml/")

#imports
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from dataloader import DenoiseDataset
from torchinfo import summary
import time
import copy
import pdb
from tqdm import tqdm

#training parameters in neptune format
PARAMS = {
    "img_size": 256,
    "model": "vgg_unet",
    "learning_rate": 0.01,
    "batch_size": 8,
    'epochs': 1000,
    'patience': 10,
    "train_dataset_size": -1, # set train dataset subset. Useful when neet to 
                              # overtrain model with small amount of images.
                              # -1 -all images from train directories.
    "test_dataset_size": -1,  # set test dataset subset.
                              # -1 -all images from train directories.
    'image_preload': False,
}


#dataset configuration
dataset_dir = os.path.normpath("/content/drive/MyDrive/DEM-waterlevel/dataset")
train_dir = os.path.join(dataset_dir,"train")
test_dir = os.path.join(dataset_dir,"test")

train_set = DenoiseDataset(train_dir, img_size=PARAMS['img_size'], count=PARAMS["train_dataset_size"])
test_set = DenoiseDataset(test_dir, img_size=PARAMS['img_size'], count=PARAMS["test_dataset_size"])

batch_size = PARAMS['batch_size']
dataloaders = {
    'train': DataLoader(train_set, batch_size=PARAMS['batch_size'], shuffle=True, num_workers=0),
    'val': DataLoader(test_set, batch_size=PARAMS['batch_size'], shuffle=True, num_workers=0)
}

# load images - useful if you want to save some time by preloading images (very time-consuming) when 
# the model is still not fuctional and cant run standard training.
if PARAMS['image_preload']:
  for phase in dataloaders:
    for inputs, labels in tqdm(dataloaders[phase]):
      pass

#model loading
elif PARAMS['model'] == "vgg_unet":
  from models.vgg_unet import VggUnet
  model = VggUnet()

#model structure preview
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
#neptune initialization
import neptune
neptune.init(project_qualified_name='radek/denoise1',
             api_token='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5haSIsImFwaV91cmwiOiJodHRwczovL3VpLm5lcHR1bmUuYWkiLCJhcGlfa2V5IjoiYmY4YjQ3YjEtNmY5My00MDc2LWI4NzAtMWE5MmUwZjQ1NDE2In0=',
             )
neptune.create_experiment(params=PARAMS)

model = model.to(device)
model_stats = summary(model, input_size=(PARAMS['batch_size'], 4, PARAMS['img_size'], PARAMS['img_size']))
for line in str(model_stats).splitlines():
  neptune.log_text('model_summary', line)

from collections import defaultdict
import torch.nn.functional as F

def calc_loss(pred, target, metrics):
    loss = F.mse_loss(pred, target)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)
    return loss

def print_metrics(metrics, epoch_samples, phase):   
    print(epoch_samples) 
    outputs = []
    for k in metrics.keys():
        outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))
        neptune.log_metric(phase+"_"+k, metrics[k] / epoch_samples) #log
    print("{}: {}".format(phase, ", ".join(outputs)))

#training loop
def train_model(model, dataloaders, optimizer, device, num_epochs=25, patience=-1):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = float('inf')
    no_improvement = 0
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        since = time.time()

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                for param_group in optimizer.param_groups:
                    print("LR", param_group['lr'])
                    
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0
            
            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)             

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = calc_loss(outputs, labels, metrics)
                    #print(model.encoder[0].weight.grad)
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        #pdb.set_trace()
                        optimizer.step()

                # statistics
                epoch_samples += inputs.size(0)

            print_metrics(metrics, epoch_samples, phase)
            epoch_loss = metrics['loss'] / epoch_samples

            # deep copy the model
            if phase == 'train':
              if epoch_loss < best_loss:
                no_improvement = 0
                print("Val loss improved by {}. Saving best model.".format(best_loss-epoch_loss))
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())
              else:
                no_improvement += 1
                print("No loss improvement since {}/{} epochs.".format(no_improvement,patience))
        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        if patience >= 0 and no_improvement > patience:
          break
    print('Best loss: {:4f}'.format(best_loss))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

#model training
optimizer_ft = optim.Adam(model.parameters(), lr=PARAMS['learning_rate'])
model = train_model(model, dataloaders, optimizer_ft, device, num_epochs=PARAMS['epochs'], patience=PARAMS['patience'])

# save weights
torch.save(model.state_dict(),"state_dict.pth")

neptune.log_artifact('state_dict.pth')



https://ui.neptune.ai/radek/denoise1/e/DEN1-36


  8%|▊         | 1/12 [00:00<00:01,  6.80it/s]

Epoch 0/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  6.96it/s]
  0%|          | 0/3 [00:00<?, ?it/s]

90
train: loss: 44431.782813
Val loss improved by inf. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.73it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.99it/s]

22
val: loss: 2504.389293
0m 2s
Epoch 1/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.31it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.88it/s]

90
train: loss: 37518.344358
Val loss improved by 6913.438454861112. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.21it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.72it/s]

22
val: loss: 50296.368253
0m 2s
Epoch 2/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.32it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.78it/s]

90
train: loss: 22853.058442
Val loss improved by 14665.28591579861. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.28it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.34it/s]

22
val: loss: 45603.251776
0m 2s
Epoch 3/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.16it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.35it/s]

90
train: loss: 8101.681641
Val loss improved by 14751.37680121528. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.36it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.52it/s]

22
val: loss: 339144.517045
0m 2s
Epoch 4/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.24it/s]
 33%|███▎      | 1/3 [00:00<00:00,  6.45it/s]

90
train: loss: 2866.062354
Val loss improved by 5235.619287109375. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  7.53it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.40it/s]

22
val: loss: 320181.500000
0m 2s
Epoch 5/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.14it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.76it/s]

90
train: loss: 1818.670397
Val loss improved by 1047.3919569227432. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.56it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.44it/s]

22
val: loss: 25866.794034
0m 2s
Epoch 6/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.31it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.66it/s]

90
train: loss: 1519.716710
Val loss improved by 298.95368652343745. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.47it/s]
  8%|▊         | 1/12 [00:00<00:01,  8.05it/s]

22
val: loss: 22791.291371
0m 2s
Epoch 7/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.30it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.73it/s]

90
train: loss: 1314.086347
Val loss improved by 205.6303629557292. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.41it/s]
  8%|▊         | 1/12 [00:00<00:01,  8.05it/s]

22
val: loss: 1245.137862
0m 2s
Epoch 8/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.24it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.96it/s]

90
train: loss: 952.762203
Val loss improved by 361.3241441514756. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.55it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.78it/s]

22
val: loss: 1863.466009
0m 2s
Epoch 9/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.32it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.90it/s]

90
train: loss: 785.466109
Val loss improved by 167.29609375000007. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.34it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.94it/s]

22
val: loss: 1273.272039
0m 2s
Epoch 10/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.35it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.11it/s]

90
train: loss: 637.102485
Val loss improved by 148.3636244032117. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.53it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.95it/s]

22
val: loss: 4779.976296
0m 2s
Epoch 11/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.32it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.78it/s]

90
train: loss: 427.418330
Val loss improved by 209.68415459526915. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.50it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.08it/s]

22
val: loss: 1251.335782
0m 2s
Epoch 12/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.00it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.12it/s]

90
train: loss: 376.812875
Val loss improved by 50.60545518663196. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.41it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.54it/s]

22
val: loss: 380.628740
0m 2s
Epoch 13/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.18it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.84it/s]

90
train: loss: 369.431406
Val loss improved by 7.381469048394081. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.51it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.12it/s]

22
val: loss: 400.960674
0m 2s
Epoch 14/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.02it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.25it/s]

90
train: loss: 365.411941
Val loss improved by 4.01946478949651. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.89it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.70it/s]

22
val: loss: 375.553223
0m 2s
Epoch 15/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.19it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.05it/s]

90
train: loss: 363.300002
Val loss improved by 2.111939154730919. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.64it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.12it/s]

22
val: loss: 383.103771
0m 2s
Epoch 16/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.10it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.00it/s]

90
train: loss: 361.949096
Val loss improved by 1.3509060329861313. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.30it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.00it/s]

22
val: loss: 367.168207
0m 2s
Epoch 17/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.03it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.75it/s]

90
train: loss: 361.092603
Val loss improved by 0.8564934624565694. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.14it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.07it/s]

22
val: loss: 363.505105
0m 2s
Epoch 18/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.27it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.81it/s]

90
train: loss: 360.503423
Val loss improved by 0.5891798231336907. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.47it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.11it/s]

22
val: loss: 360.147139
0m 2s
Epoch 19/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.07it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.30it/s]

90
train: loss: 360.176565
Val loss improved by 0.3268575032552121. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.78it/s]
  8%|▊         | 1/12 [00:00<00:01,  8.00it/s]

22
val: loss: 360.324202
0m 2s
Epoch 20/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.29it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.13it/s]

90
train: loss: 359.794864
Val loss improved by 0.38170166015623863. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.95it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.85it/s]

22
val: loss: 360.448167
0m 2s
Epoch 21/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.20it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.20it/s]

90
train: loss: 359.644962
Val loss improved by 0.1499016655815808. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.69it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.75it/s]

22
val: loss: 361.966300
0m 2s
Epoch 22/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.23it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.99it/s]

90
train: loss: 359.351893
Val loss improved by 0.29306911892365406. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.68it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.95it/s]

22
val: loss: 361.705417
0m 2s
Epoch 23/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.35it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.91it/s]

90
train: loss: 359.247317
Val loss improved by 0.10457560221351514. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.54it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.57it/s]

22
val: loss: 362.188632
0m 2s
Epoch 24/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.05it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.07it/s]

90
train: loss: 359.127925
Val loss improved by 0.11939222547744066. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.63it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.96it/s]

22
val: loss: 361.768313
0m 2s
Epoch 25/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.33it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.09it/s]

90
train: loss: 359.057016
Val loss improved by 0.07090928819445708. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.86it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.46it/s]

22
val: loss: 361.688338
0m 2s
Epoch 26/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.19it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.11it/s]

90
train: loss: 359.004074
Val loss improved by 0.052941894531215894. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.84it/s]
  8%|▊         | 1/12 [00:00<00:01,  8.00it/s]

22
val: loss: 363.206931
0m 2s
Epoch 27/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.26it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.36it/s]

90
train: loss: 358.915665
Val loss improved by 0.08840874565976264. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.96it/s]
  8%|▊         | 1/12 [00:00<00:01,  8.15it/s]

22
val: loss: 362.787046
0m 2s
Epoch 28/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.28it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.79it/s]

90
train: loss: 358.960792
No loss improvement since 1/10 epochs.


100%|██████████| 3/3 [00:00<00:00,  8.24it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.79it/s]

22
val: loss: 360.204939
0m 2s
Epoch 29/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.24it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.95it/s]

90
train: loss: 358.856216
Val loss improved by 0.059448920355862356. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.18it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.57it/s]

22
val: loss: 360.528648
0m 2s
Epoch 30/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.25it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.59it/s]

90
train: loss: 358.799810
Val loss improved by 0.056405978732641415. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.66it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.66it/s]

22
val: loss: 359.646834
0m 2s
Epoch 31/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.34it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.32it/s]

90
train: loss: 358.809696
No loss improvement since 1/10 epochs.


100%|██████████| 3/3 [00:00<00:00,  8.86it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.85it/s]

22
val: loss: 359.264266
0m 2s
Epoch 32/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.32it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.32it/s]

90
train: loss: 358.805008
No loss improvement since 2/10 epochs.


100%|██████████| 3/3 [00:00<00:00,  8.86it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.83it/s]

22
val: loss: 359.464050
0m 2s
Epoch 33/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.27it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.81it/s]

90
train: loss: 358.770605
Val loss improved by 0.029204644097262644. Saving best model.


100%|██████████| 3/3 [00:00<00:00,  8.33it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.56it/s]

22
val: loss: 359.152457
0m 2s
Epoch 34/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.25it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.32it/s]

90
train: loss: 358.793212
No loss improvement since 1/10 epochs.


100%|██████████| 3/3 [00:00<00:00,  8.93it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.76it/s]

22
val: loss: 358.982780
0m 2s
Epoch 35/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.25it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.20it/s]

90
train: loss: 358.942255
No loss improvement since 2/10 epochs.


100%|██████████| 3/3 [00:00<00:00,  9.11it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.34it/s]

22
val: loss: 359.147339
0m 2s
Epoch 36/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.19it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.04it/s]

90
train: loss: 359.162794
No loss improvement since 3/10 epochs.


100%|██████████| 3/3 [00:00<00:00,  8.69it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.70it/s]

22
val: loss: 358.973883
0m 2s
Epoch 37/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.24it/s]
 33%|███▎      | 1/3 [00:00<00:00,  8.04it/s]

90
train: loss: 358.817029
No loss improvement since 4/10 epochs.


100%|██████████| 3/3 [00:00<00:00,  8.46it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.32it/s]

22
val: loss: 359.002453
0m 2s
Epoch 38/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.06it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.55it/s]

90
train: loss: 358.839425
No loss improvement since 5/10 epochs.


100%|██████████| 3/3 [00:00<00:00,  8.43it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.54it/s]

22
val: loss: 359.113950
0m 2s
Epoch 39/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.14it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.32it/s]

90
train: loss: 358.777611
No loss improvement since 6/10 epochs.


100%|██████████| 3/3 [00:00<00:00,  7.87it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.84it/s]

22
val: loss: 359.012218
0m 2s
Epoch 40/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.11it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.41it/s]

90
train: loss: 358.800252
No loss improvement since 7/10 epochs.


100%|██████████| 3/3 [00:00<00:00,  8.42it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.96it/s]

22
val: loss: 359.030540
0m 2s
Epoch 41/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  7.22it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.16it/s]

90
train: loss: 358.785697
No loss improvement since 8/10 epochs.


100%|██████████| 3/3 [00:00<00:00,  7.94it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.30it/s]

22
val: loss: 359.083668
0m 2s
Epoch 42/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  6.94it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.36it/s]

90
train: loss: 358.773857
No loss improvement since 9/10 epochs.


100%|██████████| 3/3 [00:00<00:00,  7.73it/s]
  8%|▊         | 1/12 [00:00<00:01,  7.59it/s]

22
val: loss: 358.952193
0m 2s
Epoch 43/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  6.70it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.28it/s]

90
train: loss: 358.882557
No loss improvement since 10/10 epochs.


100%|██████████| 3/3 [00:00<00:00,  7.97it/s]
  8%|▊         | 1/12 [00:00<00:01,  6.33it/s]

22
val: loss: 359.268768
0m 2s
Epoch 44/999
----------
LR 0.01


100%|██████████| 12/12 [00:01<00:00,  6.93it/s]
 33%|███▎      | 1/3 [00:00<00:00,  7.54it/s]

90
train: loss: 358.782886
No loss improvement since 11/10 epochs.


100%|██████████| 3/3 [00:00<00:00,  8.31it/s]


22
val: loss: 359.053514
0m 2s
Best loss: 358.770605


In [4]:
# load weights
model.load_state_dict(torch.load("state_dict.pth", map_location="cpu"))
device = torch.device('cpu')
model = model.to(device)
# denormalization function
from torchvision import transforms
inv_normalize = transforms.Normalize(
   mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
   std=[1/0.229, 1/0.224, 1/0.225]
)

def reverse_transform(inp):
    print(inp.shape)
    inp = inv_normalize(inp)
    inp = inp.numpy()
    inp = np.swapaxes(inp, 1, 3)
    inp = np.swapaxes(inp, 1, 2)
    inp = np.clip(inp, 0, 1)
    inp = (inp * 255).astype(np.uint8)
    
    return inp
def labels2mask(labels):
    return labels[:,1,:,:]

# helper function to plot x, ground truth and predict images in grid
import matplotlib.pyplot as plt
def plot_side_by_side(x,y_dem_gt, y_dem_pr):
  assert x.shape[0] == y_dem_gt.shape[0] == y_dem_pr.shape[0]
  batch_size = x.shape[0]
  fig, axs = plt.subplots(batch_size, 4, figsize=(30,50))
  for i in range(batch_size):
    axs[i, 0].imshow(x[i,1:4].permute(1, 2, 0))
    min_val = torch.min(x[i,0])
    max_val = torch.max(x[i,0])
    axs[i, 1].imshow(x[i,0], vmin = min_val, vmax = max_val)
    axs[i, 2].imshow(np.squeeze(y_dem_gt[i]), vmin = min_val, vmax = max_val)
    axs[i, 3].imshow(np.squeeze(y_dem_pr[i]), vmin = min_val, vmax = max_val)

# visualize example segmentation
import math
model.eval()   # Set model to evaluate mode
test_dataset = DenoiseDataset(train_dir, img_size=PARAMS['img_size'], count=PARAMS["test_dataset_size"])
test_loader = DataLoader(test_dataset, batch_size=6, shuffle=True, num_workers=0)
inputs, gts = next(iter(test_loader))
inputs = inputs.to(device)
gts = gts.to(device)

gts = gts.data.cpu().numpy()
pred = model(inputs)

pred = pred.data.cpu().numpy()
inputs = inputs.data.cpu()

# use helper function to plot
plot_side_by_side(inputs, gts, pred)

#evaluate model
#test_dataset = DenoiseDataset(x_test_dir, y_test_dir, input_size=PARAMS['input_size'], output_size=PARAMS['output_size'], n_classes=PARAMS['n_classes'])
#test_loader = DataLoader(test_dataset, batch_size=PARAMS["batch_size"], shuffle=True, num_workers=0)
#intersection=0
#union=0
#for inputs, labels in tqdm(test_loader):
#  inputs = inputs.to(device)
#  labels = labels.to(device)
#  labels = labels.data.cpu().numpy()
#  pred = model(inputs)
#  pred = torch.round(pred)
#  pred = pred.data.cpu().numpy()
#  target = labels[:,1,:,:]
#  predict = pred[:,1,:,:]
#  temp = (target * predict).sum()
#  intersection+=temp
#  union+=((target + predict).sum() - temp)
#iou = intersection/union
#print("IoU: {}".format(iou))
#neptune.log_metric("total_iou",iou)


Output hidden; open in https://colab.research.google.com to view.

In [5]:

# update neptune status
neptune.stop()