In [None]:
from torch_geometric.datasets import UpdatedWebQSPDataset
from raw_qsp_dataset import RawWebQSPDataset
from torch_geometric.profile import profileit, timeit
from torch_geometric.profile.profile import GPUStats
from torch.profiler import  profile
from typing import Protocol, Type, List, Tuple, Any
from abc import abstractmethod
import torch

In [None]:
class Profilable(Protocol):
    model: torch.nn.Module
    device: torch.device

    @abstractmethod
    def _build_graph(self) -> None:
        pass

    @abstractmethod
    def _retrieve_subgraphs(self) -> None:
        pass

def make_profilable(dataset_obj: Type[Profilable]) -> Type[Profilable]:
    dec = profileit("cuda")

    class ProfilableObject(dataset_obj):
        def __init__(self, *args, **kwargs) -> None:
            self.desc = dict()
            self.parent_cls = super()
            self.parent_cls.__init__(*args, **kwargs)

        def _build_graph(self) -> None:
            device_tensor = torch.Tensor().to(self.device)
            wrap = dec(lambda model, dev_tensor: self.parent_cls._build_graph())
            ret, desc = wrap(self.model, device_tensor)
            self.desc['_build_graph'] = desc
            return ret
        
        def _retrieve_subgraphs(self) -> None:
            device_tensor = torch.Tensor().to(self.device)
            wrap = dec(lambda model, dev_tensor: self.parent_cls._retrieve_subgraphs())
            ret, desc = wrap(self.model, device_tensor)
            self.desc['_retrieve_subgraphs'] = desc
            return ret
    
    return ProfilableObject

In [None]:
ds = UpdatedWebQSPDataset(root="profiled_ds", force_reload=True)

In [None]:
ds = UpdatedWebQSPDataset(root="profiled_ds_wholegraph", force_reload=True, whole_graph_retrieval=True)

In [None]:
profilable_ds: Type[UpdatedWebQSPDataset] = make_profilable(UpdatedWebQSPDataset)

In [None]:
dataset: UpdatedWebQSPDataset = profilable_ds(root="profiled_ds", force_reload=True, limit=100)

In [None]:
dataset.desc['_retrieve_subgraphs']

In [None]:
with profile(profile_memory=True, with_stack=True, record_shapes=True) as prof:
    ds = UpdatedWebQSPDataset(root="profiled_ds", force_reload=True, limit=10)
    del ds

In [None]:
prof.export_chrome_trace('timeline.json')

In [None]:
prof.export_memory_timeline('timeline_mem.html')

In [None]:
avgs = prof.key_averages()

In [None]:
print(avgs.table(sort_by='cpu_time'))

In [None]:
from torchmetrics.functional import pairwise_cosine_similarity
import torch

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
i1 = torch.rand((4700,1024)).to(device)
i2 = torch.rand((1000000,1024)).to(device)
i1 = pairwise_cosine_similarity(i1, i2)

In [1]:
from torch_geometric.datasets import UpdatedWebQSPDataset
from torch_geometric.data import LargeGraphIndexer
from itertools import chain

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ds = UpdatedWebQSPDataset('small_ds', force_reload=True, limit=10)

Processing...


Loading graph...
Encoding questions...
Retrieving subgraphs...


10it [00:00, 59.05it/s]00:00<?, ?it/s]
100%|██████████| 10/10 [00:01<00:00,  5.69it/s]

Saving subgraphs...



Done!


In [3]:
graph = ds.indexer.to_data(node_feature_name='x', edge_feature_name='edge_attr')

In [4]:
graph.edge_index = graph.edge_index.contiguous()

In [22]:
from torch_geometric.loader import LinkNeighborLoader, LinkLoader
from torch_geometric.sampler import NeighborSampler, BaseSampler

In [None]:
from torch_geometric.sampler.base import EdgeSamplerInput, HeteroSamplerOutput, NegativeSampling, NodeSamplerInput, SamplerOutput


class IdentitySampler(BaseSampler):
    def sample_from_nodes(self, index: NodeSamplerInput, **kwargs) -> HeteroSamplerOutput | SamplerOutput:
        if index.input_type is not None: # Heterogeneous
            out = HeteroSamplerOutput(node={index.input_type: index.node}, row=dict(), col=dict(), edge=dict())
        else:
            out = SamplerOutput(node=index.node, row=torch.Tensor(), col=torch.Tensor())
        return out
    
    def sample_from_edges(self, index: EdgeSamplerInput, neg_sampling: NegativeSampling | None = None) -> HeteroSamplerOutput | SamplerOutput:
        EdgeSamplerInput()
        if index.input_type is not None: # Heterogeneous
            out = HeteroSamplerOutput(node=index.)

In [141]:
link_sampler = NeighborSampler(data=graph, num_neighbors=[1], replace=True)

In [142]:
load = LinkLoader(graph, link_sampler=link_sampler)

In [159]:
link_sampler.num_nodes

11466

In [19]:
load2 = LinkNeighborLoader(graph, num_neighbors=[2])

In [156]:
result = load([1])

In [154]:
result.e_id

tensor([17505, 24971])

In [107]:
graph.edge_index

tensor([[10097,  4044,  5397,  ...,  7188,  8251,  3597],
        [10097,   673,  5827,  ...,  7594,   673, 10951]])

In [153]:
ds.indexer.node_attr['pid'][4759]

'm.09g3c50'

In [155]:
ds.indexer.edge_attr["e_pid"][1]

('washington redskins at oakland raiders, 2009-12-13',
 'american_football.football_game.away_team',
 'washington redskins')