# Training and validating LSTM + MPL model with the WHXE loss function

## Imports

In [58]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [79]:
import pickle
import os
import imageio
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.animation as animation

from LSTM_model import get_LSTM_Classifier
from dataloader import LSSTSourceDataSet, load, get_augmented_data, get_static_features, ts_length
from loss import WHXE_Loss
from taxonomy import get_taxonomy_tree, get_prediction_probs, get_highest_prob_path, plot_colored_tree
from vizualizations import make_gif, plot_confusion_matrix, plot_roc_curves

from argparse import ArgumentParser
from pathlib import Path
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.utils import plot_model
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
from tensorflow import keras

## Load and balance the tensors:

This step takes a while because it has load from disc to memory...

In [60]:
X_ts = load("processed/train/x_ts.pkl")
X_static = load("processed/train/x_static.pkl")
Y = load("processed/train/y.pkl")
lengths = load("processed/train/lengths.pkl")
astrophysical_classes = load("processed/train/a_labels.pkl")
elasticc_classes = load("processed/train/e_label.pkl")

Small step to convert X_static from a dictionary to an array

In [61]:
static_list = ['MWEBV', 'MWEBV_ERR']
for i in range(len(X_static)):
    
    if i%1000 == 0:
        print(f"{(i/len(X_static) * 100):.3f} %", end="\r")
        
    X_static[i] = get_static_features(X_static[i], feature_list=static_list)

99.943 %

Balance the data set in some way

In [62]:
max_class_count = 20000

X_ts_balanced = []
X_static_balanced = []
Y_balanced = []
lengths_balanced = []
astrophysical_classes_balanced = []
elasticc_classes_balanced = []

for c in np.unique(astrophysical_classes):

    idx = np.where(np.array(astrophysical_classes) == c)[0]
    
    if len(idx) > max_class_count:
        idx = idx[:max_class_count]
 
    X_ts_balanced += [X_ts[i] for i in idx]
    X_static_balanced += [X_static[i] for i in idx]
    Y_balanced += [Y[i] for i in idx]
    astrophysical_classes_balanced += [astrophysical_classes[i] for i in idx]
    elasticc_classes_balanced += [elasticc_classes[i] for i in idx]

# Print summary of the data set used for training and validation
a, b = np.unique(astrophysical_classes_balanced, return_counts=True)
data_summary = pd.DataFrame(data = {'Class': a, 'Count': b})
data_summary

Unnamed: 0,Class,Count
0,AGN,1000
1,CART,1000
2,Cepheid,1000
3,Delta Scuti,1000
4,Dwarf Novae,1000
5,EB,1000
6,ILOT,1000
7,KN,1000
8,M-dwarf Flare,1000
9,PISN,1000


In [63]:
# clear up memory
del X_ts
del X_static
del Y
del astrophysical_classes
del elasticc_classes

Split into train and validation

In [64]:
val_fraction = 0.05
X_ts_train, X_ts_val, X_static_train, X_static_val, Y_train, Y_val, astrophysical_classes_train, astrophysical_classes_val = train_test_split(X_ts_balanced, X_static_balanced, Y_balanced, astrophysical_classes_balanced, shuffle=True, random_state = 40, test_size = val_fraction)

In [65]:
# clear up memory
del X_ts_balanced
del X_static_balanced
del Y_balanced
del astrophysical_classes_balanced
del elasticc_classes_balanced

## Augment the Time Series lengths

In [66]:
old_n_smaple = len(X_ts_train)

In [67]:
fractions = [0.25, 0.5, 0.75, 1]
X_ts_train, X_static_train, Y_train, astrophysical_classes_train, lc_fraction_train = get_augmented_data(X_ts_train, X_static_train, Y_train, astrophysical_classes_train, fractions)
print("Finished augmenting training set...")

Augmenting light curve to 25.00%
Augmenting light curve to 50.00%
Augmenting light curve to 75.00%
Augmenting light curve to 100.00%
Finished augmenting training set...


In [68]:
fractions = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
X_ts_val, X_static_val, Y_val, astrophysical_classes_val, lc_fraction_val = get_augmented_data(X_ts_val, X_static_val, Y_val, astrophysical_classes_val, fractions)
print("Finished augmenting validation set...")

Augmenting light curve to 10.00%
Augmenting light curve to 20.00%
Augmenting light curve to 30.00%
Augmenting light curve to 40.00%
Augmenting light curve to 50.00%
Augmenting light curve to 60.00%
Augmenting light curve to 70.00%
Augmenting light curve to 80.00%
Augmenting light curve to 90.00%
Augmenting light curve to 100.00%
Finished augmenting validation set...


In [69]:
new_n_sample = X_ts_train.shape[0]
print(f"Number of samples in training set before data augmentation: {old_n_smaple}")
print(f"Number of samples in training set after data augmentation: {new_n_sample}")

Number of samples in training set before data augmentation: 18050
Number of samples in training set after data augmentation: 72200


## Declare the Loss function

In [70]:
# Loss and optimizer
tree = get_taxonomy_tree()
loss_object = WHXE_Loss(tree, astrophysical_classes_train, alpha=0) 
criterion = loss_object.compute_loss

## Train the classifier using WHXE loss and save the model

In [71]:
# Inputs for model
ts_dim = X_ts_train.shape[2]
static_dim = X_static_train.shape[1]
output_dim = 26
latent_size = 64

num_epochs = 10
batch_size = 1024

In [72]:
model = get_LSTM_Classifier(ts_dim, static_dim, output_dim, latent_size, criterion)
keras.utils.plot_model(model, to_file='lstm.png', show_shapes=True, show_layer_names=True)
plt.close()

In [73]:
early_stopping = EarlyStopping(
                          patience=5,
                          min_delta=0.001,                               
                          monitor="val_loss",
                          restore_best_weights=True
                          )


history = model.fit(x = [X_ts_train, X_static_train],  y = Y_train, validation_data=([X_ts_val, X_static_val], Y_val), epochs=num_epochs, batch_size = batch_size, callbacks=[early_stopping])


model.save(f"models/lstm.keras")

Epoch 1/10
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m313s[0m 4s/step - loss: 1.2170 - val_loss: 1.1105
Epoch 2/10
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m314s[0m 4s/step - loss: 1.0709 - val_loss: 1.0596
Epoch 3/10
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m306s[0m 4s/step - loss: 1.0066 - val_loss: 1.0035
Epoch 4/10
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m304s[0m 4s/step - loss: 0.9508 - val_loss: 0.9486
Epoch 5/10
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m306s[0m 4s/step - loss: 0.8992 - val_loss: 0.9628
Epoch 6/10
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m306s[0m 4s/step - loss: 0.8844 - val_loss: 0.8657
Epoch 7/10
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m309s[0m 4s/step - loss: 0.8048 - val_loss: 0.8153
Epoch 8/10
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m310s[0m 4s/step - loss: 0.7607 - val_loss: 0.7902
Epoch 9/10
[1m71/71[0m [32m━━━━━━━━━━━━━━━━━━

## Load the saved model and validate everthing looks okay

In [74]:
saved_model = keras.models.load_model(f"models/lstm.keras", compile=False)

In [127]:
for f in fractions:

    print(f'Running inference for {int(f*100)}% light curves...')

    # Get all the indices for where 'f' fraction of the LC was used
    idx = np.where(lc_fraction_val == f)[0]

    # Run inference on these
    x1 = X_ts_val[idx, :, :]
    x2 = X_static_val[idx]
    y_pred = saved_model.predict([x1, x2])

    # Get the predictions at the leaf
    for i in range(y_pred.shape[0]):
    
        pseudo_probs, weighted_tree = get_prediction_probs(y_pred[[i], :])
        y_pred[i, 1:3] =  pseudo_probs[0, 1:3]

    y_pred_label = np.argmax(y_pred[:, 1:3], axis=1)
    y_true_label = np.argmax(Y_val[idx, 1:3], axis=1)

    # Print the stats
    print(f'For {int(f*100)}% of the light curve, these are the statistics')
    report = classification_report(y_true_label, y_pred_label, target_names=list(loss_object.level_order_nodes)[1:3])
    print(report)

    # Make plots
    leaf_labels = list(loss_object.level_order_nodes)[1:3]
    plot_title = f"~{f * 100}% of each LC visible"
    cf_plot_file = f"gif/root_cf/{f}.png"
    roc_plot_file = f"gif/root_roc/{f}.png"
    
    plot_confusion_matrix(y_true_label, y_pred_label, leaf_labels, plot_title, cf_plot_file)
    plt.close()
    plot_roc_curves(Y_val[idx, 1:3], y_pred[:, 1:3], leaf_labels, plot_title, roc_plot_file)
    plt.close()

Running inference for 10% light curves...
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 112ms/step
For 10% of the light curve, these are the statistics
              precision    recall  f1-score   support

   Transient       0.92      0.98      0.95       706
    Variable       0.94      0.75      0.84       244

    accuracy                           0.92       950
   macro avg       0.93      0.87      0.89       950
weighted avg       0.93      0.92      0.92       950

Running inference for 20% light curves...
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 112ms/step
For 20% of the light curve, these are the statistics
              precision    recall  f1-score   support

   Transient       0.95      0.98      0.97       706
    Variable       0.95      0.84      0.89       244

    accuracy                           0.95       950
   macro avg       0.95      0.91      0.93       950
weighted avg       0.95      0.95      0.95       950

Runnin

In [128]:
for f in fractions:

    print(f'Running inference for {int(f*100)}% light curves...')

    # Get all the indices for where 'f' fraction of the LC was used
    idx = np.where(lc_fraction_val == f)[0]

    # Run inference on these
    x1 = X_ts_val[idx, :, :]
    x2 = X_static_val[idx]
    y_pred = saved_model.predict([x1, x2])

    # Get the predictions at the leaf
    for i in range(y_pred.shape[0]):
    
        pseudo_probs, weighted_tree = get_prediction_probs(y_pred[[i], :])
        leaf_prob , _ = get_highest_prob_path(weighted_tree)
        y_pred[i, -19:] =  leaf_prob

    y_pred_label = np.argmax(y_pred[:, -19:], axis=1)
    y_true_label = np.argmax(Y_val[idx, -19:], axis=1)

    # Print the stats
    print(f'For {int(f*100)}% of the light curve, these are the statistics')
    report = classification_report(y_true_label, y_pred_label, target_names=list(loss_object.level_order_nodes)[-19:])
    print(report)

    # Make plots
    leaf_labels = list(loss_object.level_order_nodes)[-19:]
    plot_title = f"~{f * 100}% of each LC visible"
    cf_plot_file = f"gif/leaf_cf/{f}.png"
    roc_plot_file = f"gif/leaf_roc/{f}.png"
    
    plot_confusion_matrix(y_true_label, y_pred_label, leaf_labels, plot_title, cf_plot_file)
    plt.close()
    plot_roc_curves(Y_val[idx, -19:], y_pred[:, -19:], leaf_labels, plot_title, roc_plot_file)
    plt.close()

Running inference for 10% light curves...
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 116ms/step
For 10% of the light curve, these are the statistics
               precision    recall  f1-score   support

          AGN       0.32      0.48      0.38        52
         SNIa       0.00      0.00      0.00        41
       SNIb/c       0.10      0.08      0.09        50
        SNIax       0.19      0.10      0.13        50
      SNI91bg       0.33      0.12      0.18        73
         SNII       0.19      0.09      0.12        47
           KN       0.05      0.20      0.08        45
  Dwarf Novae       0.58      0.15      0.24        46
        uLens       0.62      0.19      0.29        52
M-dwarf Flare       0.34      0.22      0.27        50
         SLSN       1.00      0.05      0.10        60
          TDE       0.27      0.08      0.12        50
         ILOT       0.31      0.21      0.25        48
         CART       0.09      0.41      0.15        54
    

## Making a cool animation:

In [77]:
cf_files = [f"gif/root_cf/{f}.png" for f in fractions]
make_gif(cf_files, 'gif/root_cf/root_cf.gif')
plt.close()

In [129]:
roc_files = [f"gif/root_roc/{f}.png" for f in fractions]
make_gif(roc_files, 'gif/root_roc/root_roc.gif')
plt.close()

In [78]:
cf_files = [f"gif/leaf_cf/{f}.png" for f in fractions]
make_gif(cf_files, 'gif/leaf_cf/leaf_cf.gif')
plt.close()

In [130]:
roc_files = [f"gif/leaf_roc/{f}.png" for f in fractions]
make_gif(roc_files, 'gif/leaf_roc/leaf_roc.gif')
plt.close()

For the love of everthing that is good in this world, please use a different notebook for testing and genearting statistics. Keep this notebook simple enought to be converted into a script. 