In [1]:
from sklearn.model_selection import train_test_split
import torch
from torch import nn

from ariel import *

In [2]:
ids = np.arange(N)
ids_train, ids_valid = train_test_split(ids, train_size=0.8, random_state=36)
trainset = get_dataset(ids_train)

ids_test = np.arange(800)
spectra_test = read_spectra(ids_test, path="data/test/spectra.hdf5")
X_test = spectra_test[1]
auxiliary_test = read_auxiliary_table(ids_test, path="data/test/auxiliary_table.csv")

X_test = standardise(X_test, trainset.X_train_mean, trainset.X_train_std)
auxiliary_test = standardise(auxiliary_test, trainset.auxiliary_train_mean, trainset.auxiliary_train_std)
X_test.shape, X_test.dtype, auxiliary_test.shape, auxiliary_test.dtype

(torch.Size([800, 52]), torch.float32, torch.Size([800, 9]), torch.float32)

In [3]:
modelnames = [
    "deft-sweep-25",
    "dry-glitter-319",
    "clean-aardvark-320",
    "denim-smoke-321",
    "grateful-bush-321",
    "zany-shape-323",
    "fancy-capybara-324",
    "balmy-plasma-325",
    "efficient-butterfly-326",
    "divine-tree-327",
    "trim-puddle-328",
    "firm-surf-329",
    "faithful-field-330",
    "crimson-sea-331",
    "scarlet-pine-332",
    "dulcet-pine-333",
    "ruby-monkey-335",
    "wise-lake-334",
    "smooth-fog-336",
    "fanciful-mountain-337"]
state_dicts = ["models/" + modelname + ".pt" for modelname in modelnames]
models = [Model(DEFAULT_HYPERPARAMETERS) for state_dict in state_dicts]
device = "cuda" if torch.cuda.is_available() else "cpu"
for model, state_dict in zip(models, state_dicts):
    model.load_state_dict(torch.load(state_dict, map_location=torch.device(device)))

with torch.no_grad():
    outputs = [model(X_test, auxiliary_test) for model in models]

In [4]:
def sample_normal(mean, std, T):
    return np.stack([np.random.normal(loc=mean, scale=std) for i in range(T)], axis=1)

samples = np.concatenate([sample_normal(mean, std, T=250) for mean, std in outputs], axis=1)
samples.shape

(800, 5000, 6)

In [5]:
quartiles = np.quantile(samples, QUARTILES, axis=1)
quartiles.shape

(3, 800, 6)

In [6]:
light_track = light_track_format(quartiles)
light_track

Unnamed: 0_level_0,T_q1,T_q2,T_q3,log_H2O_q1,log_H2O_q2,log_H2O_q3,log_CO2_q1,log_CO2_q2,log_CO2_q3,log_CH4_q1,log_CH4_q2,log_CH4_q3,log_CO_q1,log_CO_q2,log_CO_q3,log_NH3_q1,log_NH3_q2,log_NH3_q3
planet_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
0,1090.716433,1097.975687,1107.408871,-9.463928,-9.179101,-8.048579,-4.890283,-4.579167,-4.230567,-9.762237,-9.581256,-9.342784,-6.336299,-4.791616,-3.557676,-6.657428,-6.190794,-5.760679
1,1585.455903,1602.654683,1619.753293,-5.029722,-4.723689,-4.476177,-5.289966,-4.842262,-4.499202,-8.395093,-7.458362,-6.867597,-8.210188,-7.191950,-5.544895,-9.182940,-8.586967,-7.053478
2,4670.168291,5115.920802,5551.408663,-9.728385,-8.068488,-5.606928,-10.277581,-8.559899,-6.821476,-10.343467,-8.947164,-7.977713,-8.889901,-3.825997,0.441239,-9.990380,-8.430192,-6.058551
3,1984.363310,2010.981127,2036.679678,-3.943845,-3.350724,-2.968161,-9.040220,-8.292156,-7.250538,-9.301231,-8.798516,-8.386143,-6.723263,-4.665374,-2.778529,-9.179791,-8.638097,-7.826443
4,1014.954203,1023.902959,1031.099005,-3.751043,-3.524807,-3.306738,-8.951218,-8.137231,-7.517610,-7.139204,-6.677237,-6.382870,-6.794920,-5.125984,-3.966852,-9.049853,-8.935476,-8.631140
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
795,1312.072727,1322.333085,1333.490502,-3.913629,-3.601396,-3.280923,-7.324058,-5.451666,-4.348249,-9.428836,-9.132997,-8.741256,-6.552801,-3.800329,-2.145992,-4.328632,-3.939408,-3.635960
796,599.034975,601.802919,604.752991,-3.656472,-3.536765,-3.365352,-7.587362,-5.789031,-4.835627,-4.120765,-3.889378,-3.643516,-6.195321,-3.716402,-2.379078,-4.432371,-4.275993,-4.036123
797,451.427779,454.500958,457.389182,-4.850579,-4.646527,-4.436558,-4.646517,-4.130662,-3.820897,-9.541288,-9.096791,-8.301703,-5.863027,-4.570384,-3.738395,-8.100591,-7.161968,-6.566181
798,925.231366,931.618223,937.813708,-3.832864,-3.502108,-3.319760,-9.147287,-8.493239,-7.194792,-9.382134,-9.258242,-8.703659,-7.452833,-5.853595,-4.718950,-4.949567,-4.722584,-4.577215


In [7]:
regular_track_format(samples)