In [1]:
import h5py
import numpy as np
import ot
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from tqdm.notebook import tqdm_notebook

from ariel import *

In [2]:
model = Model(HYPERPARAMETER_DEFAULTS)
model.load_state_dict(torch.load(
    "models/silver-sweep-24.pt", map_location=torch.device('cpu')))
model.eval()

Model(
  (input): Linear(in_features=61, out_features=2048, bias=True)
  (linear1): Linear(in_features=2048, out_features=2048, bias=True)
  (linear2): Linear(in_features=2048, out_features=2048, bias=True)
  (linear3): Linear(in_features=2048, out_features=2048, bias=True)
  (linear4): Linear(in_features=2048, out_features=2048, bias=True)
  (linear5): Linear(in_features=2048, out_features=2048, bias=True)
  (output): Linear(in_features=2048, out_features=12, bias=True)
)

In [3]:
trainset, validset = get_datasets()
X_train_mean, X_train_std = trainset.X_train_mean, trainset.X_train_std
auxiliary_train_mean = trainset.auxiliary_train_mean
auxiliary_train_std = trainset.auxiliary_train_std
X_train_mean, X_train_std, auxiliary_train_mean, auxiliary_train_std

(tensor(0.0052),
 tensor(0.0074),
 tensor([5.6913e+02, 2.0330e+30, 8.5143e+08, 5.6658e+03, 1.0991e+27, 2.4572e+01,
         1.1997e-01, 4.4601e+07, 1.6370e+01]),
 tensor([4.7075e+02, 6.8741e+29, 4.6110e+08, 9.3514e+02, 8.2191e+27, 9.6469e+01,
         1.9630e-01, 3.5628e+07, 6.6791e+01]))

In [4]:
mean_valid, var_valid = model.predict(validset)
mean_valid, var_valid = mean_valid.cpu().numpy(), var_valid.cpu().numpy()
std_valid = np.sqrt(var_valid)
mean_valid.shape, var_valid.shape, std_valid.shape

((4398, 6), (4398, 6), (4398, 6))

In [5]:
def get_quartiles(mean, std):
    return np.stack([norm.ppf(quartile, loc=mean, scale=std) for quartile in QUARTILES])

quartiles_valid_pred = get_quartiles(mean_valid, std_valid)
quartiles_valid_pred.shape

(3, 4398, 6)

In [6]:
light_score(validset.quartiles, quartiles_valid_pred)

975.8500492089638

In [7]:
def sample_normal(mean, std, T=5000):
    return np.stack([np.random.normal(loc=mean, scale=std) for i in range(T)], axis=1)

Y_valid_pred = sample_normal(mean_valid, std_valid)
Y_valid_pred.shape

(4398, 5000, 6)

In [8]:
#regular_score("data/train/ground_truth/traces.hdf5", Y_valid_pred, validset.ids)

In [9]:
N_test = 500
spectra_test = read_spectra(path="data/test/spectra.hdf5", n=N_test)
X_test = spectra_test[1]
auxiliary_test = read_auxiliary_table(path="data/test/auxiliary_table.csv", n=N_test)

X_test = standardise(X_test, X_train_mean, X_train_std)
auxiliary_test = standardise(auxiliary_test, auxiliary_train_mean, auxiliary_train_std)

if torch.cuda.is_available():
    X_test, auxiliary_test = X_test.cuda(), auxiliary_test.cuda()

X_test.shape, X_test.dtype, auxiliary_test.shape, auxiliary_test.dtype

(torch.Size([500, 52]), torch.float32, torch.Size([500, 9]), torch.float32)

In [10]:
with torch.no_grad():
    mean_test, var_test = model(X_test, auxiliary_test)
mean_test, var_test = mean_test.cpu().numpy(), var_test.cpu().numpy()
std_test = np.sqrt(var_test)
quartiles_test_pred = get_quartiles(mean_test, std_test)
mean_test.shape, var_test.shape, std_test.shape, quartiles_test_pred.shape

((500, 6), (500, 6), (500, 6), (3, 500, 6))

In [11]:
light_track = light_track_format(quartiles_test_pred)
light_track

Unnamed: 0_level_0,T_q1,T_q2,T_q3,log_H2O_q1,log_H2O_q2,log_H2O_q3,log_CO2_q1,log_CO2_q2,log_CO2_q3,log_CH4_q1,log_CH4_q2,log_CH4_q3,log_CO_q1,log_CO_q2,log_CO_q3,log_NH3_q1,log_NH3_q2,log_NH3_q3
planet_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
0,1079.940998,1096.160034,1112.379071,-8.576547,-7.593754,-6.610961,-4.715073,-4.574337,-4.433601,-11.221022,-9.560147,-7.899273,-8.981201,-6.273302,-3.565402,-6.395953,-6.100485,-5.805016
1,1581.346336,1605.782104,1630.217873,-4.957945,-4.881715,-4.805484,-4.924945,-4.794497,-4.664048,-9.205572,-8.015659,-6.825746,-9.565350,-6.740133,-3.914917,-9.824878,-8.174597,-6.524316
2,5054.174185,5117.702637,5181.231088,-10.143639,-8.585912,-7.028184,-11.503229,-9.515368,-7.527508,-10.243748,-8.920275,-7.596801,-4.971709,-3.902348,-2.832986,-11.003402,-8.511102,-6.018801
3,1960.749678,1995.453247,2030.156817,-3.107384,-3.024620,-2.941856,-8.664764,-7.474135,-6.283507,-8.246195,-6.812896,-5.379596,-9.562721,-6.769558,-3.976395,-8.256723,-6.638976,-5.021228
4,1000.228975,1021.964050,1043.699126,-3.698923,-3.579576,-3.460230,-9.045152,-8.149136,-7.253119,-7.204746,-6.680446,-6.156146,-8.136288,-5.785867,-3.435446,-10.727084,-8.634022,-6.540959
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
495,1239.217755,1261.673706,1284.129658,-7.818596,-6.509627,-5.200658,-4.839828,-4.676985,-4.514142,-3.561601,-3.484992,-3.408382,-10.592431,-7.464346,-4.336260,-9.254283,-7.500788,-5.747292
496,1467.636761,1500.096313,1532.555866,-3.722367,-3.631026,-3.539685,-7.882856,-7.386600,-6.890343,-10.506809,-8.389006,-6.271203,-8.266145,-5.891376,-3.516607,-12.919004,-10.412751,-7.906499
497,933.397591,947.381348,961.365105,-4.066658,-3.979494,-3.892329,-4.403640,-4.262330,-4.121019,-11.240440,-9.558548,-7.876656,-9.327237,-6.543943,-3.760649,-10.495956,-8.698401,-6.900847
498,2173.674176,2198.622070,2223.569965,-11.198727,-9.486550,-7.774374,-5.867471,-5.771573,-5.675674,-6.583540,-6.215702,-5.847864,-8.893358,-6.550657,-4.207956,-5.264003,-5.211099,-5.158194


In [12]:
Y_test_pred = sample_normal(mean_test, std_test)
regular_track_format(Y_test_pred)
Y_test_pred.shape

(500, 5000, 6)