## Imports

In [1]:
# imports
import torch
from torch.utils.data import DataLoader, random_split
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import wandb
#from prettytable import PrettyTable
import random
import os
from datetime import datetime
from torch.optim import lr_scheduler


# local imports
from models_parameters import losses
from models_parameters import models
from utils import helper_functions
from utils.dataloader import Dataset as dataset

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [2]:
device = torch.device("cpu")

# GENERAL SETTINGS

In [3]:
# Running where? "colab" or "ubs"
location = "local"

# chose which tile to train, test on eg "T30UXU"(9k) or "T30UUU"(3k)
sen2_tile_train = "T30UXU"
sen2_tile_val   = "T30UUU"
sen2_tile_test  = "all"

# chose dataloader transform - 'standardize', 'histogram_matching', 'moment_matching' or 'spot6
transform = "histogram_matching"

# chose where and how often to output - 'show' or output path 
# /share/projects/erasmus/sesure/working_dir/images/
output_location = "C:\\Users\\accou\\Documents\\thesis\\images\\"

# selesct wether to load from checkpoint - 'None' for no checkpoint or checkpoint path
load_checkpoint = "None"

# Save checkpoints, 'None' to not save, otherwise path
save_checkpoint = True

# set WandB Project Name
model_name = "HighResNet_local"

In [4]:
# create run_name
run_name = model_name+"_"+transform+"_"+sen2_tile_train+"__"+str(datetime.now().strftime("%d-%m-%Y_%H-%M-%S"))
print(run_name)

HighResNet_local_histogram_matching_T30UXU__20-04-2022_22-07-38


# MODEL SETTINGS

In [5]:
config = {'encoder': {'in_channels': 2,
   'num_layers': 2,
   'kernel_size': 3,
   'channel_size': 64},
  'recursive': {'alpha_residual': True,
   'in_channels': 64,
   'num_layers': 2,
   'kernel_size': 3},
  'decoder': {'deconv': {'in_channels': 64,
    'kernel_size': 3,
    'stride': 3,
    'out_channels': 64},
   'final': {'in_channels': 64, 'kernel_size': 1, 'out_channels': 1}}}

In [6]:
# Load Model, wither straight definitions or loaded models
model = models.HighResNet.get_model(config)

# load pretrained model weights
if load_checkpoint!="None":
  model.load_state_dict(torch.load(load_checkpoint))

# Load Loss Function
#loss_func = losses.loss_mae
from models_parameters.weighted_losses import weighted_loss
f = weighted_loss()
loss_func = losses.loss_mae

# enable benchmark for faster execution (only if inputs dont change)
torch.backends.cudnn.benchmark = True 

In [7]:
# TRANING SETTINGS
batch_size = 12
lr = 0.0001
epochs = 10
plot_frequency = 1 # in batches

## Create Dataset Object

In [8]:
if location=="colab":
  working_directory = "/content/drive/MyDrive/thesis/"
  folder_path = "/content/drive/MyDrive/thesis/data/"
  dataset_file = "/content/drive/MyDrive/thesis/data/images_subfolders3.pkl"
if location=="ubs":
  working_directory = "/share/projects/erasmus/sesure/working_dir/"
  folder_path = "/share/projects/erasmus/sesure/working_dir/"
  dataset_file = "/share/projects/erasmus/sesure/working_dir/final_dataset.pkl"
if location=="local":
  working_directory = "C:\\Users\\accou\\Documents\\thesis"
  folder_path = "C:\\Users\\accou\\Documents\\thesis\\data\\"
  dataset_file = "C:\\Users\\accou\\Documents\\thesis\\images_subfolders3.pkl"

# get dataset object
dataset_train = dataset(folder_path,dataset_file,transform,sen2_amount=1, sen2_tile = sen2_tile_train, location=location)
dataset_val   = dataset(folder_path,dataset_file,transform,sen2_amount=1, sen2_tile = sen2_tile_val,   location=location)
#dataset_test  = dataset(folder_path,dataset_file,transform,sen2_amount=1, sen2_tile = sen2_tile_test,  location=location)

In [None]:
# LOADER URBAN
# TRANING SETTINGS
batch_size = 5
lr = 0.0001
epochs = 100
plot_frequency = 25 # in batches

from utils.dataloader_urban import Dataset
dataset_train = Dataset()
loader_train = DataLoader(dataset_train)
print("Urban100 Dataset: ",len(loader_train))

## Train function

In [9]:
# implementation of model trainer function
def train_model(run_name,model,batch_size=1,lr=0.001,epochs=10,wandb_name="test",plot_frequency=10):
    
    logging=False          # log to WandB
    logging_val=True      # log validation loss in paralel to training loss
    plotting=True         # set if images should be plotted
    plot_frequency = plot_frequency    # set after how many batches images shoud be plotted

    if logging==True:
        wandb.init(name=run_name,project=wandb_name, entity="simon-donike")
        wandb.config = {
          "learning_rate": lr,
          "epochs": epochs,
          "batch_size": batch_size
        }
    
    # define loaders
    loader_train = DataLoader(dataset_train,batch_size=batch_size, shuffle=True, num_workers=0,pin_memory=True,drop_last=True)
    loader_val  = DataLoader(dataset_val,batch_size=1, shuffle=True, num_workers=0,pin_memory=True,drop_last=True)
    print("dataloader instanciated!")
    print("Len. Train: ",len(loader_train),"(Batch Sz. "+str(batch_size),") Len. Val: ",len(dataset_val),"(Batch Sz. "+str(batch_size)+")")


    train_loss = []  # where we keep track of the training loss
    train_accuracy = []  # where we keep track of the training accuracy of the model
    val_loss = []  # where we keep track of the validation loss
    val_accuracy = []  # where we keep track of the validation accuracy of the model
    epochs = epochs  # number of epochs

    # initialize model
    model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr,momentum=0.9)
    #optimizer = optim.Adam(model.parameters())
    # set up LR Scheduler
    scheduler = lr_scheduler.StepLR(optimizer, 1, 0.1,verbose=True)
    
    
    for e in range(epochs):
        model.train()
        #train_correct = 0
        batch_count=0
        for (x_train_batch, y_train_batch) in loader_train:
            batch_count = batch_count+1
            
            # send data to device
            x_train_batch = x_train_batch.to(device)
            y_train_batch = y_train_batch.to(device) # to float

            # forward pass
            y_hat = model(x_train_batch)  
            # compute the loss
            loss = loss_func(y_hat, y_train_batch)  

            loss.backward()  # obtain the gradients with respect to the loss
            optimizer.step()  # perform one step of gradient descent
            optimizer.zero_grad()  # reset the gradients to 0

            """ START LOGGING & PLOTTING """
            # Log train loss if not logging train+val loss
            if logging==True:
                wandb.log({'loss_train_step': loss.item() / len(x_train_batch),
                           'LR':scheduler.get_lr()[0]})
            # log val loss together with train and val loss
            if logging_val and logging:
              model.eval()
              x_val_batch, y_val_batch = next(iter(loader_val))
              x_val_batch = x_val_batch.to(torch.float).to(device)
              y_val_batch = y_val_batch.to(torch.float).to(device)
              y_hat_val = model(x_val_batch)
              loss_val = loss_func(y_hat_val, y_val_batch) 
              wandb.log({'loss_val_step': loss_val.item() / len(y_val_batch)})
              del x_val_batch
              del y_val_batch
              model.train()

            # log and plot after each batch/after plotting frequency
            if plotting==True:
                if batch_count%plot_frequency==0:
                    print('Epoch', e, ' Batch',batch_count,' finished.  No. of trained Images: '+str(batch_count*batch_size))
                    psnr_int = round(losses.loss_psnr(y_train_batch,helper_functions.interpolate_tensor(y_hat)).item(),2)
                    psnr_pred = round(losses.loss_psnr(y_train_batch,y_hat).item(),2)
                    helper_functions.plot_tensors_window(y_train_batch,x_train_batch,y_hat,psnr_int,psnr_pred,fig_path=output_location)
                    print("Tensor Mean:",round(y_hat.mean().item(),6),"Max: ",round(y_hat.max().item(),6),"Min: ",round(y_hat.min().item(),6))
                    #break


        """ END OF EPOCH, logging and plotting """
        print ('Epoch', e+1, ' training finished.')
        scheduler.step()
        if save_checkpoint:
          torch.save(model.state_dict(), working_directory+"/checkpoints/"+run_name+".pkl")
        """
        PERFORM VALIDATION STEP after each epoch
        """
        model.eval()
        # perform inference 10 times on random images
        psnr_inter_list = []
        psnr_pred_list = []
        loss_val_list = []
        for i in range(5):
          # Load val Data and to device
          x_val_batch, y_val_batch = next(iter(loader_val))
          x_val_batch = x_val_batch.to(torch.float).to(device)
          y_val_batch = y_val_batch.to(torch.float).to(device)
        
          # perform inference step
          with torch.no_grad():
            pred = model(x_val_batch)

          # calculate metrics
          val_loss = loss_func(pred, y_val_batch).item()
          psnr_int = round(losses.loss_psnr(y_val_batch,helper_functions.interpolate_tensor(x_val_batch)).item(),2)
          psnr_pred = round(losses.loss_psnr(y_val_batch,pred).item(),2)
          psnr_inter_list.append(psnr_int)
          psnr_pred_list.append(psnr_pred)
          loss_val_list.append(val_loss)
        
        # TODO: implement SSIM
        print("VALIDATION: Visualizing 1 of 10 inference steps. Values are averages. Time: ",datetime.now().strftime("%d-%m-%Y %H:%M:%S"))
        psnr_int_avg = round(sum(psnr_inter_list) / len(psnr_inter_list),2)
        psnr_pred_avg = round(sum(psnr_pred_list) / len(psnr_pred_list),2)
        loss_val_avg = round(sum(loss_val_list) / len(loss_val_list),2)
        helper_functions.plot_tensors_window(y_val_batch,x_val_batch,pred,psnr_int_avg,psnr_pred_avg)
        # return model to train state
        model.train()

        # log validation accuracy after each epoch
        if logging==True:
            # log val loss
            wandb.log({'val_loss_epoch': loss_val_avg})
            # log val PSNR
            wandb.log({'val_PSNR_epoch': psnr_pred_avg})

    # Steps to perform after training is finished
    print("Training finished!")
    if logging==True:
        wandb.finish()

# Perform Training

In [10]:
train_model(run_name,model=model,batch_size=batch_size,lr=lr,epochs=epochs,wandb_name=model_name,plot_frequency=plot_frequency)

dataloader instanciated!
Len. Train:  765 (Batch Sz. 12 ) Len. Val:  3506 (Batch Sz. 12)
Adjusting learning rate of group 0 to 7.0000e-04.


AttributeError: 'bool' object has no attribute 'view'

In [None]:
wandb.finish()