In [None]:
#NBVAL_SKIP
import matplotlib.pyplot as plt
from rubix.core.pipeline import RubixPipeline 
import os
config = {
    "pipeline":{"name": "calc_gradient"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 14,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": True,
            "subset_size": 100,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-14.hdf5",
        },
    
    },
    "output_path": "output",

    "telescope":
        {"name": "TESTGRADIENT",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 1,"noise_distribution": "normal"},
         },
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "BruzualCharlot2003"
        },
    },        
}


pipe = RubixPipeline(config)
rubixdata = pipe.run()

import jax.numpy as jnp
plt.imshow(jnp.sum(rubixdata.stars.datacube, axis=2), origin="lower", cmap="inferno")
plt.colorbar()

In [None]:
import jax
import jax.numpy as jnp
from rubix.galaxy.alignment import rotate_galaxy as rotate_galaxy_core
from rubix.telescope.utils import (
    calculate_spatial_bin_edges,
    square_spaxel_assignment,
    mask_particles_outside_aperture,
)
from rubix.core.telescope import get_telescope, get_spatial_bin_edges
from rubix.core.data import reshape_array
from rubix.core.ssp import get_lookup_interpolation_pmap
from rubix import config as rubix_config
from rubix.core.telescope import get_telescope
from rubix.core.ssp import get_lookup_interpolation_pmap, get_ssp
from rubix.spectra.ifu import (
    cosmological_doppler_shift,
    resample_spectrum,
    velocity_doppler_shift,
    calculate_cube,
)
from rubix.core.ifu import get_velocities_doppler_shift_vmap, get_resample_spectrum_pmap
from rubix.spectra.ifu import calculate_cube
from rubix.core.telescope import get_telescope
import jax


# Dummy data for testing
positions = rubixdata.stars.coords  # Shape (N, 3)
velocities = rubixdata.stars.velocity  # Shape (N, 3)
masses = rubixdata.stars.mass # Shape (N,)
halfmass_radius = rubixdata.galaxy.halfmassrad_stars
alpha, beta, gamma = 90.0, 0.0, 0.0

telescope = get_telescope(config)
spatial_bin_edges = get_spatial_bin_edges(config)
lookup_interpolation_pmap = get_lookup_interpolation_pmap(config)


def rotate_galaxy_wrapper(positions, velocities, masses, halfmass_radius, alpha, beta, gamma):
    new_positions, new_velocities = rotate_galaxy_core(
        positions=positions,
        velocities=velocities,
        masses=masses,
        halfmass_radius=halfmass_radius,
        alpha=alpha,
        beta=beta,
        gamma=gamma,
    )
    return new_positions, new_velocities

def spaxel_assignment(positions, spatial_bin_edges):
    if rubixdata.stars.coords is not None:
        pixel_assignment = square_spaxel_assignment(
            positions, spatial_bin_edges
        )
    return pixel_assignment

def reshape_data(positions, velocities, masses, metallicity, age):
    reshaped_coords = reshape_array(positions)
    reshaped_velocities = reshape_array(velocities)
    reshaped_masses = reshape_array(masses)
    reshaped_metallicity = reshape_array(metallicity)
    reshaped_age = reshape_array(age)

    return reshaped_coords, reshaped_velocities, reshaped_masses, reshaped_metallicity, reshaped_age

def calculate_spectra(age_data, metallicity_data):
    # Ensure they are not scalars or empty; convert to 1D arrays if necessary
    age = jnp.atleast_1d(age_data)
    metallicity = jnp.atleast_1d(metallicity_data)

    spectra = lookup_interpolation_pmap(
        # rubixdata.stars.metallicity, rubixdata.stars.age
        metallicity,
        age,
    )  # * inputs["mass"]
    spectra_jax = jnp.array(spectra)
    return spectra_jax

def scale_spectrum_by_mass(spectra, mass):
    mass = jnp.expand_dims(mass, axis=-1)
    # rubixdata.stars.spectra = rubixdata.stars.spectra * mass
    spectra_mass = spectra * mass
    return spectra_mass

"""
# The velocity component of the stars that is used to doppler shift the wavelength
velocity_direction = rubix_config["ifu"]["doppler"]["velocity_direction"]
# The redshift at which the user wants to observe the galaxy
galaxy_redshift = config["galaxy"]["dist_z"]
# Get the telescope wavelength bins
telescope = get_telescope(config)
telescope_wavelenght = telescope.wave_seq
# Get the SSP grid to doppler shift the wavelengths
ssp = get_ssp(config)
# Doppler shift the SSP wavelenght based on the cosmological distance of the observed galaxy
ssp_wave = cosmological_doppler_shift(z=galaxy_redshift, wavelength=ssp.wavelength)
# Function to Doppler shift the wavelength based on the velocity of the stars particles
# This binds the velocity direction, such that later we only need the velocity during the pipeline
doppler_shift = get_velocities_doppler_shift_vmap(ssp_wave, velocity_direction)

def doppler_shift_and_resampling(velocites, spectra):
    # Doppler shift the SSP Wavelengths based on the velocity of the stars
    doppler_shifted_ssp_wave = doppler_shift(velocites)
    # Function to resample the spectrum to the telescope wavelength grid
    resample_spectrum_pmap = get_resample_spectrum_pmap(telescope_wavelenght)
    # jax.debug.print("doppler shifted ssp wave {}", doppler_shifted_ssp_wave)
    # jax.debug.print("Spectra before resampling {}", inputs["spectra"])
    spectrum_resampled = resample_spectrum_pmap(
        spectra, doppler_shifted_ssp_wave
    )
    
    return spectrum_resampled
"""

num_spaxels = int(telescope.sbin)
# Bind the num_spaxels to the function
calculate_cube_fn = jax.tree_util.Partial(calculate_cube, num_spaxels=num_spaxels)
calculate_cube_pmap = jax.pmap(calculate_cube_fn)

def calculate_datacube(spectra, pixel_assignment):
    ifu_cubes = calculate_cube_pmap(
        spectra=spectra,
        spaxel_index=pixel_assignment,
    )
    datacube = jnp.sum(ifu_cubes, axis=0)
    datacube_jax = jnp.array(datacube)
    return datacube_jax

def pipeline(positions, velocities, masses, metallicity, age, halfmass_radius, alpha, beta, gamma, spatial_bin_edges):
    # Step 1: Rotate the galaxy
    rotated_positions, rotated_velocities = rotate_galaxy_wrapper(
        positions, velocities, masses, halfmass_radius, alpha, beta, gamma
    )
    # Step 2: Assign spaxels
    pixel_assignment = square_spaxel_assignment(rotated_positions, spatial_bin_edges)
    # Step 3: Reshape the data
    #reshaped_coords, reshaped_velocities, reshaped_masses, reshaped_metallicity, reshaped_age = reshape_data(
    #    rotated_positions, rotated_velocities, masses, metallicity, age
    #)
    #pixel_assignment = jnp.array([pixel_assignment[0, 0]])
    # Step 4: Calculate the spectra
    spectra = calculate_spectra(age, metallicity)
    # Step 5: Scale the spectra by mass
    spectra_mass = scale_spectrum_by_mass(spectra, masses)
    # Step 6: doppler shift and resample the spectra
    #spectra_resampled = doppler_shift_and_resampling(reshaped_velocities, spectra_mass)
    # Step 7: Calculate the datacube
    datacube = calculate_datacube(spectra_mass, jnp.array([pixel_assignment[:, 0]]))

    return datacube

datacube2 = pipeline(positions, velocities, masses, rubixdata.stars.metallicity, rubixdata.stars.age, halfmass_radius, alpha, beta, gamma, spatial_bin_edges)

import jax.numpy as jnp
plt.imshow(jnp.sum(datacube2, axis=2), origin="lower", cmap="inferno")
plt.colorbar()

In [None]:
from jax import grad, jacrev, jacfwd

positions = rubixdata.stars.coords  # Shape (N, 3)
velocities = rubixdata.stars.velocity  # Shape (N, 3)
masses = rubixdata.stars.mass # Shape (N,)
metallicity = rubixdata.stars.metallicity
age = rubixdata.stars.age
halfmass_radius = rubixdata.galaxy.halfmassrad_stars
alpha, beta, gamma = 90.0, 0.0, 0.0
spatial_bin_edges = get_spatial_bin_edges(config)

# Calculate the gradient with respect to the positions
jac_positions = jax.jacrev(pipeline, argnums=0)(
    positions, velocities, masses, metallicity, age, halfmass_radius, alpha, beta, gamma, spatial_bin_edges
)
# Calculate the gradient with respect to the metallicity
jac_metals = jax.jacrev(pipeline, argnums=3)(
    positions, velocities, masses, metallicity, age, halfmass_radius, alpha, beta, gamma, spatial_bin_edges
)
# Calculate the gradient with respect to the age
jac_age = jax.jacrev(pipeline, argnums=4)(
    positions, velocities, masses, metallicity, age, halfmass_radius, alpha, beta, gamma, spatial_bin_edges
)

print("Jacobian w.r.t. positions:", jac_positions)
print("Jacobian w.r.t. metallicity:", jac_metals)
print("Jacobian w.r.t. age:", jac_age)

In [None]:
jac_metals.shape

In [None]:
jac_positions.shape

# Gradient for full pipeline

In [None]:
from rubix.core.pipeline import RubixPipeline
# Suppose you already have a user_config or path to config
#config = "../rubix/config/pipeline_config.yaml"
import os
config = {
    "pipeline":{"name": "calc_gradient"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 14,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": True,
            "subset_size": 10,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-14.hdf5",
        },
    
    },
    "output_path": "output",

    "telescope":
        {"name": "TESTGRADIENT",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 1,"noise_distribution": "normal"},
         },
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "BruzualCharlot2003"
        },
    },        
}
pipe = RubixPipeline(config)
rubixdata = pipe._prepare_data()

# _get_pipeline_functions() returns a list of the transformer functions in the correct order
transformers_list = pipe._get_pipeline_functions()
print(transformers_list)  # Debug: see the list of JAX-compatible functions

from rubix.utils import read_yaml

read_cfg = read_yaml("../rubix/config/pipeline_config.yml")

# read_cfg is a dict. We specifically want read_cfg["calc_ifu"], which has "Transformers" inside.
pipeline_cfg = read_cfg["calc_gradient"]

from rubix.pipeline import linear_pipeline as ltp

tp = ltp.LinearTransformerPipeline(
    pipeline_cfg,      # pipeline_cfg == read_cfg["calc_ifu"]
    transformers_list, # The list of function objects from RubixPipeline
)

compiled_fn = tp.compile_expression()

# Evaluate pipeline
#output = compiled_fn(rubixdata)

# Calculate gradient
import jax

jac_fn = jax.jacrev(compiled_fn)
jacobian = jac_fn(rubixdata)
print(jacobian)

In [None]:
# Inspect the structure of the Jacobian
children, aux_data = jacobian.tree_flatten()
#print("Children:", children)
#print("Auxiliary data:", aux_data)
print(children[1])

# Access specific elements of the Jacobian
# Example: Access the partial derivative of the first output with respect to the first input
# Assuming the first child corresponds to the first output and the first element of the first child corresponds to the first input
first_output_first_input_jacobian = children[0][0]
print("Partial derivative of the first output with respect to the first input:", first_output_first_input_jacobian)