In [None]:
import os
os.environ['SPS_HOME'] = '/home/annalena/sps_fsps'

# Gradient calculation

In this notebook we show, how you can calculate teh full Jacobian of the pipeline with respect to the input data.

First of all, we define our config as input for the pipeline, set up the pipeline and let it run as we du normally the forward modeling, so downloading IllustrisTNG data and transforming the data in the linear pipeline and end up with the datacube.

In [None]:
from rubix.core.pipeline import RubixPipeline
# Suppose you already have a user_config or path to config
#config = "../rubix/config/pipeline_config.yaml"
import os
config = {
    "pipeline":{"name": "calc_gradient"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 14,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": True,
            "subset_size": 2,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-14.hdf5",
        },
    
    },
    "output_path": "output",
    "output_modified":  False,

    "telescope":
        {"name": "TESTGRADIENT",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 1,"noise_distribution": "normal"},
         },
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            #"name": "BruzualCharlot2003"
            #"name": "Mastar_CB19_SLOG_1_5"
            "name": "FSPS"
        },
    },        
}
pipe = RubixPipeline(config)
rubixdata = pipe.run()

#target_datacube = rubixdata.stars.datacube
#target_age = rubixdata.stars.age
#target_metallicity = rubixdata.stars.metallicity

# rubixdata as input for the pipeline

For gradient based optimization, it would be good to give the pipeline directly the modified rubixdata object and calculate the new datacube. This is now possible, if you load the pipeline from rubix.core.pipeline_gradient instead of rubix.core.pipeline. You set up the Pipeline and then you pass the rubixdata to the run function

In [None]:
#NBVAL_SKIP
from rubix.core.pipeline_gradient import RubixPipeline
# Suppose you already have a user_config or path to config
#config = "../rubix/config/pipeline_config.yaml"
import os
config = {
    "pipeline":{"name": "calc_gradient"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
     "data": {
        "args": {
            "particle_type": ["stars"],
        },
    },
    
    "output_path": "output",
    "output_modified":  False,

    "telescope":
        {"name": "TESTGRADIENT",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 1,"noise_distribution": "normal"},
         },
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            #"name": "BruzualCharlot2003"
            #"name": "Mastar_CB19_SLOG_1_5"
            "name": "FSPS"
        },
    },        
}
pipe = RubixPipeline(config)
rubixdata = pipe.run(rubixdata)

In [None]:
#NBVAL_SKIP
import matplotlib.pyplot as plt

wave = pipe.telescope.wave_seq

spectra = rubixdata.stars.datacube
#target_spectra = target_datacube[0,0,:]
plt.plot(wave, spectra[0,0,:], label="current spectrum")
#plt.plot(wave, target_spectra, label="target")
plt.legend(loc="upper right")

# Gradient calculation

The RubixPipeline from rubix.core.pipeline_gradient has als the gradient function implemented. You can just call the function and pass your rubixdata and then you get the gradfient returned.

In [None]:
#NBVAL_SKIP
gradient = pipe.gradient(rubixdata)

gradient_age = gradient.stars.age

In [None]:
import jax.numpy as jnp

rubixdata.stars.age = jnp.array([5.0, 5.0])
rubixdata.stars.metallicity = jnp.array([0.005, 0.005])

#calculate datacube with new age and metallicity
target_rubixdata = pipe.run(rubixdata)
target_datacube = target_rubixdata.stars.datacube

In [None]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt

# Define the parameter ranges
age_values = jnp.linspace(0, 10.30103, 20)
metallicity_values = jnp.logspace(np.log10(1e-4), np.log10(0.05), 20)

# Initialize storage arrays
results = []

# Iterate over age and metallicity values
for age_i in age_values:
    for metallicity_i in metallicity_values:
        rubixdata.stars.age = jnp.array([5.0, age_i])
        rubixdata.stars.metallicity = jnp.array([0.005, metallicity_i])

        # Calculate datacube and loss
        rubixdata = pipe.run(rubixdata)
        loss = pipe.loss_mse(rubixdata.stars.datacube, target_datacube)

        # Calculate gradient
        gradient = pipe.gradient(rubixdata)

        # Reshape for calculations
        cube_reshaped = (rubixdata.stars.datacube - target_datacube).reshape(6, 1)
        grad_age_reshaped = gradient.stars.age.reshape(6, 2)
        grad_metallicity_reshaped = gradient.stars.metallicity.reshape(6, 2)

        # Update rules
        update_age = (1 / 6) * 2 * jnp.sum(cube_reshaped * grad_age_reshaped, axis=0)
        update_metallicity = (1 / 6) * 2 * jnp.sum(cube_reshaped * grad_metallicity_reshaped, axis=0)

        # Store the results
        results.append((age_i, metallicity_i, loss, update_age[1], update_metallicity[1]))

# Convert results to numpy array for easy plotting
results = np.array(results)
age_vals, metallicity_vals, losses, grad_ages, grad_metallicities = (
    results[:, 0],
    results[:, 1],
    results[:, 2],
    results[:, 3],
    results[:, 4],
)

# Reshape data for plotting
age_grid, metallicity_grid = np.meshgrid(
    np.unique(age_vals), np.unique(metallicity_vals)
)
loss_grid = losses.reshape(metallicity_grid.shape)
grad_age_grid = grad_ages.reshape(metallicity_grid.shape)
grad_metallicity_grid = grad_metallicities.reshape(metallicity_grid.shape)

"""
# Plotting the loss as a heatmap
plt.figure(figsize=(10, 8))
plt.imshow(
    loss_grid,
    extent=[age_values[0], age_values[-1], metallicity_values[0], metallicity_values[-1]],
    origin="lower",
    aspect="auto",
    cmap="viridis",
)
plt.colorbar(label="Loss")
plt.quiver(
    age_grid,
    metallicity_grid,
    grad_age_grid,
    grad_metallicity_grid,
    color="white",
    scale=1,
    scale_units="xy",
    angles="xy",
)
plt.xlabel("Age")
plt.ylabel("Metallicity")
plt.title("Loss Heatmap with Gradient Arrows")
plt.show()
"""

In [None]:
# Plotting the loss as a heatmap
plt.figure(figsize=(10, 8))
plt.imshow(
    loss_grid,
    extent=[age_values[0], age_values[-1], metallicity_values[0], metallicity_values[-1]],
    origin="lower",
    aspect="auto",
    cmap="jet",
    norm=LogNorm(),
)
plt.colorbar(label="Loss")

# Set log scale for the y-axis
plt.yscale("log")

# Add a red dot at age = 5 and metallicity = 0.005
plt.scatter(5, 0.005, color="red", label="Point (5, 0.005)", zorder=200)

# Label the axes
plt.xlabel("Age")
plt.ylabel("Metallicity (log scale)")
plt.title("Loss Heatmap")
plt.savefig("output/loss/loss_heatmap_FSPS.png")

In [None]:
import pickle

# Assuming age_vals, metallicity_vals, losses, grad_ages, grad_metallicities are already defined
data = {
    'age_vals': age_vals,
    'metallicity_vals': metallicity_vals,
    'losses': losses,
    'grad_ages': grad_ages,
    'grad_metallicities': grad_metallicities
}

with open('output/loss/data_FSPS.pkl', 'wb') as f:
    pickle.dump(data, f)

In [None]:
import pickle

with open('output/loss/data.pkl', 'rb') as f:
    data = pickle.load(f)

age_vals = data['age_vals']
metallicity_vals = data['metallicity_vals']
losses = data['losses']
grad_ages = data['grad_ages']
grad_metallicities = data['grad_metallicities']

In [None]:
import matplotlib.pyplot as plt
# Plotting the loss as a heatmap
plt.figure(figsize=(10, 8))
plt.imshow(
    loss_grid,
    extent=[age_values[0], age_values[-1], metallicity_values[0], metallicity_values[-1]],
    origin="lower",
    aspect="auto",
    cmap="jet",
    norm=LogNorm(),
)
plt.colorbar(label="Loss")

# Set log scale for the y-axis
plt.yscale("log")

# Add a red dot at age = 5 and metallicity = 0.005
plt.scatter(5, 0.005, color="red", label="Point (5, 0.005)", zorder=200)

# Label the axes
plt.xlabel("Age")
plt.ylabel("Metallicity (log scale)")
plt.title("Loss Heatmap")
plt.savefig("output/loss/loss_heatmap_FSPS.png")