# Gradient calculation

In this notebook we show, how you can calculate teh full Jacobian of the pipeline with respect to the input data.

First of all, we define our config as input for the pipeline, set up the pipeline and let it run.

In [None]:
from rubix.core.pipeline import RubixPipeline
# Suppose you already have a user_config or path to config
#config = "../rubix/config/pipeline_config.yaml"
import os
config = {
    "pipeline":{"name": "calc_gradient"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 14,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": True,
            "subset_size": 2,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-14.hdf5",
        },
    
    },
    "output_path": "output",
    "output_modified":  False,

    "telescope":
        {"name": "TESTGRADIENT",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 1,"noise_distribution": "normal"},
         },
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "BruzualCharlot2003"
        },
    },        
}
pipe = RubixPipeline(config)
rubixdata = pipe.run()

Then we can plot the datacube as usual.

In [None]:
import matplotlib.pyplot as plt
import jax.numpy as jnp
plt.imshow(jnp.sum(rubixdata.stars.datacube, axis=2), origin="lower", cmap="inferno")
plt.colorbar()

In [None]:
#NBVAL_SKIP
wave = pipe.telescope.wave_seq

spectra = rubixdata.stars.datacube # Spectra of all stars
print(spectra.shape)

plt.plot(wave, spectra[0,0,:])

Now we are in the stage to calculate the Jacobian. Therefore we need a list of all the transformers in the pipeline. We then calculate an abstract form of the gradient with these expressions. Finally we can calculate the gradient for teh given input data.

In [None]:
# _get_pipeline_functions() returns a list of the transformer functions in the correct order
transformers_list = pipe._get_pipeline_functions()
print(transformers_list)  # Debug: see the list of JAX-compatible functions

from rubix.utils import read_yaml
read_cfg = read_yaml("../rubix/config/pipeline_config.yml")

# read_cfg is a dict. We specifically want read_cfg["calc_ifu"], which has "Transformers" inside.
pipeline_cfg = read_cfg["calc_gradient"]

from rubix.pipeline import linear_pipeline as ltp

tp = ltp.LinearTransformerPipeline(
    pipeline_cfg,      # pipeline_cfg == read_cfg["calc_ifu"]
    transformers_list, # The list of function objects from RubixPipeline
)

compiled_fn = tp.compile_expression()

# Evaluate pipeline
#output = compiled_fn(rubixdata)

# Calculate gradient
import jax

jac_fn = jax.jacrev(compiled_fn)
jacobian = jac_fn(rubixdata)
print(jacobian)


In [None]:
import jax.tree_util as jtu

# Flatten the Jacobian to inspect the tree structure
flat_values, tree_def = jtu.tree_flatten(jacobian)
print("Number of elements in the tree:", len(flat_values))

# Print each element's type and shape
for i, value in enumerate(flat_values):
    print(f"Element {i}: Type={type(value)}, Shape={getattr(value, 'shape', 'N/A')}")

# Optionally print the tree definition for clarity
print(tree_def)


# Function to print or process leaves in the PyTree
def process_leaf(path, leaf):
    print(f"Path: {path}, Type: {type(leaf)}, Shape: {getattr(leaf, 'shape', 'N/A')}")

# Walk through the tree structure
jtu.tree_map_with_path(process_leaf, jacobian)


In [None]:
stars_gradient = jacobian.stars.datacube
stars_gradient

In [None]:
print(rubixdata.stars.datacube.shape)
print(rubixdata.stars.age)
print(stars_gradient.stars.age)

In [None]:
rubixdata.stars

In [None]:
rubixdata.stars.age = jnp.array([[0.0, 0.0]])
rubixdata.stars.metallicity = jnp.array([[0.0, 0.0]])
rubixdata.stars.mass = jnp.array([[0.0, 0.0]])

import pickle
pickle.dump(rubixdata, open("output/rubix_galaxy.pkl", "wb"))

In [None]:
from rubix.core.pipeline import RubixPipeline
# Suppose you already have a user_config or path to config
#config = "../rubix/config/pipeline_config.yaml"
import os
config = {
    "pipeline":{"name": "calc_gradient"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 14,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": True,
            "subset_size": 2,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-14.hdf5",
        },
    
    },
    "output_path": "output",
    "output_modified": True,
    "modified_path": "output/rubix_galaxy.pkl",

    "telescope":
        {"name": "TESTGRADIENT",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 1,"noise_distribution": "normal"},
         },
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "BruzualCharlot2003"
        },
    },        
}
pipe = RubixPipeline(config)
rubixdata = pipe.run()

In [None]:
import matplotlib.pyplot as plt
import jax.numpy as jnp
plt.imshow(jnp.sum(rubixdata.stars.datacube, axis=2), origin="lower", cmap="inferno")
plt.colorbar()

In [None]:
#NBVAL_SKIP
wave = pipe.telescope.wave_seq

spectra = rubixdata.stars.datacube # Spectra of all stars
print(spectra.shape)

plt.plot(wave, spectra[0,0,:])