# Gradient calculation

In this notebook we show, how you can calculate teh full Jacobian of the pipeline with respect to the input data.

First of all, we define our config as input for the pipeline, set up the pipeline and let it run as we du normally the forward modeling, so downloading IllustrisTNG data and transforming the data in the linear pipeline and end up with the datacube.

In [None]:
from rubix.core.pipeline import RubixPipeline
# Suppose you already have a user_config or path to config
#config = "../rubix/config/pipeline_config.yaml"
import os
config = {
    "pipeline":{"name": "calc_gradient"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 14,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": True,
            "subset_size": 2,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-14.hdf5",
        },
    
    },
    "output_path": "output",
    "output_modified":  False,

    "telescope":
        {"name": "TESTGRADIENT",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 1,"noise_distribution": "normal"},
         },
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "BruzualCharlot2003"
        },
    },        
}
pipe = RubixPipeline(config)
rubixdata = pipe.run()

#target_datacube = rubixdata.stars.datacube
#target_age = rubixdata.stars.age
#target_metallicity = rubixdata.stars.metallicity

# rubixdata as input for the pipeline

For gradient based optimization, it would be good to give the pipeline directly the modified rubixdata object and calculate the new datacube. This is now possible, if you load the pipeline from rubix.core.pipeline_gradient instead of rubix.core.pipeline. You set up the Pipeline and then you pass the rubixdata to the run function

In [None]:
#NBVAL_SKIP
from rubix.core.pipeline_gradient import RubixPipeline
# Suppose you already have a user_config or path to config
#config = "../rubix/config/pipeline_config.yaml"
import os
config = {
    "pipeline":{"name": "calc_gradient"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
     "data": {
        "args": {
            "particle_type": ["stars"],
        },
    },
    
    "output_path": "output",
    "output_modified":  False,

    "telescope":
        {"name": "TESTGRADIENT",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 1,"noise_distribution": "normal"},
         },
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "BruzualCharlot2003"
        },
    },        
}
pipe = RubixPipeline(config)
rubixdata = pipe.run(rubixdata)

In [None]:
#NBVAL_SKIP
import matplotlib.pyplot as plt

wave = pipe.telescope.wave_seq

spectra = rubixdata.stars.datacube
#target_spectra = target_datacube[0,0,:]
plt.plot(wave, spectra[0,0,:], label="current spectrum")
#plt.plot(wave, target_spectra, label="target")
plt.legend(loc="upper right")

# Gradient calculation

The RubixPipeline from rubix.core.pipeline_gradient has als the gradient function implemented. You can just call the function and pass your rubixdata and then you get the gradfient returned.

In [None]:
#NBVAL_SKIP
gradient = pipe.gradient(rubixdata)

gradient_age = gradient.stars.age

In [None]:
import jax.numpy as jnp

rubixdata.stars.age = jnp.array([5.0, 5.0])
rubixdata.stars.metallicity = jnp.array([0.005, 0.005])

#calculate datacube with new age and metallicity
target_rubixdata = pipe.run(rubixdata)
target_datacube = target_rubixdata.stars.datacube

In [None]:
import pickle

with open('output/loss/data.pkl', 'rb') as f:
    data = pickle.load(f)

age_vals = data['age_vals']
metallicity_vals = data['metallicity_vals']
losses = data['losses']
grad_ages = data['grad_ages']
grad_metallicities = data['grad_metallicities']

# Gradient descent optimization

In [None]:
import jax
import jax.numpy as jnp
import optax

# Define the optimization function
def rubix_loss(params, rubixdata):
    age, metallicity = params
    rubixdata.stars.age = jnp.array([5.0, age])
    rubixdata.stars.metallicity = jnp.array([0.005, metallicity])
    rubixdata = pipe.run(rubixdata)
    return pipe.loss_mse(rubixdata.stars.datacube, target_datacube)

def gradient_descent_optimization(func, x_init, learning_rate=0.001, tol=1e-3, max_iter=100):
    xlist = []
    loss_list = []
    x = x_init
    xlist.append(x)

    optimizer = optax.adam(learning_rate=learning_rate)
    optimizer_state = optimizer.init(x)

    for i in range(max_iter):
        loss, grad = jax.value_and_grad(func)(x, rubixdata)
        loss_list.append(loss)
        
        updates, optimizer_state = optimizer.update(grad, optimizer_state)
        x = optax.apply_updates(x, updates)
        
        # Clip to enforce constraints
        x = jnp.array([
            jnp.clip(x[0], 0.0, 10.3),      # Age constraint
            jnp.clip(x[1], 1e-4, 0.05),    # Metallicity constraint
        ])

        xlist.append(x)

        if jnp.linalg.norm(updates) < tol:
            print(f"Converged at iteration {i}")
            break

    return x, jnp.array(xlist), jnp.array(loss_list)

# Initial parameters
x_init1 = jnp.array([10.0, 0.05])  # Initial age and metallicity
x_init2 = jnp.array([0.0, 0.05])  # Initial age and metallicity
x_init3 = jnp.array([6.0, 0.006])  # Initial age and metallicity
x_init4 = jnp.array([0.0, 0.0001])  # Initial age and metallicity
x_init5 = jnp.array([10.0, 0.0001])  # Initial age and metallicity
learning_rate = 0.1
tolerance = 1e-3
max_iterations = 500

# Run gradient descent
optimized_params1, param_history1, loss_history1 = gradient_descent_optimization(
    rubix_loss, x_init1, learning_rate=learning_rate, tol=tolerance, max_iter=max_iterations
)
optimized_params2, param_history2, loss_history2 = gradient_descent_optimization(
    rubix_loss, x_init2, learning_rate=learning_rate, tol=tolerance, max_iter=max_iterations
)
optimized_params3, param_history3, loss_history3 = gradient_descent_optimization(
    rubix_loss, x_init3, learning_rate=learning_rate, tol=tolerance, max_iter=max_iterations
)
optimized_params4, param_history4, loss_history4 = gradient_descent_optimization(
    rubix_loss, x_init4, learning_rate=learning_rate, tol=tolerance, max_iter=max_iterations
)
optimized_params5, param_history5, loss_history5 = gradient_descent_optimization(
    rubix_loss, x_init5, learning_rate=learning_rate, tol=tolerance, max_iter=max_iterations
)

# Extract the results
optimized_age, optimized_metallicity = optimized_params1
print(f"Optimized Age: {optimized_age}")
print(f"Optimized Metallicity: {optimized_metallicity}")

In [None]:
# Plot the loss history
plt.figure(figsize=(10, 5))
plt.plot(loss_history1, label="Loss 1")
plt.plot(loss_history2, label="Loss 2")
plt.plot(loss_history3, label="Loss 3")
plt.plot(loss_history4, label="Loss 4")
plt.plot(loss_history5, label="Loss 5")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.title("Loss vs. Iterations")
plt.legend()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Define the parameter ranges
age_values = jnp.linspace(0, 10.30103, 20)
metallicity_values = jnp.logspace(np.log10(1e-4), np.log10(0.05), 20)

# Reshape data for plotting
age_grid, metallicity_grid = np.meshgrid(
    np.unique(age_vals), np.unique(metallicity_vals)
)
loss_grid = losses.reshape(metallicity_grid.shape)
grad_age_grid = grad_ages.reshape(metallicity_grid.shape)
grad_metallicity_grid = grad_metallicities.reshape(metallicity_grid.shape)


# Plotting the trajectory in the parameter space
plt.figure(figsize=(10, 8))
plt.imshow(
    loss_grid,
    extent=[age_values[0], age_values[-1], metallicity_values[0], metallicity_values[-1]],
    origin="lower",
    aspect="auto",
    cmap="jet",
)
plt.colorbar(label="Loss")

# Set log scale for the y-axis
plt.yscale("log")

# Add a red dot at age = 5 and metallicity = 0.005
plt.scatter(5, 0.005, color="red", label="Point (5, 0.005)", zorder=200)

plt.plot(param_history1[:, 0], param_history1[:, 1], 'r', label="Optimization Path 1")
plt.plot(param_history2[:, 0], param_history2[:, 1], 'g', label="Optimization Path 2")
plt.plot(param_history3[:, 0], param_history3[:, 1], 'b', label="Optimization Path 3")
plt.plot(param_history4[:, 0], param_history4[:, 1], 'y', label="Optimization Path 4")
plt.plot(param_history5[:, 0], param_history5[:, 1], 'm', label="Optimization Path 5")
plt.xlabel("Age")
plt.ylabel("Metallicity (log scale)")
plt.title("Loss Heatmap with Optimization Path")
plt.legend()
plt.show()

In [None]:
param_history1[:, 1]

# Levenberg-Marquart from the optimistix package

In [None]:
from optimistix import ScipyLeastSquares

loss_history = []
param_history = []

def residual_fn(params):
    param_history.append(params)  # Track the parameters
    age, metallicity = params
    rubixdata.stars.age = jnp.array([5.0, age])
    rubixdata.stars.metallicity = jnp.array([0.005, metallicity])
    rubixdata = pipe.run(rubixdata)
    residuals = rubixdata.stars.datacube - target_datacube
    loss_history.append(jnp.sum(residuals**2))  # Track the loss
    return residuals.flatten()  # Residuals as a vector

# Initial guess
initial_guess = jnp.array([5.0, 0.005])

# Run the optimizer
optimizer = ScipyLeastSquares(residual_fn)
optimized_params = optimizer.run(initial_guess)

print("Optimized Age:", optimized_params[0])
print("Optimized Metallicity:", optimized_params[1])

In [None]:
# Plot the loss history
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 6))
plt.plot(loss_history, label="Loss")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.title("Loss vs. Iterations")
plt.legend()
plt.show()

In [None]:
# Convert parameter history to numpy array
param_history = jnp.array(param_history)

# Plot the trajectory
plt.figure(figsize=(10, 8))
plt.plot(param_history[:, 0], param_history[:, 1], 'r.-', label="Trajectory")
plt.scatter(optimized_params[0], optimized_params[1], color='blue', label="Optimal Point")
plt.xlabel("Age")
plt.ylabel("Metallicity")
plt.title("Optimization Trajectory in Parameter Space")
plt.legend()
plt.show()

In [None]:
# Generate a loss grid
age_values = jnp.linspace(0, 10.3, 50)
metallicity_values = jnp.logspace(np.log10(1e-4), np.log10(0.05), 50)
loss_grid = jnp.zeros((len(age_values), len(metallicity_values)))

for i, age in enumerate(age_values):
    for j, metallicity in enumerate(metallicity_values):
        rubixdata.stars.age = jnp.array([5.0, age])
        rubixdata.stars.metallicity = jnp.array([0.005, metallicity])
        rubixdata = pipe.run(rubixdata)
        residuals = rubixdata.stars.datacube - target_datacube
        loss_grid[i, j] = jnp.sum(residuals**2)

# Plot the heatmap with trajectory
plt.figure(figsize=(12, 8))
plt.imshow(
    loss_grid.T,
    extent=[age_values[0], age_values[-1], metallicity_values[0], metallicity_values[-1]],
    origin="lower",
    aspect="auto",
    cmap="viridis"
)
plt.colorbar(label="Loss")
plt.plot(param_history[:, 0], param_history[:, 1], 'r.-', label="Optimization Path")
plt.scatter(optimized_params[0], optimized_params[1], color='blue', label="Optimal Point")
plt.xlabel("Age")
plt.ylabel("Metallicity")
plt.title("Loss Landscape and Optimization Path")
plt.legend()
plt.yscale("log")
plt.show()
