In [None]:
#  Includes taking the predicted sharp image and the ground truth sharp image as input and see the variance 

In [18]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader
from astropy.io import fits
from skimage.metrics import structural_similarity as ssim

In [26]:
BLURRY_FITS_PATH = "/Users/sarayu/Downloads/m31bH20.fits"
SHARP_FITS_PATH  = "//Users/sarayu/Downloads/m31bH (1).fits"

PATCH_SIZE = 128
PATCH_STRIDE = 128
BATCH_SIZE = 1
NUM_EPOCHS = 15
LEARNING_RATE = 1e-3
KERNEL_SIZE = 15  # PSF kernel size

In [20]:
# Values bw 0 and 1
def minmax_norm(img):
    img = img.astype(np.float32)
    return (img - img.min()) / (img.max() - img.min() + 1e-8)


# Patch size = 128x128. No overlap
def extract_patches(image, patch_size=128, stride=128):
    patches = []
    h, w = image.shape
    for i in range(0, h - patch_size + 1, stride):
        for j in range(0, w - patch_size + 1, stride):
            patch = image[i:i+patch_size, j:j+patch_size]
            patches.append(patch)
    return np.array(patches)

def compute_metrics(pred, target):
    # pred - predicted sharp im.
    # target -  true sharp im.
    pred = pred.astype(np.float64)
    target = target.astype(np.float64)

    diff = pred - target
    mse = np.mean(diff**2) # mean square error
    rmse = math.sqrt(mse)
    mae = np.mean(np.abs(diff))

    data_range = target.max() - target.min()
    if mse == 0:
        psnr = float("inf")
    else:
        psnr = 10 * math.log10((data_range ** 2) / mse)

    ssim_val = ssim(target, pred, data_range=data_range)

    return {
        "MSE": mse,
        "RMSE": rmse,
        "MAE": mae,
        "PSNR": psnr,
        "SSIM": ssim_val,
    }