# Inference ensemble

## Notebook set-up

In [None]:
# Set notebook root to project root
from helper_functions import set_project_root

# Silence tensorflow, except for errors
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# Run on the GTX1080 GPU - fastest single worker/small memory performance
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

set_project_root()

# Standard library imports
import pickle
import random
import time
from functools import partial
from pathlib import Path

# Third party imports
import h5py
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import optuna
import tensorflow as tf
from scipy.interpolate import griddata

# Local imports
import configuration as config

# Make sure the figures directory exists
figures_dir = f'{config.FIGURES_DIRECTORY}/model_training'
Path(figures_dir).mkdir(parents=True, exist_ok=True)

sample_size = 883  # Number of time points per sample
samples = 10       # Number of samples to draw per planet
planets = 550      # Number of planets to evaluate

## 1. Load trained model

In [None]:
model_save_file = f'{config.MODELS_DIRECTORY}/optimized_cnn-13ksteps.keras'
model = tf.keras.models.load_model(model_save_file)

# You can then use the loaded model for prediction or other operations
# For example, to print a summary of the model's architecture:
model.summary()

## 2. Data preparation

### 2.1. Training/validation split

In [None]:
planet_ids_file = f'{config.METADATA_DIRECTORY}/planet_ids.pkl'

if Path(planet_ids_file).exists():

    with open(planet_ids_file, 'rb') as input_file:
        planet_ids = pickle.load(input_file)
        training_planet_ids = planet_ids['training']
        validation_planet_ids = planet_ids['validation']

    print('Loaded existing training/validation split')

### 2.2. Prediction dataset

In [None]:
def prediction_data_loader(planet_ids: list, data_file: str, sample_size: int = 100, n_samples: int = 10):
    '''Generator that yields signal, spectrum pairs for training/validation/testing.

    Args:
        planet_ids (list): List of planet IDs to include in the generator.
        data_file (str): Path to the HDF5 file containing the data.
        sample_size (int, optional): Number of frames to draw from each planet. Defaults to 100.
    '''

    with h5py.File(data_file, 'r') as hdf:

        while True:
            
            for planet_id in planet_ids:

                signal = hdf[planet_id]['signal'][:]

                samples = []
                spectra = []

                for _ in range(n_samples):

                    indices = random.sample(range(signal.shape[0]), sample_size)
                    samples.append(signal[sorted(indices), :])
                    spectra.append(hdf[planet_id]['spectrum'][:])

                yield np.array(samples), np.array(spectra)


In [None]:
prediction_data_generator = partial(
    prediction_data_loader,
    planet_ids=validation_planet_ids,
    data_file=f'{config.PROCESSED_DATA_DIRECTORY}/train.h5',
    sample_size=sample_size,
    n_samples=samples
)

In [None]:
prediction_dataset = tf.data.Dataset.from_generator(
    prediction_data_generator,
    output_signature=(
        tf.TensorSpec(shape=(samples, sample_size, config.WAVELENGTHS), dtype=tf.float64),
        tf.TensorSpec(shape=(samples, config.WAVELENGTHS), dtype=tf.float64)
    )
)

In [None]:
validation_data = prediction_dataset.take(planets)

signals = np.array([element[0].numpy() for element in validation_data])
spectra = np.array([element[1].numpy() for element in validation_data])

print(f'Signals shape: {signals.shape}')
print(f'Spectra shape: {spectra.shape}')