# Training and validating RNN + MPL model with the WHXE loss function

## Imports

In [14]:
import time
import pickle
import platform
import os
import imageio
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf

from random import sample 
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.utils import plot_model
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
from tensorflow import keras

from LSTM_model import get_LSTM_Classifier
from dataloader import LSSTSourceDataSet, load, get_augmented_data, get_static_features, ts_length
from loss import WHXE_Loss
from taxonomy import get_taxonomy_tree, get_prediction_probs, get_highest_prob_path, plot_colored_tree
from vizualizations import make_gif, plot_confusion_matrix, plot_roc_curves
from interpret_results import get_conditional_probabilites, save_all_cf_and_rocs, save_leaf_cf_and_rocs

In [2]:
print("Tensorflow version", tf.__version__)
print("Python version", platform.python_version())

Tensorflow version 2.15.0
Python version 3.10.12


## Load and balance the tensors:

This step takes a while because it has load from disc to memory...

In [3]:
X_ts = load("processed/train/x_ts.pkl")
X_static = load("processed/train/x_static.pkl")
Y = load("processed/train/y.pkl")
astrophysical_classes = load("processed/train/a_labels.pkl")

In [4]:
a, b = np.unique(astrophysical_classes, return_counts=True)
print(f"Total sample count = {np.sum(b)}")
pd.DataFrame(data = {'Class': a, 'Count': b})

Total sample count = 1081614


Unnamed: 0,Class,Count
0,AGN,76258
1,CART,8207
2,Cepheid,13771
3,Delta Scuti,20650
4,Dwarf Novae,8025
5,EB,66454
6,ILOT,7461
7,KN,4426
8,M-dwarf Flare,1859
9,PISN,63586


Small step to convert X_static from a dictionary to an array

In [5]:
for i in range(len(X_static)):
    
    if i%1000 == 0:
        print(f"{(i/len(X_static) * 100):.3f} %", end="\r")
        
    X_static[i] = get_static_features(X_static[i])

99.943 %

Balance the data set in some way

In [6]:
max_class_count = 30000

X_ts_balanced = []
X_static_balanced = []
Y_balanced = []
astrophysical_classes_balanced = []

for c in np.unique(astrophysical_classes):

    idx = list(np.where(np.array(astrophysical_classes) == c)[0])
    
    if len(idx) > max_class_count:
        idx = sample(idx, max_class_count)
 
    X_ts_balanced += [X_ts[i] for i in idx]
    X_static_balanced += [X_static[i] for i in idx]
    Y_balanced += [Y[i] for i in idx]
    astrophysical_classes_balanced += [astrophysical_classes[i] for i in idx]

# Print summary of the data set used for training and validation
a, b = np.unique(astrophysical_classes_balanced, return_counts=True)
data_summary = pd.DataFrame(data = {'Class': a, 'Count': b})
data_summary

Unnamed: 0,Class,Count
0,AGN,30000
1,CART,8207
2,Cepheid,13771
3,Delta Scuti,20650
4,Dwarf Novae,8025
5,EB,30000
6,ILOT,7461
7,KN,4426
8,M-dwarf Flare,1859
9,PISN,30000


Free up some memory

In [7]:
del X_ts, X_static, Y, astrophysical_classes

Split into train and validation

In [8]:
val_fraction = 0.05
X_ts_train, X_ts_val, X_static_train, X_static_val, Y_train, Y_val, astrophysical_classes_train, astrophysical_classes_val = train_test_split(X_ts_balanced, X_static_balanced, Y_balanced, astrophysical_classes_balanced, shuffle=True, random_state = 40, test_size = val_fraction)

Free up some more memory

In [9]:
del X_ts_balanced, X_static_balanced, Y_balanced, astrophysical_classes_balanced

Check make up of the validation data

## Define the Loss function, criterion, etc.

In [10]:
# Loss and optimizer
tree = get_taxonomy_tree()
loss_object = WHXE_Loss(tree, astrophysical_classes_train, alpha=0) 
criterion = loss_object.compute_loss

## Train the classifier using WHXE loss and save the model

In [11]:
optimizer = keras.optimizers.Adam(learning_rate=5e-4)

2024-06-12 07:28:04.580517: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1926] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 31127 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:8a:00.0, compute capability: 7.0
2024-06-12 07:28:04.581992: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1926] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 31127 MB memory:  -> device: 1, name: Tesla V100-SXM2-32GB, pci bus id: 0000:b2:00.0, compute capability: 7.0


In [12]:
# Inputs for model
ts_dim = 5
static_dim = len(X_static_train[0])
output_dim = 26
latent_size = 64

num_epochs = 100
batch_size = 2048

In [13]:
model = get_LSTM_Classifier(ts_dim, static_dim, output_dim, latent_size)
keras.utils.plot_model(model, to_file='lstm.pdf', show_shapes=True, show_layer_names=True)
plt.close()

In [14]:
@tf.function
def train_step(x_ts, x_static, y):
    with tf.GradientTape() as tape:
        logits = model((x_ts, x_static), training=True)
        loss_value = criterion(y, logits)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))

    return loss_value

In [15]:
avg_train_losses = []
for epoch in range(num_epochs):
    
    print(f"\nStart of epoch {epoch}")
    start_time = time.time()

    print("Augmenting time series lengths...")
    
    # Create the augmented data set for training
    X_ts_train_aug, X_static_train_aug, Y_train_aug, astrophysical_classes_train_aug = get_augmented_data(X_ts_train, X_static_train, Y_train, astrophysical_classes_train)
    train_dataset =  tf.data.Dataset.from_tensor_slices((X_ts_train_aug, X_static_train_aug, Y_train_aug, astrophysical_classes_train_aug)).batch(batch_size)
    
    # Array to keep tracking of the training loss
    train_loss_values = []
    
    # Iterate over the batches of the dataset.
    for step, (x_ts_batch_train, x_static_batch_train, y_batch_train, a_class_batch_train) in enumerate(train_dataset):
        loss_value = train_step(x_ts_batch_train, x_static_batch_train, y_batch_train)
    
    # Log the avg train loss
    avg_train_loss = np.mean(loss_value)
    avg_train_losses.append(avg_train_loss)
    print(f"Avg training loss: {float(avg_train_loss):.4f}")
    
    print(f"Time taken: {time.time() - start_time:.2f}s")
    model.save(f"models/lsst_rate_agnostic/lstm_epoch_{epoch}.h5")
    
    # Save the model with the smallest training loss
    best_model_epoch = np.argmin(avg_train_losses)
    loaded_model = keras.models.load_model(f"models/lsst_rate_agnostic/lstm_epoch_{best_model_epoch}.h5", compile=False)
    loaded_model.save(f"models/lsst_rate_agnostic/best_model.h5")


Start of epoch 0
Augmenting time series lengths...
100.000 %

2024-06-12 07:28:58.108839: W tensorflow/core/common_runtime/type_inference.cc:339] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_INT32
    }
  }
}
 is neither a subtype nor a supertype of the combined inputs preceding it:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_FLOAT
    }
  }
}

	for Tuple type infernce function 0
	while inferring type of node 'cond_18/output/_21'
2024-06-12 07:28:58.588124: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:467] Loaded cuDNN version 90000
2024-06-12 07:28:59.283088: I external/local_xla/xla/service/service.cc:168] XLA service 0x14cec038f760 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2024-06-12 07:2

Avg training loss: 0.9533
Time taken: 75.88s


  saving_api.save_model(



Start of epoch 1
Augmenting time series lengths...
Avg training loss: 0.8633
Time taken: 59.32s

Start of epoch 2
Augmenting time series lengths...
Avg training loss: 0.8306
Time taken: 59.86s

Start of epoch 3
Augmenting time series lengths...
Avg training loss: 0.9151
Time taken: 49.72s

Start of epoch 4
Augmenting time series lengths...
Avg training loss: 0.8207
Time taken: 55.84s

Start of epoch 5
Augmenting time series lengths...
Avg training loss: 0.6935
Time taken: 63.93s

Start of epoch 6
Augmenting time series lengths...
Avg training loss: 0.6672
Time taken: 62.34s

Start of epoch 7
Augmenting time series lengths...
Avg training loss: 0.7832
Time taken: 53.06s

Start of epoch 8
Augmenting time series lengths...
Avg training loss: 0.6153
Time taken: 65.31s

Start of epoch 9
Augmenting time series lengths...
Avg training loss: 0.8155
Time taken: 49.91s

Start of epoch 10
Augmenting time series lengths...
Avg training loss: 0.7976
Time taken: 49.31s

Start of epoch 11
Augmenting

In [None]:
plt.scatter(range(num_epochs), avg_train_loss)
plt.xlabel("Epoch")
plt.ylabel("Avg log loss for training")

## Load the saved model and validate that everthing looks okay

In [11]:
best_model = keras.models.load_model(f"models/lsst_rate_agnostic/best_model.h5", compile=False)

2024-06-12 12:43:21.058951: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1926] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 31125 MB memory:  -> device: 0, name: Tesla V100-SXM3-32GB, pci bus id: 0000:34:00.0, compute capability: 7.0
2024-06-12 12:43:21.060893: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1926] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 31125 MB memory:  -> device: 1, name: Tesla V100-SXM3-32GB, pci bus id: 0000:36:00.0, compute capability: 7.0


In [12]:
fractions = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]

In [15]:
for f in fractions:

    print(f'Running inference for {int(f*100)}% light curves...')

    x1, x2, y_true, _ = get_augmented_data(X_ts_val, X_static_val, Y_val, astrophysical_classes_val, fraction=f)
    
    # Run inference on these
    y_pred = best_model.predict([x1, x2])

    # Get the conditional probabilities
    _, pseudo_conditional_probabilities = get_conditional_probabilites(y_pred, tree)
    
    print(f'For {int(f*100)}% of the light curve, these are the statistics:')
    
    # Print all the stats and make plots...
    save_all_cf_and_rocs(y_true, pseudo_conditional_probabilities, tree, f)
    save_leaf_cf_and_rocs(y_true, pseudo_conditional_probabilities, tree, f)
    
    plt.close()

Running inference for 10% light curves...
For 10% of the light curve, these are the statistics:
              precision    recall  f1-score   support

   Transient       0.92      1.00      0.96     14195
    Variable       1.00      0.76      0.86      5440

    accuracy                           0.93     19635
   macro avg       0.96      0.88      0.91     19635
weighted avg       0.94      0.93      0.93     19635

              precision    recall  f1-score   support

         AGN       1.00      0.57      0.72      1563
        Fast       0.65      0.91      0.76      1576
        Long       0.75      0.64      0.69      5278
    Periodic       1.00      0.84      0.91      3877
          SN       0.76      0.91      0.83      7341

    accuracy                           0.80     19635
   macro avg       0.83      0.77      0.78     19635
weighted avg       0.82      0.80      0.80     19635

               precision    recall  f1-score   support

          AGN       1.00      0.

## Making a cool animation:

In [16]:
cf_files = [f"gif/leaf_cf/{f}.png" for f in fractions]
make_gif(cf_files, 'gif/leaf_cf/leaf_cf.gif')
plt.close()

MovieWriter ffmpeg unavailable; using Pillow instead.


In [17]:
roc_files = [f"gif/leaf_roc/{f}.png" for f in fractions]
make_gif(roc_files, 'gif/leaf_roc/leaf_roc.gif')
plt.close()

MovieWriter ffmpeg unavailable; using Pillow instead.


In [18]:
cf_files = [f"gif/level_1_cf/{f}.png" for f in fractions]
make_gif(cf_files, 'gif/level_1_cf/level_1_cf.gif')
plt.close()

MovieWriter ffmpeg unavailable; using Pillow instead.


In [19]:
roc_files = [f"gif/level_1_roc/{f}.png" for f in fractions]
make_gif(roc_files, 'gif/level_1_roc/level_1_roc.gif')
plt.close()

MovieWriter ffmpeg unavailable; using Pillow instead.


In [20]:
cf_files = [f"gif/level_2_cf/{f}.png" for f in fractions]
make_gif(cf_files, 'gif/level_2_cf/level_2_cf.gif')
plt.close()

MovieWriter ffmpeg unavailable; using Pillow instead.


In [21]:
roc_files = [f"gif/level_2_roc/{f}.png" for f in fractions]
make_gif(roc_files, 'gif/level_2_roc/level_2_roc.gif')
plt.close()

MovieWriter ffmpeg unavailable; using Pillow instead.


For the love of everthing that is good in this world, please use a different notebook for testing and genearting statistics. Keep this notebook simple enought to be converted into a script. 