# RUBIX pipeline

RUBIX is designed as a linear pipeline, where the individual functions are called and constructed as a pipeline. This allows as to execude the whole data transformation from a cosmological hydrodynamical simulation of a galaxy to an IFU cube in two lines of code. This notebook shows, how to execute the pipeline. To see, how the pipeline is execuded in small individual steps per individual function, we refer to the notebook `rubix_pipeline_stepwise.ipynb`.

## How to use the Pipeline
1) Define a `config`
2) Setup the `pipeline yaml`
3) Run the RUBIX pipeline
4) Do science with the mock-data

## Step 1: Config

The `config` contains all the information needed to run the pipeline. Those are run specfic configurations. Currently we just support Illustris as simulation, but extensions to other simulations (e.g. NIHAO) are planned.

For the `config` you can choose the following options:
- `pipeline`: you specify the name of the pipeline that is stored in the yaml file in rubix/config/pipeline_config.yml
- `logger`: RUBIX has implemented a logger to report the user, what is happening during the pipeline execution and give warnings
- `data - args - particle_type`: load only stars particle ("particle_type": ["stars"]) or only gas particle ("particle_type": ["gas"]) or both ("particle_type": ["stars","gas"])
- `data - args - simulation`: choose the Illustris simulation (e.g. "simulation": "TNG50-1")
- `data - args - snapshot`: which time step of the simulation (99 for present day)
- `data - args - save_data_path`: set the path to save the downloaded Illustris data
- `data - load_galaxy_args - id`: define, which Illustris galaxy is downloaded
- `data - load_galaxy_args - reuse`: if True, if in th esave_data_path directory a file for this galaxy id already exists, the downloading is skipped and the preexisting file is used
- `data - subset`: only a defined number of stars/gas particles is used and stored for the pipeline. This may be helpful for quick testing
- `simulation - name`: currently only IllustrisTNG is supported
- `simulation - args - path`: where the data is stored and how the file will be named
- `output_path`: where the hdf5 file is stored, which is then the input to the RUBIX pipeline
- `telescope - name`: define the telescope instrument that is observing the simulation. Some telescopes are predefined, e.g. MUSE. If your instrument does not exist predefined, you can easily define your instrument in rubix/telescope/telescopes.yaml
- `telescope - psf`: define the point spread function that is applied to the mock data
- `telescope - lsf`: define the line spread function that is applied to the mock data
- `telescope - noise`: define the noise that is applied to the mock data
- `cosmology`: specify the cosmology you want to use, standard for RUBIX is "PLANCK15"
- `galaxy - dist_z`: specify at which redshift the mock-galaxy is observed
- `galaxy - rotation`: specify the orientation of the galaxy. You can set the types edge-on or face-on or specify the angles alpha, beta and gamma as rotations around x-, y- and z-axis
- `ssp - template`: specify the simple stellar population lookup template to get the stellar spectrum for each stars particle. In RUBIX frequently "BruzualCharlot2003" is used.

In [None]:
#NBVAL_SKIP
import matplotlib.pyplot as plt
from rubix.core.pipeline import RubixPipeline 
import os
config = {
    "pipeline":{"name": "calc_ifu"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 14,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": True,
            "subset_size": 1,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-14.hdf5",
        },
    
    },
    "output_path": "output",

    "telescope":
        {"name": "MUSE",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 1,"noise_distribution": "normal"},},
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "BruzualCharlot2003"
        },
    },        
}

## Step 2: Pipeline yaml

To run the RUBIX pipeline, you need a yaml file (stored in `rubix/config/pipeline_config.yml`) that defines which functions are used during the execution of the pipeline. This shows the example pipeline yaml to compute a stellar IFU cube.

```yaml
calc_ifu:
  Transformers:
    rotate_galaxy:
      name: rotate_galaxy
      depends_on: null
      args: []
      kwargs:
        type: "face-on"
    filter_particles:
      name: filter_particles
      depends_on: rotate_galaxy
      args: []
      kwargs: {}
    spaxel_assignment:
      name: spaxel_assignment
      depends_on: filter_particles
      args: []
      kwargs: {}

    reshape_data:
      name: reshape_data
      depends_on: spaxel_assignment
      args: []
      kwargs: {}

    calculate_spectra:
      name: calculate_spectra
      depends_on: reshape_data
      args: []
      kwargs: {}

    scale_spectrum_by_mass:
      name: scale_spectrum_by_mass
      depends_on: calculate_spectra
      args: []
      kwargs: {}
    doppler_shift_and_resampling:
      name: doppler_shift_and_resampling
      depends_on: scale_spectrum_by_mass
      args: []
      kwargs: {}
    calculate_datacube:
      name: calculate_datacube
      depends_on: doppler_shift_and_resampling
      args: []
      kwargs: {}
    convolve_psf:
      name: convolve_psf
      depends_on: calculate_datacube
      args: []
      kwargs: {}
    convolve_lsf:
      name: convolve_lsf
      depends_on: convolve_psf
      args: []
      kwargs: {}
    apply_noise:
      name: apply_noise
      depends_on: convolve_lsf
      args: []
      kwargs: {}
```

Ther is one thing you have to know about the naming of the functions in this yaml: To use the functions inside the pipeline, the functions have to be called exactly the same as they are returned from the core module function!

## Step 3: Run the pipeline

After defining the `config` and the `pipeline_config` you can simply run the whole pipeline by these two lines of code.

In [None]:
#NBVAL_SKIP
pipe = RubixPipeline(config)

rubixdata = pipe.run()

## Step 4: Mock-data

Now we have our final datacube and can use the mock-data to do science. Here we have a quick look in the optical wavelengthrange of the mock-datacube and show the spectra of a central spaxel and a spatial image.

In [None]:
#NBVAL_SKIP
import jax.numpy as jnp

wave = pipe.telescope.wave_seq
# get the indices of the visible wavelengths of 4000-8000 Angstroms
visible_indices = jnp.where((wave >= 4000) & (wave <= 8000))


This is how you can access the spectrum of an individual spaxel, the wavelength can be accessed via `pipe.wave_seq`

In [None]:
#NBVAL_SKIP
wave = pipe.telescope.wave_seq

spectra = rubixdata.stars.datacube # Spectra of all stars
print(spectra.shape)

plt.plot(wave, spectra[12,12,:])


Plot a spacial image of the data cube

In [None]:
#NBVAL_SKIP
# get the spectra of the visible wavelengths from the ifu cube
visible_spectra = rubixdata.stars.datacube[:, :, visible_indices[0]]
#visible_spectra.shape

# Sum up all spectra to create an image
image = jnp.sum(visible_spectra, axis = 2)
plt.imshow(image, origin="lower", cmap="inferno")
plt.colorbar()

## DONE!

Congratulations, you have sucessfully run the RUBIX pipeline to create your own mock-observed IFU datacube! Now enjoy playing around with the RUBIX pipeline and enjoy doing amazing science with RUBIX :)

# Experimental work on the gradient

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jacrev, jacfwd

In [None]:
def pipeline(x):
    # Example pipeline function
    y = jnp.sin(x)
    z = jnp.sum(y ** 2)
    return z

In [None]:
# Define the gradient function
pipeline_grad = grad(pipeline)

# Example input
x = jnp.array([1.0, 2.0, 3.0])

# Compute the gradient
grad_result = pipeline_grad(x)
print("Gradient:", grad_result)

In [None]:
def pipeline_multi_output(x):
    # Example pipeline function with multiple outputs
    y = jnp.sin(x)
    z = jnp.cos(x)
    return y, z

# Define the Jacobian function
pipeline_jacobian = jacrev(pipeline_multi_output)

# Compute the Jacobian
jacobian_result = pipeline_jacobian(x)
print("Jacobian:", jacobian_result)

In [None]:
def pipeline_mulit_input(x, y):
    # Example pipeline function with multiple inputs
    z = x ** 2 + (y-1) ** 2
    return z

In [None]:
x = jnp.array([5.0])
y = jnp.array([2.0])

pipeline_grad_x = jacrev(pipeline_mulit_input, argnums=0)
pipeline_grad_y = jacrev(pipeline_mulit_input, argnums=1)
pipeline_grad_xy = jacrev(pipeline_mulit_input, argnums=(0, 1))

grad_x_result = pipeline_grad_x(x, y)
grad_y_result = pipeline_grad_y(x, y)
grad_xy_result = pipeline_grad_xy(x, y)

print("Gradient with respect to x:", grad_x_result)
print("Gradient with respect to y:", grad_y_result)
print("Gradient with respect to x and y:", grad_xy_result)


In [None]:
x = x - 0.1 * grad_x_result
y = y - 0.1 * grad_y_result

print("Updated x:", x)
print("Updated y:", y)

grad_x_result = pipeline_grad_x(x, y)
grad_y_result = pipeline_grad_y(x, y)

print("Gradient with respect to x:", grad_x_result)
print("Gradient with respect to y:", grad_y_result)

# Gradient for rotation function

In [None]:
import jax
import jax.numpy as jnp
from rubix.galaxy.alignment import rotate_galaxy as rotate_galaxy_core

# Dummy data for testing
positions = rubixdata.stars.coords  # Shape (N, 3)
velocities = rubixdata.stars.velocity  # Shape (N, 3)
masses = rubixdata.stars.mass # Shape (N,)
halfmass_radius = rubixdata.galaxy.halfmassrad_stars
alpha, beta, gamma = 90.0, 0.0, 0.0

# Wrapper for JAX compatibility
def rotate_galaxy_wrapper(positions, velocities, masses, halfmass_radius, alpha, beta, gamma):
    new_positions, new_velocities = rotate_galaxy_core(
        positions=positions,
        velocities=velocities,
        masses=masses,
        halfmass_radius=halfmass_radius,
        alpha=alpha,
        beta=beta,
        gamma=gamma,
    )
    return new_positions, new_velocities

# Compute the Jacobian with respect to positions
jac_positions = jax.jacrev(rotate_galaxy_wrapper, argnums=0)(
    positions, velocities, masses, halfmass_radius, alpha, beta, gamma
)

# Compute the Jacobian with respect to velocities
jac_velocities = jax.jacrev(rotate_galaxy_wrapper, argnums=1)(
    positions, velocities, masses, halfmass_radius, alpha, beta, gamma
)

# Jacobian w.r.t. mass
jac_mass = jax.jacrev(rotate_galaxy_wrapper, argnums=2)(
    positions, velocities, masses, halfmass_radius, alpha, beta, gamma
)

# Jacobian w.r.t. halfmass_radius
jac_halfmass_radius = jax.jacrev(rotate_galaxy_wrapper, argnums=3)(
    positions, velocities, masses, halfmass_radius, alpha, beta, gamma
)

# Jacobian w.r.t. alpha
jac_alpha = jax.jacrev(rotate_galaxy_wrapper, argnums=4)(
    positions, velocities, masses, halfmass_radius, alpha, beta, gamma
)

# Jacobian w.r.t. beta
jac_beta = jax.jacrev(rotate_galaxy_wrapper, argnums=5)(
    positions, velocities, masses, halfmass_radius, alpha, beta, gamma
)

# Jacobian w.r.t. gamma
jac_gamma = jax.jacrev(rotate_galaxy_wrapper, argnums=6)(
    positions, velocities, masses, halfmass_radius, alpha, beta, gamma
)


print("Jacobian w.r.t. positions:", jac_positions)
print("Jacobian w.r.t. velocities:", jac_velocities)
print("Jacobian w.r.t. mass:", jac_mass)
print("Jacobian w.r.t. halfmass_radius:", jac_halfmass_radius)
print("Jacobian w.r.t. alpha:", jac_alpha)
print("Jacobian w.r.t. beta:", jac_beta)
print("Jacobian w.r.t. gamma:", jac_gamma)

# Gradient for spaxel assignment

In [None]:
from rubix.telescope.utils import (
    calculate_spatial_bin_edges,
    square_spaxel_assignment,
    mask_particles_outside_aperture,
)
from rubix.core.telescope import get_telescope, get_spatial_bin_edges

telescope = get_telescope(config)
if telescope.pixel_type not in ["square"]:
    raise ValueError(f"Pixel type {telescope.pixel_type} not supported")
spatial_bin_edges = get_spatial_bin_edges(config)

coords = rubixdata.stars.coords

def spaxel_assignment(coords, spatial_bin_edges):
    if rubixdata.stars.coords is not None:
        pixel_assignment = square_spaxel_assignment(
            coords, spatial_bin_edges
        )
    return pixel_assignment

# Compute the Jacobian with respect to positions
# it is not possible to use jax.jacrev here, because the function returns an integer array here
#jac_positions2 = jax.jacrev(spaxel_assignment, argnums=0)(
#    coords, spatial_bin_edges
#)

# Compute the value and pullback function
value, pullback = jax.vjp(spaxel_assignment, coords, spatial_bin_edges)

# The "value" is the function output
print("Function output (value):", value)

# The pullback function takes a vector (e.g., ones_like(value)) and returns the gradient
output_grad = jnp.ones_like(value)  # Example output gradient
grad_coords, grad_edges = pullback(output_grad)

# Gradients with respect to coords and spatial_bin_edges
print("Gradient w.r.t. coords:", grad_coords)
print("Gradient w.r.t. spatial_bin_edges:", grad_edges)

In [None]:
import jax
import jax.numpy as jnp

def pipeline(positions, velocities, masses, halfmass_radius, alpha, beta, gamma, spatial_bin_edges):
    # Step 1: Rotate the galaxy
    rotated_positions, _ = rotate_galaxy_wrapper(
        positions, velocities, masses, halfmass_radius, alpha, beta, gamma
    )
    # Step 2: Assign spaxels
    pixel_assignment = square_spaxel_assignment(rotated_positions, spatial_bin_edges)
    return pixel_assignment


# Compute the Jacobian w.r.t. positions (argnums=0)
jac_positions_pipeline = jax.jacrev(pipeline, argnums=0)(
    positions, velocities, masses, halfmass_radius, alpha, beta, gamma, spatial_bin_edges
)

print("Jacobian w.r.t. positions in pipeline:", jac_positions_pipeline)

# Gradient for reshape data

In [None]:
from rubix.core.data import reshape_array

def reshape_data(coords, velocities, masses, metallicity, age):
    reshaped_coords = reshape_array(coords)
    reshaped_velocities = reshape_array(velocities)
    reshaped_masses = reshape_array(masses)
    reshaped_metallicity = reshape_array(metallicity)
    reshaped_age = reshape_array(age)

    return reshaped_coords, reshaped_velocities, reshaped_masses, reshaped_metallicity, reshaped_age
        
# Compute Jacobian with respect to coords
jac_coords = jax.jacrev(reshape_data, argnums=0)(
    rubixdata.stars.coords, rubixdata.stars.velocity, rubixdata.stars.mass, rubixdata.stars.metallicity, rubixdata.stars.age
)

# Compute Jacobian with respect to velocities
jac_velocities = jax.jacrev(reshape_data, argnums=1)(
    rubixdata.stars.coords, rubixdata.stars.velocity, rubixdata.stars.mass, rubixdata.stars.metallicity, rubixdata.stars.age
)

# Compute Jacobian with respect to masses
jac_masses = jax.jacrev(reshape_data, argnums=2)(
    rubixdata.stars.coords, rubixdata.stars.velocity, rubixdata.stars.mass, rubixdata.stars.metallicity, rubixdata.stars.age
)

# Compute Jacobian with respect to metallicity
jac_metallicity = jax.jacrev(reshape_data, argnums=3)(
    rubixdata.stars.coords, rubixdata.stars.velocity, rubixdata.stars.mass, rubixdata.stars.metallicity, rubixdata.stars.age
)

# Compute Jacobian with respect to age
jac_age = jax.jacrev(reshape_data, argnums=4)(
    rubixdata.stars.coords, rubixdata.stars.velocity, rubixdata.stars.mass, rubixdata.stars.metallicity, rubixdata.stars.age
)

print("Jacobian w.r.t. coords:", jac_coords)
print("Jacobian w.r.t. velocities:", jac_velocities)
print("Jacobian w.r.t. masses:", jac_masses)
print("Jacobian w.r.t. metallicity:", jac_metallicity)
print("Jacobian w.r.t. age:", jac_age)

# Gradient for spectra calculation

In [None]:
from rubix.core.ssp import get_lookup_interpolation_pmap
lookup_interpolation_pmap = get_lookup_interpolation_pmap(config)

def calculate_spectra(age_data, metallicity_data):
    # Ensure they are not scalars or empty; convert to 1D arrays if necessary
    age = jnp.atleast_1d(age_data)
    metallicity = jnp.atleast_1d(metallicity_data)

    spectra = lookup_interpolation_pmap(
        # rubixdata.stars.metallicity, rubixdata.stars.age
        metallicity,
        age,
    )  # * inputs["mass"]
    spectra_jax = jnp.array(spectra)
    return spectra_jax

# Compute the Jacobian with respect to age
jac_age = jax.jacrev(calculate_spectra, argnums=0)(
    rubixdata.stars.age, rubixdata.stars.metallicity
)

# Compute the Jacobian with respect to metallicity
jac_metallicity = jax.jacrev(calculate_spectra, argnums=1)(
    rubixdata.stars.age, rubixdata.stars.metallicity
)

print("Jacobian w.r.t. age:", jac_age)
print("Jacobian w.r.t. metallicity:", jac_metallicity)

# Gradient for scaling by mass

In [None]:
def scale_spectrum_by_mass(spectra, mass):

    mass = jnp.expand_dims(mass, axis=-1)
    # rubixdata.stars.spectra = rubixdata.stars.spectra * mass
    spectra_mass = spectra * mass
    return spectra_mass

# Compute the Jacobian with respect to spectra
jac_spectra = jax.jacrev(scale_spectrum_by_mass, argnums=0)(
    rubixdata.stars.spectra, rubixdata.stars.mass
)

# Compute the Jacobian with respect to mass
jac_mass = jax.jacrev(scale_spectrum_by_mass, argnums=1)(
    rubixdata.stars.spectra, rubixdata.stars.mass
)

print("Jacobian w.r.t. spectra:", jac_spectra)
print("Jacobian w.r.t. mass:", jac_mass)

# Gradient for dopplershifting and resampling

In [None]:
from rubix import config as rubix_config
from rubix.core.telescope import get_telescope
from rubix.core.ssp import get_lookup_interpolation_pmap, get_ssp
from rubix.spectra.ifu import (
    cosmological_doppler_shift,
    resample_spectrum,
    velocity_doppler_shift,
    calculate_cube,
)
from rubix.core.ifu import get_velocities_doppler_shift_vmap, get_resample_spectrum_pmap

# The velocity component of the stars that is used to doppler shift the wavelength
velocity_direction = rubix_config["ifu"]["doppler"]["velocity_direction"]

# The redshift at which the user wants to observe the galaxy
galaxy_redshift = config["galaxy"]["dist_z"]

# Get the telescope wavelength bins
telescope = get_telescope(config)
telescope_wavelenght = telescope.wave_seq

# Get the SSP grid to doppler shift the wavelengths
ssp = get_ssp(config)

# Doppler shift the SSP wavelenght based on the cosmological distance of the observed galaxy
ssp_wave = cosmological_doppler_shift(z=galaxy_redshift, wavelength=ssp.wavelength)

# Function to Doppler shift the wavelength based on the velocity of the stars particles
# This binds the velocity direction, such that later we only need the velocity during the pipeline
doppler_shift = get_velocities_doppler_shift_vmap(ssp_wave, velocity_direction)


def doppler_shift_and_resampling(velocites, spectra):
    # Doppler shift the SSP Wavelengths based on the velocity of the stars
    doppler_shifted_ssp_wave = doppler_shift(velocites)
    # Function to resample the spectrum to the telescope wavelength grid
    resample_spectrum_pmap = get_resample_spectrum_pmap(telescope_wavelenght)
    # jax.debug.print("doppler shifted ssp wave {}", doppler_shifted_ssp_wave)
    # jax.debug.print("Spectra before resampling {}", inputs["spectra"])
    spectrum_resampled = resample_spectrum_pmap(
        spectra, doppler_shifted_ssp_wave
    )
    
    return spectrum_resampled

# Compute the Jacobian with respect to velocities
jac_velocities = jax.jacrev(doppler_shift_and_resampling, argnums=0)(
    rubixdata.stars.velocity, rubixdata.stars.spectra_mass
)

# Compute the Jacobian with respect to spectra
jac_spectra = jax.jacrev(doppler_shift_and_resampling, argnums=1)(
    rubixdata.stars.velocity, rubixdata.stars.spectra_mass
)

print("Jacobian w.r.t. velocities:", jac_velocities)
print("Jacobian w.r.t. spectra:", jac_spectra)

# Gradient for datacube

In [None]:
from rubix.spectra.ifu import calculate_cube
from rubix.core.telescope import get_telescope
import jax

telescope = get_telescope(config)
num_spaxels = int(telescope.sbin)

# Bind the num_spaxels to the function
calculate_cube_fn = jax.tree_util.Partial(calculate_cube, num_spaxels=num_spaxels)
calculate_cube_pmap = jax.pmap(calculate_cube_fn)


def calculate_datacube(spectra, pixel_assignment):
    ifu_cubes = calculate_cube_pmap(
        spectra=spectra,
        spaxel_index=pixel_assignment,
    )
    datacube = jnp.sum(ifu_cubes, axis=0)
    datacube_jax = jnp.array(datacube)
    return datacube_jax

"""
# Compute the Jacobian with respect to spectra
jac_spectra = jax.jacrev(calculate_datacube, argnums=0)(
    rubixdata.stars.spectra, rubixdata.stars.pixel_assignment
)

# Compute the Jacobian with respect to pixel_assignment
jac_pixel_assignment = jax.jacrev(calculate_datacube, argnums=1)(
    rubixdata.stars.spectra, rubixdata.stars.pixel_assignment
)

print("Jacobian w.r.t. spectra:", jac_spectra)
print("Jacobian w.r.t. pixel_assignment:", jac_pixel_assignment)
"""

"""
# Compute VJP for the datacube function
datacube_value, vjp_fn = jax.vjp(calculate_datacube, rubixdata.stars.spectra, rubixdata.stars.pixel_assignment)

# Define a gradient direction (e.g., ones)
output_grad = jnp.ones_like(datacube_value)

# Compute gradient with respect to inputs
grad_spectra, grad_pixel_assignment = vjp_fn(output_grad)

print("Gradient w.r.t. spectra:", grad_spectra)
print("Gradient w.r.t. pixel_assignment:", grad_pixel_assignment)
"""


# Compute the Jacobian with forward-mode differentiation
jac_spectra_fwd = jax.jacfwd(calculate_datacube, argnums=0)(
    rubixdata.stars.spectra, rubixdata.stars.pixel_assignment
)

print("Jacobian w.r.t. spectra (forward-mode):", jac_spectra_fwd)

"""
def chunked_jacobian(func, input_data, chunk_size, *other_args):
    n_chunks = (len(input_data) + chunk_size - 1) // chunk_size
    chunks = jnp.array_split(input_data, n_chunks)
    
    jacobian_chunks = []
    for chunk in chunks:
        jacobian_chunk = jax.jacrev(func, argnums=0)(chunk, *other_args)
        jacobian_chunks.append(jacobian_chunk)
    return jnp.concatenate(jacobian_chunks, axis=0)

# Use chunking for spectra
chunk_size = 100  # Adjust based on available memory
jac_spectra_chunked = chunked_jacobian(
    calculate_datacube, rubixdata.stars.spectra, chunk_size, rubixdata.stars.pixel_assignment
)

print("Chunked Jacobian w.r.t. spectra:", jac_spectra_chunked)
"""

In [None]:
from rubix.telescope.psf.psf import get_psf_kernel, apply_psf

m, n = config["telescope"]["psf"]["size"], config["telescope"]["psf"]["size"]
sigma = config["telescope"]["psf"]["sigma"]
psf_kernel = get_psf_kernel("gaussian", m, n, sigma=sigma)


# Define the function to convolve the datacube with the PSF kernel
def convolve_psf(datacube, size, sigma):
    """Convolve the input datacube with the PSF kernel."""
    psf_kernel = get_psf_kernel("gaussian", size, size, sigma=sigma)
    datacube_psf = apply_psf(datacube, psf_kernel)
    return datacube_psf

# Compute the Jacobian with respect to datacube
jac_datacube = jax.jacfwd(convolve_psf, argnums=0)(
    rubixdata.stars.datacube, 5, 0.6
)

print("Jacobian w.r.t. datacube:", jac_datacube)

In [None]:
rubix_grad = grad(pipeline)

grad_result = rubix_grad(config)
print(grad_result)

rubix_grad = jacrev(pipeline)

grad_result = rubix_grad(config)
print(grad_result)