## Paper: https://arxiv.org/abs/2202.03772v2

In [2]:
import sys
import os
from tqdm import tqdm
import numpy as np
# load torch modules

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
import glob

In [3]:
sys.path.insert(0, "../../")
from models.ParT.ParticleTransformerEncoder import ParTEncoder
from models.ParT.utils import calculate_cartesian_components, generate_mask

In [4]:
world_size = torch.cuda.device_count()
if world_size:
    device = torch.device("cuda:0")
    for i in range(world_size):
        print(
            f"Device {i}: {torch.cuda.get_device_name(i)}", flush=True
        )
else:
    device = torch.device("cpu")
    print("Device: CPU", file=logfile, flush=True)

Device 0: NVIDIA L4


# Load some data

In [5]:
def load_data(dataset_path, flag, n_files=-1):
    # make another variable that combines flag and subdirectory such as 3_features_raw
    path_id = f"{flag}-"
    data_files = glob.glob(f"{dataset_path}/{flag}/processed/6_features_raw/data/*")
    path_id += "6_features"

    data = []
    for i, _ in enumerate(data_files):
        data.append(
            np.load(f"{dataset_path}/{flag}/processed/6_features_raw/data/data_{i}.npy")
        )
        print(f"--- loaded file {i} from `{path_id}` directory")
        if n_files != -1 and i == n_files - 1:
            break
    return data


def load_labels(dataset_path, flag, n_files=-1):
    data_files = glob.glob(f"{dataset_path}/{flag}/processed/6_features_raw/labels/*")

    data = []
    for i, file in enumerate(data_files):
        data.append(
            np.load(
                f"{dataset_path}/{flag}/processed/6_features_raw/labels/labels_{i}.npy"
            )
        )
        print(f"--- loaded label file {i} from `{flag}` directory")
        if n_files != -1 and i == n_files - 1:
            break

    return data

In [34]:
data = load_data("/ssl-jet-vol-v3/toptagging", "train", 1)
labels = load_labels("/ssl-jet-vol-v3/toptagging", "train", 1)
num_jets = 10 # Do not use more than 1000, as memory usage will blow up
tr_dat_in = torch.from_numpy(np.concatenate(data, axis=0))[:num_jets].to(device)
tr_lab_in = torch.from_numpy(np.concatenate(labels, axis=0))[:num_jets].to(device)
print("training data shape:", tr_dat_in.shape)
print("training labels shape:", tr_lab_in.shape)

--- loaded file 0 from `train-6_features` directory
--- loaded label file 0 from `train` directory
training data shape: torch.Size([10, 6, 50])
training labels shape: torch.Size([10])


# Forward pass

In [32]:
v = calculate_cartesian_components(tr_dat_in).to(device)
print("shape of v:", v.shape)  # Should print (batch_size, 4, 50)
mask = generate_mask(tr_dat_in)
print("shape of mask:", mask.shape) # Should print (batch_size, 1, 50)

shape of v: torch.Size([10, 4, 50])
shape of mask: torch.Size([10, 1, 50])


In [33]:
output_dim = 128
encoder = ParTEncoder(input_dim=6, embed_dims=[128, 512, output_dim]).to(
    device
)  # the last embedding dimension is also the latent space dimension
print(encoder)
reps = encoder(tr_dat_in.to(torch.float32), v.to(torch.float32), mask)
print("shape of latent space:", reps.shape) # Should print (batch_size, output_dim)

ParticleTransformerEncoder(
  (embed): Embed(
    (input_bn): BatchNorm1d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (embed): Sequential(
      (0): LayerNorm((6,), eps=1e-05, elementwise_affine=True)
      (1): Linear(in_features=6, out_features=128, bias=True)
      (2): GELU(approximate='none')
      (3): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (4): Linear(in_features=128, out_features=512, bias=True)
      (5): GELU(approximate='none')
      (6): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (7): Linear(in_features=512, out_features=128, bias=True)
      (8): GELU(approximate='none')
    )
  )
  (pair_embed): PairEmbed(
    (embed): Sequential(
      (0): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): Conv1d(4, 64, kernel_size=(1,), stride=(1,))
      (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): GELU(approximate='none')
      (4

# Training ParticleTransformerEncoder with a linear classification layer

In [35]:
num_jets = 10000 # Here we can use more since batch_size is small
tr_dat_in = torch.from_numpy(np.concatenate(data, axis=0))[:num_jets].to(device)
tr_lab_in = torch.from_numpy(np.concatenate(labels, axis=0))[:num_jets].to(device)
print("training data shape:", tr_dat_in.shape)
print("training labels shape:", tr_lab_in.shape)

training data shape: torch.Size([10000, 6, 50])
training labels shape: torch.Size([10000])


In [27]:
from sklearn.metrics import accuracy_score
import warnings
warnings.filterwarnings('ignore')

def Projector(mlp, embedding):
    mlp_spec = f"{embedding}-{mlp}"
    layers = []
    f = list(map(int, mlp_spec.split("-")))
    for i in range(len(f) - 2):
        layers.append(nn.Linear(f[i], f[i + 1]))
        layers.append(nn.BatchNorm1d(f[i + 1]))
        layers.append(nn.ReLU())
    layers.append(nn.Linear(f[-2], f[-1], bias=False))
    return nn.Sequential(*layers)
proj = Projector(2, output_dim)

optimizer = optim.Adam(
            [{"params": proj.parameters()}, {"params": encoder.parameters(), "lr": 1e-6}],
            lr=1e-4,
        )
proj.to(device)
encoder.to(device)
loss = nn.CrossEntropyLoss(reduction="mean")
softmax = torch.nn.Softmax(dim=1)
indices_list = torch.split(torch.randperm(tr_dat_in.shape[0]), 32)
encoder.train()
proj.train()
num_epochs = 30
for epoch in range(num_epochs):
    loss_e = []
    predicted_e, correct_e = [], []
    for i, indices in enumerate(indices_list):
        optimizer.zero_grad()
        x = tr_dat_in[indices, :, :].to(device)
        v = calculate_cartesian_components(x).to(device)
        mask = generate_mask(x)
        y = tr_lab_in[indices]
        y = torch.Tensor(y).to(device)
        reps = encoder(x.to(torch.float32), v.to(torch.float32), mask)
        out = proj(reps)
        batch_loss = loss(out, y.long()).to(device)
        batch_loss.backward()
        optimizer.step()
        batch_loss = batch_loss.detach().cpu().item()
        # print(batch_loss)
        loss_e.append(batch_loss)
        predicted_e.append(softmax(out).cpu().data.numpy())
        correct_e.append(y.cpu().data)
    print(f"Epoch {epoch} loss {np.mean(loss_e)}")
    predicted = np.concatenate(predicted_e)
    target = np.concatenate(correct_e)

    # get the accuracy
    accuracy = accuracy_score(target, predicted[:, 1] > 0.5)
    print(
        "epoch: " + str(epoch) + ", accuracy: " + str(round(accuracy, 5)),
    )

Epoch 0 loss 0.6696202198918254
epoch: 0, accuracy: 0.5833
Epoch 1 loss 0.6143759173897508
epoch: 1, accuracy: 0.672
Epoch 2 loss 0.5357150441160599
epoch: 2, accuracy: 0.7496
Epoch 3 loss 0.4497538974966866
epoch: 3, accuracy: 0.8033
Epoch 4 loss 0.3990943685126381
epoch: 4, accuracy: 0.8301
Epoch 5 loss 0.37819552007384194
epoch: 5, accuracy: 0.8429
Epoch 6 loss 0.36639379283871515
epoch: 6, accuracy: 0.8459
Epoch 7 loss 0.3550738359030824
epoch: 7, accuracy: 0.8541
Epoch 8 loss 0.3476545858783082
epoch: 8, accuracy: 0.8549
Epoch 9 loss 0.3426943522529861
epoch: 9, accuracy: 0.86
Epoch 10 loss 0.33955272008626225
epoch: 10, accuracy: 0.8599
Epoch 11 loss 0.3362997557027652
epoch: 11, accuracy: 0.8621
Epoch 12 loss 0.3311624353210004
epoch: 12, accuracy: 0.8643
Epoch 13 loss 0.3287777806432864
epoch: 13, accuracy: 0.8627
Epoch 14 loss 0.3252396984888723
epoch: 14, accuracy: 0.8648
Epoch 15 loss 0.32175163988964245
epoch: 15, accuracy: 0.8675
Epoch 16 loss 0.31911054448769116
epoch: 16