In [1]:
import h5py
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch

In [2]:
from ariel import *

In [3]:
model = Model(HYPERPARAMETER_DEFAULTS)
model.load_state_dict(torch.load("models/fancy-sweep-6.pt", map_location=torch.device('cpu')))
model

Model(
  (input): Linear(in_features=52, out_features=128, bias=True)
  (linear1): Linear(in_features=128, out_features=128, bias=True)
  (linear2): Linear(in_features=128, out_features=128, bias=True)
  (linear3): Linear(in_features=128, out_features=128, bias=True)
  (linear4): Linear(in_features=128, out_features=128, bias=True)
  (linear5): Linear(in_features=128, out_features=128, bias=True)
  (linear6): Linear(in_features=128, out_features=128, bias=True)
  (linear7): Linear(in_features=128, out_features=128, bias=True)
  (output): Linear(in_features=128, out_features=12, bias=True)
)

In [4]:
X = read_spectra()
ids = torch.arange(N)
ids_train, ids_valid = train_test_split(ids, train_size=0.8, random_state=36)
idx_train = torch.zeros_like(ids, dtype=torch.bool)
idx_train[ids_train] = True
idx_valid = ~idx_train
X_train, X_valid = X[idx_train], X[idx_valid]
X_train_mean, X_train_std = X_train.mean(), X_train.std()

In [5]:
X_test = read_spectra(path="data/test/spectra.hdf5", n=500).float()
X_test.shape, X_test.dtype

(torch.Size([500, 1, 52]), torch.float32)

In [6]:
with torch.no_grad():
    mean, var = model(X_test)
    mean, var = mean.cpu().numpy(), var.cpu().numpy()
    std = np.sqrt(var)
    quartiles_test_pred = np.stack([norm.ppf(quartile, loc=mean, scale=std) for quartile in QUARTILES])
mean.shape, var.shape, std.shape, quartiles_test_pred.shape

((500, 6), (500, 6), (500, 6), (3, 500, 6))

In [7]:
light_track = light_track_format(quartiles_test_pred)
light_track

Unnamed: 0_level_0,T_q1,T_q2,T_q3,log_H2O_q1,log_H2O_q2,log_H2O_q3,log_CO2_q1,log_CO2_q2,log_CO2_q3,log_CH4_q1,log_CH4_q2,log_CH4_q3,log_CO_q1,log_CO_q2,log_CO_q3,log_NH3_q1,log_NH3_q2,log_NH3_q3
planet_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
0,1158.223097,1394.091797,1629.960497,-9.166416,-7.311687,-5.456957,-8.428810,-6.511472,-4.594134,-10.065075,-8.427371,-6.789667,-7.481150,-5.600060,-3.718971,-9.697487,-7.989668,-6.281849
1,1255.488593,1490.359497,1725.230401,-8.392666,-6.631400,-4.870133,-8.395246,-6.498709,-4.602171,-9.845758,-8.219120,-6.592482,-7.662621,-5.750456,-3.838290,-9.757142,-8.062422,-6.367701
2,1167.464621,1408.383545,1649.302469,-9.196788,-7.382762,-5.568736,-8.742389,-6.750426,-4.758462,-9.942980,-8.272850,-6.602720,-7.412048,-5.522152,-3.632256,-9.652509,-7.918513,-6.184517
3,1254.599255,1496.885498,1739.171741,-8.641920,-6.827581,-5.013243,-8.853370,-6.849727,-4.846085,-9.969913,-8.300599,-6.631286,-7.589212,-5.658641,-3.728070,-9.613815,-7.879974,-6.146133
4,1147.981337,1381.161377,1614.341417,-8.017157,-6.338706,-4.660254,-8.790646,-6.855668,-4.920690,-9.847233,-8.232354,-6.617476,-7.431976,-5.540589,-3.649202,-9.653526,-7.979519,-6.305512
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
495,1224.409665,1459.322632,1694.235599,-9.229428,-7.348453,-5.467477,-8.520826,-6.554562,-4.588298,-8.734452,-7.064068,-5.393685,-7.597824,-5.715401,-3.832978,-9.641434,-7.889858,-6.138283
496,1313.063769,1465.301636,1617.539502,-3.916268,-3.776666,-3.637064,-8.465579,-7.737267,-7.008954,-9.304114,-8.951287,-8.598460,-8.574662,-7.296610,-6.018558,-9.348820,-8.763162,-8.177503
497,1160.636348,1392.832886,1625.029424,-8.030318,-6.353488,-4.676659,-8.124043,-6.276061,-4.428078,-10.015324,-8.408980,-6.802637,-7.586403,-5.672863,-3.759323,-9.757895,-8.086556,-6.415218
498,1501.606469,1735.968384,1970.330299,-9.574255,-7.894539,-6.214823,-8.295975,-6.450222,-4.604468,-9.552429,-7.904387,-6.256345,-7.913795,-6.073674,-4.233553,-9.462896,-7.738726,-6.014556


In [8]:
Y_test_pred = np.stack([np.random.normal(loc=mean, scale=std) for i in range(5000)], axis=1)
Y_test_pred.shape

(500, 5000, 6)

In [9]:
regular_track = regular_track_format(Y_test_pred)