In [None]:
import numpy as np
from sklearn.model_selection import train_test_split
import torch

from ariel import *

## Validation set

In [None]:
ids_pretrain = np.arange(N_ANNOTATED, N)
data_pretrain = get_data(ids_pretrain, pretrain=True)
pretrainset = NoisySpectraDataset(**data_pretrain)
ids_train = np.arange(N_ANNOTATED)
ids_train, ids_valid = train_test_split(ids_train, train_size=0.8, random_state=36)
data_train = get_data(ids_train)
data_valid = get_data(ids_valid)
trainset = NoisySpectraDataset(
        **data_train,
        auxiliary_mean=pretrainset.auxiliary_mean, auxiliary_std=pretrainset.auxiliary_std)
validset = SpectraDataset(
        **data_valid,
        auxiliary_mean=pretrainset.auxiliary_mean, auxiliary_std=pretrainset.auxiliary_std)

In [None]:
modelnames = [
    "rose-frog-439",
    "eager-water-440",
    "ancient-deluge-441",
    "fast-cherry-442",
    "trim-dew-443",
    "comic-violet-444",
    "dutiful-star-445",
    "vague-dew-446",
    "skilled-vortex-447",
    "fluent-sea-448",
    "cosmic-terrain-449",
    "twilight-haze-450",
    "hearty-snowflake-451",
    "swift-gorge-452",
    "efficient-water-453",
    "usual-breeze-454",
    "soft-pyramid-455",
    "silvery-sound-456",
    "tough-fire-457",
    "pretty-hill-458"
]
state_dicts = ["models/" + modelname + ".pt" for modelname in modelnames]
models = [Model(DEFAULT_HYPERPARAMETERS) for state_dict in state_dicts]
for model, state_dict in zip(models, state_dicts):
    model.load_state_dict(torch.load(state_dict, map_location=torch.device(DEVICE)))
    model.eval()
len(models)

In [None]:
def sample_normal(mean, L, T):
    sample = distributions.MultivariateNormal(mean, scale_tril=L).sample((T, ))
    return torch.transpose(sample, 0, 1) 

In [None]:
outputs_valid = [model.predict(validset) for model in models]
T = 5000 // len(models)
samples_valid = [sample_normal(mean, L, T) for mean, L in outputs_valid]
sample_valid = np.concatenate(samples_valid, axis=1)
quartiles_valid = np.quantile(sample_valid, QUARTILES, axis=1)
light_score(validset.quartiles, quartiles_valid)

In [None]:
regular_score(sample_valid[:500], validset.ids[:500])

## Test set

In [None]:
ids_test = np.arange(800)
spectra_test = read_spectra(ids_test, path="data/test/spectra.hdf5")
X_test = spectra_test[1]
auxiliary_test = read_auxiliary_table(ids_test, path="data/test/auxiliary_table.csv")
X_test = scale(X_test)
auxiliary_test = standardise(
    auxiliary_test, pretrainset.auxiliary_mean, pretrainset.auxiliary_std)
X_test.shape, X_test.dtype, auxiliary_test.shape, auxiliary_test.dtype

In [None]:
with torch.no_grad():
    outputs_test = [model(X_test, auxiliary_test) for model in models]

samples_test = [sample_normal(mean, L, T=250) for mean, L in outputs_test]
sample_test = np.concatenate(samples_test, axis=1)
quartiles_test = np.quantile(sample_test, QUARTILES, axis=1)
sample_test.shape, quartiles_test.shape

In [None]:
light_track = light_track_format(quartiles_test)
light_track

In [None]:
regular_track_format(sample_test)
sample_test.shape