In [None]:
import os
os.environ['SPS_HOME'] = '/home/annalena/sps_fsps'

In [None]:
import logging

# Disable all logging messages
logging.disable(logging.CRITICAL)

# Load the ssp template from FSPS

In [None]:
# NBVAL_SKIP
from rubix.spectra.ssp.factory import get_ssp_template
ssp_fsps = get_ssp_template("FSPS")

In [None]:
age_values = ssp_fsps.age
print(age_values.shape)

metallicity_values = ssp_fsps.metallicity
print(metallicity_values.shape)

In [None]:
index_age = 90
index_metallicity = 7

initial_metallicity_index = 5
initial_age_index = 70

learning_age = 0.5
learning_metallicity = 1e-3
tol = 1e-5

print(f"start age: {age_values[initial_age_index]}, start metallicity: {metallicity_values[initial_metallicity_index]}")
print(f"target age: {age_values[index_age]}, target metallicity: {metallicity_values[index_metallicity]}")

# Configure pipeline

In [None]:
from rubix.core.pipeline import RubixPipeline
# Suppose you already have a user_config or path to config
#config = "../rubix/config/pipeline_config.yaml"
import os
config = {
    "pipeline":{"name": "calc_gradient"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 14,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": True,
            "subset_size": 2,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-14.hdf5",
        },
    
    },
    "output_path": "output",
    "output_modified":  False,

    "telescope":
        {"name": "TESTGRADIENT",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 100,"noise_distribution": "normal"},
         },
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "FSPS"
        },
    },        
}
pipe = RubixPipeline(config)
rubixdata = pipe.run()

# Set target values

In [None]:
import jax.numpy as jnp

rubixdata.stars.age = jnp.array([age_values[index_age], age_values[index_age]])
rubixdata.stars.metallicity = jnp.array([metallicity_values[index_metallicity], metallicity_values[index_metallicity]])
rubixdata.stars.mass = jnp.array([[1.0, 1.0]])
rubixdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])

In [None]:
from rubix.core.pipeline_gradient import RubixPipeline
# Suppose you already have a user_config or path to config
#config = "../rubix/config/pipeline_config.yaml"
import os
config = {
    "pipeline":{"name": "calc_gradient"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
     "data": {
        "args": {
            "particle_type": ["stars"],
        },
    },
    
    "output_path": "output",
    "output_modified":  False,

    "telescope":
        {"name": "TESTGRADIENT",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 100,"noise_distribution": "normal"},
         },
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "FSPS"
        },
    },        
}
pipe = RubixPipeline(config)
rubixdata = pipe.run(rubixdata)

target = rubixdata

# Initial datacube

In [None]:
rubixdata.stars.age = jnp.array([age_values[initial_age_index], age_values[initial_age_index]])
rubixdata.stars.metallicity = jnp.array([metallicity_values[initial_metallicity_index], metallicity_values[initial_metallicity_index]])
rubixdata.stars.mass = jnp.array([[1.0, 1.0]])
rubixdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])

In [None]:
from rubix.core.pipeline_gradient import RubixPipeline
# Suppose you already have a user_config or path to config
#config = "../rubix/config/pipeline_config.yaml"
import os
config = {
    "pipeline":{"name": "calc_gradient"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
     "data": {
        "args": {
            "particle_type": ["stars"],
        },
    },
    
    "output_path": "output",
    "output_modified":  False,

    "telescope":
        {"name": "TESTGRADIENT",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 100,"noise_distribution": "normal"},
         },
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "FSPS"
        },
    },        
}
pipe = RubixPipeline(config)
rubixdata = pipe.run(rubixdata)

initial = rubixdata

# Adam optimizer

In [None]:
from rubix.pipeline import linear_pipeline as pipeline

pipeline_instance = RubixPipeline(config)

pipeline_instance._pipeline = pipeline.LinearTransformerPipeline(
    pipeline_instance.pipeline_config, 
    pipeline_instance._get_pipeline_functions()
)
pipeline_instance._pipeline.assemble()
pipeline_instance.func = pipeline_instance._pipeline.compile_expression()

In [None]:
def loss_only_wrt_age_metallicity(age, metallicity, base_data, target):
        
        base_data.stars.age = age
        base_data.stars.metallicity = metallicity

        output = pipeline_instance.func(base_data)
        loss = jnp.sum((output.stars.datacube - target.stars.datacube) ** 2)

        return loss


In [None]:

import jax
import jax.numpy as jnp
import optax


def adam_optimization_multi(loss_fn, params_init, data, target, age_lr=0.1, metallicity_lr=0.05, tol=1e-3, max_iter=500):
    """
    Optimizes both age and metallicity.

    Args:
        loss_fn: function with signature loss_fn(age, metallicity, data, target)
        params_init: dict with keys 'age' and 'metallicity', each a JAX array
        data: base data for the loss function
        target: target data for the loss function
        learning_rate: learning rate for Adam
        tol: tolerance for convergence (based on update norm)
        max_iter: maximum number of iterations

    Returns:
        params: final parameters (dict)
        params_history: list of parameter values for each iteration
        loss_history: list of loss values for each iteration
    """
    params = params_init  # e.g., {'age': jnp.array(...), 'metallicity': jnp.array(...)}
    optimizers = {
        'age': optax.adam(age_lr),
        'metallicity': optax.adam(metallicity_lr)
    }
    # Create a parameter label pytree matching the structure of params
    param_labels = {'age': 'age', 'metallicity': 'metallicity'}
    
    # Combine the optimizers with multi_transform
    optimizer = optax.multi_transform(optimizers, param_labels)
    optimizer_state = optimizer.init(params)
    
    age_history = []
    metallicity_history = []
    loss_history = []
    
    for i in range(max_iter):
        # Compute loss and gradients with respect to both parameters
        loss, grads = jax.value_and_grad(lambda p: loss_fn(p['age'], p['metallicity'], data, target))(params)
        loss_history.append(float(loss))
        # Save current parameters (convert from JAX arrays to floats)
        age_history.append(float(params['age'][0,0]))
        metallicity_history.append(float(params['metallicity'][0,0]))
        #params_history.append({
        #    'age': params['age'],
        #    'metallicity': params['metallicity']
        #})
        
        # Compute updates and apply them
        updates, optimizer_state = optimizer.update(grads, optimizer_state)
        params = optax.apply_updates(params, updates)
        
        # Optionally clip the parameters to enforce physical constraints:
        params['age'] = jnp.clip(params['age'], 0.0, 10.3)
        params['metallicity'] = jnp.clip(params['metallicity'], 1e-4, 0.05)
        # For metallicity, uncomment and adjust the limits as needed:
        # params['metallicity'] = jnp.clip(params['metallicity'], metallicity_lower_bound, metallicity_upper_bound)
        
        # Check convergence based on the global norm of updates
        if optax.global_norm(updates) < tol:
            print(f"Converged at iteration {i}")
            break

    return params, age_history, metallicity_history, loss_history

In [None]:
"""
import jax
import jax.numpy as jnp
import optax
def adam_optimization_multi_normalized(loss_fn, params_init, data, target, age_lr=0.1, metallicity_lr=0.05, tol=1e-3, max_iter=500):
    
    Optimizes normalized age and metallicity. 
    Assumes age in [0, 14] and metallicity in [1e-4, 0.05]. 
    Normalized values are in [0, 1].
    
    # Define normalization functions
    def normalize_age(age):
        return age / 14.0
    def denormalize_age(age_norm):
        return age_norm * 14.0
    def normalize_metallicity(met):
        return (met - 1e-4) / (0.05 - 1e-4)
    def denormalize_metallicity(met_norm):
        return met_norm * (0.05 - 1e-4) + 1e-4

    # Normalize initial parameters
    params_norm = {
        'age': normalize_age(params_init['age']),
        'metallicity': normalize_metallicity(params_init['metallicity'])
    }
    
    # Create separate optimizers with different learning rates for normalized parameters
    optimizers = {
        'age': optax.adam(age_lr),
        'metallicity': optax.adam(metallicity_lr)
    }
    param_labels = {'age': 'age', 'metallicity': 'metallicity'}
    optimizer = optax.multi_transform(optimizers, param_labels)
    optimizer_state = optimizer.init(params_norm)
    
    age_history = []
    metallicity_history = []
    loss_history = []
    
    for i in range(max_iter):
        # Compute loss on the denormalized parameters
        loss, grads = jax.value_and_grad(
            lambda p: loss_fn(denormalize_age(p['age']),
                              denormalize_metallicity(p['metallicity']),
                              data,
                              target)
        )(params_norm)
        
        loss_history.append(float(loss))
        # Save history in original scale
        age_history.append(float(denormalize_age(params_norm['age'])[0,0]))
        metallicity_history.append(float(denormalize_metallicity(params_norm['metallicity'])[0,0]))
        
        updates, optimizer_state = optimizer.update(grads, optimizer_state, params_norm)
        params_norm = optax.apply_updates(params_norm, updates)

        # Optionally clip the parameters to enforce physical constraints:
        params_norm['age'] = jnp.clip(params_norm['age'], normalize_age(0), normalize_age(14))
        params_norm['metallicity'] = jnp.clip(params_norm['metallicity'], normalize_metallicity(1e-4), normalize_metallicity(0.05))
        
        # Check convergence based on the global norm of updates
        if optax.global_norm(updates) < tol:
            print(f"Converged at iteration {i}")
            break

    # Denormalize the final parameters
    final_params = {
        'age': denormalize_age(params_norm['age']),
        'metallicity': denormalize_metallicity(params_norm['metallicity'])
    }
    return final_params, age_history, metallicity_history, loss_history
"""

In [None]:
data = initial  # Replace with your actual data if needed
target_value = target  # Replace with your actual target

# Define initial guesses for both age and metallicity.
# Adjust the initialization as needed for your problem.
age_init = jnp.array([[age_values[initial_age_index], age_values[initial_age_index]]])
metallicity_init = jnp.array([[metallicity_values[initial_metallicity_index], metallicity_values[initial_metallicity_index]]])

# Pack both initial parameters into a dictionary.
params_init = {'age': age_init, 'metallicity': metallicity_init}

# Call the new optimizer function that handles both parameters.
optimized_params, age_history, metallicity_history, loss_history = adam_optimization_multi(
    loss_only_wrt_age_metallicity,
    params_init,
    data,
    target_value,
    age_lr=learning_age,
    metallicity_lr=learning_metallicity,
    tol=tol,
    max_iter=10000
)

print(f"Optimized Age: {optimized_params['age']}")
print(f"Optimized Metallicity: {optimized_params['metallicity']}")


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# If loss_history is a JAX array, convert it to a NumPy array:
loss_history_np = np.array(loss_history)

# Create an array for the x-axis indices
indices = np.arange(len(loss_history_np))

plt.figure(figsize=(8, 6))
plt.plot(indices, loss_history_np, marker='o', linestyle='-')
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title("Loss History")
plt.grid(True)
#plt.ylim(-1e-8, 2e-7)
#plt.savefig(f"./output/optimizer/optimization_progress_loss_metals{metallicity_index}_agestart{initial_age_index}_to{index}_learning{learning}_tol{tol}.png")
plt.show()

In [None]:
optimized_params

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# If age_history is a JAX array, convert it to a NumPy array:
age_history_np = np.array(age_history)

# Create an array for the x-axis indices
indices = np.arange(len(age_history_np))

plt.figure(figsize=(8, 6))
plt.plot(indices, age_history_np, marker='o', linestyle='-')
plt.hlines(y=age_values[index_age], xmin=0, xmax=len(age_history_np), color='r', linestyle='-')
plt.xlabel("Iteration")
plt.ylabel("Age")
plt.title("Age History")
plt.grid(True)
#plt.savefig(f"./output/optimizer/optimization_progress_age_metals{metallicity_index}_agestart{initial_age_index}_to{index}_learning{learning}_tol{tol}.png")
plt.show()

In [None]:
metallicity_history_np = np.array(metallicity_history)

# Create an array for the x-axis indices
indices = np.arange(len(metallicity_history_np))

plt.figure(figsize=(8, 6))
plt.plot(indices, metallicity_history_np, marker='o', linestyle='-')
plt.hlines(y=metallicity_values[index_metallicity], xmin=0, xmax=len(metallicity_history_np), color='r', linestyle='-')
plt.xlabel("Iteration")
plt.ylabel("Metallicity")
plt.title("Metallicity History")
plt.grid(True)
#plt.savefig(f"./output/optimizer/optimization_progress_metals{metallicity_index}_agestart{initial_age_index}_to{index}_learning{learning}_tol{tol}.png")
plt.show()

In [None]:
#run the pipeline with the optimized age
#rubixdata.stars.age = optimized_age
i = 10
rubixdata.stars.age = jnp.array([age_history[i], age_history[i]])
rubixdata.stars.metallicity = jnp.array([metallicity_history[i], metallicity_history[i]])
rubixdata.stars.mass = jnp.array([[1.0, 1.0]])
rubixdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])

pipe = RubixPipeline(config)
rubixdata = pipe.run(rubixdata)

#plot the target and the optimized spectra
import matplotlib.pyplot as plt
wave = pipe.telescope.wave_seq

spectra_target = target.stars.datacube
spectra_optimitzed = rubixdata.stars.datacube

plt.plot(wave, spectra_target[0,0,:], label=f"Target age = {age_values[index_age]:.2f}, metal. = {metallicity_values[index_metallicity]:.4f}")
plt.plot(wave, spectra_optimitzed[0,0,:], label=f"Optimized age = {age_history[i]:.2f}, metal. = {metallicity_history[i]:.4f}")
plt.xlabel("Wavelength [Å]")
plt.ylabel("Luminosity [L/Å]")
plt.title(f"Loss {loss_history[i]:.2e}")
plt.legend()
plt.ylim(0, 0.0002)
plt.grid(True)
plt.show()

# Create gif

In [None]:
import matplotlib.pyplot as plt
import imageio
import jax.numpy as jnp
# Create a temporary directory to save frame images
frames_folder = 'frames'
os.makedirs(frames_folder, exist_ok=True)

frame_files = []

wave = pipe.telescope.wave_seq

# Loop over the indices in age_history to generate frames.
for i in range(len(age_history)):
    # Update the Rubix data object with current parameters.
    # (Adjust indexing as needed; here age_history[i,0,0] is used.)
    rubixdata.stars.age = jnp.array([age_history[i], age_history[i]])
    rubixdata.stars.metallicity = jnp.array([metallicity_history[i], metallicity_history[i]])
    rubixdata.stars.mass = jnp.array([[1.0, 1.0]])
    rubixdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])

    #pipe = RubixPipeline(config)
    #rubixdata = pipe.run(rubixdata)
    rubixdata = pipeline_instance.func(rubixdata)

    #plot the target and the optimized spectra
    import matplotlib.pyplot as plt
    

    spectra_target = target.stars.datacube
    spectra_optimitzed = rubixdata.stars.datacube


    # Create the plot
    plt.figure(figsize=(8, 6))
    plt.plot(wave, spectra_target[0,0,:], label=f"Target age = {age_values[index_age]:.2f}, metal. = {metallicity_values[index_metallicity]:.4f}")
    plt.plot(wave, spectra_optimitzed[0,0,:], label=f"Optimized age = {age_history[i]:.2f}, metal. = {metallicity_history[i]:.4f}")
    plt.xlabel("Wavelength [Å]")
    plt.ylabel("Luminosity [L/Å]")
    plt.title(f"Loss {loss_history[i]:.2e}")
    plt.legend()
    plt.ylim(0, 0.0002)
    plt.grid(True)
    
    # Save the frame as an image file
    frame_filename = os.path.join(frames_folder, f"frame_{i:03d}.png")
    plt.savefig(frame_filename)
    plt.close()
    frame_files.append(frame_filename)



# Optionally, clean up the temporary frames folder
#shutil.rmtree(frames_folder)

In [None]:
# Create the GIF from the saved frames
gif_filename = f"./output/optimizer/gifs/optimization_progress_both_metals{initial_metallicity_index}_to{index_metallicity}_agestart{initial_age_index}_to{index_age}_learning{learning_age}_{learning_metallicity}_tol{tol}.gif"
with imageio.get_writer(gif_filename, mode='I', duration=0.2) as writer:
    for frame_file in frame_files:
        image = imageio.imread(frame_file)
        writer.append_data(image)

print(f"GIF created: {gif_filename}")