In [1]:
import h5py
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch

In [2]:
import ariel
from train import HYPERPARAMETER_DEFAULTS, Model

In [3]:
model = Model(HYPERPARAMETER_DEFAULTS)
model.load_state_dict(torch.load("models/radiant-sweep-12.pt"))

<All keys matched successfully>

In [4]:
X = ariel.read_spectra()
quartiles = ariel.read_quartiles_table()
Y = torch.from_numpy(quartiles[1]).float()

# train and validation set split
ids = torch.arange(ariel.N)
ids_train, ids_valid = train_test_split(ids, train_size=0.8, random_state=36)
idx_train = torch.zeros_like(ids, dtype=torch.bool)
idx_train[ids_train] = True
idx_valid = ~idx_train
X_train, X_valid = X[idx_train], X[idx_valid]
Y_train = Y[idx_train]

X_train_mean, X_train_std = X_train.mean(), X_train.std()
Y_train_mean, Y_train_std = Y_train.mean(dim=0), Y_train.mean(dim=0)

In [5]:
X_test = ariel.read_spectra(path="data/test/spectra.hdf5", n=500).cuda()
X_test.shape

torch.Size([500, 1, 52])

In [6]:
Y_test_pred = model.sample(X_test)
Y_test_pred = ariel.unstandardise(Y_test_pred, Y_train_mean, Y_train_std)
Y_test_pred.shape

torch.Size([500, 256, 6])

In [7]:
quartiles_test_pred = np.quantile(Y_test_pred, ariel.QUARTILES, axis=1)
light_track = ariel.light_track_format(quartiles_test_pred)
light_track

Unnamed: 0_level_0,T_q1,T_q2,T_q3,log_H2O_q1,log_H2O_q2,log_H2O_q3,log_CO2_q1,log_CO2_q2,log_CO2_q3,log_CH4_q1,log_CH4_q2,log_CH4_q3,log_CO_q1,log_CO_q2,log_CO_q3,log_NH3_q1,log_NH3_q2,log_NH3_q3
planet_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
0,1103.549805,1103.549805,1105.139746,-7.015512,-7.007380,-7.007380,-6.446874,-6.446874,-6.446874,-7.754189,-7.754162,-7.754162,-4.808358,-4.808347,-4.808347,-7.982824,-7.982824,-7.982824
1,1180.695630,1181.101807,1181.101807,-6.124898,-6.123118,-6.123118,-6.441237,-6.441237,-6.441237,-7.600441,-7.583046,-7.583046,-5.010478,-5.007578,-5.007578,-8.318925,-8.315218,-8.315218
2,1106.726196,1106.726196,1106.726196,-6.891134,-6.891134,-6.891134,-6.899675,-6.899675,-6.899675,-7.315840,-7.315840,-7.315840,-4.587857,-4.583075,-4.583075,-8.095128,-8.095128,-8.095128
3,1189.808594,1189.808594,1189.808594,-6.323099,-6.323086,-6.323086,-6.982583,-6.972243,-6.972243,-7.432796,-7.432796,-7.432796,-4.696952,-4.692652,-4.692652,-8.173092,-8.172715,-8.172715
4,1109.238843,1109.805908,1110.589258,-6.069759,-6.067734,-6.065557,-6.896208,-6.896208,-6.877727,-7.399947,-7.397564,-7.395387,-4.732304,-4.727652,-4.721579,-8.178293,-8.164770,-8.164770
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
495,1106.192993,1106.192993,1107.164941,-7.005617,-7.005617,-7.001334,-6.697402,-6.678648,-6.678648,-6.015561,-6.001676,-6.001676,-4.776592,-4.773964,-4.773887,-8.129156,-8.128888,-8.124713
496,1441.243286,1441.243286,1441.243286,-3.254073,-3.254073,-3.254073,-7.349569,-7.349569,-7.349569,-8.697643,-8.697643,-8.697643,-5.507694,-5.507694,-5.507694,-8.852551,-8.852551,-8.852551
497,1079.410376,1082.146240,1083.603174,-5.915623,-5.907012,-5.907012,-6.402429,-6.397555,-6.387394,-7.689545,-7.682267,-7.682079,-5.003278,-4.987299,-4.987299,-8.307274,-8.301463,-8.301170
498,1313.637085,1313.637085,1318.163086,-7.183120,-7.183120,-7.183120,-6.619023,-6.619023,-6.619023,-6.922572,-6.915458,-6.915458,-4.577365,-4.571909,-4.571909,-7.815989,-7.815428,-7.815428


In [8]:
regular_track = ariel.regular_track_format(Y_test_pred)