# Imports

In [1]:
# general
import time
from datetime import datetime
import io
import numpy as np
import matplotlib.pyplot as plt
import PIL
from tqdm import tqdm
import wandb
from pynvml import *
import kornia

# torch
import torch
from torch import nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
import torchvision

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# custom imports
from lanczos import lanczos_2d as lanczos
from model.losses import define_loss
from utils.helper_functions import get_lr
from utils.helper_functions import plot_tensors
from utils.helper_functions import plot_tensors_extra_info
from utils.dataloader_spot import dataset_spot6

# Sesure DataLoader
from utils.dataloader import Dataset as dataset

# model
from model.shiftnet import ShiftNet

# Load Model and register/shift/loss functions

In [3]:
# load model
regis_model = ShiftNet(in_channel=1)
regis_model = regis_model.train()
regis_model = regis_model.to(device)

Input Channels: 1


In [2]:
from model.shiftnet import apply_shifts
from model.shiftnet import get_shift_loss
from model.shiftnet import get_thetas

  from .autonotebook import tqdm as notebook_tqdm


# Mimic SR

In [4]:
def superresolute(lr,factor=4):
    return(torch.nn.functional.interpolate(lr, size=None, scale_factor=factor,mode="bicubic"))

def shifter(im,shift_factor=0.05):
    affinator = torchvision.transforms.Compose([torchvision.transforms.RandomAffine(degrees=0, translate=(shift_factor,shift_factor), scale=None, shear=None,
                                    interpolation=torchvision.transforms.InterpolationMode.NEAREST, fill=0, center=None)])
    im = affinator(im)
    return(im)

# Training Settings

In [5]:
# Training Settings
# dataloader settings
#batch_size = 32
epochs = 100
lr = 0.001 # try next with 0.00001
optimizer = torch.optim.Adam(list(regis_model.parameters()), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',factor=0.1,patience=3,min_lr=0.00000001)

# Logging Settings
log_freq   =    10 # in it
image_freq =  100 # in it
val_freq   =  99999999 # in it
save_freq  =  99999999 # in epochs
wandb_project = "ShiftNet_sen2"
wandb_entity = "simon-donike"

# Loss Settings
which_loss_func = "MAE"  # choose loss: 'MSE', 'MAE','SSIM','PSNR'
ssim_window_size = 5

# Settings of Shift Data
# train on downsampled and manually shifted SPOT6?
train_shifted = False
shift_pixels = 10 # max shift in pixels
shift_factor = round((1/300)*shift_pixels,3) # max shift in factor

# Data Loader

In [6]:
# Custom dataloaders
working_directory = "C:\\Users\\accou\\Documents\\GitHub\\a-PyTorch-Tutorial-to-Super-Resolution\\"
folder_path = "C:\\Users\\accou\\Documents\\thesis\\data_v2\\"
dataset_file = "C:\\Users\\accou\\Documents\\thesis\\data_v2\\final_dataset.pkl"
transform = "histogram_matching"
sen2_tile_train = "T30UXU"
sen2_tile_test   = "T30UUU"
sen2_tile_val  = "all"
location = "local"
batch_size = 4
strat = True # decide wether agricultural areas should be stratified to have more balanced dataset

#folder_path,dataset_file,test_train_val="train",transform="histogram_matching",sen2_amount=1,sen2_tile="all",location="colab"):
dataset_train = dataset(folder_path,dataset_file,test_train_val="train",transform=transform,sen2_amount=1, location=location,strat=strat)
train_loader = DataLoader(dataset_train,batch_size=batch_size,
                          shuffle=True, num_workers=4,pin_memory=True,drop_last=True,prefetch_factor=4) # prefetch 32

# Training Loop

In [None]:
# initialize logging
run_name = str(datetime.now().strftime("%d-%m-%Y_%H-%M-%S"))
wandb.init(name=run_name,project=wandb_project,entity=wandb_entity)


# initialize loss
loss_func = define_loss(which_loss_func)

lowest_loss = 9999
for epoch in range(epochs):
    epoch+=1
    print("Epoch:",epoch)
    it = 0
    for lr,hr in tqdm(train_loader,ascii=True):
        it+=1
        
        # PERFORM SR HERE, atm pnly interplation
        sr = superresolute(lr,factor=4)
        sr,hr = sr.to(device),hr.to(device)
        
        # calculate predicted thetas
        thetas, hr_small, sr_small = get_thetas(hr,sr,regis_model)
        # perform shift based on calculated thetas
        new_images,thetas = apply_shifts(sr,thetas,regis_model)
        # calculate train loss according to defined function
        train_loss,hr_loss,new_images_loss = get_shift_loss(new_images,hr,loss_func,sr_small,hr_small)
        
        # train network
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()


        # IMAGE LOGGING
        if it%image_freq==0 and it!=1:
            # create image and turn to WandB format
            # wandb_image = plot_tensors(hr,sr,torch.clone(thetas))
            #wandb_image = plot_tensors_extra_info(hr,sr,new_images,new_images_loss,hr_loss,torch.clone(thetas))
            #wandb.log({"Training Image":wandb_image})
            plot_tensors_extra_info(a=hr,b=sr,c=new_images,d=hr_loss,e=new_images_loss,f=hr_loss,thetas=torch.clone(thetas))
            
        # METRICS LOGGING
        if it%log_freq==0 and it!=1:
            # create loss dict and append image
            losses = {
            "epoch":epoch,
            "train_loss":train_loss,
            "mae":torch.nn.functional.l1_loss(new_images,hr),
            "mse":torch.nn.functional.mse_loss(new_images,hr),
            "ssim_loss":kornia.losses.ssim_loss(new_images,hr,window_size=ssim_window_size),
            "psnr_loss":kornia.losses.psnr_loss(new_images,hr,max_val=1.0),
            "lr":get_lr(optimizer)}
            # send to WandB
            wandb.log(losses)
            # save if loss lower
            if train_loss<lowest_loss:
                lowest_loss = float(torch.clone(train_loss).detach().cpu().numpy())
                torch.save(regis_model,"checkpoints/regis_model.pth")

0,1
epoch,▁▁▁▁▁▁▁▁▁
lr,▁▁▁▁▁▁▁▁▁
mae,▂▆▅█▄▁▂▇▆
mse,▁▆▇█▄▁▂██
psnr_loss,▁▆▇█▅▁▂██
ssim_loss,▂▇▅▆▄▁▃█▆
train_loss,▅█▃▄▅▁▇▇█

0,1
epoch,1.0
lr,0.001
mae,0.01511
mse,0.00076
psnr_loss,-31.20064
ssim_loss,0.09438
train_loss,0.94274


Loss func:  <function l1_loss at 0x00000165B73AD160>
Epoch: 1


100%|##############################################################################| 3427/3427 [07:43<00:00,  7.40it/s]


Epoch: 2


100%|##############################################################################| 3427/3427 [07:37<00:00,  7.49it/s]


Epoch: 3


100%|##############################################################################| 3427/3427 [07:45<00:00,  7.36it/s]


Epoch: 4


100%|##############################################################################| 3427/3427 [07:40<00:00,  7.44it/s]


Epoch: 5


100%|##############################################################################| 3427/3427 [07:34<00:00,  7.55it/s]


Epoch: 6


100%|##############################################################################| 3427/3427 [07:20<00:00,  7.78it/s]


Epoch: 7


100%|##############################################################################| 3427/3427 [07:17<00:00,  7.84it/s]


Epoch: 8


100%|##############################################################################| 3427/3427 [07:23<00:00,  7.73it/s]


Epoch: 9


100%|##############################################################################| 3427/3427 [07:23<00:00,  7.72it/s]


Epoch: 10


100%|##############################################################################| 3427/3427 [07:45<00:00,  7.36it/s]


Epoch: 11


 56%|###########################################3                                  | 1906/3427 [04:24<03:02,  8.31it/s]

In [7]:
run_name = str(datetime.now().strftime("%d-%m-%Y_%H-%M-%S"))
wandb.init(name=run_name,project=wandb_project,entity=wandb_entity)

# initialize loss
loss_func = define_loss(which_loss_func)

lowest_loss = 9999
for epoch in range(epochs):
    epoch+=1
    print("Epoch:",epoch)
    it = 0
    for lr,hr in tqdm(train_loader,ascii=True):
        it+=1
        
        # PERFORM SR HERE, atm pnly interplation
        sr = superresolute(lr,factor=4)
        sr,hr = sr.to(device),hr.to(device)
       
        # crop image
        target_size = 128 # target w & h of image
        middle = sr.shape[2] //2 # get middle of tensor
        offset = target_size //2 # calculate offset needed from middle of tensor
        n_channels = sr.shape[1]
        hr_small = torch.clone(hr)[:,0:1,middle-offset:middle+offset,middle-offset: middle+offset] # perform crop and keep only 1 band
        sr_small = torch.clone(sr)[:,0:1,middle-offset:middle+offset,middle-offset: middle+offset] # perform crop and keep only 1 band
        #print(f'After Cropping: HR shape: {hr_small.shape}, SR shape: {sr_small.shape}')


        # rearrange from (B,C,W,H) to (B*3,1,W,H)
        hr_small = hr_small.view(-1, 1, 128, 128)
        sr_small = sr_small.view(-1, 1, 128, 128)
        if hr_small.shape!=sr_small.shape:
            print("shape mismatch")
        #print(f'After Cropping & rearranging: HR shape: {hr_small.shape}, SR shape: {sr_small.shape}')


        ## register_batch via network code
        n_views = hr_small.size(1) # get number of views -> amount of images in original, here its 1
        thetas = []
        for i in range(n_views): # iterate over channels (should be 1 in out case)
            theta = regis_model(torch.cat([hr_small[:, i : i + 1], sr_small[:, i : i + 1]], 1)) # send relevant channel to model
            thetas.append(theta)
        thetas = torch.stack(thetas, 1) # stack return
        #print(f'Thetas shape: {thetas.shape}')
        thetas = thetas[:, None, :, :].repeat(1, n_channels, 1, 1) # expand back to 3x channels
        #print(f'Thetas shape after expanding: {thetas.shape}')

        # perform translation
        # clone tensors (?)
        shifts = torch.clone(thetas)
        images = torch.clone(sr)
        
        # change names for clarity
        #shifts=thetas
        #images=sr

        ## apply_shift code
        batch_size, n_views, height, width = images.shape
        images = images.view(-1, 1, height, width)
        thetas = thetas.view(-1, 2)

        #print(f'Apply_shift to input shape: {images.shape}, thetas shape: {thetas.shape}')
        # perform translation via built-in function
        new_images = regis_model.transform(thetas, images) # error here
        #print(f'New Images shifted shape: {new_images.shape}')
        # rearrange from (B*C,1,H,W) to (B,3,H,W)
        new_images = new_images.view(-1, n_channels, images.size(2), images.size(3))
        hr = hr.view(-1, n_channels, images.size(2), images.size(3))
        #print(f'HR: {hr.shape} - ShiftNet ouput: {new_images.shape}')

        # calculate training loss
        # SR has been shifted with regards to HR-GT, therefore we need to calculate loss only over valid pixels
        #loss_mask = new_images==0 # get mask where 0
        #hr_masked = hr.masked_fill(loss_mask, 0.0) # in hr where mask is 0 with 0 aswell to minimize error
        new_images_loss = torch.clone(new_images)[:,:,middle-offset:middle+offset,middle-offset: middle+offset]
        hr_loss = torch.clone(hr)[:,:,middle-offset:middle+offset,middle-offset: middle+offset]
        
        
        #train_loss = loss_func(new_images_loss,hr_loss) # standard loss
        
        
        # Loss as relative decrease in loss as compared to unshifted images
        loss_before_shift = loss_func(sr_small,hr_small)
        loss_after_shift = loss_func(new_images_loss,hr_loss)
        loss_relative = (1/loss_before_shift)*loss_after_shift
        train_loss = loss_relative
        
        # train network
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        
        # VAL calc and Logging
        if it%val_freq==0 and it!=0:
            # TODO: theta pred and register batches outsourcing
            # perform prediction
            regis_model = regis_model.eval()
            hr_val = next(iter(val_loader))
            sr_val = shifter(sr,shift_factor)
            hr_val_small = torch.clone(hr_val)[:,0:1,middle-offset:middle+offset,middle-offset: middle+offset]
            sr_val_small = torch.clone(sr_val)[:,0:1,middle-offset:middle+offset,middle-offset: middle+offset]
            hr_val_small = hr_small.view(-1, 1, 128, 128)
            sr_val_small = sr_small.view(-1, 1, 128, 128)
            n_views = hr_val_small.size(1) # get number of views -> amount of images in original, here its 1
            thetas_val = []
            for i in range(n_views): # iterate over channels (should be 1 in out case)
                theta = regis_model(torch.cat([hr_small[:, i : i + 1], sr_small[:, i : i + 1]], 1)) # send relevant channel to model
                thetas_val.append(theta)
            thetas_val = torch.stack(thetas_val, 1) # stack return
            thetas_val = thetas_val[:, None, :, :].repeat(1, n_channels, 1, 1) # expand back to 3x channels
           
            thetas_pred = regis_model(hr_val)
            
            hr_val = torch.clone(hr_val)[:,:,middle-offset:middle+offset,middle-offset: middle+offset]
            pred = torch.clone(pred)[:,:,middle-offset:middle+offset,middle-offset: middle+offset]
            # calculate losses
            val_loss = loss_func(pred,hr_val)
            val_losses =  {
            "val_loss":val_loss,
            "mae_val":torch.nn.functional.l1_loss(pred,hr_val),
            "mse_val":torch.nn.functional.mse_loss(pred,hr_val),
            "ssim_loss_val":kornia.losses.ssim_loss(pred,hr_val,window_size=5),
            "psnr_loss_val":kornia.losses.psnr_loss(pred,hr_val,max_val=1.0),}
            # send to wandb
            wandb.log(val_losses)
            # put model back to train mode
            regis_model = regis_model.train()
        
        # IMAGE LOGGING
        if it%image_freq==0 and it!=1:
            # create image and turn to WandB format
            # wandb_image = plot_tensors(hr,sr,torch.clone(thetas))
            #wandb_image = plot_tensors_extra_info(hr,sr,new_images,new_images_loss,hr_loss,torch.clone(thetas))
            #wandb.log({"Training Image":wandb_image})
            plot_tensors_extra_info(a=hr,b=sr,c=new_images,d=hr_loss,e=new_images_loss,f=hr_loss,thetas=torch.clone(thetas))
            
        # METRICS LOGGING
        if it%log_freq==0 and it!=1:
            # create loss dict and append image
            losses = {
            "epoch":epoch,
            "train_loss":train_loss,
            "mae":torch.nn.functional.l1_loss(new_images,hr),
            "mse":torch.nn.functional.mse_loss(new_images,hr),
            "ssim_loss":kornia.losses.ssim_loss(new_images,hr,window_size=ssim_window_size),
            "psnr_loss":kornia.losses.psnr_loss(new_images,hr,max_val=1.0),
            "lr":get_lr(optimizer)}
            # send to WandB
            wandb.log(losses)
            # save if loss lower
            if train_loss<lowest_loss:
                lowest_loss = float(torch.clone(train_loss).detach().cpu().numpy())
                torch.save(regis_model,"checkpoints/regis_model.pth")
            
    # step scheduler
    scheduler.step(train_loss)
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33msimon-donike[0m. Use [1m`wandb login --relogin`[0m to force relogin


Loss func:  <function l1_loss at 0x000001D6ACE92160>
Epoch: 1


100%|################################################################################| 428/428 [08:37<00:00,  1.21s/it]


Epoch: 2


100%|################################################################################| 428/428 [08:36<00:00,  1.21s/it]


Epoch: 3


100%|################################################################################| 428/428 [08:36<00:00,  1.21s/it]


Epoch: 4


100%|################################################################################| 428/428 [08:37<00:00,  1.21s/it]


Epoch: 5


100%|################################################################################| 428/428 [08:39<00:00,  1.21s/it]


Epoch: 6


100%|################################################################################| 428/428 [08:36<00:00,  1.21s/it]


Epoch: 7


 88%|######################################################################2         | 376/428 [07:44<01:04,  1.23s/it]


KeyboardInterrupt: 

# Appendix: Sweep Search for best LR and SSIM window size

In [7]:


def train_sweep():
    epochs = 15
    
    wandb.init(project='ShiftNet_sweep')
    lr = wandb.config.lr
    batch_size = wandb.config.batch_size
    which_loss_func =  wandb.config.which_loss_func
    #lr = 0.001
    #batch_size = 1
    #which_loss_func = "MAE"
    
    # initialize dataset and dataloader
    dataset_train = dataset_spot6("data/train/")
    train_loader = DataLoader(dataset_train,batch_size=batch_size,shuffle=True, num_workers=8,pin_memory=True,drop_last=True,prefetch_factor=32) # prefetch 32
    print("Dataset Length: ",len(dataset_train))

    # initialize loss
    loss_func = define_loss(which_loss_func)
    
    # initialize optim
    optimizer = torch.optim.Adam(list(regis_model.parameters()), lr=lr)

    for epoch in range(epochs):
        epoch+=1
        print("Epoch:",epoch)
        it = 0
        for hr in tqdm(train_loader,ascii=True):
            it+=1
            #sr = superresolute(lr) # mimic SR by bicub. interpolation
            hr = hr.to(device)
            sr = hr.clone() # copy HR as SR

            if train_shifted: # if we want synthetic data
                #lr = superresolute(hr,factor=0.25) # downsample 300 SPOT6 to 75
                #sr = superresolute(lr,factor=0.4)  # artificially super-resolute from 30

                sr = shifter(sr,shift_factor) # shift sr image
            sr,hr = sr.to(device),hr.to(device)

            # crop image
            target_size = 128 # target w & h of image
            middle = sr.shape[2] //2 # get middle of tensor
            offset = target_size //2 # calculate offset needed from middle of tensor
            n_channels = sr.shape[1]
            hr_small = torch.clone(hr)[:,0:1,middle-offset:middle+offset,middle-offset: middle+offset] # perform crop and keep only 1 band
            sr_small = torch.clone(sr)[:,0:1,middle-offset:middle+offset,middle-offset: middle+offset] # perform crop and keep only 1 band
            #print(f'After Cropping: HR shape: {hr_small.shape}, SR shape: {sr_small.shape}')


            # rearrange from (B,C,W,H) to (B*3,1,W,H)
            hr_small = hr_small.view(-1, 1, 128, 128)
            sr_small = sr_small.view(-1, 1, 128, 128)
            if hr_small.shape!=sr_small.shape:
                print("shape mismatch")
            #print(f'After Cropping & rearranging: HR shape: {hr_small.shape}, SR shape: {sr_small.shape}')


            ## register_batch via network code
            n_views = hr_small.size(1) # get number of views -> amount of images in original, here its 1
            thetas = []
            for i in range(n_views): # iterate over channels (should be 1 in out case)
                theta = regis_model(torch.cat([hr_small[:, i : i + 1], sr_small[:, i : i + 1]], 1)) # send relevant channel to model
                thetas.append(theta)
            thetas = torch.stack(thetas, 1) # stack return
            #print(f'Thetas shape: {thetas.shape}')
            thetas = thetas[:, None, :, :].repeat(1, n_channels, 1, 1) # expand back to 3x channels
            #print(f'Thetas shape after expanding: {thetas.shape}')

            # perform translation
            # clone tensors (?)
            shifts = torch.clone(thetas)
            images = torch.clone(sr)

            # change names for clarity
            #shifts=thetas
            #images=sr

            ## apply_shift code
            batch_size, n_views, height, width = images.shape
            images = images.view(-1, 1, height, width)
            thetas = thetas.view(-1, 2)

            #print(f'Apply_shift to input shape: {images.shape}, thetas shape: {thetas.shape}')
            # perform translation via built-in function
            new_images = regis_model.transform(thetas, images) # error here
            #print(f'New Images shifted shape: {new_images.shape}')
            # rearrange from (B*C,1,H,W) to (B,3,H,W)
            new_images = new_images.view(-1, n_channels, images.size(2), images.size(3))
            hr = hr.view(-1, n_channels, images.size(2), images.size(3))
            #print(f'HR: {hr.shape} - ShiftNet ouput: {new_images.shape}')

            # calculate training loss
            # SR has been shifted with regards to HR-GT, therefore we need to calculate loss only over valid pixels
            #loss_mask = new_images==0 # get mask where 0
            #hr_masked = hr.masked_fill(loss_mask, 0.0) # in hr where mask is 0 with 0 aswell to minimize error
            new_images_loss = torch.clone(new_images)[:,:,middle-offset:middle+offset,middle-offset: middle+offset]
            hr_loss = torch.clone(hr)[:,:,middle-offset:middle+offset,middle-offset: middle+offset]
            #train_loss=nn.functional.l1_loss(new_images_loss,hr_loss) # compute loss
            #train_loss = kornia.losses.ssim_loss(new_images_loss,hr_loss,window_size=ssim_window_size)
            train_loss = loss_func(new_images_loss,hr_loss)

            # train network
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()

            if it%image_freq==0 and it!=1:
                # create image and turn to WandB format
                # wandb_image = plot_tensors(hr,sr,torch.clone(thetas))
                #wandb_image = plot_tensors_extra_info(hr,sr,new_images,new_images_loss,hr_loss,torch.clone(thetas))
                #wandb.log({"Training Image":wandb_image})
                plot_tensors_extra_info(a=hr,b=sr,c=new_images,d=hr_loss,e=new_images_loss,f=hr_loss,thetas=torch.clone(thetas))
            if it%log_freq==0 and it!=1:
                # create loss dict and append image
                losses = {
                "epoch":epoch,
                "train_loss":train_loss,
                "mae":torch.nn.functional.l1_loss(new_images,hr),
                "mse":torch.nn.functional.mse_loss(new_images,hr),
                "ssim_loss":kornia.losses.ssim_loss(new_images,hr,window_size=ssim_window_size),
                "psnr_loss":kornia.losses.psnr_loss(new_images,hr,max_val=1.0),}

                # send to WandB
                wandb.log(losses)
                
                # save if best result yet

        
        #scheduler.step(train_loss)

In [None]:
run_name = str(datetime.now().strftime("%d-%m-%Y_%H-%M-%S"))
sweep_configuration = {
    'method': 'grid',
    'name': 'sweep',
    'metric': {'goal': 'minimize', 'name': 'train_loss'},
    'parameters': 
    {
        'lr': {"values":[0.001,0.0001]},
        'which_loss_func':{"values":["MSE"]},  #"MSE","SSIM","PSNR"
        'batch_size':{"values":[1,4,16]},
     }
}


sweep_id = wandb.sweep(entity="simon-donike",sweep=sweep_configuration, project='ShiftNet_sweep')
wandb.agent(sweep_id, function=train_sweep, count=10)

Create sweep with ID: kpdveo60
Sweep URL: https://wandb.ai/simon-donike/ShiftNet_sweep/sweeps/kpdveo60


[34m[1mwandb[0m: Agent Starting Run: e2to4hi5 with config:
[34m[1mwandb[0m: 	batch_size: 1
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	which_loss_func: MAE
[34m[1mwandb[0m: Currently logged in as: [33msimon-donike[0m. Use [1m`wandb login --relogin`[0m to force relogin


Dataset Length:  10000
Loss func:  <function l1_loss at 0x000001B088940160>
Epoch: 1


100%|############################################################################| 10000/10000 [04:23<00:00, 38.00it/s]


Epoch: 2


100%|############################################################################| 10000/10000 [04:16<00:00, 38.98it/s]


Epoch: 3


100%|############################################################################| 10000/10000 [04:17<00:00, 38.86it/s]


Epoch: 4


100%|############################################################################| 10000/10000 [04:15<00:00, 39.18it/s]


Epoch: 5


100%|############################################################################| 10000/10000 [04:13<00:00, 39.39it/s]


Epoch: 6


100%|############################################################################| 10000/10000 [04:14<00:00, 39.29it/s]


Epoch: 7


100%|############################################################################| 10000/10000 [04:13<00:00, 39.48it/s]


Epoch: 8


100%|############################################################################| 10000/10000 [04:14<00:00, 39.28it/s]


Epoch: 9


100%|############################################################################| 10000/10000 [04:14<00:00, 39.22it/s]


Epoch: 10


100%|############################################################################| 10000/10000 [04:14<00:00, 39.36it/s]


Epoch: 11


100%|############################################################################| 10000/10000 [04:14<00:00, 39.29it/s]


Epoch: 12


100%|############################################################################| 10000/10000 [04:14<00:00, 39.26it/s]


Epoch: 13


100%|############################################################################| 10000/10000 [04:13<00:00, 39.38it/s]


Epoch: 14


100%|############################################################################| 10000/10000 [04:15<00:00, 39.18it/s]


Epoch: 15


100%|############################################################################| 10000/10000 [04:13<00:00, 39.43it/s]


0,1
epoch,▁▁▁▁▁▁▂▂▃▃▃▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▅▆▆▇▇▇▇▇▇▇▇███
mae,▅▅▅▄▆▄▄▄▅█▅▄▅▄▇▄▅▅▂▆▅▇▅▄▅▄▃▄▆▁█▅▆▆▅▄▆▆▄▅
mse,▄▄▄▃▅▃▂▃▄█▃▂▃▃▆▃▄▄▁▄▄▇▄▃▄▂▂▃▄▁▇▄▅▅▄▃▅▅▂▃
psnr_loss,███▇█▇▇▇███▇██████▇████▇█▇▇██▁████████▇█
ssim_loss,▅▅▅▅▇▄▅▄▅█▅▄▅▄▇▄▄▅▂▅▅▇▅▄▅▄▃▄▅▁▇▄▆▆▅▄▅▆▄▅
train_loss,▄▄▅▄█▄▃▄▄█▄▄▄▃▇▃▃▅▂▄▅▇▄▄▅▃▃▃▄▁▇▄▆▇▄▃▄▅▃▃

0,1
epoch,15.0
mae,0.01669
mse,0.00087
psnr_loss,-30.60095
ssim_loss,0.09286
train_loss,0.01293


[34m[1mwandb[0m: Agent Starting Run: ufkeu1fq with config:
[34m[1mwandb[0m: 	batch_size: 1
[34m[1mwandb[0m: 	lr: 0.0001
[34m[1mwandb[0m: 	which_loss_func: MAE


Dataset Length:  10000
Loss func:  <function l1_loss at 0x000001B088940160>
Epoch: 1


100%|############################################################################| 10000/10000 [04:12<00:00, 39.55it/s]


Epoch: 2


100%|############################################################################| 10000/10000 [04:14<00:00, 39.23it/s]


Epoch: 3


100%|############################################################################| 10000/10000 [04:12<00:00, 39.66it/s]


Epoch: 4


100%|############################################################################| 10000/10000 [04:15<00:00, 39.22it/s]


Epoch: 5


100%|############################################################################| 10000/10000 [04:13<00:00, 39.52it/s]


Epoch: 6


100%|############################################################################| 10000/10000 [04:15<00:00, 39.07it/s]


Epoch: 7


100%|############################################################################| 10000/10000 [04:12<00:00, 39.54it/s]


Epoch: 8


100%|############################################################################| 10000/10000 [04:13<00:00, 39.47it/s]


Epoch: 9


100%|############################################################################| 10000/10000 [04:18<00:00, 38.69it/s]


Epoch: 10


 74%|#########################################################1                   | 7422/10000 [03:16<01:11, 36.10it/s]