# Task B: Image Super-Resolution

Gravitational lensing has been a cornerstone in many cosmology experiments and studies since it was discussed in Einstein’s calculations back in 1936 and discovered in 1979, and one area of particular interest is the study of dark matter via substructure in strong lensing images. In this challenge, we focus on exploring the potential of ML models in enhancing the resolution of lensing images.

In this task, we will develop and train a super-resolution model to enhance the quality of low-resolution strong gravitational lensing images. The goal is to upscale noisy and blurry images to higher resolutions, improving their clarity and detail. Participants can explore different super-resolution techniques, including convolutional neural networks (CNNs), generative adversarial networks (GANs), and other deep learning approaches.

![HR and LR Image Pair](https://github.com/pranath-reddy/DeepLearnHackathon/blob/main/GravitationalLensingChallenge/hr_lr_pair.png?raw=true)

This is an example notebook for the Image Super-Resolution Challenge. In this notebook, we demonstrate a simple CNN model implemented using the PyTorch library to solve the task of super-resolution of strong lensing images.

### Dataset

The Dataset consists of high-resolution (HR) and low-resolution (LR) pairs. The images have been normalized using min-max normalization.

Link to the Dataset: https://drive.google.com/file/d/1yJBvKD4saonRfSy4r0ceuD9qzgrdrHld/view?usp=sharing

### Evaluation Metrics

* MSE (Mean Squared error), SSIM (Similarity Index), and PSNR (Signal to Noise Ratio)   

The model performance will be tested on the hidden test dataset based on the above metrics. More details about these metrics and the code to calculate them has been shared below.

### Instructions for using the notebook

1. Use GPU acceleration: (Edit --> Notebook settings --> Hardware accelerator --> GPU)
2. Run the cells: (Runtime --> Run all)

In [None]:
!pip install gdown

In [None]:
import gdown

In [None]:
import os
# Check if the dataset folder is missing
if not os.path.exists('./dataset_superres'):
    # Download and extract the dataset
    !gdown "http://drive.google.com/uc?id=1yJBvKD4saonRfSy4r0ceuD9qzgrdrHld"
    !unzip -q dataset.zip
    !mv dataset dataset_superres

## Single Image Super Resolution

### 1. Data Visualization and Preprocessing

#### 1.1 Import all the necessary libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.autograd import Variable
from tqdm.notebook import tqdm
import torch.utils.data as data
import torch.nn.functional as F
from skimage.metrics import structural_similarity as ssim, peak_signal_noise_ratio as psnr
%matplotlib inline

#### 1.2 Preview the Data

In [None]:
# Define the input paths for high-resolution (HR) and low-resolution (LR) training images
train_hr_path = './dataset_superres/train/HR'
train_hr_files = [os.path.join(train_hr_path, f) for f in os.listdir(train_hr_path) if f.endswith(".npy")]
train_lr_path = './dataset_superres/train/LR'
train_lr_files = [os.path.join(train_lr_path, f) for f in os.listdir(train_lr_path) if f.endswith(".npy")]

# Number of samples to display
n = 5

# Plot the high-resolution (HR) samples
i = 1
print('High-Resolution (HR) samples: ')
plt.rcParams['figure.figsize'] = [14, 14]
for image in train_hr_files[:n]:  # Loop through the first n HR images
    ax = plt.subplot(2, n, i)  # Create subplot for the current image
    plt.imshow(np.load(image).reshape(128,128), cmap='gray')  # Load and display the image in grayscale
    ax.get_xaxis().set_visible(False)  # Hide x-axis
    ax.get_yaxis().set_visible(False)  # Hide y-axis
    i += 1  # Increment the subplot index
plt.show()  # Display the plot

# Plot the low-resolution (LR) samples
print('Low-Resolution (LR) samples: ')
plt.rcParams['figure.figsize'] = [14, 14]
for image in train_lr_files[:n]:  # Loop through the first n LR images
    ax = plt.subplot(2, n, i)  # Create subplot for the current image
    plt.imshow(np.load(image).reshape(64,64), cmap='gray')  # Load and display the image in grayscale
    ax.get_xaxis().set_visible(False)  # Hide x-axis
    ax.get_yaxis().set_visible(False)  # Hide y-axis
    i += 1  # Increment the subplot index
plt.show()  # Display the plot

#### 1.3 Import Training and Validation Data

In [None]:
# Set Batch Size
batch_size = 100

# Define a custom Dataset class for loading Super Resolution data
class SuperResolutionDataset(data.Dataset):
    def __init__(self, lr_path, hr_path):
        # Initialize the dataset with lists of low-resolution and high-resolution image file paths
        self.lr_files = [os.path.join(lr_path, f) for f in os.listdir(lr_path) if f.endswith(".npy")]
        self.hr_files = [os.path.join(hr_path, f) for f in os.listdir(hr_path) if f.endswith(".npy")]
        
    def __len__(self):
        # Return the total number of low-resolution images (The number of HR and LR images is the same)
        return len(self.lr_files)
    
    def __getitem__(self, idx):
        # Load the low-resolution and high-resolution images from the file paths
        lr_image = np.load(self.lr_files[idx])
        hr_image = np.load(self.hr_files[idx])
        # Convert numpy arrays to PyTorch tensors and return them
        return torch.from_numpy(lr_image).float(), torch.from_numpy(hr_image).float()

# Create the training data loader
train_data = SuperResolutionDataset('./dataset_superres/train/LR', './dataset_superres/train/HR')
train_data_loader = data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)

# Create the validation data loader
val_data = SuperResolutionDataset('./dataset_superres/val/LR', './dataset_superres/val/HR')
val_data_loader = data.DataLoader(val_data, batch_size=batch_size, shuffle=True, num_workers=4)

### 2. Training

#### 2.1 Defining a Super-Resolution CNN Model

You may refer to this [article](https://medium.com/@RaghavPrabhu/understanding-of-convolutional-neural-network-cnn-deep-learning-99760835f148) to learn about Convolutional Neural Networks (CNN) and this [article](https://medium.com/coinmonks/review-srcnn-super-resolution-3cb3a4f67a7c) to learn more about how CNNs can be used for super-resolution.

In [None]:
# Define the Super-Resolution Convolutional Neural Network (SRCNN) model
class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN, self).__init__()
        # First convolutional layer: 1 input channel, 64 output channels, 9x9 kernel, 4 pixels padding
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=9, padding=4)
        # Second convolutional layer: 64 input channels, 32 output channels, 5x5 kernel, 2 pixels padding
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, padding=2)
        # Third convolutional layer: 32 input channels, 1 output channel, 5x5 kernel, 2 pixels padding
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=1, kernel_size=5, padding=2)
        # ReLU activation function
        self.relu = nn.ReLU()

    def forward(self, x):
        # Apply the first convolutional layer followed by ReLU activation
        x = self.relu(self.conv1(x))
        # Apply the second convolutional layer followed by ReLU activation
        x = self.relu(self.conv2(x))
        # Apply the third convolutional layer
        x = self.conv3(x)
        # Upsample the output to the HR resolution using bicubic interpolation
        x = F.interpolate(x, size=(128, 128), mode='bicubic', align_corners=False)
        return x

# Set the device to GPU if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate the SRCNN model and move it to the appropriate device
model = SRCNN().to(device)

#### 2.2 Training the Super-Resolution CNN Model

In [None]:
# Loss Function
criteria = nn.MSELoss()  # Mean Squared Error Loss

# Optimizer (Adam)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)  

# Train the model
n_epochs = 20  # Number of Training Epochs
loss_array = []  # Array to store the loss values
pbar = tqdm(range(1, n_epochs+1))  # Progress bar for tracking epochs
for epoch in pbar:
    train_loss = 0.0  # Initialize training loss for the epoch

    # Iterate over the training data loader
    for step, (lr, hr) in enumerate(train_data_loader):

        lr = Variable(lr).type(torch.FloatTensor).to(device)  # Move low-resolution images to the device
        hr = Variable(hr).type(torch.FloatTensor).to(device)  # Move high-resolution images to the device
        optimizer.zero_grad()  # Clear the gradients
        outputs = model(lr)  # Forward pass through the model
        loss = criteria(outputs, hr)  # Calculate the loss
        loss.backward()  # Backpropagation
        optimizer.step()  # Update the model parameters

        train_loss += loss.item()  # Accumulate the training loss

    train_loss = train_loss / len(train_data_loader)  # Compute average training loss for the epoch
    loss_array.append(train_loss)  # Append the loss to the loss array
    # Display the Training Stats
    pbar.set_postfix({'Training Loss': train_loss})  # Update progress bar with training loss

### 3. Testing

#### 3.1 Testing the Super-Resolution CNN Model on Validation Data - Calculate Quantitative Metrics

- **MSE (Mean Squared Error):** A measure of the average squared difference between the estimated values and the actual value. Lower values indicate better performance.

- **SSIM (Structural Similarity Index):** A method for measuring the similarity between two images. It is used to measure the quality of the super-resolved images compared to the original high-resolution images.

- **PSNR (Peak Signal-to-Noise Ratio):** The ratio between the maximum possible power of a signal and the power of corrupting noise that affects the fidelity of its representation. Higher values indicate better image quality.

You may refer to this [article](https://medium.com/@datamonsters/a-quick-overview-of-methods-to-measure-the-similarity-between-images-f907166694ee) to learn more about these metrics

*Note: Metrics need to be calculated on a sample-by-sample basis, not on a batch basis. This is because metrics like SSIM and PSNR are used for assessing the quality of individual images and the scikit-image functions do not average over the first axis when we pass batches of images to them.*

In [None]:
# Calculate Metrics

# Set the model to evaluation mode
model.eval()
out = []  # List to store model predictions
with torch.no_grad():  # Disable gradient calculation for validation
    for lr, hr in val_data_loader:
        lr = lr.to(device)  # Move low-resolution images to the device
        hr = hr.to(device)  # Move high-resolution images to the device
        recon = model(lr)  # Get model predictions
        out.append(recon.cpu().detach().numpy())  # Append predictions to the list and move to CPU
        del lr, hr, recon  # Free memory
        torch.cuda.empty_cache()  # Clear cached memory
dataSR = np.concatenate(out, axis=0)  # Concatenate predictions along the batch axis

# Prepare ground truth for comparison
val_hr = []
for _, hr in val_data_loader:
    val_hr.append(hr.cpu().numpy())  # Append ground truth high-resolution images to the list and move to CPU
val_hr = np.concatenate(val_hr, axis=0)  # Concatenate ground truth images along the batch axis

# Calculate metrics
print("Metrics:")
criteria = nn.MSELoss()  # Mean Squared Error Loss
criteria2 = nn.L1Loss()  # L1 Loss

losses = []  # List to store MSE losses
losses2 = []  # List to store L1 losses
Ssim = []  # List to store SSIM scores
Psnr = []  # List to store PSNR scores

for i in range(dataSR.shape[0]):
    # Calculate MSE loss between predicted and ground truth images
    losses.append(criteria(torch.from_numpy(dataSR[i]), torch.from_numpy(val_hr[i])))
    # Calculate L1 loss between predicted and ground truth images
    losses2.append(criteria2(torch.from_numpy(dataSR[i]), torch.from_numpy(val_hr[i])))
    # Calculate SSIM score between predicted and ground truth images
    Ssim.append(ssim(val_hr[i][0], dataSR[i][0], data_range=dataSR[i][0].max() - dataSR[i][0].min()))
    # Calculate PSNR score between predicted and ground truth images
    Psnr.append(psnr(val_hr[i][0], dataSR[i][0], data_range=dataSR[i][0].max() - dataSR[i][0].min()))

# Print average metrics
print("Average MSE super resolution samples: " + str('%.7f' % np.average(losses)))
print("Average L1 super resolution samples: " + str('%.7f' % np.average(losses2)))
print("Average SSIM super resolution samples: " + str('%.5f' % np.average(Ssim)))
print("Average PSNR super resolution samples: " + str('%.5f' % np.average(Psnr)))

#### 3.2 Visualize Outputs for Qualitative Analysis

In [None]:
# Visualize Outputs
with torch.no_grad():  # Disable gradient calculation
    for lr, hr in val_data_loader:
        lr = lr.to(device)  # Move low-resolution images to the device
        hr = hr.to(device)  # Move high-resolution images to the device
        output = model(lr)  # Get model predictions

        lr = lr.cpu().numpy()  # Move low-resolution images to CPU and convert to numpy array
        output = output.cpu().numpy()  # Move predicted images to CPU and convert to numpy array
        hr = hr.cpu().numpy()  # Move high-resolution images to CPU and convert to numpy array

        # Display the results
        plt.figure(figsize=(12, 8))  # Set figure size
        for i in range(5):  # Display first 5 images
            plt.subplot(3, 5, i + 1)  # Create subplot for low-resolution image
            plt.imshow(lr[i].reshape(64, 64), cmap='gray')  # Display low-resolution image in grayscale
            plt.title('Low Res')  # Set title for low-resolution image
            plt.axis('off')  # Hide axis
            plt.subplot(3, 5, i + 6)  # Create subplot for high-resolution image
            plt.imshow(hr[i].reshape(128, 128), cmap='gray')  # Display high-resolution image in grayscale
            plt.title('High Res')  # Set title for high-resolution image
            plt.axis('off')  # Hide axis
            plt.subplot(3, 5, i + 11)  # Create subplot for output image
            plt.imshow(output[i].reshape(128, 128), cmap='gray')  # Display predicted image in grayscale
            plt.title('Output')  # Set title for predicted image
            plt.axis('off')  # Hide axis
        plt.show()  # Display the figure
        break  # Break after first batch to visualize

## Submission Guidelines

* Fill out the pre- and post- hackathon surveys.
* You are required to submit a Google Colab Jupyter Notebook (.ipynb and pdf) clearly showing your implementation along with the evaluation metrics (MSE, SSIM, and PSNR) for the validation data.
* You must also submit the final trained model, including the model architecture and the trained weights ( For example: HDF5 file, .pb file, .pt file, etc. )
* You can use this example notebook as a template for your work.

> **_NOTE:_**  You are free to use any ML framework such as PyTorch, Keras, TensorFlow, etc.