In [1]:
%run setup.py

import jax.numpy as jnp
import jax
import numpy as np
from tqdm import tqdm
import scipy.interpolate

# Read in eligible models

In [26]:
stars = pd.read_excel('sample/samples.xlsx')

subsets = 'ms'
# subsets = 'rgb'
# subsets = 'all'

if subsets == 'all' :
    idx = (stars['ifmodelling']==1) & (~np.isin(stars['names'], ['ngc6791', 'ngc6819', 'binary']))
elif subsets == 'rgb' :
    idx = np.isin(stars['stage'], ['rgb']) & (stars['ifmodelling']==1) & (~np.isin(stars['names'], ['ngc6791', 'ngc6819', 'binary']))
else :
    idx = np.isin(stars['stage'], ['esg', 'lsg', 'ms']) & (stars['ifmodelling']==1) & (~np.isin(stars['names'], ['ngc6791', 'ngc6819', 'binary']))
    
stars = stars.loc[idx,:].reset_index(drop=True)

In [27]:
stardata = np.load('data/stellar_models_for_surface_optimisation.npy', allow_pickle=True)

In [28]:
Nmodels_star = np.array([stardata[istar]['Nmodels'] for istar in range(len(stardata))])
Nmodes_star = np.array([stardata[istar]['Nmodes'] for istar in range(len(stardata))])
kics_star = np.array([stardata[istar]['KIC'] for istar in range(len(stardata))])
# minchi2 = np.array([np.min(stardata[istar]['chi2']) for istar in range(len(stardata))])
# stardata = stardata[ (minchi2<20) & (Nmodels == 1000)]
# stardata = stardata[Nmodels]
# np.random.seed(0)
# stardata = np.random.choice(stardata, size=500, replace=False)
# Nstars = len(stardata)
stardata = stardata[(Nmodels_star == 200) & (Nmodes_star >= 5) & np.isin(kics_star, stars['KIC'])]
Nstars = len(stardata)

In [29]:
Nstars

67

In [30]:
Nmodes_max = np.max([stardata[istar]['Nmodes'] for istar in range(Nstars)])
Nmodels_max = np.max([stardata[istar]['Nmodels'] for istar in range(Nstars)])
chi2s_nonseis = np.zeros((Nmodels_max, Nstars))+np.nan 
gravs, Teffs, fehs, numaxs = [np.zeros((Nmodels_max, Nmodes_max, Nstars))+np.nan for i in range(4)]
# f = (numax/nuac)^-1/I
# g = (numax/nuac)^3/I
# scale
# D1 = dfreq(numax) = a1 * f + a3 * g
# D2 = dfreq(scale*numax) = s^-1 * a1 * f + s^3 * a3 * g
# scale = 1.2
fs, gs, fs_d, gs_d  = [np.zeros((Nmodels_max, Nmodes_max, Nstars))+np.nan for i in range(4)]
obs_freq_data, obs_efreq_data, mod_freq_data, mod_inertia_data, mod_efreq_sys = [np.ones((Nmodels_max, Nmodes_max, Nstars)) for i in range(5)]

Nmodes_imod_istar = np.ones((Nmodels_max, Nstars))
weights_imod_istar = np.ones((Nmodels_max, Nstars))

idxf = np.zeros(Nstars, dtype=bool)

for istar in tqdm(range(Nstars)):
    Nmodes = stardata[istar]['Nmodes']
    Nmodels = stardata[istar]['Nmodels']
    mod_freq_data[0:Nmodels,0:Nmodes,istar] = stardata[istar]['mod_freqs'][:,:]
    mod_inertia_data[0:Nmodels,0:Nmodes,istar] = stardata[istar]['mod_inertias'][:,:]
    chi2s_nonseis[0:Nmodels,istar] = stardata[istar]['chi2_nonseis']

    for imod in range(Nmodels):
        obs_freq_data[imod,0:Nmodes,istar] = stardata[istar]['obs_freq']
        obs_efreq_data[imod,0:Nmodes,istar] = stardata[istar]['obs_efreq'] #np.ones(len(stardata[istar]['obs_freq'])) # 
        mod_efreq_sys[imod,0:Nmodes,istar] = stardata[istar]['mod_efreq_sys']
            
        fi = scipy.interpolate.interp1d(stardata[istar]['mod_freqs'][imod,:], 
              (stardata[istar]['mod_freqs'][imod,:]/stardata[istar]['numax_scaling'][imod])**3.0/stardata[istar]['mod_inertias'][imod,:], 
              kind='cubic', fill_value='extrapolate')
        gs[imod,:,istar] = fi(stardata[istar]['numax_scaling'][imod])
        # gs_d[imod,:,istar] = (fi(stardata[istar]['numax_scaling'][imod]+0.01) - fi(stardata[istar]['numax_scaling'][imod]-0.01) )/ (0.02)
        
        fi = scipy.interpolate.interp1d(stardata[istar]['mod_freqs'][imod,:], 
              (stardata[istar]['mod_freqs'][imod,:]/stardata[istar]['numax_scaling'][imod])**-1.0/stardata[istar]['mod_inertias'][imod,:], 
              kind='cubic', fill_value='extrapolate')
        fs[imod,:,istar] = fi(stardata[istar]['numax_scaling'][imod])
        # fs_d[imod,:,istar] = (fi(stardata[istar]['numax_scaling'][imod]+0.01) - fi(stardata[istar]['numax_scaling'][imod]-0.01) )/ (0.02)
        
        gravs[imod,:,istar] = stardata[istar]['g'][imod]
        Teffs[imod,:,istar] = stardata[istar]['Teff'][imod]
        fehs[imod,:,istar] = stardata[istar]['feh'][imod]
        numaxs[imod,:,istar] = stardata[istar]['numax_scaling'][imod]
    
    Nmodes_imod_istar[:, istar] = Nmodes
    weights_imod_istar[0:Nmodels, istar] = np.log(1/Nmodels)
    weights_imod_istar[Nmodels:, istar] = -10000
    idxf[istar] = np.sum(np.isfinite(gravs[:,:,istar]))

gravs, Teffs, fehs, numaxs = gravs[:,:,idxf], Teffs[:,:,idxf], fehs[:,:,idxf], numaxs[:,:,idxf]
fs, gs = fs[:,:,idxf], gs[:,:,idxf]
# fs_d, gs_d = fs_d[:,:,idxf], gs_d[:,:,idxf]
obs_freq_data, obs_efreq_data = obs_freq_data[:,:,idxf], obs_efreq_data[:,:,idxf]
mod_freq_data, mod_inertia_data = mod_freq_data[:,:,idxf], mod_inertia_data[:,:,idxf]
chi2s_nonseis = chi2s_nonseis[:,idxf]

weight_seis = 1. # 1/10.
weight_nonseis = 1.

100%|███████████████████████████████████████████████████████| 67/67 [00:01<00:00, 45.85it/s]


In [31]:
mod_freq_data.shape, gravs.shape  # Nmodels_max, Nmodes_max, Nstars

((200, 21, 67), (200, 21, 67))

In [32]:
# make everything jax
mod_freq_data = jnp.asarray(mod_freq_data)
numaxs = jnp.asarray(numaxs)
mod_inertia_data = jnp.asarray(mod_inertia_data)
obs_freq_data = jnp.asarray(obs_freq_data)

Teffs = jnp.asarray(Teffs)
gravs = jnp.asarray(gravs)
fehs = jnp.asarray(fehs)

gs = jnp.asarray(gs)
fs = jnp.asarray(fs)

weights_imod_istar = jnp.asarray(weights_imod_istar)

# Optimisation

In [33]:
# fitting model to be minimized
def model_linear(thetas, *s): # use numax and scale*numax, feh
    D1 = thetas[0] * (gravs)**thetas[1] * (Teffs/5777.)**thetas[2] * (thetas[3]*fehs + 1) # surf corr at numax
    surface_a3 = D1/gs
                
    dfreq = (surface_a3*(mod_freq_data/numaxs)**3.0)/mod_inertia_data
    mod_freq_corr_data = mod_freq_data + dfreq 
    
    chi2s_seis = jnp.mean((mod_freq_corr_data-obs_freq_data)**2.0/(obs_efreq_data**2. + mod_efreq_sys**2.), 
                          axis=1,
                          where=mod_freq_data!=1.)

    chi2 = chi2s_seis*weight_seis + chi2s_nonseis*weight_nonseis
    logP_star = jax.scipy.special.logsumexp(-chi2/2+weights_imod_istar, axis=0)
    
    return -jnp.nansum(logP_star)

jit_model_linear = jax.jit(model_linear)

In [34]:
paramsInit = jnp.array([-4.22, 1.0, -8.44, -0.41])

# test if works
print(jit_model_linear(paramsInit))

# test if grad works
print(jax.grad(jit_model_linear)(paramsInit))

596.0312
[ 31.261253  287.25397    12.425717   -2.5516622]


In [35]:
import optax

def fit(params, optimizer):
    opt_state = optimizer.init(params)

    @jax.jit
    def step(params, opt_state):
        loss_value, grads = jax.value_and_grad(model_linear)(params)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss_value

    for i  in range(3000):
        params, opt_state, loss_value = step(params, opt_state)
        if i % 300 == 0:
            print(f'step {i}, loss: {loss_value}')

    return params

# Finally, we can fit our parametrized function using the Adam optimizer
# provided by optax.
optimizer = optax.adam(learning_rate=1e-2)
params = fit(paramsInit, optimizer)

step 0, loss: 596.0311889648438
step 300, loss: 583.40869140625
step 600, loss: 582.3173828125
step 900, loss: 582.309326171875
step 1200, loss: 582.3093872070312
step 1500, loss: 582.3093872070312
step 1800, loss: 582.3093872070312
step 2100, loss: 582.3094482421875
step 2400, loss: 582.3093872070312
step 2700, loss: 582.3093872070312


In [38]:
# estimated values
params

DeviceArray([-4.186023  ,  0.7813644 , -5.7118735 , -0.07299965], dtype=float32)

In [39]:
# approx uncertainties from the Hessian matrix
err = jnp.diag(jnp.linalg.inv(jax.hessian(jit_model_linear)(params)))**0.5
err

DeviceArray([0.3921687 , 0.06778456, 0.5822848 , 0.13034512], dtype=float32)

# Results
## all sample
DeviceArray([-4.1485624,  0.9450712, -5.479969 , -1.0995963], dtype=float32)
             
DeviceArray([0.13009433, 0.01215709, 0.23190716, 0.02584277], dtype=float32)
             
## RGB sample
DeviceArray([-4.181521  ,  0.95523196, -5.6379514 , -1.1099834 ], dtype=float32)

DeviceArray([0.13015431, 0.01164923, 0.22260946, 0.02455296], dtype=float32)

## MS+SG sample
DeviceArray([-4.186023  ,  0.7813644 , -5.7118735 , -0.07299965], dtype=float32)

DeviceArray([0.3921687 , 0.06778456, 0.5822848 , 0.13034512], dtype=float32)