In [2]:
import numpy as np
import torch

from ariel import *

In [3]:
def sample_normal(mean, std, T):
    return np.stack([np.random.normal(loc=mean, scale=std) for i in range(T)], axis=1)

In [14]:
modelnames = [
    "frosty-brook-192",
    "curious-dew-193",
    "curious-firebrand-194",
    "breezy-blaze-195",
    "wandering-night-196",
    "treasured-bird-197",
    "radiant-pine-198",
    "wild-mountain-199",
    "woven-sun-200",
    "ancient-energy-201",
    "prime-dawn-202",
    "stilted-galaxy-203",
    "effortless-haze-204",
    "crimson-shape-205",
    "hearty-puddle-206",
    "gentle-monkey-207",
    "crimson-wave-208",
    "sweet-dew-209",
    "firm-surf-210",
    "iconic-mountain-211"
]
state_dicts = ["models/" + modelname + ".pt" for modelname in modelnames]
models = [Model(DEFAULT_HYPERPARAMETERS) for state_dict in state_dicts]
device = "cuda" if torch.cuda.is_available() else "cpu"
for model, state_dict in zip(models, state_dicts):
    model.load_state_dict(torch.load(state_dict, map_location=torch.device(device)))
len(models)

20

In [8]:
ids_train = np.arange(N)
trainset = get_dataset(ids_train)

## Test set

In [15]:
ids_test = np.arange(800)
spectra_test = read_spectra(ids_test, path="data/test/spectra.hdf5")
X_test = spectra_test[1]
auxiliary_test = read_auxiliary_table(ids_test, path="data/test/auxiliary_table.csv")
X_test = (X_test - X_test.mean(dim=1, keepdim=True)) / X_test.std(dim=1, keepdim=True)
auxiliary_test = standardise(auxiliary_test, trainset.auxiliary_train_mean, trainset.auxiliary_train_std)
X_test.shape, X_test.dtype, auxiliary_test.shape, auxiliary_test.dtype

(torch.Size([800, 52]), torch.float32, torch.Size([800, 9]), torch.float32)

In [16]:
with torch.no_grad():
    outputs_test = [model(X_test, auxiliary_test) for model in models]

In [17]:
samples_test = np.concatenate([
    sample_normal(mean, torch.sqrt(var), T=250) for mean, var in outputs_test],
    axis=1)
quartiles_test = np.quantile(samples_test, QUARTILES, axis=1)
light_track = light_track_format(quartiles_test)
light_track

Unnamed: 0_level_0,T_q1,T_q2,T_q3,log_H2O_q1,log_H2O_q2,log_H2O_q3,log_CO2_q1,log_CO2_q2,log_CO2_q3,log_CH4_q1,log_CH4_q2,log_CH4_q3,log_CO_q1,log_CO_q2,log_CO_q3,log_NH3_q1,log_NH3_q2,log_NH3_q3
planet_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
0,1081.661061,1095.532350,1109.627587,-10.838108,-9.333102,-7.781693,-4.483110,-4.358104,-4.241192,-11.126384,-9.737979,-8.422520,-7.868731,-5.435276,-3.082241,-6.375481,-6.246645,-6.117220
1,1577.084998,1603.117275,1628.683797,-4.938303,-4.856052,-4.772862,-4.962060,-4.830818,-4.693182,-9.171554,-7.643885,-6.045562,-10.128436,-7.768980,-5.520299,-10.145376,-8.301453,-6.549208
2,4916.967392,5030.989510,5139.274239,-10.592396,-9.167437,-7.789282,-9.488160,-8.459079,-7.745590,-10.340314,-8.717696,-7.194212,-5.410964,-4.906198,-4.455802,-9.723925,-7.921939,-6.330726
3,1970.645953,2006.956689,2043.347841,-3.538706,-3.432690,-3.325291,-10.500261,-8.868319,-7.264776,-10.527673,-8.765256,-6.955516,-5.723255,-3.972423,-2.383848,-10.734509,-8.857200,-6.963085
4,995.252382,1025.296245,1055.357062,-3.831990,-3.660500,-3.487456,-9.365329,-8.221338,-7.054548,-7.015818,-6.564018,-6.167371,-6.979165,-5.146473,-3.414565,-10.545642,-8.804992,-6.928195
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
795,1318.825526,1333.262310,1353.309255,-3.998946,-3.844564,-3.714225,-6.182215,-5.820908,-5.421714,-5.513718,-5.256863,-5.008467,-4.090438,-3.638256,-3.246909,-4.612576,-4.435822,-4.192438
796,588.192543,600.068090,610.900238,-3.524735,-3.365035,-3.186885,-9.358524,-6.796299,-5.689306,-3.911157,-3.721583,-3.434344,-3.626706,-3.028674,-2.465989,-4.375167,-4.249932,-4.083143
797,448.890420,455.425729,463.462316,-4.902051,-4.640384,-4.465802,-4.976172,-4.647168,-4.356580,-9.086497,-7.611271,-6.439453,-4.635440,-4.063099,-3.483798,-8.294401,-6.857603,-6.289797
798,914.291333,934.391293,955.253535,-3.759469,-3.625806,-3.496914,-10.688611,-9.309029,-7.841380,-10.720921,-9.091017,-7.449468,-6.251113,-4.401581,-2.631734,-4.965862,-4.842651,-4.717348


In [18]:
regular_track_format(samples_test)
samples_test.shape

(800, 5000, 6)