# Get generated images with highest PSNR/Similarity

### The higher the value of PSNR (in decibels/dB), the better the reconstruction quality
### SSIM ranges between 0 and 1, where a higher value indicates greater structural coherence and thus better Dehazing results

In [None]:
import numpy as np
from math import log10
from skimage.metrics import structural_similarity as ssim
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from PIL import Image
import os
from database2 import DehazingDataset
import matplotlib.pyplot as plt
from skimage.transform import resize

In [None]:
%pip install scikit-image

In [None]:
def PSNR(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:  
        return 100
    max_pixel = 255.0
    psnr = 10 * log10(max_pixel**2 / mse)
    return psnr

def visualize_top_images(psnr, top_x):
    # BASED ON HIGHEST PSNR
    if psnr == 1:
        top_images = [generated_images[i] for i in sorted_indices_psnr[:top_x]]
    else: # BASED ON HIGHEST SSIM
        top_images = [generated_images[i] for i in sorted_indices_ssim[:top_x]]

    fig, axes = plt.subplots(1, top_x, figsize=(15, 5))
    for i in range(top_x):
        image = top_images[i].transpose(1, 2, 0)
        image = image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])  # Denormalize
        # Transpose dimensions from (3, 256, 256) to (256, 256, 3)
        axes[i].imshow(image)
        axes[i].axis('off')
        axes[i].set_title(f"Image {i+1}")

    plt.tight_layout()
    plt.show()




class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
        
# similarity = ssim(img1, img2)

In [None]:
root_dir = 'Task2Dataset'
train_dir = os.path.join(root_dir, 'train')
val_dir = os.path.join(root_dir, 'val')
transform = transforms.Compose([
                                #  transforms.Resize((224, 224)), # ASSUMING NO NEED FOR RESIZING AS ALL IMAGES ARE ALREADY 256*256
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
                                 ])

train_dataset = DehazingDataset(train_dir, transform)
val_dataset = DehazingDataset(val_dir, transform)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

In [None]:
# INITIALISE AND LOAD WEIGHTS TO A GENERATOR

generator = Generator()

weights_path = 'generator_l1loss_scheduler.pth'

# Load the weights into the generator model
generator.load_state_dict(torch.load(weights_path))

# Set the generator to evaluation mode
generator.eval()

In [None]:


# Iterate through hazy images and compute PSNR or SSIM for each generated image
psnr_scores = []
ssim_scores = []
generated_images = []

for hazy_imgs, clean_imgs in tqdm(train_dataloader, desc='Computing Metrics'):
    # Generate images using your GAN model (replace this with your GAN inference code)
    generated_imgs = generator(hazy_imgs)
    
    # Compute PSNR and SSIM for each pair of generated and clean images in the batch
    for generated_img, clean_img in zip(generated_imgs, clean_imgs):
        # Convert tensors to numpy arrays
        generated_img_np = generated_img.detach().cpu().numpy()
        clean_img_np = clean_img.detach().cpu().numpy()

        clean_image = images[i].permute(1, 2, 0).cpu().numpy()  # Convert to NumPy array
        clean_image = clean_image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])  # Denormalize
        

        # Store for later visualisation
        generated_images.append(generated_img_np)
        
        # Compute PSNR
        psnr = PSNR(generated_img_np, clean_img_np)
        psnr_scores.append(psnr)
        
        # Compute SSIM
        # resized_generated_img_np = resize(generated_img_np, clean_img_np.shape)
        # ssimilarity = ssim(resized_generated_img_np, clean_img_np, multichannel=True)
        # ssim_scores.append(ssimilarity)

In [None]:
min_value = np.amin(clean_img_np)
max_value = np.amax(clean_img_np)

print("Minimum pixel value:", min_value)
print("Maximum pixel value:", max_value)

In [None]:
min_value = np.amin(generated_img_np)
max_value = np.amax(generated_img_np)

print("Minimum pixel value:", min_value)
print("Maximum pixel value:", max_value)

In [None]:
clean_img_np.shape

In [None]:
# Sort generated images based on PSNR or SSIM scores
# Choose PSNR or SSIM as per your requirement
sorted_indices_psnr = np.argsort(psnr_scores)[::-1]  # Sort in descending order
# sorted_indices_ssim = np.argsort(ssim_scores)[::-1]  # Sort in descending order

# Visualize top X images based on PSNR or SSIM
top_x = 5  # Change this value as needed
visualize_top_images(1, top_x)

In [None]:
view_image(hazy_imgs[2])

In [None]:
view_image(generator(hazy_imgs)[2])

In [None]:
plt.imshow(clean_imgs[0].detach().numpy().transpose(1, 2, 0))

In [None]:
clean_image = clean_imgs[1].permute(1, 2, 0).cpu().numpy()  # Convert to NumPy array
clean_image = clean_image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])  # Denormalize
plt.imshow(clean_image)

In [None]:
def view_image(image):
    

In [None]:

clean_image = generator(hazy_imgs)[1].detach().permute(1, 2, 0).cpu().numpy()  # Convert to NumPy array
clean_image = clean_image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])  # Denormalize
plt.imshow(clean_image)

In [None]:
sorted(psnr_scores)

In [None]:
clean_image = images[i].permute(1, 2, 0).cpu().numpy()  # Convert to NumPy array
clean_image = clean_image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])  # Denormalize



In [None]:
generated_im