## Plot the shoreline 

This notebook makes grids of shoreline plots, taking different slices through the 3D cosmic shoreline space. 

In [None]:
%matplotlib widget

In [None]:
from shoreline import * 

In [None]:
subset='all'
kind='any'
uncertainties=True
pops = load_organized_populations(subset=subset, kind=kind)[f'{subset}-{kind}']
posterior = az.from_netcdf(f'posteriors/{subset}-{kind}-uncertainties={uncertainties}-numpyro.nc')

In [None]:
def best_and_sampled(posterior):
    summary = az.summary(posterior, kind="all", stat_focus="median")
    best_parameters = summary["median"]

    df = posterior.to_dataframe(
        var_names=["log_f_0", "p", "q", "ln_w"]
    )
    N_samples = 100
    sampled_parameters = df[:: int(len(df) / N_samples)]
    return best_parameters, sampled_parameters

log_v_1d = jnp.linspace(-4, 2, 1000)
log_f_1d = jnp.linspace(-5, 5, 1000)
log_L_1d = jnp.linspace(1.5, -3.5, 1000)

In [None]:
from exoatlas import * 

x = np.linspace(-1, 1)
x, y = np.meshgrid(x, x)
plt.contour

In [None]:
t = TransitingExoplanets()
t[(t.radius() > 2.0*u.Rearth)*(t.stellar_luminosity() < 10**(-3)*u.Lsun)].create_table()

In [None]:
planets_to_annotate=['Mercury', 'Venus', 'Earth', 'Mars', 'Jupiter', 'Saturn', 'Uranus', 'Neptune', 'Moon', 'Titan', 'Pluto', 'Eris', 'Haumea', 'Makemake', 'Ceres', '55 Cnc e', 'L 98-59b',  'K2 141 b', 'LHS 3844 b', 'GJ367b', 'TOI-1685b', 'GJ1252b', 'GJ486b', 'GJ1132b', 
                     'TOI-1468 b', 'LHS 1140 c', 'LTT 1445Ab', 'LTT 1445Ac', 'GJ 3929b', 'LHS 1140b', 'LTT1445Ab', 'TRAPPIST-1',
                     'Trappist-1b', 'Trappist-1c', 'LHS 1478 b', 'Kepler-10b', 'Kepler-78b', 'GJ 1214b', 'K2-18b', 'TOI-700d', 'TOI-700e', 'Kepler-62e', 'Kepler-62f', 'WD 1856+534 b']#'LP 791-18',

In [None]:
def plot_grid_of_shorelines(kind='any', f_lim=[10**-3.5, 10**4.5], v_lim=[10**-1.5, 10**1], L_lim=[10**0.5, 10**-3.5]):
    subset='all'
    
    uncertainties=True

    pops = load_organized_populations(subset=subset, kind=kind)[f'{subset}-{kind}']
    posterior = az.from_netcdf(f'posteriors/{subset}-{kind}-uncertainties={uncertainties}-numpyro.nc')

    N_columns = 4 
    symbols = dict(relative_escape_velocity=r'$\sf v_{\rm esc}/v_{esc,\oplus}$',
                    relative_insolation=r'$\sf f/f_\oplus$',
                    stellar_luminosity=r'$\sf L_\star/L_\odot$')
    
    plottables = dict(relative_escape_velocity=RelativeEscapeVelocity(lim=v_lim, kludge=True),
                    relative_insolation=Flux(lim=f_lim, label='Bolometric Flux\n(relative to Earth)'),
                    stellar_luminosity=StellarLuminosity(lim=L_lim))

    log_1d, lowers, uppers, centers = {}, {}, {}, {}
    for k, v in plottables.items():
        log_lower, log_upper = np.log10(v.lim)
        log_width = (log_upper - log_lower)/N_columns
        log_centers = log_lower + log_width*(np.arange(N_columns) + 0.5)

        log_1d[k] = jnp.linspace(log_lower, log_upper, 1000)
        centers[k] = 10**log_centers
        lowers[k] = 10**(log_centers - np.abs(log_width/2))
        uppers[k] = 10**(log_centers + np.abs(log_width/2)) 

    best_parameters, sampled_parameters = best_and_sampled(posterior)

    def plot_shoreline_slice(x_dim='relative_escape_velocity', 
                            y_dim='relative_insolation', 
                            fixed_dim='stellar_luminosity', 
                            slice=0,
                            ax=None, map=ErrorMap, **kw):
        """
        Plot a slice of the cosmic shoreline, 
        showing 2 dimensions and integrating 
        over a small range of the 3rd.
        """           

        # set up 2D grid for background colormap
        log_x_2d, log_y_2d = jnp.meshgrid(log_1d[x_dim], log_1d[y_dim])

        # create the figure with lots of panels
        if ax is None:
            ax = plt.gca() 
        plt.sca(ax)
        
        # set up the basic map for including planets
        m = map(
            xaxis=plottables[x_dim],
            yaxis=plottables[y_dim],
            ax=ax,
            size=100,
        )

        # loop through differently labeled populations
        for c in pops:
            for k in pops[c]:
                pop = pops[c][k]
                allowed = (pop.get(fixed_dim).value > lowers[fixed_dim][slice]) & (
                    pop.get(fixed_dim).value <= uppers[fixed_dim][slice]
                )
                if sum(allowed) > 0:
                    subset = pop[allowed]
                    #subset._plotkw = pop._plotkw
                    #subset.label = pop.label
                    #if len(subset) < 10:
                    subset.annotate_planets = True
                    subset.annotate_kw = dict(adjust=False, ha='left', va='center', format='   {}', names=planets_to_annotate)
                    print(subset, subset.annotate_planets, subset.annotate_kw)
                    #dots = subset[:]
                    #dots.bubble_anyway = True
                    #dots.s = 64
                    #m.build(dots)
                    m.build(subset)

        # calculate 2D image of the probability of having an atmosphere

        # (we could probably do this with vmap faster...)
        samples_of_log_P_2d = [] 
        N_samples = len(sampled_parameters)
        exoatlas_to_model = dict(relative_escape_velocity='log_v', relative_insolation='log_f', stellar_luminosity='log_L')
        fake_samples_for_fixed = np.linspace(np.log10(lowers[fixed_dim][slice]), np.log10(uppers[fixed_dim][slice]), N_samples)

        for s in range(N_samples):
            ## PICK UP FROM HERE!@!!!
            inputs_with_exoatlas_names = {x_dim:log_x_2d, y_dim:log_y_2d, fixed_dim:fake_samples_for_fixed[s]}
            inputs_with_model_names = {exoatlas_to_model[k]:inputs_with_exoatlas_names[k] for k in exoatlas_to_model}
            input_parameters = dict( p=sampled_parameters["p"].values[s],
                    q=sampled_parameters["q"].values[s],
                    log_f_0=sampled_parameters["log_f_0"].values[s],
                    ln_w=sampled_parameters["ln_w"].values[
                        s
                    ])
            this_P_2d = probability_of_atmosphere(**inputs_with_model_names, **input_parameters)
            samples_of_log_P_2d.append(this_P_2d)

        log_P_2d = np.mean(samples_of_log_P_2d, axis=0)

        background = plt.pcolormesh(
            10**log_x_2d,
            10**log_y_2d,
            log_P_2d,
            cmap=one2another("burlywood", "lightskyblue"),
            alpha=1,
            zorder=-1e9,
            rasterized=True,
        )
        plt.contour(10**log_x_2d,
            10**log_y_2d,
            log_P_2d,levels=[0.05, 0.5, 0.95], 
            linestyles='--', #(0, (5,5)),
            alpha=0.25,
            colors=['gray', 'black', 'gray'])


        plt.xscale("log")
        plt.yscale("log")
        plt.xlim(min(lowers[x_dim]), max(uppers[x_dim]))
        plt.ylim(min(lowers[y_dim]), max(uppers[y_dim]))
        plt.title(f'{np.log10(lowers[fixed_dim][slice]):.2}<log$_{{10}}${symbols[fixed_dim]}<{np.log10(uppers[fixed_dim][slice]):.2}', fontsize=9)

    # plot the actual grid
    fi, ax = plt.subplots(3, N_columns, figsize=(9,8), constrained_layout=True)
    dims = list(symbols)
    for r in range(3):
        for c in range(N_columns):
            plot_shoreline_slice(
                x_dim=dims[(r) % 3],
                y_dim=dims[(r + 1) % 3],
                fixed_dim=dims[(r + 2) % 3],
                slice=c,
                ax=ax[r, c],
            )
            if c > 0:
                plt.ylabel('')
                plt.setp(ax[r,c].get_yticklabels(), visible=False)
    
    # label panels with letters
    letters = "abcdefghijklmnopqrstuvwxyz"
    for i, a in enumerate(ax.flatten()):
        a.text(x=0.02, y=0.98, s=f"({letters[i]})", transform=a.transAxes,  va="top", ha="left", fontweight='bold')

    plt.savefig(f'figures/grid-of-shorelines-{kind}.pdf')

In [None]:
plot_grid_of_shorelines()

In [None]:
teq_limits = [194, 1673]*u.K 
flux_limits = 4*con.sigma_sb*teq_limits**4
# the 1360 W/m^2 that Earth receives from the Sun
earth_insolation = (1 * u.Lsun / 4 / np.pi / u.AU**2).to(u.W / u.m**2)

f_lim = list((flux_limits/earth_insolation).value)
plot_grid_of_shorelines(kind='CO2', f_lim=f_lim, v_lim=[0.1, 4])