In [1]:
import numpy as np

import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.patheffects as pe
from matplotlib.ticker import FormatStrFormatter

import utils

import corner

In [2]:
Res_A = utils.RetrievalResults(
    prefix='../retrieval_outputs/fiducial_J_A_ret_15_1column_n1000/test_', 
    m_set='J1226_A', w_set='J1226', load_posterior=True
    )

Res_B = utils.RetrievalResults(
    prefix='../retrieval_outputs/fiducial_J_B_ret_51_1column_n1000/test_', 
    m_set='J1226_A', w_set='J1226', load_posterior=True
    )

  analysing data from ../retrieval_outputs/fiducial_J_A_ret_15_1column_n1000/test_.txt
['modes', 'nested sampling global log-evidence', 'nested sampling global log-evidence error', 'global evidence', 'global evidence error', 'nested importance sampling global log-evidence', 'nested importance sampling global log-evidence error', 'marginals']
1399128.9058026907
1399123.4062811895
1399123.4062811895
  analysing data from ../retrieval_outputs/fiducial_J_B_ret_51_1column_n1000/test_.txt
['modes', 'nested sampling global log-evidence', 'nested sampling global log-evidence error', 'global evidence', 'global evidence error', 'nested importance sampling global log-evidence', 'nested importance sampling global log-evidence error', 'marginals']
1395569.7135439706
1395556.204842826
1395556.204842826


In [3]:
color_m_A = '#FF622E'; color_m_B = '#396ED8'
color_A = '#FF622E'; color_B = '#396ED8'

cmap_m_A = mpl.colors.LinearSegmentedColormap.from_list('', ['w',color_m_A])
env_colors_m_A = cmap_m_A([0,1/3,2/3,1])
env_colors_m_A[:,3] = 0.5
env_colors_m_A[0,3] = 0.0

cmap_m_B = mpl.colors.LinearSegmentedColormap.from_list('', ['w',color_m_B])
env_colors_m_B = cmap_m_B([0,1/3,2/3,1])
env_colors_m_B[:,3] = 0.5
env_colors_m_B[0,3] = 0.0

In [318]:
def get_ranges(posterior, q=[0.16,0.84], n=4):
    
    posterior_range = np.quantile(posterior, q=q, axis=0)
    med = np.median(posterior, axis=0)
    
    posterior_range -= med
    posterior_range *= n
    posterior_range += med

    return posterior_range.T

labels, indices, ranges = np.array([
    [r'$v_\mathrm{rad}$', 5, (10,30)], 
    [r'$v\ \sin{i}$', 4, (10,30)], 
    #[r'$\epsilon_\mathrm{limb}$', 3, (0,1)], 
    [r'$\log\ g$', 2, (3.55,3.73)],     
    [r'$\log\ \mathrm{H_2O}$', 12, (-4.8,-4.65)], 
    [r'$\log\ \mathrm{K}$', 14, (-7.9,-7.7)], 
    [r'$\log\ \mathrm{Na}$', 15, (-6.6,-6.3)], 
    [r'$\log\ \mathrm{HF}$', 13, (-8.75,-8.45)], 
    [r'$\log\ (\mathrm{FeH})_0$', 9, (-9.2,-8.8)], 
    ], dtype=object).T

labels = labels.astype(str)
indices = indices.astype(int)
ranges = None

#post_A = Res_A.posterior[:,indices]
#post_B = Res_B.posterior[:,indices[::-1]]
post_A = Res_A.posterior[:,indices[::-1]]
post_B = Res_B.posterior[:,indices]
ranges_A = get_ranges(post_A)
ranges_B = get_ranges(post_B)

kwargs = dict(
    bins=20, 
    fill_contours=True, 
    plot_datapoints=True, 
    labels=None,  
    max_n_ticks=3, 
)

# Make offset sub-figures
nrows, ncols = 22, 19
idx = 18

fig = plt.figure(figsize=(9,9*nrows/ncols))
gs = fig.add_gridspec(nrows=nrows, ncols=ncols, hspace=0., wspace=0)
subfig = np.array([ 
    fig.add_subfigure(gs[-idx:,:idx]), 
    fig.add_subfigure(gs[:idx,-idx:]),
])
for subfig_i in subfig:
    subfig_i.set_facecolor('none')

# Corner-plot in each subfigure
kwargs['color']          = color_m_A
kwargs['hist_kwargs']    = {'edgecolor':color_m_A, 'facecolor':env_colors_m_A[1], 'fill':True}
kwargs['contour_kwargs'] = {'linewidths':1.0, 'color':color_m_A}
kwargs['reverse'] = True
subfig[1] = corner.corner(
    fig=subfig[1], data=post_A, range=ranges_A, **kwargs
)

# Corner-plot in each subfigure
kwargs['color']          = color_m_B
kwargs['hist_kwargs']    = {'edgecolor':color_m_B, 'facecolor':env_colors_m_B[1], 'fill':True}
kwargs['contour_kwargs'] = {'linewidths':1.0, 'color':color_m_B}
kwargs['reverse'] = False
subfig[0] = corner.corner(
    fig=subfig[0], data=post_B, range=ranges_B, **kwargs
)

# x/y-label kwargs and padding
ann_kwargs = dict(xycoords='axes fraction', ha='center', va='center')
labelpad = 0.6

ann2_kwargs = dict(
    xycoords='axes fraction', clip_on=False, fontsize=15, fontweight='bold',
    #path_effects=[pe.withStroke(linewidth=5, foreground='w', alpha=0.9)], 
)

for h, subfig_h in enumerate(subfig):
    
    ax = np.array(subfig_h.axes)
    ax = ax.reshape((int(np.sqrt(len(ax))), int(np.sqrt(len(ax)))))
    
    # Remove spacing between axes
    subfig_h.subplots_adjust(wspace=0.0, hspace=0.0)

    for i in range(ax.shape[0]):
        for j in range(ax.shape[1]):

            # x- and y-labels
            if h==0 and i!=0 and j==0:
                ax[i,j].annotate(labels[i], xy=(-labelpad,0.5), rotation=90, **ann_kwargs)
            if h==0 and i==len(ax)-1:
                ax[i,j].annotate(labels[j], xy=(0.5,-labelpad), **ann_kwargs)

            if h==1 and i==0:
                ax[i,j].annotate(labels[j], xy=(0.5,1+labelpad), **ann_kwargs)
            if h==1 and i!=len(ax)-1 and j==len(ax)-1:
                ax[i,j].annotate(labels[i], xy=(1+labelpad,0.5), rotation=90, **ann_kwargs)

            if h==0 and i==j:
                ax[i,j].annotate(labels[i], xy=(0.5,1+0.15), **ann_kwargs)
            if h==1 and i==j:
                ax[i,j].annotate(labels[i], xy=(0.5,-0.19), **ann_kwargs)

            # x- and y-lims
            xlim, ylim = None, None
            if h==1:
                xlim = ranges_A[::-1][j]
                if i!=j:
                    ylim = ranges_A[::-1][i]
            if h==0:
                xlim = ranges_B[j]
                if i!=j:
                    ylim = ranges_B[i]

            ax[i,j].set(xlim=xlim, ylim=ylim)
            
            # x- and y-ticks
            ax[i,j].tick_params(
                top=True, right=True, bottom=True, left=True, direction='inout'
                )

    #'''
    if h==0:
        ax[0,0].annotate(
            'Luhman 16B', xy=(0.15,1), xytext=(-0.7,1.25), ha='left', va='bottom', color=color_m_B, 
            arrowprops={
                'arrowstyle':'-', 'connectionstyle':'angle3,angleA=90,angleB=-30', 
                'shrinkA':0, 'shrinkB':0, 'lw':1.5, 'color':color_m_B, 
                }, 
            **ann2_kwargs
            )
    if h==1:
        ax[0,0].annotate(
            'Luhman 16A', xy=(0.15,0), xytext=(-0.7,-0.3), ha='left', va='top', color=color_m_A, 
            arrowprops={
                'arrowstyle':'-', 'connectionstyle':'angle3,angleA=90,angleB=30', 
                'shrinkA':0, 'shrinkB':0, 'lw':1.5, 'color':color_m_A, 
                }, 
            **ann2_kwargs
            )
    #'''
# Adjust the margins
margin = 0.08
#plt.subplots_adjust(left=margin-0.02, bottom=margin-0.02, top=1-margin, right=1-margin)
plt.subplots_adjust(left=margin, bottom=margin, top=1-margin, right=1-margin)

plt.savefig('./plots/J_band_corner.pdf')
#plt.show()
plt.close()

  ax.contour(X2, Y2, H2.T, V, **contour_kwargs)
