# Neural network for track classification

Tracks detected in particle colliders like the LHC can originate from various sources, such as:

- Prompt tracks: Directly from the collision point.
- Pile-up tracks: From additional, unwanted collisions.
- Non-prompt tracks: From decays of B and D mesons and other hadrons.

Understanding the origin of these tracks is crucial for many physics analyses and will help in creating a powerful lepton isolation algorithm.
In this notebook, we aim to classify these tracks using machine learning techniques.

**Objective:** Our main objective is to build a classification algorithm that can accurately determine the origin of each track. We will:

- Use track information (impact parameters, transverse momentum, relative momentum w.r.t. closest lepton) to predict their origin.
- Implement a fully connected neural network using PyTorch, a leading library for building deep learning models.

**Tools:** To achieve our goals, we will make use of the following tools and libraries:

- PyTorch: For designing and training the neural network.
- scikit-learn (sklearn): To preprocess the data and for implementing other useful machine learning utilities.
- scikit-plot: For visualizing the results, especially for metrics like the ROC curve.
- atlas-ftag-tools Python package: Specifically designed for handling track data in the context of the ATLAS experiment.

Eventually, you can use the output of the network to build a track-based isolation variable (e.g. by computing a per-track discriminant and summing over the discriminants of all tracks inside a cone).

## Import of python modules

In [None]:
import numpy as np
import pandas as pd
from ftag.hdf5 import H5Reader

import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from sklearn.preprocessing import StandardScaler

## Preparation: Getting the `tracks` dataset

In [None]:
fname = "maya.h5"
reader = H5Reader(fname, batch_size=100, jets_name="muons")

muon_variables = ["pt", "eta", "ptvarcone30Rel", "iffClass"]
track_variables = [
    "pt", "eta", "phi", 
    "ptfrac", "dr_trackjet", "dr_leptontrack", 
    "btagIp_d0", "btagIp_z0SinTheta",
    "btagIp_d0_significance", "btagIp_z0SinTheta_significance", 
    "leptonID",
    "ftagTruthOriginLabel", "ftagTruthTypeLabel", "ftagTruthVertexIndex",
    "valid"
]
data = reader.load({"muons": muon_variables, "muon_tracks": track_variables}, num_jets=3_000) 

In [None]:
# get tracks dataset
tracks = data['muon_tracks']
tracks = tracks.flatten()
tracks = tracks[np.where(tracks["valid"])]
df = pd.DataFrame(tracks)

In [None]:
# create feature matrix X and label vector y
X = df.drop(columns=["valid", "ftagTruthOriginLabel", "ftagTruthTypeLabel", "ftagTruthVertexIndex"])
y = df[["ftagTruthOriginLabel"]]

In [None]:
# look at the feature matrix
X.head()

In [None]:
# look at the target labels
y.head()

## Exploratory Analysis: Correlation of features

Before training neural networks (or any machine learning model), it makes sense to conduct an exploratory analysis of the dataset. Otherwise you just blindly trust a non-deterministic optimisation procedure, which is rarely a wise strategy.

One component of this analysis is examining the correlation between features. It helps you identify which features carry similar information and how they relate to each other, possibly introducing redundancy in your training dataset.

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 7))
corr_X = X.corr()

# plot correlations among variables as heatmap
ax = sns.heatmap(
    corr_X,
    vmin=-1, vmax=1, center=0,
    cmap=sns.diverging_palette(20, 220, n=200),
    square=True,  annot=True, fmt='.2f', ax = ax)
ax.set_title("Correlation")
plt.tight_layout()
plt.show()

## Preprocessing of training data

When working with numerical data in machine learning, it is crucial to scale features to a common scale. This ensures that no single feature dominates others in the model training process, leading to more stable and meaningful results. The `StandardScaler` from `sklearn.preprocessing` is a popular choice for this purpose as it scales the data such that the mean is centred as zero and scales the data to unit variance.

In [None]:
scaler = StandardScaler()
scaled_X = scaler.fit_transform(X)

scaled_X

## Dataset for pytorch

We will make use of pytorch for training. We will provide the data as a torch dataset.
To provide some additional functionality, we define our own dataloader class based on the torch Dataset class. This dataset class puts together everything of the above (except for the exploratory analysis) and bundles it, so that you just have to define a `TrackDataset` and under its hood all the hard work is done for you.

In [None]:
class TrackDataset(Dataset):
    def __init__(self, file_path="maya.h5", num_jets=3000):
        super().__init__()
        reader = H5Reader(file_path, batch_size=100, jets_name="muons")

        muon_variables = ["pt", "eta", "ptvarcone30Rel", "iffClass"]
        track_variables = [
            "pt", "eta", "phi", 
            "ptfrac", "dr_trackjet", "dr_leptontrack", 
            "btagIp_d0", "btagIp_z0SinTheta",
            "btagIp_d0_significance", "btagIp_z0SinTheta_significance", 
            "leptonID",
            "ftagTruthOriginLabel", "ftagTruthTypeLabel", "ftagTruthVertexIndex",
            "valid"
        ]
        data = reader.load({"muons": muon_variables, "muon_tracks": track_variables}, num_jets=num_jets)
        # get tracks dataset
        tracks = data['muon_tracks']
        tracks = tracks.flatten()
        tracks = tracks[np.where(tracks["valid"])]
        df = pd.DataFrame(tracks)
        df.drop(columns=["valid"], inplace=True)

        # split into classes
        df_prompt = df[df['ftagTruthOriginLabel'] == 2]
        df_pileup = df[df['ftagTruthOriginLabel'] == 0]
        df_nonprompt = df[(df['ftagTruthOriginLabel'] != 0) & (df['ftagTruthOriginLabel'] != 2)]

        # create a balanced dataset
        n_events = min([len(df_prompt), len(df_pileup), len(df_nonprompt)])
        df_prompt = df_prompt.sample(n=n_events, random_state=42)
        df_pileup = df_pileup.sample(n=n_events, random_state=42)
        df_nonprompt = df_nonprompt.sample(n=n_events, random_state=42)

        # merge and shuffle
        df = pd.concat([df_prompt, df_pileup, df_nonprompt])
        df = df.sample(frac=1, random_state=42)
        
        # store the inputs and outputs as values
        X = df.drop(columns=["ftagTruthOriginLabel", "ftagTruthTypeLabel", "ftagTruthVertexIndex"]).values
        # y = df["ftagTruthOriginLabel"].values

        # convert y to a three column dataset for one-hot encoding
        is_prompt = df['ftagTruthOriginLabel'] == 2
        is_pileup = df['ftagTruthOriginLabel'] == 0
        
        df['is_prompt'] = is_prompt.astype(int)
        df['is_pileup'] = is_pileup.astype(int)
        df['is_other'] = (~(is_prompt | is_pileup)).astype(int)

        target_columns = ['is_prompt', 'is_pileup', 'is_other']
        y = df[target_columns].values
        
        # standardise X
        scale = StandardScaler()
        scaled_X = scale.fit_transform(X)
        
        # make tensor
        self.X = torch.tensor(scaled_X, dtype=torch.float32)
        self.y = torch.squeeze(torch.tensor(y, dtype=torch.float32))
            
    def __len__(self):
        return len(self.X)
 
    def __getitem__(self, idx):
        return [self.X[idx], self.y[idx]]
 
    def get_splits(self, n_test=0.1, n_val=0.1):
        # determine sizes
        test_size = round(n_test * len(self.X))
        val_size = round(n_val * len(self.X))
        train_size = len(self.X) - test_size - val_size
        # calculate the split
        return random_split(self, [train_size, val_size, test_size])

To simplify our lives, we will use a class to provide the dataset.

In [None]:
dataset = TrackDataset("maya.h5")

We can now inspect the dataset and observe that it contains the feature matrix and the target labels. Finally we check if the dataset is balanced, i.e. if every class is equally populated (which we have implemented in the dataset).

In [None]:
dataset.X

In [None]:
dataset.y

In [None]:
# check if the dataset really is balanced
dataset.y.sum(dim=0)

With the `TrackDataset` in place, we now define a function to load the train, validation and test datasets.

We will use the validation dataset to monitor the training, while the final evaluation of the performance will be done using the test dataset.

In [None]:
# set all the random seeds
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(0)

# define function for providing train, validation and test dataloaders
def prepare_data():
    # load the dataset
    dataset = TrackDataset()
    print(dataset.__len__)

    # calculate split
    train, val, test = dataset.get_splits()

    # prepare data loaders for training
    train_dl = DataLoader(train, batch_size=100, shuffle=True,  worker_init_fn=seed_worker,
    generator=g)
    val_dl = DataLoader(val, batch_size=100, shuffle=False)
    test_dl = DataLoader(test, batch_size=100, shuffle=False,  worker_init_fn=seed_worker,
    generator=g)
    return train_dl,  val_dl, test_dl

In [None]:
train_dl, val_dl, test_dl = prepare_data()

## Defining the machine learning model: fully connected neural network

In this section, we will define the architecture of our machine learning model, which is a fully connected (dense) neural network. Fully connected networks are a type of deep learning model where each neuron in one layer is connected to all neurons in the next layer. This connectivity pattern makes them particularly suited for learning from structured data like the track features in our dataset.

**Model Structure**

Our model, `NNClassifier`, allows for setting the number of hidden layers and the size of each layer when defining it. Here’s a brief outline of its components:

- Input Layer: The first layer in the network, which receives the input features. The number of neurons in this layer matches the dimensionality of our input data (default: 11 features).
- Hidden Layers: These layers allow the network to learn more complex patterns in the data. Each hidden layer is followed by a `ReLU` activation function, which introduces non-linearity into the model, enabling it to capture more complex relationships between the inputs and outputs. The use of multiple hidden layers helps in deep representation learning, which is important for accurately classifying the track data.
- Output Layer: The final layer of our model has three neurons, corresponding to the three classes of track origins (prompt, pile-up, and other). It uses a `Softmax` activation function to output a probability distribution over the three classes. This setup ensures that each output can be interpreted as the probability that a given input belongs to one of the three classes.
- Forward Pass: Defines how the input data flows through these layers. Each layer's output becomes the next layer's input, culminating in the output layer that produces the final prediction probabilities.

In [None]:
import torch.nn as nn

In [None]:
class NNClassifier(nn.Module):
    def __init__(self, num_hidden_layers, hidden_layer_size):
        super().__init__()
        self.layers = nn.ModuleList()

        # Assume hidden_layer_size includes input to first hidden layer and all subsequent layer sizes
        # Add the input layer
        self.layers.append(nn.Linear(hidden_layer_size[0][0], hidden_layer_size[0][1]))
        self.layers.append(nn.ReLU())

        # Add hidden layers
        for i in range(1, num_hidden_layers - 1):  # Adjusted to properly index hidden_layer_size
            self.layers.append(nn.Linear(hidden_layer_size[i][0], hidden_layer_size[i][1]))
            self.layers.append(nn.ReLU())

        # Add the output layer for three classes
        self.layers.append(nn.Linear(hidden_layer_size[-2][1], 3))  # Output size set to 3 for three classes
        self.layers.append(nn.Softmax(dim=1))  # Use Softmax for multi-class classification

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

In [None]:
# define hyperparameters and structure of network
num_hidden_layers = 6
num_features = len(dataset.X.T) # number of columns in X

# you can try different configurations of the network, e.g. adding or removing layers
hidden_layer_size = [(num_features, 20), (20, 40), (40, 80), (80, 60), (60, 20), (20,3)]

# this is the network
model = NNClassifier(num_hidden_layers, hidden_layer_size)
print(model)

## Training the model

Training a model in machine learning involves iteratively adjusting the model's parameters to minimize the difference between the predicted output and the actual output.
This process requires a suitably defined machine learning model (here: fully connected neural network) but data loaders to ensure efficient data flow, as well as metrics for optimizing parameters, and assessing model performance.

**Setting Up the Training Loop**

We structure our training function to handle both training and validation phases within each epoch. This approach allows us to monitor the model's performance on unseen data (validation set) while it learns from the training set. The training loop has the following components:

- Model Transfer to Device: We ensure that the model operates on the correct device (CPU or GPU). If you train large datasets on CPU, you will have a bad (and long) time.
- Loss and Accuracy Tracking: We maintain a history dictionary to track training and validation loss and accuracy. This data will help us evaluating the model's learning progression and tuning its hyperparameters.
- Batch Processing: The model processes data in batches, which helps to balance computational efficiency and memory usage. Batch processing is also beneficial for gradient descent optimization, as it introduces some noise into the gradient calculations, potentially helping to avoid local minima which are harmful if the training gets stuck in these.
- Forward and Backward Passes: For each batch, the model performs a forward pass to compute predictions and a backward pass to update weights based on the gradient of the loss function.
- Optimizer Step: Post the backward pass, the optimizer updates the model parameters based on the gradients computed during backpropagation.
- Validation Phase: After updating the model on the entire training data, we evaluate its performance on the validation set without making any further adjustments to the model's parameters.


**Epoch-wise Iteration**

The training process is divided into epochs. An epoch represents one full cycle through the entire training dataset. Here’s what happens during each epoch:

- Training Phase: The model learns from the training data, adjusting its weights to minimize the loss function.
- Validation Phase: The model's performance is assessed on a separate validation dataset that it has not seen during training. This helps estimating how well the model is likely to perform on general, unseen data.

**Performance Metrics**

We calculate and log the average loss and accuracy for both training and validation datasets after each epoch. We also track the time taken to complete the training.

In [None]:
import time

In [None]:
def train(model, optimizer, loss_fn, train_dl, val_dl, epochs=30, device='cpu'):
    model = model.to(device)
    print(f'train() called: model={type(model).__name__}, optimizer={type(optimizer).__name__}(lr={optimizer.param_groups[0]["lr"]}), epochs={epochs}, device={device}')

    # we will collect information about loss and accuracy on
    # train and validation datasets per epoch in this dictionary
    history = {
        'loss': [],
        'val_loss': [],
        'acc': [],
        'val_acc': []
    }

    # start the stopwatch and begin training loop
    start_time_sec = time.time()
    for epoch in range(1, epochs+1):
        model.train()
        train_loss, train_correct, train_total = 0, 0, 0
        # iterate over batches in training dataloader
        for x, y in train_dl:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(x)
            loss = loss_fn(outputs, y)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * x.size(0)
            _, predicted = torch.max(outputs, 1)
            _, labels = torch.max(y, 1)
            train_correct += (predicted == labels).sum().item()
            train_total += x.size(0)

        avg_train_loss = train_loss / train_total
        train_acc = train_correct / train_total

        # Validation phase
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0

        with torch.no_grad():
            for x, y in val_dl:
                x, y = x.to(device), y.to(device)
                outputs = model(x)
                loss = loss_fn(outputs, y)
                val_loss += loss.item() * x.size(0)
                _, predicted = torch.max(outputs, 1)
                _, labels = torch.max(y, 1)
                val_correct += (predicted == labels).sum().item()
                val_total += x.size(0)

        avg_val_loss = val_loss / val_total
        val_acc = val_correct / val_total

        # Logging
        history['loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['acc'].append(train_acc)
        history['val_acc'].append(val_acc)

        print(f'Epoch {epoch}/{epochs}, Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}')

    end_time_sec = time.time()
    print('Training completed in: {} seconds'.format(end_time_sec - start_time_sec))

    return history

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
optimizer = torch.optim.SGD(model.parameters(), lr=0.1) 
loss_module = nn.BCEWithLogitsLoss()

In [None]:
history = train(model=model, optimizer=optimizer, loss_fn=loss_module, train_dl=train_dl, val_dl=val_dl, epochs=100)

## Evaluation of trained model

After training a machine learning model, the next step is the evaluation of its performance. The following metrics are considered:

- Accuracy: This is a primary metric for evaluating classification models. It gives us a straightforward percentage of correctly predicted instances out of the total number of instances evaluated.
- Confusion Matrix: Provides a detailed breakdown of the model's performance across all classes, showing true positives, false positives, true negatives, and false negatives for each class. This matrix is particularly useful for identifying which classes are well-predicted by the model and which are often confused with others.
- Receiver Operating Characteristic (ROC) Curve and Area Under Curve (AUC): These are helpful for assessing the performance of a classification model at various threshold settings. The ROC curve plots the true positive rate against the false positive rate at different threshold levels.

**Evaluation Function Implementation**

In our evaluation function eval_model, the following steps are implemented:

- Model State: The model is set to evaluation mode, which disables dropout and batch normalization during the inference, ensuring consistent results across different evaluations.
- Data Handling: Data from the loader is processed batch by batch without any gradient calculations (`torch.no_grad()`), which minimizes memory consumption and computational overhead.
- Accuracy Calculation: As predictions are made, they are compared to the true labels to compute the overall accuracy of the model.
- Confusion Matrix Visualization: Using `scikitplot` (you might need to install this), we visually represent how well the model has predicted across different classes, highlighting potential areas for improvement.
- ROC and AUC Calculation: For each class, the ROC curve is calculated along with the AUC score. AUC provides a scalar value summarizing the overall ability of the model to discriminate between positive and negative classes

In [None]:
! pip install scikit-plot

In [None]:
import torch
import matplotlib.pyplot as plt
import scikitplot as skplt
from sklearn.metrics import roc_curve, auc

import torch
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize
from itertools import cycle

def eval_model(model, data_loader, device='cpu', class_count=3):
    model.eval()  # Set model to eval mode
    true_preds, num_preds = 0, 0
    pred_list = []
    target_list = []
    outputs_list = []

    with torch.no_grad():
        for data_inputs, data_target in data_loader:
            inputs = data_inputs.to(device)
            targets = data_target.to(device)

            # Get model predictions
            outputs = model(inputs)
            _, pred_labels = torch.max(outputs, 1)
            _, true_labels = torch.max(targets, 1)

            # Collect for analysis
            pred_list.extend(pred_labels.cpu().numpy())
            target_list.extend(true_labels.cpu().numpy())
            outputs_list.extend(outputs.cpu().numpy())

            # Evaluate accuracy
            true_preds += (pred_labels == true_labels).sum().item()
            num_preds += targets.size(0)

    # Calculate accuracy
    accuracy = true_preds / num_preds * 100
    print(f"Accuracy of the model: {accuracy:.2f}%")

    # Generate confusion matrix plot
    skplt.metrics.plot_confusion_matrix(target_list, pred_list, figsize=(8, 8))
    plt.title("Confusion Matrix")
    plt.show()

    # Prepare data for ROC calculation
    target_array = label_binarize(target_list, classes=list(range(class_count)))
    outputs_array = torch.nn.functional.softmax(torch.tensor(outputs_list), dim=1).numpy()

    # Compute ROC curve and ROC area for each class
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    colors = cycle(['blue', 'red', 'green', 'cyan', 'magenta', 'yellow', 'black'])
    plt.figure(figsize=(10, 8))

    for i, color in zip(range(class_count), colors):
        fpr[i], tpr[i], _ = roc_curve(target_array[:, i], outputs_array[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
        plt.plot(fpr[i], tpr[i], color=color, lw=2, label='ROC curve of class {0} (area = {1:0.2f})'.format(i, roc_auc[i]))

    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Multi-class Receiver Operating Characteristic')
    plt.legend(loc="lower right")
    plt.show()

    return accuracy


In [None]:
eval_model(model, data_loader=test_dl)