In [None]:
import astropy.units as u
import numpy as np
import os

os.environ["PIXEDFIT_HOME"] = "/nvme/scratch/work/tharvey/piXedfit/"
from EXPANSE import ResolvedGalaxy, MultipleResolvedGalaxy
from matplotlib import pyplot as plt
import glob
from scipy import signal
from scipy.interpolate import interp1d
from scipy.stats import binned_statistic
# Change dpi to make plots larger

plt.rcParams["figure.dpi"] = 100

# Disable tex in matplotlib

plt.rcParams["text.usetex"] = False

# %matplotlib inline

In [None]:
""" Initialize a galaxy object - loading from galfind if necessary """

galaxy = ResolvedGalaxy.init(1438, "JOF_psfmatched", "v11")

galaxy.plot_overview()
overwrite = True
""" Optional - make plots """

""" Do binning """
# galaxy.pixedfit_plot_binmap()

#''' Measure fluxes in bins '''
# tab = galaxy.measure_flux_in_bins()
# tab
#''' Do Bagpipes (if not done) '''

#''' Plot Bagpipes results '''
print(galaxy.det_data)

#galaxy.pixedfit_processing(gal_region_use="detection", overwrite=True)
#galaxy.pixedfit_binning(overwrite=True)
#galaxy.measure_flux_in_bins(overwrite=True)

In [None]:
galaxy.provide_bagpipes_phot(3)

#galaxy.photometry_table['star_stack']['pixedfit']

In [None]:
plt.imshow(galaxy.psf_matched_data['star_stack']['F435W'], origin='lower')

In [None]:
print(galaxy.photometry_table['star_stack'])

In [None]:
ez = galaxy.fit_eazy_photometry(fluxes=np.array([0.1]*19)* u.uJy, flux_errs=np.array([0.01]*19) *u.uJy, run_name="test")

In [None]:
ez.show_fit(0)


In [None]:
galaxy.add_detection_data(overwrite = True)

In [None]:
print(galaxy.cutout_size)

galaxy.pixedfit_processing(gal_region_use="detection",)

In [None]:
%matplotlib inline
print(np.shape(galaxy.gal_region['detection']))

print(np.shape(galaxy.det_data['rms_err']))

In [None]:
 galaxy.add_flux_aper_total(catalogue_path="/raid/scratch/work/austind/GALFIND_WORK/Catalogues/v11/ACS_WFC+NIRCam/JOF_psfmatched/JOF_psfmatched_MASTER_Sel-F277W+F356W+F444W_v11_total.fits",
                overwrite=True)

galaxy.measure_flux_in_bins(overwrite=True)

In [None]:
galaxy.plot_overview(save=True, flux_unit = u.ABmag, bands_to_show = ['F435W', 'F606W', 'F775W', 'F814W', 'F850LP'], show=True);


In [None]:
data =galaxy.psf_matched_data['star_stack']['F606W']

from photutils import CircularAperture, CircularAnnulus, aperture_photometry

positions = [(data.shape[1]/2, data.shape[0]/2)]
apertures = CircularAperture(positions, r=0.16/0.03)

phot_table = aperture_photometry(data, apertures)

d = phot_table['aperture_sum'] * u.uJy



In [None]:
db_atlas_path = f'/nvme/scratch/work/tharvey/EXPANSE/scripts/pregrids/db_atlas_JOF_10000_Nparam_3.dbatlas'

fit_results = galaxy.run_dense_basis(db_atlas_path, plot=True)

In [None]:
fit_results[0].plot_posteriors();

In [None]:
fit_results[0].plot_posterior_SFH(fit_results[0].z[0])

In [None]:
from EXPANSE.dense_basis import get_priors
priors = get_priors(db_atlas_path)

galaxy.get_filter_wavs()

wavs = np.array([galaxy.filter_wavs[band].to(u.Angstrom).value for band in galaxy.bands])

fit_results[0].plot_posterior_spec(wavs, priors)

In [None]:
galaxies = MultipleResolvedGalaxy(
    ResolvedGalaxy.init_all_field_from_h5("JOF_psfmatched")
)

#galaxies.run_function("plot_overview", save=True)
galaxies.mass_comparison_plot("photoz_delayed", "CNST_SFH_RESOLVED", label = True, markersize = 4, markeredgecolor = 'black', markeredgewidth = 0.5, elinewidth = 1);

In [None]:
galaxy.provide_bagpipes_phot("TOTAL_BIN")

In [None]:
from astropy.io import fits

err = galaxy.psf_matched_rms_err["star_stack"]["F444W"]
im = galaxy.psf_matched_data["star_stack"]["F444W"]
seg = galaxy.seg_imgs["F444W"]
header = fits.Header.fromstring(galaxy.phot_img_headers["F444W"])
exptime = header["XPOSURE"]


# convert data to uJy/arcsec2

im2 = im / 0.000899999999999999  # / Nominal pixel area in arcsec^2
im2 /= 9.225489294810032  # / Flux density (uJy/arcsec2) producing 1 cps
im2 *= exptime  # counts
poission_err = np.sqrt(np.sqrt(im2**2))

# convert error to uJy
poission_err /= exptime
poission_err *= 0.000899999999999999
poission_err *= 9.225489294810032


# a = binned_statistic_2d(im.flatten(), err.flatten(), err.flatten(), bins=100)
b = binned_statistic(im.flatten(), err.flatten(), bins=100)

g = interp1d(b.bin_edges[1:], b.statistic, kind="linear", fill_value="extrapolate")

import statsmodels.api as sm

lowess = sm.nonparametric.lowess
z = lowess(err.flatten(), im.flatten(), frac=0.1)
plt.plot(z[:, 0], z[:, 1], label="lowess", color="red")

plt.scatter(im.flatten(), err.flatten(), s=1)
# plt.plot(b.bin_edges[1:], b.statistic, label='binned', color='red')

plt.show()


fig, ax = plt.subplots(nrows=1, ncols=4, figsize=(13, 8))


one = ax[0].imshow(
    im,
    origin="lower",
)
# import make_axes_locatable
from mpl_toolkits.axes_grid1 import make_axes_locatable

ax[0].set_title("Data")
ax[1].set_title("Error")


snr_map = im / err
mask = snr_map > 8

# Median underestimate of error in mask
scaling = np.median(err[mask] / poission_err[mask])

# poission_err *= scaling

# Make fake error map from im using lowess
f = interp1d(z[:, 0], z[:, 1], kind="cubic", fill_value="extrapolate")
ferr = f(im)

poission_err = ferr
[
    a.tick_params(
        axis="both",
        which="both",
        bottom=False,
        top=False,
        labelbottom=False,
        right=False,
        left=False,
        labelleft=False,
    )
    for a in ax
]
two = ax[1].imshow(err, origin="lower")
ax[2].set_title("Model Err")
three = ax[2].imshow(poission_err, origin="lower")

ax[3].set_title("Err / Model Err")
four = ax[3].imshow(err / poission_err, origin="lower", vmax=1.1, vmin=0.9)

# Interpolate error as function of data signale


# ax[3].imshow(mask, origin='lower', alpha=0.5)

# ax[3].imshow(seg, origin='lower', alpha=0.5)

# Any pixel with SNR > 3 should use the


cax1 = make_axes_locatable(ax[0]).append_axes("right", size="5%", pad=0.05)
fig.colorbar(mappable=one, cax=cax1)
cax2 = make_axes_locatable(ax[1]).append_axes("right", size="5%", pad=0.05)
fig.colorbar(mappable=two, cax=cax2)
cax3 = make_axes_locatable(ax[2]).append_axes("right", size="5%", pad=0.05)
fig.colorbar(mappable=three, cax=cax3)
cax4 = make_axes_locatable(ax[3]).append_axes("right", size="5%", pad=0.05)
fig.colorbar(mappable=four, cax=cax4)
plt.show()


correlated_noise = err - poission_err


plt.imshow(correlated_noise / err, origin="lower")
plt.colorbar()

# 9.225489294810032 / Flux density (uJy/arcsec2) producing 1 cps
# 0.000899999999999999 / Nominal pixel area in arcsec^2
# Estimate a fake error map


def generate_correlated_noise(shape, correlation_length):
    # Generate white noise
    white_noise = np.random.normal(0, 1, shape)

    # Create a 2D Gaussian kernel for correlation
    x, y = np.meshgrid(
        np.arange(-3 * correlation_length, 3 * correlation_length + 1),
        np.arange(-3 * correlation_length, 3 * correlation_length + 1),
    )
    kernel = np.exp(-(x**2 + y**2) / (2 * correlation_length**2))
    kernel /= kernel.sum()

    # Convolve white noise with the Gaussian kernel
    correlated_noise = signal.convolve2d(
        white_noise, kernel, mode="same", boundary="wrap"
    )

    # Normalize to maintain original standard deviation
    correlated_noise *= white_noise.std() / correlated_noise.std()

    return correlated_noise


sim_correlated_noise = generate_correlated_noise(im.shape, 1.5)

fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
# Scale to match real correlated noise
sim_correlated_noise *= poission_err.std() / sim_correlated_noise.std()

q = ax[1].imshow(sim_correlated_noise, origin="lower")
plt.colorbar(q)
ax[1].set_title("Simulated Correlated Noise")

x = ax[0].imshow(correlated_noise, origin="lower")
plt.colorbar(x)
ax[0].set_title("Correlated Noise")


# Repeat for full image.

# Load full_image and error

In [None]:
possible_galaxies = glob.glob("galaxies/JOF_psfmatched_*.h5")
ids = [int(g.split("_")[-1].split(".")[0]) for g in possible_galaxies]
print(ids)
galaxies = ResolvedGalaxy.init(ids, "JOF_psfmatched", "v11")
from itertools import cycle

colors = cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"])

# import GridSpec
from matplotlib.gridspec import GridSpec

fig = plt.figure(figsize=(10, 10))
gs = GridSpec(2, 4, figure=fig)

# scatter axis
ax1 = fig.add_subplot(gs[0, :2])
band = "F090W"
data_type = "PSF"

total_err = []
total_data = []
data_type = "ORIGINAL"


for pos, galaxy in enumerate(galaxies):
    if data_type == "PSF":
        err = galaxy.psf_matched_rms_err["star_stack"][band]
        im = galaxy.psf_matched_data["star_stack"][band]
    elif data_type == "ORIGINAL":
        im = galaxy.unmatched_data[band]
        err = galaxy.unmatched_rms_err[band]
    else:
        breakmeee

    plt.scatter(
        im.flatten(),
        err.flatten(),
        s=1,
        color=next(colors),
        label=f"{galaxy.galaxy_id}",
        alpha=0.5,
    )
    total_err += list(err.flatten())
    total_data += list(im.flatten())

total_data = np.array(total_data)
total_err = np.array(total_err)


# Remove duplicates and reorder
unique_x, unique_indices = np.unique(total_data, return_index=True)
x_unique = total_data[unique_indices]
y_unique = total_err[unique_indices]

lowess = sm.nonparametric.lowess(y_unique, x_unique, frac=0.1)
# unpack the lowess smoothed points to their values
lowess_x = list(zip(*lowess))[0]
lowess_y = list(zip(*lowess))[1]

# run scipy's interpolation. There is also extrapolation I believe
f = interp1d(lowess_x, lowess_y, bounds_error=False)


# Spline interpolation
# f = interp1d(x_unique, y_unique, kind='slinear', fill_value='extrapolate')
# f = UnivariateSpline(x_unique, y_unique, s=0.05, k =5)

# Create a finer grid for the interpolated values
x_fine = np.linspace(x_unique[0], x_unique[-1], 200)
y_fine = f(x_fine)


plt.plot(x_fine, y_fine, color="black", label="Spline Interpolation")

plt.xlabel("Data")
plt.ylabel("Error")
plt.legend()
plt.show()


print(f(0.015))

In [None]:
data = galaxy.im_paths["F444W"]
err = galaxy.rms_err_paths["F444W"]

data = fits.getdata(data, ext=galaxy.im_exts["F444W"])
err = fits.getdata(err, ext=galaxy.rms_err_exts["F444W"])

# Data is in MJy/sr with 0.03 arcsec pixels
# Convert to uJy

data *= 1e12  # MJy to uJy
data *= 2.11590909090909e-14  # pixel area in sr

err *= 1e12
err *= 2.11590909090909e-14

plt.scatter(data.flatten(), err.flatten(), s=1)

plt.plot(z[:, 0], z[:, 1], label="lowess", color="red")

In [None]:
test = np.random.uniform(0, 1, (100, 64, 64))
# Generate more interesting test data
# Gaussian with sigma set by distance from center
x = np.linspace(-3, 3, 64)
y = np.linspace(-3, 3, 64)
x, y = np.meshgrid(x, y)
z = np.exp(-0.5 * (x**2 + y**2))
z = z[np.newaxis, :, :]
test = np.repeat(z, 100, axis=0)
test += np.random.normal(0, 0.1, test.shape)
test[test < 0] = 0


html = galaxy.make_animation(test, save=False, html=True, n_draws=50)

from IPython.display import HTML

HTML(html)

In [None]:
html = galaxy.plot_bagpipes_map_gif(
    parameter="stellar_mass", weight_mass_sfr=True, logmap=True
)

from IPython.display import HTML

HTML(html)

In [None]:
galaxy.photometry_property_names

In [None]:
galaxy.plot_cutouts()

In [None]:
plt.rcParams["figure.dpi"] = 300
%matplotlib inline

galaxy.plot_photometry_bins()

plt.show()

In [None]:
print(tab)

In [None]:
# Simple test Bagpipes fit_instructions

sfh = {
    "age_max": (0.03, 1),  # Gyr
    "age_min": (0, 0.5),  # Gyr
    "metallicity": (1e-3, 2.5),  # solar
    "massformed": (4, 12),  # log mstar/msun
}

nebular = {}
nebular["logU"] = -2.0

dust = {}
dust["type"] = "Calzetti"
dust["Av"] = (0, 5.0)

fit_instructions = {
    "t_bc": 0.01,
    "constant": sfh,
    "nebular": nebular,
    "dust": dust,
}
meta = {"run_name": "initial_test_cnst_sfh"}

overall_dict = {"meta": meta, "fit_instructions": fit_instructions}

galaxy.run_bagpipes(overall_dict, overwrite=False)

In [None]:
%matplotlib inline
plt.rcParams["figure.dpi"] = 300
galaxy.plot_bagpipes_results("initial_test_cnst_sfh", reload_from_cat=False)
galaxy.plot_bagpipes_results(
    "initial_test_cnst_sfh", reload_from_cat=False, weight_mass_sfr=False
)
plt.show()

In [None]:
plt.rcParams["figure.dpi"] = 300
galaxy.plot_bagpipes_sed("initial_test_cnst_sfh", bins_to_show=[1, 16, 15]);

In [None]:
galaxy.plot_bagpipes_component_comparison(
    run_name="initial_test_cnst_sfh", n_draws=10000
)

In [None]:
galaxy.plot_bagpipes_corner(run_name="initial_test_cnst_sfh");

In [None]:
galaxy.plot_bagpipes_sfh(
    run_name="initial_test_cnst_sfh",
    bins_to_show=[
        "16",
        "MAG_APER_0.32 arcsec",
        "RESOLVED",
        "MAG_BEST",
        "MAG_AUTO",
        "MAG_ISO",
        "TOTAL_BIN",
    ],
);

In [None]:
table = galaxy.sed_fitting_table["bagpipes"]["initial_test_cnst_sfh"]
mask = [len(i) <= 2 for i in table["#ID"]]
filtered_table = table[mask]

print(
    f'Combined stellar mass is log10 Mstar = {np.log10(np.sum(10**filtered_table["stellar_mass_50"])):.2f}'
)

In [None]:
galaxy.init_galfind_phot()

In [None]:
rest_UV_wav_lims = [1250.0, 3000.0] * u.Angstrom
ref_wav = 1_500.0 * u.AA
conv_author_year = "M99"
kappa_UV_conv_author_year = "MD14"
dust_author_year = "M99"
load_in = False
# galaxy.galfind_phot_property_map('beta_phot', rest_UV_wav_lims = rest_UV_wav_lims, load_in = load_in);
# galaxy.galfind_phot_property_map('mUV_phot', rest_UV_wav_lims = rest_UV_wav_lims, ref_wav = ref_wav, load_in = load_in);
"""
print(galaxy.galfind_photometry_rest[bin].flux_Jy)
print(galaxy.galfind_photometry_rest[bin].properties)

bin = 'TOTAL_BIN'
print(galaxy.galfind_photometry_rest[bin].flux_Jy)
print(galaxy.galfind_photometry_rest[bin].properties)
bin = 'MAG_APER_0.32 arcsec'


phot_obj = copy.deepcopy(galaxy.galfind_photometry_rest[bin])
func = phot_obj.calc_SFR_UV_phot
phot_obj._calc_property(func,rest_UV_wav_lims = rest_UV_wav_lims, frame = 'obs', iters = 150,
                        kappa_UV_conv_author_year = kappa_UV_conv_author_year, dust_author_year = dust_author_year,
                        ref_wav = ref_wav)
props = phot_obj.properties
print(props)
"""


# galaxy.galfind_phot_property_map('MUV_phot', rest_UV_wav_lims = rest_UV_wav_lims, ref_wav = ref_wav, load_in = load_in);
# galaxy.galfind_phot_property_map('SFR_UV_phot', rest_UV_wav_lims = rest_UV_wav_lims, frame = 'obs', iters = 150,
#                                kappa_UV_conv_author_year = kappa_UV_conv_author_year, dust_author_year = dust_author_year,
#                                ref_wav = ref_wav, density = True, logmap = True, load_in = load_in);


galaxy.galfind_phot_property_map(
    "SFR_UV_phot",
    rest_UV_wav_lims=rest_UV_wav_lims,
    frame="obs",
    iters=150,
    kappa_UV_conv_author_year=kappa_UV_conv_author_year,
    dust_author_year=dust_author_year,
    ref_wav=ref_wav,
    density=False,
);

# galaxy.galfind_photometry_rest['1'].get_rest_UV_phot(rest_UV_wav_lims).flux_Jy

In [None]:
print(galaxy.M1500.unit)

In [None]:
galaxy.galfind_photometry_rest["5"].get_rest_UV_phot(rest_UV_wav_lims).flux_Jy_errs

In [None]:
phot1 = galaxy.galfind_photometry_rest["1"]

phot2 = galaxy.galfind_photometry_rest["2"].property_PDFs

print(phot1.get_rest_UV_phot(rest_UV_wav_lims).flux_Jy)

print(phot2)

In [None]:
print(galaxy.available_em_lines)

plt.rcParams["figure.dpi"] = 300
galaxy.plot_ew_figure(medium_bands_only=False);


# galaxy.galfind_phot_property_map('EW_rest_optical', line_names = ['[OII]-3727'], medium_bands_only = False, plot=True);

### Investigating JOF psfmatched 830

In [None]:
galaxy = ResolvedGalaxy.init(830, "JOF_psfmatched", "v11")


In [None]:
band = "F814W"

data = galaxy.psf_matched_data["star_stack"][band]

mask = galaxy.gal_region['pixedfit'].astype(bool)

signal_data = np.sum(data[mask])

print(f'Signal data in pixedfit mask for {band} is {signal_data}')

# 0.007593571674078703 - in table

In [None]:
a, b, theta = galaxy.plot_kron_ellipse(return_params=True, ax = None, center = None)

print(params)

In [None]:
galaxy.photometry_table['star_stack']['pixedfit']

galaxy.plot_photometry_bins(label_individual=False, bins_to_show = ['TOTAL_BIN',  'MAG_APER_0.32 arcsec', 'MAG_APER_TOTAL',], flux_unit=u.uJy)

# get flux in ellipse using photitls

from photutils import EllipticalAperture
center = (np.shape(data)[0]/2, np.shape(data)[1]/2)

print(center)

print(center)
aperture = EllipticalAperture(center, a, b, theta)

phot = aperture.do_photometry(data)

phot = phot[0] * u.uJy

print('phot kron', phot.to(u.ABmag))

# Get flux in 0.32 arcsec aperture

from photutils import CircularAperture

aperture = CircularAperture(center, 0.16/0.03)

phot = aperture.do_photometry(data)

phot = phot[0] * u.uJy

print('phot aper', phot.to(u.ABmag))


In [None]:
print(galaxy.im_zps)