# Setup

## Imports

In [None]:
import copy
import numpy as np
import os
import pandas as pd
import tqdm

In [None]:
import cartopy.crs as ccrs

In [None]:
import yt
import unyt
import trident

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import palettable

In [None]:
import verdict
import trove
import helpers

## Parameters

In [None]:
snr = 30

In [None]:
pm = trove.link_params_to_config(
    helpers.CONFIG,
    script_id = 'nb.2',
    variation = 'm12i_md',
)

In [None]:
qual_colors = palettable.cartocolors.qualitative.Vivid_10.mpl_colors

In [None]:
sim = pm['variation']
ions = pm['ions']

## Load Halo Data

In [None]:
halo_catalog_fn = 'halo_{}.hdf5'.format( pm['snum'] )
halo_catalog_fp = os.path.join( pm['rockstar_data_dir'], halo_catalog_fn )

In [None]:
halo_data = verdict.Dict.from_hdf5( halo_catalog_fp )
index = halo_data['mass'].argmax()
center_ckpc = halo_data['position'][index]
center = center_ckpc / ( 1. + halo_data['snapshot:redshift'] )

## Load Simulation Data

In [None]:
yt_sim_fp = os.path.join( pm['sim_data_dir'], 'snapdir_{:03d}'.format( pm['snum'] ) )
ds = yt.load( yt_sim_fp )

In [None]:
kpc = ds.quan( 1, 'kpc' )

## Load Processed Data

In [None]:
data_fp = os.path.join( pm['processed_data_dir'], 'data.h5' )
data = verdict.Dict.from_hdf5( data_fp, create_nonexistent=True )

# Make an On-Sky Map

## Get Sky Frame Coords

In [None]:
sun_position = data[sim]['end'] * kpc

In [None]:
def convert_to_galactic_frame( raw_position ):
    
    position = raw_position - sun_position
    r = np.linalg.norm( position, axis=1 )
    
    position_galactic = np.array([
        np.dot( position, data[sim]['galactic_frame']['xhat'] ),
        np.dot( position, data[sim]['galactic_frame']['yhat'] ),
        np.dot( position, data[sim]['galactic_frame']['zhat'] ),
    ])
    
    phi = np.arctan2( position_galactic[1], position_galactic[0] )
    theta = np.arccos( position_galactic[2] / r )
    
    return position_galactic.transpose(), phi, theta

### Simulation Data

In [None]:
# Add ion fields
trident.ion_balance.add_ion_fields( ds, ions )

In [None]:
# Ion units, names, and internal names
ldb = trident.line_database.LineDatabase( 'lines.txt' )
number_densities_to_include = [ ( 'gas', '{}_p{}_number_density'.format( atom, lev-1 ) ) for atom, lev in ldb.parse_subset_to_ions( ions ) ]
number_density_units = [ 'log(cm**-3)', ] * len( number_densities_to_include )
number_density_names = [ 'log{}density'.format( _.replace( ' ', '' ) ) for _ in ions ]

In [None]:
sp = ds.sphere( sun_position, (2. * halo_data['radius'][index], "kpc") )

In [None]:
position_galactic, phi_galactic, theta_galactic = convert_to_galactic_frame( sp[('gas', 'position')].to( 'kpc' ) )

### Sightlines

In [None]:
position_galactic_sl, phi_sl, theta_sl = convert_to_galactic_frame( data['m12i_md']['start'] * kpc, )

In [None]:
ra_sl = 360 - phi_sl * 180. / np.pi
dec_sl = - ( theta_sl - np.pi / 2. ) * 180. / np.pi

## Spectra Properties

In [None]:
observables_fp = os.path.join( pm['data_dir'], 'sightlines', 'observables.h5' )
observables_data = verdict.Dict.from_hdf5( observables_fp, )

In [None]:
ews = []
column_densities = verdict.Dict({})
for i, start in enumerate( data[sim]['start'] ):

    spectra_fp = os.path.join( pm['data_dir'], 'sightlines', 'spectrum_{:03d}.h5'.format( i ) )
    spectra_data = verdict.Dict.from_hdf5( spectra_fp )

    w = (1. - np.exp( -spectra_data['tau'] ) )
    ew = w.sum()
    
    i_key = '{:03d}'.format( i )
    for j, ion in enumerate( ions ):
        if ion == 'H I':
            line = 'Ly a'
        else:
            ldb.lines_subset = []
            lines = ldb.parse_subset( ion )
            line = lines[0].name
        colden = observables_data[i_key][line]['column_density'].sum()
        column_densities.setitem( ion, colden, i_key )
    
    ews.append( ew )

## Plotted Properties

### LoS velocity

In [None]:
# Relative velocity to sun, but not rotated to galactic coordinates
velocity_relative = ( sp[('gas', 'velocity')] - ( data[sim]['galaxy_velocity'] + data[sim]['sun_relative_velocity'] ) * unyt.km / unyt.s ).to( 'km/s' )

In [None]:
position_relative = ( sp[( 'gas', 'position' )] - sun_position ).to( 'kpc' )
r = np.linalg.norm( position_relative, axis=1 )
velocity_los = -np.einsum( 'ij,ij->i', position_relative, velocity_relative ) / r

### Velocity magnitude
Relative to center of galaxy

In [None]:
velocity = ( sp[('gas', 'velocity')] - ( data[sim]['galaxy_velocity'] ) * unyt.km / unyt.s ).to( 'km/s' )
vmag = np.linalg.norm( velocity, axis=1 )

### Test

In [None]:
r_sp = sp.radius.to( 'kpc' )
positions_test = np.random.uniform( -r_sp, r_sp, ( 10**6, 3 ) )
r_test = np.linalg.norm( positions_test, axis=1 )
positions_test = positions_test[r_test<r_sp]
r_test = r_test[r_test<r_sp]
phi_test = np.arctan2( positions_test[:,1], positions_test[:,0] )
theta_test = np.arccos( positions_test[:,2] / r_test )

### Finalize

In [None]:
plot_types = {
    'mass-weighted': {
        'weights': sp[('gas', 'mass')].to( 'Msun' ),
    },
    'LOS velocity': {
        'weights': sp[('gas', 'mass')].to( 'Msun' ),
        'color_axis': velocity_los,
        'cmap': 'bwr_r',
        'vmin': -halo_data['vel.circ.max'][index] * 2.,
        'vmax': halo_data['vel.circ.max'][index] * 2.,
        'norm': None,
    },
    'velocity magnitude': {
        'weights': sp[('gas', 'mass')].to( 'Msun' ),
        'color_axis': vmag,
        'cmap': 'PRGn',
        'vmin': 0,
        'vmax': np.linalg.norm( data[sim]['sun_relative_velocity'] ) * 2.,
        'norm': None,
    },
    # 'test': {
    #     'phi': phi_test,
    #     'theta': theta_test,
    # }
}

In [None]:
for i, ion in enumerate( tqdm.tqdm( ions ) ):
    key = '{}-weighted'.format( ion )
    params = {
        'weights': sp[number_densities_to_include[i]],
        'los color': column_densities[ion].array(),
        'vmin': 1e-15,
    }
    plot_types[key] = params

## Plot

In [None]:
proj = ccrs.Mollweide()
img_proj = ccrs.PlateCarree()

In [None]:
def plot_projected_hist(
    ax,
    phi=phi_galactic,
    theta=theta_galactic,
    weights=None,
    color_axis=None,
    cmap=matplotlib.cm.cubehelix_r,
    vmin=None,
    vmax=None,
    norm=matplotlib.colors.LogNorm(),
    n_bins=256,
):
    
    ra_edges = np.linspace( -np.pi, np.pi, n_bins )
    cosdec_edges = np.linspace( -1, 1, n_bins )

    # Make the histogram
    hist2d, ra_edges, da_edges = np.histogram2d(
        phi,
        np.cos( theta ),
        bins = [ ra_edges, cosdec_edges ],
        weights = weights,
    )
    if color_axis is None:
        # Turn into a PDF
        hist_norm = hist2d.sum() * ( ra_edges[1] - ra_edges[0] ) * ( cosdec_edges[1] - cosdec_edges[0] )
        hist2d /= hist_norm
    else:
        hist2d_prop, ra_edges, da_edges = np.histogram2d(
            phi,
            np.cos( theta ),
            bins = [ ra_edges, cosdec_edges ],
            weights = weights * color_axis,
        ) 
        hist2d = hist2d_prop / hist2d

    # # Upsample and smooth
    # if upsample is not None:
    #     hist2d = scipy.ndimage.zoom( hist2d, upsample )
    #     ra_edges = np.linspace( -np.pi, np.pi, ( n_bins - 1)*upsample + 1 )
    #     cosdec_edges = np.linspace( -1, 1, ( n_bins - 1)*upsample + 1 )
    # if smooth is not None:
    #     if upsample is not None:                                                   
    #         sigma = upsample * smooth
    #     else:
    #         sigma = smooth
    #     hist2d = scipy.ndimage.filters.gaussian_filter( hist2d, sigma )

    # Get centers
    ra_centers = 360 - 0.5 * ( ra_edges[1:] + ra_edges[:-1] ) * 180. / np.pi
    dec_edges =  np.pi / 2. - np.arccos( cosdec_edges )
    dec_edges *= 180. / np.pi
    dec_centers = 0.5 * ( dec_edges[1:] + dec_edges[:-1] )

    img = ax.pcolormesh(
        ra_centers,
        dec_centers,
        hist2d.transpose(),
        transform = img_proj,
        cmap = cmap,
        shading = 'nearest',
        vmin = vmin,
        vmax = vmax,
        norm = norm,
    )
    
    return hist2d, img

In [None]:
projection_dir = os.path.join( pm['figure_dir'], 'projections' )
os.makedirs( projection_dir, exist_ok=True )

In [None]:
for key, plot_params in tqdm.tqdm( plot_types.items() ):
    
#     if not key in [ 'LOS velocity', 'velocity magnitude' ]:
#         continue

    width = pm['figure_width'] * 2
    height = width / 2.5
    fig = plt.figure( figsize=(width,height), facecolor='w' )
    ax = plt.axes( projection=proj )
    
    if 'los color' in plot_params:
        c = copy.copy( plot_params['los color'] )
        del plot_params['los color']
        s_colorbar = True
    else:
        c = 'k'
        s_colorbar = False
    
    hist2d, img = plot_projected_hist( ax, **plot_params )

    s = ax.scatter(
        ra_sl,
        dec_sl,
        transform = img_proj,
        c = c,
        cmap = 'viridis',
        norm = matplotlib.colors.LogNorm(),
    )
    
    ax.annotate(
        s = key,
        xy = ( 0, 1 ),
        xytext = ( 5, -5 ),
        xycoords = 'axes fraction',
        textcoords = 'offset points',
        fontsize = pm['footnote_fontsize'],
        ha = 'left',
        va = 'top',
    )
    
    plt.colorbar( img )
    if s_colorbar:
        plt.colorbar( s )
    
    ax.gridlines( color='0.6', )
    
    plt.tight_layout()
        
    projection_fn = key.replace( ' ', '_' ) + '.png'
    projection_fp = os.path.join( projection_dir, projection_fn )
    plt.savefig( projection_fp, dpi=300 )
        