# Test stuff for PFN model

Classify pions, photons, and **axion2s**

In [1]:
# Add import paths
import sys
sys.path.append("..")

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

import numpy as np
import datetime as dt
from utils import print_gpu_memory, model_dir

# Make tensorflow not use too much memory
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In [3]:
# Get data
from data import get_data
(X_train, X_val, X_test,
 Y_train, Y_val, Y_test) = get_data("processed/axion2_test/all_jets_point_cloud.npz")

print(X_train.shape, Y_train.shape)

(240000, 960, 4) (240000, 3)


In [4]:
# Build model
from model import PFN
Phi_sizes = [128, 128, 128, 128, 64, 64, 64]
F_sizes = [128, 128, 128, 128, 64, 64, 64]

_, n_particles, n_features = X_train.shape
model = PFN(
    n_features=n_features,
    n_particles=n_particles,
    n_outputs=Y_train.shape[1],
    Phi_sizes=Phi_sizes,
    F_sizes=F_sizes
)

In [5]:
from collections import defaultdict
history = defaultdict(list)

In [6]:
# Utility functions for training
from train_model import train_model
import yaml

def train_iteration(lr, epochs):
    fit_history = train_model(
        model=model, 
        data=(X_train, X_val, Y_train, Y_val),
        lr=lr,
        epochs=epochs
    )
    print(f"\nSaving history...")
    for key in ["loss", "val_loss", "accuracy", "val_accuracy"]:
        history[key].extend(fit_history.history[key])
    return fit_history

def save_model(name):
    cur_date = dt.datetime.now().strftime("%Y-%m-%d")
    model.save(f"{model_dir}/{name}_{cur_date}")

In [7]:
print_gpu_memory()

GPU memory usage
  current: 556.75 KB
  peak:    618.25 KB


In [8]:
print(f"=== Training [{dt.datetime.now()}] ===")
train_iteration(lr=2e-4, epochs=45)

=== Training [2023-06-10 00:39:27.080289] ===
Epoch 1/45
Epoch 2/45
Epoch 3/45
Epoch 4/45
Epoch 5/45
Epoch 6/45
Epoch 7/45
Epoch 8/45
Epoch 9/45
Epoch 10/45
Epoch 11/45
Epoch 12/45
Epoch 13/45
Epoch 14/45
Epoch 15/45
Epoch 16/45
Epoch 17/45
Epoch 18/45
Epoch 19/45
Epoch 20/45
Epoch 21/45
Epoch 22/45
Epoch 23/45
Epoch 24/45
Epoch 25/45
Epoch 26/45
Epoch 27/45
Epoch 28/45
Epoch 29/45
Epoch 30/45
Epoch 31/45
Epoch 32/45
Epoch 33/45
Epoch 34/45
Epoch 35/45
Epoch 36/45
Epoch 37/45
Epoch 38/45
Epoch 39/45
Epoch 40/45
Epoch 41/45
Epoch 42/45
Epoch 43/45
Epoch 44/45
Epoch 45/45

Saving history...


<keras.callbacks.History at 0x2ae5ec42ed10>

In [9]:
print(f"=== Training [{dt.datetime.now()}] ===")
train_iteration(lr=2e-5, epochs=45)

=== Training [2023-06-10 01:13:56.713851] ===
Epoch 1/45
Epoch 2/45
Epoch 3/45
Epoch 4/45
Epoch 5/45
Epoch 6/45
Epoch 7/45
Epoch 8/45
Epoch 9/45
Epoch 10/45
Epoch 11/45
Epoch 12/45
Epoch 13/45
Epoch 14/45
Epoch 15/45
Epoch 16/45
Epoch 17/45
Epoch 18/45
Epoch 19/45
Epoch 20/45
Epoch 21/45
Epoch 22/45
Epoch 23/45
Epoch 24/45
Epoch 25/45
Epoch 26/45
Epoch 27/45
Epoch 28/45
Epoch 29/45
Epoch 30/45
Epoch 31/45
Epoch 32/45
Epoch 33/45
Epoch 34/45
Epoch 35/45
Epoch 36/45
Epoch 37/45
Epoch 38/45
Epoch 39/45
Epoch 40/45
Epoch 41/45
Epoch 42/45
Epoch 43/45
Epoch 44/45
Epoch 45/45

Saving history...


<keras.callbacks.History at 0x2ae5ee1107f0>

In [None]:
print(f"=== Training [{dt.datetime.now()}] ===")
train_iteration(lr=2e-6, epochs=30)

=== Training [2023-06-10 01:48:28.294255] ===
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30

In [None]:
save_model("")

### Evaluate model

In [None]:
from matplotlib import pyplot as plt

fig, axs = plt.subplots(1, 2, figsize=(10, 4))
ax1, ax2 = axs

ax1.plot(history["loss"])
ax1.plot(history["val_loss"])
ax1.legend(["loss", "val_loss"])
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss");

ax2.plot(history["accuracy"])
ax2.plot(history["val_accuracy"])
ax2.legend(["loss", "val_accuracy"])
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy");

In [None]:
# Test model
from test_model import test_model, plot_cm
test_accuracy, cm = test_model(model, (X_test, Y_test))
print(f"Testing accuracy: {test_accuracy}")

plot_cm(cm, ["pion", "photon", "axion2"])