In [None]:
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import TensorDataset,DataLoader, Subset
import torch.nn as nn
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
torch.autograd.set_detect_anomaly(True)
#from kornia.metrics import SSIM
import kornia
from decomposition import Decomposition,DecomLoss
from image_utils import *

In [None]:
if torch.cuda.is_available():
    # Get the number of available GPUs
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs available: {num_gpus}")

    # Set the default GPU 
    torch.cuda.set_device(0)

    # Optional: Print the name of the current GPU
    current_device = torch.cuda.current_device()
    print(f"Current GPU: {torch.cuda.get_device_name(current_device)}")
else:
    print('CUDA unavailable')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Replace the path to the folder containing the images
# Structure is assumed to be:   train_images_directory/high/*.png ; train_images_directory/low/*.png
#                               test_images_directory/high/*.png ; test_images_directory/low/*.png

train_images_directory='./../../../../LOLdataset/our485'
test_images_directory='./../../../../LOLdataset/eval15'

train_dataset,test_dataset=preprocessDataset(train_images_directory,test_images_directory)

In [None]:
global batch_size
batch_size = 15

# Access class names
class_data_loaders=getDataLoaders(train_dataset,batch_size)
test_data_loaders=getDataLoaders(test_dataset,batch_size)

In [None]:
# Verifying if the data is loaded correctly and iamges are paired appropriately. We user iter to get the first batch of data (10 images)

high_img=next(iter(class_data_loaders['high']))
low_img=next(iter(class_data_loaders['low']))

verifyDataset(high_img,low_img)

In [None]:
decom=Decomposition().to(device)
optimizer_decom = torch.optim.Adam(decom.parameters(), lr=0.001,weight_decay=0.001)

In [None]:
sobel_horizontal,sobel_vertical=generateSobelFilters()
sobel_horizontal=sobel_horizontal.to(device)
sobel_vertical=sobel_vertical.to(device)
def displayImages(r_h,r_l,i_h,i_l):
    
    fig,ax=plt.subplots(2,2,figsize=(8,8))
    
    ax[0,0].imshow(r_h[0].detach().numpy().transpose(1, 2, 0))
    ax[0,0].title.set_text('Reflectance Component for Normal')
    ax[0,1].imshow(i_h[0].detach().numpy().transpose(1, 2, 0))
    ax[0,1].title.set_text('Illumination Component for Normal')
    
    
    ax[1,0].imshow(r_l[0].detach().numpy().transpose(1, 2, 0))
    ax[1,0].title.set_text('Reflectance Component for Low')
    ax[1,1].imshow(i_l[0].detach().numpy().transpose(1, 2, 0))
    ax[1,1].title.set_text('Illumination Component for Low')
    
    plt.show()

In [None]:
train_errors=[]
test_errors=[]
max_epochs=31
test_at_epochs=list(range(0,max_epochs,10))

In [None]:
for epoch in range(max_epochs):
    #print('*************************************************************************************')
    decom.train()
    epoch_loss = 0.0
    for high,low in zip(class_data_loaders['high'],class_data_loaders['low']):

        optimizer_decom.zero_grad()
        high_img,label_high=high
        low_img,label_low=low
        high_img,label_high=high_img.to(device),label_high.to(device)
        low_img,label_low=low_img.to(device),label_low.to(device)
        ref_high,illum_high=decom.forward(high_img)
        ref_low,illum_low=decom.forward(low_img)

        decom_loss= DecomLoss(ref_low,ref_high,illum_low,illum_high,sobel_horizontal,sobel_vertical,low_img,high_img)
        #print(f'Decom loss is {decom_loss}')

        epoch_loss+=decom_loss

        decom_loss.backward()

        optimizer_decom.step()    

    final_train_loss=epoch_loss/len(class_data_loaders['high'])

    print(f'Epoch {epoch} Train Loss is {final_train_loss}')
    train_errors.append(final_train_loss)
    displayImages(ref_high.cpu(),ref_low.cpu(),illum_high.cpu(),illum_low.cpu())
    if epoch in test_at_epochs and epoch!=0:
        decom.eval()
        test_loss=0
        with torch.no_grad():
            for h,l in zip(test_data_loaders['high'],test_data_loaders['low']):
                high_img_test,label_high_test=h
                low_img_test,label_low_test=l
                high_img_test,label_high_test=high_img_test.to(device),label_high_test.to(device)
                low_img_test,label_low_test=low_img_test.to(device),label_low_test.to(device)

                ref_high_test,illum_high_test=decom.forward(high_img_test)
                ref_low_test,illum_low_test=decom.forward(low_img_test)

                decom_loss_test= DecomLoss(ref_low_test,ref_high_test,illum_low_test,illum_high_test,sobel_horizontal,sobel_vertical,low_img_test,high_img_test)

                test_loss+=decom_loss_test
        
        final_test_loss=test_loss/len(test_data_loaders['high'])
        test_errors.append(final_test_loss)
        print(f'Test loss after {epoch} epochs is {final_test_loss} for all test images')
        decom.train()
        print(f'After {epoch} epochs we have test results as')
        displayImages(ref_high_test.cpu(),ref_low_test.cpu(),illum_high_test.cpu(),illum_low_test.cpu())
        torch.save(decom,'./Saved_models/only_decomposition_model.pth') # This will save entire model but we need not use the entire model but only need state dict
        torch.save(decom.state_dict(), './Saved_models/only_decom_parameters.pth') # This is what needs to be saved which is more portable
        torch.save(optimizer_decom.state_dict(), './Saved_models/only_optimizer_decom_parameters.pth')