# **Kīauhōkū Model Offsets**
## This notebook is a tutorial on estimating the systematic uncertainties associated with various stellar model grids, as illustrated in Jamie Tayar et al. submitted 2020.
## The notebook is read-only. You can make changes, but they will not save unless you make your own copy. (So don't worry about messing it up for others!)

#### **Contact:**
- Jamie Tayar (jtayar@hawaii.edu) for information regarding the stellar model grids
- Zach Claytor (zclaytor@hawaii.edu) for information regarding kīauhōkū and detailed workings/upkeep of this notebook

#### **Resources:**
- kīauhōkū: https://github.com/zclaytor/kiauhoku
- models: https://zenodo.org/record/4307955

## First download and untar the model grids from Zenodo:

In [None]:
!wget https://zenodo.org/record/4307955/files/eep_grids.tar.gz
!tar -xzvf eep_grids.tar.gz

## Next, install kiauhoku

In [None]:
!pip install kiauhoku

In [None]:
import numpy as np
import pandas as pd
import kiauhoku as kh

## Load grids, unify column names, and cast to interpolators

In [None]:
# use grid points between ZAMS and RGBump
yrec = kh.stargrid.from_parquet(path='eep_grids/yrec.pqt').query('201 <= eep <= 605')
mist = kh.stargrid.from_parquet(path='eep_grids/mist.pqt').query('0.5 <= initial_mass <= 2 and 201 <= eep <= 605')
dart = kh.stargrid.from_parquet(path='eep_grids/dartmouth.pqt').query('201 <= eep <= 605')
gars = kh.stargrid.from_parquet(path='eep_grids/garstec.pqt').query('201 <= eep <= 605')

In [None]:
# unify column names
yrec['mass'] = yrec['Mass(Msun)']
yrec['teff'] = 10**yrec['Log Teff(K)']
yrec['lum'] = 10**yrec['L/Lsun']
yrec['met'] = np.log10(yrec['Zsurf']/yrec['Xsurf']/0.0253)
yrec['age'] = yrec['Age(Gyr)']

mist['mass'] = mist['star_mass']
mist['teff'] = 10**mist['log_Teff']
mist['lum'] = 10**mist['log_L']
mist['met'] = mist['log_surf_z'] - np.log10(mist['surface_h1']*0.0173)
mist['logg'] = mist['log_g']
mist['age'] = mist['star_age'] / 1e9

dart['mass'] = dart.index.to_frame()['initial_mass']
dart['teff'] = 10**dart['Log T']
dart['lum'] = 10**dart['Log L']
dart['met'] = np.log10(dart['(Z/X)_surf']/0.0229)
dart['logg'] = dart['Log g']
dart['age'] = dart['Age (yrs)'] / 1e9

gars['mass'] = gars['M/Msun']
gars['teff'] = gars['Teff']
gars['lum'] = 10**gars['Log L/Lsun']
gars['met'] = np.log10(gars['Zsurf']/gars['Xsurf']/0.0245)
gars['age'] = gars['Age(Myr)'] / 1e3

yrec_interp = yrec.to_interpolator()
mist_interp = mist.to_interpolator()
dart_interp = dart.to_interpolator()
gars_interp = gars.to_interpolator()

## Define fitting function
##### This may get directly implemented into kiauhoku in the near future.

In [None]:
def gridsearch_fit(star, gridname, tol=1e-6, scale=(1000, 1, 0.1),
                   mass_step=0.1, met_step=0.2, eep_step=50):
    if gridname == 'yrec':
        grid = yrec
        interp = yrec_interp
    elif gridname == 'mist':
        grid = mist
        interp = mist_interp
    elif gridname == 'dartmouth':
        grid = dart
        interp = dart_interp
    elif gridname == 'garstec':
        grid = gars
        interp = gars_interp
    else:
        raise ValueError(f"Bad grid name: '{gridname}'")

    print(f'Fitting star with {gridname}...')

    # Construct a multi-index instead of using a triple-nested for-loop
    idxrange = grid.index_range
    mass_list = np.arange(*idxrange['initial_mass'], mass_step)
    met_list = np.arange(*idxrange['initial_met'], met_step)
    eep_list = np.arange(252, 606, eep_step)
    idx_list = pd.MultiIndex.from_product(
        [mass_list, met_list, eep_list])
    
    # Loop through indices searching for fit
    best_loss = 100
    some_fit = False
    good_fit = False
    for idx in idx_list:
        model, fit = interp.fit_star(star, idx, 
            None,  # Nelder-Mead does not accept bounds
            scale, # Scale input to put them ~ the same magnitude
            method='Nelder-Mead')
        if fit.success:
            some_fit = True
            if fit.fun < best_loss:
                best_model = model
                best_fit = fit
                best_loss = fit.fun
                if fit.fun <= tol:
                    good_fit = True
                    print(f'{gridname}: success!')
                    break

    # Check to see how the fit did, print comments if necessary.
    if not some_fit:
        print(f'*!*!*!* {gridname} fit failed! Returning last attempt.')
        return None, fit
    if not good_fit:
        print(f'{gridname}: Fit not converged to within tolerance, but returning closest fit.')

    # add the indices and return
    m, z, e = best_fit.x
    best_model['initial_mass'] = m
    best_model['initial_met'] = z
    best_model['eep'] = e
    return best_model, best_fit


def fit_all_grids(star, *args, **kwargs):
    gridnames = []
    models = []
    for grid in ['yrec', 'mist', 'dartmouth', 'garstec']:
        model, fit = gridsearch_fit(star, grid, *args, **kwargs)
        if fit.success:
            gridnames.append(grid)
            models.append(
                model[['initial_mass', 'initial_met', 'eep', 'mass', 'teff', 'lum', 'met', 'logg', 'age']]
            )
    models = pd.concat(models, axis=1)
    models.columns = gridnames

    return models

def compute_statistics(models, exclude=None):
    stats = models.copy()
    if exclude is not None:
        stats = stats.drop(columns=exclude)

    mean = stats.mean(axis=1)
    stdev = stats.std(axis=1, ddof=1)
    max_offset = stats.max(axis=1) - stats.min(axis=1)

    stats['mean'] = mean
    stats['stdev'] = stdev
    stats['max offset'] = max_offset

    return stats

## Define stellar examples and run!

### $\pi$ Men

In [None]:
piMen  = {'teff':6037, 'lum':1.444, 'met':0.08}
models = fit_all_grids(piMen, scale=(1000, 1, 0.1), tol=1e-6)
models

In [None]:
stats = compute_statistics(models, exclude=None)
stats

### TOI 197

In [None]:
toi197 = {'teff':5080, 'lum':5.15, 'met':-0.08}
models = fit_all_grids(toi197, scale=(1000, 1, 0.1), tol=1e-6)
models

In [None]:
stats = compute_statistics(models, exclude=None)
stats

### Sun, using Teff and Luminosity

In [None]:
sun1 = {'teff':5772, 'lum':1, 'met':0}
models = fit_all_grids(sun1, scale=(1000, 1, 0.1), tol=1e-6)
models

In [None]:
stats = compute_statistics(models, exclude=None)
stats

### Sun, using Teff logg

In [None]:
sun2 = {'teff':5772, 'logg':4.44, 'met':0}
models = fit_all_grids(sun2, scale=(1000, 1, 0.1), tol=1e-6)
models

In [None]:
stats = compute_statistics(models, exclude=None)
stats

### Sun, using Mass and Age

In [None]:
sun3 = {'age': 4.57, 'mass':1, 'met':0}
models = fit_all_grids(sun3, scale=(1, 0.1, 0.1), tol=1e-6)
models

In [None]:
stats = compute_statistics(models, exclude=None)
stats