In [None]:
import yt
from yt.units import Mpc, Msun
import glob
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import AxesGrid
from satellite_analysis.consistentscripts import consistentcatalogreader as consistent

#def parse()

def plotting_dm_filters_and_annotate_circle(par_type, tomer_x, tomer_y, tomer_z, tomer_radius, rockstar_x, rockstar_y, rockstar_z, rockstar_radius):
    global ad, masses
    ad = ds.all_data()
    masses = yt.np.unique(ad[(par_type, 'particle_mass')])
    #first set up the grid to plot onto
    fig = plt.figure()
    grid = AxesGrid(fig, (0.075,0.075,5,15),
                    nrows_ncols = (len(masses), 3),
                    axes_pad = 0.05,
                    label_mode = "L",
                    share_all = True,
                    cbar_location="right",
                    cbar_mode="single",
                    cbar_size="3%",
                    cbar_pad="0%")
    for index in range(len(masses)):
        global filter_name, index1
        index1 = index
        filter_name = par_type + str(index)
        def mass_filter(pfilter, data):
            filter = data[(pfilter.filtered_type, 'particle_mass')] == masses[index]
            return filter
        yt.add_particle_filter(filter_name, function=mass_filter, filtered_type=par_type, requires=['particle_mass'])
        ds.add_particle_filter(filter_name) 
    
        plotting_xyz_projection(filter_name, ds, 10, index)

def plotting_xyz_projection(filter_name, ds, zoom, index):
    #global a,b,c
    a = yt.ParticlePlot(ds, (filter_name, 'particle_position_x'), (filter_name, 'particle_position_y'),\
                    (filter_name,'particle_mass'))
    a.annotate_sphere([tomer_x, tomer_y, tomer_z], radius=(tomer_radius, 'kpc'), circle_args={'color':'red'})
    a.annotate_sphere([rockstar_x, rockstar_y, rockstar_z], radius=(rockstar_radius, 'kpc/h'), circle_args={'color':'black'})
    a.set_unit((filter_name,'particle_mass'), 'Msun')
    a.set_figure_size(5)
    a.zoom(zoom)

    b = yt.ParticlePlot(ds, (filter_name, 'particle_position_y'), (filter_name, 'particle_position_z'),\
                    (filter_name,'particle_mass'))
    b.annotate_sphere([tomer_x, tomer_y, tomer_z], radius=(tomer_radius, 'kpc'), circle_args={'color':'red'})
    b.annotate_sphere([rockstar_x, rockstar_y, rockstar_z], radius=(rockstar_radius, 'kpc/h'), circle_args={'color':'black'})
    b.set_unit((filter_name,'particle_mass'), 'Msun')
    b.set_figure_size(5)
    b.zoom(zoom)

    c = yt.ParticlePlot(ds, (filter_name, 'particle_position_z'), (filter_name, 'particle_position_x'),\
                        (filter_name,'particle_mass'))
    c.annotate_sphere([tomer_x, tomer_y, tomer_z], radius=(tomer_radius, 'kpc'), circle_args={'color':'red'})
    c.annotate_sphere([rockstar_x, rockstar_y, rockstar_z], radius=(rockstar_radius, 'kpc/h'), circle_args={'color':'black'})
    c.set_unit((filter_name,'particle_mass'), 'Msun')
    c.set_figure_size(5)
    c.zoom(zoom)
    
    plot = a.plots[(filter_name, 'particle_mass')]
    plot.figure = fig
    plot.axes = grid[index*3].axes
    plot.cax = grid.cbar_axes[0]
    a._setup_plots()

    plot = b.plots[(filter_name, 'particle_mass')]
    plot.figure = fig
    plot.axes = grid[index*3+1].axes
    plot.cax = grid.cbar_axes[0]
    b._setup_plots()

    plot = c.plots[(filter_name, 'particle_mass')]
    plot.figure = fig
    plot.axes = grid[index*3+2].axes
    plot.cax = grid.cbar_axes[0]
    c._setup_plots()


input_dir = args['input_dir']
VELA_dir = args['VELA_dir']

consistent.consistent_catalog_reader(input_dir, subhalos=True)

#find the VELA files to load into yt
VELA_snaps = glob.glob(VELA_dir + '/10MpcBox*')
VELA_snaps.sort()

#extract the scale factors from the VELA files to load
VELA_index = []
for snap in VELA_snaps:
    period = [pos for pos, char in enumerate(snap) if char == '.']
    number = snap[period[-2]+1:period[-1]]
    VELA_index.append(number)
VELA_index.sort()

for index in consistent.snapshot_index:
    VELA_a = consistent.consistent_file_index[index]
    print('Generating Graph Grid for snapshot:', VELA_a)
    position = [pos for pos, loc in enumerate(VELA_index) if loc == VELA_a]
    if position == [] or len(position) > 1:
        print('Could not find corresponding VELA 10Mpc File for snapshot:', VELA_a)
    else:

        #need to add tomer reader and fix the rockstar coords
        ds = yt.load(VELA_snaps[position[0]])
        domain_width = ds.domain_width.in_units('Mpc')
        
        tomer_radius = r_vir_tomer[15]
        tomer_x = x_tomer[15]
        tomer_y = y_tomer[15]
        tomer_z = z_tomer[15]
        
        rockstar_radius = float(halo_data_num_p_mvir[snapshot_index][0][0][4])*.7
        rockstar_x = float(halo_data_num_p_mvir[snapshot_index][0][0][8])*.7
        rockstar_y = float(halo_data_num_p_mvir[snapshot_index][0][0][9])*.7
        rockstar_z = float(halo_data_num_p_mvir[snapshot_index][0][0][10])*.7
        
        plotting_dm_filters_and_annotate_circle('darkmatter', tomer_x, tomer_y, tomer_z, tomer_radius, rockstar_x, rockstar_y, rockstar_z, rockstar_radius)


#plt.savefig('postfixVELA08Zoom.png', bbox_inches='tight')
plt.show()