## Setup path to starwave code and data location

In [None]:
codepath = '/Users/gennaro/starwave/'
datapath = '/Users/gennaro/Desktop/LMC_starwave_data/'

## Imports

In [None]:
import sys
sys.path.append(codepath)
import starwave

import numpy as np
import matplotlib.pyplot as plt
import bz2
import pandas as pd
from matplotlib.colors import LogNorm
import torch
import corner
from sbi import utils
import sbi
from sbi.analysis import pairplot

%matplotlib notebook

### Load Isochrones

In [None]:
with bz2.BZ2File(datapath+'/isolib_LMC_df_corr.pbz2') as f:
    isodf = pd.read_pickle(f)

In [None]:
isodf.head()

### Load Artifical Stars

In [None]:
with bz2.BZ2File(datapath+'/AScat.pbz2') as f:
    asdf = pd.read_pickle(f)
asdf = pd.DataFrame(asdf)
print('Number of ASTs:',len(asdf))

#### Rename to correct format, remove undetected stars

In [None]:
detected = np.array(asdf['AS_det'])
asdf['F110W_in'] = asdf['AS_mag1_in']
asdf['F110W_out'] = asdf['AS_mag1_out']

asdf['F160W_in'] = asdf['AS_mag2_in']
asdf['F160W_out'] = asdf['AS_mag2_out']

asdf['F110W_out'].loc[~detected] = np.nan
asdf['F160W_out'].loc[~detected] = np.nan

#### Plot the input and output CMDs for the ASTs

In [None]:
f,ax = plt.subplots(1,2,sharex=True,sharey=True)

ax[0].hist2d(asdf['F110W_in'] - asdf['F160W_in'], asdf['F160W_in'],
          bins = 250, norm = LogNorm(),
            range = ((-1.5, 0), (18, 32)));

ax[0].set_xlabel('$F110W - F160W$')
ax[0].set_ylabel('$F160W$')
ax[0].set_title('Input magnitudes')

ax[1].hist2d(asdf['F110W_out'] - asdf['F160W_out'], asdf['F160W_out'],
          bins = 250, norm = LogNorm(),
          range = ((-1.5, 0), (18, 32)));

ax[1].set_xlim(-1.5,0.)
ax[1].set_ylim(32,18)
ax[1].set_xlabel('$F110W - F160W$')
ax[1].set_title('Output magnitudes')
f.tight_layout()

#### Plot the run of completness as a function of magnitude in each bamd

In [None]:
f,ax = plt.subplots(2,1)

n1,b,p = ax[0].hist(asdf['F110W_in'],bins=50)
n2,b,p = ax[0].hist(asdf['F110W_out'],bins=b)

ax2 = ax[0].twinx()
ax2.plot(0.5*(b[:-1]+b[1:]),n2/n1,color='red')
ax2.set_ylim(0,1.1)
ax2.set_ylabel('Completeness',color='red')
ax2.tick_params(axis="y", labelcolor='red')
ax[0].set_title('F110W')

n1,b,p = ax[1].hist(asdf['F160W_in'],bins=50)
n2,b,p = ax[1].hist(asdf['F160W_out'],bins=b)

ax[1].set_title('F110W')
ax3 = ax[1].twinx()
ax3.plot(0.5*(b[:-1]+b[1:]),n2/n1,color='red')
ax3.set_ylim(0,1.1)
ax3.set_ylabel('Completeness',color='red')
ax3.tick_params(axis="y", labelcolor='red')


f.tight_layout()

### Load SFH Grid

In [None]:
with bz2.BZ2File(datapath+'/SFH_2d_frommarginals_metalbound.pbz2') as f:
    sfh_grid = pd.read_pickle(f)

#### Plot the 2D SFH

In [None]:
minmet,maxmet = sfh_grid['mets'].min(),sfh_grid['mets'].max()
minage,maxage = sfh_grid['ages'].min(),sfh_grid['ages'].max()


f,ax = plt.subplots(1,1)

ax.imshow(sfh_grid['probabilities'],origin='lower', interpolation='none', extent=[minmet,maxmet,minage,maxage])
ax.set_aspect(0.1)
ax.set_xlabel('[Fe/H]')
ax.set_ylabel('Age (Gyr)')
ax.set_title('Star Formation History')
f.tight_layout()

### Load HST LMC Data

In [None]:
with bz2.BZ2File(datapath+'/catalog.pbz2') as f:
    lmc = pd.read_pickle(f)
lmc = pd.DataFrame(lmc)

lmc = lmc[lmc['dat_det']]

In [None]:
lmc['F110W'] = lmc['dat_mag1']
lmc['F160W'] = lmc['dat_mag2']

In [None]:
f,ax = plt.subplots(1,1)
ax.scatter(lmc['F110W'] - lmc['F160W'], lmc['F160W'], s = 1, alpha = 0.5, color = 'k')

ax.invert_yaxis()

ax.set_xlabel('$F110W - F160W$')
ax.set_ylabel('$F160W$')
ax.set_title('Data CMD')

f.tight_layout()

## A few sanity checks

In [None]:
print('Min and max ages available in the isochrones: {} , {}'.format(np.min(isodf.index.get_level_values('age')),np.max(isodf.index.get_level_values('age'))))
print('***')
print('Min and max ages in the provided 2d SFH: {} {}'.format(minage,maxage))

In [None]:
print('Min and max [Fe/H] available in the isochrones: {} , {}'.format(np.min(isodf.index.get_level_values('[Fe/H]')),np.max(isodf.index.get_level_values('[Fe/H]'))))
print('***')
print('Min and max [Fe/H] in the provided 2d SFH: {} {}'.format(minmet,maxmet))

## Setup starwave

### The names of the photometric bands to use

In [None]:
shortbands = ['F160W', 'F110W']
print('Photometric passbands to use:')
print(shortbands)

### Instantiate a starwave object

In [None]:
sw = starwave.StarWave(isodf = isodf, asdf = asdf, imf_type = 'spl',
                      bands = shortbands, band_lambdas = [15369., 11534.],
                       sfh_type = 'grid', sfh_grid = sfh_grid)

### Setup some paramaters as fixed or to be fitted for

In [None]:
logint = 4.8
slope = -2.5

sw.params['dm'].set(value = 18.52, fixed = True, bounds = [18, 19])
sw.params['log_int'].set(value = logint, bounds = (logint - .5, logint + .25))
sw.params['bf'].set(value = 0.4, fixed = True)
sw.params['slope'].set(value = slope, bounds = (-4, -1))
sw.params['av'].set(value = 3.1 * 0.075, bounds = (0, 1), fixed = True)


In [None]:
sw.params

### Generate a fake CMD that looks somewhat like the LMC one and retrieve its parameters

In [None]:
params = torch.tensor([logint, -2.3])
eg_cmd = sw.sample_cmd(params, model = 'spl')[1]

### Plot LMC and sampled CMD

In [None]:
lmc_mags = np.vstack((lmc[shortbands[0]], lmc[shortbands[1]])).T
lmc_cmd = sw.make_cmd(lmc_mags)

f,ax =plt.subplots(1,1)
starwave.plot_cmd(lmc_cmd, bands = shortbands)
starwave.plot_cmd(eg_cmd, bands = shortbands)

plt.gca().invert_yaxis()

### Run Starwave to fit the fake LMC CMD

#### Train the posterior

In [None]:
posteriors = sw.fit_cmd(eg_cmd, cores = 1, n_sims = 100, n_rounds = 3, gamma = None, 
                       gamma_kw = dict(q = 0.68, NN = 5, fac = 1))

#### Sample from the trained posterior and plot the results

In [None]:
posterior = sw.posteriors[-1]

In [None]:
print(posterior)

In [None]:
posterior_samples = posterior.sample((1000000,), x = sw.obs)

In [None]:
f = corner.corner(np.array(posterior_samples), show_titles = True, 
                  labels = list(sw.param_mapper.keys()), title_kwargs = dict(fontsize = 14),
                  truths = [logint, slope])