In [1]:
from profiling_utils import create_remote_backend_from_triplets
from rag_feature_store import SentenceTransformerFeatureStore
from rag_graph_store import NeighborSamplingRAGGraphStore
from rag_loader import RagQueryLoader
from torch_geometric.datasets import UpdatedWebQSPDataset
from torch_geometric.nn.nlp import SentenceTransformer
from torch_geometric.datasets.updated_web_qsp_dataset import preprocess_triplet
from torch_geometric.data import get_features_for_triplets_groups, Data
from itertools import chain
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ds = UpdatedWebQSPDataset("small_ds", force_reload=True, limit=10)

Processing...


Loading graph...
Encoding questions...
Retrieving subgraphs...


10it [00:00, 59.04it/s]00:00<?, ?it/s]
100%|██████████| 10/10 [00:01<00:00,  5.82it/s]

Saving subgraphs...



Done!


In [3]:
triplets = list(chain.from_iterable((d['graph'] for d in ds.raw_dataset)))

In [4]:
questions = ds.raw_dataset['question']
questions

['what is the name of justin bieber brother',
 'what character did natalie portman play in star wars',
 'what country is the grand bahama island in',
 'what kind of money to take to bahamas',
 'what character did john noble play in lord of the rings',
 'who does joakim noah play for',
 'where are the nfl redskins from',
 'where did saki live',
 'who did draco malloy end up marrying',
 'which countries border the us']

In [5]:
ground_truth_graphs = get_features_for_triplets_groups(ds.indexer, (d['graph'] for d in ds.raw_dataset), pre_transform=preprocess_triplet)
num_edges = len(ds.indexer._edges)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SentenceTransformer().to(device)

In [7]:
fs, gs = create_remote_backend_from_triplets(triplets=triplets, node_embedding_model=model, node_method_to_call="encode", path="backend", pre_transform=preprocess_triplet, node_method_kwargs={"batch_size": 256}, graph_db=NeighborSamplingRAGGraphStore, feature_db=SentenceTransformerFeatureStore).load()

In [8]:
query_loader = RagQueryLoader(data=(fs, gs))

In [9]:
# Accuracy Metrics to be added to Profiler
def _eidx_helper(subg: Data, ground_truth: Data):
    subg_eidx, gt_eidx = subg.edge_idx, ground_truth.edge_idx
    if isinstance(subg_eidx, torch.Tensor):
        subg_eidx = subg_eidx.tolist()
    if isinstance(gt_eidx, torch.Tensor):
        gt_eidx = gt_eidx.tolist()
    subg_e = set(subg_eidx)
    gt_e = set(gt_eidx)
    return subg_e, gt_e
def check_retrieval_accuracy(subg: Data, ground_truth: Data, num_edges: int):
    subg_e, gt_e = _eidx_helper(subg, ground_truth)
    total_e = set(range(num_edges))
    tp = len(subg_e & gt_e)
    tn = len(total_e-(subg_e | gt_e))
    return (tp+tn)/num_edges
def check_retrieval_precision(subg: Data, ground_truth: Data):
    subg_e, gt_e = _eidx_helper(subg, ground_truth)
    return len(subg_e & gt_e) / len(subg_e)
def check_retrieval_recall(subg: Data, ground_truth: Data):
    subg_e, gt_e = _eidx_helper(subg, ground_truth)
    return len(subg_e & gt_e) / len(gt_e)

In [10]:
for subg, gt in zip((query_loader.query(q) for q in questions), ground_truth_graphs):
    print(check_retrieval_accuracy(subg, gt, num_edges), check_retrieval_precision(subg, gt), check_retrieval_recall(subg, gt))

10it [00:00, 59.40it/s]


0.47500327696945865 0.22471144012078118 0.5290677674578603
0.6829990824485516 0.10915320606950563 0.2714877039201363
0.4894219425874951 0.05582922824302135 0.5027726432532348
0.5146939310525626 0.14810335349092907 0.47180385288966725
0.48978896316686327 0.038508934072704865 0.5047106325706595
0.556979944946913 0.07315452458620272 0.4376581134803036
0.4486826582776249 0.1033375226923614 0.5568096313017307
0.6349980338183248 0.06187745246000604 0.35482475118996104
0.6690260846768908 0.1052010529182173 0.29521141110545085
0.5178660374885307 0.16869006042463552 0.47266231748990367
