In [None]:
import numpy as np
import torch

from ariel import *

In [None]:
def sample_normal(mean, std, T):
    return np.stack([np.random.normal(loc=mean, scale=std) for i in range(T)], axis=1)

In [None]:
modelnames = [
    "fresh-rain-186",
    "super-fire-187",
    "earnest-water-188",
    "lively-frost-189",
    "glorious-fire-190",
    "snowy-dew-191",
    #"legendary-dust-192",
    "electric-silence-193",
    "pretty-serenity-194",
    #"autumn-voice-195",
    "devout-capybara-196",
    "fancy-shadow-197",
    "smart-eon-198",
    "cosmic-blaze-199",
    "divine-yogurt-200",
    "whole-durian-201",
    "swift-firebrand-202",
    "playful-snowball-203",
    "glad-frost-204",
    "clear-snowball-205"
]
state_dicts = ["models/" + modelname + ".pt" for modelname in modelnames]
models = [Model(DEFAULT_HYPERPARAMETERS) for state_dict in state_dicts]
for model, state_dict in zip(models, state_dicts):
    model.load_state_dict(torch.load(state_dict, map_location=torch.device(DEVICE)))
    model.eval()
len(models)

In [None]:
ids_pretrain = np.arange(N_ANNOTATED, N)
pretrainset = get_dataset(ids_pretrain, pretrain=True)

## Test set

In [None]:
ids_test = np.arange(N_TEST)
spectra_test = read_spectra(ids_test, path="data/test/spectra.hdf5")
X_test = spectra_test[1]
auxiliary_test = read_auxiliary_table(ids_test, path="data/test/auxiliary_table.csv")
X_test = scale(X_test)
auxiliary_test = standardise(
    auxiliary_test, pretrainset.auxiliary_mean, pretrainset.auxiliary_std)
X_test.shape, X_test.dtype, auxiliary_test.shape, auxiliary_test.dtype

In [None]:
with torch.no_grad():
    outputs_test = [model(X_test, auxiliary_test) for model in models]

T = 5000 // len(models)
samples_test = [sample_normal(mean, torch.sqrt(var), T) for mean, var in outputs_test]
sample_test = np.concatenate(samples_test, axis=1)
quartiles_test = np.quantile(sample_test, QUARTILES, axis=1)
sample_test.shape, quartiles_test.shape

In [None]:
light_track = light_track_format(quartiles_test)
light_track

In [None]:
regular_track_format(sample_test)