In [1]:
#NBVAL_SKIP
import matplotlib.pyplot as plt
from rubix.core.pipeline import RubixPipeline 
import os
config = {
    "pipeline":{"name": "calc_ifu"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 14,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": True,
            "subset_size": 100,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-14.hdf5",
        },
    
    },
    "output_path": "output",

    "telescope":
        {"name": "MUSE",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 1,"noise_distribution": "normal"},},
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "BruzualCharlot2003"
        },
    },        
}

pipe = RubixPipeline(config)

data= pipe.run()

datacube = data["datacube"]
img = datacube.sum(axis=2)
plt.imshow(img, origin="lower")

2024-06-04 21:49:34,153 - rubix - INFO - 
   ___  __  _____  _____  __
  / _ \/ / / / _ )/  _/ |/_/
 / , _/ /_/ / _  |/ /_>  <  
/_/|_|\____/____/___/_/|_|  
                            

2024-06-04 21:49:34,154 - rubix - INFO - Rubix version: 0.0.1.post55+g93a544d
2024-06-04 21:49:34,547 - rubix - INFO - Getting rubix data...
2024-06-04 21:49:34,547 - rubix - INFO - Rubix galaxy file already exists, skipping conversion
2024-06-04 21:49:36,490 - rubix - INFO - Data loaded with 1000 particles.
2024-06-04 21:49:36,491 - rubix - DEBUG - Data Shape: {'coords': (1000, 3), 'velocities': (1000, 3), 'metallicity': (1000,), 'mass': (1000,), 'age': (1000,)}
2024-06-04 21:49:36,491 - rubix - INFO - Setting up the pipeline...
2024-06-04 21:49:36,492 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {'type': 'face-on'}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'ar

rotating galaxy:  face-on


2024-06-04 21:49:37,425 - rubix - DEBUG - Calculation Finished! Spectra shape: (1, 1000, 842)
2024-06-04 21:49:37,425 - rubix - INFO - Scaling Spectra by Mass...
2024-06-04 21:49:37,431 - rubix - DEBUG - Doppler Shifted SSP Wave: (1, 1000, 842)
2024-06-04 21:49:37,432 - rubix - DEBUG - Telescope Wave Seq: (3721,)
2024-06-04 21:49:37,518 - rubix - INFO - Calculating Data Cube...
2024-06-04 21:49:37,522 - rubix - DEBUG - Datacube Shape: (25, 25, 3721)
2024-06-04 21:49:37,525 - rubix - INFO - Convolving with LSF...
2024-06-04 21:49:38.358776: E external/xla/xla/stream_executor/cuda/cuda_blas.cc:202] failed to create cublas handle: the resource allocation failed
2024-06-04 21:49:38.358811: E external/xla/xla/stream_executor/cuda/cuda_blas.cc:205] Failure to initialize cublas may be due to OOM (cublas needs some free memory when you initialize it, and your deep-learning framework may have preallocated more than its fair share), or may be because this binary was not built with support for th

XlaRuntimeError: INTERNAL: No BLAS support for stream

In [None]:
#NBVAL_SKIP
wave = pipe.telescope.wave_seq
spectra = data["spectra"] # Spectra of all stars
print(spectra.shape)

plt.plot(wave, spectra[0,3]) 


Some of the spectra may be zero, this happens if the metallicity or age values are outside the range of the SSP model. This is currently the expected behavior

In [None]:
#NBVAL_SKIP
import jax 
import jax.numpy as jnp
# Create a function to calculate a single IFU cube
def calculate_ifu_cube(stars_spectra, pixel_indices):
    # Create an IFU cube of shape (25*25, 842)
    #ifu_cube = jnp.zeros((25 * 25, 842))
    
    # Use jax.ops.segment_sum to sum the spectra into the IFU cube based on pixel indices
    ifu_cube = jax.ops.segment_sum(stars_spectra, pixel_indices, num_segments=25*25)
    
    # Reshape the IFU cube to the desired shape (25, 25, 842)
    ifu_cube = ifu_cube.reshape((25, 25, 3721))
    
    return ifu_cube

spectra = data["spectra"]
assignments = data["pixel_assignment"]

# Calculate 4 individual IFU cubes
ifu_cubes = jax.vmap(calculate_ifu_cube)(spectra, assignments)

# Sum the 4 IFU cubes
final_ifu_cube = jnp.sum(ifu_cubes, axis=0)
final_ifu_cube.shape

In [None]:
#NBVAL_SKIP
wavelengths = pipe.telescope.wave_seq

# get the indices of the visible wavelengths of 4000-8000 Angstroms

visible_indices = jnp.where((wavelengths >= 4000) & (wavelengths <= 8000))



In [None]:
#NBVAL_SKIP
wavelengths

In [None]:
#NBVAL_SKIP
spectra[0,7]

In [None]:
#NBVAL_SKIP
# plot example spectrum
import matplotlib.pyplot as plt

spec = final_ifu_cube[12, 12]

plt.plot(wavelengths[visible_indices], spec[visible_indices])
plt.yscale("log")

In [None]:
#NBVAL_SKIP
# get the spectra of the visible wavelengths from the ifu cube
visible_spectra = final_ifu_cube[:, :, visible_indices[0]]
visible_spectra.shape

In [None]:
#NBVAL_SKIP
# Sum up all spectra to create an image
image = jnp.sum(visible_spectra, axis = 2)
plt.imshow(image, origin="lower", cmap="inferno")
plt.colorbar()