In [15]:
import os
import numpy as np
import pylab as py
import matplotlib.pyplot as plt
from spisea import synthetic, evolution, atmospheres, reddening, ifmr
from spisea.imf import imf, multiplicity
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
import csv

# Paths for isochrones and output
iso_dir = 'isochrones/'
output_dir = 'output_diagrams/'

# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)

# Estimation variables
star_index = 0  # Set which star in the CSV to analyze

# Define static isochrone parameters
dist = 4500
evo_model = evolution.Baraffe15()
atm_func = atmospheres.get_merged_atmosphere
red_law = reddening.RedLawCardelli(3.1)
filt_list = ['jwst,F162M', 'jwst,F182M']
filters = ['m_jwst_F162M', 'm_jwst_F182M']
metallicity = 0
level_ages = np.linspace(1, 10, 19) * 1e6  # Define age array
log_age_arr = np.log10(level_ages)

# Load sample magnitudes, skipping the header row
sample_mags = []
with open('./s284-162-182.csv', mode='r') as file:
    csvFile = csv.reader(file)
    next(csvFile)  # Skip header row
    for lines in csvFile:
        sample_mags.append([float(x) for x in lines])

# Function to generate isochrone grid
def generate_isochrones(AKs):
    return np.array([
        synthetic.IsochronePhot(log_age, AKs, dist, metallicity=metallicity,
                                evo_model=evo_model, atm_func=atm_func,
                                red_law=red_law, filters=filt_list,
                                iso_dir=iso_dir)
        for log_age in log_age_arr
    ])

# Generate isochrone grids
isochrone_AKs_07 = generate_isochrones(0.7)
isochrone_AKs_00 = generate_isochrones(0.0)

# Deredden the observed flux
filter_wavelengths = {
    "m_jwst_F162M": 1.62,
    "m_jwst_F182M": 1.82,
    "m_jwst_F200W": 2.00,
    "m_jwst_F356M": 3.56,
    "m_jwst_F405N": 4.05
}

# Convert filter names to wavelengths
wavelengths = [filter_wavelengths[f] for f in filters]
AKs_ref = 0.7
sample_mags_dereddened = [m - red_law.Cardelli89(wavelength, AKs_ref) for m, wavelength in zip(sample_mags[star_index], wavelengths)]

# Plot function
def plot_diagram(isochrone_grid, ref_mags, title, filename):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
    cmap = plt.get_cmap('coolwarm')
    
    # CMD
    for i, instance in enumerate(isochrone_grid):
        color = cmap(i / (len(isochrone_grid) - 1))
        ax1.plot(instance.points[filters[0]] - instance.points[filters[1]],
                 instance.points[filters[1]], color=color)
    ax1.set_xlabel('F162M - F182M')
    ax1.set_ylabel('F182M')
    ax1.invert_yaxis()
    ax1.grid(True)
    
    # MMD
    for i, instance in enumerate(isochrone_grid):
        color = cmap(i / (len(isochrone_grid) - 1))
        ax2.plot(instance.points[filters[0]], instance.points[filters[1]], color=color)
    ax2.set_xlabel('F162M')
    ax2.set_ylabel('F182M')
    ax2.invert_xaxis()
    ax2.invert_yaxis()
    ax2.grid(True)
    
    # Reference star
    ax1.plot(ref_mags[0] - ref_mags[1], ref_mags[1], '*', markersize=10, color='gold', label="Reference Star")
    ax2.plot(ref_mags[0], ref_mags[1], '*', markersize=10, color='gold', label="Reference Star")
    
    ax1.legend()
    ax2.legend()
    fig.suptitle(title)
    plt.savefig(os.path.join(output_dir, filename))
    plt.close()

# Plot both cases
plot_diagram(isochrone_AKs_07, sample_mags[star_index], "Isochrone Grid with AKs = 0.7", "CMD_MMD_AKs_07.png")
plot_diagram(isochrone_AKs_00, sample_mags_dereddened, "Isochrone Grid with AKs = 0 and Dereddened Flux", "CMD_MMD_AKs_00_Dereddened.png")
