In [None]:
#NBVAL_SKIP
import matplotlib.pyplot as plt
from rubix.core.pipeline import RubixPipeline 
import os
config = {
    "pipeline":{"name": "calc_ifu"},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars", "gas"],
            #"cube_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 11,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": False,
            "subset_size": 1000,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-11.hdf5",
        },
    
    },
    "output_path": "output",

    "telescope":
        {"name": "MUSE",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 0.5},
         "noise": {"signal_to_noise": 100,"noise_distribution": "normal"},},
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "BruzualCharlot2003"
        },
    },        
}

pipe = RubixPipeline(config)

data= pipe.run()

In [None]:
datacube = data.stars.datacube
img = datacube.sum(axis=2)
plt.imshow(img, origin="lower")

In [None]:
import jax.numpy as jnp

#NBVAL_SKIP
wavelengths = pipe.telescope.wave_seq

# get the indices of the visible wavelengths of 4000-8000 Angstroms

visible_indices = jnp.where((wavelengths >= 4000) & (wavelengths <= 8000))

spec = datacube[12, 12]

plt.plot(wavelengths[visible_indices], spec[visible_indices])

In [None]:
datacube.shape[1]

In [None]:
from rubix.core.telescope import get_telescope

telescope = get_telescope(config)
print(telescope)
print(telescope.spatial_res)

In [None]:
import numpy as np
from astropy.io import fits

In [None]:
hdr = fits.Header()
hdr['PIPELINE'] = config['pipeline']['name']
hdr['DIST_z'] = config['galaxy']['dist_z']
hdr['ROTATION'] = config['galaxy']['rotation']['type']
#hdr['XCOORD'] = params['cube_params']['x_coord']
#hdr['YCOORD'] = params['cube_params']['y_coord']
#hdr['X_RES'] = params['cube_params']['spatial_resolution'][0]
#hdr['Y_RES'] = params['cube_params']['spatial_resolution'][1]
hdr['SIMULATION'] = config['simulation']['name']
hdr['GALAXY_ID'] = config['data']['load_galaxy_args']['id']
hdr['SNAPSHOT'] = config['data']['args']['snapshot']
hdr['PARTICLE_SUBSET'] = config['data']['subset']['use_subset']
hdr['SSP'] = config['ssp']['template']['name']
hdr['INSTRUMENT'] = config['telescope']['name']
hdr['PSF'] = config['telescope']['psf']['name']
hdr['LSF'] = config['telescope']['lsf']['sigma']
hdr['SIGNAL_TO_NOISE'] = config['telescope']['noise']['signal_to_noise']
hdr['NOISE_DISTRIBUTION'] = config['telescope']['noise']['noise_distribution']

In [None]:
hdr1 = fits.Header()
hdr1['EXTNAME'] = 'DATA'
hdr1['OBJECT'] = str(config['simulation']['name']) + ' ' + str(config['data']['load_galaxy_args']['id'])
hdr1['BUNIT'] = 'erg/(s*cm^2)' #? /Angstrom
hdr1['CRPIX1'] = (datacube.shape[0] - 1) / 2
hdr1['CRPIX2'] = (datacube.shape[1] - 1) / 2
hdr1['CD1_1'] = telescope.spatial_res / 3600 # to convert from arcsec to deg
hdr1['CD1_2'] = 0
hdr1['CD2_1'] = 0
hdr1['CD2_2'] = telescope.spatial_res / 3600 # to convert from arcsec to deg
hdr1['CUNIT1'] = 'deg'
hdr1['CUNIT2'] = 'deg'
hdr1['CTYPE1'] = 'RA---TAN'
hdr1['CTYPE2'] = 'DEC--TAN'
hdr1['CRVAL1'] = 0
hdr1['CRVAL2'] = 0
hdr1['CTYPE3'] = 'AWAV'
hdr1['CUNIT3'] = 'Angstrom'
hdr1['CD3_3'] = telescope.wave_res
hdr1['CRPIX3'] = 1
hdr1['CRVAL3'] = telescope.wave_range[0]
hdr1['CD1_3'] = 0
hdr1['CD2_3'] = 0
hdr1['CD3_1'] = 0
hdr1['CD3_2'] = 0

In [None]:
empty_primary = fits.PrimaryHDU(header=hdr)
image_hdu1 = fits.ImageHDU(datacube.T, header=hdr1)
image_hdu2 = fits.ImageHDU(wavelengths, name='WAVE')

hdul = fits.HDUList([empty_primary, image_hdu1, image_hdu2])
filepath = 'output/'
hdul.writeto(filepath + f"{config['simulation']['name']}id{config['data']['load_galaxy_args']['id']}_stars2.fits", overwrite=True)
