In [1]:
from sklearn.model_selection import train_test_split
import torch
from torch import nn

from ariel import *

In [2]:
def sample_normal(mean, std, T):
    return np.stack([np.random.normal(loc=mean, scale=std) for i in range(T)], axis=1)

In [5]:
modelnames = [
    "splendid-morning-378",
    "mild-leaf-376",
    "easy-valley-377",
    "rare-smoke-374",
    "solar-butterfly-375",
    "clean-wood-373",
    "vague-moon-372",
    "dainty-fog-371",
    "floral-sound-369",
    "vital-pine-370",
    "classic-water-365",
    "fresh-brook-365",
    "dainty-voice-368",
    "fresh-morning-367",
    "stellar-shape-364",
    "fragrant-water-360",
    "astral-oath-360",
    "toasty-tree-360",
    "fluent-sun-360",
    "smart-wildflower-359"
]
state_dicts = ["models/" + modelname + ".pt" for modelname in modelnames]
models = [Model(DEFAULT_HYPERPARAMETERS) for state_dict in state_dicts]
device = "cuda" if torch.cuda.is_available() else "cpu"
for model, state_dict in zip(models, state_dicts):
    model.load_state_dict(torch.load(state_dict, map_location=torch.device(device)))
len(models)

20

## Validation set

In [6]:
ids = np.arange(N)
ids_train, ids_valid = train_test_split(ids, train_size=0.8, random_state=36)
trainset = get_dataset(ids_train)
validset = get_dataset(
    ids_valid, trainset.auxiliary_train_mean, trainset.auxiliary_train_std)

In [7]:
outputs_valid = [model.predict(validset) for model in models]

In [8]:
samples_valid = np.concatenate([
    sample_normal(mean, torch.sqrt(var), T=250) for mean, var in outputs_valid],
    axis=1)
quartiles_valid = np.quantile(samples_valid, QUARTILES, axis=1)
light_score(validset.quartiles, quartiles_valid)

995.6391278665945

In [13]:
regular_score(samples_valid[:800], validset.ids[:800])

  0%|          | 0/800 [00:00<?, ?it/s]RESULT MIGHT BE INACURATE
Max number of iteration reached, currently 100000. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher
  check_result(result_code)
  0%|          | 1/800 [00:02<36:41,  2.76s/it]RESULT MIGHT BE INACURATE
Max number of iteration reached, currently 100000. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher
  0%|          | 2/800 [00:05<36:40,  2.76s/it]RESULT MIGHT BE INACURATE
Max number of iteration reached, currently 100000. 

995.2727768500861

## Test set

In [9]:
ids_test = np.arange(800)
spectra_test = read_spectra(ids_test, path="data/test/spectra.hdf5")
X_test = spectra_test[1]
auxiliary_test = read_auxiliary_table(ids_test, path="data/test/auxiliary_table.csv")
X_test = (X_test - X_test.mean(dim=1, keepdim=True)) / X_test.std(dim=1, keepdim=True)
auxiliary_test = standardise(auxiliary_test, trainset.auxiliary_train_mean, trainset.auxiliary_train_std)
X_test.shape, X_test.dtype, auxiliary_test.shape, auxiliary_test.dtype

(torch.Size([800, 52]), torch.float32, torch.Size([800, 9]), torch.float32)

In [10]:
with torch.no_grad():
    outputs_test = [model(X_test, auxiliary_test) for model in models]

In [11]:
samples_test = np.concatenate([
    sample_normal(mean, torch.sqrt(var), T=250) for mean, var in outputs_test],
    axis=1)
quartiles_test = np.quantile(samples_test, QUARTILES, axis=1)
light_track = light_track_format(quartiles_test)
light_track

Unnamed: 0_level_0,T_q1,T_q2,T_q3,log_H2O_q1,log_H2O_q2,log_H2O_q3,log_CO2_q1,log_CO2_q2,log_CO2_q3,log_CH4_q1,log_CH4_q2,log_CH4_q3,log_CO_q1,log_CO_q2,log_CO_q3,log_NH3_q1,log_NH3_q2,log_NH3_q3
planet_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
0,1081.733307,1094.584018,1108.485903,-10.764056,-9.254254,-7.794949,-4.504924,-4.393083,-4.280006,-11.004909,-9.700189,-8.396319,-7.359435,-5.122463,-2.997618,-6.381561,-6.243831,-6.109036
1,1580.411416,1605.391188,1630.695004,-4.927238,-4.851248,-4.772114,-4.960115,-4.822017,-4.691145,-9.194023,-7.657609,-6.066468,-10.067934,-7.720443,-5.482839,-9.851996,-8.055836,-6.189748
2,4842.623781,4973.844416,5102.586024,-10.731767,-9.218537,-7.770765,-10.073742,-8.800412,-7.979374,-10.339125,-8.645147,-7.149354,-5.123128,-4.722586,-4.311391,-9.409901,-7.484530,-6.413422
3,1971.706366,2008.321169,2044.688204,-3.540650,-3.434773,-3.326299,-10.505138,-8.944953,-7.307727,-10.684162,-8.843752,-7.012174,-4.990651,-3.569078,-2.391445,-10.617332,-8.778087,-6.907526
4,989.822504,1020.574319,1052.276534,-3.808575,-3.633753,-3.462550,-9.318316,-8.142801,-7.036601,-6.873658,-6.500531,-6.164196,-6.720862,-4.992676,-3.471742,-10.543742,-8.656210,-6.837647
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
795,1325.412478,1337.428118,1352.586412,-3.999443,-3.895300,-3.687787,-6.206094,-5.921465,-5.616191,-5.461241,-5.223599,-4.835906,-3.882832,-3.466363,-3.131401,-4.493545,-4.358410,-4.194950
796,591.857213,602.061762,612.226513,-3.603432,-3.446342,-3.272046,-8.023126,-6.397070,-5.646795,-3.884638,-3.729927,-3.554613,-3.613542,-3.052832,-2.597959,-4.408988,-4.287266,-4.152568
797,449.641900,456.222969,462.356654,-4.837315,-4.683305,-4.504382,-4.858388,-4.576393,-4.328714,-8.673032,-7.197262,-6.230657,-4.417719,-3.868509,-3.370168,-8.577981,-6.909549,-6.332987
798,914.108544,933.760304,953.451394,-3.745233,-3.613003,-3.472853,-10.782577,-9.327036,-7.944798,-10.760212,-9.109591,-7.510391,-5.525928,-4.072983,-2.679813,-4.944047,-4.827500,-4.711268


In [12]:
regular_track_format(samples_test)