In [None]:
# Download the data 
!mkdir graph_data
!wget https://www.physi.uni-heidelberg.de/~dittmeier/pytorch/graph_data/graphs_pT2GeV.zip
#wget https://www.physi.uni-heidelberg.de/~dittmeier/pytorch/graph_data/graphs_no_pTCut.zip # uncomment this line to download the data without pT cut

In [None]:
!unzip graphs_pT2GeV.zip -d graph_data
#!unzip graphs_no_pTCut.zip -d graph_data   # uncomment this line to unzip the data without pT cut

In [7]:
data_dir = 'graph_data/pTge2GeV'
#data_dir = 'graph_data/nopTCut/' # uncomment this line to use the data without pT cut
## for local use:
data_dir = '/mnt/data0/Trackml_dataset_100_events/seb/metric_learning/pTge2GeV'

In [2]:
#install required packages
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)
#ensure that the PyTorch and the PyG are the same version
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

# Helper function for visualization.
%matplotlib inline
import networkx as nx
import matplotlib.pyplot as plt

2.1.0+cu121


In [226]:
from torch_geometric.data import Data

# Define the path to the PyG file
file_path = f'{data_dir}/trainset/event000021000.pyg'

# Load the PyG file
data = torch.load(file_path)

# Print the properties of the PyG file
print(data)
print(data.num_nodes)
print(data.num_edges)
print(data.edge_index)
print(data.y)
print(data.truth_map)
print(data.track_edges)
print(data.x)
print(data.r)
print(data.phi)
print(data.z)

DataBatch(leta=[2718], module_index=[2718], lx=[2718], region=[2718], lphi=[2718], cell_val=[2718], weight=[2718], lz=[2718], phi=[2718], ly=[2718], x=[2718], hit_id=[2718], eta=[2718], geta=[2718], y=[7649], r=[2718], cell_count=[2718], z=[2718], gphi=[2718], track_edges=[2, 2477], particle_id=[2477], radius=[2477], pt=[2477], nhits=[2477], config=[2], event_id=[1], num_nodes=2718, batch=[2718], ptr=[2], edge_index=[2, 7649], truth_map=[2477])
tensor(2718)
7649
tensor([[  34,   35,    0,  ..., 2710, 2711, 2715],
        [   0,    0,   79,  ..., 2712, 2712, 2716]])
tensor([ True, False, False,  ..., False, False,  True])
tensor([   6,   12,   37,  ..., 7641, 7643, 7648])
tensor([[   2,    5,   10,  ..., 2706, 2709, 2716],
        [   7,    3,   11,  ..., 2707, 2711, 2715]])
tensor([-109.3260,  -63.7417,  -42.1189,  ..., -307.3140, -307.0520,
        -891.4610], dtype=torch.float64)
tensor([109.3261,  87.2665,  68.3407,  ..., 798.6319, 798.4037, 973.5359],
       dtype=torch.float64)
te

In [227]:
from utils import eval_utils
from utils.version_utils import get_pyg_data_keys
from utils import (
    load_datafiles_in_dir,
    run_data_tests,
    handle_weighting,
    handle_hard_cuts,
    remap_from_mask,
    handle_edge_features,
    get_optimizers,
    get_condition_lambda,
)

device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")

In [228]:
from torch_geometric.data import Dataset

class GraphDataset(Dataset):
    """
    The custom default GNN dataset to load graphs off the disk
    """

    def __init__(
        self,
        input_dir,
        data_name=None,
        num_events=None,
        stage="fit",
        hparams=None,
        transform=None,
        pre_transform=None,
        pre_filter=None,
        preprocess=True,
    ):
        if hparams is None:
            hparams = {}
        super().__init__(input_dir, transform, pre_transform, pre_filter)

        self.input_dir = input_dir
        self.data_name = data_name
        self.hparams = hparams
        self.num_events = num_events
        self.stage = stage
        self.preprocess = preprocess
        self.transform = transform

        self.input_paths = load_datafiles_in_dir(
            self.input_dir, self.data_name, self.num_events
        )
        self.input_paths.sort()  # We sort here for reproducibility

    def len(self):
        return len(self.input_paths)

    def get(self, idx):
        event_path = self.input_paths[idx]
        event = torch.load(event_path, map_location=torch.device("cpu"))
        # convert DataBatch to Data instance because some transformations don't work on DataBatch
        event = Data(**event.to_dict())
        if not self.preprocess:
            return event
        event = self.preprocess_event(event)
        # do pyg transformation if a torch_geometric.transforms instance is given
        if self.transform is not None:
            event = self.transform(event)

        # return (event, event_path) if self.stage == "predict" else event
        return event
    
    def preprocess_event(self, event):
        """
        Process event before it is used in training and validation loops
        """
        event = self.construct_weighting(event)
        event = self.scale_features(event)
        return event

    def construct_weighting(self, event):
        """
        Construct the weighting for the event
        """

        assert event.y.shape[0] == event.edge_index.shape[1], (
            f"Input graph has {event.edge_index.shape[1]} edges, but"
            f" {event.y.shape[0]} truth labels"
        )

        if self.hparams is not None and "weighting" in self.hparams.keys():
            assert isinstance(self.hparams["weighting"], list) & isinstance(
                self.hparams["weighting"][0], dict
            ), "Weighting must be a list of dictionaries"
            event.weights = handle_weighting(event, self.hparams["weighting"])
        else:
            event.weights = torch.ones_like(event.y, dtype=torch.float32)

        return event


    def scale_features(self, event):
        """
        Handle feature scaling for the event
        """

        if (
            self.hparams is not None
            and "node_scales" in self.hparams.keys()
            and "node_features" in self.hparams.keys()
        ):
            assert isinstance(
                self.hparams["node_scales"], list
            ), "Feature scaling must be a list of ints or floats"
            for i, feature in enumerate(self.hparams["node_features"]):
                assert feature in get_pyg_data_keys(
                    event
                ), f"Feature {feature} not found in event"
                event[feature] = event[feature] / self.hparams["node_scales"][i]

        return event

    def apply_score_cut(self, event, score_cut):
        """
        Apply a score cut to the event. This is used for the evaluation stage.
        """
        passing_edges_mask = event.scores >= score_cut
        num_edges = event.edge_index.shape[1]
        for key in get_pyg_data_keys(event):
            if (
                isinstance(event[key], torch.Tensor)
                and event[key].shape
                and (
                    event[key].shape[0] == num_edges
                    or event[key].shape[-1] == num_edges
                )
            ):
                event[key] = event[key][..., passing_edges_mask]

        remap_from_mask(event, passing_edges_mask)
        return event

    def get_y_node(self, event):
        y_node = torch.zeros(event.z.size(0))
        y_node[event.track_edges.view(-1)] = 1
        event.y_node = y_node
        return event

In [229]:
parameters = {
    "node_features": ["r",  "phi", "z"],
    "node_scales": [1000, 3.14,  1000],
    #"weighting": [], we can play with this
}

trainset = GraphDataset(f"{data_dir}/trainset", hparams=parameters)
valset = GraphDataset(f"{data_dir}/valset", hparams=parameters)
testset = GraphDataset(f"{data_dir}/testset", hparams=parameters)
print("Number of samples in trainset:", len(trainset))
print("Number of samples in valset:", len(valset))
print("Number of samples in testset:", len(testset))


Number of samples in trainset: 80
Number of samples in valset: 10
Number of samples in testset: 10


In [230]:
from torch_geometric.loader import DataLoader

# Set batch size and number of workers
batch_size = 1
num_workers = 1

# Create dataloaders
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
val_loader = DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

for test_event in train_loader:
    print(test_event.r)
    print(test_event.y.shape)
    print(test_event.weights.shape)
    break

tensor([0.1093, 0.0873, 0.0683,  ..., 0.7986, 0.7984, 0.9735],
       dtype=torch.float64)
torch.Size([7649])
torch.Size([7649])


In [231]:
from torch_geometric.nn import MessagePassing
import torch
import torch.nn as nn

# https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_gnn.html

class InteractionConv(MessagePassing):
    def __init__(
        self,
        hidden_dim,
        aggr="add",
        *,
        aggr_kwargs={},
        flow: str = "source_to_target",
        node_dim: int = -2,
        decomposed_layers: int = 1,
        **kwargs
    ):
        super().__init__(
            aggr,
            aggr_kwargs=aggr_kwargs,
            flow=flow,
            node_dim=node_dim,
            decomposed_layers=decomposed_layers,
            **kwargs
        )
        
        self.node_network = nn.Sequential(
            nn.Linear(2*hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

        self.edge_network = nn.Sequential(
            nn.Linear(3*hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

    def message(self,e):
        # constructs messages for each edge; e as given to propagate
        return e

    def aggregate(
        self,
        inputs: torch.Tensor,
        index: torch.Tensor,
        edge_index,
        x,
        ptr=None,
        dim_size=None,
    ) -> torch.Tensor:
        # takes output from message as first argument (inputs = e), and can take any other data passed to propagate
        # so the edge data is creating a message for source and destination nodes
        src_message = self.aggr_module(inputs, edge_index[1], dim_size=x.size(0))
        dst_message = self.aggr_module(inputs, edge_index[0], dim_size=x.size(0))
        out = src_message + dst_message
        return out

    def update(self, inputs: torch.Tensor, x) -> torch.Tensor:
        # takes the aggregated messages and the node data and updates the node data
        x_in = torch.cat([x, inputs], dim=1)
        out = self.node_network(x_in)
        return out

    def edge_update(self, edge_index, x, e) -> torch.Tensor:
        x_in = torch.cat([x[edge_index[0]], x[edge_index[1]], e], dim=1)
        out = self.edge_network(x_in)
        return out

    def forward(self, edge_index, x, e):
        # propagate: initial call to start propagating messages, takes edge indices and any other data
        # propagate calls message, aggreate and update functions
        x = self.propagate(edge_index=edge_index, x=x, e=e)
        # then we update our edge features, this calls edge_update
        e = self.edge_updater(edge_index=edge_index, x=x, e=e)
        return x, e



In [232]:
class InteractionNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_iterations=2):
        super(InteractionNetwork, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.n_iterations = n_iterations
        
        self.node_encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

        self.edge_encoder = nn.Sequential(
            nn.Linear(2*input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

        self.conv = InteractionConv(hidden_dim)
        
        self.edge_classifier = nn.Sequential(
            nn.Linear(3* hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
    # we take a full data batch, and make use of features we want 
    def forward(self, batch):
        # Extract node features
        x = torch.stack([batch.r, batch.phi, batch.z], dim=-1).to(torch.float32)
        edge_index = batch.edge_index
        #print(f"x= {x}")
        #print(f"edge_index= {edge_index.shape}")
        
        #if "undirected" in self.hparams and self.hparams["undirected"]:
        #edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
        #print(f"edge_index= {edge_index.shape}")

        start, end = edge_index
        x.requires_grad = True
        #print(start, end)

        e = self.edge_encoder(torch.cat([x[start], x[end]], dim=-1))
        x = self.node_encoder(x)
        #print(x)
        #print(e)
        # Message passing
        for i in range(self.n_iterations):
            x, e = self.conv(edge_index, x, e)
            
        #return
        # Decode edge features
        decoded = self.edge_classifier(torch.cat([x[start], x[end], e], dim=-1))
        
        #print(decoded)
        return decoded


In [233]:
model = InteractionNetwork(input_dim=3, hidden_dim=32, n_iterations=2)
print(test_event)
test_run = model(test_event)
print(test_run)


DataBatch(x=[2718], edge_index=[2, 7649], y=[7649], leta=[2718], module_index=[2718], lx=[2718], region=[2718], lphi=[2718], cell_val=[2718], weight=[2718], lz=[2718], phi=[2718], ly=[2718], hit_id=[2718], eta=[2718], geta=[2718], r=[2718], cell_count=[2718], z=[2718], gphi=[2718], track_edges=[2, 2477], particle_id=[2477], radius=[2477], pt=[2477], nhits=[2477], config=[1], event_id=[1], num_nodes=2718, batch=[2718], truth_map=[2477], weights=[7649], ptr=[2])
tensor([[0.2116],
        [0.2464],
        [0.2155],
        ...,
        [0.2229],
        [0.2230],
        [0.2267]], grad_fn=<AddmmBackward0>)


In [234]:
learning_rate = 1e-3
epochs = 5
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [235]:
import torch.nn.functional as F

def loss_function(output, batch):
        """
        Applies the loss function to the output of the model and the truth labels.
        To balance the positive and negative contribution, simply take the means of each separately.
        Any further fine tuning to the balance of true target, true background and fake can be handled
        with the `weighting` config option.
        """

        assert hasattr(batch, "y"), (
            "The batch does not have a truth label. Please ensure the batch has a `y`"
            " attribute."
        )
        assert hasattr(batch, "weights"), (
            "The batch does not have a weighting label. Please ensure the batch"
            " weighting is handled in preprocessing."
        )

        negative_mask = ((batch.y == 0) & (batch.weights != 0)) | (batch.weights < 0)
        #print(negative_mask.shape)
        #print(output[negative_mask].shape)
        #print(batch.weights[negative_mask].abs().shape)

        negative_loss = F.binary_cross_entropy_with_logits(
            output[negative_mask],
            torch.zeros_like(output[negative_mask]),
            #weight=batch.weights[negative_mask].abs(),
            reduction="sum",
        )

        positive_mask = (batch.y == 1) & (batch.weights > 0)
        positive_loss = F.binary_cross_entropy_with_logits(
            output[positive_mask],
            torch.ones_like(output[positive_mask]),
            #weight=batch.weights[positive_mask].abs(),
            reduction="sum",
        )

        n = positive_mask.sum() + negative_mask.sum()
        return (
            (positive_loss + negative_loss) / n,
            positive_loss.detach() / n,
            negative_loss.detach() / n,
        )

In [236]:
# loops over our optimization code
def train_loop(dataloader, model, loss_fn, optimizer, device):
    size = len(dataloader.dataset)
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()

    for batchid, batch in enumerate(dataloader):

        output = model(batch)
        loss, pos_loss, neg_loss = loss_fn(output, batch)
        scores = torch.sigmoid(output)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batchid % 10 == 0:
            loss, current = loss.item(), batchid * batch_size + len(batch.event_id)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

        # Clear GPU memory
        del batch
        torch.cuda.empty_cache()



In [237]:
# evaluate the model's performance against the test dataset
def test_loop(dataloader, model, loss_fn, optimizer, device):
    size = len(dataloader.dataset)
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.eval()
    test_loss = 0
    with torch.no_grad():   
        for batchid, batch in enumerate(dataloader):

            output = model(batch)
            loss, pos_loss, neg_loss = loss_fn(output, batch)
            scores = torch.sigmoid(output)
            test_loss += loss

            # Clear GPU memory
            del batch
            torch.cuda.empty_cache()
    test_loss /= size
    print(f"Test Error: \n Avg loss: {test_loss:>8f} \n")
    return test_loss


In [238]:
val_loss = []
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_loader, model, loss_function, optimizer, device)
    val_loss.append(test_loop(val_loader, model, loss_function, optimizer, device))
print("Done!")

Epoch 1
-------------------------------
loss: 0.741230  [    1/   80]
loss: 0.641859  [   11/   80]
loss: 0.613610  [   21/   80]
loss: 0.616529  [   31/   80]
loss: 0.554355  [   41/   80]
loss: 0.585642  [   51/   80]
loss: 0.573608  [   61/   80]
loss: 0.600520  [   71/   80]
Test Error: 
 Avg loss: 0.567785 

Epoch 2
-------------------------------
loss: 0.591930  [    1/   80]
loss: 0.579573  [   11/   80]
loss: 0.563805  [   21/   80]
loss: 0.574940  [   31/   80]
loss: 0.510036  [   41/   80]
loss: 0.550159  [   51/   80]
loss: 0.535475  [   61/   80]
loss: 0.562351  [   71/   80]
Test Error: 
 Avg loss: 0.527509 

Epoch 3
-------------------------------
loss: 0.546477  [    1/   80]
loss: 0.530489  [   11/   80]
loss: 0.501238  [   21/   80]
loss: 0.465706  [   31/   80]
loss: 0.404503  [   41/   80]
loss: 0.411440  [   51/   80]
loss: 0.395450  [   61/   80]
loss: 0.390389  [   71/   80]
Test Error: 
 Avg loss: 0.372886 

Epoch 4
-------------------------------
loss: 0.379105 

In [None]:
#for name, param in model.named_parameters():
#    print(f"Parameter name: {name}, Size: {param.size()}, Values: {param}")

# Assuming you have the `test_loss` variable containing the loss values for each epoch
epochs = range(1, len(val_loss) + 1)

plt.figure(figsize=(6, 4))
plt.plot(epochs, val_loss, 'b', label='Test Loss')
plt.title('Validation Loss vs Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()