# Training and validating RNN + MPL model with the WHXE loss function

## Imports

In [1]:
import time
import pickle
import platform
import os
import imageio
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf

from random import sample 
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.utils import plot_model
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
from tensorflow import keras

from LSTM_model import get_LSTM_Classifier
from dataloader import LSSTSourceDataSet, load, get_augmented_data, get_static_features, ts_length
from loss import WHXE_Loss
from taxonomy import get_taxonomy_tree, get_prediction_probs, get_highest_prob_path, plot_colored_tree
from vizualizations import make_gif, plot_confusion_matrix, plot_roc_curves
from interpret_results import get_conditional_probabilites, save_all_cf_and_rocs

2024-06-12 07:25:29.910227: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-06-12 07:25:30.646380: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9373] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-12 07:25:30.647894: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-12 07:25:30.814408: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1534] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-06-12 07:25:31.139021: I tensorflow/core/platform/cpu_feature_guar

In [2]:
print("Tensorflow version", tf.__version__)
print("Python version", platform.python_version())

Tensorflow version 2.15.0
Python version 3.10.12


## Load and balance the tensors:

This step takes a while because it has load from disc to memory...

In [None]:
X_ts = load("processed/train/x_ts.pkl")
X_static = load("processed/train/x_static.pkl")
Y = load("processed/train/y.pkl")
astrophysical_classes = load("processed/train/a_labels.pkl")

In [None]:
a, b = np.unique(astrophysical_classes, return_counts=True)
print(f"Total sample count = {np.sum(b)}")
pd.DataFrame(data = {'Class': a, 'Count': b})

Small step to convert X_static from a dictionary to an array

In [None]:
for i in range(len(X_static)):
    
    if i%1000 == 0:
        print(f"{(i/len(X_static) * 100):.3f} %", end="\r")
        
    X_static[i] = get_static_features(X_static[i])

Balance the data set in some way

In [None]:
max_class_count = 30000

X_ts_balanced = []
X_static_balanced = []
Y_balanced = []
astrophysical_classes_balanced = []

for c in np.unique(astrophysical_classes):

    idx = list(np.where(np.array(astrophysical_classes) == c)[0])
    
    if len(idx) > max_class_count:
        idx = sample(idx, max_class_count)
 
    X_ts_balanced += [X_ts[i] for i in idx]
    X_static_balanced += [X_static[i] for i in idx]
    Y_balanced += [Y[i] for i in idx]
    astrophysical_classes_balanced += [astrophysical_classes[i] for i in idx]

# Print summary of the data set used for training and validation
a, b = np.unique(astrophysical_classes_balanced, return_counts=True)
data_summary = pd.DataFrame(data = {'Class': a, 'Count': b})
data_summary

Free up some memory

In [None]:
del X_ts, X_static, Y, astrophysical_classes

Split into train and validation

In [None]:
val_fraction = 0.05
X_ts_train, X_ts_val, X_static_train, X_static_val, Y_train, Y_val, astrophysical_classes_train, astrophysical_classes_val = train_test_split(X_ts_balanced, X_static_balanced, Y_balanced, astrophysical_classes_balanced, shuffle=True, random_state = 40, test_size = val_fraction)

Free up some more memory

In [None]:
del X_ts_balanced, X_static_balanced, Y_balanced, astrophysical_classes_balanced

Check make up of the validation data

## Define the Loss function, criterion, etc.

In [None]:
# Loss and optimizer
tree = get_taxonomy_tree()
loss_object = WHXE_Loss(tree, astrophysical_classes_train, alpha=0) 
criterion = loss_object.compute_loss

## Train the classifier using WHXE loss and save the model

In [None]:
optimizer = keras.optimizers.Adam(learning_rate=5e-4)

In [None]:
# Inputs for model
ts_dim = 5
static_dim = len(X_static_train[0])
output_dim = 26
latent_size = 64

num_epochs = 100
batch_size = 2048

In [None]:
model = get_LSTM_Classifier(ts_dim, static_dim, output_dim, latent_size)
keras.utils.plot_model(model, to_file='lstm.pdf', show_shapes=True, show_layer_names=True)
plt.close()

In [None]:
@tf.function
def train_step(x_ts, x_static, y):
    with tf.GradientTape() as tape:
        logits = model((x_ts, x_static), training=True)
        loss_value = criterion(y, logits)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))

    return loss_value

In [None]:
avg_train_losses = []
for epoch in range(num_epochs):
    
    print(f"\nStart of epoch {epoch}")
    start_time = time.time()

    print("Augmenting time series lengths...")
    
    # Create the augmented data set for training
    X_ts_train_aug, X_static_train_aug, Y_train_aug, astrophysical_classes_train_aug = get_augmented_data(X_ts_train, X_static_train, Y_train, astrophysical_classes_train)
    train_dataset =  tf.data.Dataset.from_tensor_slices((X_ts_train_aug, X_static_train_aug, Y_train_aug, astrophysical_classes_train_aug)).batch(batch_size)
    
    # Array to keep tracking of the training loss
    train_loss_values = []
    
    # Iterate over the batches of the dataset.
    for step, (x_ts_batch_train, x_static_batch_train, y_batch_train, a_class_batch_train) in enumerate(train_dataset):
        loss_value = train_step(x_ts_batch_train, x_static_batch_train, y_batch_train)
    
    # Log the avg train loss
    avg_train_loss = np.mean(loss_value)
    avg_train_losses.append(avg_train_loss)
    print(f"Avg training loss: {float(avg_train_loss):.4f}")
    
    print(f"Time taken: {time.time() - start_time:.2f}s")
    model.save(f"models/lsst_rate_agnostic/lstm_epoch_{epoch}.h5")
    
    # Save the model with the smallest training loss
    best_model_epoch = np.argmin(avg_train_losses)
    loaded_model = keras.models.load_model(f"models/lsst_rate_agnostic/lstm_epoch_{best_model_epoch}.h5", compile=False)
    loaded_model.save(f"models/lsst_rate_agnostic/best_model.h5")

## Load the saved model and validate that everthing looks okay

In [12]:
best_model = keras.models.load_model(f"models/lsst_rate_agnostic/best_model.h5", compile=False)

In [13]:
fractions = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]

In [49]:
for f in fractions:

    print(f'Running inference for {int(f*100)}% light curves...')

    x1, x2, y_true, _ = get_augmented_data(X_ts_val, X_static_val, Y_val, astrophysical_classes_val, fraction=f)
    
    # Run inference on these
    y_pred = best_model.predict([x1, x2])

    # Get the conditional probabilities
    _, pseudo_conditional_probabilities = get_conditional_probabilites(y_pred, tree)
    
    print(f'For {int(f*100)}% of the light curve, these are the statistics:')
    
    # Print all the stats and make plots...
    save_all_cf_and_rocs(y_true, pseudo_conditional_probabilities, tree, f)
    save_leaf_cf_and_rocs(y_true, pseudo_conditional_probabilities, tree, f)
    
    plt.close()

Running inference for 10% light curves...
For 10% of the light curve, these are the statistics:
              precision    recall  f1-score   support

   Transient       0.88      1.00      0.94     14195
    Variable       0.98      0.65      0.78      5440

    accuracy                           0.90     19635
   macro avg       0.93      0.82      0.86     19635
weighted avg       0.91      0.90      0.89     19635

              precision    recall  f1-score   support

         AGN       1.00      0.27      0.43      1563
        Fast       0.60      0.85      0.70      1576
        Long       0.65      0.59      0.62      5278
    Periodic       0.98      0.81      0.89      3877
          SN       0.73      0.90      0.81      7341

    accuracy                           0.74     19635
   macro avg       0.79      0.68      0.69     19635
weighted avg       0.77      0.74      0.73     19635

               precision    recall  f1-score   support

          AGN       1.00      0.

## Making a cool animation:

In [17]:
cf_files = [f"gif/leaf_cf/{f}.png" for f in fractions]
make_gif(cf_files, 'gif/leaf_cf/leaf_cf.gif')
plt.close()

MovieWriter ffmpeg unavailable; using Pillow instead.


In [18]:
roc_files = [f"gif/leaf_roc/{f}.png" for f in fractions]
make_gif(roc_files, 'gif/leaf_roc/leaf_roc.gif')
plt.close()

MovieWriter ffmpeg unavailable; using Pillow instead.


In [19]:
cf_files = [f"gif/level_1_cf/{f}.png" for f in fractions]
make_gif(cf_files, 'gif/level_1_cf/level_1_cf.gif')
plt.close()

MovieWriter ffmpeg unavailable; using Pillow instead.


In [20]:
roc_files = [f"gif/level_1_roc/{f}.png" for f in fractions]
make_gif(roc_files, 'gif/level_1_roc/level_1_roc.gif')
plt.close()

MovieWriter ffmpeg unavailable; using Pillow instead.


In [21]:
cf_files = [f"gif/level_2_cf/{f}.png" for f in fractions]
make_gif(cf_files, 'gif/level_2_cf/level_2_cf.gif')
plt.close()

MovieWriter ffmpeg unavailable; using Pillow instead.


In [22]:
roc_files = [f"gif/level_2_roc/{f}.png" for f in fractions]
make_gif(roc_files, 'gif/level_2_roc/level_2_roc.gif')
plt.close()

MovieWriter ffmpeg unavailable; using Pillow instead.


For the love of everthing that is good in this world, please use a different notebook for testing and genearting statistics. Keep this notebook simple enought to be converted into a script. 