In [None]:
# Import necessary packages
import os
import numpy as np
import pylab as py
import matplotlib.pyplot as plt
from spisea import synthetic, evolution, atmospheres, reddening
from spisea.imf import imf, multiplicity
from matplotlib.colors import LogNorm
from matplotlib.cm import ScalarMappable
import csv
from itertools import combinations

# Paths for isochrones and output
iso_dir = 'isochrones/'
output_dir = 'output_diagrams/'

# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)

# Estimation variables
star_index = 0      # Set which star in the CSV to analyze
num_top_predictions = 10  # Control how many top predictions to plot

# Define static isochrone parameters
AKs = 0.7  # Assume a fixed AKs value (update this as needed)
dist = 4500
evo_model = evolution.Baraffe15()
atm_func = atmospheres.get_merged_atmosphere
red_law = reddening.RedLawCardelli(3.1)
filt_list = ['jwst,F162M', 'jwst,F182M', 'jwst,F200W', 'jwst,F356M', 'jwst,F405N']
filters = ['m_jwst_F162M', 'm_jwst_F182M', 'm_jwst_F200W', 'm_jwst_F356M', 'm_jwst_F405N']
metallicity = 0
level_ages = np.linspace(1, 10, 19) * 1e6  # Define age array
log_age_arr = np.log10(level_ages)

# Load sample magnitudes, skipping the header row
sample_mags = []
with open('../s284-no-errors.csv', mode='r') as file:
    csvFile = csv.reader(file)
    next(csvFile)  # Skip header row
    for lines in csvFile:
        sample_mags.append([float(x) for x in lines])

# Chi-square minimization function
def chi_square_reverse_model(iso_grid, sample_mags, filters):
    results = []
    for i, iso in enumerate(iso_grid):
        for star in iso.points:
            chi_square = sum(((sample_mags[k] - star[filters[k]]) ** 2) / star[filters[k]] for k in range(len(sample_mags)))
            results.append([chi_square, star['mass'], 10 ** log_age_arr[i]])
    return results

# Ensure directories exist
os.makedirs(iso_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)

# Get all pairs of filters (combinations of two)
filter_combinations = list(combinations(filt_list, 2))

# Plot the results
fig, ax = plt.subplots(figsize=(10, 8))

# Loop through each filter combination and compute chi-square
for idx, (filter_pair, filter_labels) in enumerate(zip(filter_combinations, combinations(filters, 2))):
    # Clear previous isochrones
    for file in os.listdir(iso_dir):
        os.remove(os.path.join(iso_dir, file))
    
    # Generate isochrone grid for the current pair of filters
    instances = np.array([
        synthetic.IsochronePhot(log_age, AKs, dist, metallicity=metallicity,
                                evo_model=evo_model, atm_func=atm_func,
                                red_law=red_law, filters=filter_pair,
                                iso_dir=iso_dir)
        for log_age in log_age_arr
    ])

    # Compute chi-square for each point in the filter pair
    results = chi_square_reverse_model(instances, sample_mags[star_index], filter_labels)
    chi_square_values, masses, ages = zip(*results)  # Unpack results

    # Scatter plot with black points
    sc = ax.scatter(ages, masses, c='black', s=50, edgecolor='k', linewidth=0.5)

    # Outline the points with chi-square values < 0.1 using unique color for each filter pair
    norm = LogNorm(vmin=0, vmax=0.1)
    for chi_square, mass, age in zip(chi_square_values, masses, ages):
        if chi_square < 0.1:
            ax.scatter(age, mass, c=[plt.cm.viridis(chi_square)], s=80, edgecolor='k', linewidth=1.5)

# Apply logarithmic scale to y-axis
ax.set_yscale('log')

# Add labels, grid, and legend
ax.set_xlabel('Age (years)')
ax.set_ylabel('Mass (M☉)')
ax.set_title(f'Age-Mass Diagram with Chi-Squared Outlines for Star {star_index+1}')
ax.grid(True, which='both', linestyle='--', linewidth=0.5)  # Adjust grid for log scale

# Save the figure with the star's index as part of the filename
plt.savefig(os.path.join(output_dir, f'star-{star_index+1}.png'))
plt.close()

