In [None]:
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import TensorDataset,DataLoader, Subset
import torch.nn as nn
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
torch.autograd.set_detect_anomaly(True)
#from kornia.metrics import SSIM
import kornia
from decomposition import Decomposition
from restoration import Restoration
from illuminationAdjustment import IlluminationAdjustment,AdjustmentLoss
from image_utils import *
import cv2

In [None]:
if torch.cuda.is_available():
    # Get the number of available GPUs
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs available: {num_gpus}")

    # Set the default GPU (assuming you have at least one GPU)
    torch.cuda.set_device(0)


    # Optional: Print the name of the current GPU
    current_device = torch.cuda.current_device()
    print(f"Current GPU: {torch.cuda.get_device_name(current_device)}")
else:
    print('CUDA unavailable')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
#test_images_directory='./../../../../LOLdataset/eval15'
test_images_directory='C:/Users/pspra/OneDrive/Desktop/ECE_271_Project/GAN_dataset/test'
test_dataset=testData(test_images_directory)
global batch_size
batch_size = 15

test_data_loaders=getDataLoaders(test_dataset,batch_size)

model = torch.hub.load('ultralytics/yolov5', 'yolov5s',pretrained=True)
model.eval()

In [None]:
decom=Decomposition().to(device)
rest=Restoration().to(device)
adjust=IlluminationAdjustment().to(device)
decom.load_state_dict(torch.load('./Saved_models_crop/only_decom_parameters.pth',map_location=device))
rest.load_state_dict(torch.load('./Saved_models_crop/only_restoration_parameters.pth',map_location=device))
adjust.load_state_dict(torch.load('./Saved_models_crop/only_adjustment_parameters.pth',map_location=device))


In [None]:
def displayImages(actual_high, actual_low, r_map, a_map,ref_low,final_img,count):
    batch_size = actual_high.shape[0]

    for i in range(batch_size):
        

        fig, ax = plt.subplots(1, 3, figsize=(10,5))

        ax[0].imshow(actual_high[i].detach().numpy().transpose(1, 2, 0))
        ax[0].title.set_text('Normal Image')
        ax[1].imshow(actual_low[i].detach().numpy().transpose(1, 2, 0))
        ax[1].title.set_text('Low Light Image')
        ax[2].imshow(final_img[i].detach().numpy().transpose(1, 2, 0))
        ax[2].title.set_text('Final Enhanced image')
        # ax[2].imshow(ref_low[i].detach().numpy().transpose(1, 2, 0))
        # ax[2].title.set_text('Final Enhanced image')
        f_name='./Enhanced_images/Image_'+str(count)+'_'+str(i)+'.jpg'
        plt.savefig('./Results/Image_'+str(count)+'_'+str(i)+'.jpg')
        final_img_np = final_img[i].detach().cpu().numpy().transpose(1, 2, 0)
        final_img_np = (final_img_np * 255).astype(np.uint8)
        Image.fromarray(final_img_np).save(f_name)
        #plt.savefig(f_name)
        
        plt.show()

In [None]:
def displaytight(actual_high, actual_low, r_map, a_map, ref_low, final_img):#,results):
    batch_size = actual_high.shape[0]
    #rendered_img = results.render()
    for i in range(0, batch_size, 2):
        if i + 1 < batch_size:

            #fig,ax=plt.subplots(1,3,figsize=(10,5))
            fig = plt.figure(constrained_layout=False,figsize=(10,5))
            subplots = fig.subfigures(1, 2)
            ax0 = subplots[0].subplots(1, 2)
            ax1 = subplots[1].subplots(1, 2)

            ax0[0].imshow(actual_low[i].detach().numpy().transpose(1, 2, 0))
            ax0[1].imshow(final_img[i].detach().numpy().transpose(1, 2, 0))
            ax0[0].title.set_text('Low Light Image')
            ax0[1].title.set_text('Enhanced image')
            ax0[0].tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)
            ax0[1].tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)

            ax1[0].imshow(actual_low[i + 1].detach().numpy().transpose(1, 2, 0))
            ax1[1].imshow(final_img[i + 1].detach().numpy().transpose(1, 2, 0))
            ax1[0].title.set_text('Low Light Image')
            ax1[1].title.set_text('Enhanced image')
            ax1[0].tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)
            ax1[1].tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)

            # ax[0].imshow(actual_low[i].detach().numpy().transpose(1, 2, 0))
            # ax[1].imshow(final_img[i].detach().numpy().transpose(1, 2, 0))
            # ax[0].title.set_text('Low Light Image')
            # ax[1].title.set_text('Enhanced image')
            # ax[0].tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)
            # ax[1].tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)

            # ax[2].imshow(rendered_img[i].detach().numpy().transpose(1, 2, 0))
            # ax[2].title.set_text('Detection using YOLO')
            # ax[2].tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)
            
            plt.show()
    

In [None]:
decom.eval()
rest.eval()
adjust.eval()

In [None]:
with torch.no_grad():
        count=0
        for h,l in zip(test_data_loaders['high'],test_data_loaders['low']):
                high_img_test,label_high_test=h
                low_img_test,label_low_test=l
                high_img_test,label_high_test=high_img_test.to(device),label_high_test.to(device)
                low_img_test,label_low_test=low_img_test.to(device),label_low_test.to(device)

                ref_low,illum_low=decom.forward(low_img_test)
                restored_img=rest(ref_low,illum_low)
                alpha=3*torch.ones(illum_low.shape[0],1,illum_low.shape[2],illum_low.shape[3]).to(device)
                adjustment_map=adjust.forward(illum_low,alpha)
                final_img = torch.mul(adjustment_map.expand_as(restored_img), restored_img)              
                displayImages(high_img_test,low_img_test,restored_img,adjustment_map,ref_low,final_img,count)
                count+=1
                #displaytight(high_img_test,low_img_test,restored_img,adjustment_map,ref_low,final_img)#,results)
