In [None]:
import os
os.environ['SPS_HOME'] = '/home/annalena/sps_fsps'

# Load the ssp template from FSPS

In [None]:
from rubix.spectra.ssp.factory import get_ssp_template
ssp_fsps = get_ssp_template("FSPS")

In [None]:
age_values = ssp_fsps.age
print(age_values)

metallicity_values = ssp_fsps.metallicity
print(metallicity_values)

In [None]:
index=1
age_index = 100

# Configure pipeline

In [None]:
from rubix.core.pipeline import RubixPipeline
# Suppose you already have a user_config or path to config
#config = "../rubix/config/pipeline_config.yaml"
import os
config = {
    "pipeline":{"name": "calc_gradient"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 14,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": True,
            "subset_size": 2,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-14.hdf5",
        },
    
    },
    "output_path": "output",
    "output_modified":  False,

    "telescope":
        {"name": "TESTGRADIENT",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 1,"noise_distribution": "normal"},
         },
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "FSPS"
        },
    },        
}
pipe = RubixPipeline(config)
rubixdata = pipe.run()

# Set target values

In [None]:
import jax.numpy as jnp

rubixdata.stars.age = jnp.array([age_values[age_index], age_values[age_index]])
rubixdata.stars.metallicity = jnp.array([metallicity_values[index], metallicity_values[index]])
rubixdata.stars.mass = jnp.array([[1.0, 1.0]])
rubixdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])

In [None]:
from rubix.core.pipeline_gradient import RubixPipeline
# Suppose you already have a user_config or path to config
#config = "../rubix/config/pipeline_config.yaml"
import os
config = {
    "pipeline":{"name": "calc_gradient"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
     "data": {
        "args": {
            "particle_type": ["stars"],
        },
    },
    
    "output_path": "output",
    #"output_modified":  False,

    "telescope":
        {"name": "TESTGRADIENT",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 1,"noise_distribution": "normal"},
         },
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "FSPS"
        },
    },        
}
pipe = RubixPipeline(config)
rubixdata = pipe.run(rubixdata)

target = rubixdata

# Calculate the gradient for the grid values

In [None]:
import jax
import jax.numpy as jnp
import numpy as np

# Initialize the pipeline
pipe = RubixPipeline(config)

# Prepare the data structure to store results
results = []

# Iterate over the age values from 1.0 to 10.3 in 0.1 steps
for metallicity in metallicity_values:
    metallicity_test = jnp.array([[metallicity, metallicity]])
    #print(metallicity_test)
    loss = pipe.loss_only_wrt_metallicity(metallicity_test, rubixdata, target)
    grad_stars_metallicity = jax.grad(pipe.loss_only_wrt_metallicity, argnums=0)(metallicity_test, rubixdata, target)
    results.append((metallicity_test, loss, grad_stars_metallicity))

In [None]:
# Save the results to a file
with open(f'./output/loss_grid_fsps/gradient_age{age_index}_metallsvarying{index}.txt', 'w') as f:
    for age, loss, grad in results:
        f.write(f"Metallicity: {age}, Loss: {loss}, Gradient: {grad}\n")

In [None]:
metallicity = [result[0] for result in results]
losses = [result[1] for result in results]
gradients = [result[2] for result in results]

metallicity = jnp.array(metallicity)
losses = jnp.array(losses)
gradients = jnp.array(gradients)

print(metallicity[:, :, 0])
print(losses.shape)
print(gradients[:, :, 0])

In [None]:
import matplotlib.pyplot as plt

metallicity = [result[0] for result in results]
losses = [result[1] for result in results]
gradients = [result[2] for result in results]

metallicity = jnp.array(metallicity)
losses = jnp.array(losses)
gradients = jnp.array(gradients)

#print(losses)
#print(gradients)

# Plot the gradient versus age
plt.figure(figsize=(10, 6))
plt.plot(metallicity[:,:, 0], gradients[:, :, 0], label='Gradient wrt metallicity')
plt.plot(metallicity[:, :, 0], losses, label='loss wrt metallicity')
plt.vlines(metallicity_values[5], -50, 50, colors='r', linestyles='dashed', label='True metallicity')
plt.xlabel('Metallicity')
plt.ylabel('Gradient')
plt.title('Gradient vs Metallicity')
plt.legend()
plt.grid(True)
#plt.ylim(-1e-2,1e-2)
#plt.ylim(-60,10)
#plt.xscale('log')
#plt.savefig('./output/loss/gradient_wrt_metallicity_grid_bc.png')
plt.show()
plt.close()

In [None]:
import matplotlib.pyplot as plt
import jax.numpy as jnp

metallicity = [result[0] for result in results]
losses = [result[1] for result in results]
gradients = [result[2] for result in results]

metallicity = jnp.array(metallicity)
losses = jnp.array(losses)
gradients = jnp.array(gradients)
print(metallicity[:, :, 0])
print(gradients[:, :, 0])
print(losses)

# Create a figure with two subplots side by side
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 6))

# Plot the gradient versus metallicity in the first subplot with crosses at data points
ax1.plot(metallicity[:, :, 0], gradients[:, :, 0], label='Gradient wrt metallicity', marker='x')
ax1.vlines(metallicity_values[index], -50, 50, colors='r', linestyles='dashed', label='True metallicity')
ax1.set_xlabel('Metallicity')
ax1.set_ylabel('Gradient')
ax1.set_title('Gradient vs Metallicity')
ax1.set_ylim(-1e-3, 2e-4)
ax1.legend()
ax1.grid(True)

# Plot the loss versus metallicity in the second subplot with crosses at data points
ax2.plot(metallicity[:, :, 0], losses, label='Loss wrt metallicity', marker='x')
ax2.vlines(metallicity_values[index], -50, 50, colors='r', linestyles='dashed', label='True metallicity')
ax2.set_xlabel('Metallicity')
ax2.set_ylabel('Loss')
ax2.set_title('Loss vs Metallicity')
ax2.set_ylim(-2e-7, 2e-6)
ax2.legend()
ax2.grid(True)

# Save the figure
#plt.savefig('./output/loss/gradient_and_loss_wrt_metallicity_grid_bc.png')
plt.show()
plt.close()