In [18]:
import os
import sys
import time
import yaml
import torch
import pandas as pd
import numpy as np 
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Dataset, Data
from torch_geometric.loader import DataLoader
from pytorch_lightning import LightningModule
from torch_geometric.nn import aggr
from torch.utils.checkpoint import checkpoint
import torch.optim as optim
from torch_scatter import scatter_add
from torch import Tensor
from torch_geometric.nn import MessagePassing

from torch_geometric.utils import to_networkx
from torch.nn import Sequential as Seq, Linear, ReLU, Sigmoid
from torch.optim.lr_scheduler import StepLR
from collections import namedtuple

from pytorch_lightning import LightningModule
import pytorch_lightning as pl


In [19]:

TORCH = torch.__version__
CUDA = torch.version.cuda

print(f"Formatted PyTorch Version: {TORCH}")
print(f"Formatted CUDA Version: {CUDA}")


Formatted PyTorch Version: 2.2.1+cu121
Formatted CUDA Version: 12.1


In [20]:
import os
import torch
from torch.utils.data import Dataset

class GraphDataset(Dataset):
    def __init__(self, home_dir, sub_dir, preprocess=True, hparams=None):
        self.base_path = os.path.join(home_dir, sub_dir)
        self.file_names = self._get_file_names()
        self.preprocess = preprocess
        self.hparams = hparams if hparams is not None else {}

    def _get_file_names(self):
        file_names = []
        for file_name in os.listdir(self.base_path):
            if file_name.endswith('.pyg'):  # Adjust this condition as needed
                file_names.append(file_name)
        return file_names

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        file_path = os.path.join(self.base_path, file_name)
        data = torch.load(file_path)
        print(f"Loaded data from {file_path}")

        # Remove 'scores' from data if it exists
        if 'scores' in data:
            del data['scores']

        
        if self.preprocess:
            data = self.preprocess_event(data)
        return data

    def __len__(self):
        return len(self.file_names)

    def preprocess_event(self, data):
        print("Preprocessing data")
        data = self.add_edge_features(data)
        return data

    def add_edge_features(self, data):
        edge_features = self.hparams.get("edge_features", [])
        data = handle_edge_features(data, edge_features)
        return data

def handle_edge_features(data, edge_features):
    src, dst = data.edge_index
    
    if "dr" in edge_features and not ("dr" in data.keys()):
        data.dr = data.r[dst] - data.r[src]
    
    return data

In [21]:

hparams = {
    "stage": "edge_classifier",
    "model": "InteractionGNN2",
    "input_dir": "/users/santoshp/standalone_IN_gnn/data/",
    "stage_dir": "/users/santoshp/standalone_IN_gnn/",
    "project": "Stand_alone",
    "gpus": 1,
    "nodes": 1,
    "data_split": [10, 5, 5],
    "dataset_class": "GraphDataset",
    "undirected": False,
    "weighting": [
        {"weight": 0.1, "conditions": {"y": False}},
        {"weight": 0.0, "conditions": {"y": True}},
        {
            "weight": 1.0,
            "conditions": {
                "y": True,
                "pt": [1000, float("inf")],
                "nhits": [3, float("inf")],
                "primary": True,
                "pdgId": ["not_in", [11, -11]],
                "radius": [0.0, 260.0],
                "eta_particle": [-4.0, 4.0],
                "redundant_split_edges": False,
            },
        },
    ],
    "edge_cut": 0.5,
    "node_features": [
        "r",
        "phi",
        "z",
        "eta",
        "cluster_r_1",
        "cluster_phi_1",
        "cluster_z_1",
        "cluster_eta_1",
        "cluster_r_2",
        "cluster_phi_2",
        "cluster_z_2",
        "cluster_eta_2",
    ],
    "node_scales": [
        1000.0,
        3.14159265359,
        1000.0,
        1.0,
        1000.0,
        3.14159265359,
        1000.0,
        1.0,
        1000.0,
        3.14159265359,
        1000.0,
        1.0,
    ],
    "edge_features": ["dr", "dphi", "dz", "deta", "phislope", "rphislope"],
    "hidden": 128,
    "n_graph_iters": 8,
    "n_node_encoder_layers": 3,
    "n_edge_encoder_layers": 3,
    "n_node_net_layers": 3,
    "n_edge_net_layers": 3,
    "n_node_decoder_layers": 3,
    "n_edge_decoder_layers": 3,
    "layernorm": False,
    "output_layer_norm": False,
    "edge_output_transform_final_layer_norm": False,
    "batchnorm": False,
    "output_batch_norm": False,
    "edge_output_transform_final_batch_norm": False,
    "bn_track_running_stats": False,
    "hidden_activation": "ReLU",
    "output_activation": "ReLU",
    "edge_output_transform_final_activation": None,
    "concat": True,
    "node_net_recurrent": False,
    "edge_net_recurrent": False,
    "in_out_diff_agg": True,
    "checkpointing": True,
    "warmup": 5,
    "lr": 0.0005,
    "min_lr": 0.000005,
    "factor": 0.9,
    "patience": 15,
    "max_epochs": 1,
    "max_training_graph_size": 2800000,
    "debug": False,
    "num_workers": [8, 8, 8],
}


In [22]:
home_dir = '/users/santoshp/standalone_IN_gnn/data/'

test = 'test_set/'
train = 'train_set/'
val = 'val_set/'


test_dataset = GraphDataset(home_dir, test, preprocess=True, hparams=hparams)
train_dataset = GraphDataset(home_dir, train, preprocess=True, hparams=hparams)
val_dataset = GraphDataset(home_dir, val, preprocess=True, hparams=hparams)


print(train_dataset[0])

Loaded data from /users/santoshp/standalone_IN_gnn/data/train_set/event005000002.pyg
Preprocessing data
DataBatch(x=[305932], edge_index=[2, 42890], y=[42890], cluster_x_2=[305932], z=[305932], r=[305932], norm_z_1=[305932], norm_y_2=[305932], cluster_z_2=[305932], cluster_y_1=[305932], hit_id=[305932], norm_y_1=[305932], phi_angle_1=[305932], cluster_r_1=[305932], cluster_phi_1=[305932], phi_angle_2=[305932], cluster_eta_2=[305932], norm_x_2=[305932], module_id=[305932], cluster_z_1=[305932], phi=[305932], norm_z_2=[305932], cluster_y_2=[305932], region=[305932], eta_angle_2=[305932], cluster_x_1=[305932], eta_angle_1=[305932], eta=[305932], norm_x_1=[305932], cluster_eta_1=[305932], cluster_phi_2=[305932], cluster_r_2=[305932], track_edges=[2, 15712], particle_id=[15712], pt=[15712], pdgId=[15712], radius=[15712], redundant_split_edges=[15712], primary=[15712], eta_particle=[15712], nhits=[15712], config=[2], event_id=[1], truth_map=[15712], phi_region_id=[1], eta_region_id=[1], weig

In [23]:
params = {'batch_size': 8, 'shuffle': True, 'num_workers': 0}
train_loader = DataLoader(train_dataset,**params)  #batches join graphs instead of splitting them therefore more than train set batches will make 1 batch only 
test_loader = DataLoader(test_dataset, **params)
val_loader = DataLoader(val_dataset, **params)

In [24]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [25]:
def make_mlp(
    input_size,
    sizes,
    hidden_activation="ReLU",
    output_activation=None,
    layer_norm=False,  # TODO : change name to hidden_layer_norm while ensuring backward compatibility
    output_layer_norm=False,
    batch_norm=False,  # TODO : change name to hidden_batch_norm while ensuring backward compatibility
    output_batch_norm=False,
    input_dropout=0,
    hidden_dropout=0,
    track_running_stats=False,
):
    """Construct an MLP with specified fully-connected layers."""
    hidden_activation = getattr(nn, hidden_activation)
    if output_activation is not None:
        output_activation = getattr(nn, output_activation)
    layers = []
    n_layers = len(sizes)
    sizes = [input_size] + sizes
    # Hidden layers
    for i in range(n_layers - 1):
        if i == 0 and input_dropout > 0:
            layers.append(nn.Dropout(input_dropout))
        layers.append(nn.Linear(sizes[i], sizes[i + 1]))
        if layer_norm:  # hidden_layer_norm
            layers.append(nn.LayerNorm(sizes[i + 1], elementwise_affine=False))
        if batch_norm:  # hidden_batch_norm
            layers.append(
                nn.BatchNorm1d(
                    sizes[i + 1],
                    eps=6e-05,
                    track_running_stats=track_running_stats,
                    affine=True,
                )  # TODO : Set BatchNorm and LayerNorm parameters in config file ?
            )
        layers.append(hidden_activation())
        if hidden_dropout > 0:
            layers.append(nn.Dropout(hidden_dropout))
    # Final layer
    layers.append(nn.Linear(sizes[-2], sizes[-1]))
    if output_activation is not None:
        if output_layer_norm:
            layers.append(nn.LayerNorm(sizes[-1], elementwise_affine=False))
        if output_batch_norm:
            layers.append(
                nn.BatchNorm1d(
                    sizes[-1],
                    eps=6e-05,
                    track_running_stats=track_running_stats,
                    affine=True,
                )  # TODO : Set BatchNorm and LayerNorm parameters in config file ?
            )
        layers.append(output_activation())
    return nn.Sequential(*layers)

In [26]:
class InteractionGNN2(LightningModule):

    """
    Interaction Network (L2IT version).
    Operates on directed graphs.
    Aggregate and reduce (sum) separately incomming and outcoming edges latents.
    """

    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters(hparams)

        hparams["batchnorm"] = (
            False if "batchnorm" not in hparams else hparams["batchnorm"]
        )
        hparams["output_batch_norm"] = hparams.get("output_batch_norm", False)
        hparams["edge_output_transform_final_batch_norm"] = hparams.get(
            "edge_output_transform_final_batch_norm", False
        )
        hparams["edge_output_transform_final_batch_norm"] = hparams.get(
            "edge_output_transform_final_batch_norm", False
        )
        hparams["track_running_stats"] = (
            False
            if "track_running_stats" not in hparams
            else hparams["track_running_stats"]
        )

        # TODO: Add equivalent check and default values for other model parameters ?
        # TODO: Use get() method

        # Define the dataset to be used, if not using the default
        #self.save_hyperparameters(hparams)

        # self.setup_layer_sizes()

        if hparams["concat"]:
            if hparams["in_out_diff_agg"]:
                in_node_net = hparams["hidden"] * 4
            else:
                in_node_net = hparams["hidden"] * 3
            in_edge_net = hparams["hidden"] * 6
        else:
            if hparams["in_out_diff_agg"]:
                in_node_net = hparams["hidden"] * 3
            else:
                in_node_net = hparams["hidden"] * 2
            in_edge_net = hparams["hidden"] * 3
        # node encoder
        self.node_encoder = make_mlp(
            input_size=len(hparams["node_features"]),
            sizes=[hparams["hidden"]] * hparams["n_node_net_layers"],
            output_activation=hparams["output_activation"],
            hidden_activation=hparams["hidden_activation"],
            layer_norm=hparams["layernorm"],
            batch_norm=hparams["batchnorm"],
            output_batch_norm=hparams["output_batch_norm"],
            track_running_stats=hparams["track_running_stats"],
        )
     
        # edge encoder
        if "edge_features" in hparams and len(hparams["edge_features"]) != 0:
            self.edge_encoder = make_mlp(
                input_size=len(hparams["edge_features"]),
                sizes=[hparams["hidden"]] * hparams["n_edge_net_layers"],
                output_activation=hparams["output_activation"],
                hidden_activation=hparams["hidden_activation"],
                layer_norm=hparams["layernorm"],
                batch_norm=hparams["batchnorm"],
                output_batch_norm=hparams["output_batch_norm"],
                track_running_stats=hparams["track_running_stats"],
            )
        else:
            self.edge_encoder = make_mlp(
                input_size=2 * hparams["hidden"],
                sizes=[hparams["hidden"]] * hparams["n_edge_net_layers"],
                output_activation=hparams["output_activation"],
                hidden_activation=hparams["hidden_activation"],
                layer_norm=hparams["layernorm"],
                batch_norm=hparams["batchnorm"],
                output_batch_norm=hparams["output_batch_norm"],
                track_running_stats=hparams["track_running_stats"],
            )

        # edge network
        if hparams["edge_net_recurrent"]:
            self.edge_network = make_mlp(
                input_size=in_edge_net,
                sizes=[hparams["hidden"]] * hparams["n_edge_net_layers"],
                output_activation=hparams["output_activation"],
                hidden_activation=hparams["hidden_activation"],
                layer_norm=hparams["layernorm"],
                batch_norm=hparams["batchnorm"],
                output_batch_norm=hparams["output_batch_norm"],
                track_running_stats=hparams["track_running_stats"],
            )
        else:
            self.edge_network = nn.ModuleList(
                [
                    make_mlp(
                        input_size=in_edge_net,
                        sizes=[hparams["hidden"]] * hparams["n_edge_net_layers"],
                        output_activation=hparams["output_activation"],
                        hidden_activation=hparams["hidden_activation"],
                        layer_norm=hparams["layernorm"],
                        batch_norm=hparams["batchnorm"],
                        output_batch_norm=hparams["output_batch_norm"],
                        track_running_stats=hparams["track_running_stats"],
                    )
                    for i in range(hparams["n_graph_iters"])
                ]
            )
        # node network
        if hparams["node_net_recurrent"]:
            self.node_network = make_mlp(
                input_size=in_node_net,
                sizes=[hparams["hidden"]] * hparams["n_node_net_layers"],
                output_activation=hparams["output_activation"],
                hidden_activation=hparams["hidden_activation"],
                layer_norm=hparams["layernorm"],
                batch_norm=hparams["batchnorm"],
                output_batch_norm=hparams["output_batch_norm"],
                track_running_stats=hparams["track_running_stats"],
            )
        else:
            self.node_network = nn.ModuleList(
                [
                    make_mlp(
                        input_size=in_node_net,
                        sizes=[hparams["hidden"]] * hparams["n_node_net_layers"],
                        output_activation=hparams["output_activation"],
                        hidden_activation=hparams["hidden_activation"],
                        layer_norm=hparams["layernorm"],
                        batch_norm=hparams["batchnorm"],
                        output_batch_norm=hparams["output_batch_norm"],
                        track_running_stats=hparams["track_running_stats"],
                    )
                    for i in range(hparams["n_graph_iters"])
                ]
            )

        # edge decoder
        self.edge_decoder = make_mlp(
            input_size=hparams["hidden"],
            sizes=[hparams["hidden"]] * hparams["n_edge_decoder_layers"],
            output_activation=hparams["output_activation"],
            hidden_activation=hparams["hidden_activation"],
            layer_norm=hparams["layernorm"],
            batch_norm=hparams["batchnorm"],
            output_batch_norm=hparams["output_batch_norm"],
            track_running_stats=hparams["track_running_stats"],
        )
        # edge output transform layer
        self.edge_output_transform = make_mlp(
            input_size=hparams["hidden"],
            sizes=[hparams["hidden"], 1],
            output_activation=hparams["edge_output_transform_final_activation"],
            hidden_activation=hparams["hidden_activation"],
            layer_norm=hparams["layernorm"],
            batch_norm=hparams["batchnorm"],
            output_batch_norm=hparams["edge_output_transform_final_batch_norm"],
            track_running_stats=hparams["track_running_stats"],
        )

        # dropout layer
        self.dropout = nn.Dropout(p=0.1)
        # hyperparams
        # self.hparams = hparams



    ###############################

    def training_step(self, batch, batch_idx):
        # Forward pass
        output = self(batch)
        
        # Dummy loss calculation
        loss = output.mean()
        self.log('train_loss', loss, on_epoch=True, prog_bar=True)

        return loss
        
    def validation_step(self, batch, batch_idx):
        # Forward pass
        output = self(batch)
        
        # Dummy loss calculation
        loss = output.mean()
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)

        return loss
        
    def testing_step(self, batch, batch_idx):
        # Forward pass
        output = self(batch)
        
        # Dummy loss calculation
        loss = output.mean()

        return loss
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.get("lr", 0.001))
        return optimizer


    def train_dataloader(self):
        """
        Load the training set.
        """
        return DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)

    
    def val_dataloader(self):
        """
        Load the val set.
        """
        return DataLoader(val_dataset, batch_size=1, shuffle=True, num_workers=4)

    
    def test_dataloader(self):
        """
        Load the test set.
        """
        return DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=4)
    ###########################

    def forward(self, batch):
        x = torch.stack(
            [batch[feature] for feature in self.hparams["node_features"]], dim=-1
        ).float()

        # Same features on the 3 channels in the STRIP ENDCAP TODO: Process it in previous stage
        mask = torch.logical_or(batch.region == 2, batch.region == 6).reshape(-1)
        x[mask] = torch.cat([x[mask, 0:4], x[mask, 0:4], x[mask, 0:4]], dim=1)
        # print(x[:, 8:12])
        

        if "edge_features" in self.hparams and len(self.hparams) != 0:
            edge_attr = torch.stack(
                [batch[feature] for feature in self.hparams["edge_features"]], dim=-1
            ).float()
        else:
            edge_attr = None

        x.requires_grad = True
        if edge_attr is not None:
            edge_attr.requires_grad = True

        # Get src and dst
        src, dst = batch.edge_index

####################
         # Call handle_edge_features function to handle edge features
        #handle_edge_features(batch, self.hparams.get("edge_features", []))
####################

        # Encode nodes and edges features into latent spaces
        if self.hparams["checkpointing"]:
            x = checkpoint(self.node_encoder, x)
            if edge_attr is not None:
                e = checkpoint(self.edge_encoder, edge_attr)
            else:
                e = checkpoint(self.edge_encoder, torch.cat([x[src], x[dst]], dim=-1))
        else:
            x = self.node_encoder(x)
            if edge_attr is not None:
                e = self.edge_encoder(edge_attr)
            else:
                e = self.edge_encoder(torch.cat([x[src], x[dst]], dim=-1))
        # Apply dropout
        # x = self.dropout(x)
        # e = self.dropout(e)

        # memorize initial encodings for concatenate in the gnn loop if request
        if self.hparams["concat"]:
            input_x = x
            input_e = e
        # Initialize outputs
        outputs = []
        # Loop over gnn layers
        for i in range(self.hparams["n_graph_iters"]):
            if self.hparams["checkpointing"]:
                if self.hparams["concat"]:
                    x = checkpoint(self.concat, x, input_x)
                    e = checkpoint(self.concat, e, input_e)
                if (
                    self.hparams["node_net_recurrent"]
                    and self.hparams["edge_net_recurrent"]
                ):
                    x, e, out = checkpoint(self.message_step, x, e, src, dst)
                else:
                    x, e, out = checkpoint(self.message_step, x, e, src, dst, i)
            else:
                if self.hparams["concat"]:
                    x = torch.cat([x, input_x], dim=-1)
                    e = torch.cat([e, input_e], dim=-1)
                if (
                    self.hparams["node_net_recurrent"]
                    and self.hparams["edge_net_recurrent"]
                ):
                    x, e, out = self.message_step(x, e, src, dst)
                else:
                    x, e, out = self.message_step(x, e, src, dst, i)
            outputs.append(out)
        return outputs[-1].squeeze(-1)

    def message_step(self, x, e, src, dst, i=None):
        #print("Shape of e tensor:", e.size())
        #print("Shape of x[src] tensor:", x[src].size())
        #print("Shape of x[dst] tensor:", x[dst].size())
        
        #assert e.shape[0] == x[src].shape[0] #"Number of rows in e and x[src] must match"
        x_src = x[src].unsqueeze(0)  # Add an extra dimension
        x_dst = x[dst].unsqueeze(0)  # Add an extra dimension


        edge_inputs = torch.cat([e, x[src], x[dst]], dim=-1)  # order dst src x ?
        if self.hparams["edge_net_recurrent"]:
            e_updated = self.edge_network(edge_inputs)
        else:
            e_updated = self.edge_network[i](edge_inputs)
        # Update nodes
        edge_messages_from_src = scatter_add(e_updated, dst, dim=0, dim_size=x.shape[0])
        edge_messages_from_dst = scatter_add(e_updated, src, dim=0, dim_size=x.shape[0])
        if self.hparams["in_out_diff_agg"]:
            node_inputs = torch.cat(
                [edge_messages_from_src, edge_messages_from_dst, x], dim=-1
            )  # to check : the order dst src  x ?
        else:
            # add message from src and dst ?? # edge_messages = edge_messages_from_src + edge_messages_from_dst
            edge_messages = edge_messages_from_src + edge_messages_from_dst
            node_inputs = torch.cat([edge_messages, x], dim=-1)
        # x_updated = self.dropout(self.node_network[i](node_inputs))
        if self.hparams["node_net_recurrent"]:
            x_updated = self.node_network(node_inputs)
        else:
            x_updated = self.node_network[i](node_inputs)

        return (
            x_updated,
            e_updated,
            self.edge_output_transform(self.edge_decoder(e_updated)),
        )

    def concat(self, x, y):
        return torch.cat([x, y], dim=-1)



In [27]:
model = InteractionGNN2(hparams)
model

InteractionGNN2(
  (node_encoder): Sequential(
    (0): Linear(in_features=12, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=128, bias=True)
    (5): ReLU()
  )
  (edge_encoder): Sequential(
    (0): Linear(in_features=6, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=128, bias=True)
    (5): ReLU()
  )
  (edge_network): ModuleList(
    (0-7): 8 x Sequential(
      (0): Linear(in_features=768, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): ReLU()
      (4): Linear(in_features=128, out_features=128, bias=True)
      (5): ReLU()
    )
  )
  (node_network): ModuleList(
    (0-7): 8 x Sequential(
      (0): Linear(in_features=512, out_features=128, bias=True)
      (1): ReLU(

In [28]:
trainer = pl.Trainer(max_epochs=2)
trainer.fit(model) 

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name                  | Type       | Params
-----------------------------------------------------
0 | node_encoder          | Sequential | 34.7 K
1 | edge_encoder          | Sequential | 33.9 K
2 | edge_network          | ModuleList | 1.1 M 
3 | node_network          | ModuleList | 789 K 
4 | edge_decoder          | Sequential | 49.5 K
5 | edge_output_transform | Sequential | 16.6 K
6 | dropout               | Dropout    | 0     
-----------------------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params
7.904     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]Loaded data from /users/santoshp/standalone_IN_gnn/data/val_set/event005000003.pyg
Preprocessing data
Epoch 0:   0%|          | 0/1 [00:00<?, ?it/s]                             Loaded data from /users/santoshp/standalone_IN_gnn/data/train_set/event005000002.pyg
Preprocessing data
Epoch 0: 100%|██████████| 1/1 [00:51<00:00,  0.02it/s, v_num=60, train_loss_step=0.055]Loaded data from /users/santoshp/standalone_IN_gnn/data/val_set/event005000003.pyg
Preprocessing data

Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0: 100%|██████████| 1/1 [00:10<00:00,  0.09it/s][A
Epoch 1:   0%|          | 0/1 [00:00<?, ?it/s, v_num=60, train_loss_step=0.055, val_loss=0.0182, train_loss_epoch=0.055]        Loaded data from /users/santoshp/standalone_IN_gnn/data/train_set/event005000002.pyg
Preprocessing data
Epoch

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 1/1 [01:06<00:00,  0.02it/s, v_num=60, train_loss_step=0.0192, val_loss=-0.0312, train_loss_epoch=0.0192]


In [29]:
torch.save(model, 'complete_model.pth')

In [30]:
model = torch.load('complete_model.pth')
model.eval()


InteractionGNN2(
  (node_encoder): Sequential(
    (0): Linear(in_features=12, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=128, bias=True)
    (5): ReLU()
  )
  (edge_encoder): Sequential(
    (0): Linear(in_features=6, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=128, bias=True)
    (5): ReLU()
  )
  (edge_network): ModuleList(
    (0-7): 8 x Sequential(
      (0): Linear(in_features=768, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): ReLU()
      (4): Linear(in_features=128, out_features=128, bias=True)
      (5): ReLU()
    )
  )
  (node_network): ModuleList(
    (0-7): 8 x Sequential(
      (0): Linear(in_features=512, out_features=128, bias=True)
      (1): ReLU(

In [31]:
with torch.no_grad():
    for batch in val_loader:
        edge_scores = model(batch)
        print(edge_scores)
        print(edge_scores.size(0))
        positive_scores = edge_scores[edge_scores > 0]
        print(positive_scores)
        print(positive_scores.size(0))

Loaded data from /users/santoshp/standalone_IN_gnn/data/val_set/event005000003.pyg
Preprocessing data
tensor([ 0.0078, -0.0145, -0.0326,  ..., -0.2122, -0.2273, -0.2223])
39573
tensor([0.0078, 0.0099, 0.0036,  ..., 0.0007, 0.0019, 0.0033])
16543
