In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import sys
from torch_geometric.data import DataLoader, Dataset, Data
import lightning.pytorch as pl
import seaborn as sns
import pandas as pd
import os
from tqdm import tqdm
import torch
import itertools
import yaml
from pytorch_lightning.loggers import WandbLogger

import matplotlib.pyplot as plt

from epic_clustering.utils import plot_clusters, get_cluster_pos
from epic_clustering.models import MemberClassification

## 1. Load Model and Test Dataloading

In [None]:
input_dir = "/global/cfs/cdirs/m3443/data/PowerWeek/train/train"
csv_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.csv')]

In [None]:
events = pd.read_csv(csv_files[0])
event = events[events["event"] == events["event"].unique()[0]]

In [None]:
event

In [None]:
# Sorted event.E - get first 40
high_energy_hits = event.sort_values(by="E", ascending=False).iloc[:40]

In [None]:
# use torch meshgrid to get all pairs between high_energy_hits.hit_number and event.hit_number
pairs = torch.meshgrid(torch.from_numpy(high_energy_hits.hit_number.values), torch.from_numpy(event.hit_number.values))
# convert into a 2 x N array
pairs = torch.stack(pairs).reshape(2, -1).T

In [None]:
class EventDataset(Dataset):
    """
    The custom default dataset to load CSV events off the disk
    """

    def __init__(self, input_dir, num_events = None, hparams=None, transform=None, pre_transform=None, pre_filter=None, **kwargs):
        super().__init__(input_dir, transform, pre_transform, pre_filter)
        
        self.input_dir = input_dir
        self.hparams = hparams
        self.num_events = num_events
        self.scales = {
                    "E": 30.,
                    "T": 100.,
                    "posx": 200.,
                    "posy": 200.,
                    "posz": 500.,
                }
        
        self.csv_events = self.load_datafiles_in_dir(self.input_dir, self.num_events)

        print("Converting to PyG data objects")
        self.pyg_events = [self.convert_to_pyg(event[1]) for event in tqdm(self.csv_events)]
        
    def load_datafiles_in_dir(self, input_dir, num_events):

        # Each file is 1000 events, so need to load num_events//1000 + 1 files
        csv_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.csv')][:num_events//1000 + 1]
        events = pd.concat([pd.read_csv(f) for f in csv_files])
        if num_events is not None:
            events = events[events.entry < num_events]

        self.scale_features(events)

        return list(events.groupby('entry'))

    def convert_to_pyg(self, event):

        # Convert to PyG data object
        event = event.reset_index(drop=True)
        event = event.drop(columns=['entry'])

        data.edge_index = self.create_training_pairs(event)
        y = event.clusterID[data.edge_index[0]] == event.clusterID[data.edge_index[1]]
        node_features = torch.from_numpy(event[['posx', 'posy', 'posz', 'E']].to_numpy())
        edge_features = torch.cat([node_features[data.edge_index[0]], node_features[data.edge_index[1]]], dim=1)

        data = Data(
                        x = edge_features,
                        y = y
                    )

        data.num_nodes = data.x.shape[0]

        return data
        
    def len(self):
        return len(self.pyg_events)

    def get(self, idx):

        return self.pyg_events[idx]

    def scale_features(self, event):
        """
        Handle feature scaling for the event
        """

        for feature in self.scales.keys():
            event[feature] = event[feature]/self.scales[feature]

        return event

    def create_training_pairs(self, event):
        """
        Create the true edge list for the event. This is 
        """

        # Sorted event.E - get first 40
        high_energy_hits = event.sort_values(by="E", ascending=False).iloc[:40]

        # use torch meshgrid to get all pairs between high_energy_hits.hit_number and event.hit_number
        pairs = torch.meshgrid(torch.from_numpy(high_energy_hits.hit_number.values), torch.from_numpy(event.hit_number.values))
        # convert into a 2 x N array
        pairs = torch.stack(pairs).reshape(2, -1)

        return pairs

In [None]:
dataset = EventDataset(input_dir, num_events=100)

In [None]:
# Split dataset into train val and test
from torch.utils.data import random_split

train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

## 3. Training Loop Test

In [2]:
with open("member_classification.yaml") as f:
    member_classification_config = yaml.safe_load(f)
model = MemberClassification(member_classification_config)
model.setup(stage="fit")

Converting to PyG data objects


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
100%|██████████| 220/220 [00:02<00:00, 84.63it/s] 

Loaded 200 training events, 10 validation events and 10 testing events





In [3]:
logger = WandbLogger(project=member_classification_config["project"])
trainer = pl.Trainer(devices=1, accelerator="gpu", max_epochs=100, logger=logger)
trainer.fit(model)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmurnanedaniel[0m. Use [1m`wandb login --relogin`[0m to force relogin


  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
You are using a CUDA device ('NVIDIA A100 80GB PCIe') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type       | Params
---------------------------------------
0 | network | Sequential | 793 K 
---------------------------------------
793 K     Trainable params
0         Non-trainable params
793 K     Total params
3.172     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 0: 100%|██████████| 200/200 [00:07<00:00, 27.67it/s, v_num=3s7c]



Epoch 3: 100%|██████████| 200/200 [00:07<00:00, 26.94it/s, v_num=3s7c]