In [1]:
from sklearn.model_selection import train_test_split
import torch
from torch import nn

from ariel import *

In [2]:
def sample_normal(mean, std, T):
    return np.stack([np.random.normal(loc=mean, scale=std) for i in range(T)], axis=1)

In [3]:
modelnames = [
    "hearty-firefly-352",
    "efficient-lion-351",
    "fearless-capybara-350",
    "floral-plasma-349",
    "resilient-salad-348",
    "fresh-lake-347",
    "breezy-yogurt-346",
    "dulcet-haze-345",
    "confused-aardvark-344",
    "hardy-armadillo-343",
    "devout-shadow-342",
    "wise-resonance-341",
    "graceful-darkness-340",
    "proud-night-339",
    "smart-eon-338",
    "lemon-cosmos-337",
    "fine-violet-336",
    "usual-bee-335",
    "misunderstood-water-334",
    "genial-leaf-333"
]
state_dicts = ["models/" + modelname + ".pt" for modelname in modelnames]
models = [Model(DEFAULT_HYPERPARAMETERS) for state_dict in state_dicts]
device = "cuda" if torch.cuda.is_available() else "cpu"
for model, state_dict in zip(models, state_dicts):
    model.load_state_dict(torch.load(state_dict, map_location=torch.device(device)))

## Validation set

In [4]:
ids = np.arange(N)
ids_train, ids_valid = train_test_split(ids, train_size=0.8, random_state=36)
trainset = get_dataset(ids_train)
validset = get_dataset(
    ids_valid, trainset.auxiliary_train_mean, trainset.auxiliary_train_std)

In [5]:
outputs_valid = [model.predict(validset) for model in models]

In [6]:
samples_valid = np.concatenate([
    sample_normal(mean, torch.sqrt(var), T=250) for mean, var in outputs_valid],
    axis=1)
quartiles_valid = np.quantile(samples_valid, QUARTILES, axis=1)
light_score(validset.quartiles, quartiles_valid)

988.9590532905528

In [None]:
regular_score(samples_valid[:500], validset.ids[:500])

## Test set

In [7]:
ids_test = np.arange(800)
spectra_test = read_spectra(ids_test, path="data/test/spectra.hdf5")
X_test = spectra_test[1]
auxiliary_test = read_auxiliary_table(ids_test, path="data/test/auxiliary_table.csv")
X_test = (X_test - X_test.mean(dim=1, keepdim=True)) / X_test.std(dim=1, keepdim=True)
auxiliary_test = standardise(auxiliary_test, trainset.auxiliary_train_mean, trainset.auxiliary_train_std)
X_test.shape, X_test.dtype, auxiliary_test.shape, auxiliary_test.dtype

(torch.Size([800, 52]), torch.float32, torch.Size([800, 9]), torch.float32)

In [8]:
with torch.no_grad():
    outputs_test = [model(X_test, auxiliary_test) for model in models]

In [9]:
samples_test = np.concatenate([
    sample_normal(mean, torch.sqrt(var), T=250) for mean, var in outputs_test],
    axis=1)
quartiles_test = np.quantile(samples_test, QUARTILES, axis=1)
light_track = light_track_format(quartiles_test)
light_track

Unnamed: 0_level_0,T_q1,T_q2,T_q3,log_H2O_q1,log_H2O_q2,log_H2O_q3,log_CO2_q1,log_CO2_q2,log_CO2_q3,log_CH4_q1,log_CH4_q2,log_CH4_q3,log_CO_q1,log_CO_q2,log_CO_q3,log_NH3_q1,log_NH3_q2,log_NH3_q3
planet_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
0,1093.984406,1098.333067,1102.495163,-9.639672,-9.385441,-9.017010,-4.524674,-4.308253,-4.115872,-9.959359,-9.744446,-9.526174,-5.693941,-4.945644,-4.272630,-6.488044,-6.178168,-5.899204
1,1599.206882,1604.366345,1610.034211,-4.976118,-4.729993,-4.445066,-5.085149,-4.800401,-4.520850,-7.936440,-7.664780,-7.304749,-8.267015,-7.342932,-6.425355,-8.624533,-7.990357,-7.323595
2,5077.802045,5117.820663,5162.060737,-9.640170,-8.779598,-7.541153,-9.472256,-8.280484,-7.552075,-9.228012,-8.555046,-7.586657,-5.796461,-3.076877,-0.322550,-9.655059,-8.243859,-6.668634
3,2002.826223,2009.364145,2016.249519,-3.673455,-3.386767,-3.010125,-9.337507,-8.956978,-8.598464,-9.109046,-8.803363,-8.402114,-4.913286,-3.684740,-2.411437,-9.105685,-8.678651,-8.200187
4,1022.062707,1026.196919,1030.209470,-3.730180,-3.555705,-3.375681,-8.443032,-8.100254,-7.805621,-6.924347,-6.637887,-6.392794,-5.616676,-4.904572,-4.229582,-8.912300,-8.603757,-8.246953
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
795,1320.653752,1325.298043,1330.076292,-4.214476,-3.831195,-3.416117,-6.251784,-5.729686,-5.282403,-8.688281,-6.695692,-5.619516,-5.462477,-3.503170,-2.173817,-4.769914,-4.383545,-3.985823
796,600.086197,602.569860,604.624983,-3.638403,-3.290490,-3.042403,-7.179546,-6.237291,-5.349426,-3.716082,-3.488728,-3.204942,-3.677119,-3.106534,-2.484282,-4.414235,-4.140867,-3.922510
797,452.033643,454.304132,456.519976,-4.824469,-4.618367,-4.450129,-4.689590,-4.402164,-4.168032,-9.162821,-8.472205,-7.855183,-4.545819,-4.061104,-3.311648,-9.066060,-7.755781,-6.769745
798,928.473302,932.228214,935.718724,-3.676912,-3.495889,-3.296919,-9.559867,-9.176007,-8.836075,-9.331978,-9.131983,-8.919507,-4.937381,-4.234254,-3.649447,-4.921446,-4.692590,-4.463272


In [10]:
regular_track_format(samples_test)