In [1]:
from abc import ABC, abstractmethod
from typing import Optional, Any, List, Iterable, Tuple, Dict, Union, Set
from collections.abc import Hashable
from torch_geometric.data import FeatureStore, InMemoryDataset, GraphStore
from torch_geometric.data.feature_store import TensorAttr, FeatureTensorType
from torch_geometric.distributed.local_graph_store import LocalGraphStore
from torch_geometric.datasets.web_qsp_dataset import WebQSPDataset
import torch
import datasets

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class RawWebQSPDataset(WebQSPDataset):

    def __init__(
        self,
        root: str = "",
        force_reload: bool = False,
    ) -> None:
        self._check_dependencies()
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        super(InMemoryDataset, self).__init__(root, None, None, force_reload=force_reload)
        self._load_raw_data()
    
    @property
    def raw_file_names(self) -> List[str]:
        return ["raw_data", "split_idxs"]

    def _save_raw_data(self) -> None:
        self.raw_dataset.save_to_disk(self.raw_paths[0])
        torch.save(self.split_idxs, self.raw_paths[1]) 

    def _load_raw_data(self) -> None:
        self.raw_dataset = datasets.load_from_disk(self.raw_paths[0])
        self.split_idxs = torch.load(self.raw_paths[1])
    
    def download(self) -> None:
        super().download()
        self._save_raw_data()

    def process(self) -> None:
        pass

In [3]:
dataset = RawWebQSPDataset()

Processing...
Done!


In [4]:
dataset.raw_dataset[0]['graph']

[['P!nk', 'freebase.valuenotation.is_reviewed', 'Gender'],
 ['1Club.FM: Power', 'broadcast.content.artist', 'P!nk'],
 ['Somebody to Love', 'music.recording.contributions', 'm.0rqp4h0'],
 ['Rudolph Valentino', 'freebase.valuenotation.is_reviewed', 'Place of birth'],
 ['Ice Cube', 'broadcast.artist.content', '.977 The Hits Channel'],
 ['Colbie Caillat', 'broadcast.artist.content', 'Hot Wired Radio'],
 ['Stephen Melton', 'people.person.nationality', 'United States of America'],
 ['Record producer',
  'music.performance_role.regular_performances',
  'm.012m1vf1'],
 ['Justin Bieber', 'award.award_winner.awards_won', 'm.0yrkc0l'],
 ['1.FM Top 40', 'broadcast.content.artist', 'Geri Halliwell'],
 ['2011 Teen Choice Awards',
  'award.award_ceremony.awards_presented',
  'm.0yrkr34'],
 ['m.012bm2v1', 'celebrities.friendship.friend', 'Miley Cyrus'],
 ['As Long As You Love Me (Ferry Corsten radio)',
  'common.topic.notable_types',
  'Musical Recording'],
 ['Toby Gad', 'music.artist.genre', 'Rhythm 

In [5]:
test_graph_store = LocalGraphStore()

In [6]:
attr = dict(
    edge_type=None,
    layout='coo',
    size=(2,2),
    is_sorted=False
)
test_graph_store.put_edge_index(torch.Tensor([[0,1], [1,2]]), **attr)

True

In [7]:
test_graph_store.get_all_edge_attrs()

[EdgeAttr(edge_type=None, layout=<EdgeLayout.COO: 'coo'>, is_sorted=False, size=(2, 2))]

In [8]:
from large_graph_indexer import LargeGraphIndexer

In [9]:
indexer = LargeGraphIndexer.from_triplets(dataset.raw_dataset[0]['graph'])

In [10]:
indexer2 = LargeGraphIndexer.from_triplets(dataset.raw_dataset[1]['graph'])

In [11]:
bigger_indexer = LargeGraphIndexer.collate([indexer, indexer2])

In [12]:
assert len(indexer.nodes) + len(indexer2.nodes) - len(indexer.nodes.keys() & indexer2.nodes.keys()) == len(bigger_indexer.nodes)
assert len(indexer.edges) + len(indexer2.edges) - len(indexer.edges.keys() & indexer2.edges.keys()) == len(bigger_indexer.edges)

In [13]:
assert len(set(bigger_indexer.nodes.values())) == len(bigger_indexer.nodes)
assert len(set(bigger_indexer.edges.values())) == len(bigger_indexer.edges)

In [14]:
for node in indexer.nodes.keys():
    assert indexer.node_attr[indexer.nodes[node]]["pid"] == node, f'{node} is not {indexer.node_attr[indexer.nodes[node]]["pid"]}'

In [15]:
import tqdm
from multiprocessing import Pool

In [16]:
indexers = []
#TODO: can probably be parallelized
with Pool(40) as p:
    indexers = list(tqdm.tqdm(p.imap(LargeGraphIndexer.from_triplets, [ds['graph'] for ds in dataset.raw_dataset]), total=len(dataset.raw_dataset)))

KeyboardInterrupt: 

In [16]:
#FIXME: right now this is really slow
big_indexer = LargeGraphIndexer.collate(tqdm.tqdm(indexers), skip_shared_check=True)

NameError: name 'indexers' is not defined

In [17]:
from itertools import chain

In [18]:
large_graph_dataset = chain.from_iterable([ds['graph'] for ds in dataset.raw_dataset])

In [19]:
total_size = 0
for ds in tqdm.tqdm(dataset.raw_dataset):
    total_size += len(ds['graph'])

100%|██████████| 4700/4700 [00:51<00:00, 92.08it/s] 


In [20]:
large_indexer = LargeGraphIndexer.from_triplets(tqdm.tqdm(large_graph_dataset, total=total_size))

100%|██████████| 19986134/19986134 [00:20<00:00, 956393.67it/s] 


In [21]:
node_attributes = list(large_indexer.get_unique_node_features("pid"))
node_attributes = [i.lower() for i in node_attributes]

In [22]:
edge_attributes = list(large_indexer.get_unique_edge_features("r"))
edge_attributes = [i.lower() for i in edge_attributes]

In [23]:
len(edge_attributes)

6094

In [24]:
from torch_geometric.nn.text import text2embedding, SentenceTransformer

In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SentenceTransformer("sentence-transformers/all-roberta-large-v1").to(device)
model.eval()

inherit model weights from sentence-transformers/all-roberta-large-v1


SentenceTransformer(
  (bert_model): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 1024, padding_idx=1)
      (position_embeddings): Embedding(514, 1024, padding_idx=1)
      (token_type_embeddings): Embedding(1, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-23): 24 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
           

In [26]:
print(device)

cuda


In [27]:
from more_itertools import chunked

In [28]:
# Indexing graph features
BATCH_SIZE = 256
node_embs = []
for nbatch in tqdm.tqdm(chunked(node_attributes, BATCH_SIZE), total=len(node_attributes)//BATCH_SIZE):
    node_embs.append(text2embedding(model, device, nbatch))

#node_embs = text2embedding(model, device, node_attributes)
#edge_embs = text2embedding(model, device, edge_attributes)



  0%|          | 0/5071 [00:00<?, ?it/s]

  7%|▋         | 362/5071 [01:41<19:57,  3.93it/s]

In [None]:
# Indexing graph features
BATCH_SIZE = 256
edge_embs = []
for ebatch in tqdm.tqdm(chunked(edge_attributes, BATCH_SIZE), total=len(edge_attributes)//BATCH_SIZE):
    edge_embs.append(text2embedding(model, device, ebatch))


100%|██████████| 1/1 [00:00<00:00,  2.79it/s]
100%|██████████| 1/1 [00:00<00:00, 77.62it/s]/s]
100%|██████████| 1/1 [00:00<00:00, 80.12it/s]/s]
100%|██████████| 1/1 [00:00<00:00, 81.32it/s]/s]
100%|██████████| 1/1 [00:00<00:00, 80.06it/s]/s]
100%|██████████| 1/1 [00:00<00:00, 75.51it/s]/s]
100%|██████████| 1/1 [00:00<00:00, 80.41it/s]/s]
100%|██████████| 1/1 [00:00<00:00, 82.53it/s]/s]
100%|██████████| 1/1 [00:00<00:00, 80.63it/s]/s]
100%|██████████| 1/1 [00:00<00:00, 79.81it/s]/s]
100%|██████████| 1/1 [00:00<00:00, 83.74it/s]t/s]
100%|██████████| 1/1 [00:00<00:00, 83.16it/s]t/s]
100%|██████████| 1/1 [00:00<00:00, 86.36it/s]t/s]
100%|██████████| 1/1 [00:00<00:00, 77.53it/s]t/s]
100%|██████████| 1/1 [00:00<00:00, 78.36it/s]t/s]
100%|██████████| 1/1 [00:00<00:00, 86.84it/s]t/s]
100%|██████████| 1/1 [00:00<00:00, 79.24it/s]t/s]
100%|██████████| 1/1 [00:00<00:00, 80.22it/s]t/s]
100%|██████████| 1/1 [00:00<00:00, 75.86it/s]t/s]
100%|██████████| 1/1 [00:00<00:00, 83.67it/s]t/s]
100%|████████

SBERT text embedding failed, returning torch.zeros((0, 1024))...


100%|██████████| 1/1 [00:00<00:00, 84.56it/s]
100%|██████████| 1/1 [00:00<00:00, 80.32it/s]t/s]
100%|██████████| 1/1 [00:00<00:00, 84.78it/s]t/s]
100%|██████████| 1/1 [00:00<00:00, 81.12it/s]t/s]
100%|██████████| 1/1 [00:00<00:00, 84.48it/s]t/s]
100%|██████████| 1/1 [00:00<00:00, 80.79it/s]t/s]
100%|██████████| 1/1 [00:00<00:00, 86.20it/s]t/s]
100%|██████████| 1/1 [00:00<00:00, 73.87it/s]t/s]
100%|██████████| 1/1 [00:00<00:00, 83.26it/s]1s/it]
100%|██████████| 1/1 [00:00<00:00, 85.60it/s]1it/s]
100%|██████████| 1/1 [00:00<00:00, 79.75it/s]0it/s]
100%|██████████| 1/1 [00:00<00:00, 87.54it/s]t/s]  
100%|██████████| 1/1 [00:00<00:00, 81.08it/s]t/s]
100%|██████████| 1/1 [00:00<00:00, 86.16it/s]t/s]
100%|██████████| 1/1 [00:00<00:00, 80.72it/s]t/s]
100%|██████████| 1/1 [00:00<00:00, 67.81it/s]t/s]
100%|██████████| 1/1 [00:00<00:00, 85.91it/s]t/s]
100%|██████████| 1/1 [00:00<00:00, 88.85it/s]t/s]
100%|██████████| 1/1 [00:00<00:00, 82.38it/s]t/s]
100%|██████████| 1/1 [00:00<00:00, 86.72it/s]t