In [26]:
%run setup.py

import jax.numpy as jnp
import jax
import numpy as np
from tqdm import tqdm
import scipy.interpolate

# Read in eligible models

In [215]:
stars = pd.read_excel('sample/samples.xlsx')

# subsets = 'ms'
# subsets = 'rgb'
subsets = 'all'

if subsets == 'all' :
    idx = (stars['ifmodelling']==1) & (~np.isin(stars['names'], ['ngc6791', 'ngc6819', 'binary']))
elif subsets == 'rgb' :
    idx = np.isin(stars['stage'], ['rgb']) & (stars['ifmodelling']==1) & (~np.isin(stars['names'], ['ngc6791', 'ngc6819', 'binary']))
else :
    idx = np.isin(stars['stage'], ['esg', 'lsg', 'ms']) & (stars['ifmodelling']==1) & (~np.isin(stars['names'], ['ngc6791', 'ngc6819', 'binary']))
    
stars = stars.loc[idx,:].reset_index(drop=True)

In [221]:
stardata = np.load('data/stellar_models_for_surface_optimisation.npy', allow_pickle=True)

In [222]:
Nmodels_star = np.array([stardata[istar]['Nmodels'] for istar in range(len(stardata))])
Nmodes_star = np.array([stardata[istar]['Nmodes'] for istar in range(len(stardata))])
kics_star = np.array([stardata[istar]['KIC'] for istar in range(len(stardata))])
# minchi2 = np.array([np.min(stardata[istar]['chi2']) for istar in range(len(stardata))])
# stardata = stardata[ (minchi2<20) & (Nmodels == 1000)]
# stardata = stardata[Nmodels]
# np.random.seed(0)
# stardata = np.random.choice(stardata, size=500, replace=False)
# Nstars = len(stardata)
stardata = stardata[(Nmodes_star >= 4) & np.isin(kics_star, stars['KIC'])] #(Nmodels_star == 200) & 
Nstars = len(stardata)

In [223]:
Nstars

1162

In [224]:
Nmodes_max = np.max([stardata[istar]['Nmodes'] for istar in range(Nstars)])
Nmodels_max = np.max([stardata[istar]['Nmodels'] for istar in range(Nstars)])
chi2s_nonseis = np.zeros((Nmodels_max, Nstars))
gravs = np.ones((Nmodels_max, Nmodes_max, Nstars)) + 1.
Teffs = np.ones((Nmodels_max, Nmodes_max, Nstars)) + 5777.
fehs = np.ones((Nmodels_max, Nmodes_max, Nstars)) + 0. 
numaxs = np.ones((Nmodels_max, Nmodes_max, Nstars)) + 3090.
# f = (numax/nuac)^-1/I
# g = (numax/nuac)^3/I
# scale
# D1 = dfreq(numax) = a1 * f + a3 * g
# D2 = dfreq(scale*numax) = s^-1 * a1 * f + s^3 * a3 * g
# scale = 1.2
fs, gs, fs_d, gs_d  = [np.ones((Nmodels_max, Nmodes_max, Nstars)) for i in range(4)]
obs_freq_data, obs_efreq_data, mod_freq_data, mod_inertia_data, mod_efreq_sys = [np.ones((Nmodels_max, Nmodes_max, Nstars)) for i in range(5)]

Nmodes_imod_istar = np.ones((Nmodels_max, Nstars))
weights_imod_istar = np.ones((Nmodels_max, Nstars))
imode_imod_istar = np.zeros((Nmodels_max, Nmodes_max, Nstars))

idxf = np.zeros(Nstars, dtype=bool)

for istar in tqdm(range(Nstars)):
    Nmodes = stardata[istar]['Nmodes']
    Nmodels = stardata[istar]['Nmodels']
    mod_freq_data[0:Nmodels,0:Nmodes,istar] = stardata[istar]['mod_freqs'][:,:]
    mod_inertia_data[0:Nmodels,0:Nmodes,istar] = stardata[istar]['mod_inertias'][:,:]
    chi2s_nonseis[0:Nmodels,istar] = stardata[istar]['chi2_nonseis']

    for imod in range(Nmodels):
        obs_freq_data[imod,0:Nmodes,istar] = stardata[istar]['obs_freq']
        obs_efreq_data[imod,0:Nmodes,istar] = stardata[istar]['obs_efreq'] #np.ones(len(stardata[istar]['obs_freq'])) # 
        mod_efreq_sys[imod,0:Nmodes,istar] = stardata[istar]['mod_efreq_sys']
            
        fi = scipy.interpolate.interp1d(stardata[istar]['mod_freqs'][imod,:], 
              (stardata[istar]['mod_freqs'][imod,:]/stardata[istar]['numax_scaling'][imod])**3.0/stardata[istar]['mod_inertias'][imod,:], 
              kind='cubic', fill_value='extrapolate')
        gs[imod,:,istar] = fi(stardata[istar]['numax_scaling'][imod])
        # gs_d[imod,:,istar] = (fi(stardata[istar]['numax_scaling'][imod]+0.01) - fi(stardata[istar]['numax_scaling'][imod]-0.01) )/ (0.02)
        
        fi = scipy.interpolate.interp1d(stardata[istar]['mod_freqs'][imod,:], 
              (stardata[istar]['mod_freqs'][imod,:]/stardata[istar]['numax_scaling'][imod])**-1.0/stardata[istar]['mod_inertias'][imod,:], 
              kind='cubic', fill_value='extrapolate')
        fs[imod,:,istar] = fi(stardata[istar]['numax_scaling'][imod])
        # fs_d[imod,:,istar] = (fi(stardata[istar]['numax_scaling'][imod]+0.01) - fi(stardata[istar]['numax_scaling'][imod]-0.01) )/ (0.02)
        
        gravs[imod,:,istar] = stardata[istar]['g'][imod]
        Teffs[imod,:,istar] = stardata[istar]['Teff'][imod]
        fehs[imod,:,istar] = stardata[istar]['feh'][imod]
        numaxs[imod,:,istar] = stardata[istar]['numax_scaling'][imod]
    
    Nmodes_imod_istar[:, istar] = Nmodes
    weights_imod_istar[0:Nmodels, istar] = 1/Nmodels
    weights_imod_istar[Nmodels:, istar] = 0
    imode_imod_istar[:, 0:Nmodes, istar] = 1.
    idxf[istar] = np.sum(np.isfinite(gravs[:,:,istar]))

gravs, Teffs, fehs, numaxs = gravs[:,:,idxf], Teffs[:,:,idxf], fehs[:,:,idxf], numaxs[:,:,idxf]
fs, gs = fs[:,:,idxf], gs[:,:,idxf]
# fs_d, gs_d = fs_d[:,:,idxf], gs_d[:,:,idxf]
obs_freq_data, obs_efreq_data = obs_freq_data[:,:,idxf], obs_efreq_data[:,:,idxf]
mod_freq_data, mod_inertia_data = mod_freq_data[:,:,idxf], mod_inertia_data[:,:,idxf]
chi2s_nonseis = chi2s_nonseis[:,idxf]

weight_seis = 1. # 1/10.
weight_nonseis = 1.

100%|████████████████████████████████████████████████████████████████████| 1162/1162 [00:49<00:00, 23.48it/s]


In [225]:
mod_freq_data.shape, gravs.shape  # Nmodels_max, Nmodes_max, Nstars

((1500, 21, 1162), (1500, 21, 1162))

In [226]:
# make everything jax
mod_freq_data = jnp.asarray(mod_freq_data)
numaxs = jnp.asarray(numaxs)
mod_inertia_data = jnp.asarray(mod_inertia_data)
obs_freq_data = jnp.asarray(obs_freq_data)

Teffs = jnp.asarray(Teffs)
gravs = jnp.asarray(gravs)
fehs = jnp.asarray(fehs)

gs = jnp.asarray(gs)
fs = jnp.asarray(fs)

weights_imod_istar = jnp.asarray(weights_imod_istar)
imode_imod_istar = jnp.asarray(imode_imod_istar)

# Optimisation

In [227]:
# fitting model to be minimized
def model_linear(thetas, *s): # use numax and scale*numax, feh
    D1 = thetas[0] * (gravs)**thetas[1] * (Teffs/5777.)**thetas[2] * (thetas[3]*fehs + 1) # surf corr at numax
    D2 = thetas[4] * (gravs)**thetas[5] * (Teffs/5777.)**thetas[6] * (thetas[7]*fehs + 1) # surf corr at scale*numax
    scale = 1.1

    surface_a3 = (scale**-1 * D1 - D2)/ ((scale**-1 - scale**3) *gs)
    surface_a1 = (scale**3 * D1 - D2)/ ((scale**3 - scale**-1) *fs)
                
    dfreq = (surface_a3 * (mod_freq_data/numaxs)**3.0 + surface_a1 * (mod_freq_data/numaxs)**-1.0)/mod_inertia_data
    mod_freq_corr_data = mod_freq_data + dfreq 
    
    chi2s_seis = jnp.mean((mod_freq_corr_data-obs_freq_data)**2.0/(obs_efreq_data**2. + mod_efreq_sys**2.), 
                          axis=1,
                          where=imode_imod_istar==1.)

    chi2 = chi2s_seis*weight_seis + chi2s_nonseis*weight_nonseis
    logP_star = jax.scipy.special.logsumexp(-chi2/2, axis=0, b=weights_imod_istar)
    
    return -jnp.nansum(logP_star)

jit_model_linear = jax.jit(model_linear)

In [228]:
paramsInit = jnp.array([-4.22, 1.0, -8.44, -0.41, #[-4.22, 0.99, -6.44, -0.41]
                  -5.22, 1.0, -6.44, -0.41])
# paramsInit = jnp.array([-7.883228  ,  1.2839227 , -6.3577967 , -0.01633822,
#              -8.4594965 ,  1.1317054 , -5.3492737 , -0.3091406 ])

# RGB:
# [-7.9282274 ,  1.2846141 , -6.3388844 , -0.01533486,
#  -8.494455  ,  1.1319928 , -5.3336873 , -0.30835932]
# [0.41842967, 0.02130897, 0.3689297 , 0.03597676, 
# 0.4364696 , 0.0170249 , 0.3102947 , 0.03083902]
    
# test if works
print(jit_model_linear(paramsInit))

# test if grad works
print(jax.grad(jit_model_linear)(paramsInit))

2022-12-13 00:06:21.176150: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:65] Constant folding an instruction is taking > 4s:

  reduce.109 (displaying the full instruction incurs a runtime overhead. Raise your logging level to 4 or above).

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime.  XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2022-12-13 00:06:22.150347: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 4.9835s
Constant folding an instruction is taking > 4s:

  reduce.109 (displaying the full instruction incurs a runtime overhead. Raise your logging level to 4 or above).

This isn't nece

44991.23
[-2.4030268e+04 -3.6665669e+05 -1.9202863e+04  1.6515952e+03
  1.7273779e+04  3.2324522e+05  1.6820492e+04  4.5648476e+01]


In [230]:
%timeit jit_model_linear(paramsInit)

150 ms ± 1.47 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
import optax

def fit(params, optimizer):
    opt_state = optimizer.init(params)

    @jax.jit
    def step(params, opt_state):
        loss_value, grads = jax.value_and_grad(model_linear)(params)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss_value

    for i  in range(10000):
        params, opt_state, loss_value = step(params, opt_state)
        if i % 200 == 0:
            print(f'step {i}, loss: {loss_value}')

    return params

# Finally, we can fit our parametrized function using the Adam optimizer
# provided by optax.
optimizer = optax.adam(learning_rate=2e-2)
params = fit(paramsInit, optimizer)

step 0, loss: 44991.23046875
step 200, loss: 23910.59375
step 400, loss: 23828.685546875
step 600, loss: 23810.46875
step 800, loss: 23801.119140625
step 1000, loss: 23732.498046875
step 1200, loss: 23729.89453125
step 1400, loss: 23728.525390625
step 1600, loss: 23727.44140625
step 1800, loss: 23726.53515625
step 2000, loss: 23725.80859375
step 2200, loss: 23725.234375
step 2400, loss: 23724.771484375
step 2600, loss: 23724.392578125
step 2800, loss: 23724.06640625
step 3000, loss: 23723.7734375
step 3200, loss: 23723.494140625
step 3400, loss: 23723.25
step 3600, loss: 23723.06640625
step 3800, loss: 23722.875


In [None]:
# estimated values
params

In [None]:
# approx uncertainties from the Hessian matrix
err = jnp.diag(jnp.abs(jnp.linalg.inv(jax.hessian(jit_model_linear)(params))))**0.5
err

# Results
## all sample
DeviceArray([-3.7443726,  1.0866573, -8.742777 , -1.3795667, 
             -4.9891973,  1.0541984, -8.15369  , -1.3154854], dtype=float32)

DeviceArray([0.13330452, 0.01283954, 0.22996646, 0.02094826, 
             0.13725656, 0.00977644, 0.18469958, 0.01802539], dtype=float32)
             
## RGB sample
DeviceArray([-3.8369734,  1.0976515, -8.83042  , -1.3841832, 
             -5.0435867,  1.0658274, -8.325036 , -1.3250351], dtype=float32)
             
DeviceArray([0.16418463, 0.01225858, 0.24061048, 0.02119218, 
             0.16585062, 0.00967092, 0.1921673 , 0.01829769], dtype=float32)

## MS+SG sample
DeviceArray([-3.715097  ,  0.6096728 , -1.6813153 , -0.08168457,
             -5.550722  ,  0.6479634 , -1.5653569 , -0.64631474],            dtype=float32)
   
DeviceArray([0.40209517, 0.0793009 , 0.82480085, 0.1394154 , 
             0.42107266, 0.05699116, 0.6225904 , 0.11118653], dtype=float32)

