In [None]:
# General Imports

import numpy as np
from PIL import Image
from astropy import nddata
from astropy.io import fits
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
from jax.random import PRNGKey # Need to use a seed to start jax's random number generation

In [None]:
##############################################
# Acquiring Image, PSF, Variance ------------#
##############################################

# Opening Image
img = 'rsb_fits_images/test1.fits'
hdu = fits.open(img)

uncertainty_data = hdu[3].data    # variance
img_data = hdu[1].data            # image

psfex_hdu_data_index = 10
psfex_hdu_info_index = 9 

# Acquiring Workable PSF
psfex_info = hdu[psfex_hdu_info_index]
psfex_data = hdu[psfex_hdu_data_index]

pixstep = psfex_info.data._pixstep[0]  # Image pixel per PSF pixel
size = psfex_data.data["_size"]  # size of PSF  (nx, ny, n_basis_vectors)
comp = psfex_data.data["_comp"]  # PSF basis components
coeff = psfex_data.data["coeff"]  # Coefficients modifying each basis vector
psf_basis_image = comp[0].reshape(*size[0][::-1])
psf_image = psf_basis_image * psfex_data.data["basis"][0, :, np.newaxis, np.newaxis]
psf_image = psf_image.sum(0)
psf_image /= psf_image.sum() * pixstep**2

# Plotting PSF
# plt.imshow(psf_image, cmap='gray', interpolation='none',vmin=-0.0001, vmax=0.0001)
# plt.xlabel('X-axis')
# plt.ylabel('Y-axis')
# plt.show()

# PLotting Retrieved FITS Image
# plt.imshow(img_data.data, vmin=0, vmax=0.3,origin="lower", cmap="gray")
# plt.title('Original Image Array')
# plt.show()

In [None]:
def gen_mask(image_shape):
    import jax.numpy as jnp
    return jnp.array(np.zeros(image_shape))

def gen_psf(image_shape):
    psf = np.zeros(image_shape)
    center = (image_shape[0] // 2, image_shape[1] // 2)
    psf[center] = 1
    sigma = 2  
    psf = gaussian_filter(psf, sigma=sigma)
    psf /= psf.sum()
    return psf

def resize_psf(psf_image, new_shape):
    import cv2
    resized_psf = cv2.resize(psf_image, new_shape, interpolation=cv2.INTER_AREA)
    resized_psf /= np.sum(resized_psf)    
    return resized_psf

In [None]:
##############################################
# Identifying Sources & Creating Cutouts ----#
##############################################

# Defining Threshold in Relation to Background RMS (how bright a source must be)
from photutils.background import Background2D, MedianBackground
bkg_estimator = MedianBackground()
bkg = Background2D(img_data, (50, 50), filter_size=(3, 3), bkg_estimator=bkg_estimator)
threshold = 1.9 * bkg.background_rms 


from astropy.convolution import convolve
from photutils.segmentation import make_2dgaussian_kernel
kernel = make_2dgaussian_kernel(3.0, size=5)  # FWHM = 3.0
convolved_data = convolve(img_data, kernel)


# Detecting Sources --> Segmentation Map
from photutils.segmentation import detect_sources
from matplotlib.colors import LogNorm
segment_map = detect_sources(convolved_data, threshold, npixels=10)
# display segment map
# plt.imshow(segment_map,cmap="gray", origin="lower", vmin=0, vmax=500)
segment_map.remove_border_labels(10, partial_overlap=False, relabel=True)

print(len(segment_map.labels))

bbox = segment_map.bbox
assert(len(segment_map.labels) == len(bbox)) # should be as many labels as bboxes

labels = segment_map.labels
# Creating Same-Dimension Cutouts of Image, Mask, and Variance
cutouts = []
for i in range(len(shortened_bbox)):
    y_center, x_center = shortened_bbox[i].center
    x_len,y_len = shortened_bbox[i].shape
    min_length = 12 #22
    if (x_len> 10 and y_len > 10 and x_len < 40 and y_len < 40):
        length = max([x_len, y_len, min_length]) * 1.25
        my_cutout = nddata.Cutout2D(img_data, (x_center,y_center), int(length))
        cutout_mask = gen_mask(my_cutout.shape)
        # generated_psf = gen_psf(my_cutout.shape)
        # normalized_psf = generated_psf/np.sum(generated_psf)
        actual_psf = resize_psf(psf_image, my_cutout.shape)
        cutout_var = nddata.Cutout2D(uncertainty_data, (x_center,y_center), int(length))
        package = [my_cutout, cutout_mask, cutout_var, actual_psf,labels[i]] #normalized_psf,actual_psf]
        cutouts.append(package)

In [7]:
print(len(cutouts))

1905


In [None]:
##############################################
# Fitting a Sersic Profile ------------------#
##############################################

from pysersic import FitSingle
from pysersic.loss import student_t_loss
from pysersic import results
from pysersic.priors import autoprior
from pysersic.priors import SourceProperties
from pysersic import check_input_data
from pysersic import FitSingle
from pysersic.loss import gaussian_loss
from pysersic.results import plot_residual

# future --> test different loss functions --> student-t-loss, gaussian

# now we begin labelling the segmentation map
labelled_seg = np.zeros((segment_map.shape[0],segment_map.shape[1],3))

for i in range(len(cutouts)):
    
    im,mask,sig,psf,label = cutouts[i] # image, mask, variance, psf
    if (im.shape[0] != psf.shape[0] or im.shape[1] != psf.shape[1]):
        print('weird size mismatch for some reason')
        # fig, axes = plt.subplots(1,2)
        # axes[0].imshow(im.data,cmap="gray")
        # axes[1].imshow(psf,cmap='gray')
        
    else:
        # Plotting Cutout & PSF
        # fig, axes = plt.subplots(1, 1, figsize=(12, 12))
        #     axes[0].imshow(im.data, cmap='gray', origin="lower")
        #     axes[0].set_title('Cutout')
        #     axes[0].axis('off')
        #     plt.show()

        # Verify Components are Usable
        check_input_data(im.data, sig.data, psf, mask)

        # Prior Estimation of Parameters
        props = SourceProperties(im.data,mask=mask) 
        prior = props.generate_prior('sersic',sky_type='none')

        # Fit
        fitter = FitSingle(data=im.data,rms=sig.data, psf=psf, prior=prior, mask=mask, loss_func=gaussian_loss) 
        map_params = fitter.find_MAP(rkey = PRNGKey(1000));               # contains dictionary of Sersic values
        fig, ax = plot_residual(im.data,map_params['model'],mask=mask,vmin=-1,vmax=1);
        fig.suptitle("Analysis of Fit")

        ##############################################
        # Testing Fit -------------------------------#
        # (does it belong in training dataset)
        ##############################################
        image = im.data
        model = map_params['model']
        assert(image.shape == model.shape)

        # Chi-squared Statistic ----------------------------------------------------------------#
        # (evaluating whether the difference in Image and Model is systematic or due to noise)

        from scipy.stats import chi2
        chi_square = np.sum((image*2.2 - model) ** 2 / (model))
        df = image.size-1                                                 # number of categories - 1
        p_value = chi2.sf(chi_square, df)
        
        ##############################################
        # Labelling the Segmap ----------------------#
        ##############################################
        # segmap id (pixel value??)
        # n 
        # p-value from fit
        
        n = map_params['n']
        # print(im.xmin_original, im.ymin_original,im.xmax_original,im.ymax_original)

        for xpos in range(im.xmin_original, im.xmax_original,1):
            for ypos in range(im.ymin_original, im.ymax_original,1):
                labelled_seg[xpos][ypos] = [label, n, p_value]
            

                
#         noise_threshold = np.mean(sig.data)                               # average of the variance
#         image_1D = image.flatten()
#         model_1D = model.flatten()
#         difference_1D = image_1D - model_1D
#         l1 = np.sum(np.abs(difference_1D))
#         l1_normalized = l1/(image_1D.size)
#         print(f"L1 norm: {l1_normalized}")
        
#         if (p_value < noise_threshold):
#             print(f"This model is not good.{p_value}")
#             # bad_list.append([image,model,p_value,l1_normalized])
#         else:
#             print(f"This model is a good fit.{p_value}")
#             good_list.append({
#                 'image': im,
#                 'r_eff': map_params['r_eff'],
#             
# 'n': map_params['n']
#             })
            # good_list.append([im.data,model,p_value,l1_normalized])

        # L1-Norm Statistic/Manhatten Distance (1) ---------------------------------------------#
        # (sum of absolute value of (actual pixel value - model pixel value) for all pixels)

        # what threshold should the l1_normalized metric meet?


In [None]:
for i in range(len(labelled_seg)):
    for j in range(len(labelled_seg[0])):
        for k in range(3):
            if (labelled_seg[i][j][k] != 0):
                print(labelled_seg[i][j][k], end="")
    # print()

In [None]:
# for i in bad_list:
#     fig, axes = plt.subplots(1,3)
#     axes[0].imshow(i[0],cmap="gray", vmin=0, vmax=0.5)
#     axes[1].imshow(i[1],cmap="gray", vmin=0, vmax=0.5)
#     axes[2].imshow(i[0]-i[1],cmap="gray",vmin=0, vmax=0.1)

#     fig.suptitle(f"Image (LEFT)       Model (MIDDLE)         Difference (RIGHT)     p-val:{i[2]}    l1_norm:{i[3]}")
# for i in good_list:
#     fig, axes = plt.subplots(1,3)
#     axes[0].imshow(i[0],cmap="gray", vmin=0, vmax=0.5)
#     axes[1].imshow(i[1],cmap="gray", vmin=0, vmax=0.5)
#     axes[2].imshow(i[0]-i[1],cmap="gray",vmin=0, vmax=0.1)

#     fig.suptitle(f"Image (LEFT)       Model (MIDDLE)         Difference (RIGHT)     p-val:{i[2]}    l1_norm:{i[3]}")

In [90]:
"""
Sources:
1) https://medium.com/swlh/different-types-of-distances-used-in-machine-learning-ec7087616442#:~:text=L1%20Norm%3A,the%20components%20of%20the%20vectors.

"""

'\nSources:\n1) https://medium.com/swlh/different-types-of-distances-used-in-machine-learning-ec7087616442#:~:text=L1%20Norm%3A,the%20components%20of%20the%20vectors.\n\n'