In [1]:
from ogb.nodeproppred import PygNodePropPredDataset
import os
import os.path as osp
import sys
from torch_geometric.data import Data
import torch
import pandas as pd
import numpy as np
from torch_geometric.data import InMemoryDataset, download_url

sys.path.append(osp.abspath(".."))
print(sys.path)

from utils.encoder import SentenceEncoder
from utils_data.custom_pyg import CustomPygDataset
from utils.dataloader import GetDataloader
from utils_data.cora import CoraPyGDataset

import warnings
warnings.filterwarnings("ignore")

['/home/prateek/graphs-with-llms-experiments/utils_data', '/home/prateek/miniconda3/envs/torch_pyg/lib/python310.zip', '/home/prateek/miniconda3/envs/torch_pyg/lib/python3.10', '/home/prateek/miniconda3/envs/torch_pyg/lib/python3.10/lib-dynload', '', '/home/prateek/miniconda3/envs/torch_pyg/lib/python3.10/site-packages', '/home/prateek/graphs-with-llms-experiments']


In [2]:
data_root = "../data"

In [3]:
LMencoder = SentenceEncoder("ST", root="../lang_models", device=1)
custom_cora = CoraPyGDataset(dataRoot=data_root, custom_dataRoot="../custom_data", sentence_encoder=LMencoder)
cora = custom_cora._data
cora.to("cpu")
cora

Data(x=[2708, 384], edge_index=[2, 10858], y=[2708], label_names=[7], num_nodes=2708, x_text_feat=[2708, 768], label_text_feat=[7, 768], edge_text_feat=[1, 768], prompt_text_edge_feat=[1, 768], prompt_text_feat=[1, 768], prompt_edge_feat=[1, 768], edge_label_feat=[2, 768], train_mask=[2708], val_mask=[2708], test_mask=[2708])

In [4]:
import yaml
from datetime import date

with open("../config.yaml", "r") as f:
    args = yaml.safe_load(f)

args["device"] = 'cpu' if args["device"] == 123 else f"cuda:{args['device']}"
args["exp_name"] = f"Date -> {date.today()}. Experiment_{args['sentence_encoder']}_{args['exp_name']}"

args["encoder_path"] = '../lang_models'
args["dataRoot"] = '../data'
args["custom_dataRoot"] = '../custom_data'
args["dataset"] = "cora"
args["batch_count"] = 5

args

{'exp_name': 'Date -> 2024-02-06. Experiment_ST_evaluation-mode',
 'dataRoot': '../data',
 'custom_dataRoot': '../custom_data',
 'dataset': 'cora',
 'sentence_encoder': 'ST',
 'encoder_path': '../lang_models',
 'state_dict_path': './state_dicts',
 'model_option': 2,
 'model_params': [{'name': 'SAGE', 'in_dim': 768},
  {'name': 'RGCN',
   'num_layers': 2,
   'in_dim': 768,
   'out_dim': 256,
   'edge_attr_dim': 768,
   'num_relations': 6,
   'heads': 1,
   'dropout': 0.3,
   'aggr': 'mean',
   'JK': None,
   'batch_norm': True}],
 'lr': 0.001,
 'epochs': 200,
 'batch_count': 5,
 'batch_size': 100,
 'weight_decay': 0.001,
 'seed': None,
 'device': 'cuda:0',
 'n_way': 5,
 'n_shot': 2,
 'n_query': 1,
 'num_neighbors': [-1],
 'subgraph_type': 'induced'}

In [5]:
dl = GetDataloader(**args)
dl

<utils.dataloader.GetDataloader at 0x7fb19078caf0>

In [6]:
batch1 = next(iter(dl.trn_smplr))
batch1

[{4: [688, 330, 1440],
  3: [2116, 1302, 1004],
  1: [495, 1511, 380],
  0: [412, 144, 1583],
  2: [1343, 547, 1602]},
 {4: [1315, 2559, 1825],
  1: [2397, 1129, 1220],
  2: [2466, 547, 845],
  3: [2263, 534, 2246],
  5: [2205, 1295, 1418]},
 {2: [1751, 547, 658],
  1: [2701, 250, 380],
  0: [2622, 2414, 2590],
  3: [2112, 505, 2263],
  4: [330, 1642, 507]},
 {1: [660, 1662, 445],
  2: [658, 1696, 2466],
  3: [1004, 505, 58],
  6: [1610, 1971, 1970],
  5: [1562, 358, 2183]},
 {3: [58, 856, 2263],
  6: [1554, 1970, 2347],
  0: [1186, 2590, 2414],
  5: [1519, 2594, 1418],
  4: [330, 2363, 688]},
 {3: [993, 841, 1004],
  6: [2347, 2261, 797],
  5: [1895, 2181, 348],
  1: [72, 380, 250],
  4: [1642, 1315, 516]},
 {2: [1602, 2242, 954],
  0: [518, 2622, 1878],
  3: [2137, 2555, 1776],
  6: [2191, 1087, 2013],
  5: [1418, 2282, 1741]},
 {2: [2466, 1696, 5],
  4: [516, 688, 2363],
  5: [2183, 358, 789],
  6: [728, 2217, 2191],
  0: [2622, 2325, 412]},
 {6: [2191, 959, 2217],
  1: [660, 1662, 

In [7]:
from torch_geometric.loader import NeighborLoader


def getitem(index):
    if isinstance(index, list):
        return [getitem(i) for i in index]
    elif isinstance(index, dict):
        return {key: getitem(value) for key, value in index.items()}
    elif not isinstance(index, int):
        raise IndexError("Only integers, lists and dictionaries can be used as indices")

    loader = NeighborLoader(data=cora,
                            num_neighbors=args["num_neighbors"],
                            input_nodes=torch.LongTensor([index]),
                            subgraph_type=args["subgraph_type"])
    subgraph = next(iter(loader))
    subgraph.batch_size = None

    return subgraph

batch = getitem(batch1)
batch

[{4: [Data(x=[2, 384], edge_index=[2, 2], y=[2], label_names=[7], num_nodes=2, x_text_feat=[2, 768], label_text_feat=[7, 768], edge_text_feat=[1, 768], prompt_text_edge_feat=[1, 768], prompt_text_feat=[1, 768], prompt_edge_feat=[1, 768], edge_label_feat=[2, 768], train_mask=[2], val_mask=[2], test_mask=[2], n_id=[2], e_id=[2], input_id=[1]),
   Data(x=[2, 384], edge_index=[2, 2], y=[2], label_names=[7], num_nodes=2, x_text_feat=[2, 768], label_text_feat=[7, 768], edge_text_feat=[1, 768], prompt_text_edge_feat=[1, 768], prompt_text_feat=[1, 768], prompt_edge_feat=[1, 768], edge_label_feat=[2, 768], train_mask=[2], val_mask=[2], test_mask=[2], n_id=[2], e_id=[2], input_id=[1]),
   Data(x=[9, 384], edge_index=[2, 26], y=[9], label_names=[7], num_nodes=9, x_text_feat=[9, 768], label_text_feat=[7, 768], edge_text_feat=[1, 768], prompt_text_edge_feat=[1, 768], prompt_text_feat=[1, 768], prompt_edge_feat=[1, 768], edge_label_feat=[2, 768], train_mask=[9], val_mask=[9], test_mask=[9], n_id=[9]

In [8]:
def process_one_task(task):
    label_map = list(task) # Looks like this: (0, 'task1'), (1, 'task2'), ...
    label_map_reverse = {v: i for i, v in enumerate(label_map)} # ((0, 'task1'), 0), ((1, 'task2'), 1), ...
    all_graphs = []
    labels = []
    query_mask = []
    for label, graphs in task.items():
        augmented = [graph for graph in graphs]
        all_graphs.extend(augmented)
        query_mask.extend([False] * (args["n_shot"]))
        query_mask.extend([True] * (args["n_query"]))
        labels.extend([label_map_reverse[label]] * len(augmented)) # label_map_reverse[label] is the index of label in label_map
    return all_graphs, torch.tensor(labels), torch.tensor(query_mask), label_map


In [9]:
graphs, labels, query_mask, label_map = map(list, zip(*[process_one_task(task) for task in batch]))
print("graphs = ", graphs)
print("labels = ", labels)
print("query_mask = ", query_mask)
print("label_map = ", label_map)

graphs =  [[Data(x=[2, 384], edge_index=[2, 2], y=[2], label_names=[7], num_nodes=2, x_text_feat=[2, 768], label_text_feat=[7, 768], edge_text_feat=[1, 768], prompt_text_edge_feat=[1, 768], prompt_text_feat=[1, 768], prompt_edge_feat=[1, 768], edge_label_feat=[2, 768], train_mask=[2], val_mask=[2], test_mask=[2], n_id=[2], e_id=[2], input_id=[1]), Data(x=[2, 384], edge_index=[2, 2], y=[2], label_names=[7], num_nodes=2, x_text_feat=[2, 768], label_text_feat=[7, 768], edge_text_feat=[1, 768], prompt_text_edge_feat=[1, 768], prompt_text_feat=[1, 768], prompt_edge_feat=[1, 768], edge_label_feat=[2, 768], train_mask=[2], val_mask=[2], test_mask=[2], n_id=[2], e_id=[2], input_id=[1]), Data(x=[9, 384], edge_index=[2, 26], y=[9], label_names=[7], num_nodes=9, x_text_feat=[9, 768], label_text_feat=[7, 768], edge_text_feat=[1, 768], prompt_text_edge_feat=[1, 768], prompt_text_feat=[1, 768], prompt_edge_feat=[1, 768], edge_label_feat=[2, 768], train_mask=[9], val_mask=[9], test_mask=[9], n_id=[9]

In [10]:
g1, g2 = graphs[0][0], graphs[0][1]
g1, g2

(Data(x=[2, 384], edge_index=[2, 2], y=[2], label_names=[7], num_nodes=2, x_text_feat=[2, 768], label_text_feat=[7, 768], edge_text_feat=[1, 768], prompt_text_edge_feat=[1, 768], prompt_text_feat=[1, 768], prompt_edge_feat=[1, 768], edge_label_feat=[2, 768], train_mask=[2], val_mask=[2], test_mask=[2], n_id=[2], e_id=[2], input_id=[1]),
 Data(x=[2, 384], edge_index=[2, 2], y=[2], label_names=[7], num_nodes=2, x_text_feat=[2, 768], label_text_feat=[7, 768], edge_text_feat=[1, 768], prompt_text_edge_feat=[1, 768], prompt_text_feat=[1, 768], prompt_edge_feat=[1, 768], edge_label_feat=[2, 768], train_mask=[2], val_mask=[2], test_mask=[2], n_id=[2], e_id=[2], input_id=[1]))

In [11]:
from torch_geometric.data import Batch

gg = Batch.from_data_list([g1, g2])
gg

DataBatch(x=[4, 384], edge_index=[2, 4], y=[4], label_names=[2], num_nodes=4, x_text_feat=[4, 768], label_text_feat=[14, 768], edge_text_feat=[2, 768], prompt_text_edge_feat=[2, 768], prompt_text_feat=[2, 768], prompt_edge_feat=[2, 768], edge_label_feat=[4, 768], train_mask=[4], val_mask=[4], test_mask=[4], n_id=[4], e_id=[4], input_id=[2], batch=[4], ptr=[3])

In [12]:
from torch_geometric.data import Batch
from itertools import chain

num_task = len(graphs)
task_len = len(graphs[0])
num_labels = len(label_map[0])

print("num_task = ", num_task)
print("task_len = ", task_len)
print("num_labels = ", num_labels)

graphs = Batch.from_data_list([g for l in graphs for g in l])
labels = torch.cat(labels)
b_mask = torch.stack(query_mask)
query_mask = torch.cat(query_mask)
label_map = list(chain(*label_map))

print("graphs = ", graphs)
print("labels = ", labels)
print("b_mask = ", b_mask)
print("query_mask = ", query_mask)
print("label_map = ", label_map)

num_task =  100
task_len =  15
num_labels =  5
graphs =  DataBatch(x=[7326, 384], edge_index=[2, 17202], y=[7326], label_names=[1500], num_nodes=7326, x_text_feat=[7326, 768], label_text_feat=[10500, 768], edge_text_feat=[1500, 768], prompt_text_edge_feat=[1500, 768], prompt_text_feat=[1500, 768], prompt_edge_feat=[1500, 768], edge_label_feat=[3000, 768], train_mask=[7326], val_mask=[7326], test_mask=[7326], n_id=[7326], e_id=[17202], input_id=[1500], batch=[7326], ptr=[1501])
labels =  tensor([0, 0, 0,  ..., 4, 4, 4])
b_mask =  tensor([[False, False,  True,  ..., False, False,  True],
        [False, False,  True,  ..., False, False,  True],
        [False, False,  True,  ..., False, False,  True],
        ...,
        [False, False,  True,  ..., False, False,  True],
        [False, False,  True,  ..., False, False,  True],
        [False, False,  True,  ..., False, False,  True]])
query_mask =  tensor([False, False,  True,  ..., False, False,  True])
label_map =  [4, 3, 1, 0, 2, 4, 

In [13]:
metagraph_edge_source = torch.arange(labels.size(0)).repeat_interleave(num_labels)

metagraph_edge_target = torch.arange(num_labels).repeat(labels.size(0))
metagraph_edge_target += (torch.arange(num_task) * num_labels).repeat_interleave(task_len * num_labels) + labels.size(0)

metagraph_edge_index = torch.stack([metagraph_edge_source, metagraph_edge_target], dim=0)

metagraph_edge_mask = query_mask.repeat_interleave(num_labels)

metagraph_edge_attr = torch.nn.functional.one_hot(labels, num_labels).float()
metagraph_edge_attr = metagraph_edge_attr.reshape(-1)
metagraph_edge_attr = (metagraph_edge_attr * 2 - 1) * (~metagraph_edge_mask)

metagraph_edge_attr = torch.stack([metagraph_edge_mask, metagraph_edge_attr], dim=1)

label_meta = torch.arange(7 * 768).reshape(7, 768)

label_map = torch.tensor(label_map)
label_embeddings = label_meta[label_map]

labels_onehot = torch.nn.functional.one_hot(labels).float()

In [14]:
graphs, label_embeddings, labels_onehot, metagraph_edge_index, metagraph_edge_attr, metagraph_edge_mask # True for query_mask

(DataBatch(x=[7326, 384], edge_index=[2, 17202], y=[7326], label_names=[1500], num_nodes=7326, x_text_feat=[7326, 768], label_text_feat=[10500, 768], edge_text_feat=[1500, 768], prompt_text_edge_feat=[1500, 768], prompt_text_feat=[1500, 768], prompt_edge_feat=[1500, 768], edge_label_feat=[3000, 768], train_mask=[7326], val_mask=[7326], test_mask=[7326], n_id=[7326], e_id=[17202], input_id=[1500], batch=[7326], ptr=[1501]),
 tensor([[3072, 3073, 3074,  ..., 3837, 3838, 3839],
         [2304, 2305, 2306,  ..., 3069, 3070, 3071],
         [ 768,  769,  770,  ..., 1533, 1534, 1535],
         ...,
         [2304, 2305, 2306,  ..., 3069, 3070, 3071],
         [ 768,  769,  770,  ..., 1533, 1534, 1535],
         [1536, 1537, 1538,  ..., 2301, 2302, 2303]]),
 tensor([[1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.],
         ...,
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.]]),
 tensor([[   0,    0,    0,  ..., 149