In [None]:
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import TensorDataset,DataLoader, Subset
import torch.nn as nn
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
torch.autograd.set_detect_anomaly(True)
#from kornia.metrics import SSIM
import kornia
from decomposition import Decomposition
from restoration import Restoration
from illuminationAdjustment import IlluminationAdjustment,AdjustmentLoss
from image_utils import *

In [None]:
if torch.cuda.is_available():
    # Get the number of available GPUs
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs available: {num_gpus}")

    # Set the default GPU (assuming you have at least one GPU)
    torch.cuda.set_device(0)


    # Optional: Print the name of the current GPU
    current_device = torch.cuda.current_device()
    print(f"Current GPU: {torch.cuda.get_device_name(current_device)}")
else:
    print('CUDA unavailable')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Replace the path to the folder containing the images
# Structure is assumed to be:   train_images_directory/high/*.png ; train_images_directory/low/*.png
#                               test_images_directory/high/*.png ; test_images_directory/low/*.png

train_images_directory='./../../../../LOLdataset/our485'
test_images_directory='./../../../../LOLdataset/eval15'

train_dataset,test_dataset=preprocessDataset(train_images_directory,test_images_directory)

In [None]:
global batch_size
batch_size = 15

# Access class names
class_data_loaders=getDataLoaders(train_dataset,batch_size)
test_data_loaders=getDataLoaders(test_dataset,batch_size)

In [None]:
# Verifying if the data is loaded correctly and iamges are paired appropriately. We user iter to get the first batch of data (10 images)

high_img=next(iter(class_data_loaders['high']))
low_img=next(iter(class_data_loaders['low']))

verifyDataset(high_img,low_img)

In [None]:
decom=Decomposition().to(device)
rest=Restoration().to(device)
adjust=IlluminationAdjustment().to(device)
optimizer_adjustment = torch.optim.Adam(adjust.parameters(), lr=0.001,weight_decay=0.001)
decom.load_state_dict(torch.load('./Saved_models/only_decom_parameters.pth',map_location=device))
rest.load_state_dict(torch.load('./Saved_models/only_restoration_parameters.pth',map_location=device))

In [None]:
sobel_horizontal,sobel_vertical=generateSobelFilters()
def displayImages(r_h,r_l,i_h,i_l,r_map,a_map):

    final_img=torch.mul(a_map.expand_as(r_map),r_map)
    
    fig,ax=plt.subplots(3,3,figsize=(10,10))
    
    ax[0,0].imshow(r_h[0].detach().numpy().transpose(1, 2, 0))
    ax[0,0].title.set_text('Reflectance Normal')
    ax[1,0].imshow(i_h[0].detach().numpy().transpose(1, 2, 0))
    ax[1,0].title.set_text('Illumination  Normal')
    
    ax[0,1].imshow(r_l[0].detach().numpy().transpose(1, 2, 0))
    ax[0,1].title.set_text('Reflectance Low')
    ax[1,1].imshow(i_l[0].detach().numpy().transpose(1, 2, 0))
    ax[1,1].title.set_text('Illumination Low')

    ax[0,2].imshow(r_map[0].detach().numpy().transpose(1, 2, 0))
    ax[0,2].title.set_text('Restored Image')
    ax[1,2].imshow(a_map[0].detach().numpy().transpose(1, 2, 0))
    ax[1,2].title.set_text('Illumination Adjusted Image')
    
    ax[2,1].imshow(final_img[0].detach().numpy().transpose(1, 2, 0))
    ax[2,1].title.set_text('Final Enhanced image')
    
    plt.show()

In [None]:
train_errors=[]
test_errors=[]
max_epochs=31
test_at_epochs=list(range(0,max_epochs,10))

In [None]:
for epoch in range(max_epochs):
    #print('*************************************************************************************')
    adjust.train()
    rest.eval()
    decom.eval()
    epoch_loss = 0.0
    for high,low in zip(class_data_loaders['high'],class_data_loaders['low']):

        optimizer_adjustment.zero_grad()
        high_img,label_high=high
        low_img,label_low=low
        high_img,label_high=high_img.to(device),label_high.to(device)
        low_img,label_low=low_img.to(device),label_low.to(device)

        with torch.no_grad():
            ref_high,illum_high=decom.forward(high_img)
            ref_low,illum_low=decom.forward(low_img)
        
        restoration_map=rest.forward(ref_low,illum_low)

        alpha=5*torch.ones(illum_low.shape[0],1,illum_low.shape[2],illum_low.shape[3]).to(device)
        adjustment_map=adjust.forward(illum_low,alpha)

        adjustment_loss=AdjustmentLoss(adjustment_map,illum_high,sobel_horizontal,sobel_vertical)

        epoch_loss+=adjustment_loss

        adjustment_loss.backward()

        optimizer_adjustment.step()  

    final_train_loss=epoch_loss/len(class_data_loaders['high'])

    print(f'Epoch {epoch} Train Loss is {final_train_loss}')
    train_errors.append(final_train_loss)
    #displayImages(ref_high.cpu(),ref_low.cpu(),illum_high.cpu(),illum_low.cpu(),restoration_map.cpu(),adjustment_map.cpu())
    if epoch in test_at_epochs and epoch!=0:
        rest.eval()
        test_loss=0
        with torch.no_grad():
            #print('Inside test')
            for h,l in zip(test_data_loaders['high'],test_data_loaders['low']):
                high_img_test,label_high_test=h
                low_img_test,label_low_test=l
                high_img_test,label_high_test=high_img_test.to(device),label_high_test.to(device)
                low_img_test,label_low_test=low_img_test.to(device),label_low_test.to(device)

                ref_high_test,illum_high_test=decom.forward(high_img_test)
                ref_low_test,illum_low_test=decom.forward(low_img_test)

                restoration_map_test=rest.forward(ref_low_test,illum_low_test)
                alpha_test=5*torch.ones(illum_low_test.shape[0],1,illum_low_test.shape[2],illum_low_test.shape[3]).to(device)
                adjustment_map_test=adjust.forward(illum_low_test,alpha_test)
                adjustment_loss_test=AdjustmentLoss(adjustment_map_test,illum_high_test,sobel_horizontal,sobel_vertical)

                test_loss+=adjustment_loss_test
        
        final_test_loss=test_loss/len(test_data_loaders['high'])
        test_errors.append(final_test_loss)
        print(f'Test loss after {epoch} epochs is {final_test_loss} for all test images')
        rest.train()
        print(f'After {epoch} epochs we have test results as')
        displayImages(ref_high_test.cpu(),ref_low_test.cpu(),illum_high_test.cpu(),illum_low_test.cpu(),restoration_map_test.cpu(),adjustment_map_test.cpu())
        torch.save(adjust,'./Saved_models/only_adjustment_model.pth') # This will save entire model but we need not use the entire model but only need state dict
        torch.save(adjust.state_dict(), './Saved_models/only_adjustment_parameters.pth') # This is what needs to be saved which is more portable
        torch.save(optimizer_adjustment.state_dict(), './Saved_models/only_optimizer_adjustment_parameters.pth')