# Dynamic Graph Neural Network


## Imports


In [None]:
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import uproot
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [None]:
class color:
    PURPLE = "\033[95m"
    CYAN = "\033[96m"
    DARKCYAN = "\033[36m"
    BLUE = "\033[94m"
    GREEN = "\033[92m"
    YELLOW = "\033[93m"
    RED = "\033[91m"
    BOLD = "\033[1m"
    UNDERLINE = "\033[4m"
    END = "\033[0m"

## Initialization


In [None]:
sample = False
read = True

## Read Data


In [None]:
if sample == True:
    file = uproot.open("./datasets/ttbar_sample.root")
    tree = file["events"]
    branches = tree.arrays()
else:
    file = uproot.open("./datasets/ttbar.root")
    tree = file["Events"]
    branches = tree.arrays()

In [None]:
print(color.BOLD + "File:" + color.END, file)
print(
    color.BOLD + "File Keys are",
    len(tree.keys()),
    "elements and are:" + color.END,
    file.keys(),
)
print(color.BOLD + "File Classnames:" + color.END, file.classnames())

print(color.BOLD + "Tree:" + color.END, tree)
print(
    color.BOLD + "Tree Keys are",
    len(tree.keys()),
    "elements and are:" + color.END,
    tree.keys(),
)

print(color.BOLD + "Branches:" + color.END, branches)
print(
    color.BOLD + "PV are",
    len(branches["PV_x"]),
    "elements and are:" + color.END,
    branches["PV_x"],
)

In [None]:
branches[1].tolist()

In [None]:
print(color.BOLD + "Mean:               " + color.END, np.mean(branches["nMuon"]))
print(color.BOLD + "Standard Deviation: " + color.END, np.std(branches["nMuon"]))
print(color.BOLD + "Minimum:            " + color.END, np.min(branches["nMuon"]))
print(color.BOLD + "Maximum:            " + color.END, np.max(branches["nMuon"]))

In [None]:
plt.hist(branches["nMuon"], bins=10, range=(0, 10))
plt.xlabel("Number of muons in event")
plt.ylabel("Number of events")

## Extract Data


In [None]:
print(branches.fields)

We will extract all data anyways and then use the ones we need.


In [None]:
if not read:
    for branch_name in branches.fields:
        print(branch_name)
        with open(f"./datasets/ttbar/{branch_name}.txt", "w") as f:
            for value in branches[branch_name]:
                f.write(f"{value}\n")

## Prepare Data

For primary vertex identification in ttbar events, the key features to consider are

1. Primary vertex coordinates: `PV_x`, `PV_y`, `PV_z`
2. Track-related features from muons (since they're good for vertex identification):

    - Impact parameter: `Muon_dxy`, `Muon_dxyErr`
    - Longitudinal impact parameter: `Muon_dz`, `Muon_dzErr`
    - Kinematics: `Muon_pt`, `Muon_eta`, `Muon_phi`

3. Jet features (can help identify the hard scatter vertex): `Jet_pt`, `Jet_eta`, `Jet_phi`
4. Number of primary vertices `PV_npvs` as auxiliary information


In [None]:
PV_x = branches["PV_x"]
PV_y = branches["PV_y"]
PV_z = branches["PV_z"]
PV_npvs = branches["PV_npvs"]
print(color.BOLD + "PV_x are " + f"{len(PV_x)}:" + color.END, PV_x)
print(color.BOLD + "PV_y are " + f"{len(PV_y)}:" + color.END, PV_y)
print(color.BOLD + "PV_z are " + f"{len(PV_z)}:" + color.END, PV_z)
print(color.BOLD + "PV_npvs are " + f"{len(PV_npvs)}:" + color.END, PV_npvs)


print("\n")


Muon_pt = branches["Muon_pt"]
Muon_eta = branches["Muon_eta"]
Muon_phi = branches["Muon_phi"]
Muon_dxy = branches["Muon_dxy"]
Muon_dxyErr = branches["Muon_dxyErr"]
Muon_dz = branches["Muon_dz"]
Muon_dzErr = branches["Muon_dzErr"]
print(color.BOLD + "Muon_pt are " + f"{len(Muon_pt)}:" + color.END, Muon_pt)
print(color.BOLD + "Muon_eta are " + f"{len(Muon_eta)}:" + color.END, Muon_eta)
print(color.BOLD + "Muon_phi are " + f"{len(Muon_phi)}:" + color.END, Muon_phi)
print(color.BOLD + "Muon_dxy are " + f"{len(Muon_dxy)}:" + color.END, Muon_dxy)
print(color.BOLD + "Muon_dxyErr are " + f"{len(Muon_dxyErr)}:" + color.END, Muon_dxyErr)
print(color.BOLD + "Muon_dz are " + f"{len(Muon_dz)}:" + color.END, Muon_dz)
print(color.BOLD + "Muon_dzErr are " + f"{len(Muon_dzErr)}:" + color.END, Muon_dzErr)


print("\n")


Jet_pt = branches["Jet_pt"]
Jet_eta = branches["Jet_eta"]
Jet_phi = branches["Jet_phi"]
print(color.BOLD + "Jet_pt are " + f"{len(Jet_pt)}:" + color.END, Jet_pt)
print(color.BOLD + "Jet_eta are " + f"{len(Jet_eta)}:" + color.END, Jet_eta)
print(color.BOLD + "Jet_phi are " + f"{len(Jet_phi)}:" + color.END, Jet_phi)


print("\n")


MET_pt = branches["MET_pt"]
MET_phi = branches["MET_phi"]
print(color.BOLD + "MET_pt are " + f"{len(MET_pt)}:" + color.END, MET_pt)
print(color.BOLD + "MET_phi are " + f"{len(MET_phi)}:" + color.END, MET_phi)

In [None]:
PV = {
    "PV_x": PV_x,
    "PV_y": PV_y,
    "PV_z": PV_z,
    "PV_npvs": PV_npvs,
}

Muon = {
    "Muon_pt": Muon_pt,
    "Muon_eta": Muon_eta,
    "Muon_phi": Muon_phi,
    "Muon_dxy": Muon_dxy,
    "Muon_dxyErr": Muon_dxyErr,
    "Muon_dz": Muon_dz,
    "Muon_dzErr": Muon_dzErr,
}

Jet = {
    "Jet_pt": Jet_pt,
    "Jet_eta": Jet_eta,
    "Jet_phi": Jet_phi,
}


PVdf = pd.DataFrame(PV)
Muondf = pd.DataFrame(Muon)
Jetdf = pd.DataFrame(Jet)
print(PVdf.head())
print(Muondf.head())
print(Jetdf.head())

In [None]:
class VertexDataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.FloatTensor(data)
        self.labels = torch.FloatTensor(labels)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


def prepare_data(data, max_muons=5, max_jets=8):
    """
    Prepare features for the model.
    Args:
        data: Dictionary containing the dataset
        max_muons: Maximum number of muons to consider
        max_jets: Maximum number of jets to consider
    """
    features_list = []

    # Add PV information
    features_list.extend([data["PV_x"], data["PV_y"], data["PV_z"], data["PV_npvs"]])

    # Add Muon information (for up to max_muons)
    for i in range(max_muons):
        if i < len(data["Muon_pt"]):
            features_list.extend(
                [
                    data["Muon_pt"][i],
                    data["Muon_eta"][i],
                    data["Muon_phi"][i],
                    data["Muon_dxy"][i],
                    data["Muon_dxyErr"][i],
                    data["Muon_dz"][i],
                    data["Muon_dzErr"][i],
                ]
            )
        else:
            # Padding for missing muons
            features_list.extend([0] * 7)

    # Add Jet information (for up to max_jets)
    for i in range(max_jets):
        if i < len(data["Jet_pt"]):
            features_list.extend(
                [data["Jet_pt"][i], data["Jet_eta"][i], data["Jet_phi"][i]]
            )
        else:
            # Padding for missing jets
            features_list.extend([0] * 3)

    return np.array(features_list)

In [None]:
df = prepare_data(branches)

# if read:
#     columns = [
#         "PV_x",
#         "PV_y",
#         "PV_z",
#         "Muon_dxy",
#         "Muon_dxyErr",
#         "Muon_dz",
#         "Muon_dzErr",
#         "Muon_pt",
#         "Muon_eta",
#         "Muon_phi",
#         "Jet_pt",
#         "Jet_eta",
#         "Jet_phi",
#         "PV_npvs",
#     ]

#     dataframes = [
#         pd.read_csv(
#             f"./datasets/ttbar/{col}.txt",
#             header=None,
#             names=[col],
#             delim_whitespace=True,
#         )
#         for col in columns
#     ]
#     data = pd.concat(dataframes, axis=1)

#     print(data)

## Additional Features


In [None]:
# Feature Engineering
def engineer_vertex_features(df):
    """
    Engineer additional features for vertex identification.
    """
    # Track quality metrics
    df["track_quality"] = df["Muon_pt"] / df["Muon_ptErr"]

    # Impact parameter significance
    df["dxy_significance"] = df["Muon_dxy"] / df["Muon_dxyErr"]
    df["dz_significance"] = df["Muon_dz"] / df["Muon_dzErr"]

    # Angular separation between tracks
    df["delta_phi"] = abs(df["Muon_phi"] - df["Jet_phi"])
    df["delta_eta"] = abs(df["Muon_eta"] - df["Jet_eta"])

    return df


# Pile-up Mitigation
def pileup_weighting(df):
    """
    Apply pile-up dependent weights
    """
    # Example weighting based on number of vertices
    weights = 1.0 / (1.0 + 0.1 * df["PV_npvs"])
    return weights


# Quality Cuts
def apply_quality_cuts(df):
    """
    Apply basic quality cuts for vertex identification
    """
    mask = (
        (abs(df["Muon_dxy"]) < 0.2)  # Impact parameter cut
        & (df["Muon_pt"] > 20)  # Minimum pt cut
        & (abs(df["Muon_eta"]) < 2.4)  # Eta acceptance
        & (df["dxy_significance"] < 5)  # Impact parameter significance
    )
    return df[mask]


# Track Clustering
def cluster_tracks(df, max_dz=0.2):
    """
    Simple track clustering by z-position
    """
    from sklearn.cluster import DBSCAN

    # Prepare features for clustering
    X = np.vstack([df["PV_z"], df["Muon_dz"]]).T

    # Perform clustering
    clustering = DBSCAN(eps=max_dz, min_samples=2).fit(X)

    return clustering.labels_


# Vertex Scoring
def score_vertex_candidates(df):
    """
    Score vertex candidates based on various criteria
    """
    scores = (
        df["track_quality"] * 0.3
        + (1.0 / df["dxy_significance"]) * 0.3
        + (df["Muon_pt"] / df["Muon_pt"].max()) * 0.4
    )
    return scores

## Model


**Model Architecture:**

-   Deep neural network with 5 layers (256→128→64→32→1)
-   Includes dropout for regularization
-   Batch normalization for better training stability
-   ReLU activation functions
-   Sigmoid output for binary classification

**Training Pipeline:**

-   Custom Dataset class for efficient data handling
-   Training loop with validation
-   Adam optimizer and Binary Cross Entropy loss
-   Built-in progress monitoring


In [None]:
class VertexGNN(nn.Module):
    def __init__(
        self,
        node_features=10,  # 7 track + 3 jet features
        edge_features=2,  # η-φ distance features
        global_features=3,  # PV_npvs + MET features
        hidden_dim=128,
    ):
        super().__init__()

        # Edge convolution layers
        self.edge_conv1 = geom_nn.EdgeConv(
            nn=nn.Sequential(
                nn.Linear(2 * node_features + edge_features, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            ),
            aggr="max",
        )

        self.edge_conv2 = geom_nn.EdgeConv(
            nn=nn.Sequential(
                nn.Linear(2 * hidden_dim + edge_features, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            ),
            aggr="max",
        )

        self.edge_conv3 = geom_nn.EdgeConv(
            nn=nn.Sequential(
                nn.Linear(2 * hidden_dim + edge_features, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            ),
            aggr="max",
        )

        # Combine with global features
        self.global_mlp = nn.Sequential(
            nn.Linear(hidden_dim + global_features, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

        # Final classification layers
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, data):
        # Unpack the graph data
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        batch = data.batch if hasattr(data, "batch") else None
        global_features = data.global_features

        # Apply edge convolutions
        x = self.edge_conv1(x, edge_index, edge_attr)
        x = self.edge_conv2(x, edge_index, edge_attr)
        x = self.edge_conv3(x, edge_index, edge_attr)

        # Global pooling
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        x = scatter_mean(x, batch, dim=0)

        # Combine with global features
        x = torch.cat([x, global_features], dim=1)
        x = self.global_mlp(x)

        # Final classification
        return self.classifier(x)


# class VertexIdentificationGNN(nn.Module):
#     def __init__(self, input_dim, hidden_dim=128, output_dim=1):
#         super(VertexIdentificationGNN, self).__init__()
#         self.conv1 = GCNConv(input_dim, hidden_dim)
#         self.conv2 = GCNConv(hidden_dim, hidden_dim)
#         self.conv3 = GCNConv(hidden_dim, hidden_dim)
#         self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2)
#         self.fc2 = nn.Linear(hidden_dim // 2, output_dim)
#         self.dropout = nn.Dropout(0.2)

#     def forward(self, x, edge_index, batch):
#         x = self.conv1(x, edge_index)
#         x = F.relu(x)
#         x = self.dropout(x)
#         x = self.conv2(x, edge_index)
#         x = F.relu(x)
#         x = self.dropout(x)
#         x = self.conv3(x, edge_index)
#         x = F.relu(x)
#         x = self.dropout(x)

#         # Global mean pooling
#         x = global_mean_pool(x, batch)

#         x = self.fc1(x)
#         x = F.relu(x)
#         x = self.dropout(x)
#         x = self.fc2(x)
#         x = torch.sigmoid(x)
#         return x

## Train


In [None]:
def train_gnn(model, train_loader, optimizer, epoch):
    model.train()
    loss_fn = nn.BCELoss()

    for batch in train_loader:
        optimizer.zero_grad()
        batch = batch.to(device)
        out = model(batch)
        loss = loss_fn(out, batch.y)
        loss.backward()
        optimizer.step()


# def train_model(model, train_loader, val_loader, num_epochs=50, learning_rate=0.001):
#     criterion = nn.BCELoss()
#     optimizer = optim.Adam(model.parameters(), lr=learning_rate)

#     for epoch in range(num_epochs):
#         model.train()
#         train_loss = 0
#         for batch_data, batch_labels in train_loader:
#             optimizer.zero_grad()
#             outputs = model(batch_data)
#             loss = criterion(outputs, batch_labels.unsqueeze(1))
#             loss.backward()
#             optimizer.step()
#             train_loss += loss.item()

#         # Validation
#         model.eval()
#         val_loss = 0
#         with torch.no_grad():
#             for batch_data, batch_labels in val_loader:
#                 outputs = model(batch_data)
#                 loss = criterion(outputs, batch_labels.unsqueeze(1))
#                 val_loss += loss.item()

#         if (epoch + 1) % 5 == 0:
#             print(
#                 f"Epoch [{epoch+1}/{num_epochs}], "
#                 f"Train Loss: {train_loss/len(train_loader):.4f}, "
#                 f"Val Loss: {val_loss/len(val_loader):.4f}"
#             )


# # Calculate input dimension
# input_dim = 4 + (7 * 5) + (3 * 8)  # PV features + Muon features + Jet features

In [None]:
def prepare_graph_data(tracks, jets, global_features):
    """
    Prepare graph data structure from physics objects
    """
    # Node features: [pT, η, φ, dxy, dxyErr, dz, dzErr] + [Jet_pT, Jet_eta, Jet_phi]
    node_features = torch.cat([tracks, jets], dim=1)

    # Calculate edges (fully connected graph)
    num_nodes = node_features.size(0)
    edge_index = torch.combinations(torch.arange(num_nodes), r=2).t()

    # Calculate edge features (η-φ distances)
    node_eta_phi = node_features[:, [1, 2]]  # η, φ columns
    edge_attr = calc_delta_r(node_eta_phi[edge_index[0]], node_eta_phi[edge_index[1]])

    return Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
        global_features=global_features,
    )


def calc_delta_r(point1, point2):
    """Calculate η-φ distance between points"""
    delta_eta = point1[:, 0] - point2[:, 0]
    delta_phi = torch.abs(point1[:, 1] - point2[:, 1])
    delta_phi = torch.min(delta_phi, 2 * torch.pi - delta_phi)
    return torch.stack([delta_eta, delta_phi], dim=1)

In [None]:
model = VertexIdentificationModel(input_dim)

# Prepare your data
X_train = ...  # Your training data
y_train = ...  # Your training labels
X_val = ...  # Your validation data
y_val = ...  # Your validation labels

# Create datasets
train_dataset = VertexDataset(X_train, y_train)
val_dataset = VertexDataset(X_val, y_val)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

# Train the model
train_model(model, train_loader, val_loader)

In [None]:
# class DynamicGNN(nn.Module):
#     def __init__(self, num_features):
#         super().__init__()
#         self.edge_conv1 = EdgeConv(
#             nn.Sequential(
#                 nn.Linear(2 * num_features, 128),
#                 nn.BatchNorm1d(128),
#                 nn.ReLU(),
#                 nn.Linear(128, 64),
#             )
#         )
#         self.edge_conv2 = EdgeConv(
#             nn.Sequential(
#                 nn.Linear(2 * 64, 256),
#                 nn.BatchNorm1d(256),
#                 nn.ReLU(),
#                 nn.Linear(256, 128),
#             )
#         )
#         self.edge_conv3 = EdgeConv(
#             nn.Sequential(
#                 nn.Linear(2 * 128, 512),
#                 nn.BatchNorm1d(512),
#                 nn.ReLU(),
#                 nn.Linear(512, 256),
#             )
#         )

#         self.node_update = nn.Sequential(
#             nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Linear(128, 64)
#         )

#         self.global_pool = global_max_pool

#         self.classifier = nn.Sequential(
#             nn.Linear(64, 32), nn.ReLU(), nn.Linear(32, 1), nn.Sigmoid()
#         )

#     def forward(self, x, edge_index):
#         # x: (N, num_features), edge_index: (2, E)

#         x = self.edge_conv1(x, edge_index)
#         x = self.edge_conv2(x, edge_index)
#         x = self.edge_conv3(x, edge_index)

#         x = self.node_update(x)

#         graph_feature = self.global_pool(
#             x, torch.zeros(x.size(0), dtype=torch.long, device=x.device)
#         )

#         out = self.classifier(graph_feature)

#         return out

In [None]:
# # Example usage
# model = DynamicGNN(num_features=13)
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# for epoch in range(num_epochs):
#     model.train()

#     # Prepare input data
#     track_features = ...  # (N, 13)
#     edge_index = ...  # (2, E)

#     output = model(track_features, edge_index)
#     loss = criterion(output, target)

#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()

#     # Evaluation
#     model.eval()
#     with torch.no_grad():
#         eval_output = model(eval_track_features, eval_edge_index)
#         eval_loss = criterion(eval_output, eval_target)

#         # Calculate evaluation metrics
#         # ...

# Var 2


In [None]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch_geometric.nn import EdgeConv, global_max_pool, global_mean_pool
# from torch_geometric.utils import k_hop_subgraph


# class DynamicEdgeConv(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super().__init__()
#         self.mlp = nn.Sequential(
#             nn.Linear(2 * in_channels, 2 * out_channels),
#             nn.BatchNorm1d(2 * out_channels),
#             nn.ReLU(),
#             nn.Linear(2 * out_channels, out_channels),
#         )

#     def forward(self, x, edge_index):
#         row, col = edge_index
#         edge_attr = torch.cat([x[row], x[col]], dim=1)
#         return self.mlp(edge_attr)


# class DynamicGNN(nn.Module):
#     def __init__(self, num_features, num_classes):
#         super().__init__()
#         self.edge_conv1 = DynamicEdgeConv(num_features, 64)
#         self.edge_conv2 = DynamicEdgeConv(64, 128)
#         self.edge_conv3 = DynamicEdgeConv(128, 256)

#         self.node_update1 = nn.Sequential(
#             nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Linear(128, 64)
#         )
#         self.node_update2 = nn.Sequential(
#             nn.Linear(64, 32), nn.BatchNorm1d(32), nn.ReLU(), nn.Linear(32, 16)
#         )

#         self.global_pool = global_max_pool
#         self.classifier = nn.Sequential(
#             nn.Linear(16, 8), nn.ReLU(), nn.Linear(8, num_classes), nn.Sigmoid()
#         )

#     def forward(self, x, edge_index):
#         # x: (N, num_features), edge_index: (2, E)

#         x = self.edge_conv1(x, edge_index)
#         x = self.edge_conv2(x, edge_index)
#         x = self.edge_conv3(x, edge_index)

#         x = self.node_update1(x)
#         x = self.node_update2(x)

#         graph_feature = self.global_pool(
#             x, torch.zeros(x.size(0), dtype=torch.long, device=x.device)
#         )

#         out = self.classifier(graph_feature)

#         return out


# # Example usage
# model = DynamicGNN(num_features=13, num_classes=2)
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# for epoch in range(num_epochs):
#     model.train()

#     # Prepare input data
#     track_features = ...  # (N, 13)
#     edge_index = ...  # (2, E)

#     output = model(track_features, edge_index)
#     loss = criterion(output, target)

#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()

#     # Evaluation
#     model.eval()
#     with torch.no_grad():
#         eval_output = model(eval_track_features, eval_edge_index)
#         eval_loss = criterion(eval_output, eval_target)

#         # Calculate evaluation metrics
#         # ...


# # Dynamic Edge Construction
# def get_dynamic_edges(track_features, k=20):
#     """
#     Construct dynamic k-nearest neighbor edges based on track features.

#     Args:
#         track_features (torch.Tensor): (N, num_features)
#         k (int): Number of nearest neighbors

#     Returns:
#         edge_index (torch.Tensor): (2, E)
#     """
#     N = track_features.size(0)
#     edge_index = []

#     for i in range(N):
#         dists = torch.sum((track_features - track_features[i]) ** 2, dim=1)
#         _, indices = torch.topk(dists, k=k + 1)
#         for j in indices[1:]:
#             edge_index.append([i, j])
#             edge_index.append([j, i])

#     return torch.tensor(edge_index, dtype=torch.long).t()


# # Vertex Classification
# def classify_vertices(model, track_features, edge_index):
#     """
#     Classify primary vertices using the Dynamic Graph CNN model.

#     Args:
#         model (DynamicGNN): Trained model
#         track_features (torch.Tensor): (N, num_features)
#         edge_index (torch.Tensor): (2, E)

#     Returns:
#         vertex_scores (torch.Tensor): (N,) Vertex classification scores
#     """
#     model.eval()
#     with torch.no_grad():
#         output = model(track_features, edge_index)
#         vertex_scores = output.squeeze()
#     return vertex_scores

# Var 3


In [None]:
# class GraphNeuralNetwork(nn.Module):
#     def __init__(self, in_features, hidden_features, out_features):
#         super(GraphNeuralNetwork, self).__init__()
#         self.conv1 = GCNConv(in_features, hidden_features)
#         self.conv2 = GCNConv(hidden_features, out_features)

#     def forward(self, x, edge_index):
#         x = self.conv1(x, edge_index)
#         x = F.relu(x)
#         x = self.conv2(x, edge_index)
#         return x