In [1]:
from ogb.nodeproppred import PygNodePropPredDataset
import os
import os.path as osp
import sys
from torch_geometric.data import Data
import torch
import pandas as pd
import numpy as np
from torch_geometric.data import InMemoryDataset, download_url

sys.path.append(osp.abspath(".."))
print(sys.path)

from utils.encoder import SentenceEncoder
from utils_data.custom_pyg import CustomPygDataset
from utils.dataloader import GetDataloader
from utils_data.cora import CoraPyGDataset

import warnings
warnings.filterwarnings("ignore")

['/home/prateek/graphs-with-llms-experiments/utils_data', '/home/prateek/miniconda3/envs/torch_pyg/lib/python310.zip', '/home/prateek/miniconda3/envs/torch_pyg/lib/python3.10', '/home/prateek/miniconda3/envs/torch_pyg/lib/python3.10/lib-dynload', '', '/home/prateek/miniconda3/envs/torch_pyg/lib/python3.10/site-packages', '/home/prateek/graphs-with-llms-experiments']


In [2]:
data_root = "../data"

In [3]:
LMencoder = SentenceEncoder("ST", root="../lang_models", device=1)
custom_cora = CoraPyGDataset(dataRoot=data_root, custom_dataRoot="../custom_data", sentence_encoder=LMencoder)
cora = custom_cora._data
cora.to("cpu")
cora

Data(x=[2708, 384], edge_index=[2, 10858], y=[2708], label_names=[7], num_nodes=2708, x_text_feat=[2708, 768], label_text_feat=[7, 768], edge_text_feat=[1, 768], prompt_text_edge_feat=[1, 768], prompt_text_feat=[1, 768], prompt_edge_feat=[1, 768], edge_label_feat=[2, 768], train_mask=[2708], val_mask=[2708], test_mask=[2708])

In [4]:
import yaml
from datetime import date

with open("../config.yaml", "r") as f:
    args = yaml.safe_load(f)

args["device"] = 'cpu' if args["device"] == 123 else f"cuda:{args['device']}"
args["exp_name"] = f"Date -> {date.today()}. Experiment_{args['sentence_encoder']}_{args['exp_name']}"

args["encoder_path"] = '../lang_models'
args["dataRoot"] = '../data'
args["custom_dataRoot"] = '../custom_data'
args["dataset"] = "cora"
args["batch_count"] = 5

args

{'exp_name': 'Date -> 2024-02-07. Experiment_ST_ofa-1st-run',
 'dataRoot': '../data',
 'custom_dataRoot': '../custom_data',
 'dataset': 'cora',
 'sentence_encoder': 'ST',
 'encoder_path': '../lang_models',
 'state_dict_path': './state_dicts',
 'model_option': 2,
 'model_params': [{'name': 'SAGE', 'in_dim': 768},
  {'name': 'RGCN',
   'num_layers': 2,
   'in_dim': 768,
   'out_dim': 256,
   'edge_attr_dim': 768,
   'num_relations': 6,
   'heads': 1,
   'dropout': 0.3,
   'aggr': 'mean',
   'JK': None,
   'batch_norm': True}],
 'lr': 0.001,
 'epochs': 1,
 'batch_count': 5,
 'batch_size': 16,
 'val_check_interval': 100,
 'weight_decay': 0.001,
 'seed': None,
 'device': 'cuda:2',
 'n_way': 5,
 'n_shot': 2,
 'n_query': 1,
 'num_neighbors': [-1],
 'subgraph_type': 'induced'}

In [5]:
dl = GetDataloader(**args)
dl

<utils.dataloader.GetDataloader at 0x7f0a40b6f8b0>

In [6]:
batch1 = next(iter(dl.trn_smplr))
batch1

[{1: [2419, 1195, 660],
  5: [1519, 1562, 348],
  2: [174, 547, 1343],
  6: [772, 959, 1234],
  3: [1302, 2112, 1004]},
 {6: [1610, 772, 160],
  1: [1511, 2397, 495],
  3: [299, 2246, 856],
  2: [547, 1751, 954],
  0: [2458, 412, 2622]},
 {2: [658, 1499, 1696],
  4: [2067, 2519, 2559],
  3: [2033, 534, 2030],
  0: [1736, 886, 2325],
  5: [2482, 1895, 30]},
 {2: [845, 1751, 576],
  6: [2337, 1554, 2217],
  4: [1315, 2559, 330],
  3: [2137, 739, 1302],
  0: [1878, 1373, 2325]},
 {6: [2013, 1610, 2299],
  0: [886, 2622, 2458],
  3: [2246, 2116, 1004],
  4: [688, 2519, 1246],
  5: [1418, 2482, 2205]},
 {5: [1041, 348, 1895],
  1: [445, 1220, 2323],
  0: [2590, 1583, 1878],
  3: [993, 1302, 58],
  2: [1751, 2466, 2467]},
 {0: [412, 1583, 2325],
  2: [845, 2242, 1343],
  3: [534, 993, 1776],
  4: [2067, 1950, 40],
  1: [2701, 1017, 495]},
 {2: [845, 1602, 954],
  5: [1418, 2282, 1295],
  6: [2277, 2347, 1971],
  4: [2559, 1825, 2343],
  3: [993, 1004, 2030]},
 {1: [2701, 1503, 1129],
  6: [1

In [7]:
from torch_geometric.loader import NeighborLoader


def getitem(index):
    if isinstance(index, list):
        return [getitem(i) for i in index]
    elif isinstance(index, dict):
        return {key: getitem(value) for key, value in index.items()}
    elif not isinstance(index, int):
        raise IndexError("Only integers, lists and dictionaries can be used as indices")

    loader = NeighborLoader(data=cora,
                            num_neighbors=args["num_neighbors"],
                            input_nodes=torch.LongTensor([index]),
                            subgraph_type=args["subgraph_type"])
    subgraph = next(iter(loader))
    subgraph.batch_size = None

    return subgraph

batch = getitem(batch1)
batch

[{1: [Data(x=[3, 384], edge_index=[2, 6], y=[3], label_names=[7], num_nodes=3, x_text_feat=[3, 768], label_text_feat=[7, 768], edge_text_feat=[1, 768], prompt_text_edge_feat=[1, 768], prompt_text_feat=[1, 768], prompt_edge_feat=[1, 768], edge_label_feat=[2, 768], train_mask=[3], val_mask=[3], test_mask=[3], n_id=[3], e_id=[6], input_id=[1]),
   Data(x=[5, 384], edge_index=[2, 10], y=[5], label_names=[7], num_nodes=5, x_text_feat=[5, 768], label_text_feat=[7, 768], edge_text_feat=[1, 768], prompt_text_edge_feat=[1, 768], prompt_text_feat=[1, 768], prompt_edge_feat=[1, 768], edge_label_feat=[2, 768], train_mask=[5], val_mask=[5], test_mask=[5], n_id=[5], e_id=[10], input_id=[1]),
   Data(x=[4, 384], edge_index=[2, 8], y=[4], label_names=[7], num_nodes=4, x_text_feat=[4, 768], label_text_feat=[7, 768], edge_text_feat=[1, 768], prompt_text_edge_feat=[1, 768], prompt_text_feat=[1, 768], prompt_edge_feat=[1, 768], edge_label_feat=[2, 768], train_mask=[4], val_mask=[4], test_mask=[4], n_id=[4

In [8]:
def process_one_task(task):
    label_map = list(task) # Looks like this: (0, 'task1'), (1, 'task2'), ...
    label_map_reverse = {v: i for i, v in enumerate(label_map)} # ((0, 'task1'), 0), ((1, 'task2'), 1), ...
    all_graphs = []
    labels = []
    query_mask = []
    for label, graphs in task.items():
        augmented = [graph for graph in graphs]
        all_graphs.extend(augmented)
        query_mask.extend([False] * (args["n_shot"]))
        query_mask.extend([True] * (args["n_query"]))
        labels.extend([label_map_reverse[label]] * len(augmented)) # label_map_reverse[label] is the index of label in label_map
    return all_graphs, torch.tensor(labels), torch.tensor(query_mask), label_map


In [9]:
graphs, labels, query_mask, label_map = map(list, zip(*[process_one_task(task) for task in batch]))
print("graphs = ", graphs)
print("labels = ", labels)
print("query_mask = ", query_mask)
print("label_map = ", label_map)

graphs =  [[Data(x=[3, 384], edge_index=[2, 6], y=[3], label_names=[7], num_nodes=3, x_text_feat=[3, 768], label_text_feat=[7, 768], edge_text_feat=[1, 768], prompt_text_edge_feat=[1, 768], prompt_text_feat=[1, 768], prompt_edge_feat=[1, 768], edge_label_feat=[2, 768], train_mask=[3], val_mask=[3], test_mask=[3], n_id=[3], e_id=[6], input_id=[1]), Data(x=[5, 384], edge_index=[2, 10], y=[5], label_names=[7], num_nodes=5, x_text_feat=[5, 768], label_text_feat=[7, 768], edge_text_feat=[1, 768], prompt_text_edge_feat=[1, 768], prompt_text_feat=[1, 768], prompt_edge_feat=[1, 768], edge_label_feat=[2, 768], train_mask=[5], val_mask=[5], test_mask=[5], n_id=[5], e_id=[10], input_id=[1]), Data(x=[4, 384], edge_index=[2, 8], y=[4], label_names=[7], num_nodes=4, x_text_feat=[4, 768], label_text_feat=[7, 768], edge_text_feat=[1, 768], prompt_text_edge_feat=[1, 768], prompt_text_feat=[1, 768], prompt_edge_feat=[1, 768], edge_label_feat=[2, 768], train_mask=[4], val_mask=[4], test_mask=[4], n_id=[4

In [10]:
g1, g2 = graphs[0][0], graphs[0][1]
g1, g2

(Data(x=[3, 384], edge_index=[2, 6], y=[3], label_names=[7], num_nodes=3, x_text_feat=[3, 768], label_text_feat=[7, 768], edge_text_feat=[1, 768], prompt_text_edge_feat=[1, 768], prompt_text_feat=[1, 768], prompt_edge_feat=[1, 768], edge_label_feat=[2, 768], train_mask=[3], val_mask=[3], test_mask=[3], n_id=[3], e_id=[6], input_id=[1]),
 Data(x=[5, 384], edge_index=[2, 10], y=[5], label_names=[7], num_nodes=5, x_text_feat=[5, 768], label_text_feat=[7, 768], edge_text_feat=[1, 768], prompt_text_edge_feat=[1, 768], prompt_text_feat=[1, 768], prompt_edge_feat=[1, 768], edge_label_feat=[2, 768], train_mask=[5], val_mask=[5], test_mask=[5], n_id=[5], e_id=[10], input_id=[1]))

In [11]:
from torch_geometric.data import Batch

gg = Batch.from_data_list([g1, g2])
gg

DataBatch(x=[8, 384], edge_index=[2, 16], y=[8], label_names=[2], num_nodes=8, x_text_feat=[8, 768], label_text_feat=[14, 768], edge_text_feat=[2, 768], prompt_text_edge_feat=[2, 768], prompt_text_feat=[2, 768], prompt_edge_feat=[2, 768], edge_label_feat=[4, 768], train_mask=[8], val_mask=[8], test_mask=[8], n_id=[8], e_id=[16], input_id=[2], batch=[8], ptr=[3])

In [12]:
from torch_geometric.data import Batch
from itertools import chain

num_task = len(graphs)
task_len = len(graphs[0])
num_labels = len(label_map[0])

print("num_task = ", num_task)
print("task_len = ", task_len)
print("num_labels = ", num_labels)

graphs = Batch.from_data_list([g for l in graphs for g in l])
labels = torch.cat(labels)
b_mask = torch.stack(query_mask)
query_mask = torch.cat(query_mask)
label_map = list(chain(*label_map))

print("graphs = ", graphs)
print("labels = ", labels)
print("b_mask = ", b_mask)
print("query_mask = ", query_mask)
print("label_map = ", label_map)

num_task =  16
task_len =  15
num_labels =  5
graphs =  DataBatch(x=[1184, 384], edge_index=[2, 2884], y=[1184], label_names=[240], num_nodes=1184, x_text_feat=[1184, 768], label_text_feat=[1680, 768], edge_text_feat=[240, 768], prompt_text_edge_feat=[240, 768], prompt_text_feat=[240, 768], prompt_edge_feat=[240, 768], edge_label_feat=[480, 768], train_mask=[1184], val_mask=[1184], test_mask=[1184], n_id=[1184], e_id=[2884], input_id=[240], batch=[1184], ptr=[241])
labels =  tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 0, 0, 0, 1, 1, 1, 2, 2, 2,
        3, 3, 3, 4, 4, 4, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 0, 0, 0,
        1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3,
        4, 4, 4, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 0, 0, 0, 1, 1, 1,
        2, 2, 2, 3, 3, 3, 4, 4, 4, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4,
        0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 0, 0, 0, 1, 1, 1, 2, 2, 2,
        3, 3, 3, 4, 4, 4, 0, 0, 0, 1, 1,

In [13]:
metagraph_edge_source = torch.arange(labels.size(0)).repeat_interleave(num_labels)

metagraph_edge_target = torch.arange(num_labels).repeat(labels.size(0))
metagraph_edge_target += (torch.arange(num_task) * num_labels).repeat_interleave(task_len * num_labels) + labels.size(0)

metagraph_edge_index = torch.stack([metagraph_edge_source, metagraph_edge_target], dim=0)

metagraph_edge_mask = query_mask.repeat_interleave(num_labels)

metagraph_edge_attr = torch.nn.functional.one_hot(labels, num_labels).float()
metagraph_edge_attr = metagraph_edge_attr.reshape(-1)
metagraph_edge_attr = (metagraph_edge_attr * 2 - 1) * (~metagraph_edge_mask)

metagraph_edge_attr = torch.stack([metagraph_edge_mask, metagraph_edge_attr], dim=1)

label_meta = torch.arange(7 * 768).reshape(7, 768)

label_map = torch.tensor(label_map)
label_embeddings = label_meta[label_map]

labels_onehot = torch.nn.functional.one_hot(labels).float()

In [14]:
graphs, label_embeddings, labels_onehot, metagraph_edge_index, metagraph_edge_attr, metagraph_edge_mask # True for query_mask

(DataBatch(x=[1184, 384], edge_index=[2, 2884], y=[1184], label_names=[240], num_nodes=1184, x_text_feat=[1184, 768], label_text_feat=[1680, 768], edge_text_feat=[240, 768], prompt_text_edge_feat=[240, 768], prompt_text_feat=[240, 768], prompt_edge_feat=[240, 768], edge_label_feat=[480, 768], train_mask=[1184], val_mask=[1184], test_mask=[1184], n_id=[1184], e_id=[2884], input_id=[240], batch=[1184], ptr=[241]),
 tensor([[ 768,  769,  770,  ..., 1533, 1534, 1535],
         [3840, 3841, 3842,  ..., 4605, 4606, 4607],
         [1536, 1537, 1538,  ..., 2301, 2302, 2303],
         ...,
         [3072, 3073, 3074,  ..., 3837, 3838, 3839],
         [ 768,  769,  770,  ..., 1533, 1534, 1535],
         [3840, 3841, 3842,  ..., 4605, 4606, 4607]]),
 tensor([[1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.],
         ...,
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.]]),
 tensor([[  0,   0,   0,  ..., 239, 239, 239],
 