In [1]:
import numpy as np
from sklearn.model_selection import train_test_split
import torch

from ariel import *

In [2]:
def sample_normal(mean, std, T):
    return np.stack([np.random.normal(loc=mean, scale=std) for i in range(T)], axis=1)

## Validation set

In [4]:
ids = np.arange(N)
ids_train, ids_valid = train_test_split(ids, train_size=0.8, random_state=36)
data_train = get_dataset(ids_train)
data_valid = get_dataset(ids_valid)
trainset = NoisySpectraDataset(*data_train)
validset = SpectraDataset(
    *data_valid,
    trainset.auxiliary_train_mean, trainset.auxiliary_train_std)
ids_train.shape, ids_valid.shape

((17589,), (4398,))

In [5]:
modelnames = ["dauntless-vortex-103", "hopeful-morning-102", "light-brook-101", "distinctive-cosmos-104", "misunderstood-glitter-105"]
modelnames = ["polished-sea-115", "noble-moon-114", "radiant-dream-112", "wandering-cosmos-113", "lilac-blaze-111"]
state_dicts = ["models/" + modelname + ".pt" for modelname in modelnames]
models = [Model(DEFAULT_HYPERPARAMETERS) for state_dict in state_dicts]
device = "cuda" if torch.cuda.is_available() else "cpu"
for model, state_dict in zip(models, state_dicts):
    model.load_state_dict(torch.load(state_dict, map_location=torch.device(device)))
len(models)

5

In [6]:
outputs_valid = [model.predict(validset) for model in models]
samples_valid = np.concatenate([
    sample_normal(mean, torch.sqrt(var), T=1000) for mean, var in outputs_valid],
    axis=1)
quartiles_valid = np.quantile(samples_valid, QUARTILES, axis=1)
light_score(validset.quartiles, quartiles_valid)

990.1099696260763

In [None]:
# 995.126070254648 (test: 975.3316644649024)
# noisy data: 989.6687695924106 (test: 975.2577801950997)

In [None]:
regular_score(samples_valid[:500], validset.ids[:500])

In [None]:
# 994.6852719977115 (test: 971.0450579106324)
# noisy data: 988.6791129818406 (test: 970.174913016858)

## Test set

In [7]:
ids_test = np.arange(800)
spectra_test = read_spectra(ids_test, path="data/test/spectra.hdf5")
X_test = spectra_test[1]
auxiliary_test = read_auxiliary_table(ids_test, path="data/test/auxiliary_table.csv")
X_test = scale(X_test)
auxiliary_test = standardise(auxiliary_test, trainset.auxiliary_train_mean, trainset.auxiliary_train_std)
X_test.shape, X_test.dtype, auxiliary_test.shape, auxiliary_test.dtype

(torch.Size([800, 52]), torch.float32, torch.Size([800, 9]), torch.float32)

In [8]:
with torch.no_grad():
    outputs_test = [model(X_test, auxiliary_test) for model in models]
samples_test = np.concatenate([
    sample_normal(mean, torch.sqrt(var), T=250) for mean, var in outputs_test],
    axis=1)
quartiles_test = np.quantile(samples_test, QUARTILES, axis=1)
light_track = light_track_format(quartiles_test)
light_track

Unnamed: 0_level_0,T_q1,T_q2,T_q3,log_H2O_q1,log_H2O_q2,log_H2O_q3,log_CO2_q1,log_CO2_q2,log_CO2_q3,log_CH4_q1,log_CH4_q2,log_CH4_q3,log_CO_q1,log_CO_q2,log_CO_q3,log_NH3_q1,log_NH3_q2,log_NH3_q3
planet_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
0,1103.389270,1146.014024,1187.612235,-12.827775,-10.919790,-8.883407,-4.721163,-4.514639,-4.306102,-10.541329,-9.194980,-7.804109,-7.830815,-5.264796,-3.049139,-6.707351,-6.369145,-6.118935
1,1605.334179,1698.918492,1784.570475,-5.264023,-5.076784,-4.886699,-5.093571,-4.851765,-4.622532,-9.012706,-7.586000,-6.064235,-9.859350,-7.190492,-4.501278,-9.166517,-7.470333,-5.664373
2,5331.072440,5599.635845,5810.826401,-13.481579,-11.191005,-9.038355,-9.237009,-8.304927,-7.486267,-9.677402,-7.916786,-6.361482,-5.350973,-4.700459,-4.183405,-9.725294,-8.124455,-6.503003
3,1929.096782,2040.641252,2147.707849,-4.347931,-4.052049,-3.805580,-9.807376,-8.208652,-6.538153,-10.821321,-8.895922,-6.862265,-7.399133,-5.226374,-2.985685,-10.221416,-8.302659,-6.429971
4,1006.852250,1055.804477,1103.661988,-4.829280,-4.185151,-3.827741,-9.778154,-8.272952,-6.890462,-9.057072,-7.563402,-6.236933,-7.516016,-5.481215,-3.558883,-9.901531,-8.213960,-6.342934
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
795,1316.601575,1368.157683,1417.533103,-4.704861,-4.487048,-4.352076,-6.222796,-5.905012,-5.600917,-6.142982,-5.603417,-5.193520,-5.834685,-4.612896,-3.675382,-4.834496,-4.691531,-4.549544
796,559.934759,600.023146,654.684094,-4.150851,-3.902576,-3.617141,-7.139496,-6.257116,-5.545966,-4.205410,-4.006593,-3.824686,-4.488960,-3.624485,-2.825176,-4.561154,-4.377094,-4.168876
797,433.538259,457.190548,478.938944,-5.034794,-4.844478,-4.603761,-4.887092,-4.531045,-4.184727,-9.543593,-7.999228,-6.407455,-5.056764,-4.253153,-3.439985,-7.983479,-6.872046,-6.188636
798,889.190058,937.975574,995.237805,-4.209382,-3.966837,-3.762159,-9.943196,-8.336208,-6.987999,-10.834459,-8.925924,-7.066892,-7.250345,-4.971090,-2.846873,-4.930739,-4.736758,-4.555167


In [9]:
regular_track_format(samples_test)
samples_test.shape

(800, 1250, 6)