In [1]:
#to read files, get the stellar parameters, blur using a Gaussian kernel 
import numpy as np
import pandas as pd
import re, glob
from astropy.convolution import convolve, Gaussian1DKernel
#gives file name without directories - for obtaining stellar parameters
from pathlib import Path

#fixed wavelength grid matching desi wavelength range 
#to fix the output dimension from the emulator - out_dim = len(WL_GRID) 
#fixed spacing too, seems appropriate for the desi wavelength range 
L_MIN, L_MAX, DLAM = 3600.0, 9824.0, 0.8

#building the wavelength grid vector - used for emulator output axis 
def make_uniform_grid(lmin=L_MIN, lmax=L_MAX, dlam=DLAM):
    #making it inclusive of lmax
    #returns an array of wavelengths 
    n = int(np.floor((lmax - lmin)/dlam)) + 1
    a = lmin + np.arange(n) * dlam
    return a

#def make_log_grid(lmin=3600., lmax=9800., R=3000., oversample=3):
    # pixel size ~ FWHM/R/oversample in ln λ
    #dln = 1.0/(R*oversample)
    #n = int(np.floor(np.log(lmax/lmin)/dln)) + 1
    #return lmin*np.exp(np.arange(n)*dln)

#creating the grid
WL_GRID = make_uniform_grid()

#obtaining the stellar parameters from the file names - inputs to emulator
#the pattern is based on the file names in the data directory - looks for substrings in the filenames 
# _T<Teff>_g<logg>_m<FeH>_a<aFe>_c<CFe>_n<NFe>
# where <Teff>, <logg>, <FeH>, <aFe>, <CFe>, <NFe> are the stellar parameters 
# Teff in K, logg in cm/s^2... 
#Teff = positive float - the rest allow negative floats 
pattern = re.compile(
    r"_T(?P<Teff>\d+(?:\.\d+)?)_g(?P<logg>-?\d+(?:\.\d+)?)"
    r"_m(?P<FeH>-?\d+(?:\.\d+)?)_a(?P<aFe>-?\d+(?:\.\d+)?)"
    r"_c(?P<CFe>-?\d+(?:\.\d+)?)_n(?P<NFe>-?\d+(?:\.\d+)?)")

#function to read model file and obtain the stellar parameters 
#puts them in an array of shape (num_models, 6) 
#as well as the flux and wavlength arrays 
def read_model_dat(path):
    #opening file and reading the data 
    #skipping two header lines, assinging column names 
    df = pd.read_csv(path, sep=r"\s+", skiprows=2,
                     names=["WAV", "FLUX", "CONT", "FLUX_CONT"],
                     dtype=float, engine="c")
    #converting wavelengths and fluxes to numpy arrays 
    wl = df["WAV"].to_numpy()
    flux = df["FLUX"].to_numpy()
    #to make sure that the wavelengths are increasing
    #detects if any step is not increasing 
    if np.any(np.diff(wl) <= 0):
        #sorting the wavelengths and fluxes 
        idx = np.argsort(wl)
        wl, flux = wl[idx], flux[idx]
    #applying the pattern to the file name 
    #tells if it doesnt match 
    m = pattern.search(Path(path).name)
    if not m:
        raise ValueError(f"Cannot parse params from {Path(path).name}")
    #turns captured groups into a dictionary with float values 
    d = {k: float(v) for k, v in m.groupdict().items()}
    #put them into a 6 element vector in order of Teff, logg, FeH, aFe, CFe, NFe 
    params = np.array([d["Teff"], d["logg"], d["FeH"], d["aFe"], d["CFe"], d["NFe"]], float)
    #returns the wavelength, flux and parameters 
    return wl, flux, params

In [6]:
#spectral lines of desi broaded by instrument 
#apply gaussian blurring to the flux but need width in pixels 
def estimate_sigma_pixels(wl, R=3000.0):
    """converting resolving power R into gaussian sigma in pixels""" 
    #dlam = median wavelength step in Angstroms = angstroms per pixel 
    #use median to avoid outliers 
    dlam = np.median(np.diff(wl))
    #converting resolving power to sigma in Angstroms 
    #R = lambda / fwhm - sigma = fwhm / 2.3548.. 
    FWHM_over_sigma = 2*np.sqrt(2*np.log(2))
    sigma_A = (np.median(wl) / R) / FWHM_over_sigma
    #converting sigma from angstroms to pixels by dividing by dlam = angs/pixel 
    #at least 0.5 pixels to avoid too narrow gaussian 
    return max(sigma_A / dlam, 0.3)

#preprocessing for a single file 
#y output = what emulator will predict 
#info output = dictionary with parameters and intermediate data 
def preprocess_file_to_grid_logflux(path, wl_grid=WL_GRID, R=3000.0):
    '''for each file: blur flux to DESI resolution,
       trim to [wl_grid.min, wl_grid.max], rebin to wl_grid, log-transform
       returns log(flux) on wl_grid and a dictionary with params + intermediate data''' 
    #reading the model data from the file 
    #returns wavelength, flux and parameters
    #wl = wavelength in Angstroms, flux = flux in erg/s/cm^2
    #params = [Teff, logg, FeH, aFe, CFe, NFe] 
    wl, flux, params = read_model_dat(path)
    #finding gaussian width in pixels for spectrum - from resolving power 
    sig_px = estimate_sigma_pixels(wl, R=R)
    #applying gaussian blur with the sigma in pixels
    #building kernel in pixel units + extending the boundary so end flux not lost 
    #to match desi finite resolution 
    #synthetic models not - might get line width mismatches and biased fits 
    flux_blur = convolve(flux, Gaussian1DKernel(sig_px), boundary="extend")
    #masking the wavelengths that are outside the grid range - trims the spectrum 
    mask = (wl >= wl_grid[0]) & (wl <= wl_grid[-1])
    wl_t, fx_t = wl[mask], flux_blur[mask]
    #rebinning to the fixed wavelength grid
    #linearly interpolating the blurred flux to the fixed wavelength grid
    #anything that falls outside model coverage will be NaN
    y_lin = np.interp(wl_grid, wl_t, fx_t, left=np.nan, right=np.nan)
    #converting flux to log flux to avoid negative values
    #clipping to avoid log(0) or log(negative) - set to 1e-6
    y = np.log(np.clip(y_lin, 1e-6, None))
    #returning the log flux on the fixed wavelength grid and a dictionary with parameters
    #and the sigma in pixels for the Gaussian kernel
    info = {"params": params, "sigma_px": sig_px}
    return y, info


In [9]:
#looping over many model files to build the training set 
#Y = targets matrix of shape (N, Npix) - log-flux spectra
#P = parameters matrix of shape (N, 6) - stellar parameters 
def build_targets(paths, wl_grid=WL_GRID, R=3000.0):
    """returns (N, Npix) matrix of log-flux spectra and (N, 6) parameter matrix"""
    #two lists to store the outputs 
    Ys, Ps = [], []
    #looping over the sorted paths to ensure consistent order 
    for p in sorted(paths):
        #reading file, trims, blurs, rebins, applying log
        y, info = preprocess_file_to_grid_logflux(p, wl_grid, R)
        #turning 1D array into a row shape - (1, Npix)
        #and appending to the list of Ys
        Ys.append(y[None, :])
        #appending the parameters to the list of Ps - from info dict
        Ps.append(info["params"][None, :])
    #stacking the lists into 2D arrays 
    #(n_files, Npix) and (n_files, 6) shapes
    Y, P = np.vstack(Ys), np.vstack(Ps)
    return Y, P

#computing one mean and std per wavelength pixel
#stabalise training 
def compute_output_scaler(Y):
    """per-pixel mean/std for output standardisation."""
    #column wise mean over npix,.. 
    y_mean = np.nanmean(Y, axis=0)
    #column wise std over (npix,..) + small constant to avoid division by zero 
    y_std = np.nanstd(Y, axis=0) + 1e-6
    return y_mean, y_std


In [None]:
#choose only the base .dat files (ignore _broad.dat and .mod)
files = [f for f in glob.glob("/export/linnunrata/jls/triple_models/*.dat")
         if not f.endswith("_broad.dat") and not f.endswith(".mod")]

#build training set
Y, P = build_targets(files, WL_GRID, R=3000.0)

#compute scalers 
#y_mean = av log flux at each wavelength
#y_std = std log flux at each wavelength + small constant to avoid division by zero 
#stablising gradients to avoid big variance in the training set
y_mean, y_std = compute_output_scaler(Y)
#Y_std = what emulator learns to predict
#standardising the output by subtracting mean and dividing by std
#Y = (Y - y_mean) / y_std
Y_std = (Y - y_mean) / y_std

In [None]:
#emulator 
#want to take 6 parameters (Teff, logg, FeH, aFe, CFe, NFe) 
#predict a spectrum on fixed wavelengths grid 

#jax = like numpy but with autodiff - for computing gradients
#jit = just-in-time compilation, speeds up the code - compiles the function to machine code
#flax = neural networks library for JAX, similar to TensorFlow 

#importing the necessary libraries
#jax version of numpy - can run on CPU, GPU, TPU 
import jax, jax.numpy as jnp
#flax = deep learning library for JAX
#linen = for defining neural networks - gives layers, activation functions, etc. 
from flax import linen as nn
#python standard library for type hints
#sequence [int] = a list of integers, represents the number of neurons in each layer
from typing import Sequence


#before running model 
#standardize inputs (subtract mean, divide by std for each of the 6–7 features)
#keeps scales comparable and gradients well-behaved
#after model: de-standardize the predicted spectrum to return to physical flux units on wavelength grid

#converting data to and from standardized form 
#mean/std come from training data, computed for each feature 
#standardise = subtract mean and divide by standard deviation 
#keeping scales comparable - otherwise some features might dominate the gradients (e.g. Teff vs NFe)
#destandardise = multiply by standard deviation and add mean - opposite
#to get real flux units on the wavelength grid after prediction 
def standardise(x, mean, std, eps=1e-6): return (x - mean) / (std + eps)
def destandardise(y, mean, std):         return y * std + mean


class emulator(nn.Module):
    #how many wavelengths to predict - the spectrum length 
    out_dim: int  
    #widths of hidden layers - 5 layers with different sizes 
    hidden: Sequence[int] = (512,1024,2048,1024,512)
    #random dropout rate - prevent overfitting by dropping neurons during training 
    dropout: float = 0.1

    #forward pass of the model 
    #transformations applied to the input data with learnable weights 
    #x: input features - [Teff, logg, FeH, aFe, CFe, NFe] = (...,6) 
    #input : ... can be any number, but the last dimension must be 6 - vector of stellar parameters
    #can have mutiple stars in a batch, so the input shape can be (2,6) 

    #compact = define the model layers in a concise way
    #define sublayers in the __call__ method 
    @nn.compact
    #__call__ = the forward pass of the model, takes input x and returns output
    def __call__(self, x, *, train=False):    
        h = x
        #for each layer in the hidden layers
        #apply a dense layer, layer normalization, ReLU activation, and dropout 
        for d in self.hidden:
            #dense layers act only on batch, keeping the last dimension = 6 
            #affine transformation, linear transformation with weights and biases 
            #h = h @ w + b - w = (inputs ,d), b = (d,) (d=2 for two stars) 
            #mixes all input features together linearly, to produce d new features 
            h = nn.Dense(d)(h)
            #layer normalization = normalizes the output of the dense layer
            #computing the mean and variance across features, not across the batch
            #normalizes each feature to have zero mean and unit variance
            #this helps with training stability and convergence
            #h = (h - mean) / sqrt(var + eps)
            #eps = small constant to avoid division by zero 
            h = nn.LayerNorm()(h)
            #ReLU activation = rectified linear unit, applies non-linearity
            #ReLU(x) = max(0, x) - sets negative values to zero 
            #network can do complex mappings - spectral lines shapes etc. 
            h = nn.relu(h)
            #during training, apply dropout to the output of the ReLU activation
            #dropout = randomly sets some neurons to zero, to prevent overfitting 
            #other neurons scale their output to keep the same expected value 
            #prevent network from relying too much on specific neurons 
            #only applied during training, not during inference 
            h = nn.Dropout(self.dropout, deterministic=not train)(h)
        #after all hidden layers, apply a final dense layer to produce the output
        #output shape = (batch_size, out_dim) = (2, 1000) for two stars and 1000 wavelengths e.g. 4000-8000 Angstroms
        #this layer maps the final hidden representation to the output dimension
        y = nn.Dense(self.out_dim)(h)
        return y
