In [2]:
from sklearn.model_selection import train_test_split
import torch
from torch import nn

from ariel import *

In [3]:
def sample_normal(mean, std, T):
    return np.stack([np.random.normal(loc=mean, scale=std) for i in range(T)], axis=1)

## Validation set

In [18]:
ids_train.shape, ids_valid.shape

((17589,), (4398,))

In [6]:
ids = np.arange(N)
ids_train, ids_valid = train_test_split(ids, train_size=0.8, random_state=36)
data_train = get_dataset(ids_train)
data_valid = get_dataset(ids_valid)
trainset = NoisySpectraDataset(*data_train)
validset = SpectraDataset(
    *data_valid,
    trainset.auxiliary_train_mean, trainset.auxiliary_train_std)

In [19]:
modelnames = ["dauntless-vortex-103", "hopeful-morning-102", "light-brook-101", "distinctive-cosmos-104", "misunderstood-glitter-105"]
modelnames = ["polished-sea-115", "noble-moon-114", "radiant-dream-112", "wandering-cosmos-113", "lilac-blaze-111"]
state_dicts = ["models/" + modelname + ".pt" for modelname in modelnames]
models = [Model(DEFAULT_HYPERPARAMETERS) for state_dict in state_dicts]
device = "cuda" if torch.cuda.is_available() else "cpu"
for model, state_dict in zip(models, state_dicts):
    model.load_state_dict(torch.load(state_dict, map_location=torch.device(device)))
len(models)

5

In [20]:
outputs_valid = [model.predict(validset) for model in models]
samples_valid = np.concatenate([
    sample_normal(mean, torch.sqrt(var), T=1000) for mean, var in outputs_valid],
    axis=1)
quartiles_valid = np.quantile(samples_valid, QUARTILES, axis=1)
light_score(validset.quartiles, quartiles_valid)

990.1215765221007

In [None]:
# 995.126070254648 (test: 975.3316644649024)
# noisy data: 989.6687695924106 (test: 975.2577801950997)

In [14]:
regular_score(samples_valid[:500], validset.ids[:500])

  0%|          | 0/500 [00:00<?, ?it/s]RESULT MIGHT BE INACURATE
Max number of iteration reached, currently 100000. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher
  0%|          | 1/500 [00:03<27:33,  3.31s/it]RESULT MIGHT BE INACURATE
Max number of iteration reached, currently 100000. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher
  0%|          | 2/500 [00:06<27:20,  3.30s/it]RESULT MIGHT BE INACURATE
Max number of iteration reached, currently 100000. Sometimes iterations go on i

994.6852719977115

In [1]:
# 994.6852719977115 (test: 971.0450579106324)
# noisy data: 988.6791129818406 (test: 970.174913016858)

## Test set

In [21]:
ids_test = np.arange(800)
spectra_test = read_spectra(ids_test, path="data/test/spectra.hdf5")
X_test = spectra_test[1]
auxiliary_test = read_auxiliary_table(ids_test, path="data/test/auxiliary_table.csv")
X_test = scale(X_test)
auxiliary_test = standardise(auxiliary_test, trainset.auxiliary_train_mean, trainset.auxiliary_train_std)
X_test.shape, X_test.dtype, auxiliary_test.shape, auxiliary_test.dtype

(torch.Size([800, 52]), torch.float32, torch.Size([800, 9]), torch.float32)

In [22]:
with torch.no_grad():
    outputs_test = [model(X_test, auxiliary_test) for model in models]
samples_test = np.concatenate([
    sample_normal(mean, torch.sqrt(var), T=250) for mean, var in outputs_test],
    axis=1)
quartiles_test = np.quantile(samples_test, QUARTILES, axis=1)
light_track = light_track_format(quartiles_test)
light_track

Unnamed: 0_level_0,T_q1,T_q2,T_q3,log_H2O_q1,log_H2O_q2,log_H2O_q3,log_CO2_q1,log_CO2_q2,log_CO2_q3,log_CH4_q1,log_CH4_q2,log_CH4_q3,log_CO_q1,log_CO_q2,log_CO_q3,log_NH3_q1,log_NH3_q2,log_NH3_q3
planet_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
0,1053.019046,1104.416014,1167.004644,-12.573129,-10.635120,-8.791491,-4.671328,-4.479246,-4.285585,-10.808493,-9.216853,-7.590198,-7.238516,-4.989355,-2.919106,-6.608383,-6.323585,-6.087413
1,1605.950844,1676.945302,1754.010354,-5.383540,-5.118953,-4.943428,-5.055111,-4.806292,-4.564204,-9.219417,-7.630106,-6.121266,-9.848332,-7.312692,-4.513991,-8.784598,-7.148990,-5.718957
2,5387.209835,5596.358667,5857.232570,-12.765936,-10.665808,-8.760651,-9.121009,-8.216633,-7.514515,-9.952523,-8.356574,-6.614797,-5.561352,-4.983759,-4.407334,-9.460484,-8.023615,-6.578949
3,1918.100778,2020.079254,2122.536105,-4.509922,-4.116116,-3.839331,-9.515294,-7.915428,-6.542063,-10.271054,-8.454118,-6.577283,-6.785296,-4.810212,-2.922835,-9.729568,-7.932235,-6.086650
4,998.860435,1042.693835,1089.406616,-5.226905,-4.272354,-3.911495,-9.913257,-8.443795,-7.137819,-8.864538,-7.628075,-6.408043,-7.831807,-5.710475,-3.732843,-10.249789,-8.435502,-6.521128
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
795,1322.022719,1375.157526,1426.722844,-4.809508,-4.440980,-4.241363,-6.107583,-5.847860,-5.638342,-5.957846,-5.462834,-5.112699,-5.857657,-4.481164,-3.575252,-4.811678,-4.621760,-4.434327
796,568.432919,605.604577,641.860577,-4.144334,-3.844117,-3.585769,-7.224425,-6.210217,-5.522646,-4.114192,-3.939319,-3.761853,-4.604752,-3.672151,-2.882897,-4.541156,-4.365813,-4.184161
797,429.670502,448.802541,470.528941,-4.935657,-4.734087,-4.530865,-4.881294,-4.524541,-4.184052,-9.007423,-7.672142,-6.299457,-4.784057,-4.011015,-3.323635,-7.641284,-6.675043,-6.126854
798,868.036849,917.911645,968.074866,-4.200729,-3.917230,-3.686479,-10.368964,-8.720743,-7.083012,-10.764961,-8.984531,-7.211610,-7.175859,-4.889669,-2.859371,-4.933004,-4.751200,-4.576850


In [23]:
regular_track_format(samples_test)