# Error correction

In [None]:
# Standard library imports
from pathlib import Path

# Third party imports
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from matplotlib.ticker import FormatStrFormatter

# Project imports
import configuration as config
from ariel_data_preprocessing.data_generator_functions import make_training_datasets

## 2. Initialize data generators

In [None]:
sample_size = 372
n_samples = 10

_, _, evaluation_dataset = make_training_datasets(
    data_file=f'{config.PROCESSED_DATA_DIRECTORY}/train_no_smoothing.h5',
    sample_size=sample_size,
    n_samples=n_samples,
    validation=True
)

## 3. Create dataset

In [None]:
evaluation_data = evaluation_dataset.take(550)
signals = np.array([element[0].numpy() for element in evaluation_data])
spectra = np.array([element[1].numpy() for element in evaluation_data])

print(f'Signals shape: {signals.shape}')
print(f'Spectra shape: {spectra.shape}')

## 4. Load model

In [None]:
model = tf.keras.models.load_model(f'{config.MODELS_DIRECTORY}/ariel-cnn-8.1M-43ksteps-tf2.11.keras')

## 5. Make predictions

In [None]:
Path(f'{config.EXPERIMENT_RESULTS_DIRECTORY}/error_correction').mkdir(parents=True, exist_ok=True)

predictions_file = f'{config.EXPERIMENT_RESULTS_DIRECTORY}/error_correction/predictions'

if Path(predictions_file).is_file():
    spectrum_predictions = np.load(predictions_file)

else:
    spectrum_predictions = []
    
    for planet in signals:
        spectrum_predictions.append(model.predict(planet, batch_size=10, verbose=0))
    
    spectrum_predictions = np.array(spectrum_predictions)
    np.save(predictions_file, spectrum_predictions)

spectrum_predictions_avg = np.mean(spectrum_predictions, axis=1)
spectrum_predictions_std = np.std(spectrum_predictions, axis=1)
reference_spectra = spectra[:,0,:]

print(f'Spectrum predictions shape: {spectrum_predictions.shape}')
print(f'Spectrum predictions avg shape: {spectrum_predictions_avg.shape}')
print(f'Spectrum predictions std shape: {spectrum_predictions_std.shape}')

## 6. Evaluation

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(10,10))
axs = axs.flatten()

fig.suptitle('Spectral prediction evaluation')

axs[0].set_title('Predicted vs true spectral signals')
axs[0].scatter(
    spectra.flatten(),
    spectrum_predictions.flatten(),
    s=10,
    alpha=0.5,
    color='black',
    label='Sample predictions'
)

axs[0].scatter(
    reference_spectra.flatten(),
    spectrum_predictions_avg.flatten(),
    s=2.5,
    alpha=0.5,
    color='red', 
    label='Averaged prediction'
)

axs[0].set_xlim(0,0.1)
axs[0].set_ylim(0,0.1)
axs[0].set_aspect('equal')
axs[0].set_xlabel('True spectral signal')
axs[0].set_ylabel('Predicted spectral signal')
axs[0].legend(loc='best', markerscale=2)

residuals = spectrum_predictions.flatten() - spectra.flatten()
avg_residual = spectrum_predictions_avg.flatten() - reference_spectra.flatten()

axs[1].set_title('Prediction residuals')
axs[1].scatter(
    spectra,
    residuals,
    s=10,
    alpha=0.5,
    color='black',
    label='Sample predictions'
)

axs[1].scatter(
    reference_spectra,
    avg_residual,
    s=2.5,
    alpha=0.5,
    color='red',
    label='Averaged prediction'
)

axs[1].set_xlabel('True spectral signal')
axs[1].set_ylabel('Prediction residuals')

axs[2].set_title('Mean fit residual vs sample sigma')
axs[2].scatter(avg_residual, spectrum_predictions_std.flatten(), color='black', alpha=0.5, s=2.5)
axs[2].set_xlabel('Fit residual')
axs[2].set_ylabel('Standard deviation')

axs[3].set_title('Standard deviation of predictions')
axs[3].hist(spectrum_predictions_std.flatten(), bins=100, color='black')
axs[3].xaxis.set_major_formatter(FormatStrFormatter('%.2e'))
axs[3].tick_params(axis='x', labelrotation=45)
axs[3].set_xlabel('Standard deviation')
axs[3].set_ylabel('Counts')

fig.tight_layout()
fig.show()

Still seem to have that weird problem, where true vs predicted signal looks OK-ish - mostly kind of diagonal. But, for some reason the predictions seem like they are capped at 0.06, while the true range of the data is up to ~ 0.08.

That aside for the moment, I think the more pressing issue is that our standard deviations are way too tight. Even if our predictions are pretty good (which I doubt, but for the sake of argument..) setting the error too low will tank the score. The scoring function finds the probability that the true value was drawn from a gaussian with our mean and sigma, like this:

```python
scipy.stats.norm.logpdf(true_spectrum, loc=our_predictions, scale=our_errors)
```

Two approaches to do a better job with the errors come to mind:

1. Figure out a way to use the observed signal to get a wavelength & signal dependent error for the instrument
2. Scale the prediction ensemble error until it encompasses the true value

The first approach is probably the more scientifically correct way to attack the problem - but I'm going to go with option 2. Seems like the easiest way to get validated error values.