In [None]:
#to read files, get the stellar parameters, blur using a Gaussian kernel 
import numpy as np
import pandas as pd
import re, glob
from astropy.convolution import convolve, Gaussian1DKernel
#gives file name without directories - for obtaining stellar parameters
from pathlib import Path

#fixed wavelength grid matching desi wavelength range 
#to fix the output dimension from the emulator - out_dim = len(WL_GRID) 
#fixed spacing too, seems appropriate for the desi wavelength range 
L_MIN, L_MAX, DLAM = 3600.0, 9824.0, 0.8

#building the wavelength grid vector - used for emulator output axis 
def make_uniform_grid(lmin=L_MIN, lmax=L_MAX, dlam=DLAM):
    #making it inclusive of lmax
    #returns an array of wavelengths 
    n = int(np.floor((lmax - lmin)/dlam)) + 1
    a = lmin + np.arange(n) * dlam
    return a

#def make_log_grid(lmin=3600., lmax=9800., R=3000., oversample=3):
    # pixel size ~ FWHM/R/oversample in ln λ
    #dln = 1.0/(R*oversample)
    #n = int(np.floor(np.log(lmax/lmin)/dln)) + 1
    #return lmin*np.exp(np.arange(n)*dln)

#creating the grid
WL_GRID = make_uniform_grid()

#obtaining the stellar parameters from the file names - inputs to emulator
#the pattern is based on the file names in the data directory - looks for substrings in the filenames 
# _T<Teff>_g<logg>_m<FeH>_a<aFe>_c<CFe>_n<NFe>
# where <Teff>, <logg>, <FeH>, <aFe>, <CFe>, <NFe> are the stellar parameters 
# Teff in K, logg in cm/s^2... 
#Teff = positive float - the rest allow negative floats 
pattern = re.compile(
    r"_T(?P<Teff>\d+(?:\.\d+)?)_g(?P<logg>-?\d+(?:\.\d+)?)"
    r"_m(?P<FeH>-?\d+(?:\.\d+)?)_a(?P<aFe>-?\d+(?:\.\d+)?)"
    r"_c(?P<CFe>-?\d+(?:\.\d+)?)_n(?P<NFe>-?\d+(?:\.\d+)?)")

#function to read model file and obtain the stellar parameters 
#puts them in an array of shape (num_models, 6) 
#as well as the flux and wavlength arrays 
def read_model_dat(path):
    #opening file and reading the data 
    #skipping two header lines, assinging column names 
    df = pd.read_csv(path, sep=r"\s+", skiprows=2,
                     names=["WAV", "FLUX", "CONT", "FLUX_CONT"],
                     dtype=float, engine="c")
    #converting wavelengths and fluxes to numpy arrays 
    wl = df["WAV"].to_numpy()
    flux = df["FLUX"].to_numpy()
    #to make sure that the wavelengths are increasing
    #detects if any step is not increasing 
    if np.any(np.diff(wl) <= 0):
        #sorting the wavelengths and fluxes 
        idx = np.argsort(wl)
        wl, flux = wl[idx], flux[idx]
    #applying the pattern to the file name 
    #tells if it doesnt match 
    m = pattern.search(Path(path).name)
    if not m:
        raise ValueError(f"Cannot parse params from {Path(path).name}")
    #turns captured groups into a dictionary with float values 
    d = {k: float(v) for k, v in m.groupdict().items()}
    #put them into a 6 element vector in order of Teff, logg, FeH, aFe, CFe, NFe 
    params = np.array([d["Teff"], d["logg"], d["FeH"], d["aFe"], d["CFe"], d["NFe"]], float)
    #returns the wavelength, flux and parameters 
    return wl, flux, params

In [None]:
#spectral lines of desi broaded by instrument 
#apply gaussian blurring to the flux but need width in pixels 
def estimate_sigma_pixels(wl, R=3000.0):
    """converting resolving power R into gaussian sigma in pixels""" 
    #dlam = median wavelength step in Angstroms = angstroms per pixel 
    #use median to avoid outliers 
    dlam = np.median(np.diff(wl))
    #converting resolving power to sigma in Angstroms 
    #R = lambda / fwhm - sigma = fwhm / 2.3548.. 
    sigma_A = (np.median(wl) / R) / 2.354820045
    #converting sigma from angstroms to pixels by dividing by dlam = angs/pixel 
    #at least 0.5 pixels to avoid too narrow gaussian 
    return max(sigma_A / dlam, 0.5)

#preprocessing for a single file 
#y output = what emulator will predict 
#info output = dictionary with parameters and intermediate data 
def preprocess_file_to_grid_logflux(path, wl_grid=WL_GRID, R=3000.0):
    '''for each file: blur flux to DESI resolution,
       trim to [wl_grid.min, wl_grid.max], rebin to wl_grid, log-transform
       returns log(flux) on wl_grid and a dictionary with params + intermediate data''' 
    #reading the model data from the file 
    #returns wavelength, flux and parameters
    #wl = wavelength in Angstroms, flux = flux in erg/s/cm^2
    #params = [Teff, logg, FeH, aFe, CFe, NFe] 
    wl, flux, params = read_model_dat(path)
    #finding gaussian width in pixels for spectrum - from resolving power 
    sig_px = estimate_sigma_pixels(wl, R=R)
    #applying gaussian blur with the sigma in pixels
    #building kernel in pixel units + extending the boundary so end flux not lost 
    #to match desi finite resolution 
    #synthetic models not - might get line width mismatches and biased fits 
    flux_blur = convolve(flux, Gaussian1DKernel(sig_px), boundary="extend")
    #masking the wavelengths that are outside the grid range - trims the spectrum 
    mask = (wl >= wl_grid[0]) & (wl <= wl_grid[-1])
    wl_t, fx_t = wl[mask], flux_blur[mask]
    #rebinning to the fixed wavelength grid
    #linearly interpolating the blurred flux to the fixed wavelength grid
    #anything that falls outside model coverage will be NaN
    y_lin = np.interp(wl_grid, wl_t, fx_t, left=np.nan, right=np.nan)
    #converting flux to log flux to avoid negative values
    #clipping to avoid log(0) or log(negative) - set to 1e-6
    y = np.log(np.clip(y_lin, 1e-6, None))
    #returning the log flux on the fixed wavelength grid and a dictionary with parameters
    #and the sigma in pixels for the Gaussian kernel
    info = {"params": params, "sigma_px": sig_px}
    return y, info


In [27]:
def build_targets(paths, wl_grid=WL_GRID, R=3000.0):
    """
    Build training set from many files.
    Returns:
      Y : (N, Npix) matrix of log-flux spectra
      P : (N, 6) parameter matrix
    """
    Ys, Ps = [], []
    for p in sorted(paths):
        y, info = preprocess_file_to_grid_logflux(p, wl_grid, R)
        Ys.append(y[None, :])
        Ps.append(info["params"][None, :])
    Y, P = np.vstack(Ys), np.vstack(Ps)
    return Y, P

def compute_output_scaler(Y):
    """
    Per-pixel mean/std for output standardisation.
    """
    y_mean = np.nanmean(Y, axis=0)
    y_std = np.nanstd(Y, axis=0) + 1e-6
    return y_mean, y_std


In [None]:
# choose only the base .dat files (ignore _broad.dat and .mod)
files = [f for f in glob.glob("/export/linnunrata/jls/triple_models/*.dat")
         if not f.endswith("_broad.dat") and not f.endswith(".mod")]

# build training set
Y, P = build_targets(files, WL_GRID, R=3000.0)

# compute scalers
y_mean, y_std = compute_output_scaler(Y)
Y_std = (Y - y_mean) / y_std

In [None]:
#emulator 
#want to take 6 parameters (Teff, logg, FeH, aFe, CFe, NFe) 
#predict a spectrum on fixed wavelengths grid 

#jax = like numpy but with autodiff - for computing gradients
#jit = just-in-time compilation, speeds up the code - compiles the function to machine code
#flax = neural networks library for JAX, similar to TensorFlow 

#importing the necessary libraries
#jax version of numpy - can run on CPU, GPU, TPU 
import jax, jax.numpy as jnp
#flax = deep learning library for JAX
#linen = for defining neural networks - gives layers, activation functions, etc. 
from flax import linen as nn
#python standard library for type hints
#sequence [int] = a list of integers, represents the number of neurons in each layer
from typing import Sequence


#before running model 
#standardize inputs (subtract mean, divide by std for each of the 6–7 features)
#keeps scales comparable and gradients well-behaved
#after model: de-standardize the predicted spectrum to return to physical flux units on wavelength grid

#converting data to and from standardized form 
#mean/std come from training data, computed for each feature 
#standardise = subtract mean and divide by standard deviation 
#keeping scales comparable - otherwise some features might dominate the gradients (e.g. Teff vs NFe)
#destandardise = multiply by standard deviation and add mean - opposite
#to get real flux units on the wavelength grid after prediction 
def standardise(x, mean, std, eps=1e-6): return (x - mean) / (std + eps)
def destandardise(y, mean, std):         return y * std + mean


class emulator(nn.Module):
    #how many wavelengths to predict - the spectrum length 
    out_dim: int  
    #widths of hidden layers - 5 layers with different sizes 
    hidden: Sequence[int] = (512,1024,2048,1024,512)
    #random dropout rate - prevent overfitting by dropping neurons during training 
    dropout: float = 0.1

    #forward pass of the model 
    #transformations applied to the input data with learnable weights 
    #x: input features - [Teff, logg, FeH, aFe, CFe, NFe] = (...,6) 
    #input : ... can be any number, but the last dimension must be 6 - vector of stellar parameters
    #can have mutiple stars in a batch, so the input shape can be (2,6) 

    #compact = define the model layers in a concise way
    #define sublayers in the __call__ method 
    @nn.compact
    #__call__ = the forward pass of the model, takes input x and returns output
    def __call__(self, x, *, train=False):    
        h = x
        #for each layer in the hidden layers
        #apply a dense layer, layer normalization, ReLU activation, and dropout 
        for d in self.hidden:
            #dense layers act only on batch, keeping the last dimension = 6 
            #affine transformation, linear transformation with weights and biases 
            #h = h @ w + b - w = (inputs ,d), b = (d,) (d=2 for two stars) 
            #mixes all input features together linearly, to produce d new features 
            h = nn.Dense(d)(h)
            #layer normalization = normalizes the output of the dense layer
            #computing the mean and variance across features, not across the batch
            #normalizes each feature to have zero mean and unit variance
            #this helps with training stability and convergence
            #h = (h - mean) / sqrt(var + eps)
            #eps = small constant to avoid division by zero 
            h = nn.LayerNorm()(h)
            #ReLU activation = rectified linear unit, applies non-linearity
            #ReLU(x) = max(0, x) - sets negative values to zero 
            #network can do complex mappings - spectral lines shapes etc. 
            h = nn.relu(h)
            #during training, apply dropout to the output of the ReLU activation
            #dropout = randomly sets some neurons to zero, to prevent overfitting 
            #other neurons scale their output to keep the same expected value 
            #prevent network from relying too much on specific neurons 
            #only applied during training, not during inference 
            h = nn.Dropout(self.dropout, deterministic=not train)(h)
        #after all hidden layers, apply a final dense layer to produce the output
        #output shape = (batch_size, out_dim) = (2, 1000) for two stars and 1000 wavelengths e.g. 4000-8000 Angstroms
        #this layer maps the final hidden representation to the output dimension
        y = nn.Dense(self.out_dim)(h)
        return y


In [53]:
# train_flax_emulator.py
import jax, jax.numpy as jnp
import optax, pickle
from flax.training.train_state import TrainState
from flax.training import checkpoints

def make_state(rng, out_dim, x_mean, x_std, y_mean, y_std, lr=1e-3):
    model = FlaxEmulator(out_dim=out_dim)
    params = model.init(rng, jnp.zeros((1,6)), train=True)
    tx = optax.adamw(lr)
    return TrainState.apply_fn==model.apply, params==params, tx==tx, opt_state==tx.init(params), 
    # Use a small helper container if your env expects a single object.

def loss_fn(params, apply_fn, xb, yb, x_mean, x_std, y_mean, y_std):
    xbn = standardize(xb, x_mean, x_std)
    predn = apply_fn(params, xbn, train=True)
    pred  = destandardize(predn, y_mean, y_std)
    return jnp.mean((pred - yb)**2)

@jax.jit
def train_step(state, xb, yb, x_mean, x_std, y_mean, y_std):
    l, grads = jax.value_and_grad(loss_fn)(state.params, state.apply_fn, xb, yb, x_mean, x_std, y_mean, y_std)
    updates, new_opt_state = state.tx.update(grads, state.opt_state, state.params)
    new_params = optax.apply_updates(state.params, updates)
    return state.replace(params=new_params, opt_state=new_opt_state), l

# Pseudocode usage:
# X (N,6), Y (N, L), wav (L,), compute x_mean/std, y_mean/std, then loop train_step(...)
# Save with checkpoints.save_checkpoint(..., target={'params':..., 'x_mean':..., 'x_std':..., 'y_mean':..., 'y_std':..., 'wav': wav})


In [54]:
# forward_model.py
import jax.numpy as jnp

def apply_extinction(flux, A_lambda, ebv):     # A_lambda per unit E(B-V) on the same wav grid
    return flux * 10**(-0.4 * A_lambda * ebv)

def blend(flux1, flux2, W):
    return W * flux1 + (1. - W) * flux2

def emulator_predict(params_emul, apply_fn, x_mean, x_std, y_mean, y_std, theta6):
    x = (theta6 - x_mean) / (x_std + 1e-6)
    yn = apply_fn(params_emul, x, train=False)
    return yn * y_std + y_mean

def model_fibre(params_emul, apply_fn, stats, theta1_6, ebv1, theta2_6, ebv2, W, A_lambda):
    spec1 = emulator_predict(params_emul, apply_fn, stats['x_mean'], stats['x_std'], stats['y_mean'], stats['y_std'], theta1_6)
    spec2 = emulator_predict(params_emul, apply_fn, stats, theta2_6)  # or pass stats fields explicitly
    spec1 = apply_extinction(spec1, A_lambda, ebv1)
    spec2 = apply_extinction(spec2, A_lambda, ebv2)
    return blend(spec1, spec2, W)


In [55]:
# photometry_jax.py
import jax.numpy as jnp

def trapz(y, x): dx = jnp.diff(x); return jnp.sum(0.5*(y[...,1:]+y[...,:-1])*dx, axis=-1)

def synth_mags(wav_um, flux, filters):
    """filters: dict[name] -> (T_on_wav, zp_offset). T already on wav_um grid; AB by default."""
    c = 2.99792458e10
    mags = {}
    for name, (T, zp) in filters.items():
        w = wav_um
        num = trapz(flux * T * w, w)      # photon counting
        den = trapz(T * w, w) + 1e-30
        f_lambda = num / den
        lam_cm = jnp.mean(w) * 1e-4
        f_nu = f_lambda * lam_cm**2 / c
        mags[name] = -2.5*jnp.log10(jnp.clip(f_nu, 1e-50)) - 48.6 + zp
    return mags


In [56]:
# logposterior.py
import jax, jax.numpy as jnp
from jax.nn import logsumexp

def loglike_spec(obs_flux, obs_ivar, mod_flux):
    resid = obs_flux - mod_flux
    return -0.5 * jnp.sum(resid**2 * obs_ivar)

def loglike_phot(m_obs, m_pred, sigma):
    return -0.5 * jnp.sum((m_obs - m_pred)**2 / (sigma**2))

def build_logposterior(apply_fn, params_emul, stats, wav_um, A_lambda, filters, priors):
    def unpack(theta):
        # theta = [θ1(6), ebv1, θ2(6), ebv2, W_A, W_B]
        t1 = theta[0:6]; ebv1 = theta[6]
        t2 = theta[7:13]; ebv2 = theta[13]
        W_A = jnp.clip(theta[14], 0., 1.); W_B = jnp.clip(theta[15], 0., 1.)
        return t1, ebv1, t2, ebv2, W_A, W_B

    def fibre_loglike(t1, ebv1, t2, ebv2, W_A, W_B, data):
        from forward_model import model_fibre
        mod_A = model_fibre(params_emul, apply_fn, stats, t1, ebv1, t2, ebv2, W_A, A_lambda)
        mod_B = model_fibre(params_emul, apply_fn, stats, t1, ebv1, t2, ebv2, W_B, A_lambda)
        ll = loglike_spec(data['A']['flux'], data['A']['ivar'], mod_A)
        ll+= loglike_spec(data['B']['flux'], data['B']['ivar'], mod_B)

        if 'phot' in data:
            from photometry_jax import synth_mags
            blend_for_phot = 0.5*(mod_A + mod_B)   # or a chosen mapping from fibres → imaging
            mpred = synth_mags(wav_um, blend_for_phot, filters)
            m_pred_vec = jnp.array([mpred[b] for b in data['phot']['bands']])
            ll += loglike_phot(data['phot']['m'], m_pred_vec, data['phot']['sigma'])
        return ll

    def logprior(t1, ebv1, t2, ebv2, W_A, W_B):
        return (priors['star'](t1) + priors['ebv'](ebv1) +
                priors['star'](t2) + priors['ebv'](ebv2) +
                priors['W'](W_A)   + priors['W'](W_B))

    def logpost(theta, data):
        t1, e1, t2, e2, WA, WB = unpack(theta)
        # two labelings: (1,2) and (2,1)
        lp  = logprior(t1,e1,t2,e2,WA,WB)
        ll1 = fibre_loglike(t1,e1,t2,e2,WA,WB, data)
        ll2 = fibre_loglike(t2,e2,t1,e1,WA,WB, data)
        # permutation-safe combine
        return logsumexp(jnp.array([lp + ll1, lp + ll2]))
    return logpost


In [59]:
# nuts_driver.py
import jax, jax.numpy as jnp, blackjax

def run_nuts(logpost, data, theta_init, n_warm=1000, n_samp=2000, seed=0):
    def logprob(theta): return logpost(theta, data)
    rng = jax.random.PRNGKey(seed)
    kernel = blackjax.nuts(logprob)
    state = kernel.init(theta_init, rng)

    @jax.jit
    def one_step(state, key):
        new_state, info = kernel.step(key, state)
        return new_state, info

    samples = []
    key = rng
    for i in range(n_warm + n_samp):
        key, sub = jax.random.split(key)
        state, _ = one_step(state, sub)
        if i >= n_warm:
            samples.append(state.position)
    return jnp.stack(samples, axis=0)


In [60]:
# priors.py
import jax.numpy as jnp

def box_gauss_vec(x, mu, sig, lo, hi):
    inside = jnp.all((x>=lo)&(x<=hi))
    return jnp.where(inside, -0.5*jnp.sum(((x-mu)/sig)**2), -jnp.inf)

def star_prior_factory(ranges):
    lo = jnp.array([a for a,b in ranges]); hi = jnp.array([b for a,b in ranges])
    mu = 0.5*(lo+hi); sig = (hi-lo)/3.
    return lambda x: box_gauss_vec(x, mu, sig, lo, hi)

def ebv_prior(mu=0.07, sig=0.05, lo=0., hi=1.0):
    def lp(x):
        inside = (x>=lo)&(x<=hi)
        return jnp.where(inside, -0.5*((x-mu)/sig)**2, -jnp.inf)
    return lp

def beta01(a=1., b=1.):
    from jax.scipy.special import betaln
    Z = betaln(a,b)
    def lp(w):
        inside = (0.<=w)*(w<=1.)
        return jnp.where(inside, (a-1.)*jnp.log(jnp.clip(w,1e-12,1.)) + (b-1.)*jnp.log(jnp.clip(1.-w,1e-12,1.)) - Z, -jnp.inf)
    return lp


In [61]:
# run_inference.py
import jax.numpy as jnp, jax
from flax.training import checkpoints

# 1) Load emulator checkpoint and stats
ckpt = checkpoints.restore_checkpoint("emulator_ckpt_dir", target=None)
params_emul = ckpt['params']
x_mean, x_std = ckpt['x_mean'], ckpt['x_std']
y_mean, y_std = ckpt['y_mean'], ckpt['y_std']
wav_um       = ckpt['wav']          # ensure this matches DESI grid used for data

emulator = FlaxEmulator(out_dim=len(wav_um))
apply_fn = emulator.apply
stats = {'x_mean': x_mean, 'x_std': x_std, 'y_mean': y_mean, 'y_std': y_std}

# 2) Prepare data dict
data = {
  'A': {'flux': flux_A, 'ivar': ivar_A},
  'B': {'flux': flux_B, 'ivar': ivar_B},
  'phot': {'m': m_obs, 'sigma': m_err, 'bands': bands},   # optional
}
A_lambda = A_lambda_on_wav      # array length = len(wav_um)
filters  = preinterp_filters    # dict: name -> (T_on_wav, zp)

# 3) Priors
ranges = jnp.array([
    [4000., 6500.],   # Teff
    [3.5, 5.0],       # logg
    [-2.5, 0.5],      # [Fe/H]
    [-0.2, 0.5],      # [alpha/Fe]
    [-0.5, 2.0],      # [C/Fe]
    [-0.5, 2.0],      # [N/Fe]
])
priors = {'star': star_prior_factory(ranges),
          'ebv':  ebv_prior(0.07,0.05,0.,1.0),
          'W':    beta01(1.,1.)}

# 4) Build log-posterior closure
logpost = build_logposterior(apply_fn, params_emul, stats, wav_um, A_lambda, filters, priors)

# 5) Initialisation (use photometry-only predictions if you have them)
theta0 = jnp.array([teff1,logg1,feh1,a1,c1,n1, ebv1,
                    teff2,logg2,feh2,a2,c2,n2, ebv2,
                    0.7, 0.3])

# 6) Sample
samples = run_nuts(logpost, data, theta0, n_warm=1000, n_samp=2000, seed=123)


TypeError: 'NoneType' object is not subscriptable