In [8]:
import argparse
import copy
import math
import os
from typing import Dict, List

import numpy as np
import torch
from inferred_stypes import dataset2inferred_stypes
from model import Model
from text_embedder import GloveTextEmbedding
from torch.nn import BCEWithLogitsLoss, L1Loss
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch_geometric.data import HeteroData
from torch_geometric.loader import NeighborLoader
from torch_geometric.seed import seed_everything
from tqdm import tqdm

from relbench.data import NodeTask, RelBenchDataset
from relbench.data.task_base import TaskType
from relbench.datasets import get_dataset
from relbench.external.graph import get_node_train_table_input, make_pkey_fkey_graph

from torch_geometric.data import HeteroData
from torch_geometric.explain import Explainer, CaptumExplainer
import dgl

from dgl.nn import HeteroGNNExplainer
from torch_geometric.data import HeteroData


In [12]:
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="rel-stackex")
parser.add_argument("--task", type=str, default="rel-stackex-engage")
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=512)
parser.add_argument("--channels", type=int, default=128)
parser.add_argument("--aggr", type=str, default="sum")
parser.add_argument("--num_layers", type=int, default=2)
parser.add_argument("--num_neighbors", type=int, default=128)
parser.add_argument("--temporal_strategy", type=str, default="uniform")
parser.add_argument("--num_workers", type=int, default=1)
args = parser.parse_args("")


In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed_everything(42)

root_dir = "./data"

# TODO: remove process=True once correct data/task is uploaded.
dataset: RelBenchDataset = get_dataset(name=args.dataset, process=True)
task: NodeTask = dataset.get_task(args.task, process=True)

making Database object from raw files...
done in 60.87 seconds.
reindexing pkeys and fkeys...
done in 4.43 seconds.


In [11]:
col_to_stype_dict = dataset2inferred_stypes[args.dataset]

data, col_stats_dict = make_pkey_fkey_graph(
    dataset.db,
    col_to_stype_dict=col_to_stype_dict,
    text_embedder_cfg=TextEmbedderConfig(
        text_embedder=GloveTextEmbedding(device=device), batch_size=256
    ),
    cache_dir=os.path.join(root_dir, f"{args.dataset}_materialized_cache"),
)
# print("data:",data['comments']['tf'])
#print("data:\n",data)


TypeError: Converting from datetime64[ns] to int32 is not supported. Do obj.astype('int64').astype(dtype) instead

In [ ]:
loader_dict: Dict[str, NeighborLoader] = {}
for split, table in [
    ("train", task.train_table),
    ("val", task.val_table),
    ("test", task.test_table),
]:
    table_input = get_node_train_table_input(table=table, task=task)
    entity_table = table_input.nodes[0]
    loader_dict[split] = NeighborLoader(
        data,
        num_neighbors=[
            int(args.num_neighbors / 2**i) for i in range(args.num_layers)
        ],
        time_attr="time",
        input_nodes=table_input.nodes,
        input_time=table_input.time,
        transform=table_input.transform,
        batch_size=args.batch_size,
        temporal_strategy=args.temporal_strategy,
        shuffle=split == "train",
        num_workers=args.num_workers,
        persistent_workers=args.num_workers > 0,
    )

In [ ]:
edge_index_dict = {}
for edge_type in data.edge_types:
    edge_index_dict[edge_type] = data[edge_type].edge_index

#print("edge types:\n",data.edge_types)

x_dict = {}
for node_type in data.node_types:
    # Assuming data[node_type].tf.to_tensor() is the method to convert TensorFrame to a tensor.
    # Adjust this method based on the actual implementation of TensorFrame.
    x_dict[node_type] = data[node_type].tf.to_tensor() if hasattr(data[node_type].tf, 'to_tensor') else data[node_type].tf
#print(x_dict)
#print("node types:\n",data.node_types)

# if __name__ == '__main__':
#     # Your code that creates and starts processes should go here
#     # This ensures that multiprocessing is handled correctly
#     for batch in loader_dict["train"]:
#         batch = batch.to(device)
#         print("batch:\n", batch)


#print("task.train_table\n",task.train_table)
#print("task.entity_table\n",task.entity_table)

# print("table_input:",table_input)
# print("table_input.nodes:",table_input.nodes)
# print("table_input.nodes[0]:",table_input.nodes[0])


In [ ]:
# homogeneous_data = data.to_homogeneous()
# print("homogeneous_data:\n",homogeneous_data)

graph_data = {}

# For each edge type in the HeteroData
for (src_type, edge_type, dst_type), edge_data in data.edge_index_dict.items():
    src_nodes, dst_nodes = edge_data
    # Convert PyG edge index format to DGL format
    graph_data[(src_type, edge_type, dst_type)] = (src_nodes.numpy(), dst_nodes.numpy())

# Create the DGL heterograph
dgl_graph = dgl.heterograph(graph_data)
print("dgl_graph:\n",dgl_graph)
print("dgl_graph.nodes['users']:\n",dgl_graph.nodes['users'])



In [ ]:
import torch

for node_type in data.node_types:
    tensor_frame = data[node_type]['tf']

    # Hypothetical method to extract data. Replace with actual method to access TensorFrame data.
    raw_data = tensor_frame.get_data()  # Assume this returns a dict with 'timestamp' and 'embedding'

    # Convert each part of the data to a PyTorch tensor. You may need to handle data types appropriately.
    for key, value in raw_data.items():
        if isinstance(value, np.ndarray):
            tensor_data = torch.from_numpy(value)
        elif isinstance(value, list):
            tensor_data = torch.tensor(value)
        else:
            raise TypeError(f"Unsupported data type for {key} in TensorFrame.")

        # Assuming you want to concatenate all features into a single tensor for each node
        # You might need to adjust this based on how your model expects the features
        if 'features' not in dgl_graph.nodes[node_type].data:
            dgl_graph.nodes[node_type].data['features'] = tensor_data
        else:
            dgl_graph.nodes[node_type].data['features'] = torch.cat(
                (dgl_graph.nodes[node_type].data['features'], tensor_data), dim=1)

# # Add node features for each node type from HeteroData to DGL graph
# for node_type in data.node_types:
#     # Assuming a direct conversion of TensorFrame to tensor, adjust as needed
#     #dgl_graph.nodes[node_type].data['h'] = data[node_type]['tf'].to_tensor()
#     print("data[node_type]['tf']:\n",data[node_type]['tf'])
#     dgl_graph.nodes[node_type].data['h'] = data[node_type]['tf']
#
#
#
# feat = {ntype: dgl_graph.nodes[ntype].data['h'] for ntype in dgl_graph.ntypes}
#
# print("feat:\n",feat)



In [ ]:
# explainer = Explainer(
#     model,  # It is assumed that model outputs a single tensor.
#     algorithm=CaptumExplainer('IntegratedGradients'),
#     explanation_type='model',
#     node_mask_type='attributes',
#     edge_mask_type='object',
#     model_config = dict(
#         mode='binary_classification',
#         task_level="node",
#         return_type='probs',  # Model returns probabilities.
#     ),
# )
#
# hetero_explanation = explainer(
#     x_dict,
#     edge_index_dict,
#     index=torch.tensor([1, 3]),
#  )
# print(hetero_explanation.edge_mask_dict)
# print(hetero_explanation.node_mask_dict)

In [ ]:

model = torch.load("C:\\Users\\Shreya Reddy\\Downloads\\relbenchmain\\examples\\saved_model.pth",
                   map_location=torch.device('cpu'))
explainer = HeteroGNNExplainer(model, num_hops=1)

In [ ]:
feat_mask, edge_mask = explainer.explain_graph(dgl_graph, feat)