In [None]:
import coolingFunction

In [None]:
import copy
import numpy as np
import h5py
import pandas as pd
import scipy
import scipy.special
import sys
import verdict
import os
import tqdm
import unyt

In [None]:
import kalepy as kale

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.patheffects as path_effects
import matplotlib.cm as cm
import matplotlib.colors as plt_colors
import matplotlib.gridspec as gridspec
import matplotlib.transforms
import palettable

In [None]:
import cartopy.crs as ccrs

In [None]:
import linefinder.analyze_data.worldlines as a_worldlines
import linefinder.analyze_data.worldline_set as worldline_set
import linefinder.analyze_data.plot_worldlines as p_worldlines
import linefinder.utils.presentation_constants as p_constants

In [None]:
import galaxy_dive.analyze_data.ahf as analyze_ahf
import galaxy_dive.plot_data.ahf as plot_ahf
import galaxy_dive.analyze_data.particle_data as particle_data
import galaxy_dive.plot_data.generic_plotter as generic_plotter
import galaxy_dive.plot_data.plotting as plotting
import galaxy_dive.utils.data_operations as data_operations
import galaxy_dive.utils.executable_helpers as exec_helpers

In [None]:
import linefinder.utils.file_management as file_management
import linefinder.config as config

In [None]:
import trove

In [None]:
from py2tex import py2tex

In [None]:
import helpers

In [None]:
%matplotlib inline

# Load Data

In [None]:
pm = dict(
    snum = 600,
    tables_dir = '/work/03057/zhafen/CoolingTables/',
    study_duplicates = False,
    ahf_index = 600,
    
    # For the fancy scatter plot we're visualizing.
    variable_alpha = True,
)

In [None]:
pm = trove.link_params_to_config(
    '/home1/03057/zhafen/papers/Hot-Accretion-in-FIRE/analysis/hot_accretion.trove',
    script_id = 'nb.11',
    variation = 'm12i_md',
    global_variation = '',
    **pm
)

In [None]:
# Used so often it's nice to not enclose
snum = pm['snum']
ind = pm['ahf_index'] - snum

In [None]:
w = a_worldlines.Worldlines(
    tag = pm['tag'],
    data_dir = pm['base_data_dir'],
    halo_data_dir = pm['halo_data_dir'],
    ahf_index = pm['ahf_index'],
)

In [None]:
w.retrieve_halo_data()

In [None]:
m_plot_label  = r'$M_{\rm h} = 10^{' + '{:.02g}'.format( np.log10( w.m_vir[snum] ) )
m_plot_label += '} M_\odot$'
plot_label = m_plot_label + ', z={:.02}'.format( w.redshift[snum] )
print( plot_label )

In [None]:
base_processed_data_dir = pm['config_parser'].get( 'DEFAULT', 'processed_data_dir' )
default_data_fp = os.path.join( base_processed_data_dir, 'summary.hdf5' )
default_data = verdict.Dict.from_hdf5( default_data_fp, create_nonexistent=True )

In [None]:
data_fp = os.path.join( pm['processed_data_dir'], 'summary.hdf5' )
data = verdict.Dict.from_hdf5( data_fp, create_nonexistent=True )

In [None]:
if pm['plt_style'] is not None:
    plt.style.use( pm['plt_style'] )

## Labels

In [None]:
tchange_key = pm['central_indices'].split( '_' )[0]
t_tchange_key = 't_' + tchange_key

In [None]:
t_tchange_label = helpers.get_t_tchange_label( pm )

## Calculate Central Indices

In [None]:
if pm['central_indices'] == 'tcools_inds':
    inds = w.calc_tcools_inds(
        lookback_time_max = pm['lookback_time_max'],
        choose_first = pm['choose_first'],
        B = pm['logTcools'],
    )
else:
    calc_fn = getattr( w, 'calc_{}'.format( pm['central_indices'] ) )
    inds = calc_fn(
        lookback_time_max = pm['lookback_time_max'],
        choose_first = pm['choose_first'],
    )

In [None]:
# valid = inds > pm['minInd']
valid = inds != config.INT_FILL_VALUE
valid_inds = inds[valid]
particle_inds = np.arange( w.n_particles )[valid]

## Calculate $\vec j$

In [None]:
specific_mom = w.get_data( 'J' )
w.data['Jmag'] = w.get_data( 'Jmag' )

In [None]:
tot_momentum_fp = os.path.join( base_processed_data_dir, 'tot_momentums.hdf5' )
tot_ang_momentum = verdict.Dict.from_hdf5( tot_momentum_fp )[pm['variation']]['snum{:03d}'.format( snum )]

In [None]:
tot_ang_momentum_normed = tot_ang_momentum / np.linalg.norm( tot_ang_momentum )
_ = w.calc_ang_momentum_alignment( tot_ang_momentum_normed )

In [None]:
# Setup coordinate system
z_hat = tot_ang_momentum_normed
x_hat = np.cross( z_hat, np.array([ 0., 1., 0. ]) )
x_hat /= np.linalg.norm( x_hat )
y_hat = np.cross( z_hat, x_hat )

In [None]:
# Calculate on-sky coordinates
theta = []
phi = []
jx = []
jy = []
jz = []
for i, jmag_i in enumerate( tqdm.tqdm( w.get_data( 'Jmag' ).transpose() ) ):

    # Calculate angular momentum components
    j_i = w.get_data( 'J' )[:,:,i].transpose()
    jz_i = w.get_data( 'Jz' )[:,i]
    jperp_i = j_i - z_hat * jz_i[:,np.newaxis]
    jx_i = np.dot( x_hat, jperp_i.transpose() )
    jy_i = np.dot( y_hat, jperp_i.transpose() )

    # Theta and phi
    theta_i = np.arccos( jz_i / jmag_i )
    phi_i = np.arctan2( jy_i, jx_i )
    
    jx.append( jx_i )
    jy.append( jy_i )
    jz.append( jz_i )
    theta.append( theta_i )
    phi.append( phi_i )
    
w.data['Jx'] = np.array( jx ).transpose()
w.data['Jy'] = np.array( jy ).transpose()
w.data['Jz'] = np.array( jz ).transpose()
w.data['ThetaJ'] = np.array( theta ).transpose()
w.data['PhiJ'] = np.array( phi ).transpose()

# Plot

## Settings

In [None]:
def get_cmap_and_norm( data_key ):

    if data_key == 'T':
        cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
            'berlin_white',
            [
                palettable.scientific.diverging.Berlin_3.mpl_colors[0],
                'w',
                palettable.scientific.diverging.Berlin_3.mpl_colors[-1],
            ],
        )
        norm = plt_colors.LogNorm( vmin=1e4, vmax=1e6 )
    elif data_key == 'Jz/Jmag':
        cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
            'tofino_white',
            [
                palettable.cartocolors.diverging.Tropic_7.mpl_colors[-2],
                'w',
                palettable.cartocolors.diverging.Tropic_7.mpl_colors[1],
            ],
        )
        norm = plt_colors.Normalize( vmin=-1, vmax=1 )

    return cmap, norm

In [None]:
cmap, norm = get_cmap_and_norm( 'T' )

In [None]:
# Setup time steps
delta_t = 0.01
t_movie_after = np.arange( 0., 0.5 + delta_t, delta_t )
t_movie_before = -1. * np.arange( 0., 1.5 + delta_t, delta_t )[1:][::-1]
t_movie = np.concatenate( [ t_movie_before, t_movie_after ]  )

In [None]:
class ContourCalc( object ):
    
    def __init__( self, arr ):
        
        is_not_nan = np.invert( np.isnan( arr ) )
        is_finite = np.invert( np.isinf( arr ) )
        is_valid = is_not_nan & is_finite
        self.values_sorted = np.sort( arr[is_valid] )[::-1]
        
        self.values_fraction = np.cumsum( self.values_sorted )
        self.values_fraction /= self.values_fraction[-1]
        
        self.interp_fn = scipy.interpolate.interp1d( self.values_fraction, self.values_sorted )
        
    def get_level( self, q, f_min_is_average=True ):
        
        f = np.array( q ) / 100.
        
        if f_min_is_average:
            f_min = 0.5 * ( self.values_fraction[0] + self.values_fraction[1] )
        else:
            f_min = self.values_fraction[0]
        
        if pd.api.types.is_list_like( f ):
            f = np.array( f )
            f[f<f_min] = f_min
        else:
            if f < f_min:
                f = f_min

        return self.interp_fn( f ) 

## Get Time Offset
Find offset for each worldline (to account for t_tchange occuring between snapshots)

In [None]:
# Mask Data
w.data_masker.clear_masks()
w.data_masker.mask_data( 'n_out', -1, 1 ) # Only include particles that have never left the main galaxy
w.data_masker.mask_data( 'PType', data_value=0 )

In [None]:
# Find particles that are valid near tchange
mask_overall = w.data_masker.get_mask()
mask_at_tchange = mask_overall[particle_inds,valid_inds]
mask_after_tchange = mask_overall[particle_inds,valid_inds-1]
masked_near_tchange = np.logical_and( mask_at_tchange, mask_after_tchange )
valid_near_tchange = np.invert( masked_near_tchange )

In [None]:
# Find the temperatures to interpolate to get the offset
T_at_tchange = w.get_data( 'T' )[particle_inds,valid_inds][valid_near_tchange]
T_after_tchange = w.get_data( 'T' )[particle_inds,valid_inds-1][valid_near_tchange]
logT_interp = np.log10( np.array([ T_at_tchange, T_after_tchange ]).transpose() )

In [None]:
# Find the times to interpolate to get the offset
t_tchange = w.get_data( t_tchange_key )
t_tchange_at_tchange = t_tchange[particle_inds,valid_inds][valid_near_tchange]
t_tchange_after_tchange = t_tchange[particle_inds,valid_inds-1][valid_near_tchange]
t_tchange_interp = np.array([ t_tchange_at_tchange, t_tchange_after_tchange ]).transpose()
t_tchange = w.get_data( t_tchange_key )[particle_inds]

In [None]:
if pm['central_indices'] == 'tcools_inds':
    # Get the offset
    time_offsets = []
    n_wrong = 0
    for i, tchange_interp_i in enumerate( tqdm.tqdm( t_tchange_interp ) ):
        interp_fn = scipy.interpolate.interp1d( logT_interp[i], t_tchange_interp[i],  )
        try:
            time_offsets.append( interp_fn( pm['logTcools'] ) )
        except ValueError:
            time_offsets.append( np.nan )
            n_wrong += 1

    time_offsets = np.array( time_offsets )
    print( n_wrong, n_wrong / t_tchange_interp.shape[0] )

    time_offsets_all = np.full( particle_inds.size, np.nan )
    time_offsets_all[valid_near_tchange] = time_offsets

    t_tchange_corrected = t_tchange - time_offsets_all[:,np.newaxis]

    # The below code shows individual cases where tchange appears to be calculated wrong
    if np.isnan( time_offsets ).sum() > 0:
        i = np.argmax( time_offsets )
        weird_original_ind = particle_inds[valid_near_tchange][i]
        print( np.log10( w.get_data( 'T' )[weird_original_ind] )[inds[weird_original_ind]-2:inds[weird_original_ind]+1] )
else:
    t_tchange_corrected = t_tchange

## Get Data

In [None]:
w.data_masker.clear_masks()
w.data_masker.mask_data( 'n_out', -1, 1 ) # Only include particles that have never left the main galaxy
w.data_masker.mask_data( 'PType', data_value=0 )

In [None]:
ds = []
for i, t_frame in enumerate( tqdm.tqdm( t_movie ) ):

    t_frame_inds = np.argmin( np.abs( t_tchange_corrected - t_frame ), axis=1 )

    # Prepare to mask
    if len( w.data_masker.masks ) == 3:
        del w.data_masker.masks[2]

    # Mask data that's out of bounds
    out_of_bounds = t_frame_inds <= 0
    out_of_bounds_full = np.zeros( w.n_particles ).astype( 'bool' )
    out_of_bounds_full[particle_inds] = out_of_bounds
    out_of_bounds_mask = np.tile( out_of_bounds_full, ( w.n_snaps, 1 ) ).transpose()
    w.data_masker.mask_data( 'out_of_bounds', custom_mask=out_of_bounds_mask,  )

    # Base parameters
    d = {}
    for key in [ 'Rx', 'Ry', 'Rz', 'M', 'Den', 'T', 'Jz/Jmag', 'ThetaJ', 'PhiJ', 'Jmag', 'Jx', 'Jy', 'Jz' ]:
        d[key] = w.get_selected_data( key, compress=False )[particle_inds,t_frame_inds].compressed()

    # Volume and smoothing length
    den_msunkpc3 = ( d['Den']*unyt.mp/unyt.cm**3  ).to( 'Msun/kpc**3' )
    d['Vol'] = d['M'] * unyt.Msun / den_msunkpc3
    d['h'] = ( 3. * d['Vol'] / 4. / np.pi )**(1./3.)

    ds.append( d )

## Make Frames

In [None]:
# Frame and scalebar size objects
def round_down(num, divisor):
    return num - (num%divisor)

In [None]:
# The t_movie cut is set to be useless by default.
lim = np.max( [ np.nanpercentile( np.abs( ds[i]['Rx'] ), 95 ) for i in np.arange(len(ds))[(t_movie >= -0.5)&(t_movie <=0.2)] ] )

In [None]:
for data_key in [ 'Jz/Jmag', 'T' ]:

    cmap, norm = get_cmap_and_norm( data_key )

    i_focused = 0

    for i, d in enumerate( tqdm.tqdm( ds ) ):
        
        # DEBUG
        if i not in [ 150, ]:
            continue

        fig = plt.figure( figsize=(12,10), facecolor='k' )
        ax = plt.gca()

        # Point size
        if len( d['M'] ) > 0:
            width_in_data = 2 * lim
            width_in_pixels = ax.get_window_extent().width
            pixels_to_points = fig.dpi / 72.
            scale = 10.
            radius = d['h'] * ( width_in_pixels / width_in_data ) * pixels_to_points * scale
            s = ( radius )**2.

            # Colors
            colors = cmap( norm( d[data_key] ) )

            # Alpha
            if pm['variable_alpha']:
                column_den = d['M'] / d['h']**2.
                alpha = plt_colors.LogNorm( vmin=np.nanmin( column_den ), vmax=np.nanmax( column_den ) )( column_den ) * 0.065 * ( 50000 / w.n_particles)
                alpha[alpha>1.] = 1.
                alpha[alpha<0.] = 0.
                colors[:,3] = alpha
            else:
                colors[:,3] = 0.01

            # Plot itself
            ax.scatter(
                d['Rx'],
                d['Rz'],
                s = s,
                c = colors,
                edgecolors = 'none',
            )

        # Scale bar
        size = round_down( min( 30., 0.95 * lim ), 10 )
        if np.isclose( size, 0. ):
            size = 10.
        xy = ( -size, -0.95*lim )
        line = ax.plot(
            xy[0] + np.array([ 0., size ]),
            [ xy[1], xy[1] ],
            linewidth = 10,
            color = 'w',
            path_effects = [
                path_effects.Stroke(linewidth=12, foreground='black'),
                path_effects.Normal()
            ]
        )
        text = ax.annotate(
            text = '{:.2g} kpc'.format( size ),
            xy = xy,
            xycoords = 'data',
            xytext = ( 5, 10 ),
            textcoords = 'offset points',
            va = 'bottom',
            ha = 'left',
            color = 'w',
            fontsize = 42,
        )
        text.set_path_effects([
            path_effects.Stroke(linewidth=2.5, foreground='black'),
            path_effects.Normal()
        ])

        # Plot label
        annotate_label = (
            r'$t - ' +
            helpers.get_tchange_label( pm )[1:] +
            '= {} Myrs'.format( int( t_movie[i]*1e3 ) )
        )
        text = ax.annotate(
            text = annotate_label,
            xy = ( 1, 1 ),
            xycoords = 'axes fraction',
            xytext = ( -10, -10 ),
            textcoords = 'offset points',
            va = 'top',
            ha = 'right',
            color = 'w',
            fontsize = 42,
        )
        text.set_path_effects([
            path_effects.Stroke(linewidth=2.5, foreground='black'),
            path_effects.Normal()
        ])
        
        # Temperature
        if len( d['T'] ) > 0:
            temp_label = r'$\langle \log (T/{\rm K}) \rangle =' + py2tex.to_tex_scientific_notation( np.nanmedian( np.log10( d['T'] ) ), sig_figs=2 ) + r'$'
            text = ax.annotate(
                text = temp_label,
                xy = ( 1, 0 ),
                xycoords = 'axes fraction',
                xytext = ( -10, 10 ),
                textcoords = 'offset points',
                va = 'bottom',
                ha = 'right',
                color = 'w',
                fontsize = 34,
            )
            text.set_path_effects([
                path_effects.Stroke(linewidth=2, foreground='black'),
                path_effects.Normal()
            ])

        # Limits
        ax.set_xlim( -lim, lim )
        ax.set_ylim( -lim, lim )
        ax.set_aspect( 'equal' )

        # Ticks
        plt.tick_params(
            which = 'both',
            left = False,
            labelleft = False,
            bottom = False,
            labelbottom = False,
        )

        # Change colors
        ax.set_facecolor( 'k' )
        plt.setp( ax.spines.values(), color='w' )
        [m.set_linewidth(3) for m in ax.spines.values()]

        save_file_tag = {
            'T': '',
            'Jz/Jmag': '_alignment',
        }[data_key]
        plotting.save_fig(
            out_dir = os.path.join( pm['data_dir'], 'projected_frames' ),
            save_file = '{}{}_{:0>3d}.png'.format( pm['variation'], save_file_tag, i ),
            fig = fig,
            resolution = 150.,
        )

        # Save movie focused on time of cooling or accreting
        if np.abs( t_movie[i] ) <= 0.150:
            plotting.save_fig(
                out_dir = os.path.join( pm['data_dir'], 'projected_frames' ),
                save_file = '{}{}_focused_{:0>3d}.png'.format( pm['variation'], save_file_tag, i_focused ),
                fig = fig,
                resolution = 150.,
            )
            i_focused += 1

        plt.close()

## On-Sky Projection

In [None]:
def angular_momentum_projection( ax, d, img_proj, vmin=None, vmax=None, n_bins=50, img=True, contour=False, q_levels=[ 50, ], colors=[ 'k', ], upsample=None, smooth=None, outline=False, outline_linewidth=5.5, norm=None ):
    
    ra_edges = np.linspace( -np.pi, np.pi, n_bins )
    cosdec_edges = np.linspace( -1, 1, n_bins )

    # Make the histogram
    hist2d, ra_edges, da_edges = np.histogram2d(
        d['PhiJ'],
        np.cos( d['ThetaJ'] ),
        bins = [ ra_edges, cosdec_edges ],
        weights = d['M'],
    )
    
    # Upsample and smooth
    if upsample is not None:
        hist2d = scipy.ndimage.zoom( hist2d, upsample )
        ra_edges = np.linspace( -np.pi, np.pi, ( n_bins - 1)*upsample + 1 )
        cosdec_edges = np.linspace( -1, 1, ( n_bins - 1)*upsample + 1 )
    if smooth is not None:
        if upsample is not None:                                                   
            sigma = upsample * smooth
        else:
            sigma = smooth
        hist2d = scipy.ndimage.filters.gaussian_filter( hist2d, sigma )
        
    # Get centers
    ra_centers = 0.5 * ( ra_edges[1:] + ra_edges[:-1] ) * 180. / np.pi
    dec_edges = np.pi / 2. - np.arccos( cosdec_edges )
    dec_edges *= 180. / np.pi
    dec_centers = 0.5 * ( dec_edges[1:] + dec_edges[:-1] )

    # Plot it
    if img:
        ax.pcolormesh(
            ra_centers,
            dec_centers,
            hist2d.transpose(),
            transform = img_proj,
            cmap = matplotlib.cm.cubehelix_r,
            shading = 'nearest',
            vmin = vmin,
            vmax = vmax,
            norm = norm,
        )
    if contour:
        
        if q_levels is not None:
            c_calc = ContourCalc( hist2d )
            levels = c_calc.get_level( q_levels )
        else:
            levels = None
            
        contour = ax.contour(
            ra_centers,
            dec_centers,
            hist2d.transpose(),
            transform = img_proj,
            colors = colors,
            levels = levels,
            linewidths = 3,
        )
                
        if outline:
            contour.collections[0].set_path_effects([
                path_effects.Stroke(linewidth=outline_linewidth, foreground='k'),
                path_effects.Normal()
            ])
            
    return hist2d

In [None]:
def plot_total_angular_momentum( ax, d, img_proj, color='k', edgecolor='k', linewidth=7, marker='x', s=100, zorder=100, **kwargs ):
    
    # Plot the direction of the total angular momentum of the infalling gas
    jvec = np.array([
        d['Jx'],
        d['Jy'],
        d['Jz'],
    ])
    jvec = jvec.transpose()
    jmean = ( jvec * d['M'][:,np.newaxis] ).sum( axis=0 ) / d['M'].sum()
    jmean /= np.linalg.norm( jmean )
    phi_mean = np.arctan2( jmean[1], jmean[0] )
    theta_mean = np.arccos( jmean[2] )
    ra_mean = phi_mean * 180. / np.pi
    dec_mean = ( np.pi / 2. - theta_mean ) * 180. / np.pi
    totmarker = ax.scatter(
        [ ra_mean, ],
        [ dec_mean, ],
        transform = img_proj,
        zorder = zorder,
        color = color,
        marker = marker,
        s = s,
        edgecolor = edgecolor,
        linewidth = linewidth,
        **kwargs
    )

In [None]:
def onsky_plot_auxilliaries( i, d, proj, img_proj, ax=None, label=True, label_color='k', label_outline_color='w', label_outline_linewidth=3.5 ):
    
    if ax is None:
        fig = plt.figure( figsize=(12,12), facecolor='w' )
        ax = plt.axes( projection=proj )

    # Gridlines
    ra_gridlines = np.arange( -1., 1., 0.25 ) * 180.
    jzj_gridlines = np.arange(-1., 1.01, 0.5 )
    dec_gridlines = ( np.pi/2. - np.arccos( jzj_gridlines ) ) * 180. / np.pi
    ax.gridlines( crs=img_proj, color='k', xlocs=ra_gridlines, ylocs=dec_gridlines, )
    gl = ax._gridliners[0].n_steps = 10000

    # Gridline labels
    # DEC
    for j, dec_gridline in enumerate( dec_gridlines ):
        
        # jz/j labels
        text = '{:.1f}'.format( jzj_gridlines[j] )
        va = 'top'
        yoffset = -5
        ha = 'left'
        xoffset = 5
        if j == 0:
            text = r'$j_z / \vert j \vert$ = ' + text
            va = 'bottom'
            yoffset = 5
            ha = 'center'
            xoffset = 3
        text = ax.annotate(
            text = text,
            xy = ( 180., dec_gridline ),
            xycoords = img_proj._as_mpl_transform( ax ),
            xytext = ( xoffset, yoffset ),
            textcoords = 'offset points',
            fontsize = 18,
            va = va,
            ha = ha,
        )
        text.set_path_effects([
            path_effects.Stroke(linewidth=3, foreground='w'),
            path_effects.Normal()
        ])
        
        # theta labels
        theta = np.arccos( jzj_gridlines[j] ) * 180. / np.pi
        text = '{:.0f}'.format( theta ) + r'$\degree$'
        va = 'bottom'
        yoffset = 5
        ha = 'left'
        xoffset = 5
        if j == 0:
            text = r'$\theta$ = ' + text
            va = 'top'
            yoffset = -5
            ha = 'center'
            xoffset = 13
        elif j == len( jzj_gridlines ) - 1:
            continue
        text = ax.annotate(
            text = text,
            xy = ( 0., dec_gridline ),
            xycoords = img_proj._as_mpl_transform( ax ),
            xytext = ( xoffset, yoffset ),
            textcoords = 'offset points',
            fontsize = 18,
            va = va,
            ha = ha,
        )
        
    # RA
    for j, ra_gridline in enumerate( ra_gridlines ):
        ra_label = ra_gridline
        if ra_label < 0.:
            ra_label += 360.
        text = '{:.0f}'.format( ra_label ) + r'$\degree$'
        if np.isclose( ra_gridline, 0. ):
            text = r'$\phi$ = ' + text
        if j in [ 1, 5, 6 ]:
            ha = 'left'
            xoffset = 5
        else:
            ha = 'right'
            xoffset = -5
        if j in [ 0, 4 ]:
            va = 'center'
            yoffset = 0
        else:
            va = 'bottom'
            yoffset = 5
        text = ax.annotate(
            text = text,
            xy = ( ra_gridline, -0.65 * 90. ),
            xycoords = img_proj._as_mpl_transform( ax ),
            xytext = ( xoffset, yoffset ),
            textcoords = 'offset points',
            fontsize = 18,
            ha = ha,
            va = va,
        )

    # Fill in the empty spot in the center
    ax.scatter(
        [ 0., ],
        [ 90. ],
        color = 'k',
        transform = img_proj,
    )


In [None]:
def angular_momentum_projection_plot( i, proj_str='AzimuthalEquidistant', img_proj='rotated', ax=None, vmin=None, vmax=None, n_bins=50, norm=None, ):
    
    d = ds[i]

    # Set up coordinates
    proj = getattr( ccrs, proj_str )()
    if img_proj == 'rotated':
        img_proj = ccrs.RotatedPole(pole_longitude=0., pole_latitude=0.)
    else:
        img_proj = ccrs.PlateCarree()
     
    # Data itself
    hist2d = angular_momentum_projection( ax, d, img_proj, vmin, vmax, n_bins, norm=norm )
    
    onsky_plot_auxilliaries( i, d, proj, img_proj, ax )

    return hist2d
    

In [None]:
# Set up coordinates
proj = ccrs.AzimuthalEquidistant()
img_proj = ccrs.RotatedPole(pole_longitude=0., pole_latitude=0.)

### Multipanel

In [None]:
# For colormap
z_min = -1.
z_max = -z_min
z_width = z_max - z_min

In [None]:
fig, axs = plt.subplots(
    nrows = 3,
    ncols = 2,
    subplot_kw = { 'projection': ccrs.AzimuthalEquidistant() },
    figsize = (15.75, 22.75),
    facecolor = 'w',
)
fig.tight_layout()

# vmin = None
vmin = 7000
# vmax = None
vmax = 4.7e6
axs_list = [
    axs[0][0],
    axs[0][1],
    axs[1][0],
    axs[1][1],
    axs[2][0],
    axs[2][1],
]
hist2ds = []
for j, i in enumerate( tqdm.tqdm([ 50, 100, 135, 145, 150, 165 ]) ):
    
    ax = axs_list[j]
    
    hist2d_i = angular_momentum_projection( ax, ds[i], img_proj=img_proj, n_bins=64, norm=matplotlib.colors.LogNorm( vmin=vmin, vmax=vmax ) )
    hist2ds.append( hist2d_i )

    onsky_plot_auxilliaries( i, ds[i], proj, img_proj, axs_list[j], )
    
    # Plot label
    annotate_label = (
        r'$t - ' +
        helpers.get_tchange_label( pm )[1:] +
        '= {} Myrs'.format( int( t_movie[i]*1e3 ) )
    )
    text = axs_list[j].annotate(
        text = annotate_label,
        xy = ( 0, 1 ),
        xycoords = 'axes fraction',
        xytext = ( 5, -5 ),
        textcoords = 'offset points',
        va = 'top',
        ha = 'left',
        color = 'k',
        fontsize = 19,
    )
    text.set_path_effects([
        path_effects.Stroke(linewidth=3.5, foreground='w'),
        path_effects.Normal()
    ])
    
#     # Add an underline
#     text_extent = text.get_window_extent( fig.canvas.renderer )
#     text_extent = text_extent.transformed( ax.transData )
#     axs_list[j].plot(
#         [ text_extent.xmin, text_extent.xmax ],
#         [ text_extent.ymin, text_extent.ymin ],
#         transform = ax.transData,
#         linewidth = 5,
#         color = 'k',
#     )
    
    plot_total_angular_momentum( axs_list[j], ds[i], img_proj, edgecolor='w', linewidth=2, marker='o', s=150, )
    
plotting.save_fig(
    out_dir = os.path.join( pm['figure_dir'], 'on_sky' ),
    save_file = 'angular_momentum.png',
    fig = fig,
    resolution = 150.,
)

In [None]:
np.nanmin( [ _[_>0].min() for _ in hist2ds] )

In [None]:
np.nanmax( hist2ds ) 

### Contour

In [None]:
# For colormap
z_min = -0.5
z_max = -z_min
z_width = z_max - z_min

In [None]:
# Plot itself
fig = plt.figure( figsize=(16,16), facecolor='w' )
ax = plt.axes( projection=proj )

t_min = -0.5
t_max = 0.2
interval = 1
for i, t_i in enumerate( tqdm.tqdm( t_movie ) ):
    
    # Skip some contours
    if i % interval != 0:
        continue
    if t_i < t_min or t_i > t_max:
        continue

    d = ds[i]
    
    color_value = ( t_movie[i] - z_min )/z_width
    color_i = palettable.scientific.diverging.Roma_9.mpl_colormap( color_value )
    
    outline = i in [ 135, 150, 165 ]

    hist2d_i = angular_momentum_projection( ax, d, img_proj, img=False, contour=True, q_levels=[ 50, ], colors=[ color_i, ], n_bins=128, upsample=3, smooth=2, outline=outline )
    
    if outline:
        edgecolor = 'k'
        zorder = 101
    else:
        edgecolor = 'none'
        zorder = 100
    plot_total_angular_momentum( ax, d, img_proj, color=color_i, marker='o', linewidth=1, edgecolor=edgecolor, s=200, zorder=zorder )

onsky_plot_auxilliaries( i, d, proj, img_proj, ax, label=False )

If I wanted to focus on improving this plot, I would vary these.
n_bins, upsample, smooth, t_min, t_max, z_min, interval