In [None]:
import numpy as np
import torch

from ariel import *

In [None]:
def sample_normal(mean, std, T):
    return np.random.normal(loc=mean, scale=std, size=(T, *mean.shape)).swapaxes(0, 1)

In [None]:
modelnames = [
    "royal-dew-647",
    "warm-voice-646",
    "classic-valley-645",
    "autumn-violet-644",
    "sweet-surf-643",
    "efficient-smoke-642",
    "driven-silence-641",
    "whole-frost-640",
    "astral-terrain-639",
    "elated-frost-638",
    "light-bird-637",
    "ruby-field-636",
    "elated-salad-635",
    "fallen-waterfall-634",
    "clear-valley-633",
    "giddy-serenity-632",
    "colorful-shadow-631",
    "olive-bird-630",
    "fresh-elevator-629",
    "splendid-sea-628",
]
state_dicts = ["models/" + modelname + ".pt" for modelname in modelnames]
models = [Model(DEFAULT_HYPERPARAMETERS) for state_dict in state_dicts]
for model, state_dict in zip(models, state_dicts):
    model.load_state_dict(torch.load(state_dict, map_location=torch.device(DEVICE)))
    model.eval()
len(models)

In [None]:
ids_train = np.arange(N_ANNOTATED)
trainset = get_dataset(ids_train)

## Test set

In [None]:
N_TEST = 800

In [None]:
ids_test = np.arange(N_TEST)
spectra_test = read_spectra(ids_test, path="data/test/spectra.hdf5")
X_test = spectra_test[1]
auxiliary_test = read_auxiliary_table(ids_test, path="data/test/auxiliary_table.csv")
X_test = scale(X_test)
auxiliary_test = standardise(
    auxiliary_test, trainset.auxiliary_mean, trainset.auxiliary_std)
X_test.shape, X_test.dtype, auxiliary_test.shape, auxiliary_test.dtype

In [None]:
with torch.no_grad():
    outputs_test = [model(X_test, auxiliary_test) for model in models]

T = 5000 // len(models)
samples_test = [sample_normal(mean, torch.sqrt(var), T) for mean, var in outputs_test]
sample_test = np.concatenate(samples_test, axis=1)
quartiles_test = np.quantile(sample_test, QUARTILES, axis=1)
sample_test.shape, quartiles_test.shape

In [None]:
light_track = light_track_format(quartiles_test)
light_track

In [None]:
regular_track_format(sample_test)