In [None]:
# %%
import yt
import trident
import numpy as np
import matplotlib.pyplot as plt
import astropy
from astropy.table import Table, join, vstack
import os
import time
from matplotlib.ticker import MaxNLocator
from matplotlib.ticker import MultipleLocator, FormatStrFormatter, MaxNLocator, FixedLocator
from matplotlib.gridspec import GridSpec
import pygad as pg
from trident import LSF
import h5py

# %%
import matplotlib.pyplot as plt
print(plt.colormaps())

# %%
simba_data_path = '/Users/tsingh65/ASU Dropbox/Tanmay Singh/Grad School/MgII_TNG_Project/Primary Project/Trident Codes/SIMBA_IGrM/revised_final/spectra_Simba100_z0.1_n300d2-sample_localized_COS-G130M_OVI_combined.hdf5'

tng_data_path = '/Users/tsingh65/ASU Dropbox/Tanmay Singh/Grad School/MgII_TNG_Project/Primary Project/Trident Codes/TNG_50_z_0.1_groups/revised_final/spectra_TNG50-1_z0.1_n300d2-sample_localized_COS-G130M_OVI_combined.hdf5'

tng_file_0 = '/Users/tsingh65/ASU Dropbox/Tanmay Singh/Grad School/MgII_TNG_Project/Primary Project/Trident Codes/TNG_50_z_0.1_groups/revised_final/spectra_TNG50-1_z0.1_n300d2-sample_localized_COS-G130M_OVI_0-of-14.hdf5'

# %%
# load the keys of the hdf5 file
f = h5py.File(simba_data_path, 'r')
keys = list(f.keys())
print(keys)

# %%
# print the shpae of all the keys 
for key in keys:
    print(f'{key} : {f[key].shape}')
    

# %%
# load the largest EW_OVI_1031 value
EW_OVI_1031 = f['EW_OVI_1031']
EW_OVI_1031 = np.array(EW_OVI_1031)
print(EW_OVI_1031.shape)
print(np.max(EW_OVI_1031))

# %%
# Open HDF5 file
with h5py.File(tng_data_path, 'r') as f:
    # Fetch only the required equivalent width data (first group: indices 0:90000)
    chunk_size = 10000  # Chunk size for processing
    data_indices = range(0, 90000)  # Range for the first group
    eq_widths = []
    eq_width_indices = []

    # Process in chunks to find the 100th maximum
    for start in range(data_indices.start, data_indices.stop, chunk_size):
        end = min(start + chunk_size, data_indices.stop)
        chunk = f['EW_OVI_1031'][start:end]
        eq_widths.extend(chunk)
        eq_width_indices.extend(range(start, end))

    # Convert to numpy arrays for sorting
    eq_widths = np.array(eq_widths)
    eq_width_indices = np.array(eq_width_indices)

    # Find the 100th maximum equivalent width and its index
    sorted_indices = np.argsort(eq_widths)[::-1]  # Sort in descending order
    target_index_in_chunk = sorted_indices[50]  # 100th maximum (0-based index)
    target_index = eq_width_indices[target_index_in_chunk]

    print(f"100th maximum EW_OVI_1031 value: {eq_widths[target_index_in_chunk]}")
    print(f"Index of the 100th maximum EW_OVI_1031 in the dataset: {target_index}")

    # Fetch all information for the identified index
    result = {
        'EW_OVI_1031': f['EW_OVI_1031'][target_index],
        'EW_OVI_1037': f['EW_OVI_1037'][target_index],
        'flux': f['flux'][target_index, :],
        'ray_dir': f['ray_dir'][:],  # ray_dir is not indexed
        'ray_pos': f['ray_pos'][target_index, :],
        'ray_total_dl': f['ray_total_dl'][()],  # Scalar value
        'tau_OVI_1031': f['tau_OVI_1031'][target_index, :],
        'tau_OVI_1037': f['tau_OVI_1037'][target_index, :],
        'wave': f['wave'][:]
    }



# %%
# Save the result dictionary
print("Data fetched successfully. Example:", result)

# Plot flux with respect to wave
plt.figure(figsize=(10, 6))
plt.plot(result['wave'], result['flux'], label='Flux')
plt.xlabel('Wavelength')
plt.ylabel('Flux')
plt.title('Flux vs. Wavelength')
plt.legend()
plt.xlim(1030*(1+z),1042*(1+z))
#plt.xlim(1170,1190)
plt.grid()

plt.show()

# Calculate flux from tau and plot it with tau
flux_from_tau = np.exp(-result['tau_OVI_1031'])
plt.figure(figsize=(10, 6))
plt.plot(result['wave'], flux_from_tau, label='Flux (from Tau)')
plt.xlabel('Wavelength')
plt.ylabel('Flux (from Tau)')
plt.title('Flux (from Tau) vs. Wavelength')
plt.legend()
plt.grid()
#plt.xlim(1120,1160)
plt.xlim(1030*(1+z),1035*(1+z))
plt.show()

# %%
# save the flux_from_tau = np.exp(-result['tau_OVI_1031']) and result['wave'] to a h5 file
with h5py.File('flux_from_tau.h5', 'w') as f:
    f.create_dataset('flux', data=flux_from_tau)
    f.create_dataset('wavelength', data=result['wave'])
    f.create_dataset('tau', data=result['tau_OVI_1031'])
    

# %%
sg = trident.load_spectrum('flux_from_tau.h5')

# %%
wavelength_old = sg.lambda_field
flux_old = sg.flux_field


# %%
sg.apply_lsf(filename='COS_G130M_1150.txt')
wavelength_new = sg.lambda_field
flux_new = sg.flux_field



# %%
# plot the flux before and after applying the LSF
z = 0.09940180263022191

plt.figure(figsize=(10, 6))
plt.plot(wavelength_old, flux_old, label='Flux (before LSF)')
plt.plot(wavelength_new, flux_new, label='Flux (after LSF)')
plt.xlabel('Wavelength')
plt.ylabel('Flux')
plt.title('Flux vs. Wavelength')
plt.legend()
plt.grid()
plt.xlim(1030*(1+z), 1035*(1+z))
plt.show()


# %%
sg.add_gaussian_noise(20)

# %%
wavelength_obs = sg.lambda_field.value
flux = sg.flux_field 
flux_err = np.ones_like(flux) * 0.1  # Assume a constant error for now

# Constants
rest_wavelength = 1031.926  # OVI rest wavelength in Å
#z =  0.137807                 # Previously used redshift
z = 0.09940180263022191
observed_rest_wavelength = rest_wavelength * (1 + z)  # Adjusted rest wavelength
c = 299792.458             # Speed of light in km/s

# Convert to velocity space
velocity = c * (wavelength_obs - observed_rest_wavelength) / observed_rest_wavelength
zoom_mask = (velocity >= -800) & (velocity <= 800)
velocity_zoom = velocity[zoom_mask]
flux_zoom = flux[zoom_mask]
flux_err_zoom = flux_err[zoom_mask]

# Calculate the bounds for the shaded region
flux_upper = flux_zoom + flux_err_zoom
flux_lower = flux_zoom - flux_err_zoom

# Plot the zoomed-in region in velocity space
fig, ax = plt.subplots(figsize=(10, 6))

# Create a stepped plot using ax.step
ax.step(velocity_zoom, flux_zoom, where='mid', label='Flux', color='blue', linewidth=1.5)

# Add the shaded error region
ax.fill_between(velocity_zoom, flux_lower, flux_upper, color='green', alpha=0.3, label='Error')

# Add vertical and horizontal lines for reference
ax.axvline(x=0, color='red', linestyle='--', label=f'OVI {rest_wavelength} Å (0 km/s)', linewidth=2)
ax.axhline(y=1, color='gray', linestyle='--', label='Continuum')

# # Set x-axis formatter to display floating-point numbers
# formatter = FuncFormatter(lambda x, _: f'{x:.0f}')
# ax.xaxis.set_major_formatter(formatter)

# Set plot labels and title
ax.set_xlabel('Velocity (km/s)', fontsize=14)
ax.set_ylabel('Normalized Flux', fontsize=14)
# ax.set_xlim(-500, 500)
ax.set_ylim(0.1, 1.35)

# Customize ticks and grid
ax.tick_params(axis='both', which='major', labelsize=12)
ax.grid(True, linestyle='--', alpha=0.7)

# Add legend
ax.legend(fontsize=12)

# Save and show the plot
#plt.savefig('OVI_velocity_space_zoom_shaded_z_013748.png', dpi=300)
plt.show()

# %%
def shift_to_restframe(wave,z):
    rest = wave / (1 + z)
    return rest

def wave_to_vel(wavelength,center_wave):
    vel = 2.99792458e5 * ((wavelength/center_wave) - 1)
    return vel 

def vel_to_wave(vel,center_wave):
    wave = (1. + (vel/2.99792458e5)) * center_wave
    return wave

def match_ion(line):
    if line == 'H I 1216':
        ion = 'H1215'
        return ion
    elif line == 'H I 1026':
        ion = 'H1025'
        return ion
    elif line == 'Ly c':
        ion = 'H972'
        return ion
    elif line == 'O VI 1032':
        ion = 'OVI1031'
        return ion
    elif line == 'O VI 1038':
        ion = 'OVI1037'
        return ion
    elif line == 'Si II 1193':
        ion = 'SiII1193'
        return ion
    elif line == 'Si II 1190':
        ion = 'SiII1190'
        return ion
    elif line == 'Si III 1206':
        ion = 'SiIII1206'
        return ion
    elif line == 'N V 1239':
        ion = 'NV1238'
        return ion
    elif line == 'C II 1036':
        ion = 'CII1036'
        return ion
    else:
        print('Error: Check to make sure Trident ion matches Pygad')

def test_for_saturation(strong_line, flux, z):
    if strong_line == 'H I 1216':
        if np.min(flux) <= 0.05:
            use_line = 'H I 1026'
            saturated = True
        else:
            use_line = strong_line
            saturated = False
    elif strong_line == 'H I 1026':
        if np.min(flux) <= 0.25:
            use_line = 'H I 972'
            saturated = True
        else:
            use_line = strong_line
            saturated = False
    elif strong_line == 'O VI 1032':
        if np.min(flux) <= 0.25:
            use_line = 'O VI 1038'
            saturated = True
        else:
            use_line = strong_line
            saturated = False
    elif strong_line == 'Si II 1193':
        if np.min(flux) <= 0.25:
            use_line = 'Si II 1190'
            saturated = True
        else:
            use_line = strong_line
            saturated = False
    else:
        use_line = strong_line
        saturated = False
    return saturated, use_line

def EW_to_N(pg_ion,ew_err):

    ## Following Draine eq. 9.15

    f = float(pg.analysis.absorption_spectra.lines[pg_ion]['f'])
    l = float(pg.analysis.absorption_spectra.lines[pg_ion]['l'].split()[0])
    N = 1.13e12 * (1 * ew_err * 1.0e-11) / f / (l * 1.0e-8)**2

    return np.log10(N)

def N_to_EW(pg_ion,logN):
    
    ## Following Draine eq. 9.15

    f = float(pg.analysis.absorption_spectra.lines[pg_ion]['f'])
    l = float(pg.analysis.absorption_spectra.lines[pg_ion]['l'].split()[0])
    EW = f * (l * 1.0e-8)**2 * 10**(logN) / 1.13e12

    return EW




def fit_vp_pipeline(ray, ion, wave, vel, flux, error, z, logN_bounds=[11, 19], b_bounds=[5, 200], min_region_width=5, N_sigma=1.5):
    """
    Fits Voigt profiles to absorption regions and generates a table of results.

    Parameters:
    - ray: Sightline ID
    - ion: Ion name
    - wave: Wavelength array
    - vel: Velocity array
    - flux: Flux array
    - error: Error array
    - z: Redshift
    - logN_bounds: Logarithmic column density bounds for the fit
    - b_bounds: Doppler parameter bounds for the fit
    - min_region_width: Minimum width of the region to consider for fitting (in pixels)
    - N_sigma: Detection threshold in sigma for the region

    Returns:
    - sat_flag: Boolean indicating if the line is saturated
    - t: Astropy Table containing fitting results
    """
    wave_subset = wave[(vel >= -1500) & (vel <= 1500)]
    flux_subset = flux[(vel >= -1500) & (vel <= 1500)]
    error_subset = error[(vel >= -1500) & (vel <= 1500)]

    sat_flag = False
    if np.min(flux) <= 0.2:
        sat_flag = True

    # Creating output table
    t = Table(names=('Sightline', 'Species', 'EW(mA)', 'dEW(mA)', 'N', 'dN', 'b', 'db', 'v', 'dv', 'l', 'dl', 'UpLim', 'Sat', 'Chisq'),
              dtype=['i4', 'S10', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'bool', 'bool', 'f8'])

    # Detect absorption regions
    regions, _ = pg.analysis.vpfit.find_regions(
        wave_subset, flux_subset, error_subset, min_region_width=min_region_width, N_sigma=N_sigma, extend=True)
    print(f'Found {len(regions)} absorption regions for {ion}')

    if len(regions) > 0:
        # Fit detected regions
        fit = pg.analysis.vpfit.fit_profiles(
            ion, wave_subset, flux_subset, error_subset,
            chisq_lim=1, max_lines=6, mode="Voigt",
            logN_bounds=logN_bounds, b_bounds=b_bounds,
            min_region_width=min_region_width, N_sigma=N_sigma, extend=True
        )
        print(fit)
        chisq = fit['chisq']
        vels = wave_to_vel(fit['l'], float(pg.analysis.absorption_spectra.lines[pg_ion]['l'].split()[0]))

        for v in vels:
            for j in range(len(fit['EW'])):  # Loop over components
                t.add_row((
                    ray,
                    ion.replace(' ', ''),
                    fit['EW'][j] * 1000,  # EW in mA for the j-th component
                    np.nan,  # Placeholder for dEW
                    fit['N'][j],
                    fit['dN'][j],
                    fit['b'][j],
                    fit['db'][j],
                    wave_to_vel(fit['l'][j], float(pg.analysis.absorption_spectra.lines[pg_ion]['l'].split()[0])),
                    (wave_to_vel(fit['l'][j] + fit['dl'][j], float(pg.analysis.absorption_spectra.lines[pg_ion]['l'].split()[0]))
                    - wave_to_vel(fit['l'][j], float(pg.analysis.absorption_spectra.lines[pg_ion]['l'].split()[0]))) / 2.0,
                    fit['l'][j],
                    fit['dl'][j],
                    False,  # Not an upper limit
                    sat_flag,
                    fit['chisq'][j]  # Assign chisq per component
                ))

    else:
        # No detected regions: Calculate upper limit
        vel_subset = wave_to_vel(wave_subset, float(pg.analysis.absorption_spectra.lines[pg_ion]['l'].split()[0]))
        wave_pm100 = wave_subset[(vel_subset >= -50.0) & (vel_subset <= 50.0)]
        flux_pm100 = flux_subset[(vel_subset >= -50.0) & (vel_subset <= 50.0)]
        error_pm100 = error_subset[(vel_subset >= -50.0) & (vel_subset <= 50.0)]
        ew_pm100 = pg.analysis.vpfit.EquivalentWidth(flux_pm100, wave_pm100) * 1000  # mA
        print(f'EW: {ew_pm100}')
        print(f'error_pm100: {error_pm100}')
        print(f'abs(wave_pm100[1] - wave_pm100[0])= {abs(wave_pm100[1] - wave_pm100[0])}')
        dew_pm100 = ((np.sqrt(np.sum(error_pm100**2))) * abs(wave_pm100[1] - wave_pm100[0])) * 1000  # mA
        print(f'dEW: {dew_pm100}')
        N_lim = EW_to_N(pg_ion, dew_pm100)

        # Add a single row for the non-detection case
        t.add_row((
            ray,
            ion.replace(' ', ''),
            ew_pm100,  # EW in mA
            dew_pm100,  # dEW in mA
            N_lim,
            np.nan,  # dN placeholder
            np.nan,  # b placeholder
            np.nan,  # db placeholder
            0,  # velocity placeholder
            50,  # dv placeholder
            np.nan,  # wavelength placeholder
            np.nan,  # dl placeholder
            True,  # Upper limit
            sat_flag,
            np.nan  # chisq placeholder
        ))

    return sat_flag, t

# %%
EW_to_N('OVI1031',11.88)

# %%
def fit_vp_pipeline_new(ray, ion, wave, vel, flux, error, z, logN_bounds=[11, 19], b_bounds=[5, 200], min_region_width=5, N_sigma=1.5,chisq_lim=0.25):
    """
    Fits Voigt profiles to absorption regions and generates a table of results.

    Parameters:
    - ray: Sightline ID
    - ion: Ion name
    - wave: Wavelength array
    - vel: Velocity array
    - flux: Flux array
    - error: Error array
    - z: Redshift
    - logN_bounds: Logarithmic column density bounds for the fit
    - b_bounds: Doppler parameter bounds for the fit
    - min_region_width: Minimum width of the region to consider for fitting (in pixels)
    - N_sigma: Detection threshold in sigma for the region

    Returns:
    - sat_flag: Boolean indicating if the line is saturated
    - t: Astropy Table containing fitting results
    """
    pg_ion = 'OVI1031'
    # wave_subset = wave[(vel >= -1500) & (vel <= 1500)]
    # flux_subset = flux[(vel >= -1500) & (vel <= 1500)]
    # error_subset = error[(vel >= -1500) & (vel <= 1500)]
    wave_subset = wave[(vel >= -800) & (vel <= 800)]
    flux_subset = flux[(vel >= -800) & (vel <= 800)]
    error_subset = error[(vel >= -800) & (vel <= 800)]

    sat_flag = False
    if np.min(flux) <= 0.2:
        sat_flag = True

    # Creating output table
    t = Table(names=('Sightline', 'Species', 'EW(mA)', 'dEW(mA)', 'N', 'dN', 'b', 'db', 'v', 'dv', 'l', 'dl', 'UpLim', 'Sat', 'Chisq'),
              dtype=['i4', 'S10', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'bool', 'bool', 'f8'])

    # Detect absorption regions
    regions, _ = pg.analysis.vpfit.find_regions(
        wave_subset, flux_subset, error_subset, min_region_width=min_region_width, N_sigma=N_sigma, extend=True)
    print(f'Found {len(regions)} absorption regions for {ion}')

    if len(regions) > 0:
        # Fit detected regions
        fit = pg.analysis.vpfit.fit_profiles(
            ion, wave_subset, flux_subset, error_subset,
            chisq_lim=chisq_lim, max_lines=4, mode="Voigt",
            logN_bounds=logN_bounds, b_bounds=b_bounds,
            min_region_width=min_region_width, N_sigma=N_sigma, extend=True
        )
        print(fit)
        chisq = fit['chisq']
        vels = wave_to_vel(fit['l'], float(pg.analysis.absorption_spectra.lines[pg_ion]['l'].split()[0]))

        # for v in vels:
        #     for j in range(len(fit['EW'])):  # Loop over components
        #         t.add_row((
        #             ray,
        #             ion.replace(' ', ''),
        #             fit['EW'][j] * 1000,  # EW in mA for the j-th component
        #             np.nan,  # Placeholder for dEW
        #             fit['N'][j],
        #             fit['dN'][j],
        #             fit['b'][j],
        #             fit['db'][j],
        #             wave_to_vel(fit['l'][j], float(pg.analysis.absorption_spectra.lines[pg_ion]['l'].split()[0])),
        #             (wave_to_vel(fit['l'][j] + fit['dl'][j], float(pg.analysis.absorption_spectra.lines[pg_ion]['l'].split()[0]))
        #             - wave_to_vel(fit['l'][j], float(pg.analysis.absorption_spectra.lines[pg_ion]['l'].split()[0]))) / 2.0,
        #             fit['l'][j],
        #             fit['dl'][j],
        #             False,  # Not an upper limit
        #             sat_flag,
        #             fit['chisq'][j]  # Assign chisq per component
        #         ))
        for j in range(len(fit['EW'])):  # Loop over components
            t.add_row((
                ray,
                ion.replace(' ', ''),
                fit['EW'][j] * 1000,  # EW in mA for the j-th component
                np.nan,  # Placeholder for dEW
                fit['N'][j],
                fit['dN'][j],
                fit['b'][j],
                fit['db'][j],
                wave_to_vel(fit['l'][j], float(pg.analysis.absorption_spectra.lines[pg_ion]['l'].split()[0])),
                (wave_to_vel(fit['l'][j] + fit['dl'][j], float(pg.analysis.absorption_spectra.lines[pg_ion]['l'].split()[0]))
                - wave_to_vel(fit['l'][j], float(pg.analysis.absorption_spectra.lines[pg_ion]['l'].split()[0]))) / 2.0,
                fit['l'][j],
                fit['dl'][j],
                False,  # Not an upper limit
                sat_flag,
                fit['chisq'][j]  # Assign chisq per component
            ))

    else:
        fit = None
        # No detected regions: Calculate upper limit
        vel_subset = wave_to_vel(wave_subset, float(pg.analysis.absorption_spectra.lines[pg_ion]['l'].split()[0]))
        wave_pm100 = wave_subset[(vel_subset >= -50.0) & (vel_subset <= 50.0)]
        flux_pm100 = flux_subset[(vel_subset >= -50.0) & (vel_subset <= 50.0)]
        error_pm100 = error_subset[(vel_subset >= -50.0) & (vel_subset <= 50.0)]
        ew_pm100 = pg.analysis.vpfit.EquivalentWidth(flux_pm100, wave_pm100) * 1000  # mA
        print(f'EW: {ew_pm100}')
        print(f'error_pm100: {error_pm100}')
        print(f'abs(wave_pm100[1] - wave_pm100[0])= {abs(wave_pm100[1] - wave_pm100[0])}')
        dew_pm100 = ((np.sqrt(np.sum(error_pm100**2))) * abs(wave_pm100[1] - wave_pm100[0])) * 1000  # mA
        print(f'dEW: {dew_pm100}')
        N_lim = EW_to_N(pg_ion, dew_pm100)

        # Add a single row for the non-detection case
        t.add_row((
            ray,
            ion.replace(' ', ''),
            ew_pm100,  # EW in mA
            dew_pm100,  # dEW in mA
            N_lim,
            np.nan,  # dN placeholder
            np.nan,  # b placeholder
            np.nan,  # db placeholder
            0,  # velocity placeholder
            50,  # dv placeholder
            np.nan,  # wavelength placeholder
            np.nan,  # dl placeholder
            True,  # Upper limit
            sat_flag,
            np.nan  # chisq placeholder
        ))

    return sat_flag, t,fit,regions

results_table = Table(
    names=(
        'Sightline', 'Species', 'EW(mA)', 'dEW(mA)', 'N', 'dN', 'b', 'db', 'v', 
        'dv', 'l', 'dl', 'UpLim', 'Sat', 'Chisq'
    ),
    dtype=['i4', 'S10', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'bool', 'bool', 'f8']
)

line_list = ['O VI 1032']

# %%
results_table = Table(
    names=(
        'Sightline', 'Species', 'EW(mA)', 'dEW(mA)', 'N', 'dN', 'b', 'db', 'v', 
        'dv', 'l', 'dl', 'UpLim', 'Sat', 'Chisq'
    ),
    dtype=['i4', 'S10', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'bool', 'bool', 'f8']
)

line_list = ['O VI 1032']

# %%


try:

    sg.apply_lsf(filename='COS_G130M_1150.txt')
    sg.add_gaussian_noise(20)


    pg_ion = 'OVI1031'
    wave_binned = np.arange(sg.lambda_min.value, sg.lambda_max.value, 0.0112)
    flux_binned = np.interp(wave_binned, sg.lambda_field.value, sg.flux_field)
    #error_binned = np.interp(wave_binned, sg.lambda_field.value, sg.error_func(sg.flux_field))
    error_binned = np.interp(wave_binned, sg.lambda_field.value, flux_err)
    print(f"Error Binned is {error_binned}")
    # change the error to 1 sigma (1/3)
    # error_binned = error_binned *0.5
    # print(f"Error Binned is {error_binned}")
    wave_binned /= (1 + z)
    vel_binned = wave_to_vel(wave_binned, float(pg.analysis.absorption_spectra.lines[pg_ion]['l'].split()[0]))

    min_region_width = 3  # pixels
    #N_sigma = 0.5  # 1-sigma detection limit
    N_sigma = 1 # 1-sigma detection limit
    logN_bounds = [12, 18]
    b_bounds = [5, 200]

    saturation_flag, output_table,fit = fit_vp_pipeline_new( 1, pg_ion, wave_binned, vel_binned, flux_binned, error_binned, z, logN_bounds, b_bounds, min_region_width, N_sigma)
    results_table = vstack((results_table, output_table))

except RuntimeError:
    empty_row = Table(
        [[ 1], [np.nan], [np.nan], [np.nan], [np.nan], [np.nan], [np.nan], [np.nan], [np.nan], [np.nan], [np.nan], [np.nan], [False], [np.nan], [np.nan]],)

# # Save intermediate results every 50 iterations
# if (i + 1) % 50 == 0:
#     print(f"Writing intermediate results after {i + 1} sightlines...")
#     results_table.write(
#         os.path.join(data_path, f'Grp_{halo_ids[grp_index]}_fitting_results.txt'),
#         format='ascii.fixed_width',
#         overwrite=True
#     )

# # Final save for all results
# print("Writing final results...")
# results_table.write(
#     os.path.join(sol_result_path, 'local_run_6360.txt'),
#     format='ascii.fixed_width',
#     overwrite=True
# )

# %%
fit

# %%
def generate_params(fitting_data):
    """
    Generate the parameter array for Voigt profile fitting from the fitting_data dictionary.

    Parameters:
        fitting_data (dict): Dictionary containing 'N', 'b', and 'l' arrays.

    Returns:
        np.ndarray: Flattened parameter array in the required format.
    """
    # Ensure 'N', 'b', and 'l' keys are present
    if not all(key in fitting_data for key in ['N', 'b', 'l']):
        raise ValueError("fitting_data must contain 'N', 'b', and 'l' keys.")
    
    # Interleave 'N', 'b', and 'l' to create the params array
    params = np.empty((3 * len(fitting_data['N'])))
    params[0::3] = fitting_data['N']  # Column densities
    params[1::3] = fitting_data['b']  # Doppler parameters
    params[2::3] = fitting_data['l']  # Centroid wavelengths
    
    return params

# %%
import numpy as np
import matplotlib.pyplot as plt
import os
import h5py
from scipy.constants import c as c_km_s  # Speed of light in km/s
import pygad as pg

# Constants
rest_wavelength = 1031.926  # OVI rest wavelength in Å
z = 0.09940180263022191                 # Redshift
#z =  0.137807   
c_km_s = c_km_s / 1e3       # Speed of light in km/s

# Provided Data
wavelength_obs = sg.lambda_field.value  # Observed wavelength from the previous code
flux = sg.flux_field                    # Flux from the previous code
flux_err = np.ones_like(flux) * 0.1    # Assume a constant flux error

# Convert observed wavelength to rest frame
wavelength_rest = wavelength_obs / (1 + z)

# Select the range in rest wavelength space
zoom_mask = (wavelength_rest >= 1030) & (wavelength_rest <= 1035)

# Extract zoomed data in rest wavelength space
wavelength_zoom_rest = wavelength_rest[zoom_mask]
flux_zoom = flux[zoom_mask]
flux_err_zoom = flux_err[zoom_mask]
flux_upper = flux_zoom + flux_err_zoom
flux_lower = flux_zoom - flux_err_zoom


fitting_data = fit
# Generate the Voigt profile for the fitted spectrum
line_data = pg.analysis.absorption_spectra.lines['OVI1031']

# Generate the params array dynamically
params = generate_params(fitting_data)

# Voigt profile computation remains the same
line_data = pg.analysis.absorption_spectra.lines['OVI1031']
wave_subset = wavelength_zoom_rest  # Use the zoomed rest wavelength range
total_tau = pg.analysis.vpfit.model_tau(line_data, params, wave_subset, mode='Voigt')
model_flux = np.exp(-total_tau)

# Rest of the code for plotting, FWHM calculation, and annotations remains unchanged.

# Compute velocity space
observed_rest_wavelength = rest_wavelength * (1 + z)
velocity = c_km_s * (wavelength_obs - observed_rest_wavelength) / observed_rest_wavelength

# Zoom mask for velocity range (-800 km/s to +800 km/s)
velocity_zoom = velocity[zoom_mask]

# Calculate FWHM for each feature
centroids = params[2::3]  # Extract every 3rd value starting from index 2
doppler_params = params[1::3]  # Extract every 3rd value starting from index 1
fwhms = [(2 * np.sqrt(np.log(2)) * b / c_km_s) * l for b, l in zip(doppler_params, centroids)]

# Plotting
fig, ax = plt.subplots(figsize=(12, 6))

# Original spectrum
ax.plot(wavelength_zoom_rest, flux_zoom, label='Original Spectrum', color='blue')
ax.fill_between(wavelength_zoom_rest, flux_lower, flux_upper, color='green', alpha=0.3, label='Error')

# Fitted spectrum (Voigt profile)
ax.step(wave_subset, model_flux, label='Fitted Spectrum (Voigt Profile)', color='red', where='mid', linewidth=3)

# Highlight centroids and add FWHM arrows
arrowprops = dict(arrowstyle='<->', color='black', lw=1.5)
for i, (centroid, fwhm, b) in enumerate(zip(centroids, fwhms, doppler_params)):
    ax.axvline(centroid, color=f'C{i}', linestyle='--', label=f'Centroid {i+1}: {centroid:.2f} Å', linewidth=2)
    ax.annotate('', xy=(centroid - fwhm / 2, 0.75 - i * 0.05), 
                xytext=(centroid + fwhm / 2, 0.75 - i * 0.05), arrowprops=arrowprops)
    ax.text(centroid, 0.77 - i * 0.05, f'b = {b:.1f} km/s', ha='center', fontsize=10)

# Highlight reference line for OVI in rest wavelength
ax.axvline(rest_wavelength, color='purple', linestyle='--', label=f'OVI {rest_wavelength:.2f} Å (Rest Frame)', linewidth=2)

# Set labels, limits, and grid
ax.set_xlabel('Rest Wavelength (Å)', fontsize=14)
ax.set_ylabel('Normalized Flux', fontsize=14)
ax.set_xlim(wavelength_zoom_rest.min(), wavelength_zoom_rest.max())
ax.set_ylim(0.2, 1.35)
ax.grid(True, linestyle='--', alpha=0.7)
#ax.legend(fontsize=12)

# Save and show the plot
plt.savefig('OVI_Voigt_Profile_Fit_0.1_err.png', dpi=300)
plt.show()

# %%
import numpy as np
import matplotlib.pyplot as plt
from scipy.constants import c as c_km_s  # Speed of light in km/s
import pygad as pg

# Constants
rest_wavelength = 1031.926  # OVI rest wavelength in Å
z = 0.09940180263022191  # Redshift
c_km_s = c_km_s / 1e3  # Speed of light in km/s

# Provided Data
wavelength_obs = sg.lambda_field.value  # Observed wavelength
flux = sg.flux_field  # Flux
flux_err = np.ones_like(flux) * 0.15  # Assume a constant flux error

# Convert observed wavelength to rest frame
wavelength_rest = wavelength_obs / (1 + z)

# Select the range in rest wavelength space
zoom_mask = (wavelength_rest >= 1030) & (wavelength_rest <= 1035)

# Extract zoomed data in rest wavelength space
wavelength_zoom_rest = wavelength_rest[zoom_mask]
flux_zoom = flux[zoom_mask]
flux_err_zoom = flux_err[zoom_mask]
flux_upper = flux_zoom + flux_err_zoom
flux_lower = flux_zoom - flux_err_zoom

# Generate the Voigt profile for the fitted spectrum
line_data = pg.analysis.absorption_spectra.lines['OVI1031']

# Generate the params array dynamically
params = generate_params(fitting_data)

# Compute Voigt profile
wave_subset = wavelength_zoom_rest  # Use the zoomed rest wavelength range
total_tau = pg.analysis.vpfit.model_tau(line_data, params, wave_subset, mode='Voigt')
model_flux = np.exp(-total_tau)

# Compute velocity space
observed_rest_wavelength = rest_wavelength * (1 + z)
velocity = c_km_s * (wavelength_obs - observed_rest_wavelength) / observed_rest_wavelength

# Zoom mask for velocity range (-800 km/s to +800 km/s)
velocity_zoom = velocity[zoom_mask]

# Convert wavelengths of the Voigt profile to velocity
velocity_model = c_km_s * (wave_subset - observed_rest_wavelength) / observed_rest_wavelength

# Calculate FWHM for each feature
centroids = params[2::3]  # Extract every 3rd value starting from index 2
doppler_params = params[1::3]  # Extract every 3rd value starting from index 1
fwhms = [(2 * np.sqrt(np.log(2)) * b / c_km_s) * l for b, l in zip(doppler_params, centroids)]

# Plotting in velocity space
fig, ax = plt.subplots(figsize=(12, 6))

# Plot flux in velocity space
ax.step(velocity_zoom, flux_zoom, where='mid', label='Original Spectrum', color='blue', linewidth=1.5)

# Shaded error region
ax.fill_between(velocity_zoom, flux_lower, flux_upper, color='green', alpha=0.3, label='Error')

# Plot Voigt profile in velocity space
ax.step(velocity_model, model_flux, label='Fitted Spectrum (Voigt Profile)', color='red', where='mid', linewidth=3)

# Highlight centroids and add FWHM arrows
arrowprops = dict(arrowstyle='<->', color='black', lw=1.5)
for i, (centroid, fwhm, b) in enumerate(zip(centroids, fwhms, doppler_params)):
    centroid_velocity = c_km_s * (centroid - rest_wavelength) / rest_wavelength
    if -800 <= centroid_velocity <= 800:  # Include only visible centroids
        ax.axvline(centroid_velocity, color=f'C{i}', linestyle='--', label=f'Centroid {i+1}: {centroid_velocity:.1f} km/s')
        ax.annotate('', xy=(centroid_velocity - fwhm / 2, 0.75 - i * 0.05), 
                    xytext=(centroid_velocity + fwhm / 2, 0.75 - i * 0.05), arrowprops=arrowprops)
        ax.text(centroid_velocity, 0.77 - i * 0.05, f'b = {b:.1f} km/s', ha='center', fontsize=10)

# Highlight reference line for 0 km/s
ax.axvline(0, color='purple', linestyle='--', label='OVI 1031.926 Å (0 km/s)', linewidth=2)

# Set labels, limits, and grid
ax.set_xlabel('Velocity (km/s)', fontsize=14)
ax.set_ylabel('Normalized Flux', fontsize=14)
ax.set_xlim(-800, 800)
ax.set_ylim(0.2, 1.35)
ax.grid(True, linestyle='--', alpha=0.7)
ax.legend(fontsize=12)

# Save and show the plot
plt.savefig('OVI_Voigt_Profile_Fit_Velocity_0.1.png', dpi=300)
plt.show()

# %% [markdown]
# #### Test for a single group

# %%
import numpy as np
import h5py
from astropy.table import Table, vstack
import os
import pygad as pg
import matplotlib.pyplot as plt
from scipy.constants import c as c_km_s
from tqdm import tqdm

# Constants
rest_wavelength = 1031.926  # OVI rest wavelength in Å
z = 0.09940180263022191  # Redshift
c_km_s = c_km_s / 1e3  # Speed of light in km/s
chunk_size = 25  # Interval for saving intermediate results
output_dir = "./results_group_0_err_0.2/"
os.makedirs(output_dir, exist_ok=True)  # Ensure output directory exists

# Input file
#tng_data_path = "spectra_TNG50-1_z0.1_n300d2-sample_localized_COS-G130M_OVI_combined.hdf5"

# Initialize results table
def initialize_results_table():
    return Table(
        names=(
            'Sightline', 'Species', 'EW(mA)', 'dEW(mA)', 'N', 'dN', 'b', 'db', 'v', 
            'dv', 'l', 'dl', 'UpLim', 'Sat', 'Chisq'
        ),
        dtype=['i4', 'S10', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'bool', 'bool', 'f8']
    )

# Process spectra
def process_spectra(group_indices, spectra_file, output_dir, save_interval=25):
    # Open HDF5 file
    with h5py.File(spectra_file, 'r') as f:
        # Extract group-specific data
        wavelength_obs = f['wave'][:]
        flux_data = f['flux'][group_indices, :]
        tau_data = f['tau_OVI_1031'][group_indices, :]

        # Initialize results table
        results_table = initialize_results_table()

        for i, (flux, tau) in tqdm(enumerate(zip(flux_data, tau_data), start=1)):
            try:
                # Calculate flux from tau
                flux_from_tau = np.exp(-tau)

                # Generate trident spectrum for processing
                with h5py.File('flux_from_tau_temp.h5', 'w') as temp_file:
                    temp_file.create_dataset('flux', data=flux_from_tau)
                    temp_file.create_dataset('wavelength', data=wavelength_obs)
                    temp_file.create_dataset('tau', data=tau)
                
                sg = trident.load_spectrum('flux_from_tau_temp.h5')
                sg.apply_lsf(filename='COS_G130M_1150.txt')
                sg.add_gaussian_noise(20)

                # Set flux error
                flux_err = np.ones_like(flux) * 0.20  # Constant error

                # Convert to velocity space
                velocity = c_km_s * (wavelength_obs - rest_wavelength * (1 + z)) / (rest_wavelength * (1 + z))
                zoom_mask = (velocity >= -800) & (velocity <= 800)
                velocity_zoom = velocity[zoom_mask]
                flux_zoom = flux[zoom_mask]
                flux_err_zoom = flux_err[zoom_mask]

                # Fitting pipeline
                wave_binned = np.arange(sg.lambda_min.value, sg.lambda_max.value, 0.0112)
                flux_binned = np.interp(wave_binned, sg.lambda_field.value, sg.flux_field)
                error_binned = np.interp(wave_binned, sg.lambda_field.value, flux_err)
                wave_binned /= (1 + z)
                vel_binned = c_km_s * (wave_binned - rest_wavelength) / rest_wavelength

                min_region_width = 3  # pixels
                N_sigma = 1  # 1-sigma detection limit
                logN_bounds = [12, 18]
                b_bounds = [10, 200]

                saturation_flag, output_table, fit = fit_vp_pipeline_new(
                    i, 'OVI1031', wave_binned, vel_binned, flux_binned, error_binned, 
                    z, logN_bounds, b_bounds, min_region_width, N_sigma
                )

                # Append results
                results_table = vstack((results_table, output_table))

            except RuntimeError as e:
                print(f"Error processing spectrum {i}: {e}")
                empty_row = Table(
                    [[i], ['OVI1031'], [np.nan], [np.nan], [np.nan], [np.nan], [np.nan], [np.nan], 
                     [np.nan], [np.nan], [np.nan], [np.nan], [False], [False], [np.nan]],
                    names=results_table.colnames
                )
                results_table = vstack((results_table, empty_row))

            # Save intermediate results
            if i % save_interval == 0:
                print(f"Saving intermediate results at spectrum {i}...")
                results_table.write(
                    os.path.join(output_dir, f'intermediate_results_{i}.txt'),
                    format='ascii.fixed_width',
                    overwrite=True
                )

        # Save final results
        print("Saving final results...")
        results_table.write(
            os.path.join(output_dir, 'final_results.txt'),
            format='ascii.fixed_width',
            overwrite=True
        )

# Define indices for the first group
group_indices = range(0,50)

# Process the spectra
process_spectra(group_indices, tng_data_path, output_dir, save_interval=chunk_size)

# %% [markdown]
# ### Code for binning, adding noise and then fitting 

# %% [markdown]
# ##### Loading the data example 1 = 50th , example 2 = 500th, example 3 = 750

# %%
simba_data_path = '/Users/tsingh65/ASU Dropbox/Tanmay Singh/Grad School/MgII_TNG_Project/Primary Project/Trident Codes/SIMBA_IGrM/revised_final/spectra_Simba100_z0.1_n300d2-sample_localized_COS-G130M_OVI_combined.hdf5'

tng_data_path = '/Users/tsingh65/ASU Dropbox/Tanmay Singh/Grad School/MgII_TNG_Project/Primary Project/Trident Codes/TNG_50_z_0.1_groups/revised_final/spectra_TNG50-1_z0.1_n300d2-sample_localized_COS-G130M_OVI_combined.hdf5'

from scipy.constants import c as c_km_s
c_km_s = c_km_s / 1e3  # Speed of light in km/s
# Open HDF5 file
with h5py.File(tng_data_path, 'r') as f:
    # Fetch only the required equivalent width data (first group: indices 0:90000)
    chunk_size = 10000  # Chunk size for processing
    data_indices = range(0, 90000)  # Range for the first group
    eq_widths = []
    eq_width_indices = []

    # Process in chunks to find the 100th maximum
    for start in range(data_indices.start, data_indices.stop, chunk_size):
        end = min(start + chunk_size, data_indices.stop)
        chunk = f['EW_OVI_1031'][start:end]
        eq_widths.extend(chunk)
        eq_width_indices.extend(range(start, end))

    # Convert to numpy arrays for sorting
    eq_widths = np.array(eq_widths)
    eq_width_indices = np.array(eq_width_indices)

    # Find the 100th maximum equivalent width and its index
    sorted_indices = np.argsort(eq_widths)[::-1]  # Sort in descending order
    target_index_in_chunk = sorted_indices[200]  # 100th maximum (0-based index)
    target_index = eq_width_indices[target_index_in_chunk]

    print(f"100th maximum EW_OVI_1031 value: {eq_widths[target_index_in_chunk]}")
    print(f"Index of the 100th maximum EW_OVI_1031 in the dataset: {target_index}")

    # Fetch all information for the identified index
    result = {
        'EW_OVI_1031': f['EW_OVI_1031'][target_index],
        'EW_OVI_1037': f['EW_OVI_1037'][target_index],
        'flux': f['flux'][target_index, :],
        'ray_dir': f['ray_dir'][:],  # ray_dir is not indexed
        'ray_pos': f['ray_pos'][target_index, :],
        'ray_total_dl': f['ray_total_dl'][()],  # Scalar value
        'tau_OVI_1031': f['tau_OVI_1031'][target_index, :],
        'tau_OVI_1037': f['tau_OVI_1037'][target_index, :],
        'wave': f['wave'][:]
    }

# %%
import h5py
import numpy as np
from scipy.constants import c as c_km_s


# set the random seed for reproducibility = 42
np.random.seed(42)

# Speed of light in km/s
c_km_s = c_km_s / 1e3  

# Define file path
tng_data_path = '/Users/tsingh65/ASU Dropbox/Tanmay Singh/Grad School/MgII_TNG_Project/Primary Project/Trident Codes/TNG_50_z_0.1_groups/revised_final/spectra_TNG50-1_z0.1_n300d2-sample_localized_COS-G130M_OVI_combined.hdf5'
#tng_data_path = '/Users/tsingh65/Downloads/spectra_TNG50-1_z0.1_n300d2-sample_localized_COS-G130M_OVI_6-of-14.hdf5'
# Specify the target index directly
# 1) 9665
# 2) 35710 
# for group 6 , index 30 is 6*90000 + 30  = 540030 (can't leave the N=18 absorber)
# for group 6 , index 36 is 6*90000 + 36  = 540036 (works good)
# for group 6 , index 56 is 6*90000 + 56  = 540056 (no feature but still fitting is being done)
# for group 6,  index 68 is 6*90000 + 68  = 540068 (Ok )
# for group 6,  index 93 is 6*90000 + 93  = 540093 (OK, leaving the higher N value removes the saturation effect. )
# for group 6,  index 129 is 6*90000 + 129  = 540129 ( working)
# for group 6, index 45 is 6*90000 + 45 = 540045 (can't leave the N=18 absorber)
# for group 6, index 277 is 6*90000 + 277 = 540277 
# for group 6, index 778 is 6*90000 + 778 = 540778
# for group 6, index 80583 is 6*90000 + 80583 = 620583
# for group 6, index 40317 is 6*90000 + 40317 = 580317
# for group 1, index 34791 is 1*90000 + 34791 = 124791
# for group 7 , index = 2877 is 7*90000 + 2877 = 632877

target_index = 656688


# Open HDF5 file
with h5py.File(tng_data_path, 'r') as f:
    # Fetch all information for the identified index
    result = {
        'EW_OVI_1031': f['EW_OVI_1031'][target_index],
        'EW_OVI_1037': f['EW_OVI_1037'][target_index],
        'flux': f['flux'][target_index, :],
        'ray_dir': f['ray_dir'][:],  # ray_dir is not indexed
        'ray_pos': f['ray_pos'][target_index, :],
        'ray_total_dl': f['ray_total_dl'][()],  # Scalar value
        'tau_OVI_1031': f['tau_OVI_1031'][target_index, :],
        'tau_OVI_1037': f['tau_OVI_1037'][target_index, :],
        'wave': f['wave'][:]
    }

# Print the results for the given index
print(f"Results for index {target_index}:")
for key, value in result.items():
    print(f"{key}: {value}")


# %%
656688- 7*90000

# %%
# figure out the coordinate of this sightline, virial radius of the group and centre and then plot the 2D location of the sightline
group_7_path = '/Users/tsingh65/ASU Dropbox/Tanmay Singh/Grad School/MgII_TNG_Project/Primary Project/Trident Codes/TNG_50_z_0.1_groups/revised_final/integral_TNG50-1_z0.1_n300d2-sample_localized_OVInumdens_7-of-14.hdf5'

# Read the file
with h5py.File(group_7_path, 'r') as f:
    # Print all keys
    keys = list(f.keys())
    print("Keys in the file:", keys)
    
    # Print shapes of each dataset
    for key in keys:
        print(f"{key}: {f[key].shape}")
    
    # Extract the relevant data for index 3099
    index = 26688
    ray_pos = f['ray_pos'][index]  # Position of the sightline
    ray_dir = f['ray_dir'][:]  # Direction of the rays
    ray_total_dl = f['ray_total_dl'][()]  # Total path length

    print("\nCoordinates of sightline (ray_pos):", ray_pos)
    print("Direction of the sightline (ray_dir):", ray_dir)
    print("Total path length (ray_total_dl):", ray_total_dl)
    # calculate and print the i

        
    
        

# %%
import h5py
import numpy as np
import matplotlib.pyplot as plt

# File path
tng_data_path = '/Users/tsingh65/ASU Dropbox/Tanmay Singh/Grad School/MgII_TNG_Project/Primary Project/Trident Codes/TNG_50_z_0.1_groups/revised_final/spectra_TNG50-1_z0.1_n300d2-sample_localized_COS-G130M_OVI_combined.hdf5'

# Group 7 parameters
group_center = np.array([4834.154, 22167.719, 16398.639])  # Group center (x, y, z)
r_vir = 384.19003  # Virial radius
circle_radius = 1.5 * r_vir  # Circle radius for 1.5 R_vir
group_start_index = 90000 * 7
group_end_index = 90000 * 8

# EW and position extraction
sightline_positions = []
with h5py.File(tng_data_path, 'r') as f:
    # Read data
    EW_OVI_1031 = f['EW_OVI_1031'][group_start_index:group_end_index]
    ray_pos = f['ray_pos'][group_start_index:group_end_index, :]

    # Filter criteria
    for i, (EW, pos) in enumerate(zip(EW_OVI_1031, ray_pos)):
        # Compute 2D distance (impact parameter) in the x, y plane
        distance_2d = np.linalg.norm(pos[:2] - group_center[:2])

        # Check if sightline satisfies the criteria
        if 0.3 <= EW <= 0.7 and (0.7 * r_vir) <= distance_2d <= (0.9 * r_vir):
            sightline_positions.append(pos[:2])  # Only x, y coordinates
            print(f"Sightline {i + group_start_index} meets the criteria.")
            
            print(f"EW_OVI_1031: {EW:.3f}, 2D Distance: {distance_2d:.3f} kpc")
            

# Convert positions to NumPy array for easier handling
sightline_positions = np.array(sightline_positions)

# Plotting
fig, ax = plt.subplots(figsize=(8, 6))

# Plot group center
ax.scatter(group_center[0], group_center[1], color='red', label='Group Center', s=100, zorder=3)

# Plot sightlines
if sightline_positions.size > 0:
    ax.scatter(sightline_positions[:, 0], sightline_positions[:, 1], color='blue', label='Sightlines', s=50, zorder=3)
else:
    print("No sightlines found that match the criteria.")

# Draw a circle for 1.5 R_vir radius
circle = plt.Circle(group_center[:2], circle_radius, color='green', fill=False, linestyle='--', label='1.5 R_vir', zorder=2)
ax.add_artist(circle)

# Customize plot
ax.set_xlabel("X [kpc]")
ax.set_ylabel("Y [kpc]")
ax.set_title("2D Location of Sightlines and Group Center")
ax.legend()
ax.grid(True)
plt.axis("equal")

# Adjust limits for better visualization
x_min, x_max = group_center[0] - circle_radius, group_center[0] + circle_radius
y_min, y_max = group_center[1] - circle_radius, group_center[1] + circle_radius
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)

# Show plot
plt.show()

# Print positions of sightlines meeting the criteria
print("Sightline Positions (x, y):")
for pos in sightline_positions:
    print(pos)

# %%
import h5py
import numpy as np
import matplotlib.pyplot as plt
import os
import pandas as pd


def plot_2d_dataset_with_galaxy_overlays(hdf5_file_path, dataset_name, tng_data_path, galaxy_catalog_path, target_index, group_center, r_vir, output_path):
    """
    Plots a 2D dataset from an HDF5 file and aligns/overlays the group center, 1.5 R_vir circle, galaxies with 25 kpc circles,
    and sightline in physical coordinates.

    Parameters:
        hdf5_file_path (str): Path to the HDF5 file containing the 2D dataset.
        dataset_name (str): Name of the 2D dataset to plot.
        tng_data_path (str): Path to the HDF5 file containing sightline data.
        galaxy_catalog_path (str): Path to the galaxy catalog text file.
        target_index (int): Index of the sightline to overlay.
        group_center (list or np.array): Coordinates of the group center [x, y, z].
        r_vir (float): Virial radius of the group.
        output_path (str): Path to save the resulting plot.
    """
    try:
        # Load the 2D dataset
        with h5py.File(hdf5_file_path, 'r') as hdf:
            if dataset_name in hdf:
                data = hdf[dataset_name][()]
                
                # Ensure data is 2D
                if data.ndim != 2 or data.shape != (3000, 3000):
                    print(f"Dataset '{dataset_name}' is not 2D or not 3000x3000. Found {data.shape} dimensions.")
                    return

                # Replace NaN values with a visualization minimum
                vmin = 10  # Minimum value for visualization
                data = np.nan_to_num(data, nan=vmin)

                # Generate physical coordinate grid
                pixel_size = (4 * r_vir) / 3000  # Physical distance per pixel
                x = np.linspace(-2 * r_vir + group_center[0], 2 * r_vir + group_center[0], 3000)
                y = np.linspace(-2 * r_vir + group_center[1], 2 * r_vir + group_center[1], 3000)
                xx, yy = np.meshgrid(x, y)

                # Load sightline position
                with h5py.File(tng_data_path, 'r') as tng_file:
                    ray_pos = tng_file['ray_pos'][target_index]  # Extract the sightline position

                    # Extract (x, y) components of the sightline and group center
                    sightline_2d = ray_pos[:2]
                    group_center_2d = group_center[:2]

                # Load galaxy catalog and extract positions
                galaxy_catalog = pd.read_csv(galaxy_catalog_path, delimiter='|', skipinitialspace=True)
                galaxy_catalog.columns = galaxy_catalog.columns.str.strip()  # Clean column names
                galaxy_positions = galaxy_catalog[['SubhaloPos_0', 'SubhaloPos_1']].to_numpy()
                
                # calulcta the impat parameter of the sightline from the group center using x and y coordinates
                cen_gx = group_center[0]
                cen_gy = group_center[1]
                sightline_x = ray_pos[0]
                sightline_y = ray_pos[1]
                print(f"Group center x: {cen_gx}, y: {cen_gy}")
                print(f"Sightline x: {sightline_x}, y: {sightline_y}")
                d_ckpch = np.sqrt((cen_gx - sightline_x)**2 + (cen_gy - sightline_y)**2)
                h = 0.6774
                z = 0.09940180263022191
                d_kpc = d_ckpch / h / (1+z)
                print(f"impact parameter of the sightline from the group center is {d_kpc} kpc")
                

                # Create the plot
                plt.figure(figsize=(10, 8))
                plt.imshow(data, origin='lower', extent=[x.min(), x.max(), y.min(), y.max()],
                           cmap='viridis', aspect='auto', vmin=vmin)
                plt.colorbar(label='Value')
                plt.title(f"2D Dataset: {dataset_name} with Physical Overlays")
                plt.xlabel("X [kpc]")
                plt.ylabel("Y [kpc]")

                # Overlay group center
                plt.scatter(group_center_2d[0], group_center_2d[1], color='red', label='Group Center', s=100, zorder=3)

                # Overlay 1.5 R_vir circle
                circle = plt.Circle(group_center_2d, 1.5 * r_vir, color='green', fill=False, linestyle='--', label='1.5 R_vir', zorder=2)
                plt.gca().add_artist(circle)

                # Overlay sightline
                plt.scatter(sightline_2d[0], sightline_2d[1], color='blue', label='Sightline', s=100, zorder=3)

                # Overlay galaxies and their 25 kpc circles
                for pos in galaxy_positions:
                    plt.scatter(pos[0], pos[1], color='orange', label='Galaxy Center', s=50, zorder=3)
                    galaxy_circle = plt.Circle(pos, 100, color='orange', fill=False, linestyle='-', zorder=2)
                    plt.gca().add_artist(galaxy_circle)

                # Add legend and grid
                plt.legend()
                plt.grid(True)

                # Save the plot
                output_file = os.path.join(output_path, f"{dataset_name}_2D_plot_with_galaxy_overlays.png")
                plt.savefig(output_file, dpi=600)
                print(f"Plot saved at {output_file}")

                # Show the plot
                plt.show()
            else:
                print(f"Dataset '{dataset_name}' not found in the HDF5 file.")
    except Exception as e:
        print(f"Error reading or plotting dataset: {e}")


# Parameters
hdf5_file_path = "/Users/tsingh65/ASU Dropbox/Tanmay Singh/Grad School/MgII_TNG_Project/Primary Project/Trident Codes/TNG_50_z_0.1_groups/grid_data/grp_15_halo_15_snapshot_91.hdf5"
tng_data_path = "/Users/tsingh65/ASU Dropbox/Tanmay Singh/Grad School/MgII_TNG_Project/Primary Project/Trident Codes/TNG_50_z_0.1_groups/revised_final/spectra_TNG50-1_z0.1_n300d2-sample_localized_COS-G130M_OVI_combined.hdf5"
galaxy_catalog_path = "/Users/tsingh65/ASU Dropbox/Tanmay Singh/Synthetic_IGrM_Sightlines/TNG50_fitting_results/galaxy_cats/group_15_galaxy_catalog_converted.txt"
dataset_name = "grid"  # Replace with the actual dataset name if different
target_index = 656688
group_center = np.array([4834.154, 22167.719, 16398.639])  # Group center (x, y, z)
r_vir = 384.19003  # Virial radius
output_path = "/Users/tsingh65/ASU Dropbox/Tanmay Singh/Grad School/MgII_TNG_Project/Primary Project/Trident Codes/TNG_50_z_0.1_groups/grid_data"

# Plot with physical coordinate overlays
plot_2d_dataset_with_galaxy_overlays(hdf5_file_path, dataset_name, tng_data_path, galaxy_catalog_path, target_index, group_center, r_vir, output_path)

# %%
def plot_2d_dataset_with_zoom(hdf5_file_path, dataset_name, tng_data_path, target_index, group_center, r_vir, output_path, zoom=True, zoom_radius=1.0):
    """
    Plots a 2D dataset from an HDF5 file and aligns/overlays the group center, 1.5 R_vir circle, and sightline in physical coordinates.
    Optionally zooms into the region around the sightline.

    Parameters:
        hdf5_file_path (str): Path to the HDF5 file containing the 2D dataset.
        dataset_name (str): Name of the 2D dataset to plot.
        tng_data_path (str): Path to the HDF5 file containing sightline data.
        target_index (int): Index of the sightline to overlay.
        group_center (list or np.array): Coordinates of the group center [x, y, z].
        r_vir (float): Virial radius of the group.
        output_path (str): Path to save the resulting plot.
        zoom (bool): Whether to zoom in around the sightline position.
        zoom_radius (float): Radius of the zoomed-in region in kpc.
    """
    try:
        # Load the 2D dataset
        with h5py.File(hdf5_file_path, 'r') as hdf:
            if dataset_name in hdf:
                data = hdf[dataset_name][()]
                
                # Ensure data is 2D
                if data.ndim != 2 or data.shape != (3000, 3000):
                    print(f"Dataset '{dataset_name}' is not 2D or not 3000x3000. Found {data.shape} dimensions.")
                    return

                # Replace NaN values with a visualization minimum
                vmin = 11  # Minimum value for visualization
                data = np.nan_to_num(data, nan=vmin)

                # Generate physical coordinate grid
                pixel_size = (4 * r_vir) / 3000  # Physical distance per pixel
                x = np.linspace(-2 * r_vir + group_center[0], 2 * r_vir + group_center[0], 3000)
                y = np.linspace(-2 * r_vir + group_center[1], 2 * r_vir + group_center[1], 3000)
                xx, yy = np.meshgrid(x, y)

                # Load sightline position
                with h5py.File(tng_data_path, 'r') as tng_file:
                    ray_pos = tng_file['ray_pos'][target_index]  # Extract the sightline position

                    # Extract (x, y) components of the sightline and group center
                    sightline_2d = ray_pos[:2]
                    group_center_2d = group_center[:2]

                # Create the plot
                plt.figure(figsize=(10, 8))
                plt.imshow(data, origin='lower', extent=[x.min(), x.max(), y.min(), y.max()],
                           cmap='magma', aspect='auto', vmin=vmin)
                plt.colorbar(label='Value')
                plt.title(f"2D Dataset: {dataset_name} with Physical Overlays")
                plt.xlabel("X [kpc]")
                plt.ylabel("Y [kpc]")

                # Overlay group center
                plt.scatter(group_center_2d[0], group_center_2d[1], color='red', label='Group Center', s=100, zorder=3)

                # Overlay 1.5 R_vir circle
                circle = plt.Circle(group_center_2d, 1.5 * r_vir, color='green', fill=False, linestyle='--', label='1.5 R_vir', zorder=2)
                plt.gca().add_artist(circle)

                # Overlay sightline
                plt.scatter(sightline_2d[0], sightline_2d[1], color='blue', label='Sightline', s=100, zorder=3)

                # Apply zoom if enabled
                if zoom:
                    plt.xlim(sightline_2d[0] - zoom_radius, sightline_2d[0] + zoom_radius)
                    plt.ylim(sightline_2d[1] - zoom_radius, sightline_2d[1] + zoom_radius)

                # Add legend and grid
                plt.legend()
                plt.grid(True)

                # Save the plot
                output_file = os.path.join(output_path, f"{dataset_name}_2D_plot_with_physical_overlays_zoomed.png" if zoom else f"{dataset_name}_2D_plot_with_physical_overlays.png")
                plt.savefig(output_file, dpi=600)
                print(f"Plot saved at {output_file}")

                # Show the plot
                plt.show()
            else:
                print(f"Dataset '{dataset_name}' not found in the HDF5 file.")
    except Exception as e:
        print(f"Error reading or plotting dataset: {e}")


# Parameters
hdf5_file_path = "/Users/tsingh65/ASU Dropbox/Tanmay Singh/Grad School/MgII_TNG_Project/Primary Project/Trident Codes/TNG_50_z_0.1_groups/grid_data/grp_15_halo_15_snapshot_91.hdf5"
tng_data_path = "/Users/tsingh65/ASU Dropbox/Tanmay Singh/Grad School/MgII_TNG_Project/Primary Project/Trident Codes/TNG_50_z_0.1_groups/revised_final/spectra_TNG50-1_z0.1_n300d2-sample_localized_COS-G130M_OVI_combined.hdf5"
dataset_name = "grid"  # Replace with the actual dataset name if different
target_index =  656688
group_center = np.array([4834.154, 22167.719, 16398.639])  # Group center (x, y, z)
r_vir = 384.19003  # Virial radius
output_path = "/Users/tsingh65/ASU Dropbox/Tanmay Singh/Grad School/MgII_TNG_Project/Primary Project/Trident Codes/TNG_50_z_0.1_groups/grid_data"

# Plot with zoom around sightline
plot_2d_dataset_with_zoom(hdf5_file_path, dataset_name, tng_data_path, target_index, group_center, r_vir, output_path, zoom=True, zoom_radius=25.0)

# %% [markdown]
# #### Plot the Raw Spectra
# - $$\Delta \lambda = 0.00997 \, \text{\AA}$$
# 

# %%
# Save the result dictionary

# Constants
rest_wavelength = 1031.927  # OVI rest wavelength in Å
z = 0.09940180263022191                 # Redshift
#z = 0.137807

#z = 0.13
print("Data fetched successfully. Example:", result)

# Calculate flux from tau and plot it with tau
flux_from_tau = np.exp(-result['tau_OVI_1031'])
plt.figure(figsize=(10, 6))
plt.plot(result['wave'], flux_from_tau, label='Flux (from Tau)')
plt.xlabel('Wavelength')
plt.ylabel('Flux (from Tau)')
plt.title('Flux (from Tau) vs. Wavelength')
plt.legend()
plt.grid()
# plt.xlim(1170,1190)
#plt.xlim(1170,1190)
plt.show()


# plot in the velocity space
# Convert to velocity space
velocity_orig = c_km_s * (result['wave'] - rest_wavelength * (1 + z)) / (rest_wavelength * (1 + z))
plt.figure(figsize=(10, 6))
plt.plot(velocity_orig, flux_from_tau, label='Flux (from Tau)')
plt.xlabel('Velocity (km/s)')
plt.ylabel('Flux (from Tau)')
plt.title('Flux (from Tau) vs. Velocity')
plt.legend()
plt.grid()
plt.ylim(0, 1.1)
plt.xlim(-800,3600)
plt.savefig('Flux_from_Tau_Velocity_spectra_3.png', dpi=300)
plt.show()



# %%
# import numpy as np
# import h5py
# import matplotlib.pyplot as plt
# from scipy.constants import c as c_km_s

# c_km_s = c_km_s / 1e3  # Speed of light in km/s
# rest_wavelength = 1031.927  # OVI rest wavelength in Å
# z = 0.139706  # Redshift

# # Open HDF5 file
# with h5py.File(simba_data_path, 'r') as f:
#     # Fetch equivalent width data and indices
#     data_indices = range(0, 90000)
#     eq_widths = []
#     eq_width_indices = []
#     chunk_size = 10000
    
#     for start in range(data_indices.start, data_indices.stop, chunk_size):
#         end = min(start + chunk_size, data_indices.stop)
#         chunk = f['EW_OVI_1031'][start:end]
#         eq_widths.extend(chunk)
#         eq_width_indices.extend(range(start, end))

#     # Convert to numpy arrays and sort
#     eq_widths = np.array(eq_widths)
#     eq_width_indices = np.array(eq_width_indices)
#     sorted_indices = np.argsort(eq_widths)[::-1]  # Sort in descending order
    
#     # Get top 50 indices
#     top_50_indices = eq_width_indices[sorted_indices[:540]]
    
#     # Fetch and average spectra
#     flux_spectra = []
#     wave_array = f['wave'][:]  # Same for all spectra
    
#     for idx in top_50_indices:
#         flux_spectra.append(f['flux'][idx, :])
    
#     flux_spectra = np.array(flux_spectra)
#     avg_flux_spectrum = np.mean(flux_spectra, axis=0)
    
#     # Convert to velocity space
#     velocity = c_km_s * (wave_array - rest_wavelength * (1 + z)) / (rest_wavelength * (1 + z))
    
#     # Plot averaged spectrum in wavelength space
#     plt.figure(figsize=(10, 6))
#     plt.plot(wave_array, avg_flux_spectrum, label='Averaged Spectrum')
#     plt.xlabel('Wavelength (Å)')
#     plt.ylabel('Flux')
#     plt.title('Averaged Spectrum in Wavelength Space')
#     plt.legend()
#     plt.grid()
#     plt.xlim(1170, 1190)
#     #plt.savefig('Averaged_Spectrum_Wavelength.png', dpi=300)
#     plt.show()
    
#     # Plot averaged spectrum in velocity space
#     plt.figure(figsize=(10, 6))
#     plt.plot(velocity, avg_flux_spectrum, label='Averaged Spectrum')
#     plt.xlabel('Velocity (km/s)')
#     plt.ylabel('Flux')
#     plt.title('Averaged Spectrum in Velocity Space')
#     plt.legend()
#     plt.grid()
#     plt.ylim(0, 1.1)
#     plt.xlim(-800, 3600)
#     #plt.savefig('Averaged_Spectrum_Velocity.png', dpi=300)
#     plt.show()

# %%
# # Existing redshift and additional velocity
# z = 0.137807
# v = 500  # km/s

# # Speed of light in km/s
# c = c_km_s

# # Calculate the new redshift
# z_new = (1 + z) * np.sqrt((1 + v / c) / (1 - v / c)) - 1

# print(f"Original redshift (z): {z}")
# print(f"Additional velocity: {v} km/s")
# print(f"New redshift (z_new): {z_new:.6f}")

# %% [markdown]
# #### First Appying LSF

# %%
def apply_lsf_to_spectrum(filename,flux):
    import pandas as pd
    
    lsf_data= pd.read_csv(filename, delim_whitespace=True, header=None)
    lsf_kernel = lsf_data[1].values

    print(f"SUM of LSF Kernel: {np.sum(lsf_kernel)}")
    
    # Convolve the flux with the LSF kernel
    from astropy.convolution import convolve
    
    flux_lsf = convolve(flux, lsf_kernel)
    
    np.clip(flux_lsf, 0, np.inf, out=flux_lsf)  # Clip negative values to zero
    
    return flux_lsf

# %%
# apply LSF to the spectrum and plot 
flux_lsf = apply_lsf_to_spectrum('COS_G130M_1150.txt',result['flux'])
#binned_flux_lsf = apply_lsf_to_spectrum('avg_COS.txt',binned_flux)

# Plot the binned flux with LSF applied
plt.figure(figsize=(10, 6))
plt.plot(result['wave'], flux_lsf, label='Flux with LSF')
plt.plot(result['wave'], result['flux'], label='Original Flux')
plt.xlabel('Wavelength')
plt.ylabel('Flux (from Tau)')
plt.title('Binned Flux with LSF vs. Wavelength')
plt.legend()
plt.grid()
# plt.xlim(1120,1160)

plt.xlim(1030*(1+z),1035*(1+z))
plt.show()



# %%
plt.figure(figsize=(10, 6))

# Plot the original flux as a step plot
plt.step(velocity_orig, flux_from_tau, label='Original', where='mid', linewidth=2, alpha=1)

# Plot the LSF-applied flux as a step plot
plt.step(velocity_orig, flux_lsf, label='LSF Applied', where='mid', linewidth=2, alpha=1)

# Add labels and title
plt.xlabel('Velocity (km/s)', fontsize=14)
plt.ylabel('Flux (from Tau)', fontsize=14)
#plt.title('Flux (from Tau) vs. Velocity', fontsize=16)

# Add legend and grid
plt.legend(fontsize=12, loc='best')
plt.grid(True, linestyle='--', alpha=0.7)

# Set x-axis limits
plt.xlim(-1000, 1000)
#plt.savefig('Original_vs_LSF.png', dpi=300)
#plt.ylim(0.5,1.5)
# Show the plot
plt.show()

# %% [markdown]
# #### Binning with N pix

# %%
bin_pixels = 3 # Number of pixels to bin

# Bin the flux and wavelength data using the specified number of pixels here.
def bin_data(wavelength, flux, bin_pixels):
    # Calculate the number of bins
    num_bins = len(wavelength) // bin_pixels

    # Calculate the new number of pixels
    new_num_pixels = num_bins * bin_pixels

    # Reshape the flux and wavelength arrays
    flux_reshaped = flux[:new_num_pixels].reshape(num_bins, bin_pixels)
    wavelength_reshaped = wavelength[:new_num_pixels].reshape(num_bins, bin_pixels)

    # Calculate the binned flux and wavelength
    binned_flux = np.mean(flux_reshaped, axis=1)
    binned_wavelength = np.mean(wavelength_reshaped, axis=1)

    return binned_wavelength, binned_flux

# Bin the flux and wavelength data
binned_wavelength, binned_flux = bin_data(result['wave'], flux_lsf, bin_pixels)

# Plot the binned flux
plt.figure(figsize=(10, 6))
plt.plot(binned_wavelength, binned_flux, label='Binned Flux (from Tau)')
plt.xlabel('Wavelength')
plt.ylabel('Flux (from Tau)')
plt.title('Binned Flux (from Tau) vs. Wavelength')
plt.legend()
plt.grid()
plt.xlim(1120,1160)
plt.show()

# plot in the velocity space
velocity = c_km_s * (binned_wavelength - rest_wavelength * (1 + z)) / (rest_wavelength * (1 + z))
plt.figure(figsize=(10, 6))
plt.plot(velocity, binned_flux, label='Binned Flux (from Tau)')
plt.xlabel('Velocity (km/s)')
plt.ylabel('Flux (from Tau)')
plt.title('Binned Flux (from Tau) vs. Velocity')
plt.legend()
plt.grid()
plt.xlim(-800,800)
plt.show()


# # overalay unbinned and binned flux inn the velocity space
# plt.figure(figsize=(10, 6))
# plt.plot(velocity, binned_flux, label='Binned Flux (LSF applied)')
# plt.plot(velocity_orig, flux_from_tau, label='Flux')
# plt.xlabel('Velocity (km/s)')
# plt.ylabel('Flux (from Tau)')
# plt.title('Binned Flux vs. Flux')
# plt.legend()
# plt.grid()
# plt.xlim(-800,800)
# plt.show()

plt.figure(figsize=(12, 7))

# Plot binned flux as a step plot
plt.step(velocity, binned_flux, label='LSF + Binning', color='blue', where='mid', linewidth=1.5, alpha=0.8)

# Plot unbinned flux as a step plot
plt.step(velocity_orig, flux_lsf, label='LSF', color='red', where='mid', linewidth=1.5, alpha=1)

# Add labels and title
plt.xlabel('Velocity (km/s)', fontsize=16, labelpad=10)
plt.ylabel('Flux (from Tau)', fontsize=16, labelpad=10)
#plt.title('Comparison of Binned and Unbinned Flux in Velocity Space', fontsize=18, pad=15)

# Customize legend
plt.legend(fontsize=14, loc='best', frameon=True, shadow=True)

# Add a grid with light style
plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)

# Set x and y limits
plt.xlim(-1000, 1000)  # Corrected xlim to cover the proper range
plt.ylim(min(min(binned_flux), min(flux_lsf)) * 0.95, max(max(binned_flux), max(flux_lsf)) * 1.05)

# Enhance tick marks
plt.tick_params(axis='both', which='major', labelsize=12, direction='in', length=6, width=1.5)

# Save the plot as a high-resolution image
#plt.savefig('LSF_vs_LSF_binned_spectra_1.png', dpi=300, bbox_inches='tight')

# Show the plot
plt.show()

# %% [markdown]
# #### Applying Noise , SNR = 10 per resolution element 

# %%
SNR_per_res = 10
resolution_pix = 6

SNR_per_bin = SNR_per_res/(np.sqrt(resolution_pix/bin_pixels))
print(f"SNR per bin: {SNR_per_bin}")

# %%
# make a function to add gaussian noise to the flux with the specified SNR

def add_gaussian_noise(flux, SNR):
    # Calculate the standard deviation of the noise
    seed = 42
    np.random.seed(seed)
    noise_std = np.mean(flux) / SNR

    # Generate Gaussian noise
    noise = np.random.normal(0, noise_std, len(flux))

    # Add noise to the flux
    flux_noisy = flux + noise

    return flux_noisy



# %%
# Add Gaussian noise to the binned flux
flux_noisy = add_gaussian_noise(binned_flux, SNR_per_bin)

# Plot the noisy flux
plt.figure(figsize=(10, 6))
# plt.plot(binned_wavelength, flux_noisy, label='Noisy Flux')
plt.step(binned_wavelength, flux_noisy, label='Noisy Flux', where='mid', linewidth=1.5, alpha=1)
plt.xlabel('Wavelength')
plt.ylabel('Flux (from Tau)')
plt.title('Noisy Flux vs. Wavelength')
plt.legend()
plt.grid()
plt.xlim(1030*(1+z),1035*(1+z))
plt.show()

velocity = c_km_s * (binned_wavelength - rest_wavelength * (1 + z)) / (rest_wavelength * (1 + z))
print(velocity.shape)
# Plot the noisy flux in velocity space
plt.figure(figsize=(10, 6))
# plt.plot(velocity, flux_noisy, label='Noisy Flux')
plt.step(velocity, flux_noisy, label='Noisy Flux', where='mid', linewidth=1.5, alpha=1)
plt.xlabel('Velocity (km/s)')
plt.ylabel('Flux (from Tau)')
plt.title('Noisy Flux vs. Velocity')
plt.legend()
plt.grid()
# plt.ylim(0.5,1.5)
plt.xlim(-1000,1000)
plt.show()



# %%
# make the noise array with n sigma noise corresponding to the given SNR
def make_noise_array(flux, SNR,constant = True):
    # Calculate the standard deviation of the noise
    if constant == True:
        noise = 1/SNR * np.ones_like(flux)
    else:
        noise_std = np.mean(flux) / SNR

        # Generate Gaussian noise
        noise = np.random.normal(0, noise_std, len(flux))

    return noise

# Generate the noise array
noise = make_noise_array(flux_noisy, SNR_per_bin,constant=True)
print(noise.shape)


# %%
noise

# %%
# plot both the flux and the noise array together like this 
# Convert to velocity space
velocity = c_km_s * (binned_wavelength - rest_wavelength * (1 + z)) / (rest_wavelength * (1 + z))
print(f"Velocity: {velocity}")

fig, ax = plt.subplots(figsize=(10, 6))

# Create a stepped plot using ax.step
ax.step(velocity, flux_noisy, where='mid', label='Flux', color='blue', linewidth=1.5)

# Add the shaded error region
ax.fill_between(velocity, flux_noisy - noise , flux_noisy + noise, color='green', alpha=0.3, label='Error')

# Add vertical and horizontal lines for reference
ax.axvline(x=0, color='red', linestyle='--', label=f'OVI {rest_wavelength} Å (0 km/s)', linewidth=2)
ax.axhline(y=1, color='gray', linestyle='--', label='Continuum')

# # Set x-axis formatter to display floating-point numbers
# formatter = FuncFormatter(lambda x, _: f'{x:.0f}')
# ax.xaxis.set_major_formatter(formatter)

# Set plot labels and title
ax.set_xlabel('Velocity (km/s)', fontsize=14)
ax.set_ylabel('Normalized Flux', fontsize=14)
ax.set_xlim(-800, 800)
ax.set_ylim(0.0, 1.35)

# Customize ticks and grid
ax.tick_params(axis='both', which='major', labelsize=12)
ax.grid(True, linestyle='--', alpha=0.7)

# Add legend
ax.legend(fontsize=12)

# Save and show the plot
#plt.savefig('OVI_velocity_space_zoom_shaded_z_013748.png', dpi=300)
plt.show()

# %%
results_table = Table(
    names=(
        'Sightline', 'Species', 'EW(mA)', 'dEW(mA)', 'N', 'dN', 'b', 'db', 'v', 
        'dv', 'l', 'dl', 'UpLim', 'Sat', 'Chisq'
    ),
    dtype=['i4', 'S10', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'bool', 'bool', 'f8']
)

# %%


try:
    pg_ion = 'OVI1031'
    wave_binned = binned_wavelength
    wave_binned_rest = wave_binned/(1 + z)
    vel_binned = wave_to_vel(wave_binned_rest, float(pg.analysis.absorption_spectra.lines[pg_ion]['l'].split()[0]))
    flux_binned = flux_noisy
    error_binned = noise

    min_region_width = 3  # pixels
    N_sigma = 3 # 1-sigma detection limit
    logN_bounds = [13.5, 18]
    b_bounds = [6, 100]
    chisq_lim = 1

    saturation_flag, output_table,fit,regions = fit_vp_pipeline_new( 1, pg_ion, wave_binned_rest, vel_binned, flux_binned, error_binned, z, logN_bounds, b_bounds, min_region_width, N_sigma,chisq_lim)
    results_table = vstack((results_table, output_table))

except RuntimeError:
    empty_row = Table(
        [[ 1], [np.nan], [np.nan], [np.nan], [np.nan], [np.nan], [np.nan], [np.nan], [np.nan], [np.nan], [np.nan], [np.nan], [False], [np.nan], [np.nan]],)



# %%
EW_to_N('OVI1031', 13.328228594498576*3)

# %%
results_table = Table(
    names=(
        'Sightline', 'Species', 'EW(mA)', 'dEW(mA)', 'N', 'dN', 'b', 'db', 'v', 
        'dv', 'l', 'dl', 'UpLim', 'Sat', 'Chisq'
    ),
    dtype=['i4', 'S10', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'bool', 'bool', 'f8']
)


results_table = vstack((results_table, output_table))

# %%
wave_binned

# %%
binned_wavelength

# %%
# Convert observed wavelength to rest frame
wavelength_rest = wave_binned_rest 

# Select the range in rest wavelength space
zoom_mask = (wavelength_rest >= 1028) & (wavelength_rest <= 1035)

# count true values in the zoom mask
print(f"Number of True values in the zoom mask: {np.sum(zoom_mask)}")

# %%
# Extract zoomed data in rest wavelength space
wavelength_zoom_rest = wavelength_rest[zoom_mask]
flux_zoom = flux_noisy[zoom_mask]
flux_err_zoom = noise[zoom_mask]
flux_upper = flux_zoom + flux_err_zoom
flux_lower = flux_zoom - flux_err_zoom

# %%

fitting_data = {
    'N': [17.992483830528826, 13.924630006215876, 14.322969764691424, 13.504412498282178],  # Column densities
    'b': [23.354884550850837, 35.102423554192654, 38.87781593059373, 34.809892803010214],  # Doppler parameters
    'l': [1032.9528822036423, 1032.6238137252512, 1033.2457246082984, 1033.2870964343672]   # Centroid wavelengths
}

# fitting_data = {
#     'N': [13.627762728567626, 13.83463650854652, 13.907586508371804, 13.506077943128467, 13.710643440946845, 13.680620097917311],  # Column densities
#     'b': [34.62001630649476, 99.98341163335535, 21.627720825696773, 8.051624910628453, 21.494323628859412, 29.08035308392944],     # Doppler parameters
#     'l': [1032.4140559610435, 1032.879437552414, 1033.0962234038689, 1032.8915702362729, 1033.2501701901726, 1033.6654472160149]  # Centroid wavelengths
# }

# fitting_data = {
#     'N': [18.013234609966627, 14.82061101872063],  # Column densities
#     'b': [17.17429070294082, 77.42142875692223],   # Doppler parameters
#     'l': [1031.9047774831201, 1031.8703202230272]  # Centroid wavelengths
# }

# # Define the fitting data
# fitting_data = {
#     "Sightline": [35710, 35710, 35710, 35710, 35710],
#     "Species": ["OVI1031"] * 5,
#     "EW(mA)": [470.9112527570984, 35.73384219627579, 35.39387096372975, 35.634952771976636, 268.3832591343839],
#     "dEW(mA)": [float('nan')] * 5,
#     "N": [18.017747337421095, 13.50709737036714, 13.502423360164864, 13.505772637400014, 14.538282149599882],
#     "dN": [0.6202466336088821, 1.000000055410929, 1.0000001346476475, 1.0000001337853366, 0.17091492698129115],
#     "b": [21.120247025492837, 18.98624152919868, 18.982636561271896, 18.98077232285682, 45.46853940331213],
#     "db": [2.590321983068815, 1.000000112539643, 1.0000000768941977, 1.000000083456931, 16.099769404879623],
#     "v": [-216.6258488018081, -205.77610784515178, -206.820507882346, -211.88233321068637, -36.17721606646197],
#     "dv": [2.0939466455272537, 145.26751694740992, 145.26353900131502, 145.27199627587964, 5.742119247646899],
#     "l": [1031.166354781371, 1031.2037005437185, 1031.2001056269605, 1031.182682379114, 1031.78747484138],
#     "dl": [0.014415096932742682, 1.0000471322613318, 1.0000197473807777, 1.0000779687461026, 0.03952975866451903],
#     "UpLim": [False] * 5,
#     "Sat": [True] * 5,
#     "Chisq": [2.046075413031284, 2.046075413031284, 2.046075413031284, 2.046075413031284, 0.7835714179362755],
# }
fitting_data = fit

# Generate the Voigt profile for the fitted spectrum
line_data = pg.analysis.absorption_spectra.lines['OVI1031']

# Generate the params array dynamically
params = generate_params(fitting_data)

# Voigt profile computation remains the same
line_data = pg.analysis.absorption_spectra.lines['OVI1031']
wave_subset = wavelength_zoom_rest  # Use the zoomed rest wavelength range
total_tau = pg.analysis.vpfit.model_tau(line_data, params, wave_subset, mode='Voigt')
model_flux = np.exp(-total_tau)

# %%
fitting_data

# %%
# Rest of the code for plotting, FWHM calculation, and annotations remains unchanged.
velocity = c_km_s * (wave_binned - rest_wavelength) / (rest_wavelength)
# Zoom mask for velocity range (-800 km/s to +800 km/s)
velocity_zoom = velocity[zoom_mask]



# Calculate FWHM for each feature
centroids = params[2::3]  # Extract every 3rd value starting from index 2
doppler_params = params[1::3]  # Extract every 3rd value starting from index 1
fwhms = [(2 * np.sqrt(np.log(2)) * b / c_km_s) * l for b, l in zip(doppler_params, centroids)]

# Plotting
fig, ax = plt.subplots(figsize=(12, 6))

# Original spectrum
ax.plot(wavelength_zoom_rest, flux_zoom, label='Original Spectrum', color='blue')
ax.fill_between(wavelength_zoom_rest, flux_lower, flux_upper, color='green', alpha=0.3, label='Error')

# Fitted spectrum (Voigt profile)
ax.step(wave_subset, model_flux, label='Fitted Spectrum (Voigt Profile)', color='red', where='mid', linewidth=3)

# Highlight centroids and add FWHM arrows
arrowprops = dict(arrowstyle='<->', color='black', lw=1.5)
for i, (centroid, fwhm, b) in enumerate(zip(centroids, fwhms, doppler_params)):
    ax.axvline(centroid, color=f'C{i}', linestyle='--', label=f'Centroid {i+1}: {centroid:.2f} Å', linewidth=2)
    ax.annotate('', xy=(centroid - fwhm / 2, 0.75 - i * 0.05), 
                xytext=(centroid + fwhm / 2, 0.75 - i * 0.05), arrowprops=arrowprops)
    ax.text(centroid, 0.77 - i * 0.1, f'b = {b:.1f} km/s', ha='center', fontsize=10)

# Highlight reference line for OVI in rest wavelength
ax.axvline(rest_wavelength, color='purple', linestyle='--', label=f'OVI {rest_wavelength:.2f} Å (Rest Frame)', linewidth=2)

# Set labels, limits, and grid
ax.set_xlabel('Rest Wavelength (Å)', fontsize=14)
ax.set_ylabel('Normalized Flux', fontsize=14)
ax.set_xlim(wavelength_zoom_rest.min(), wavelength_zoom_rest.max())
ax.set_ylim(0.0, 1.35)
ax.grid(True, linestyle='--', alpha=0.7)
#ax.legend(fontsize=12)

# Save and show the plot
#plt.savefig('OVI_Voigt_Profile_Fit_0.1_err.png', dpi=300)
plt.show()

# %%
velocity

# %%
# Constants
c_km_s = 299792.458  # Speed of light in km/s

# Calculate velocity from wavelength
velocity = c_km_s * (wave_binned_rest - rest_wavelength) / rest_wavelength
velocity_zoom = velocity[zoom_mask]


# Adjust FWHM for velocity space
centroid_velocities = c_km_s * (centroids - rest_wavelength) / rest_wavelength
fwhms_velocity = [(2 * np.sqrt(np.log(2)) * b) for b in doppler_params]

# Plotting
fig, ax = plt.subplots(figsize=(12, 6))

# Original spectrum
#ax.plot(velocity_zoom, flux_zoom, label='Original Spectrum', color='blue')
# Original spectrum as a step plot
ax.step(velocity_zoom, flux_zoom, label='Original Spectrum', color='blue', where='mid', linewidth=2,alpha=0.7)
#ax.fill_between(velocity_zoom, flux_lower, flux_upper, color='green', alpha=0.3, label='Error')

# Fitted spectrum (Voigt profile)
ax.step(velocity[zoom_mask], model_flux, label='Fitted Spectrum (Voigt Profile)', color='red', where='mid', linewidth=3)

# Highlight centroids and add FWHM arrows
arrowprops = dict(arrowstyle='<->', color='black', lw=1.5)
for i, (centroid_vel, fwhm_vel, b) in enumerate(zip(centroid_velocities, fwhms_velocity, doppler_params)):
    ax.axvline(centroid_vel, color=f'C{i}', linestyle='--', label=f'Centroid {i+1}: {centroid_vel:.1f} km/s', linewidth=2)
    ax.annotate('', xy=(centroid_vel - fwhm_vel / 2, 0.75 - i * 0.05), 
                xytext=(centroid_vel + fwhm_vel / 2, 0.75 - i * 0.05), arrowprops=arrowprops)
    ax.text(centroid_vel, 0.77 - i * 0.1, f'b = {b:.1f} km/s', ha='center', fontsize=10)

# Highlight reference line for OVI in velocity space (0 km/s)
ax.axvline(0, color='purple', linestyle='--', label=f'OVI {rest_wavelength:.2f} Å (Rest Frame)', linewidth=2)

# Set labels, limits, and grid
ax.set_xlabel('Velocity (km/s)', fontsize=14)
ax.set_ylabel('Normalized Flux', fontsize=14)
ax.set_xlim(-800, 800)  # Adjust velocity range as needed
ax.set_ylim(0.0, 1.35)
ax.grid(True, linestyle='--', alpha=0.7)
#ax.legend(fontsize=12)

# Save and show the plot
#plt.savefig('VP_fit_80583_group_6_sol_dN>2.png', dpi=300)
plt.show()

# %%
import h5py
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# Set font styles
plt.rc('font', family='Times New Roman')
plt.rcParams.update({
    'axes.labelsize': 32,
    'xtick.labelsize': 30,
    'ytick.labelsize': 30,
    'legend.fontsize': 30,
})

c_km_s = 299792.458  # Speed of light in km/s

# Fitting results
fitting_results = {
    'region': np.array([0.]),
    'l': np.array([1032.47545845]),
    'dl': np.array([0.0799331]),
    'b': np.array([73.64487674]),
    'db': np.array([0.9999996]),
    'N': np.array([14.53968515]),
    'dN': np.array([0.17620323]),
    'EW': np.array([0.31674129]),
    'chisq': np.array([0.79111285])
}

# Calculate velocity centroids
centroid_velocities = c_km_s * (fitting_results['l'] - 1031.92) / 1031.92
fwhms_velocity = [(2 * np.sqrt(np.log(2)) * fitting_results['b'][0])]

# Plotting
fig, axs = plt.subplots(1, 2, figsize=(18, 8), gridspec_kw={'width_ratios': [1, 1], 'wspace': 0.25})

# Left panel: Galaxy Overlay
def plot_galaxy_overlay(ax, hdf5_file_path, dataset_name, tng_data_path, galaxy_catalog_path, target_index, group_center, r_vir):
    with h5py.File(hdf5_file_path, 'r') as hdf:
        data = hdf[dataset_name][()]
        data = np.nan_to_num(data, nan=10)

        # Grid extent in R_vir units
        x = np.linspace(-2, 2, 3000)
        y = np.linspace(-2, 2, 3000)

        with h5py.File(tng_data_path, 'r') as tng_file:
            ray_pos = tng_file['ray_pos'][target_index]
            sightline_2d = (ray_pos[:2] - group_center[:2]) / r_vir

        galaxy_catalog = pd.read_csv(galaxy_catalog_path, delimiter='|', skipinitialspace=True)
        galaxy_catalog.columns = galaxy_catalog.columns.str.strip()
        galaxy_positions = (galaxy_catalog[['SubhaloPos_0', 'SubhaloPos_1']].to_numpy() - group_center[:2]) / r_vir

        ax.imshow(data, origin='lower', extent=[x.min(), x.max(), y.min(), y.max()],
                  cmap='magma', aspect='equal', vmin=10)
        ax.scatter(0, 0, color='crimson', s=50, zorder=3)  # Group center
        # Plot the sightline as a blue 'x'
        ax.scatter(sightline_2d[0], sightline_2d[1], color='cyan', marker='*', s=250, zorder=3)

        # Add a box with the sightline index at the bottom left and an arrow pointing to the sightline location
        box_props = dict(boxstyle='round,pad=0.3', edgecolor='black', facecolor='snow', alpha=1)
        arrowprops = dict(arrowstyle='->', color='white', lw=1.5)

        # Annotate with the sightline index enclosed in a box and an arrow
        ax.annotate(f"Sightline {target_index - 90000*7}",
                    xy=(sightline_2d[0], sightline_2d[1]),
                    xytext=(-1.2, -1.65),  # Bottom left corner in R_vir units
                    fontsize=20, ha='center', va='center', bbox=box_props,
                    arrowprops=arrowprops)

        # Galaxy circles
        for pos in galaxy_positions:
            ax.add_artist(plt.Circle(pos, 50 / r_vir, color='aquamarine', fill=False, linestyle='-', zorder=2,linewidth = 1.5))
        
        # Circle for 1.5 R_vir
        ax.add_artist(plt.Circle((0, 0), 1.5, color='white', fill=False, linestyle='solid', linewidth=2, zorder=2))

        ax.set_xlabel("$X \, [R_{vir}]$",labelpad=10,fontsize=32)
        ax.set_ylabel("$Y \, [R_{vir}]$",labelpad=10,fontsize=32)

# Parameters for the galaxy overlay
hdf5_file_path = "/Users/tsingh65/ASU Dropbox/Tanmay Singh/Grad School/MgII_TNG_Project/Primary Project/Trident Codes/TNG_50_z_0.1_groups/grid_data/grp_15_halo_15_snapshot_91.hdf5"
tng_data_path = "/Users/tsingh65/ASU Dropbox/Tanmay Singh/Grad School/MgII_TNG_Project/Primary Project/Trident Codes/TNG_50_z_0.1_groups/revised_final/spectra_TNG50-1_z0.1_n300d2-sample_localized_COS-G130M_OVI_combined.hdf5"
galaxy_catalog_path = "/Users/tsingh65/ASU Dropbox/Tanmay Singh/Synthetic_IGrM_Sightlines/TNG50_fitting_results/galaxy_cats/group_15_galaxy_catalog_converted.txt"
dataset_name = "grid"
target_index = 656688
group_center = np.array([4834.154, 22167.719, 16398.639])  # Group center
r_vir = 384.19003

# Left panel: Galaxy overlay
plot_galaxy_overlay(axs[0], hdf5_file_path, dataset_name, tng_data_path, galaxy_catalog_path, target_index, group_center, r_vir)

# Right panel: Velocity spectrum
ax = axs[1]
ax.step(velocity_zoom, flux_zoom, label='Original Spectrum', color='blue', where='mid', linewidth=2, alpha=0.7)
ax.step(velocity[zoom_mask], model_flux, label='Fitted Spectrum (Voigt Profile)', color='red', where='mid', linewidth=4)

# Highlight reference line for OVI in velocity space (0 km/s)
ax.axvline(0, color='purple', linestyle='--', linewidth=2)
ax.axvline(centroid_velocities[0], color='green', linestyle='--', linewidth=2)
ax.axhline(1, color='gray', linestyle='--', linewidth=2)

# Add a text box for "OVI 1031" at (500, 0.3)
ax.text(500, 0.2, r"$\mathrm{OVI \, 1031}$", fontsize=36, color='black', ha='center')

# Add arrow for the Doppler b-value
arrowprops = dict(arrowstyle='<->', color='black', lw=2)
b_value_start = centroid_velocities[0] - fwhms_velocity[0] / 2
b_value_end = centroid_velocities[0] + fwhms_velocity[0] / 2
ax.annotate('', xy=(b_value_start, 0.6), xytext=(b_value_end, 0.6), arrowprops=arrowprops)
ax.text(centroid_velocities[0]+220, 0.55, r"$b = $" + f"{fitting_results['b'][0]:.1f} km/s",
        fontsize=18, ha='center', va='bottom', color='black')

# Add fitting parameters with proper formatting using mattext
textstr = (
    f"$v_{{0}}$: {centroid_velocities[0]:.1f} ± {fitting_results['dl'][0]:.2f} km/s\n"
    f"$b$: {fitting_results['b'][0]:.1f} ± {fitting_results['db'][0]:.2f} km/s\n"
    f"$N$: {fitting_results['N'][0]:.2f} ± {fitting_results['dN'][0]:.2f}\n"
    f"$EW$: {fitting_results['EW'][0]:.3f} Å\n"
    f"$\chi^2$: {fitting_results['chisq'][0]:.3f}"
)

# Add text to the plot with center alignment
ax.text(
    -450, 0.25, textstr,
    fontsize=18,
    ha='center',
    va='center',
    alpha=0.9,
    color='black',
    backgroundcolor='lightgrey',
    bbox=dict(facecolor='lightgrey', edgecolor='black', boxstyle='round,pad=0.5'),
    
)
ax.set_xticks(np.arange(-800, 801, 200))
ax.set_yticks([0.0, 0.25, 0.5, 0.75, 1.0])
# ax.set_xlabel('Velocity (km/s)')
ax.set_xlabel(r"$\mathrm{Velocity \, [km s^{-1}]}$",labelpad=10,fontsize=32)
ax.set_ylabel('Normalized Flux')
ax.set_xlim(-800, 800)
ax.set_ylim(0.0, 1.35)

# Adjust layout and save
plt.tight_layout()
#plt.savefig("combined_plot_rvir_velocity_fit.png", dpi=300)
plt.show()

# %%
import h5py
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# Set font styles
plt.rc('font', family='Times New Roman')
plt.rcParams.update({
    'axes.labelsize': 32,
    'xtick.labelsize': 30,
    'ytick.labelsize': 30,
    'legend.fontsize': 30,
    'xtick.direction': 'in',
    'ytick.direction': 'in',
    'axes.linewidth': 2  # Set the width of the box
})

c_km_s = 299792.458  # Speed of light in km/s

# Fitting results
fitting_results = {
    'region': np.array([0.]),
    'l': np.array([1032.47545845]),
    'dl': np.array([0.0799331]),
    'b': np.array([73.64487674]),
    'db': np.array([0.9999996]),
    'N': np.array([14.53968515]),
    'dN': np.array([0.17620323]),
    'EW': np.array([0.31674129]),
    'chisq': np.array([0.79111285])
}

# Calculate velocity centroids
centroid_velocities = c_km_s * (fitting_results['l'] - 1031.92) / 1031.92
fwhms_velocity = [(2 * np.sqrt(np.log(2)) * fitting_results['b'][0])]

# Plotting
fig, axs = plt.subplots(1, 2, figsize=(18, 6), gridspec_kw={'width_ratios': [1, 2], 'wspace': 0.25})

# Left panel: Galaxy Overlay
def plot_galaxy_overlay(ax, hdf5_file_path, dataset_name, tng_data_path, galaxy_catalog_path, target_index, group_center, r_vir):
    with h5py.File(hdf5_file_path, 'r') as hdf:
        data = hdf[dataset_name][()]
        data = np.nan_to_num(data, nan=10)

        # Grid extent in R_vir units
        x = np.linspace(-2, 2, 3000)
        y = np.linspace(-2, 2, 3000)

        with h5py.File(tng_data_path, 'r') as tng_file:
            ray_pos = tng_file['ray_pos'][target_index]
            sightline_2d = (ray_pos[:2] - group_center[:2]) / r_vir

        galaxy_catalog = pd.read_csv(galaxy_catalog_path, delimiter='|', skipinitialspace=True)
        galaxy_catalog.columns = galaxy_catalog.columns.str.strip()
        galaxy_positions = (galaxy_catalog[['SubhaloPos_0', 'SubhaloPos_1']].to_numpy() - group_center[:2]) / r_vir

        im = ax.imshow(data, origin='lower', extent=[x.min(), x.max(), y.min(), y.max()],
                       cmap='magma', aspect='equal', vmin=10)
        ax.scatter(0, 0, color='crimson', s=50, zorder=3)  # Group center
        ax.scatter(sightline_2d[0], sightline_2d[1], color='cyan', marker='*', s=250, zorder=3)

        box_props = dict(boxstyle='round,pad=0.3', edgecolor='black', facecolor='snow', alpha=1)
        arrowprops = dict(arrowstyle='->', color='white', lw=1.5)
        ax.annotate(f"Sightline {target_index - 90000*7}",
                    xy=(sightline_2d[0], sightline_2d[1]),
                    xytext=(-1.2, -1.65),
                    fontsize=20, ha='center', va='center', bbox=box_props,
                    arrowprops=arrowprops)

        for pos in galaxy_positions:
            ax.add_artist(plt.Circle(pos, 50 / r_vir, color='aquamarine', fill=False, linestyle='-', zorder=2, linewidth=1.5))
        
        ax.add_artist(plt.Circle((0, 0), 1.5, color='white', fill=False, linestyle='solid', linewidth=2, zorder=2))

        ax.set_xlabel("$X \, [R_{vir}]$", labelpad=10, fontsize=32)
        ax.set_ylabel("$Y \, [R_{vir}]$", labelpad=10, fontsize=32)

hdf5_file_path = "/Users/tsingh65/ASU Dropbox/Tanmay Singh/Grad School/MgII_TNG_Project/Primary Project/Trident Codes/TNG_50_z_0.1_groups/grid_data/grp_15_halo_15_snapshot_91.hdf5"
tng_data_path = "/Users/tsingh65/ASU Dropbox/Tanmay Singh/Grad School/MgII_TNG_Project/Primary Project/Trident Codes/TNG_50_z_0.1_groups/revised_final/spectra_TNG50-1_z0.1_n300d2-sample_localized_COS-G130M_OVI_combined.hdf5"
galaxy_catalog_path = "/Users/tsingh65/ASU Dropbox/Tanmay Singh/Synthetic_IGrM_Sightlines/TNG50_fitting_results/galaxy_cats/group_15_galaxy_catalog_converted.txt"
dataset_name = "grid"
target_index = 656688
group_center = np.array([4834.154, 22167.719, 16398.639])  # Group center
r_vir = 384.19003

# Left panel: Galaxy overlay
plot_galaxy_overlay(axs[0], hdf5_file_path, dataset_name, tng_data_path, galaxy_catalog_path, target_index, group_center, r_vir)

# Right panel: Velocity spectrum
ax = axs[1]
ax.step(velocity_zoom, flux_zoom, label='Original Spectrum', color='blue', where='mid', linewidth=2, alpha=0.7)
ax.step(velocity[zoom_mask], model_flux, label='Fitted Spectrum (Voigt Profile)', color='red', where='mid', linewidth=4)

ax.axvline(0, color='purple', linestyle='--', linewidth=2)
ax.axvline(centroid_velocities[0], color='green', linestyle='--', linewidth=2)
ax.axhline(1, color='gray', linestyle='--', linewidth=2)

ax.text(500, 0.2, r"$\mathrm{OVI \, 1031}$", fontsize=36, color='black', ha='center')

arrowprops = dict(arrowstyle='<->', color='black', lw=2)
b_value_start = centroid_velocities[0] - fwhms_velocity[0] / 2
b_value_end = centroid_velocities[0] + fwhms_velocity[0] / 2
ax.annotate('', xy=(b_value_start, 0.6), xytext=(b_value_end, 0.6), arrowprops=arrowprops)
ax.text(centroid_velocities[0] + 220, 0.55, r"$b = $" + f"{fitting_results['b'][0]:.1f} km/s",
        fontsize=18, ha='center', va='bottom', color='black')

textstr = (
    f"$v_{{0}}$: {centroid_velocities[0]:.1f} ± {fitting_results['dl'][0]:.2f} km/s\n"
    f"$b$: {fitting_results['b'][0]:.1f} ± {fitting_results['db'][0]:.2f} km/s\n"
    f"$N$: {fitting_results['N'][0]:.2f} ± {fitting_results['dN'][0]:.2f}\n"
    f"$EW$: {fitting_results['EW'][0]:.3f} Å\n"
    f"$\chi^2$: {fitting_results['chisq'][0]:.3f}"
)

ax.text(-450, 0.25, textstr, fontsize=18, ha='center', va='center', alpha=0.9, color='black', backgroundcolor='lightgrey')

# Add minor ticks and customize ticks
for ax in axs:
    ax.tick_params(which='both', direction='in', width=2)
    ax.minorticks_on()
    ax.tick_params(which='minor', length=4)
    ax.tick_params(which='major', length=8)

# Set labels, limits, and layout
axs[1].set_xlabel(r"$\mathrm{Velocity \, [km s^{-1}]}$", labelpad=10, fontsize=32)
axs[1].set_ylabel("Normalized Flux")
axs[1].set_xlim(-800, 800)
axs[1].set_ylim(0.0, 1.35)

plt.tight_layout()
plt.show()

# %%
import h5py
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# Set font styles
plt.rc('font', family='Times New Roman')
plt.rcParams.update({
    'axes.labelsize': 32,
    'xtick.labelsize': 30,
    'ytick.labelsize': 30,
    'legend.fontsize': 30,
    'xtick.direction': 'in',
    'ytick.direction': 'in',
    'axes.linewidth': 2  # Set the width of the box
})

c_km_s = 299792.458  # Speed of light in km/s

# Fitting results
fitting_results = {
    'region': np.array([0.]),
    'l': np.array([1032.47545845]),
    'dl': np.array([0.0799331]),
    'b': np.array([73.64487674]),
    'db': np.array([0.9999996]),
    'N': np.array([14.53968515]),
    'dN': np.array([0.17620323]),
    'EW': np.array([0.31674129]),
    'chisq': np.array([0.79111285])
}

# Calculate velocity centroids
centroid_velocities = c_km_s * (fitting_results['l'] - 1031.92) / 1031.92
fwhms_velocity = [(2 * np.sqrt(np.log(2)) * fitting_results['b'][0])]

# Plotting
fig, axs = plt.subplots(1, 2, figsize=(20, 6), gridspec_kw={'width_ratios': [1, 2], 'wspace': 0.15})

# Left panel: Galaxy Overlay
def plot_galaxy_overlay(ax, hdf5_file_path, dataset_name, tng_data_path, galaxy_catalog_path, target_index, group_center, r_vir):
    with h5py.File(hdf5_file_path, 'r') as hdf:
        data = hdf[dataset_name][()]
        data = np.nan_to_num(data, nan=10)

        # Grid extent in R_vir units
        x = np.linspace(-2, 2, 3000)
        y = np.linspace(-2, 2, 3000)

        with h5py.File(tng_data_path, 'r') as tng_file:
            ray_pos = tng_file['ray_pos'][target_index]
            sightline_2d = (ray_pos[:2] - group_center[:2]) / r_vir

        galaxy_catalog = pd.read_csv(galaxy_catalog_path, delimiter='|', skipinitialspace=True)
        galaxy_catalog.columns = galaxy_catalog.columns.str.strip()
        galaxy_positions = (galaxy_catalog[['SubhaloPos_0', 'SubhaloPos_1']].to_numpy() - group_center[:2]) / r_vir

        im = ax.imshow(data, origin='lower', extent=[x.min(), x.max(), y.min(), y.max()],
                       cmap='magma', aspect='equal', vmin=10)
        ax.scatter(0, 0, color='crimson', s=50, zorder=3)  # Group center
        ax.scatter(sightline_2d[0], sightline_2d[1], color='cyan', marker='*', s=250, zorder=3)

        box_props = dict(boxstyle='round,pad=0.3', edgecolor='black', facecolor='snow', alpha=1)
        arrowprops = dict(arrowstyle='->', color='white', lw=1.5)
        ax.annotate(f"Sightline {target_index - 90000*7}",
                    xy=(sightline_2d[0], sightline_2d[1]),
                    xytext=(-1.1, -1.75),
                    fontsize=20, ha='center', va='center', bbox=box_props,
                    arrowprops=arrowprops)

        for pos in galaxy_positions:
            ax.add_artist(plt.Circle(pos, 50 / r_vir, color='aquamarine', fill=False, linestyle='-', zorder=2, linewidth=1.5))
        
        ax.add_artist(plt.Circle((0, 0), 1.5, color='white', fill=False, linestyle='solid', linewidth=2, zorder=2))

        ax.set_xlabel("$X \, [R_{vir}]$", labelpad=10, fontsize=32)
        ax.set_ylabel("$Y \, [R_{vir}]$", labelpad=10, fontsize=32)

        # Add "ID = 264620" to the top-left corner in white
        ax.text(-1.9, 1.8, "ID = 264620", color='white', fontsize=22, ha='left', va='top')

        # Add colorbar inside the plot (relative to `imshow` extent)
        cbar_ax = ax.inset_axes([0.55, 0.07, 0.4, 0.02])  # [x_start, y_start, width, height] relative to imshow
        cbar = fig.colorbar(im, cax=cbar_ax, orientation='horizontal')
        cbar.set_label(r"$\log$[OVI Column Density](cm$^{-2}$)", fontsize=12,color='white',labelpad=3)
        
        # Place tick markers at the top side
        cbar.ax.tick_params(labelsize=15, length=8, width=2, direction='in', colors='white', top=True, bottom=False)
        cbar.ax.xaxis.set_ticks_position('top')  # Move ticks to the top
        cbar.ax.set_xticks([12, 14, 16])  # Set tick positions
        cbar.ax.set_xticklabels(['12', '14', '16'])  # Set tick labels
        
        #cbar.ax.tick_params(labelsize=15, length=0, width=2)
        cbar.outline.set_edgecolor('snow')  # Remove colorbar border

# Parameters for the galaxy overlay
hdf5_file_path = "/Users/tsingh65/ASU Dropbox/Tanmay Singh/Grad School/MgII_TNG_Project/Primary Project/Trident Codes/TNG_50_z_0.1_groups/grid_data/grp_15_halo_15_snapshot_91.hdf5"
tng_data_path = "/Users/tsingh65/ASU Dropbox/Tanmay Singh/Grad School/MgII_TNG_Project/Primary Project/Trident Codes/TNG_50_z_0.1_groups/revised_final/spectra_TNG50-1_z0.1_n300d2-sample_localized_COS-G130M_OVI_combined.hdf5"
galaxy_catalog_path = "/Users/tsingh65/ASU Dropbox/Tanmay Singh/Synthetic_IGrM_Sightlines/TNG50_fitting_results/galaxy_cats/group_15_galaxy_catalog_converted.txt"
dataset_name = "grid"
target_index = 656688
group_center = np.array([4834.154, 22167.719, 16398.639])  # Group center
r_vir = 384.19003

# Left panel: Galaxy overlay
plot_galaxy_overlay(axs[0], hdf5_file_path, dataset_name, tng_data_path, galaxy_catalog_path, target_index, group_center, r_vir)

# Right panel: Velocity spectrum
ax = axs[1]
ax.step(velocity_zoom, flux_zoom, label='Original Spectrum', color='blue', where='mid', linewidth=2, alpha=0.8)
ax.step(velocity[zoom_mask], model_flux, label='Fitted Spectrum (Voigt Profile)', color='red', where='mid', linewidth=5)

ax.axvline(0, color='purple', linestyle='--', linewidth=2)
ax.axvline(centroid_velocities[0], color='black', linestyle='solid', linewidth=2)
ax.axhline(1, color='gray', linestyle='--', linewidth=2)

ax.text(480, 0.2, r"$\mathrm{OVI \, 1031}$", fontsize=34, color='black', ha='center',
        bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.2'))

arrowprops = dict(arrowstyle='<->', color='black', lw=2)
b_value_start = centroid_velocities[0] - fwhms_velocity[0] / 2
b_value_end = centroid_velocities[0] + fwhms_velocity[0] / 2
ax.annotate('', xy=(b_value_start, 0.6), xytext=(b_value_end, 0.6), arrowprops=arrowprops)
ax.text(centroid_velocities[0] + 220, 0.55, r"$b = $" + f"{fitting_results['b'][0]:.1f} km/s",
        fontsize=18, ha='center', va='bottom', color='black')

textstr = (
    f"$v_{{0}}$: {centroid_velocities[0]:.1f} ± {fitting_results['dl'][0]:.2f} km/s\n"
    f"$b$: {fitting_results['b'][0]:.1f} ± {fitting_results['db'][0]:.2f} km/s\n"
    f"$N$: {fitting_results['N'][0]:.2f} ± {fitting_results['dN'][0]:.2f}\n"
    f"$EW$: {fitting_results['EW'][0]:.3f} Å\n"
    f"$\chi^2$: {fitting_results['chisq'][0]:.3f}"
)

ax.text(-450, 0.35, textstr, fontsize=20, ha='center', va='center', alpha=0.9, color='black', backgroundcolor='lightgrey')

# Add minor ticks and customize ticks
for ax in axs:
    ax.tick_params(which='both', direction='in', width=2)
    ax.tick_params(which='major', length=7, width=2, direction='in', top=True, right=True)
    ax.tick_params(which='minor', length=4, width=1.5, direction='in', top=True, right=True)

    ax.tick_params(which='minor', length=4)
    ax.tick_params(which='major', length=7)
    ax.minorticks_on()

# Set labels, limits, and layout
axs[1].set_xlabel(r"$\mathrm{Velocity \, [km s^{-1}]}$", labelpad=10, fontsize=32)
axs[1].set_ylabel("Normalized Flux")
axs[1].set_xlim(-800, 800)
# set x labels at , -600, -400, -200, 0, 200, 400, 600, 800
axs[1].set_xticks(np.arange(-600, 801, 200))
# reduce the font size of the xtickslabels slightly
axs[1].set_xticklabels(axs[1].get_xticks(), fontsize=28)

axs[1].set_ylim(0.0, 1.35)


output_dir = "/Users/tsingh65/ASU Dropbox/Tanmay Singh/Synthetic_IGrM_Sightlines/TNG50_fitting_results/Plots"
# Save the plot with additional adjustments to avoid clipping
plt.savefig(f"{output_dir}/eg_spectra.pdf", dpi=400, bbox_inches='tight', pad_inches=0.3)
plt.show()

# %%
import matplotlib
matplotlib.rcdefaults()  # Reset Matplotlib configuration to defaults
matplotlib.use('Agg')  # Use the Agg backend for rendering
import matplotlib.pyplot as plt
plt.rc('font', family='DejaVu Sans')  # DejaVu Sans is UTF-8 compatible

# %%
import matplotlib.pyplot as plt
import numpy as np

# Reset to defaults
import matplotlib
matplotlib.rcdefaults()

# Generate sample data
x = np.linspace(0, 10, 100)
y = np.sin(x)

# Plot
plt.plot(x, y, label="sin(x)")
plt.xlabel("x")
plt.ylabel("y")
plt.title("Basic Plot")
plt.legend()

# Save the figure
plt.savefig("basic_plot_1.png")
plt.show()

# %%
def EW_to_N(pg_ion,ew_err):

    ## Following Draine eq. 9.15

    f = float(pg.analysis.absorption_spectra.lines[pg_ion]['f'])
    l = float(pg.analysis.absorption_spectra.lines[pg_ion]['l'].split()[0])
    N = 1.13e12 * (3 * ew_err * 1.0e-11) / f / (l * 1.0e-8)**2

    return np.log10(N)
EW_to_N('OVI1031', 13.328228594498576)


