In [8]:
import numpy as np
from PIL import Image
from astropy import nddata
from astropy.io import fits
import matplotlib.pyplot as plt

In [None]:
# following these docs: https://docs.astropy.org/en/stable/api/astropy.nddata.Cutout2D.html
# https://photutils.readthedocs.io/en/stable/segmentation.html following these helpful docs

In [9]:
img = 'rsb_fits_images/test.fits'
hdul = fits.open(img)

# grabbing information necessary for sersic profile fit
img_data = hdul[1].data # image
mask_data = hdul[2].data # mask
uncertainty_data = hdul[3].data # variance
psf_data = hdul[8].data # warped psf


# data -= bgmean
# plt.hist(data.flatten(), bins=100, range=[-5,2200]);
# plt.yscale('log');
# data = fits.getdata(img)
# header = fits.getheader(img)
# bgvar = header['BGVAR']
# bgmean = header['BGMEAN']
# # plt.hist(data.flatten(), bins=100, range=[-5,2200]);
# # plt.yscale('log');

In [10]:
from photutils.background import Background2D, MedianBackground
bkg_estimator = MedianBackground()
img_data = data
bkg = Background2D(img_data, (50, 50), filter_size=(3, 3),
                   bkg_estimator=bkg_estimator)
img_data -= bkg.background

In [5]:
# plt.hist(img_data.flatten(), bins=100, range=[-1,1]);
# very different... normalized to 1

In [12]:
threshold = 1.9 * bkg.background_rms
from astropy.convolution import convolve
from photutils.segmentation import make_2dgaussian_kernel
kernel = make_2dgaussian_kernel(3.0, size=5)  # FWHM = 3.0
convolved_data = convolve(img_data, kernel)

In [7]:
# plt.hist(convolved_data.flatten(), bins=100, range=[-1,1]);

In [13]:
from photutils.segmentation import detect_sources
from matplotlib.colors import LogNorm

segment_map = detect_sources(convolved_data, threshold, npixels=10)
%matplotlib inline
print(segment_map.shape)
# plt.imshow(img_data, cmap='gray_r', norm=LogNorm(vmin=0.1, vmax=1),origin='lower')

(4200, 4200)


In [None]:
# plt.imshow(segment_map, cmap='gray_r', norm=LogNorm(vmin=0.1, vmax=1), origin='lower')
# seems to be an okay seg...

In [14]:
import numpy as np
import matplotlib.pyplot as plt
from astropy.visualization import SqrtStretch
from astropy.visualization.mpl_normalize import ImageNormalize
norm = ImageNormalize(stretch=SqrtStretch())
# fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12.5))
# ax1.imshow(data, origin='lower', cmap='Greys_r', norm=norm)
# ax1.set_title('Background-subtracted Data')
# ax2.imshow(segment_map, origin='lower', cmap=segment_map.cmap,
#            interpolation='nearest')
# ax2.set_title('Segmentation Image')

In [15]:
# get bounding boxes of sources
segment_map.remove_border_labels(10, partial_overlap=False, relabel=False)
bbox = segment_map.bbox
shortened_bbox = bbox[500:3500:30]

In [16]:
# ixmin, ixmax, iymin, iymax
cutouts = []
for i in range(len(shortened_bbox)):
    y_center, x_center = shortened_bbox[i].center
    x_len,y_len = shortened_bbox[i].shape
    min_length = 12
    if (x_len> 10 and y_len > 10 and x_len < 40 and y_len < 40):
        length = max([x_len, y_len, min_length]) * 1.25
        my_cutout = nddata.Cutout2D(img_data, (x_center,y_center), int(length))
        cutout_mask = nddata.Cutout2D(mask_data, (x_center,y_center), int(length))
        cutout_var = nddata.Cutout2D(uncertainty_data, (x_center,y_center), int(length))
        package = [my_cutout, cutout_mask, cutout_var]
        cutouts.append(package)

In [17]:
# PYSERSIC BEGINNINGS
from pysersic import results
from pysersic.priors import autoprior

im,mask,sig = cutouts[0]
psf = psf_data

prior  = autoprior(image = im.data, profile_type = 'sersic', mask=mask.data, sky_type = 'none')


# PRINTING SHAPES
print(f"Psf dimensions: {psf.shape}")
print(f"Image dimensions: {im.shape}")
print(f"Var dimensions: {sig.shape}")
print(f"Mask dimensions: {mask.shape}")
print(prior)

from pysersic import FitSingle
from pysersic.loss import student_t_loss

fitter = FitSingle(data=im.data,rms=sig.data,mask=mask.data,psf=psf,prior=prior,loss_func=student_t_loss)

Psf dimensions: (74,)
Image dimensions: (21, 21)
Var dimensions: (21, 21)
Mask dimensions: (21, 21)
Prior for a sersic source:
--------------------------
flux ---  Normal w/ mu = 0.00, sigma = nan
xc ---  Normal w/ mu = nan, sigma = 1.00
yc ---  Normal w/ mu = nan, sigma = 1.00
r_eff ---  Truncated Normal w/ mu = nan, sigma = nan, between: 0.50 -> inf
ellip ---  Uniform between: 0.00 -> 0.90
theta ---  Custom prior of type: <class 'numpyro.distributions.directional.VonMises'>
n ---  Uniform between: 0.65 -> 8.00

Sky Type: none



TypeError: Value '[(  5,   6,   7) ( 14,  15,  16) ( 23,  24,  25) ( 32,  33,  34)
 ( 41,  42,  43) ( 50,  51,  52) ( 59,  60,  61) ( 68,  69,  70)
 ( 77,  78,  79) ( 86,  87,  88) ( 95,  96,  97) (104, 105, 106)
 (113, 114, 115) (122, 123, 124) (131, 132, 133) (140, 141, 142)
 (149, 150, 151) (158, 159, 160) (167, 168, 169) (176, 177, 178)
 (185, 186, 187) (194, 195, 196) (203, 204, 205) (212, 213, 214)
 (221, 222, 223) (230, 231, 232) (239, 240, 241) (248, 249, 250)
 (257, 258, 259) (266, 267, 268) (275, 276, 277) (284, 285, 286)
 (293, 294, 295) (302, 303, 304) (311, 312, 313) (320, 321, 322)
 (329, 330, 331) (338, 339, 340) (347, 348, 349) (356, 357, 358)
 (365, 366, 367) (374, 375, 376) (383, 384, 385) (392, 393, 394)
 (401, 402, 403) (410, 411, 412) (419, 420, 421) (428, 429, 430)
 (437, 438, 439) (446, 447, 448) (455, 456, 457) (464, 465, 466)
 (473, 474, 475) (482, 483, 484) (491, 492, 493) (500, 501, 502)
 (509, 510, 511) (518, 519, 520) (527, 528, 529) (536, 537, 538)
 (545, 546, 547) (554, 555, 556) (563, 564, 565) (572, 573, 574)
 (581, 582, 583) (590, 591, 592) (599, 600, 601) (608, 609, 610)
 (617, 618, 619) (626, 627, 628) (635, 636, 637) (644, 645, 646)
 (653, 654, 655) (662, 663, 664)]' with dtype (numpy.record, [('psfIndex', '<i4'), ('transformIndex', '<i4'), ('controlIndex', '<i4')]) is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

In [None]:
# for j in cutouts:
#     fig, (ax1) = plt.subplots(1, 1, figsize=(4, 4))
#     ax1.imshow(j.data, origin='lower', cmap='Greys_r')
#     ax1.set_title('Attempt of cuttout image')

# display just first image
fig, (ax1) = plt.subplots(1, 1, figsize=(4, 4))
ax1.imshow(cutouts[0].data, origin='lower', cmap='Greys_r')
ax1.set_title('Attempt of cuttout image')

In [None]:
# # pass to sersic
# # find distribution of sources

# # what does it mean to 'fit' an image?
# # find a model and adjust its parameters until it is able to quite accurately represent the desired image

# from astropy.modeling.models import Sersic2D
# spherical_img = cutouts[0]

# img_data = spherical_img.data

# # Printing Initial Values
# print("Initial Guesses")
# print(f"Amplitude:            {np.max(img_data)}")
# print(f"Half-light Radius :   1")
# print(f"Sersic Index:         4")
# print(f"Center:               {spherical_img.center_cutout}")


# # Model
# print(f"Shape of data before model fitting: {spherical_img.data.shape}")

# sersic_initial = Sersic2D(amplitude=np.max(img_data), r_eff=1, n=4, x_0=spherical_img.center_cutout[0],y_0=spherical_img.center_cutout[1], ellip=0.1, theta=0)

# sersic_img = sersic_initial(img_data[0], img_data[1])
# print(f"Shape of data after model fitting?: {sersic_img.shape}")
# fig, ax = plt.subplots()
# ax.imshow(spherical_img.data)
# plt.show()