In [17]:
from torch_geometric.datasets import UpdatedWebQSPDataset
from torch_geometric.profile import profileit, timeit
from torch_geometric.profile.profile import GPUStats
from typing import Protocol, Type, List, Tuple, Any
from abc import abstractmethod
import torch

In [18]:
class Profilable(Protocol):
    model: torch.nn.Module
    device: torch.device

    @abstractmethod
    def _build_graph(self) -> None:
        pass

    @abstractmethod
    def _retrieve_subgraphs(self) -> None:
        pass

def make_profilable(dataset_obj: Type[Profilable]) -> Type[Profilable]:
    dec = profileit("cuda")

    class ProfilableObject(dataset_obj):
        def __init__(self, *args, **kwargs) -> None:
            self.desc = dict()
            self.parent_cls = super()
            self.parent_cls.__init__(*args, **kwargs)

        def _build_graph(self) -> None:
            device_tensor = torch.Tensor().to(self.device)
            wrap = dec(lambda model, dev_tensor: self.parent_cls._build_graph())
            ret, desc = wrap(self.model, device_tensor)
            self.desc['_build_graph'] = desc
            return ret
        
        def _retrieve_subgraphs(self) -> None:
            device_tensor = torch.Tensor().to(self.device)
            wrap = dec(lambda model, dev_tensor: self.parent_cls._retrieve_subgraphs())
            ret, desc = wrap(self.model, device_tensor)
            self.desc['_retrieve_subgraphs'] = desc
            return ret
    
    return ProfilableObject

In [19]:
profilable_ds: Type[UpdatedWebQSPDataset] = make_profilable(UpdatedWebQSPDataset)

In [21]:
dataset: UpdatedWebQSPDataset = profilable_ds(root="profiled_ds", force_reload=True, limit=2)

Processing...


Encoding graph...


TypeError: super(type, obj): obj must be an instance or subtype of type