# Graph Neural Network Training

Creation of graphs were done with 
```tools/training/GraphCreationModel.py```
Files can be found in: 
```/eos/cms/store/user/folguera/L1TMuon/INTREPID/Graphs_v240725_241015/```
in two flavours, with "all" connected layers and with up to "3-neighbour" layers connections. 

In [2]:
import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
import os, sys

import matplotlib.pyplot as plt


In [3]:
class GATRegressor(torch.nn.Module):
    def __init__(self, num_node_features, hidden_dim, output_dim=3):
        super(GATRegressor, self).__init__()
        self.conv1 = GATConv(num_node_features, hidden_dim)
        self.conv2 = GATConv(hidden_dim, hidden_dim)
        self.fc1 = torch.nn.Linear(hidden_dim * 2, output_dim)

    def forward(self, data):
        # load nodel attributes: x, and edge attributes: deltaPhi and deltaEta
        x, edge_index, deltaPhi, deltaEta, batch = data.x.float(), data.edge_index, data.deltaPhi.float(), data.deltaEta.float(), data.batch

        # Combine deltaPhi and deltaEta into edge_attr
        edge_attr = torch.stack([deltaPhi, deltaEta], dim=1)
        # Apply graph convolutions
        x = F.relu(x)
        x = self.conv1(x, edge_index, edge_attr=edge_attr)  # Using GAT as it allow to use edge attributes
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_attr=edge_attr)

        # Global mean pooling to get graph-level output
        # x = gmp(x, batch)
        x = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        # Fully connected layers for regression
        x = self.fc1(x)
        return x


## Load data

In [4]:
## check if EOS folder exists otherwise use local folder
if os.path.exists("/eos/cms/store/user/folguera/L1TMuon/INTREPID/Graphs_v240725_241015/"):
    GraphDIR = "/eos/cms/store/user/folguera/L1TMuon/INTREPID/Graphs_v240725_241015/"
else:
    GraphDIR = "../graph_folder/"
using_only = 5  ## number of files used

In [5]:
Allgraphs = []
all_files = os.listdir(GraphDIR)

# Filter for .pkl files
pkl_files = [f for f in all_files if f.endswith('.pkl') and '_3_' in f]
print(f"Using files: {pkl_files}")
if not pkl_files:
    print("No .pkl files found in the directory.")
    sys.exit()

Using files: ['vix_graph_3_15Oct_onlypt_001.pkl', 'vix_graph_3_15Oct_onlypt_002.pkl', 'vix_graph_3_15Oct_onlypt_003.pkl', 'vix_graph_3_15Oct_onlypt_004.pkl', 'vix_graph_3_15Oct_onlypt_005.pkl', 'vix_graph_3_15Oct_onlypt_006.pkl', 'vix_graph_3_15Oct_onlypt_007.pkl', 'vix_graph_3_15Oct_onlypt_008.pkl', 'vix_graph_3_15Oct_onlypt_009.pkl']


In [6]:
count_files = 0
for pkl_file in pkl_files:
    if count_files >= using_only: break
    file_path = os.path.join(GraphDIR, pkl_file)
    print(f"Loading file: {pkl_file}")
    with open(file_path, 'rb') as file:
        graphfile = torch.load(file)
        Allgraphs.append(graphfile)
    count_files+=1

Loading file: vix_graph_3_15Oct_onlypt_001.pkl


  graphfile = torch.load(file)


Loading file: vix_graph_3_15Oct_onlypt_002.pkl
Loading file: vix_graph_3_15Oct_onlypt_003.pkl
Loading file: vix_graph_3_15Oct_onlypt_004.pkl
Loading file: vix_graph_3_15Oct_onlypt_005.pkl


In [7]:
BatchSize=64

Graphs_for_training = sum(Allgraphs, [])
Graphs_for_training_reduced = Graphs_for_training
Graphs_for_training_filtered = [g for g in Graphs_for_training_reduced if g.edge_index.size(1) > 0]  # remove empty graphs

# remove extra dimenson in y
print(f"Total Graphs: {len(Graphs_for_training)}")
for i in range(0, len(Graphs_for_training)):
    Graphs_for_training_reduced[i].y = Graphs_for_training[i].y.mean(dim=0)

# Train and test split:
events = len(Graphs_for_training_reduced)
ntrain = int((events * 0.7) / BatchSize) * BatchSize  # to have full batches
print(f"Training events: {ntrain}")
train_dataset = Graphs_for_training_reduced[:ntrain]
test_dataset = Graphs_for_training_reduced[ntrain:ntrain * 2]

print("====================================")
print("Example of data:")
print(train_dataset[0].x)
print(train_dataset[0].edge_index)
print(train_dataset[0].deltaPhi)
print(train_dataset[0].deltaEta)
print(train_dataset[0].y)
print("====================================")

# Load data
train_loader = DataLoader(train_dataset, batch_size=BatchSize, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BatchSize, shuffle=False)



Total Graphs: 279187
Training events: 195392
Example of data:
tensor([[1.0331e+00, 9.8300e+02, 1.0060e+03, 6.0000e+00, 9.0000e+00],
        [1.0114e+00, 9.7700e+02, 1.0767e+03, 1.5000e+01, 5.0000e+00],
        [1.0440e+00, 9.9100e+02, 1.1959e+03, 7.0000e+00, 9.0000e+00],
        [1.1092e+00, 9.5400e+02, 4.1368e+02, 1.0000e+01, 5.0000e+00]],
       dtype=torch.float64)
tensor([[0, 0, 0, 1, 1, 2, 2, 3],
        [1, 2, 3, 0, 2, 0, 1, 0]])
tensor([ -6,   8, -29,  -6, -14,   8, -14, -29])
tensor([-0.0217,  0.0109,  0.0761, -0.0217, -0.0326,  0.0109, -0.0326,  0.0761])
tensor([-10.3202])


In [20]:
print(train_loader.dataset[0].deltaPhi)

tensor([ -6,   8, -29,  -6, -14,   8, -14, -29])


### Training loop

In [14]:
num_node_features = 5
hidden_dim = BatchSize
output_dim = 1
LearningRate=0.0005
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Using device: {device}")
model = GATRegressor(num_node_features, hidden_dim, output_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LearningRate, weight_decay=0.75)
loss_fn = torch.nn.MSELoss()
print("Model initialized")
print(model)


Using device: cpu
Model initialized
GATRegressor(
  (conv1): GATConv(5, 64, heads=1)
  (conv2): GATConv(64, 64, heads=1)
  (fc1): Linear(in_features=128, out_features=1, bias=True)
)


In [15]:
train_losses = []
test_losses = []
#
# path = "/eos/cms/store/user/folguera/L1TMuon/INTREPID/Model_v240725_241015/"
path = "../model_folder/"
if not os.path.exists(path):
    os.makedirs(path)

In [11]:
def train():
    model.train()
    total_loss = 0
    i = 0
    for data in train_loader:
        data = data.to(device)  # Mueve los datos al dispositivo
        optimizer.zero_grad()
        out = model(data)
        try:
            loss = loss_fn(out, data.y.view(out.size()))
        except KeyError as e:
            print(f"KeyError: {e}")
            # Maneja el error o proporciona un valor por defecto
            loss = torch.tensor(0.0, device=device)

        loss.backward()
        optimizer.step()
        total_loss += float(loss)
    return total_loss / len(train_loader.dataset)

def test(self):
    with torch.no_grad():
        model.eval()
        total_loss = 0
        for data in test_loader:
            data = data.to(device)
            out = model(data)
            loss = loss_fn(out, data.y.view(out.size()))
            total_loss += float(loss)
    return total_loss / len(test_loader.dataset)


print("Start training...")
for epoch in range(100):
    train_loss = train()
    test_loss = test()
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    
    if epoch == 0:
        torch.save(test_loss, f"{path}/testloss_{epoch + 1}.pt")
        torch.save(train_loss, f"{path}/trainloss_{epoch + 1}.pt")
    elif (epoch + 1) % 10 == 0:
        print(f'Epoch: {epoch + 1:02d}, Train loss: {train_loss:.4f}, Test loss: {test_loss:.4f}')
        torch.save(model, f"{path}/model_{epoch + 1}.pth")
        torch.save(test_loss, f"{path}/testloss_{epoch + 1}.pt")
        torch.save(train_loss, f"{path}/trainloss_{epoch + 1}.pt")

        plt.plot(train_losses, "b", label="Train loss")
        plt.plot(test_losses, "k", label="Test loss")
        plt.yscale('log')
        plt.savefig(f"{path}/loss_plot.png")


Start training...
DataBatch(x=[376, 5], edge_index=[2, 1328], y=[64], deltaPhi=[1328], deltaEta=[1328], batch=[376], ptr=[65])
DataBatch(x=[378, 5], edge_index=[2, 1234], y=[64], deltaPhi=[1234], deltaEta=[1234], batch=[378], ptr=[65])
DataBatch(x=[359, 5], edge_index=[2, 1152], y=[64], deltaPhi=[1152], deltaEta=[1152], batch=[359], ptr=[65])
DataBatch(x=[365, 5], edge_index=[2, 1148], y=[64], deltaPhi=[1148], deltaEta=[1148], batch=[365], ptr=[65])
DataBatch(x=[383, 5], edge_index=[2, 1258], y=[64], deltaPhi=[1258], deltaEta=[1258], batch=[383], ptr=[65])
DataBatch(x=[340, 5], edge_index=[2, 1000], y=[64], deltaPhi=[1000], deltaEta=[1000], batch=[340], ptr=[65])
DataBatch(x=[377, 5], edge_index=[2, 1268], y=[64], deltaPhi=[1268], deltaEta=[1268], batch=[377], ptr=[65])
DataBatch(x=[380, 5], edge_index=[2, 1250], y=[64], deltaPhi=[1250], deltaEta=[1250], batch=[380], ptr=[65])
DataBatch(x=[364, 5], edge_index=[2, 1192], y=[64], deltaPhi=[1192], deltaEta=[1192], batch=[364], ptr=[65])
D

KeyError: 'deltaPhi'