# Setup

## Imports

In [None]:
import numpy as np
import os
import pandas as pd

In [None]:
import cartopy.crs as ccrs

In [None]:
import yt
import unyt
import trident

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import palettable

In [None]:
import verdict
import trove
import helpers

## Parameters

In [None]:
ions = [
    'H I',
    'O I',
    'C II',
    'C III',
    'C IV',
    'N II',
    'N III',
    'Si II',
    'Si III',
    'Si IV',
    'N V',
    'O VI',
    'Mg II'
]
fields = [
    'H_p0_number_density', 
    'O_p0_number_density',
    'C_p1_number_density',
    'C_p2_number_density',
    'C_p3_number_density',
    'N_p1_number_density',
    'N_p2_number_density',
    'Si_p1_number_density',
    'Si_p2_number_density',
    'Si_p3_number_density',
    'N_p4_number_density',
    'O_p5_number_density',
    'Mg_p1_number_density'
]
snr = 30

In [None]:
pm = trove.link_params_to_config(
    helpers.CONFIG,
    script_id = 'nb.1',
    variation = 'm12i_md',
)

In [None]:
qual_colors = palettable.cartocolors.qualitative.Vivid_10.mpl_colors

In [None]:
sim = pm['variation']

## Load Halo Data

In [None]:
halo_catalog_fn = 'halo_{}.hdf5'.format( pm['snum'] )
halo_catalog_fp = os.path.join( pm['rockstar_data_dir'], halo_catalog_fn )

In [None]:
halo_data = verdict.Dict.from_hdf5( halo_catalog_fp )
index = halo_data['mass'].argmax()
center_ckpc = halo_data['position'][index]
center = center_ckpc / ( 1. + halo_data['snapshot:redshift'] )

## Load Simulation Data

In [None]:
yt_sim_fp = os.path.join( pm['sim_data_dir'], 'snapdir_{:03d}'.format( pm['snum'] ) )
ds = yt.load( yt_sim_fp )

In [None]:
kpc = ds.quan( 1, 'kpc' )

## Load Processed Data

In [None]:
data_fp = os.path.join( pm['processed_data_dir'], 'data.h5' )
data = verdict.Dict.from_hdf5( data_fp, create_nonexistent=True )

# Sightline Coordinates

## One-time Calculations

In [None]:
new_data = False

In [None]:
data_key = 'total_angular_momentum'
try:
    jtot = data[sim][data_key]
except KeyError:
    
    new_data = True
        
    sp = ds.sphere( center * kpc , (10, "kpc"))
    jtot = sp.quantities.angular_momentum_vector( particle_type='PartType4' ).to( 'kpc * km / s' ).value

    data.setitem(sim, jtot, data_key, )

In [None]:
data_key = 'galaxy_velocity'
try:
    galaxy_velocity = data[sim][data_key]
except KeyError:
    
    new_data = True
        
    sp_center = ds.sphere( center * kpc, ( 10, 'kpc' ) )
    galaxy_velocity = sp_center.quantities.bulk_velocity().to( 'km/s' ).value

    data.setitem(sim, galaxy_velocity, data_key, )

In [None]:
if new_data:
    data.to_hdf5( data_fp )

## End Position

### Position

In [None]:
zhat = jtot / np.linalg.norm( jtot )
xhat = np.cross( [ 1, 0, 0 ], zhat )
xhat /= np.linalg.norm( xhat )

In [None]:
end = center + pm['sun_galactocentric_radius'] * xhat

### Velocity

In [None]:
data_key = 'sun_relative_velocity'
try:
    sun_relative_velocity = data[sim][data_key]
except KeyError:
    
    new_data = True
    
    sp_sun = ds.sphere( end * kpc, ( 2, 'kpc' ) )
    sun_velocity = sp_sun.quantities.bulk_velocity( particle_type='PartType4' ).to( 'km/s' )
    sun_relative_velocity = ( sun_velocity - galaxy_velocity * unyt.km / unyt.s ).value
    
    data.setitem(sim, sun_relative_velocity, data_key, )

In [None]:
if new_data:
    data.to_hdf5( data_fp )

## Start Positions

### Load On-Sky Coords

In [None]:
coords_fp = os.path.join( pm['processed_data_dir'], 'skycoords.txt' )
skycoords = pd.read_csv( coords_fp, sep=', ' )

In [None]:
# Add a test coordinate
testcoord = pd.Series( data={ 'QSO ID': 'test', 'latitude': 30., 'longitude': 0. } )
skycoords = skycoords.append( testcoord, ignore_index=True, )

In [None]:
# Add a second test coordinate
testcoord = pd.Series( data={ 'QSO ID': 'test', 'latitude': 0., 'longitude': 30. } )
skycoords = skycoords.append( testcoord, ignore_index=True, )

#### Plot

In [None]:
proj = ccrs.Mollweide()
img_proj = ccrs.PlateCarree()

In [None]:
fig = plt.figure()
ax = plt.axes( projection=proj )

ax.scatter(
    -skycoords['longitude'],
    skycoords['latitude'],
    transform = img_proj
)

ax.gridlines()
delta_extent = 1
ax.set_extent([180-delta_extent,180+delta_extent,-90,90], crs=img_proj)

### Convert to Simulation Coordinates

In [None]:
# Unit vectors for on-sky coordinate system.
# xskyhat points from the sun to the center of the galaxy.
# zskyhat points parallel to the axis of total angular momentum.
# yskyhat (should) point to the left on a sky map.
# Dot these with the coordinates to get coordinates in the sky frame.
xskyhat = -xhat
zskyhat = -zhat
yskyhat = np.cross( zskyhat, xskyhat )

In [None]:
# Convert from longitude and latitude to spherical coords, in the original (non-sky) frame.
phi = ( 360 - skycoords['longitude'] ) * np.pi / 180.
theta = ( 90 + skycoords['latitude'] ) * np.pi / 180.

In [None]:
unitcoords_sky = np.array([ np.cos( phi ) * np.sin( theta ), np.sin( phi ) * np.sin( theta ), np.cos( theta ) ])

In [None]:
unitcoords = (
    unitcoords_sky[0][:,np.newaxis] * xskyhat +
    unitcoords_sky[1][:,np.newaxis] * yskyhat +
    unitcoords_sky[2][:,np.newaxis] * zskyhat
)

In [None]:
start = end + unitcoords * pm['pathlength']

In [None]:
# Check we did things at least somewhat right:
# the second test coordinate dotted with yhat should be - pm['pathlength'] * sin( 30 ).
# This checks if the coordinates on the left are on the side of the galaxy rotating towards the viewer.
np.testing.assert_allclose( np.dot( start[-1] - end, np.cross( zhat, xhat ) ), -pm['pathlength'] * 0.5 )
# Once checked, get rid of the test coord
start = np.delete( start, -1, axis=0 )

In [None]:
# Check we did things at least somewhat right:
# the test coordinate dotted with the zhat should be pm['pathlength'] * cos( 60 )
np.testing.assert_allclose( np.dot( start[-1] - end, zhat ), pm['pathlength'] * 0.5 )
# Once checked, get rid of the test coord
start = np.delete( start, -1, axis=0 )

In [None]:
data.setitem( sim, xskyhat, 'galactic_frame', 'xhat', )
data.setitem( sim, yskyhat, 'galactic_frame', 'yhat', )
data.setitem( sim, zskyhat, 'galactic_frame', 'zhat', )

## Save

In [None]:
data.setitem( sim, start, 'start' )
data.setitem( sim, end, 'end' )

In [None]:
data.to_hdf5( data_fp )

# Generate Rays and Spectra

In [None]:
# Objects for use
ldb = trident.LineDatabase(None)

## Generate Rays

In [None]:
for i, start_i in enumerate( start ):
    
    print( 'Generating sightline {:03d}...'.format( i ) )
    
    ray_dir = os.path.join( pm['data_dir'], 'sightlines' )
    os.makedirs( ray_dir, exist_ok=True )
    ray_fp = os.path.join( ray_dir, 'ray_{:03d}.h5'.format( i ) )
    
    ray = trident.make_simple_ray(
        ds,
        start_position = ( start_i * ds.quan( 1, 'kpc' ) ),
        end_position = ( end * ds.quan( 1, 'kpc' ) ),
        data_filename = ray_fp,
        lines = ions,
    )

## Generate Spectra

In [None]:
sg = trident.SpectrumGenerator( line_database=ldb )

In [None]:
for i, start_i in enumerate( start ):
    
    print( 'Generating spectra {:03d}...'.format( i ) )
    
    ray_dir = os.path.join( pm['data_dir'], 'sightlines' )
    ray_fp = os.path.join( ray_dir, 'ray_{:03d}.h5'.format( i ) )
    ray = yt.load( ray_fp )
    
    trident.add_ion_fields(ray, ions=ions )
    
    sg.make_spectrum( ray, lines=ions, store_observables=True, min_tau=1e-4 )
    sg.apply_lsf()
    sg.add_gaussian_noise( snr )
    
    # Save
    sg.save_spectrum(
        os.path.join( ray_dir, 'spectrum_{:03d}.h5'.format( i ) )
    )
    sg.plot_spectrum(
        os.path.join( ray_dir, 'spectrum_{:03d}.png'.format( i ) )
    )