# CNN model

## Notebook set-up

In [None]:
# Set notebook root to project root
from helper_functions import set_project_root

# Silence tensorflow, except for errors
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# Run on the GTX1080 GPU - fastest single worker/small memory performance
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

set_project_root()

# Standard library imports
import pickle
import time
from pathlib import Path

# Third party imports
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import optuna
import tensorflow as tf
from scipy.interpolate import griddata

# Local imports
from ariel_data_preprocessing.data_preprocessing import DataProcessor
import configuration as config

# Make sure the figures directory exists
figures_dir = f'{config.FIGURES_DIRECTORY}/model_training'
Path(figures_dir).mkdir(parents=True, exist_ok=True)

# Make sure models directory exists
Path(config.MODELS_DIRECTORY).mkdir(parents=True, exist_ok=True)

# Best settings from ~400 Optuna optimization trials
# (see model_training/optimize_cnn.py)
sample_size = 372
batch_size = 4
steps = 431
learning_rate = 0.0007103203515277739
l_one = 0.9381346432258663
l_two = 0.36282682418942663
cnn_layers = 3
first_filter_set = 73
second_filter_set = 34
third_filter_set = 48
first_filter_size = 2
second_filter_size = 5
third_filter_size = 3
dense_units = 104
beta_one=0.72
beta_two=0.93
amsgrad=True
weight_decay=0.016
use_ema=True

# Long training run
epochs = 1000

# Evaluation settings
samples = 10   # Number of samples to draw per planet
planets = 550  # Number of planets to evaluate

# File names
total_ksteps = int((epochs * steps) / 1000)
model_save_file = f'{config.MODELS_DIRECTORY}/ariel-cnn-8.4M-{total_ksteps}ksteps.keras'
training_results_save_file = f'{config.MODELS_DIRECTORY}/ariel-cnn-8.4M-{total_ksteps}ksteps.pkl'

## 1. Hyperparameter optimization results

### 1.1. Load Optuna study results

In [None]:
loaded_study = optuna.load_study(
    study_name='cnn_optimization',
    storage=f'postgresql://{config.USER}:{config.PASSWD}@{config.HOST}:{config.PORT}/{config.STUDY_NAME}'
)

results_df = loaded_study.trials_dataframe()
results_df = results_df[results_df['state'] == 'COMPLETE'] # Only keep completed trials
results_df = results_df[results_df['value'] < 0.02] # Filter out extreme high loss values
results_df.sort_values('value', ascending=True, inplace=True)

param_columns = [
    'params_batch_size', 'params_cnn_layers', 'params_dense_layers',
    'params_first_dense_units', 'params_first_filter_set',
    'params_first_filter_size', 'params_l_one', 'params_l_two',
    'params_learning_rate', 'params_sample_size',
    'params_second_dense_units', 'params_second_filter_set',
    'params_second_filter_size', 'params_steps', 'params_third_dense_units',
    'params_third_filter_set', 'params_third_filter_size'
]

results_df.head()

In [None]:
results_df.info()

In [None]:
results_df.describe().transpose()

### 1.2. Validation loss distribution

In [None]:
plt.figure(figsize=(config.STD_FIG_WIDTH/1.25, config.STD_FIG_WIDTH/1.5))
plt.title('CNN hyperparameter optimization\nvalidation RMSE distribution')
plt.hist(
    results_df['value'],
    bins=30,
    color='black',
    edgecolor='black'
)
plt.xlabel('Validation RMSE')
plt.ylabel('Number of Trials')
# plt.yscale('log')
plt.tight_layout()

# Save the figure
plt.savefig(
    f'{figures_dir}/03.2.1-validation_RMSE_distribution.jpg',
    dpi=config.STD_FIG_DPI,
    bbox_inches='tight'
)


### 1.3. Hyperparameter sampling distributions

In [None]:
fig, axs = plt.subplots(3, 6, figsize=(16, 6))
axs = axs.flatten()

fig.suptitle('CNN hyperparameter distributions')
fig.supxlabel('Hyperparameter value')
fig.supylabel('Number of trials')

for i, param in enumerate(param_columns):
    axs[i].set_title(param.replace('params_', ''))
    axs[i].hist(
        results_df[param],
        bins=20,
        color='black'
    )

axs[-1].axis('off')  # Turn off the last unused subplot

fig.tight_layout()

# Save the figure
plt.savefig(
    f'{figures_dir}/03.2.2-hyperparameter_distributions.jpg',
    dpi=config.STD_FIG_DPI,
    bbox_inches='tight'
)

### 1.4. Hyperparameter heatmap matrix

In [None]:
# Create a figure with subplots for the heatmap matrix
n_params = len(param_columns)
fig, axes = plt.subplots(n_params, n_params, figsize=(10, 9))

# Get the range of validation loss values for consistent color scaling
vmin, vmax = results_df['value'].min(), results_df['value'].max()

# Create the heatmap matrix
for i, param_y in enumerate(param_columns):
    for j, param_x in enumerate(param_columns):
        ax = axes[i, j]
        
        if i == j:

            # For diagonal elements, show parameter distribution as histogram
            ax.hist(
                results_df[param_x],
                bins=20,
                alpha=0.7,
                color='lightblue',
                edgecolor='black'
            )
            
            # Configure tick labels and marks: only show for bottom row (x) and first column (y)
            if j == 0:  # Left edge
                ax.set_ylabel(param_y.replace('params_', ''), rotation='horizontal', ha='right')

            if i == n_params - 1:  # Bottom edge
                ax.set_xlabel(param_x.replace('params_', ''), rotation='vertical')

            # Hide all tick labels and marks
            ax.tick_params(labelbottom=False, labelleft=False, bottom=False, left=False, top=False, right=False)

        else:
            # Create interpolated heatmap of validation loss
            x_data = results_df[param_x].values
            y_data = results_df[param_y].values
            z_data = results_df['value'].values
            
            # Create a regular grid for interpolation
            x_min, x_max = x_data.min(), x_data.max()
            y_min, y_max = y_data.min(), y_data.max()
            
            # Add small margins to avoid edge effects
            x_margin = (x_max - x_min) * 0.05
            y_margin = (y_max - y_min) * 0.05
            
            x_grid = np.linspace(x_min - x_margin, x_max + x_margin, 50)
            y_grid = np.linspace(y_min - y_margin, y_max + y_margin, 50)
            X_grid, Y_grid = np.meshgrid(x_grid, y_grid)
            
            # Interpolate using griddata
            Z_interp = griddata(
                (x_data, y_data), z_data, (X_grid, Y_grid), 
                method='linear', fill_value=np.nan
            )
            
            # Standardize the heatmap using its own min/max values
            z_min_local = np.nanmin(Z_interp)
            z_max_local = np.nanmax(Z_interp)
            
            # Create the interpolated heatmap with individual scaling
            im = ax.imshow(Z_interp, origin='lower', aspect='auto', 
                cmap='viridis', vmin=z_min_local, vmax=z_max_local,
            )
            
            # Set labels only for edge subplots to avoid clutter
            if j == 0:  # Left edge
                ax.set_ylabel(param_y.replace('params_', ''), rotation='horizontal', ha='right')

            if i == n_params - 1:  # Bottom edge
                ax.set_xlabel(param_x.replace('params_', ''), rotation='vertical')

            # Hide all tick labels and marks
            ax.tick_params(labelbottom=False, labelleft=False, bottom=False, left=False, top=False, right=False)

# Add a single colorbar to the left of the figure showing the global range
# Note: Individual heatmaps are standardized to their own ranges
cbar_ax = fig.add_axes([0.02, 0.08, 0.05, 0.87])  # [left, bottom, width, height]

# Create a dummy mappable for the global colorbar
norm = plt.Normalize(vmin=vmin, vmax=vmax)
sm = cm.ScalarMappable(norm=norm, cmap='viridis')
sm.set_array([])

cbar = fig.colorbar(sm, cax=cbar_ax)
cbar.ax.tick_params(labelsize=9)

# Adjust layout to make room for the colorbar
plt.subplots_adjust(left=0.27, right=0.95, top=0.95, bottom=0.08, hspace=0.05, wspace=0.05)

# Add a main title
fig.suptitle(
    'Validation RMSE Heatmap Matrix'
)

# Save the figure
plt.savefig(
    f'{figures_dir}/03.2.3-hyperparameter_validation_RMSE_heatmaps.jpg',
    dpi=config.STD_FIG_DPI,
    bbox_inches='tight'
)

plt.show()

## 2. Long training run

### 2.1. Initialize data generators

In [None]:
data_preprocessor = DataProcessor(
    input_data_path=config.RAW_DATA_DIRECTORY,
    output_data_path=config.PROCESSED_DATA_DIRECTORY,
    mode='train',
)

data_preprocessor.initialize_data_generators(
    sample_size=sample_size,
    n_samples=samples,
    validation=True
)

### 2.2. CNN

#### 2.2.1. Model definition

In [None]:
def compile_model(
        samples: int=sample_size,
        wavelengths: int=config.WAVELENGTHS,
        learning_rate: float=learning_rate,
        l1: float=l_one,
        l2: float=l_two,
        first_filter_set: int=first_filter_set,
        second_filter_set: int=second_filter_set,
        third_filter_set: int=third_filter_set,
        first_filter_size: int=first_filter_size,
        second_filter_size: int=second_filter_size,
        third_filter_size: int=third_filter_size,
        dense_units: int=dense_units,
        beta_one: float=beta_one,
        beta_two: float=beta_two,
        amsgrad: bool=amsgrad,
        weight_decay: float=weight_decay,
        use_ema: bool=use_ema
) -> tf.keras.Model:

    '''Builds the convolutional neural network regression model'''

    # Set-up the L1L2 for the dense layers
    regularizer = tf.keras.regularizers.L1L2(l1=l1, l2=l2)

    # Define the model layers in order
    model = tf.keras.Sequential([
        tf.keras.layers.Input((samples,wavelengths,1)),
        tf.keras.layers.Conv2D(
            first_filter_set,
            first_filter_size,
            padding='same',
            activation='relu',
        ),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Conv2D(
            second_filter_set,
            second_filter_size,
            padding='same',
            activation='relu',
        ),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Conv2D(
            third_filter_set,
            third_filter_size,
            padding='same',
            activation='relu',
        ),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(
            dense_units,
            kernel_regularizer=regularizer,
            activation='relu',
        ),
        tf.keras.layers.Dense(wavelengths, activation='linear')
    ])

    # Define the optimizer
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=learning_rate,
        beta_1=beta_one,
        beta_2=beta_two,
        amsgrad=amsgrad,
        weight_decay=weight_decay,
        use_ema=use_ema
    )

    # Compile the model, specifying the type of loss to use during training 
    # and any extra metrics to evaluate
    model.compile(
        optimizer=optimizer,
        loss=tf.keras.losses.MeanSquaredError(name='MSE'),
        metrics=[
            tf.keras.metrics.RootMeanSquaredError(name='RMSE')
        ]
    )

    return model

In [None]:
model = compile_model()
model.summary()

#### 2.2.2. Training

In [None]:
if Path(model_save_file).exists() and Path(training_results_save_file).exists():

    print(f'Found existing model for {total_ksteps} ksteps, skipping training.')

    # Load the existing model
    model = tf.keras.models.load_model(model_save_file)

    # Load existing training results
    with open(training_results_save_file, 'rb') as input_file:
        training_results = pickle.load(input_file)

else:

  print(f'Training model for {total_ksteps} ksteps')
  start_time = time.time()

  training_results = model.fit(
    data_preprocessor.training.batch(batch_size),
    validation_data=data_preprocessor.validation.batch(batch_size),
    epochs=epochs,
    steps_per_epoch=steps,
    validation_steps=steps,
    verbose=1
  )

  print(f'Training complete in {(time.time() - start_time)/60:.1f} minutes')
  model.save(model_save_file)

  with open(training_results_save_file, 'wb') as output_file:
      pickle.dump(training_results, output_file)

In [None]:
# Set-up a 1x2 figure for accuracy and binary cross-entropy
fig, axs=plt.subplots(1,2, figsize=(12,4))

# Add the main title
fig.suptitle('CNN training curves', size='large')

# Plot training and validation loss
axs[0].set_title('Training loss (mean squared error)')
axs[0].plot(np.array(training_results.history['loss']), alpha=0.5, label='Training')
axs[0].plot(np.array(training_results.history['val_loss']), alpha=0.5, label='Validation')
axs[0].set_xlabel('Epoch')
axs[0].set_ylabel('loss')
# axs[0].set_ylim(21, 25)
# axs[0].set_yscale('log')
axs[0].legend(loc='upper right')

# Plot training and validation RMSE
axs[1].set_title('Root mean squared error')
axs[1].plot(training_results.history['RMSE'], alpha=0.5, label='Training')
axs[1].plot(training_results.history['val_RMSE'], alpha=0.5, label='Validation')
axs[1].set_xlabel('Epoch')
axs[1].set_ylabel('RMSE')
# axs[2].set_ylim(top=0.014)
axs[1].set_yscale('log')

# Show the plot
fig.tight_layout()
fig.savefig(
    f'{figures_dir}/03.2.4-ariel_cnn_training_curves_8.4M-{total_ksteps}ksteps.jpg',
    dpi=config.STD_FIG_DPI,
    bbox_inches='tight'
)

## 3. Model evaluation (validation set)

### 3.1. Evaluation dataset

In [None]:
evaluation_data = data_preprocessor.evaluation.take(planets)

signals = np.array([element[0].numpy() for element in evaluation_data])
spectra = np.array([element[1].numpy() for element in evaluation_data])

print(f'Signals shape: {signals.shape}')
print(f'Spectra shape: {spectra.shape}')

### 3.2. Predictions

In [None]:
spectrum_predictions = []

for planet in signals:
    spectrum_predictions.append(model.predict(planet, batch_size=samples, verbose=0))

spectrum_predictions = np.array(spectrum_predictions)
spectrum_predictions_avg = np.mean(spectrum_predictions, axis=1).flatten()
spectrum_predictions = spectrum_predictions.flatten()
reference_spectra = spectra[:,0,:].flatten()
spectra = spectra.flatten()

### 3.3. Plot

In [None]:
plt.title('Predicted vs true spectral signals')
plt.scatter(spectra, spectrum_predictions, s=10, alpha=0.5, color='black', label='Sample predictions')
plt.scatter(reference_spectra, spectrum_predictions_avg, s=2.5, alpha=0.5, color='red', label='Averaged prediction')
plt.xlabel('True spectral signal')
plt.ylabel('Predicted spectral signal')
plt.legend(loc='best', markerscale=2)

plt.savefig(
    f'{figures_dir}/03.2.5-ariel_cnn_predicted_vs_true_spectra-8.4M-{total_ksteps}ksteps.jpg',
    dpi=config.STD_FIG_DPI,
    bbox_inches='tight'
)

plt.show()