In [None]:
%load_ext autoreload
%autoreload 2


import corner
import jax
import jax.numpy as jnp
import numpy as np
import torch
from omegaconf import OmegaConf  # To create DictConfig-like objects if needed
from scoresbibm.evaluation.eval_task import (
    eval_coverage,
    eval_inference_task,
)
from scoresbibm.methods.score_transformer import train_transformer_model
from scoresbibm.utils.data_utils import (
    load_model,
    save_model,
)
from simformer import GalaxyPhotometryTask, GalaxySimulator
from synthesizer.emission_models import (
    TotalEmission,
)
from synthesizer.emission_models.attenuation import Calzetti2000
from synthesizer.grid import Grid
from synthesizer.instruments import FilterCollection, Instrument
from synthesizer.parametric import (
    SFH,
    ZDist,
)  # Need concrete SFH, ZDist classes
from unyt import Myr

# Example: Define global 'device' if not already defined
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Setup synthesizer

In [None]:
grid_dir = "/home/tharvey/work/synthesizer_grids/"
grid_name = "bpass-2.2.1-bin_chabrier03-0.1,300.0_cloudy-c23.01-sps.hdf5"

grid = Grid(grid_name, grid_dir=grid_dir)
filter_codes = [
    "JWST/NIRCam.F090W",
    "JWST/NIRCam.F115W",
    "JWST/NIRCam.F150W",
    "JWST/NIRCam.F162M",
    "JWST/NIRCam.F182M",
    "JWST/NIRCam.F200W",
    "JWST/NIRCam.F210M",
    "JWST/NIRCam.F250M",
    "JWST/NIRCam.F277W",
    "JWST/NIRCam.F300M",
    "JWST/NIRCam.F335M",
    "JWST/NIRCam.F356W",
    "JWST/NIRCam.F410M",
    "JWST/NIRCam.F444W",
]
filterset = FilterCollection(filter_codes)
instrument = Instrument("JWST", filters=filterset)


emission_model_instance = TotalEmission(
    grid=grid,
    fesc=0.0,
    fesc_ly_alpha=0.1,
    dust_curve=Calzetti2000(),
    dust_emission_model=None,
)

### Setup simulator for photometry and prior ranges

In [None]:
sfh_model_class = SFH.LogNormal
zdist_model_class = ZDist.DeltaConstant

emitter_params_dict = {"stellar": ["tau_v"]}

galaxy_simulator_instance = GalaxySimulator(
    sfh_model=sfh_model_class,
    zdist_model=zdist_model_class,
    grid=grid,
    instrument=instrument,
    emission_model=emission_model_instance,
    emission_model_key="total",
    emitter_params=emitter_params_dict,
    param_units={"peak_age": Myr, "max_age": Myr},
    normalize_method=None,
    output_type="photo_fnu",
    out_flux_unit="ABmag",
)

inputs_list = [
    "redshift",
    "log_mass",
    "log10metallicity",
    "tau_v",
    "peak_age",
    "max_age",
    "tau",
]

priors_ranges_dict = {
    "redshift": (5.0, 12.0),
    "log_mass": (7.0, 11.0),
    "log10metallicity": (-3.0, -1.3),
    "tau_v": (0.0, 2),
    "peak_age": (0.0, 500.0),
    "max_age": (500.0, 1000.0),
    "tau": (0.3, 1.5),
}

### Setup simulator wrapper

In [None]:
def run_simulator_glob(params, return_type="tensor"):
    """Run the galaxy simulator with given parameters."""
    if isinstance(params, torch.Tensor):
        params = params.cpu().numpy()
    if isinstance(params, dict):
        pass  # assumes params are correctly keyed
    elif isinstance(params, (list, tuple, np.ndarray)):
        params = np.squeeze(params)
        params = {inputs_list[i]: params[i] for i in range(len(inputs_list))}

    phot = galaxy_simulator_instance(
        params
    )  # This line requires galaxy_simulator_instance

    if return_type == "tensor":
        return torch.tensor(phot[np.newaxis, :], dtype=torch.float32).to(device)
    else:
        return phot


galaxy_task = GalaxyPhotometryTask(
    prior_dict=priors_ranges_dict,
    param_names_ordered=inputs_list,
    run_simulator_fn=run_simulator_glob,  # Pass your actual function
    num_filters=len(filter_codes),
)

# Test data generation
print(f"Theta dim: {galaxy_task.get_theta_dim()}")
print(f"X dim: {galaxy_task.get_x_dim()}")
data_batch = galaxy_task.get_data(num_samples=3)
print("Sampled theta (JAX):", data_batch["theta"])
print("Shape of theta:", data_batch["theta"].shape)
print("Sampled x (JAX):", data_batch["x"])
print("Shape of x:", data_batch["x"].shape)

# Test prior sampling directly
prior_for_test = galaxy_task.get_prior()
theta_samples_torch = prior_for_test.sample((2,))
print("Direct prior samples (Torch):", theta_samples_torch)
print("Log prob of prior samples:", prior_for_test.log_prob(theta_samples_torch))

# Test base mask function
mask_fn = galaxy_task.get_base_mask_fn()
node_ids_example = jnp.arange(galaxy_task.get_theta_dim() + galaxy_task.get_x_dim())
# jax.random.shuffle(jax.random.PRNGKey(0), node_ids_example) # Example of permuted IDs
applied_mask = mask_fn(node_ids=node_ids_example, node_meta_data=None)
print("Base mask applied to node_ids:", applied_mask)

### Configure model hyperparameters


In [None]:
# Model Configuration (from config/method/model/score_transformer.yaml)
#
model_config_dict = {
    "name": "ScoreTransformer",
    "d_model": 128,
    "n_heads": 4,
    "n_layers": 4,
    "d_feedforward": 256,
    "dropout": 0.1,
    "max_len": 5000,  # Adjust based on theta_dim + x_dim
    "tokenizer": {"name": "LinearTokenizer", "encoding_dim": 64},
    "use_output_scale_fn": True,
    # Add other model-specific parameters as per the YAML
}

# SDE Configuration (e.g., from config/method/sde/vpsde.yaml)
#
sde_config_dict = {
    "name": "VPSDE",  # or "VESDE"
    "beta_min": 0.1,
    "beta_max": 20.0,
    "num_steps": 1000,
    "T_min": 1e-05,
    "T_max": 1.0,
}

# Training Configuration (from config/method/train/train_score_transformer.yaml)
train_config_dict = {
    "learning_rate": 1e-4,  # Initial learning rate for training # used
    "min_learning_rate": 1e-6,  # Minimum learning rate for training # used
    "z_score_data": True,  # Whether to z-score (normalize) the data # used
    "total_number_steps_scaling": 5,  # Scaling factor for total number of steps # used
    "max_number_steps": 1e8,  # Maximum number of steps for training # used
    "min_number_steps": 1e4,  # Minimum number of steps for training # used
    "training_batch_size": 64,  # Batch size for training # used
    "val_every": 50,  # Validate every 100 steps # used
    "clip_max_norm": 10.0,  # Gradient clipping max norm # used
    "condition_mask_fn": {
        "name": "structured_random"
    },  # Use the base mask function defined in the task
    "edge_mask_fn": {"name": "faithfull"},
    "validation_fraction": 0.1,  # Fraction of data to use for validation # used
    "val_repeat": 5,  # Number of times to repeat validation # used
    "stop_early_count": 5,  # Number of steps to wait before stopping early # used
    "rebalance_loss": True,  # Whether to rebalance the loss # used
}


method_config_dict = {
    "device": str(device),  # Ensure this matches your device setup
    "sde": sde_config_dict,
    "model": model_config_dict,
    "train": train_config_dict,
}

# Convert the main method_cfg to OmegaConf DictConfig
method_cfg = OmegaConf.create(method_config_dict)

In [None]:
galaxy_task.get_data(num_samples=3)

### Setup task and generate training data

In [None]:
print("Instantiating GalaxyPhotometryTask...")
galaxy_task = GalaxyPhotometryTask(
    prior_dict=priors_ranges_dict,
    param_names_ordered=inputs_list,
    run_simulator_fn=run_simulator_glob,
    num_filters=len(filter_codes),
    backend="jax",  # Or "torch"
)
print("Task instantiated.")

# --- 2. Generate Data ---
num_training_simulations = 5000
print(f"Generating {num_training_simulations} training simulations...")
# .get_data() returns a dict with JAX arrays if backend is "jax"
training_data = galaxy_task.get_data(num_samples=num_training_simulations)
theta_train = training_data["theta"]
x_train = training_data["x"]
print(f"Data generated: theta shape {theta_train.shape}, x shape {x_train.shape}")

# (Optional) Generate validation data if your train_config.val_split is 0
num_validation_simulations = 20
validation_data = galaxy_task.get_data(num_samples=num_validation_simulations)
theta_val = validation_data["theta"]
x_val = validation_data["x"]

# --- 3. Set RNG Seed for JAX ---
rng_seed_for_training = 0
master_rng_key = jax.random.PRNGKey(rng_seed_for_training)

### Train the model

In [None]:
# --- 4. Train the Model ---

print("Starting training...")
trained_score_model = train_transformer_model(
    task=galaxy_task,
    data=training_data,  # Expects dict {"theta": ..., "x": ...} with JAX arrays
    method_cfg=method_cfg,  # The OmegaConf object created above
    rng=master_rng_key,
)
print(
    "Training finished. Model returned by train_transformer_model:",
    type(trained_score_model),
)

### Evaluate and validate the model

In [None]:
plot_corner = True
import corner
import matplotlib.pyplot as plt

# Take test observation
theta_dim = galaxy_task.get_theta_dim()
x_dim = galaxy_task.get_x_dim()
# Mask for posterior: theta is unknown (0), x is known (1)
posterior_condition_mask = jnp.array([0] * theta_dim + [1] * x_dim, dtype=jnp.bool_)
for i, xobs in enumerate(x_val):
    x_o = jnp.array([xobs], dtype=jnp.float32)
    samples = trained_score_model.sample_batched(
        num_samples=1000,
        x_o=x_o,
        rng=master_rng_key,
        condition_mask=posterior_condition_mask,
    )
    samples = np.array(samples, dtype=np.float32)
    theta = np.array(theta_val[i], dtype=np.float32)

    if plot_corner:
        truth = np.array(theta_val[i], dtype=np.float32)

        corner.corner(
            samples,
            labels=galaxy_task.param_names_ordered,
            show_titles=True,
            truths=truth,
            quantiles=[0.16, 0.5, 0.84],
            title_kwargs={"fontsize": 12},
        )
        plt.savefig(
            f"/home/tharvey/work/ltu-ili_testing/models/simformer/plots/corner_plot_{i}.png"
        )
"""
import pickle
with open("trained_galaxy_score_model_params.pkl", "wb") as f:
    pickle.dump(trained_score_model.score_model_params, f)
print("Model parameters saved (example).")"""

### Compute C2ST metric

In [None]:
from scoresbibm.evaluation.eval_metrics import c2st
from scoresbibm.evaluation.eval_task import eval_inference_task

eval_inference_task(
    task=galaxy_task,
    model=trained_score_model,
    metric_fn=c2st,  # Use the c2st metric function
    metric_params={"condition_mask_fn": "posterior"},
    rng=master_rng_key,
    num_samples=1000,
    num_evaluations=10,
)

In [None]:
def get_condition_mask_fn(str):
    """Get the condition mask function based on the provided string."""
    return (
        galaxy_task.get_base_mask_fn()
        if str == "base"
        else galaxy_task.get_joint_mask_fn()
    )


metric_values, eval_time = eval_coverage(
    task=galaxy_task,
    model=trained_score_model,
    metric_params={
        "num_samples": 1000,  # Number of samples to draw for coverage evaluation
        "num_evaluations": 50,
        "condition_mask_fn": "posterior",  # posterior, joint, likelihood,
        # random or structured random
        "num_bins": 20,  # Number of bins for histogram
        "sample_kwargs": {},
        "log_prob_kwargs": {},
        "batch_size": 64,  # Batch size for sampling
    },
    rng=master_rng_key,
)

In [None]:
plt.plot(metric_values[0], metric_values[1], marker="o", label="Coverage")
plt.plot([0, 1], [0, 1], "k--", label="Ideal Coverage")
plt.xlabel("Predicted Percentile")
plt.ylabel("Empirical Percentile")
plt.legend()
plt.savefig("/home/tharvey/work/ltu-ili_testing/models/simformer/plots/coverage_plot.png")

### Saving the model

In [None]:
save_model(
    model=trained_score_model,
    dir_path="/home/tharvey/work/ltu-ili_testing/models/simformer/",
    model_id="simformer_galaxy_score_model_test_v1",
)

# also need to pickle - num_filters, prior_dict, run_simulator_fn, param_names_ordered,
# backend

save_dict = {
    "_x_dim": galaxy_task.get_x_dim(),
    "_theta_dim": galaxy_task.get_theta_dim(),
    "prior_dict": priors_ranges_dict,
    "param_names_ordered": galaxy_task.param_names_ordered,
    "backend": galaxy_task.backend,
    "method_config_dict": method_config_dict,
}

from joblib import dump

dump(
    save_dict,
    "/home/tharvey/work/ltu-ili_testing/models/simformer/models/data_simformer_galaxy_score_model_test_v1.joblib",
)

In [None]:
reloaded_model = load_model(
    dir_path="/home/tharvey/work/ltu-ili_testing/models/simformer/",
    model_id="simformer_galaxy_score_model_test_v1",
)

In [None]:
plot_corner = True
import corner
import matplotlib.pyplot as plt

# Take test observation
theta_dim = galaxy_task.get_theta_dim()
x_dim = galaxy_task.get_x_dim()
# Mask for posterior: theta is unknown (0), x is known (1)
posterior_condition_mask = jnp.array([0] * theta_dim + [1] * x_dim, dtype=jnp.bool_)
for i, xobs in enumerate(x_val):
    x_o = jnp.array([xobs], dtype=jnp.float32)
    samples = trained_score_model.sample_batched(
        num_samples=1000,
        x_o=x_o,
        rng=master_rng_key,
        condition_mask=posterior_condition_mask,
    )
    samples = np.array(samples, dtype=np.float32)
    theta = np.array(theta_val[i], dtype=np.float32)

    if plot_corner:
        truth = np.array(theta_val[i], dtype=np.float32)

        corner.corner(
            samples,
            labels=galaxy_task.param_names_ordered,
            show_titles=True,
            truths=truth,
            quantiles=[0.16, 0.5, 0.84],
            title_kwargs={"fontsize": 12},
        )
        plt.savefig(
            f"/home/tharvey/work/ltu-ili_testing/models/simformer/plots/corner_plot_{i}.png"
        )
"""
import pickle
with open("trained_galaxy_score_model_params.pkl", "wb") as f:
    pickle.dump(trained_score_model.score_model_params, f)
print("Model parameters saved (example).")"""

In [None]:
trained_score_model.z_score_params