In [48]:
from typing import List
from torch_geometric.data import InMemoryDataset
from torch_geometric.distributed.local_graph_store import LocalGraphStore
from torch_geometric.datasets.web_qsp_dataset import *
import torch
import datasets
import time

In [49]:
from raw_qsp_dataset import RawWebQSPDataset

In [50]:
dataset = RawWebQSPDataset(force_reload=True)

Processing...
Done!


In [51]:
dataset.raw_dataset

Dataset({
    features: ['id', 'question', 'answer', 'q_entity', 'a_entity', 'graph', 'choices'],
    num_rows: 4700
})

In [52]:
from large_graph_indexer import LargeGraphIndexer, TripletLike

In [53]:
def preprocess_triplet(triplet: TripletLike):
    return tuple([word.lower() for word in triplet])

Test basic collation

In [54]:
indexer = LargeGraphIndexer.from_triplets(dataset.raw_dataset[0]['graph'], pre_transform=preprocess_triplet)

In [55]:
indexer2 = LargeGraphIndexer.from_triplets(dataset.raw_dataset[1]['graph'], pre_transform=preprocess_triplet)

In [56]:
bigger_indexer = LargeGraphIndexer.collate([indexer, indexer2])

In [57]:
assert len(indexer._nodes) + len(indexer2._nodes) - len(indexer._nodes.keys() & indexer2._nodes.keys()) == len(bigger_indexer._nodes)
assert len(indexer._edges) + len(indexer2._edges) - len(indexer._edges.keys() & indexer2._edges.keys()) == len(bigger_indexer._edges)

In [58]:
assert len(set(bigger_indexer._nodes.values())) == len(bigger_indexer._nodes)
assert len(set(bigger_indexer._edges.values())) == len(bigger_indexer._edges)

In [59]:
for node in indexer._nodes.keys():
    assert indexer.node_attr['pid'][indexer._nodes[node]] == node, f'{node} is not {indexer.node_attr["pid"][indexer._nodes[node]]}'

In [60]:
import tqdm
from multiprocessing import Pool

Test collation on entire dataset

In [61]:
LIMIT=100
def get_next_graph(dataset, limit=None):
    i = 0
    for ds in dataset:
        if i == limit:
            break
        yield ds['graph']
        i += 1
graphs = get_next_graph(dataset.raw_dataset, limit=LIMIT)

In [62]:
indexers = []
def from_trips_with_pretransform(triplets):
    return LargeGraphIndexer.from_triplets(triplets, pre_transform=preprocess_triplet)
with Pool(40) as p:
    indexers = list(tqdm.tqdm(p.imap(from_trips_with_pretransform, graphs), total=len(dataset.raw_dataset) if LIMIT is None else LIMIT))

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [63]:
start = time.time()
big_indexer = LargeGraphIndexer.collate(indexers)
time.time()-start

0.5566630363464355

Naive method, concatenating all the triplets

In [64]:
from itertools import chain

In [65]:
large_graph_dataset = chain.from_iterable(get_next_graph(dataset.raw_dataset, limit=LIMIT))

In [66]:
# This is just for TQDM to work well
total_size = 0
#for g in tqdm.tqdm(get_next_graph(dataset.raw_dataset), total=4700):
#    total_size += len(g)

In [67]:
large_indexer = LargeGraphIndexer.from_triplets(tqdm.tqdm(large_graph_dataset, total=total_size), pre_transform=preprocess_triplet)

423763it [00:02, 198983.48it/s]


In [68]:
assert set(large_indexer._nodes) == set(big_indexer._nodes)
assert set(large_indexer._edges) == set(big_indexer._edges)

## Phase I: Indexing Large Graph

In [69]:
node_attributes = list(big_indexer.get_unique_node_features())

In [70]:
edge_attributes = list(big_indexer.get_unique_edge_features("r"))

In [71]:
len(node_attributes)

105413

In [72]:
len(edge_attributes)

3095

In [73]:
from torch_geometric.nn.text import text2embedding, SentenceTransformer

In [74]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SentenceTransformer("sentence-transformers/all-roberta-large-v1").to(device)
model.eval()

inherit model weights from sentence-transformers/all-roberta-large-v1


SentenceTransformer(
  (bert_model): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 1024, padding_idx=1)
      (position_embeddings): Embedding(514, 1024, padding_idx=1)
      (token_type_embeddings): Embedding(1, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-23): 24 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
           

In [75]:
print(device)

cuda


In [76]:
from more_itertools import chunked

In [77]:
nodes_to_process, edges_to_process = len(node_attributes), len(edge_attributes)

In [78]:
BATCH_SIZE = 32

In [79]:
# Indexing graph features
node_embs = []
for nbatch in tqdm.tqdm(chunked(node_attributes, BATCH_SIZE), total=nodes_to_process//BATCH_SIZE):
    node_embs.append(text2embedding(model, device, nbatch, BATCH_SIZE).cpu())
node_embs = torch.cat(node_embs, 0)



3295it [01:08, 47.83it/s]                          


In [80]:
node_embs.shape

torch.Size([105413, 1024])

In [81]:
big_indexer.add_node_feature(new_feature_name="embs", new_feature_vals=node_embs)

In [82]:
from typing import Iterable, Callable, Optional
from large_graph_indexer import TripletLike, ordered_set
from torch_geometric.typing import FeatureTensorType
from torch_geometric.data import Data

In [83]:
# Indexing graph features
edge_embs = []
for ebatch in tqdm.tqdm(chunked(edge_attributes, BATCH_SIZE), total=edges_to_process//BATCH_SIZE):
    edge_embs.append(text2embedding(model, device, ebatch).cpu())
edge_embs = torch.cat(edge_embs, 0)


97it [00:01, 49.56it/s]                        


In [84]:
edge_embs.shape

torch.Size([3095, 1024])

In [85]:
big_indexer.add_edge_feature(new_feature_name="embs", new_feature_vals=edge_embs, map_from_feature="r")

In [None]:
from large_graph_indexer import get_features_for_triplets

In [86]:
import networkx as nx

In [87]:
first_10_trips = dataset.raw_dataset[:10]['graph']

In [88]:
#TODO: Parallelize
first_10 = [get_features_for_triplets(big_indexer, trip_lst, pre_transform=preprocess_triplet) for trip_lst in tqdm.tqdm(first_10_trips)]

100%|██████████| 10/10 [00:01<00:00,  6.31it/s]


In [89]:
# Grab the first few samples from the old ds to test with LargeGraphIndexer
old_dataset = RawWebQSPDataset(force_reload=True, with_process=True, limit=10)

Processing...


inherit model weights from sentence-transformers/all-roberta-large-v1
Encoding graphs...


 30%|███       | 3/10 [00:10<00:21,  3.03s/it]

SBERT text embedding failed, returning torch.zeros((0, 1024))...


100%|██████████| 10/10 [00:35<00:00,  3.55s/it]
Done!


In [90]:
def results_are_close_enough(ground_truth: Data, new_method: Data, thresh=.8):
    def _sorted_tensors_are_close(tensor1, tensor2):
        return torch.all(torch.isclose(tensor1.sort(dim=0)[0], tensor2.sort(dim=0)[0]).float().mean(axis=1) > thresh)
    def _graphs_are_same(tensor1, tensor2):
        return nx.weisfeiler_lehman_graph_hash(nx.Graph(tensor1.T)) == nx.weisfeiler_lehman_graph_hash(nx.Graph(tensor2.T))
    return _sorted_tensors_are_close(ground_truth.x, new_method.x) \
        and _sorted_tensors_are_close(ground_truth.edge_attr, new_method.edge_attr) \
        and _graphs_are_same(ground_truth.edge_index, new_method.edge_index)

In [91]:
for ds in zip(old_dataset, first_10):
    print(results_are_close_enough(*ds))

True
True
True


RuntimeError: The size of tensor a (0) must match the size of tensor b (1988) at non-singleton dimension 0

Test Saving and Loading and eq

In [None]:
big_indexer.save('indexer')

In [None]:
assert big_indexer == LargeGraphIndexer.from_disk('indexer')

UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 0: invalid start byte