In [2]:
from sklearn.model_selection import train_test_split
import torch
from torch import nn

from ariel import *

In [3]:
def sample_normal(mean, std, T):
    return np.stack([np.random.normal(loc=mean, scale=std) for i in range(T)], axis=1)

In [4]:
modelnames = [
    "fragrant-water-360",
    "astral-oath-360",
    "toasty-tree-360",
    "fluent-sun-360",
    "smart-wildflower-359"
]
state_dicts = ["models/" + modelname + ".pt" for modelname in modelnames]
models = [Model(DEFAULT_HYPERPARAMETERS) for state_dict in state_dicts]
device = "cuda" if torch.cuda.is_available() else "cpu"
for model, state_dict in zip(models, state_dicts):
    model.load_state_dict(torch.load(state_dict, map_location=torch.device(device)))
len(models)

5

## Validation set

In [5]:
ids = np.arange(N)
ids_train, ids_valid = train_test_split(ids, train_size=0.8, random_state=36)
trainset = get_dataset(ids_train)
validset = get_dataset(
    ids_valid, trainset.auxiliary_train_mean, trainset.auxiliary_train_std)

In [6]:
outputs_valid = [model.predict(validset) for model in models]

In [7]:
samples_valid = np.concatenate([
    sample_normal(mean, torch.sqrt(var), T=250) for mean, var in outputs_valid],
    axis=1)
quartiles_valid = np.quantile(samples_valid, QUARTILES, axis=1)
light_score(validset.quartiles, quartiles_valid)

994.6414888705325

In [8]:
regular_score(samples_valid[:500], validset.ids[:500])

 50%|█████     | 250/500 [03:06<02:52,  1.45it/s]RESULT MIGHT BE INACURATE
Max number of iteration reached, currently 100000. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher
  check_result(result_code)
100%|██████████| 500/500 [06:10<00:00,  1.35it/s]


993.503549047879

## Test set

In [9]:
ids_test = np.arange(800)
spectra_test = read_spectra(ids_test, path="data/test/spectra.hdf5")
X_test = spectra_test[1]
auxiliary_test = read_auxiliary_table(ids_test, path="data/test/auxiliary_table.csv")
X_test = (X_test - X_test.mean(dim=1, keepdim=True)) / X_test.std(dim=1, keepdim=True)
auxiliary_test = standardise(auxiliary_test, trainset.auxiliary_train_mean, trainset.auxiliary_train_std)
X_test.shape, X_test.dtype, auxiliary_test.shape, auxiliary_test.dtype

(torch.Size([800, 52]), torch.float32, torch.Size([800, 9]), torch.float32)

In [10]:
with torch.no_grad():
    outputs_test = [model(X_test, auxiliary_test) for model in models]

In [11]:
samples_test = np.concatenate([
    sample_normal(mean, torch.sqrt(var), T=250) for mean, var in outputs_test],
    axis=1)
quartiles_test = np.quantile(samples_test, QUARTILES, axis=1)
light_track = light_track_format(quartiles_test)
light_track

Unnamed: 0_level_0,T_q1,T_q2,T_q3,log_H2O_q1,log_H2O_q2,log_H2O_q3,log_CO2_q1,log_CO2_q2,log_CO2_q3,log_CH4_q1,log_CH4_q2,log_CH4_q3,log_CO_q1,log_CO_q2,log_CO_q3,log_NH3_q1,log_NH3_q2,log_NH3_q3
planet_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
0,1076.076721,1093.993232,1112.198700,-10.710290,-9.224858,-7.691718,-4.523450,-4.401793,-4.279614,-11.223670,-9.868146,-8.558330,-7.044384,-5.113300,-3.319796,-6.431118,-6.289752,-6.157399
1,1578.157268,1605.771601,1633.542829,-4.916562,-4.819096,-4.726319,-4.966552,-4.835785,-4.689770,-9.167509,-7.703569,-6.215824,-10.608084,-7.710256,-4.920106,-9.725181,-7.929946,-6.318779
2,4735.917124,4920.037567,5052.461464,-10.766681,-9.310719,-7.745955,-9.211385,-8.484583,-7.734853,-10.062416,-8.238234,-7.050261,-5.537786,-4.871410,-4.200258,-9.941263,-8.204023,-6.798083
3,1966.241505,2001.846782,2036.951042,-3.568627,-3.410249,-3.261246,-10.503792,-8.880066,-7.277046,-10.406737,-8.623702,-6.733381,-5.020951,-3.676174,-2.497464,-10.824980,-8.914745,-6.900460
4,1000.101763,1024.321557,1048.038196,-3.786114,-3.627065,-3.482640,-9.619157,-8.406438,-7.223049,-7.320368,-6.750464,-6.196757,-7.520239,-5.348226,-3.715546,-10.468859,-8.728149,-7.024228
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
795,1328.794292,1353.543164,1378.424896,-3.861751,-3.729490,-3.242247,-5.874909,-5.493866,-5.371856,-5.454578,-4.954393,-4.398589,-4.071540,-3.247708,-2.334339,-5.012351,-4.141050,-3.749678
796,599.832280,611.777500,622.378119,-3.730035,-3.444158,-3.079496,-9.012510,-6.468976,-5.743366,-4.221418,-3.506996,-3.175590,-3.021300,-2.664715,-2.287299,-4.556144,-4.238950,-4.034302
797,450.447913,456.565764,468.531882,-4.924446,-4.816423,-4.481833,-4.956322,-4.448502,-4.215037,-7.887900,-6.934254,-6.169931,-4.300748,-3.726184,-2.993722,-7.981302,-6.654283,-6.101185
798,914.692317,935.300150,956.049813,-3.673781,-3.532456,-3.399993,-10.929013,-9.350889,-7.876629,-10.597704,-8.903424,-7.204806,-6.177909,-4.521718,-2.860778,-4.893043,-4.751484,-4.619098


In [12]:
regular_track_format(samples_test)