### Generate Deconvolution Image

In [1]:
import torch
import numpy as np
from os import environ
from pathlib import Path

from hydra import initialize, compose
from hydra.utils import instantiate

In [19]:
environ["BLISS_HOME"] = str(Path().resolve().parents[1])
with initialize(config_path=".", version_base=None):
    cfg = compose("config")

dc2 = instantiate(cfg.surveys.dc2)
dc2.prepare_data()

In [3]:
def generate_gaussian_psf(size, center, sigma_x, sigma_y, sigma_xy):
    x, y = np.meshgrid(np.arange(size[0]), np.arange(size[1]))
    x_c, y_c = center
    exponent = -((x - x_c) ** 2 / (2 * sigma_x ** 2) +
                 (y - y_c) ** 2 / (2 * sigma_y ** 2) -
                 sigma_xy * (x - x_c) * (y - y_c) / (sigma_x * sigma_y))
    psf = np.exp(exponent)
    return psf / torch.sum(psf)

def generate_psf_image(avg_asm_x, avg_asm_y, avg_asm_xy):
    size = (25, 25)  # Size of the PSF image

    # Assuming the center of the PSF is at the center of the image
    center = (size[0] // 2, size[1] // 2)

    # Calculate standard deviations from the average ASM values
    sigma_x = np.sqrt(avg_asm_x)
    sigma_y = np.sqrt(avg_asm_y)

    # Calculate covariance from the average ASM_xy value
    sigma_xy = avg_asm_xy / (sigma_x * sigma_y)

    psf = generate_gaussian_psf(size, center, sigma_x, sigma_y, sigma_xy)
    return psf

def get_psf_band(psf_params):
    psf_image = []
    for i in range(6):
        psf_image.append(generate_psf_image(psf_params[i][0], psf_params[i][1], psf_params[i][2]))
    
    return psf_image

from skimage.restoration import richardson_lucy
def get_deconvolved_images(images, backgrounds, psfs):
    deconv_images = np.zeros_like(images)
    for band in range(6):
        deconv_images[band] = deconvolve_image(
            images[band], backgrounds[band], psfs[band]
        )
    return torch.from_numpy(deconv_images)

def deconvolve_image(image, background, psf, pad=10):
    padded_image = np.pad(image, pad, mode="constant", constant_values=background.mean().item())
    normalized = padded_image / np.max(padded_image)
    deconv = richardson_lucy(normalized, psf)
    return deconv[pad:-pad, pad:-pad]


In [23]:
# save generated image to file
psf_image = get_psf_band(dc2.dc2_data[0]['psf_params'])
decolve = get_deconvolved_images(
    dc2[0]['images'], 
    dc2[0]['background'], 
    torch.stack(psf_image).numpy()
)
torch.save(decolve, 'g-3828-0,0.pt')