In [1]:
import numpy as np
from PIL import Image
from astropy import nddata
from astropy.io import fits
import matplotlib.pyplot as plt

In [2]:
# following these docs: https://docs.astropy.org/en/stable/api/astropy.nddata.Cutout2D.html
# https://photutils.readthedocs.io/en/stable/segmentation.html following these helpful docs

In [18]:
img = 'rsb_fits_images/test.fits'
hdul = fits.open(img)

# grabbing information necessary for sersic profile fit
img_data = hdul[1].data # image
mask_data = hdul[2].data # mask
uncertainty_data = hdul[3].data # variance

# code from this lsst community post:
# https://community.lsst.org/t/how-to-extract-psfex-psf-from-a-pvi-calexp-outside-of-science-pipelines/8057/6
psfex_hdu_info_index = 9
psfex_hdu_data_index = 10

psfex_info = hdul[psfex_hdu_info_index]
psfex_data = hdul[psfex_hdu_data_index]

pixstep = psfex_info.data._pixstep[0]  # Image pixel per PSF pixel
size = psfex_data.data["_size"]  # size of PSF  (nx, ny, n_basis_vectors)
comp = psfex_data.data["_comp"]  # PSF basis components
coeff = psfex_data.data["coeff"]  # Coefficients modifying each basis vector
psf_basis_image = np.roll(size, shift=-2, axis=1)
psf_image = psf_basis_image * psfex_data.data["basis"][0, :, np.newaxis, np.newaxis]
# psf_image
psf_image = psf_image.sum(0)[0:20]
psf_image /= psf_image.sum()

print(psf_image)
print(psf_image.shape)
# psf_image /= psf_image.sum() * pixstep**2 # normalized to 1
# print(sum(psf_image))

[[0.00178571 0.02410714 0.02410714]
 [0.00178571 0.02410714 0.02410714]
 [0.00178571 0.02410714 0.02410714]
 [0.00178571 0.02410714 0.02410714]
 [0.00178571 0.02410714 0.02410714]
 [0.00178571 0.02410714 0.02410714]
 [0.00178571 0.02410714 0.02410714]
 [0.00178571 0.02410714 0.02410714]
 [0.00178571 0.02410714 0.02410714]
 [0.00178571 0.02410714 0.02410714]
 [0.00178571 0.02410714 0.02410714]
 [0.00178571 0.02410714 0.02410714]
 [0.00178571 0.02410714 0.02410714]
 [0.00178571 0.02410714 0.02410714]
 [0.00178571 0.02410714 0.02410714]
 [0.00178571 0.02410714 0.02410714]
 [0.00178571 0.02410714 0.02410714]
 [0.00178571 0.02410714 0.02410714]
 [0.00178571 0.02410714 0.02410714]
 [0.00178571 0.02410714 0.02410714]]
(20, 3)


In [4]:
from photutils.background import Background2D, MedianBackground
bkg_estimator = MedianBackground()
bkg = Background2D(img_data, (50, 50), filter_size=(3, 3),
                   bkg_estimator=bkg_estimator)
threshold = 1.9 * bkg.background_rms

# img_data -= bkg.background

In [5]:
from astropy.convolution import convolve
from photutils.segmentation import make_2dgaussian_kernel
kernel = make_2dgaussian_kernel(3.0, size=5)  # FWHM = 3.0
convolved_data = convolve(img_data, kernel)

In [6]:
# plt.hist(convolved_data.flatten(), bins=100, range=[-1,1]);

In [7]:
from photutils.segmentation import detect_sources
from matplotlib.colors import LogNorm

segment_map = detect_sources(convolved_data, threshold, npixels=10)
%matplotlib inline
print(segment_map.shape)
# plt.imshow(img_data, cmap='gray_r', norm=LogNorm(vmin=0.1, vmax=1),origin='lower')

(4200, 4200)


In [8]:
# plt.imshow(segment_map, cmap='gray_r', norm=LogNorm(vmin=0.1, vmax=1), origin='lower')

In [9]:
import numpy as np
import matplotlib.pyplot as plt
from astropy.visualization import SqrtStretch
from astropy.visualization.mpl_normalize import ImageNormalize
norm = ImageNormalize(stretch=SqrtStretch())
# fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12.5))
# ax1.imshow(data, origin='lower', cmap='Greys_r', norm=norm)
# ax1.set_title('Background-subtracted Data')
# ax2.imshow(segment_map, origin='lower', cmap=segment_map.cmap,
#            interpolation='nearest')
# ax2.set_title('Segmentation Image')

In [10]:
# get bounding boxes of sources
segment_map.remove_border_labels(10, partial_overlap=False, relabel=False)
bbox = segment_map.bbox
shortened_bbox = bbox[500:3500:30]

In [11]:
# ixmin, ixmax, iymin, iymax
cutouts = []
for i in range(len(shortened_bbox)):
    y_center, x_center = shortened_bbox[i].center
    x_len,y_len = shortened_bbox[i].shape
    min_length = 12
    if (x_len> 10 and y_len > 10 and x_len < 40 and y_len < 40):
        length = max([x_len, y_len, min_length]) * 1.25
        my_cutout = nddata.Cutout2D(img_data, (x_center,y_center), int(length))
        cutout_mask = nddata.Cutout2D(mask_data, (x_center,y_center), int(length))
        cutout_var = nddata.Cutout2D(uncertainty_data, (x_center,y_center), int(length))
        package = [my_cutout, cutout_mask, cutout_var]
        cutouts.append(package)

In [21]:
# PYSERSIC BEGINNINGS
from pysersic import results
from pysersic.priors import autoprior
import jax.numpy as jnp

im,mask,sig = cutouts[0]
prior  = autoprior(image = im.data, profile_type = 'sersic', mask=mask.data, sky_type = 'none')

# PRINTING SHAPES
print(f"Psf dimensions: {psf_image.shape}")
print(f"Image dimensions: {im.shape}")
print(f"Var dimensions: {sig.shape}")
print(f"Mask dimensions: {mask.shape}")
print(im.shape > psf_image.shape)

# psf_jax = jnp.array(psf_index)

from pysersic import FitSingle
from pysersic.loss import student_t_loss

mask_vals = mask.data <= 32
# help(FitSingle)
fitter = FitSingle(data=im.data,rms=sig.data,mask=mask_vals,psf=psf_image,prior=prior,loss_func=student_t_loss)

Psf dimensions: (20, 3)
Image dimensions: (22, 22)
Var dimensions: (22, 22)
Mask dimensions: (22, 22)
True
[[False False False False False False False False False False False False
  False False False False False False False False False False]
 [False False False False False False False False False False False False
  False  True False False False False False False False False]
 [False False False False False False False False False False False False
   True  True False False False False False False False False]
 [False False False False False False False False False False  True  True
   True  True False False False False False False False False]
 [False False False False False False False False False  True  True  True
   True  True False False False False False False False False]
 [False False False False False False False False  True  True  True  True
   True  True False False False False False False False False]
 [False False False False False False False  True  True  True  True  Tr

TypeError: mul got incompatible shapes for broadcasting: (20, 3), (3, 20).

In [None]:
# for j in cutouts:
#     fig, (ax1) = plt.subplots(1, 1, figsize=(4, 4))
#     ax1.imshow(j.data, origin='lower', cmap='Greys_r')
#     ax1.set_title('Attempt of cuttout image')

# display just first image
fig, (ax1) = plt.subplots(1, 1, figsize=(4, 4))
ax1.imshow(cutouts[0].data, origin='lower', cmap='Greys_r')
ax1.set_title('Attempt of cuttout image')

In [None]:
# from astropy.modeling.models import Sersic2D
# spherical_img = cutouts[0]

# img_data = spherical_img.data

# # Model
# print(f"Shape of data before model fitting: {spherical_img.data.shape}")

# sersic_initial = Sersic2D(amplitude=np.max(img_data), r_eff=1, n=4, x_0=spherical_img.center_cutout[0],y_0=spherical_img.center_cutout[1], ellip=0.1, theta=0)

# sersic_img = sersic_initial(img_data[0], img_data[1])
# print(f"Shape of data after model fitting?: {sersic_img.shape}")
# fig, ax = plt.subplots()
# ax.imshow(spherical_img.data)
# plt.show()