In [1]:
from updated_qsp_dataset import UpdatedWebQSPDataset
from raw_qsp_dataset import RawWebQSPDataset
from unittest.mock import patch
from torch_geometric.data import Data
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def tokenizer_call_iter():
    a, batch_size = [], 256
    i = 0
    while True:
        rv = torch.Tensor([list(range(i, i+len(a)))]).T
        i += len(a)
        a = yield rv
        if a is None:
            a = []
gen = tokenizer_call_iter()
next(gen)
def tokenizer_call(a, batch_size):
    return gen.send(a)

In [3]:
old_token_mock, old_ret_mock = None, None
with (patch("torch_geometric.datasets.web_qsp_dataset.SentenceTransformer.encode") as model_mock,
      patch("raw_qsp_dataset.retrieval_via_pcst") as pcst_mock):
    model_mock.side_effect =tokenizer_call
    pcst_mock.return_value = Data(), ""
    old_dataset = RawWebQSPDataset(root='old_dataset', force_reload=True, limit=2, with_process=True, with_pcst=True)
    old_token_mock = model_mock
    old_ret_mock = pcst_mock

Processing...


Encoding questions...
Encoding graphs...


100%|██████████| 2/2 [00:00<00:00, 46.95it/s]
Done!


In [16]:
# Create reference list for all tokenizer call unique values
TOKEN_CALLS = set()
for call in model_mock.call_args_list:
    text_lst = call[0][0]
    TOKEN_CALLS |= set(text_lst)
TOKEN_CALLS_MAP = list(TOKEN_CALLS)
TOKEN_CALLS_REVERSE_MAP = {v: i for i, v in enumerate(TOKEN_CALLS_MAP)}

In [18]:
def tokenizer_from_map_iter():
    a = []
    while True:
        rv = torch.Tensor([TOKEN_CALLS_REVERSE_MAP[t] for t in a])
        a = yield rv
        if a is None:
            a = []
gen = tokenizer_from_map_iter()
next(gen)
def tokenizer_from_map(a, batch_size):
    return gen.send(a)

In [19]:
old_token_mock, old_ret_mock = None, None
with (patch("torch_geometric.datasets.web_qsp_dataset.SentenceTransformer.encode") as model_mock,
      patch("raw_qsp_dataset.retrieval_via_pcst") as pcst_mock):
    model_mock.side_effect =tokenizer_from_map
    pcst_mock.return_value = Data(), ""
    old_dataset = RawWebQSPDataset(root='old_dataset', force_reload=True, limit=2, with_process=True, with_pcst=True)
    old_token_mock = model_mock
    old_ret_mock = pcst_mock

Processing...


Encoding questions...
Encoding graphs...


100%|██████████| 2/2 [00:00<00:00, 45.33it/s]
Done!


In [23]:
new_token_mock, new_ret_mock = None, None
with (patch("torch_geometric.datasets.web_qsp_dataset.SentenceTransformer.encode") as model_mock,
      patch("updated_qsp_dataset.retrieval_via_pcst") as pcst_mock):
    model_mock.side_effect = tokenizer_from_map
    pcst_mock.return_value = Data(), ""
    old_dataset = UpdatedWebQSPDataset(root='old_dataset', force_reload=True, limit=2, whole_graph_retrieval=False)
    new_token_mock = model_mock
    new_ret_mock = pcst_mock

Processing...


Encoding graph...
Encoding questions...
Retrieving subgraphs...


100%|██████████| 2/2 [00:00<00:00,  6.98it/s]
Done!


In [53]:
old_nodes_called = [old_ret_mock.call_args_list[i][0][0].x.int().numpy() for i in range(len(old_ret_mock.call_args_list))]
old_edges_called = [old_ret_mock.call_args_list[i][0][0].edge_attr.int().numpy() for i in range(len(old_ret_mock.call_args_list))]
old_question_called = [old_ret_mock.call_args_list[i][0][1].int().numpy() for i in range(len(old_ret_mock.call_args_list))]

In [54]:
new_nodes_called = [new_ret_mock.call_args_list[i][0][0].x.int().numpy() for i in range(len(new_ret_mock.call_args_list))]
new_edges_called = [new_ret_mock.call_args_list[i][0][0].edge_attr.int().numpy() for i in range(len(new_ret_mock.call_args_list))]
new_question_called = [new_ret_mock.call_args_list[i][0][1].int().numpy() for i in range(len(new_ret_mock.call_args_list))]

In [59]:
for i in range(len(old_ret_mock.call_args_list)):
    print(set(old_nodes_called[i]) == set(new_nodes_called[i]))
    print(set(old_edges_called[i]) == set(new_edges_called[i]))
    print(old_question_called[i] == new_question_called[i])

True
True
True
True
True
True


In [58]:
old_question_called

[array(3073, dtype=int32), array(2115, dtype=int32)]

In [34]:
new_nodes_called

[tensor([   0., 3971., 3972.,  ..., 8044., 8046., 5478.]),
 tensor([  12., 3985., 3987.,  ..., 3953., 8035., 8038.])]

In [None]:
new_dataset = UpdatedWebQSPDataset(root='updated_dataset',force_reload=True, limit=2, whole_graph_retrieval=False)

In [6]:
old_dataset[0]

Data(x=[12, 1024], edge_index=[2, 12], edge_attr=[12, 1024], question='Question: what is the name of justin bieber brother
Answer: ', label='jaxon bieber', desc='node_id,node_attr
15,justin bieber
151,pattie mallette
286,english language
294,jaxon bieber
346,yves bole
356,jeremy bieber
452,jazmyn bieber
545,m.0wfn4pm
551,m.0gxnnwp
933,m.0gxnnwc
1032,this is justin bieber
1359,m.0129jzth

src,edge_attr,dst
346,people.person.languages,286
1032,film.film.language,286
346,influence.influence_node.influenced_by,15
151,people.person.children,15
294,people.person.parents,356
545,people.sibling_relationship.sibling,151
933,people.sibling_relationship.sibling,452
1359,people.sibling_relationship.sibling,346
551,people.sibling_relationship.sibling,294
15,people.person.sibling_s,933
933,people.sibling_relationship.sibling,15
551,people.sibling_relationship.sibling,15
', num_nodes=12)

In [7]:
new_dataset[0]

Data(x=[15, 1024], edge_index=[2, 20], edge_attr=[20, 1024], question='Question: what is the name of justin bieber brother
Answer: ', label='jaxon bieber', desc='index,node_id,node_attr
54,54,m.0gxnp5m
307,307,m.0129jzth
391,391,pattie mallette
650,650,english language
837,837,justin bieber
934,934,this is justin bieber
1022,1022,m.0gxnp5x
1123,1123,jeremy bieber
1149,1149,m.0gxnnwp
1847,1847,m.0gxnnwc
2294,2294,yves bole
2702,2702,m.0gxnp5d
2719,2719,jazmyn bieber
2880,2880,m.0wfn4pm
2882,2882,jaxon bieber

src,edge_attr,dst
2294,people.person.languages,650
934,film.film.language,650
2294,influence.influence_node.influenced_by,837
391,people.person.children,837
2882,people.person.parents,1123
2880,people.sibling_relationship.sibling,391
2882,people.person.sibling_s,1149
837,base.popstra.celebrity.hangout,1022
1847,people.sibling_relationship.sibling,2719
307,people.sibling_relationship.sibling,2294
391,people.person.sibling_s,2880
1149,people.sibling_relationship.sibling,2882
2294,peo

In [8]:
old_dataset.raw_graphs[0]

Data(x=[1709, 1024], edge_index=[2, 9088], edge_attr=[9088, 1024], num_nodes=1709)

In [9]:
new_dataset.raw_graphs[0]

Data(x=[1709, 1024], edge_index=[2, 9088], edge_attr=[9088, 1024], num_nodes=1709, pid=[1709], e_pid=[9088], node_idx=[1709], edge_idx=[9088])

In [10]:
import networkx as nx
from torch_geometric.data import Data
import torch

In [25]:
def results_are_close_enough(ground_truth: Data, new_method: Data, thresh=.9):
    def _sorted_tensors_are_close(tensor1, tensor2):
        vals = torch.isclose(tensor1.sort(dim=0)[0], tensor2.sort(dim=0)[0]).float().mean(axis=1)
        return torch.all(vals > thresh)
    def _graphs_are_same(tensor1, tensor2):
        return nx.weisfeiler_lehman_graph_hash(nx.Graph(tensor1.T)) == nx.weisfeiler_lehman_graph_hash(nx.Graph(tensor2.T))
    val = _sorted_tensors_are_close(ground_truth.x, new_method.x)
    val &= _sorted_tensors_are_close(ground_truth.edge_attr, new_method.edge_attr)
    val &= _graphs_are_same(ground_truth.edge_index, new_method.edge_index)
    return val

In [28]:
results_are_close_enough(old_dataset.raw_graphs[0], new_dataset.raw_graphs[0])

tensor(False)

In [None]:
old_dataset.questions[0]

In [None]:
new_dataset.questions[0]

In [None]:
torch.sort(new_dataset.raw_graphs[0].edge_attr, 0)[0].unique(dim=0).size(0)

In [None]:
torch.sort(old_dataset.raw_graphs[0].edge_attr, 0)[0].unique(dim=0).size()