# Training and validating RNN + MPL model with the WHXE loss function

## Imports

In [1]:
import time
import pickle
import platform
import os
import imageio
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
import random 

from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.utils import plot_model
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
from tensorflow import keras
from tqdm import tqdm

from LSTM_model import get_LSTM_Classifier
from dataloader import LSSTSourceDataSet, load, get_augmented_data, get_static_features, ts_length
from loss import WHXE_Loss
from taxonomy import get_taxonomy_tree, get_prediction_probs, get_highest_prob_path, plot_colored_tree
from vizualizations import make_gif, plot_confusion_matrix, plot_roc_curves
from interpret_results import get_conditional_probabilites, save_all_cf_and_rocs, save_leaf_cf_and_rocs

2024-06-17 08:16:51.047795: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-06-17 08:16:52.466944: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9373] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-17 08:16:52.468465: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-17 08:16:52.664714: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1534] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-06-17 08:16:53.031900: I tensorflow/core/platform/cpu_feature_guar

In [2]:
print("Tensorflow version", tf.__version__)
print("Python version", platform.python_version())

Tensorflow version 2.15.0
Python version 3.10.12


In [3]:
seed = 42
random.seed(seed)

## Load and balance the tensors:

This step takes a while because it has load from disc to memory...

In [4]:
X_ts = load("processed/test/x_ts.pkl")
X_static = load("processed/test/x_static.pkl")
Y = load("processed/test/y.pkl")
astrophysical_classes = load("processed/test/a_labels.pkl")

In [5]:
a, b = np.unique(astrophysical_classes, return_counts=True)
print(f"Total sample count = {np.sum(b)}")
pd.DataFrame(data = {'Class': a, 'Count': b})

Total sample count = 463528


Unnamed: 0,Class,Count
0,AGN,32681
1,CART,3517
2,Cepheid,5901
3,Delta Scuti,8849
4,Dwarf Novae,3439
5,EB,28480
6,ILOT,3197
7,KN,1896
8,M-dwarf Flare,796
9,PISN,27250


Small step to convert X_static from a dictionary to an array

In [6]:
for i in tqdm(range(len(X_static))):        
    X_static[i] = get_static_features(X_static[i])

100%|██████████| 463528/463528 [00:02<00:00, 170670.85it/s]


In [7]:
max_class_count = 750

X_ts_balanced = []
X_static_balanced = []
Y_balanced = []
astrophysical_classes_balanced = []

for c in np.unique(astrophysical_classes):

    idx = list(np.where(np.array(astrophysical_classes) == c)[0])
    
    if len(idx) > max_class_count:
        idx = random.sample(idx, max_class_count)
 
    X_ts_balanced += [X_ts[i] for i in idx]
    X_static_balanced += [X_static[i] for i in idx]
    Y_balanced += [Y[i] for i in idx]
    astrophysical_classes_balanced += [astrophysical_classes[i] for i in idx]

# Print summary of the data set used for training and validation
a, b = np.unique(astrophysical_classes_balanced, return_counts=True)
data_summary = pd.DataFrame(data = {'Class': a, 'Count': b})
data_summary

Unnamed: 0,Class,Count
0,AGN,750
1,CART,750
2,Cepheid,750
3,Delta Scuti,750
4,Dwarf Novae,750
5,EB,750
6,ILOT,750
7,KN,750
8,M-dwarf Flare,750
9,PISN,750


In [8]:
del X_ts, X_static, Y, astrophysical_classes

## Load the best model and run inference on the test set

In [12]:
tree = get_taxonomy_tree()

In [9]:
best_model = keras.models.load_model(f"models/lsst_rate_agnostic/best_model.h5", compile=False)

2024-06-17 08:18:04.626195: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1926] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 31127 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:3a:00.0, compute capability: 7.0
2024-06-17 08:18:04.627505: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1926] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 31127 MB memory:  -> device: 1, name: Tesla V100-SXM2-32GB, pci bus id: 0000:b3:00.0, compute capability: 7.0


In [10]:
fractions = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]

In [13]:
for f in fractions:

    print(f'Running inference for {int(f*100)}% light curves...')

    x1, x2, y_true, _ = get_augmented_data(X_ts_balanced, X_static_balanced, Y_balanced, astrophysical_classes_balanced, fraction=f)
    
    # Run inference on these
    y_pred = best_model.predict([x1, x2])

    # Get the conditional probabilities
    _, pseudo_conditional_probabilities = get_conditional_probabilites(y_pred, tree)
    
    print(f'For {int(f*100)}% of the light curve, these are the statistics:')
    
    # Print all the stats and make plots...
    save_all_cf_and_rocs(y_true, pseudo_conditional_probabilities, tree, f)
    save_leaf_cf_and_rocs(y_true, pseudo_conditional_probabilities, tree, f)
    
    plt.close()

Running inference for 10% light curves...


100%|██████████| 14250/14250 [00:01<00:00, 13739.11it/s]


For 10% of the light curve, these are the statistics:
              precision    recall  f1-score   support

   Transient       0.93      1.00      0.96     10500
    Variable       0.99      0.78      0.87      3750

    accuracy                           0.94     14250
   macro avg       0.96      0.89      0.92     14250
weighted avg       0.94      0.94      0.94     14250

              precision    recall  f1-score   support

         AGN       1.00      0.44      0.61       750
        Fast       0.84      0.77      0.80      3000
        Long       0.72      0.60      0.65      3750
    Periodic       0.99      0.86      0.92      3000
          SN       0.59      0.87      0.70      3750

    accuracy                           0.75     14250
   macro avg       0.83      0.71      0.74     14250
weighted avg       0.79      0.75      0.75     14250

               precision    recall  f1-score   support

          AGN       1.00      0.45      0.62       750
         CART      

100%|██████████| 14250/14250 [00:00<00:00, 14252.39it/s]


For 20% of the light curve, these are the statistics:
              precision    recall  f1-score   support

   Transient       0.95      1.00      0.97     10500
    Variable       0.99      0.85      0.92      3750

    accuracy                           0.96     14250
   macro avg       0.97      0.93      0.95     14250
weighted avg       0.96      0.96      0.96     14250

              precision    recall  f1-score   support

         AGN       1.00      0.58      0.74       750
        Fast       0.90      0.80      0.85      3000
        Long       0.77      0.67      0.71      3750
    Periodic       0.99      0.92      0.96      3000
          SN       0.64      0.88      0.74      3750

    accuracy                           0.80     14250
   macro avg       0.86      0.77      0.80     14250
weighted avg       0.82      0.80      0.80     14250

               precision    recall  f1-score   support

          AGN       1.00      0.60      0.75       750
         CART      

100%|██████████| 14250/14250 [00:00<00:00, 14324.42it/s]


For 30% of the light curve, these are the statistics:
              precision    recall  f1-score   support

   Transient       0.96      1.00      0.98     10500
    Variable       0.99      0.89      0.94      3750

    accuracy                           0.97     14250
   macro avg       0.98      0.94      0.96     14250
weighted avg       0.97      0.97      0.97     14250

              precision    recall  f1-score   support

         AGN       1.00      0.70      0.82       750
        Fast       0.93      0.84      0.88      3000
        Long       0.82      0.70      0.75      3750
    Periodic       0.99      0.94      0.97      3000
          SN       0.68      0.90      0.77      3750

    accuracy                           0.83     14250
   macro avg       0.88      0.82      0.84     14250
weighted avg       0.85      0.83      0.83     14250

               precision    recall  f1-score   support

          AGN       1.00      0.72      0.84       750
         CART      

100%|██████████| 14250/14250 [00:01<00:00, 14174.70it/s]


For 40% of the light curve, these are the statistics:
              precision    recall  f1-score   support

   Transient       0.97      1.00      0.99     10500
    Variable       0.99      0.92      0.96      3750

    accuracy                           0.98     14250
   macro avg       0.98      0.96      0.97     14250
weighted avg       0.98      0.98      0.98     14250

              precision    recall  f1-score   support

         AGN       1.00      0.80      0.89       750
        Fast       0.94      0.87      0.90      3000
        Long       0.85      0.71      0.78      3750
    Periodic       0.99      0.95      0.97      3000
          SN       0.70      0.91      0.79      3750

    accuracy                           0.85     14250
   macro avg       0.90      0.85      0.87     14250
weighted avg       0.87      0.85      0.85     14250

               precision    recall  f1-score   support

          AGN       0.99      0.82      0.90       750
         CART      

100%|██████████| 14250/14250 [00:00<00:00, 14329.98it/s]


For 50% of the light curve, these are the statistics:
              precision    recall  f1-score   support

   Transient       0.98      1.00      0.99     10500
    Variable       1.00      0.95      0.97      3750

    accuracy                           0.99     14250
   macro avg       0.99      0.97      0.98     14250
weighted avg       0.99      0.99      0.99     14250

              precision    recall  f1-score   support

         AGN       0.99      0.88      0.93       750
        Fast       0.95      0.90      0.92      3000
        Long       0.88      0.73      0.80      3750
    Periodic       1.00      0.96      0.98      3000
          SN       0.73      0.92      0.81      3750

    accuracy                           0.87     14250
   macro avg       0.91      0.88      0.89     14250
weighted avg       0.89      0.87      0.87     14250

               precision    recall  f1-score   support

          AGN       0.99      0.89      0.94       750
         CART      

100%|██████████| 14250/14250 [00:01<00:00, 14147.89it/s]


For 60% of the light curve, these are the statistics:
              precision    recall  f1-score   support

   Transient       0.99      1.00      0.99     10500
    Variable       0.99      0.96      0.98      3750

    accuracy                           0.99     14250
   macro avg       0.99      0.98      0.99     14250
weighted avg       0.99      0.99      0.99     14250

              precision    recall  f1-score   support

         AGN       0.99      0.93      0.96       750
        Fast       0.96      0.92      0.94      3000
        Long       0.91      0.75      0.82      3750
    Periodic       0.99      0.97      0.98      3000
          SN       0.75      0.93      0.83      3750

    accuracy                           0.89     14250
   macro avg       0.92      0.90      0.91     14250
weighted avg       0.90      0.89      0.89     14250

               precision    recall  f1-score   support

          AGN       0.99      0.94      0.96       750
         CART      

100%|██████████| 14250/14250 [00:01<00:00, 14121.94it/s]


For 70% of the light curve, these are the statistics:
              precision    recall  f1-score   support

   Transient       0.99      1.00      0.99     10500
    Variable       0.99      0.98      0.98      3750

    accuracy                           0.99     14250
   macro avg       0.99      0.99      0.99     14250
weighted avg       0.99      0.99      0.99     14250

              precision    recall  f1-score   support

         AGN       0.98      0.96      0.97       750
        Fast       0.97      0.93      0.95      3000
        Long       0.92      0.75      0.83      3750
    Periodic       0.99      0.98      0.99      3000
          SN       0.77      0.94      0.85      3750

    accuracy                           0.90     14250
   macro avg       0.93      0.91      0.92     14250
weighted avg       0.91      0.90      0.90     14250

               precision    recall  f1-score   support

          AGN       0.98      0.96      0.97       750
         CART      

100%|██████████| 14250/14250 [00:01<00:00, 14199.89it/s]


For 80% of the light curve, these are the statistics:
              precision    recall  f1-score   support

   Transient       0.99      0.99      0.99     10500
    Variable       0.98      0.98      0.98      3750

    accuracy                           0.99     14250
   macro avg       0.99      0.99      0.99     14250
weighted avg       0.99      0.99      0.99     14250

              precision    recall  f1-score   support

         AGN       0.97      0.97      0.97       750
        Fast       0.97      0.93      0.95      3000
        Long       0.94      0.76      0.84      3750
    Periodic       0.99      0.98      0.99      3000
          SN       0.78      0.95      0.86      3750

    accuracy                           0.91     14250
   macro avg       0.93      0.92      0.92     14250
weighted avg       0.92      0.91      0.91     14250

               precision    recall  f1-score   support

          AGN       0.97      0.97      0.97       750
         CART      

100%|██████████| 14250/14250 [00:00<00:00, 14319.37it/s]


For 90% of the light curve, these are the statistics:
              precision    recall  f1-score   support

   Transient       1.00      0.99      0.99     10500
    Variable       0.98      0.99      0.98      3750

    accuracy                           0.99     14250
   macro avg       0.99      0.99      0.99     14250
weighted avg       0.99      0.99      0.99     14250

              precision    recall  f1-score   support

         AGN       0.97      0.99      0.98       750
        Fast       0.98      0.94      0.96      3000
        Long       0.94      0.78      0.85      3750
    Periodic       0.98      0.99      0.99      3000
          SN       0.79      0.95      0.87      3750

    accuracy                           0.91     14250
   macro avg       0.93      0.93      0.93     14250
weighted avg       0.92      0.91      0.91     14250

               precision    recall  f1-score   support

          AGN       0.97      0.99      0.98       750
         CART      

100%|██████████| 14250/14250 [00:00<00:00, 14256.71it/s]


For 100% of the light curve, these are the statistics:
              precision    recall  f1-score   support

   Transient       1.00      0.99      0.99     10500
    Variable       0.98      0.99      0.99      3750

    accuracy                           0.99     14250
   macro avg       0.99      0.99      0.99     14250
weighted avg       0.99      0.99      0.99     14250

              precision    recall  f1-score   support

         AGN       0.97      0.99      0.98       750
        Fast       0.98      0.95      0.96      3000
        Long       0.94      0.79      0.86      3750
    Periodic       0.98      0.99      0.99      3000
          SN       0.81      0.95      0.87      3750

    accuracy                           0.92     14250
   macro avg       0.94      0.93      0.93     14250
weighted avg       0.92      0.92      0.92     14250

               precision    recall  f1-score   support

          AGN       0.96      0.99      0.98       750
         CART     

## Making a cool animation:

In [14]:
cf_files = [f"gif/leaf_cf/{f}.png" for f in fractions]
make_gif(cf_files, 'gif/leaf_cf/leaf_cf.gif')
plt.close()

MovieWriter ffmpeg unavailable; using Pillow instead.


In [15]:
roc_files = [f"gif/leaf_roc/{f}.png" for f in fractions]
make_gif(roc_files, 'gif/leaf_roc/leaf_roc.gif')
plt.close()

MovieWriter ffmpeg unavailable; using Pillow instead.


In [16]:
cf_files = [f"gif/level_1_cf/{f}.png" for f in fractions]
make_gif(cf_files, 'gif/level_1_cf/level_1_cf.gif')
plt.close()

MovieWriter ffmpeg unavailable; using Pillow instead.


In [17]:
roc_files = [f"gif/level_1_roc/{f}.png" for f in fractions]
make_gif(roc_files, 'gif/level_1_roc/level_1_roc.gif')
plt.close()

MovieWriter ffmpeg unavailable; using Pillow instead.


In [18]:
cf_files = [f"gif/level_2_cf/{f}.png" for f in fractions]
make_gif(cf_files, 'gif/level_2_cf/level_2_cf.gif')
plt.close()

MovieWriter ffmpeg unavailable; using Pillow instead.


In [19]:
roc_files = [f"gif/level_2_roc/{f}.png" for f in fractions]
make_gif(roc_files, 'gif/level_2_roc/level_2_roc.gif')
plt.close()

MovieWriter ffmpeg unavailable; using Pillow instead.
