In [10]:
import numpy as np
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
from sfh import SFH

In [6]:
def rmse(base, restricted=False):

    output_dir = f"OUTPUTS/{base}/"

    if restricted:
        pred = np.load(f"{output_dir}{base}_cv_all_pred_restricted.npy")
        true = np.load(f"{output_dir}{base}_cv_all_true_restricted.npy")
    else:
        pred = np.load(f"{output_dir}{base}_cv_all_pred.npy")
        true = np.load(f"{output_dir}{base}_cv_all_true.npy")

    nan_rows = np.isnan(pred).any(axis=1)
    pred_clean = pred[~nan_rows]
    true_clean = true[~nan_rows]

    rmse = np.sqrt(mean_squared_error(true_clean, pred_clean, multioutput='raw_values'))
    #for j, val in enumerate(rmse):
        #print(f"RMSE for label {j+1}: {val:.4f}")

    overall_rmse = np.sqrt(mean_squared_error(true_clean, pred_clean))
    #print(f"Overall RMSE: {overall_rmse:.4f}")

    return rmse, overall_rmse

In [11]:
def extract_data(base, n):

    f_real = np.load(f"OUTPUTS/{base}/{base}_cv_all_true_restricted.npy")
    f_pred = np.load(f"OUTPUTS/{base}/{base}_cv_all_pred_restricted.npy")

    sfh_real = SFH(f_real[n])
    sfh_pred = SFH(f_pred[n])

    wav_real, s_real, invar_real = sfh_real.final_spectrum()
    wav_pred, s_pred, invar_pred = sfh_pred.final_spectrum()

    return f_real, f_pred, wav_real, s_real, wav_pred, s_pred

In [12]:
def plot_spectra(wav_real, s_real, wav_pred, s_pred):

    fig, ax = plt.subplots(4,1,figsize=(20,7))

    ax[0].plot(wav_real, s_real, 'r', alpha=0.5)
    ax[0].plot(wav_pred, s_pred, 'k', linewidth=0.5)
    ax[0].set_ylim(-0.15,0.1)
    ax[0].set_xlim(3500,4500)

    ax[1].plot(wav_real, s_real, 'r', alpha=0.5)
    ax[1].plot(wav_pred, s_pred, 'k', linewidth=0.5)
    ax[1].set_ylim(-0.1,0.05)
    ax[1].set_xlim(4500,5500)

    ax[2].plot(wav_real, s_real, 'r', alpha=0.5)
    ax[2].plot(wav_pred, s_pred, 'k', linewidth=0.5)
    ax[2].set_ylim(-0.02,0.01)
    ax[2].set_xlim(5500,6500)

    ax[3].plot(wav_real, s_real, 'r', alpha=0.5)
    ax[3].plot(wav_pred, s_pred, 'k', linewidth=0.5)
    ax[3].set_ylim(-0.15,0.05)
    ax[3].set_xlim(6500,7500)

    return

In [57]:
f_real, f_pred, wav_real, s_real, wav_pred, s_pred = extract_data("sfh_2000_10_20250904_134728", 0)

# Plot histogram-like bar plot for a single f_real array with custom bin widths
bin_arr = np.r_[np.array([0, 0.1, 20, 50, 100, 200, 500])*1e6, np.logspace(9.5, 10.15, 4)]
binning = np.log10(bin_arr)
bin_widths = np.diff(binning)
bin_centers = binning[:-1] + bin_widths/2

# Select the index you want to plot, e.g., 0
n = 0
real_weights = f_real[n]
pred_weights = f_pred[n]

plt.figure(figsize=(10, 5))
plt.bar(bin_centers, real_weights, width=bin_widths, align='center', color='b', alpha=0.5, edgecolor='b', label='f_real')
plt.bar(bin_centers, pred_weights, width=bin_widths, align='center', color='r', alpha=0.5, edgecolor='r', label='f_pred')
plt.xlabel('log(Age) bin')
plt.ylabel('Weight')
plt.show()

FileNotFoundError: [Errno 2] No such file or directory: 'OUTPUTS/sfh_2000_10_20250904_134728/sfh_2000_10_20250904_134728_cv_all_true_restricted.npy'

In [56]:
f_pred[0]

array([ 0.03259614,  0.18396943,  0.15258875, -0.05955222,  0.15461805,
        0.00043491,  0.08144851,  0.16409504,  0.38246676, -0.09266535])