In [None]:
import numpy as np
import torch

from ariel import *

In [None]:
def sample_normal(mean, std, T):
    return np.random.normal(loc=mean, scale=std, size=(T, *mean.shape)).swapaxes(0, 1)

In [None]:
modelnames = [
    "morning-grass-661",
    "fanciful-star-651",
    "clean-lake-649",
    "classic-durian-662",
    "scarlet-dew-666",
    "quiet-forest-655",
    "hopeful-hill-651",
    "whole-tree-648",
    "faithful-elevator-659",
    "spring-violet-650",
    "brisk-cloud-657",
    "drawn-dream-663",
    "pretty-smoke-660",
    "toasty-firebrand-671",
    "brisk-sunset-664",
    "lunar-firefly-656",
    "balmy-oath-668",
    "eternal-donkey-658",
    "hopeful-serenity-653",
    "stoic-jazz-670"
]
state_dicts = ["models/" + modelname + ".pt" for modelname in modelnames]
models = [Model(DEFAULT_HYPERPARAMETERS) for state_dict in state_dicts]
for model, state_dict in zip(models, state_dicts):
    model.load_state_dict(torch.load(state_dict, map_location=torch.device(DEVICE)))
    model.eval()
len(models)

In [None]:
ids_train = np.arange(N_ANNOTATED)
trainset = get_dataset(ids_train)

## Test set

In [None]:
ids_test = np.arange(N_TEST)
spectra_test = read_spectra(ids_test, path="data/test/spectra.hdf5")
X_test = spectra_test[1]
auxiliary_test = read_auxiliary_table(ids_test, path="data/test/auxiliary_table.csv")
X_test = scale(X_test)
auxiliary_test = standardise(auxiliary_test, trainset.auxiliary_mean, trainset.auxiliary_std)
X_test.shape, X_test.dtype, auxiliary_test.shape, auxiliary_test.dtype

In [None]:
with torch.no_grad():
    outputs_test = [model(X_test, auxiliary_test) for model in models]

T = 1000 // len(models)
samples_test = [sample_normal(mean, torch.sqrt(var), T) for mean, var in outputs_test]
sample_test = np.concatenate(samples_test, axis=1)
quartiles_test = np.quantile(sample_test, QUARTILES, axis=1)
sample_test.shape, quartiles_test.shape

In [None]:
light_track = light_track_format(quartiles_test)
light_track

In [None]:
regular_track_format(sample_test)