## Imports

In [None]:
%pip install -r requirements.txt

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from skimage import io
from skimage.filters import meijering
import matplotlib.pyplot as plt
from astropy.visualization import simple_norm

from NormalizedData import NormalizedData
from Models import *
from utils import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Load Data & Normalize

In [None]:
data_path = 'sample_data.tif' # Enter raw data path
normalized = False
save_path = 'normalized_data.tif' # Enter path for saving normalized data

train_dataset = NormalizedData(data_path, is_normalized=normalized, dest=save_path, mask_threshold=0.93, label_threshold=130, modified_db_size='original')

## Training

### Training Parameters

In [None]:
batch_size = 16
scale_factor = 4
intp_mode = 'nearest'

lista_folds = 1
kernel_size_enc = 30
kernel_size_dec = 30
training_iters = 1
inference_iters = 2

learning_rate = 0.001
no_of_epochs = 25

##################################### Inference function #####################################
inference_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=2)
def infer(model):
    ol = []
    for c in inference_loader:
        with torch.no_grad():
            _ , f_output = model(c, iters=inference_iters, interp_mode=intp_mode)
            output = np.squeeze(f_output.cpu().numpy())
        ol.append(output)
    return np.concatenate(ol)

### Training Loop

In [None]:
trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

model = DecoderEncoder(e_kernel_size=kernel_size_enc, d_kernel_size=kernel_size_dec, scale_factor=scale_factor)
model = model.to(device)

criterion_l1 = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

outputs = []
no_of_batches_per_epoch = len(train_dataset)//batch_size

for epoch in range(no_of_epochs):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        
        # get the input; data is a list of [inputs]
        inpt = data
        inpt = inpt.float()
        inpt = inpt.to(device)
        
        # interpolate the input (for loss calculation)
        inpt_interp = F.interpolate(inpt, scale_factor=scale_factor, mode=intp_mode)
        inpt_interp = inpt_interp.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward 
        encoder_output, decoder_output = model(inpt, iters=training_iters, interp_mode=intp_mode)

        # loptimize over loss
        loss =  criterion_l1(encoder_output, inpt_interp)
        loss.sum().backward()
        optimizer.step()
        
        running_loss += loss.sum().item()

    # print running loss      
    print(f'[{epoch + 1}] loss: {running_loss / no_of_batches_per_epoch:.8f}')
    # run inference for current epoch
    outputs.append(infer(model))

print('Finished Training')

### Plot Results (Per Epoch)

In [None]:
### Plot Input & GT
input_im = np.std(np.squeeze(train_dataset.data.cpu().numpy()), axis=0)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(4, 4))
ax.imshow(input_im, cmap = 'afmhot')
ax.set_title('Input', fontdict = {'fontsize': 40})

### Plot results
r = len(outputs)//5
fig, ax = plt.subplots(nrows=r, ncols=5, figsize=(100, r*20))
for i in range(r*5):
    oi = np.std(outputs[i], axis=0)
    oi = meijering(oi,black_ridges=False, sigmas=2)
    oi = meijering(oi,black_ridges=False, sigmas=1)
    ax[i//5, i % 5].imshow(oi, cmap='afmhot', norm = simple_norm(oi, percent = 99.5))
    ax[i//5, i % 5].set_title('{e} Epochs'.format(e = (i+1)), fontdict = {'fontsize': 40})

### Save result

In [None]:
no_epochs_save = 10
save_path = "result_image.tif"

oi = np.std(outputs[no_epochs_save-1], axis=0)
oi = meijering(oi,black_ridges=False, sigmas=2)
oi = meijering(oi,black_ridges=False, sigmas=1)
io.imsave(save_path, oi)