# Gradient calculation

In this notebook we show, how you can calculate teh full Jacobian of the pipeline with respect to the input data.

First of all, we define our config as input for the pipeline, set up the pipeline and let it run as we du normally the forward modeling, so downloading IllustrisTNG data and transforming the data in the linear pipeline and end up with the datacube.

In [None]:
from rubix.core.pipeline import RubixPipeline
# Suppose you already have a user_config or path to config
#config = "../rubix/config/pipeline_config.yaml"
import os
config = {
    "pipeline":{"name": "calc_gradient"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 14,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": True,
            "subset_size": 2,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-14.hdf5",
        },
    
    },
    "output_path": "output",
    "output_modified":  False,

    "telescope":
        {"name": "TESTGRADIENT",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 1,"noise_distribution": "normal"},
         },
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "BruzualCharlot2003"
        },
    },        
}
pipe = RubixPipeline(config)
rubixdata = pipe.run()

target_datacube = rubixdata.stars.datacube
target_age = rubixdata.stars.age
target_metallicity = rubixdata.stars.metallicity

# rubixdata as input for the pipeline

For gradient based optimization, it would be good to give the pipeline directly the modified rubixdata object and calculate the new datacube. This is now possible, if you load the pipeline from rubix.core.pipeline_gradient instead of rubix.core.pipeline. You set up the Pipeline and then you pass the rubixdata to the run function

In [None]:
#NBVAL_SKIP
from rubix.core.pipeline_gradient import RubixPipeline
# Suppose you already have a user_config or path to config
#config = "../rubix/config/pipeline_config.yaml"
import os
config = {
    "pipeline":{"name": "calc_gradient"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
     "data": {
        "args": {
            "particle_type": ["stars"],
        },
    },
    
    "output_path": "output",
    "output_modified":  False,

    "telescope":
        {"name": "TESTGRADIENT",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 1,"noise_distribution": "normal"},
         },
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "BruzualCharlot2003"
        },
    },        
}
pipe = RubixPipeline(config)
rubixdata = pipe.run(rubixdata)

In [None]:
#NBVAL_SKIP
import matplotlib.pyplot as plt

wave = pipe.telescope.wave_seq

spectra = rubixdata.stars.datacube
target_spectra = target_datacube[0,0,:]
plt.plot(wave, spectra[0,0,:], label="current spectrum")
plt.plot(wave, target_spectra, label="target")
plt.legend(loc="upper right")

# Gradient calculation

The RubixPipeline from rubix.core.pipeline_gradient has als the gradient function implemented. You can just call the function and pass your rubixdata and then you get the gradfient returned.

In [None]:
#NBVAL_SKIP
gradient = pipe.gradient(rubixdata)

gradient_age = gradient.stars.age

In [None]:
import jax.numpy as jnp

rubixdata.stars.age = jnp.array([10.0, 10.0])
rubixdata.stars.metallicity = jnp.array([0.01, 0.01])

In [None]:
import jax.numpy as jnp
from rubix.core.pipeline_gradient import RubixPipeline
config = {
    "pipeline":{"name": "calc_gradient"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
     "data": {
        "args": {
            "particle_type": ["stars"],
        },
    },
    
    "output_path": "output",
    "output_modified":  False,

    "telescope":
        {"name": "TESTGRADIENT",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 1,"noise_distribution": "normal"},
         },
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "BruzualCharlot2003"
        },
    },        
}

# Gradient-based optimization loop
def optimize_inputs(config, target_datacube, rubixdata, RubixPipeline, learning_rate=0.1, max_iters=100, tol=1e-6):
    for i in range(max_iters):
    # Compute loss and gradient
    #loss, grads = jax.value_and_grad(loss_fn)(config, target_datacube, RubixPipeline)
        print(f"Iteration {i}")
        pipe = RubixPipeline(config)
        input_rubixdata = pipe.run(rubixdata)
        print("calculating gradient................")
        stars_gradient = pipe.gradient(input_rubixdata)
        print("gradient calculated................")

        grad_age = stars_gradient.stars.age
        grad_metallicity = stars_gradient.stars.metallicity

        current_datacube = input_rubixdata.stars.datacube
        cube = (current_datacube - target_datacube) 
        cube_reshaped = cube.reshape(6,1)

        grad_age_reshaped = grad_age.reshape(6,2)

        update_age = 1/6 * 2 *jnp.sum(cube_reshaped * grad_age_reshaped, axis=0) * 1e-3
        update_age = update_age.reshape(1,2)
        print(f"update age: {update_age}")

        grad_metallicity_reshaped = grad_metallicity.reshape(6,2)

        update_metallicity = 1/6 * 2 *jnp.sum(cube_reshaped * grad_metallicity_reshaped, axis=0) * 1e-9
        update_metallicity = update_metallicity.reshape(1,2)
        print(f"update metallicity: {update_metallicity}")

        # Update inputs
        new_inputs = input_rubixdata

        new_inputs.stars.age = input_rubixdata.stars.age - learning_rate * update_age
        new_inputs.stars.age = input_rubixdata.stars.age.reshape(2,)
        new_inputs.stars.metallicity = input_rubixdata.stars.metallicity - learning_rate * update_metallicity
        new_inputs.stars.metallicity = input_rubixdata.stars.metallicity.reshape(2,)
        print(f"new age...........: {new_inputs.stars.age.shape}")
        
        rubixdata = new_inputs

        spectra = new_inputs.stars.datacube # Spectra of all stars
        plt.plot(wave, spectra[0,0,:], label=f"Iteration {i}, {new_inputs.stars.age}, {new_inputs.stars.metallicity}")
        plt.plot(wave, target_spectra, label="Target")
        plt.legend(loc="upper right")
        plt.savefig(f"output/gradient_plots/spectra_{i}.png")
        plt.close()
        
        loss = pipe.loss_mse(new_inputs.stars.datacube, target_datacube)
        # Check convergence
        if loss < tol:
            break
    
    return new_inputs, loss


In [None]:
optimize_inputs(config, target_datacube, rubixdata, RubixPipeline, learning_rate=0.1, max_iters=100, tol=1e-6)