# Gradient calculation

In this notebook we show, how you can calculate teh full Jacobian of the pipeline with respect to the input data.

First of all, we define our config as input for the pipeline, set up the pipeline and let it run.

In [None]:
from rubix.core.pipeline import RubixPipeline
# Suppose you already have a user_config or path to config
#config = "../rubix/config/pipeline_config.yaml"
import os
config = {
    "pipeline":{"name": "calc_gradient"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 14,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": True,
            "subset_size": 2,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-14.hdf5",
        },
    
    },
    "output_path": "output",
    "output_modified":  False,

    "telescope":
        {"name": "TESTGRADIENT",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 1,"noise_distribution": "normal"},
         },
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "BruzualCharlot2003"
        },
    },        
}
pipe = RubixPipeline(config)
rubixdata = pipe.run()

target_datacube = rubixdata.stars.datacube

Then we can plot the datacube as usual.

In [None]:
import matplotlib.pyplot as plt
import jax.numpy as jnp
plt.imshow(jnp.sum(rubixdata.stars.datacube, axis=2), origin="lower", cmap="inferno")
plt.colorbar()

In [None]:
#NBVAL_SKIP
wave = pipe.telescope.wave_seq

spectra = rubixdata.stars.datacube # Spectra of all stars
print(spectra.shape)
target_spectra = spectra[0,0,:]
plt.plot(wave, spectra[0,0,:])

Now we are in the stage to calculate the Jacobian. Therefore we need a list of all the transformers in the pipeline. We then calculate an abstract form of the gradient with these expressions. Finally we can calculate the gradient for teh given input data.

In [None]:
# _get_pipeline_functions() returns a list of the transformer functions in the correct order
transformers_list = pipe._get_pipeline_functions()
print(transformers_list)  # Debug: see the list of JAX-compatible functions

from rubix.utils import read_yaml
read_cfg = read_yaml("../rubix/config/pipeline_config.yml")

# read_cfg is a dict. We specifically want read_cfg["calc_ifu"], which has "Transformers" inside.
pipeline_cfg = read_cfg["calc_gradient"]

from rubix.pipeline import linear_pipeline as ltp

tp = ltp.LinearTransformerPipeline(
    pipeline_cfg,      # pipeline_cfg == read_cfg["calc_ifu"]
    transformers_list, # The list of function objects from RubixPipeline
)

compiled_fn = tp.compile_expression()

# Evaluate pipeline
#output = compiled_fn(rubixdata)

# Calculate gradient
import jax

jac_fn = jax.jacrev(compiled_fn)
jacobian = jac_fn(rubixdata)
print(jacobian)


In [None]:
import jax.tree_util as jtu

# Flatten the Jacobian to inspect the tree structure
flat_values, tree_def = jtu.tree_flatten(jacobian)
print("Number of elements in the tree:", len(flat_values))

# Print each element's type and shape
for i, value in enumerate(flat_values):
    print(f"Element {i}: Type={type(value)}, Shape={getattr(value, 'shape', 'N/A')}")

# Optionally print the tree definition for clarity
print(tree_def)


# Function to print or process leaves in the PyTree
def process_leaf(path, leaf):
    print(f"Path: {path}, Type: {type(leaf)}, Shape: {getattr(leaf, 'shape', 'N/A')}")

# Walk through the tree structure
jtu.tree_map_with_path(process_leaf, jacobian)


In [None]:
stars_gradient = jacobian.stars.datacube
stars_gradient.stars.metallicity

In [None]:
print(rubixdata.stars.datacube.shape)
print(rubixdata.stars.age)
print(stars_gradient.stars.age)

In [None]:
rubixdata.stars

In [None]:
print(rubixdata.stars.age)
print(rubixdata.stars.age.shape)
print(rubixdata.stars.metallicity)
print(rubixdata.stars.mass)

In [None]:
rubixdata.stars.age = jnp.array([5.0, 5.0])
rubixdata.stars.metallicity = jnp.array([0.01, 0.01])
#rubixdata.stars.mass = jnp.array([[10000.0, 10000.0]])

#rubixdata.stars.age = rubixdata.stars.age.at[0, 0].set(5.0)
#rubixdata.stars.age = rubixdata.stars.age.at[0, 1].set(5.0)
#rubixdata.stars.metallicity = rubixdata.stars.metallicity.at[0, 0].set(0.01)
#rubixdata.stars.metallicity = rubixdata.stars.metallicity.at[0, 1].set(0.01)

import pickle
pickle.dump(rubixdata, open("output/rubix_galaxy.pkl", "wb"))

In [None]:
#print(rubixdata.stars.age[0,0])
print(rubixdata.stars.age)
print(rubixdata.stars.age.shape)

In [None]:
from rubix.core.pipeline import RubixPipeline
# Suppose you already have a user_config or path to config
#config = "../rubix/config/pipeline_config.yaml"
import os
config = {
    "pipeline":{"name": "calc_gradient"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 14,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": True,
            "subset_size": 2,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-14.hdf5",
        },
    
    },
    "output_path": "output",
    "output_modified": True,
    "modified_path": "output/rubix_galaxy.pkl",

    "telescope":
        {"name": "TESTGRADIENT",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 1,"noise_distribution": "normal"},
         },
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "BruzualCharlot2003"
        },
    },        
}
pipe = RubixPipeline(config)
rubixdata = pipe.run()

In [None]:
import jax
from rubix.utils import read_yaml
from rubix.pipeline import linear_pipeline as ltp

def calc_gradient(rubixdata):
    # _get_pipeline_functions() returns a list of the transformer functions in the correct order
    transformers_list = pipe._get_pipeline_functions()

    read_cfg = read_yaml("../rubix/config/pipeline_config.yml")

    # read_cfg is a dict. We specifically want read_cfg["calc_ifu"], which has "Transformers" inside.
    pipeline_cfg = read_cfg["calc_gradient"]

    tp = ltp.LinearTransformerPipeline(
        pipeline_cfg,      # pipeline_cfg == read_cfg["calc_ifu"]
        transformers_list, # The list of function objects from RubixPipeline
    )

    compiled_fn = tp.compile_expression()
    jac_fn = jax.jacrev(compiled_fn)
    jacobian = jac_fn(rubixdata)

    return jacobian

gradient = calc_gradient(rubixdata)
print(gradient)

In [None]:
jacobian = calc_gradient(rubixdata)

In [None]:
#NBVAL_SKIP
wave = pipe.telescope.wave_seq

spectra = rubixdata.stars.datacube # Spectra of all stars
print(spectra.shape)

plt.plot(wave, target_spectra)
plt.plot(wave, spectra[0,0,:])
plt.show()

In [None]:
import jax
import jax.numpy as jnp

# Loss function
def loss_fn(config, target_datacube, RubixPipeline):
    pipe = RubixPipeline(config)
    output_rubixdata = pipe.run()
    output_datacube = output_rubixdata.stars.datacube
    return jnp.mean((output_datacube - target_datacube)**2)

print(loss_fn(config, target_datacube, RubixPipeline))

In [None]:
target_datacube

In [None]:
current_datacube = rubixdata.stars.datacube
print(current_datacube)

In [None]:
current_datacube - target_datacube

In [None]:
stars_gradient = jacobian.stars.datacube

grad_age = stars_gradient.stars.age
grad_metallicity = stars_gradient.stars.metallicity
print(grad_age)
print(grad_age.shape)

In [None]:
cube = (current_datacube - target_datacube) 
cube_reshaped = cube.reshape(6,1)
print(cube_reshaped)
print(cube_reshaped.shape)
grad = grad_age
grad_reshaped = grad.reshape(6,2)
print(grad_reshaped)
print(grad_reshaped.shape)
update_age = 1/6 * 2 *jnp.sum(cube_reshaped * grad_reshaped, axis=0)
update_age = update_age.reshape(1,2)
print(update_age)
result_age = rubixdata.stars.age - update_age
print(rubixdata.stars.age)
print(result_age)

In [None]:
# Gradient-based optimization loop
def optimize_inputs(config, target_datacube, RubixPipeline, learning_rate=0.1, max_iters=100, tol=1e-6):
    for i in range(max_iters):
        # Compute loss and gradient
        #loss, grads = jax.value_and_grad(loss_fn)(config, target_datacube, RubixPipeline)
        
        pipe = RubixPipeline(config)
        input_rubixdata = pipe.run()

        jacobian = calc_gradient(input_rubixdata)
        
        
        stars_gradient = jacobian.stars.datacube

        grad_age = stars_gradient.stars.age
        grad_metallicity = stars_gradient.stars.metallicity

        current_datacube = input_rubixdata.stars.datacube
        cube = (current_datacube - target_datacube) 
        cube_reshaped = cube.reshape(6,1)

        grad_age_reshaped = grad_age.reshape(6,2)

        update_age = 1/6 * 2 *jnp.sum(cube_reshaped * grad_age_reshaped, axis=0)
        update_age = update_age.reshape(1,2)
        print(update_age)
        result_age = rubixdata.stars.age - update_age
        print(rubixdata.stars.age)
        print(result_age)
        # Update inputs
        new_inputs = input_rubixdata
        new_inputs.stars.age = input_rubixdata.stars.age - learning_rate * 2 * (input_rubixdata.stars.datacube - target_datacube) * grad_age
        new_inputs.stars.metallicity = input_rubixdata.stars.metallicity - learning_rate * 2 * (input_rubixdata.stars.datacube - target_datacube) * grad_metallicity
        import pickle
        pickle.dump(new_inputs, open("output/rubix_galaxy.pkl", "wb"))

        spectra = new_inputs.stars.datacube # Spectra of all stars
        plt.plot(wave, spectra[0,0,:], label=f"Iteration {i}")
        plt.plot(wave, target_spectra, label="Target")
        plt.legend()
        plt.savefig(f"output/gradient_plots/spectra_{i}.png")
        
        loss = loss_fn(config, target_datacube, RubixPipeline)
        # Check convergence
        if loss < tol:
            break
    
    return new_inputs, loss


In [None]:
optimize_inputs(config, target_datacube, RubixPipeline, learning_rate=0.1, max_iters=100, tol=1e-6)

In [None]:
optimized_inputs, final_loss = optimize_inputs(config, target_datacube, RubixPipeline)

In [None]:
import jax
import jax.numpy as jnp

# Loss function
def loss_fn(inputs, target_datacube, pipeline_fn):
    output_datacube = pipeline_fn(inputs)
    return jnp.mean((output_datacube - target_datacube)**2)

# Gradient-based optimization loop
def optimize_inputs(initial_inputs, target_datacube, pipeline_fn, learning_rate=0.01, max_iters=100, tol=1e-6):
    inputs = initial_inputs
    for i in range(max_iters):
        # Compute loss and gradient
        loss, grads = jax.value_and_grad(loss_fn)(inputs, target_datacube, pipeline_fn)
        
        # Update inputs
        inputs = inputs - learning_rate * grads
        
        # Check convergence
        if loss < tol:
            break
    
    return inputs, loss

# Example usage
optimized_inputs, final_loss = optimize_inputs(initial_inputs, target_datacube, pipeline_fn)
