In [2]:
import json
import os
from itertools import product
from pathlib import Path
from typing import List, Tuple, Union

import corner
import jax
import jax.numpy as jnp
import keras
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as tck
import numpy as np
import pandas as pd
import seaborn as sns
import sklearn
import tensorflow as tf
from keras import layers
from sklearn.decomposition import PCA
from tensorflow.keras.models import Model

In [3]:
from scripts.PCANN import PCANN
from scripts.WMSE import WMSE

ModuleNotFoundError: No module named 'scripts.PCANN'

In [3]:
run_name = "smiley_stingray"

In [4]:
emulatorpath = f"./{run_name}/"

logfile = f"log_{run_name}.json"
pcafile = f"pca_{run_name}.json"
checkpointfile = f"{run_name}_checkpoint.h5"
historyfile = f"history_{run_name}.json"

with open(os.path.join(emulatorpath, logfile), "r") as fp:
    data = json.load(fp)

    gridpath = data["gridpath"]
    gridfile = data["gridfile"]
    grid = os.path.join(gridpath, gridfile)

    seed = data["seed"]
    n_components = data["n_components"]

    batch_size_exp = data["batch_size_exp"]
    epochs = data["epochs"]
    test_size = data["test_size"]
    fractrain = data["fractrain"]

    inputs = data["inputs"]
    classical_outputs = data["classical_outputs"]
    nmin = data["nmin"]
    nmax = data["nmax"]

astero_outputs = [f"nu_0_{i+1}" for i in range(nmin - 1, nmax)]
outputs = classical_outputs + astero_outputs

## A lot of functions

In [5]:
def load_pca_components(emulatorpath: str, pcafile: str) -> (np.array, np.array):
    with open(os.path.join(emulatorpath, pcafile), "r") as fp:
        data = json.load(fp)
        pca_comps = np.array(data["pca_comps"])
        pca_mean = np.array(data["pca_mean"])
    return pca_comps, pca_mean


def get_weights_and_biases(tf_model: Model) -> (list, list):
    weights = list(map(jnp.asarray, tf_model.weights[::2]))
    biases = list(map(jnp.asarray, tf_model.weights[1::2]))
    return weights, biases


def load_tf_model(
    emulatorpath: str,
    checkpointfile: str,
    pcafile: str | None = None,
    pca_comps: np.ndarray | None = None,
    pca_mean: np.ndarray | None = None,
    n: int = 25,
):
    if pca_comps is None or pca_mean is None:
        assert pcafile is not None
        pca_comps, pca_mean = load_pca_components(emulatorpath, pcafile)

    custom_objects = {
        "PCANN": PCANN(pca_comps, pca_mean),
        "WMSE": WMSE(np.ones(n)),
    }

    tf_model = tf.keras.models.load_model(
        os.path.join(emulatorpath, checkpointfile), custom_objects=custom_objects
    )
    return tf_model


def load_emulator(run_name: str, emulatorpath: str, checkpointfile: str, pcafile: str):
    pca_comps, pca_mean = load_pca_components(emulatorpath, pcafile)

    tf_model = load_tf_model(
        emulatorpath,
        checkpointfile,
        pca_comps=pca_comps,
        pca_mean=pca_mean,
    )
    weights, biases = get_weights_and_biases(tf_model)

    stem_map = [0, 1]
    ctine_map = [-5, -3, -1]
    atine_map = [-10, -9, -8, -7, -6, -4, -2]

    emulator = (
        weights,
        biases,
        stem_map,
        ctine_map,
        atine_map,
        pca_comps,
        pca_mean,
    )
    return emulator

In [6]:
logcols = [
    "initial_mass",
    "alphaMLT",
    "radius",
    "luminosity",
    "mass",
]
logcols += ["error_" + col for col in logcols]


def scale(
    data: Union[pd.DataFrame, np.ndarray],
    logcols: List[str] = logcols,
    col_names: List[str] | None = None,
    verbose: bool = False,
) -> Union[pd.DataFrame, np.ndarray]:
    if isinstance(data, np.ndarray):
        if col_names is None:
            raise ValueError("col_names must be provided when data is a NumPy array.")
        df_unnorm = pd.DataFrame(data, columns=col_names)
    else:
        df_unnorm = data

    if col_names is None:
        col_names = df_unnorm.columns
        cols = df_unnorm.values.T
    else:
        cols = data.T

    df_norm = df_unnorm.copy()
    for col_name, col in zip(col_names, cols):
        if col_name in logcols:
            if verbose:
                print(f"{col_name} scaled with log10")
            df_norm[col_name] = np.log10(col)
        elif col_name in ["initial_y", "error_initial_y"]:
            if verbose:
                print(f"{col_name} scaled by multiply with 4 and log10")
            df_norm[col_name] = np.log10(col * 4)
        elif col_name in ["age", "error_age"]:
            if verbose:
                print(f"{col_name} scaled by dividing with 1000 and then log10")
            df_norm[col_name] = np.log10(col / 1000)
        else:
            if verbose:
                print(f"{col_name} not scaled")

    if isinstance(data, np.ndarray):
        return df_norm.values
    return df_norm


def descale(
    data: Union[pd.DataFrame, np.ndarray],
    logcols: List[str] = logcols,
    col_names: List[str] | None = None,
    verbose: bool = False,
) -> Union[pd.DataFrame, np.ndarray]:
    if isinstance(data, np.ndarray):
        if col_names is None:
            raise ValueError("col_names must be provided when data is a NumPy array.")
        df_norm = pd.DataFrame(data, columns=col_names)
    else:
        df_norm = data

    if col_names is None:
        col_names = df_norm.columns
        cols = df_norm.values.T

    df_unnorm = df_norm.copy()
    for col_name, col in zip(col_names, cols):
        if col_name in logcols:
            if verbose:
                print(f"{col_name} descaled using inverse log10")
            df_unnorm[col_name] = 10 ** (col)
        elif col_name == "initial_y":
            if verbose:
                print(f"{col_name} descaled by inverse log10 and then divide by 4")
            df_unnorm[col_name] = (10 ** (col)) / 4
        elif col_name == "age":
            if verbose:
                print(
                    f"{col_name} descaled by inverse log10 and then multiply with 1000"
                )
            df_unnorm[col_name] = (10 ** (col)) * 1000
        else:
            if verbose:
                print(f"{col_name} not descaled")

    if isinstance(data, np.ndarray):
        return df_unnorm.values
    return df_unnorm

In [7]:
def call_emulator(
    input_norm: np.ndarray,
    emulator: Tuple[
        np.ndarray, np.ndarray, list[int], list[int], list[int], np.ndarray, np.ndarray
    ],
    scale_dimensions: List[str] | None = None,
) -> jax.Array:
    stem = input_norm

    (weights, biases, stem_map, ctine_map, atine_map, pca_comps, pca_mean) = emulator

    for index in stem_map:
        stem = jax.nn.elu(jnp.dot(stem, weights[index]) + biases[index])
    xx = jnp.copy(stem)

    for i, cindex in enumerate(ctine_map[:-1]):
        if i == 0:
            ctine = jax.nn.elu(jnp.dot(stem, weights[cindex]) + biases[cindex])
        else:
            ctine = jax.nn.elu(jnp.dot(ctine, weights[cindex]) + biases[cindex])
    ctine_out = jnp.dot(ctine, weights[ctine_map[-1]]) + biases[ctine_map[-1]]

    for i, aindex in enumerate(atine_map[:-1]):
        if i == 0:
            atine = jax.nn.elu(jnp.dot(stem, weights[aindex]) + biases[aindex])
        else:
            atine = jax.nn.elu(jnp.dot(atine, weights[aindex]) + biases[aindex])
    atine_out = jnp.dot(atine, weights[atine_map[-1]]) + biases[atine_map[-1]]
    atine_out = jnp.dot(atine_out, pca_comps) + pca_mean

    out_norm = jnp.concatenate((ctine_out, atine_out), axis=-1)
    return out_norm


def call_emulator_with_df(
    input_norm: pd.DataFrame,
    emulator: Tuple[
        np.ndarray, np.ndarray, list[int], list[int], list[int], np.ndarray, np.ndarray
    ],
    outputcolumns: list[str, ...],
    verbose: bool = False,
) -> pd.DataFrame:
    np_input_norm = input_norm.to_numpy()
    for i, row_norm in enumerate(np_input_norm):
        out = call_emulator(
            row_norm,
            emulator,
        )
        out = out.reshape(-1, len(outputcolumns))
        df_output_norm = pd.DataFrame(data=out, columns=outputcolumns, dtype=float)
        df_output_unnorm = descale(df_output_norm, verbose=verbose)
        if i == 0:
            output_unnorm = df_output_unnorm
        else:
            output_unnorm = np.vstack([output_unnorm, df_output_unnorm])
    df_output_unnorm = pd.DataFrame(data=output_unnorm, columns=outputcolumns)
    return df_output_unnorm

In [8]:
def get_test_input(
    test_star_linear: dict | None = None,
    df: pd.DataFrame | None = None,
    idx: int | None = None,
    inputs: list[str, ...] = inputs,
) -> pd.DataFrame:
    assert (test_star_linear is not None) | ((df is not None) & (idx is not None))
    if test_star_linear is not None:
        # This is a manual input using a dict to make a star
        df_test_star_linear = pd.DataFrame(data=test_star_linear)
        df_test_star = scale(df_test_star_linear)
        assert df_test_star_linear.equals(descale(df_test_star))

        test_input = np.array([df_test_star.values.tolist()[0]])
    else:
        # This looks up df[idx] and uses this as test input
        test_input = np.array([df.loc[idx][inputs].values])
    return test_input

## Make a simulated star

In [45]:
test_star_linear = {
    "initial_mass": np.asarray(
        [
            1.0,
            0.9,
        ]
    ),
    "MeH": [
        0.0,
        -0.1,
    ],
    "alphaFe": [
        0.0,
        0.2,
    ],
    "initial_y": [
        0.26,
        0.26,
    ],
    "alphaMLT": [
        2.0,
        2.0,
    ],
    "eta": [
        0.01,
        0.01,
    ],
    "normalised_age": [
        0,
        0.2,
    ],
}

test_input_norm = scale(pd.DataFrame.from_dict(test_star_linear))
# get_test_input(test_star_linear=test_star_linear)

In [46]:
test_input_norm

Unnamed: 0,initial_mass,MeH,alphaFe,initial_y,alphaMLT,eta,normalised_age
0,0.0,0.0,0.0,0.017033,0.30103,0.01,0.0
1,-0.045757,-0.1,0.2,0.017033,0.30103,0.01,0.2


In [47]:
tf_model = load_tf_model(
    emulatorpath=emulatorpath,
    checkpointfile=checkpointfile,
    pcafile=pcafile,
    n=nmax - nmin,
)



In [48]:
modelprediction = tf_model.predict(test_input_norm)
preds_norm = tf.concat(modelprediction, 1)
df_preds_norm = pd.DataFrame(preds_norm, columns=outputs, dtype=float)
df_preds = descale(df_preds_norm)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 258ms/step


In [49]:
df_preds

Unnamed: 0,radius,luminosity,mass,age,nu_0_1,nu_0_2,nu_0_3,nu_0_4,nu_0_5,nu_0_6,...,nu_0_12,nu_0_13,nu_0_14,nu_0_15,nu_0_16,nu_0_17,nu_0_18,nu_0_19,nu_0_20,nu_0_21
0,2.260889,2.978988,0.999034,12282.85152,62.477612,114.506554,162.430389,208.015564,252.467346,295.544159,...,532.782654,572.258728,612.109375,651.628601,692.060242,732.139465,772.875977,813.355042,854.031616,894.85907
1,2.546391,3.667733,0.899824,15569.359825,49.285847,90.772896,128.693268,164.908096,200.13887,233.358337,...,419.85083,451.191528,482.874695,514.497253,546.805176,578.810303,611.28949,643.438843,675.483032,707.394043


## How I use it in my model

In [None]:
# def model(
#     emulator,
#     error_obs: dict | pd.DataFrame,
#     obs: dict | pd.DataFrame | None = None,
#     modelparams: dict | None = None,
#     modelhyperparams: dict | None = None,
#     fillvalue: float | None = None,
# ):
#     if fillvalue is None:
#         fillvalue = constants.fillvalue

#     # Do stuff with priors

#     # Emulate observables
#     input_norm = jnp.stack(
#         [
#             jnp.log10(initial_mass),
#             initial_MeH,
#             alphaFe,
#             jnp.log10(initial_y * 4),
#             jnp.log10(alphaMLT),
#             eta,
#             normalised_age,
#         ],
#         axis=-1,
#     )

#     output_norm = call_emulator(input_norm=input_norm, emulator=emulator)

#     # Unpack and descale
#     radius = numpyro.deterministic("radius", 10 ** output_norm[:, 0])
#     luminosity = numpyro.deterministic("luminosity", 10 ** output_norm[:, 1])
#     teff = numpyro.deterministic(
#         "teff", constants.teff_sun * (radius ** (-2) * luminosity) ** (0.25)
#     )
#     mass = numpyro.deterministic("mass", 10 ** output_norm[:, 2])
#     age = numpyro.deterministic("age", (10 ** output_norm[:, 3]) * 1000)

#     # Compute numax for the sole purpose of being a scale in the surface correction
#     dnu = jnp.median(jnp.diff(output_norm[:, len(classical_outputs) :]))  # len(classical_outputs) = 4
#     numax = (dnu / 0.263) ** (1 / 0.772)  # Stello et al

#     if obs is not None:
#         numpyro.sample(
#             f"observed_teff",
#             dist.StudentT(5, teff, error_obs["teff"]),
#             obs=obs["teff"],
#         )
#         numpyro.sample(
#             f"observed_luminosity",
#             dist.StudentT(5, luminosity, error_obs["luminosity"]),
#             obs=obs["luminosity"],
#         )

#         # etc