In [7]:
import os
import zipfile
import numpy as np
import torch
from torch_geometric.utils import dense_to_sparse
from six.moves import urllib
import torch_geometric_temporal.signal

In [10]:
help(torch_geometric_temporal.signal.StaticGraphTemporalSignal)

Help on class StaticGraphTemporalSignal in module torch_geometric_temporal.signal.static_graph_temporal_signal:

class StaticGraphTemporalSignal(builtins.object)
 |  StaticGraphTemporalSignal(edge_index: Union[numpy.ndarray, NoneType], edge_weight: Union[numpy.ndarray, NoneType], features: List[Union[numpy.ndarray, NoneType]], targets: List[Union[numpy.ndarray, NoneType]], **kwargs: List[numpy.ndarray])
 |  
 |  A data iterator object to contain a static graph with a dynamically
 |  changing constant time difference temporal feature set (multiple signals).
 |  The node labels (target) are also temporal. The iterator returns a single
 |  constant time difference temporal snapshot for a time period (e.g. day or week).
 |  This single temporal snapshot is a Pytorch Geometric Data object. Between two
 |  temporal snapshots the features and optionally passed attributes might change.
 |  However, the underlying graph is the same.
 |  
 |  Args:
 |      edge_index (Numpy array): Index tensor 

AttributeError: type object 'StaticGraphTemporalSignal' has no attribute '__file__'

In [15]:
torch_geometric_temporal.signal.static_graph_temporal_signal.__file__

'/home/yang/miniconda3/envs/pygt/lib/python3.8/site-packages/torch_geometric_temporal/signal/static_graph_temporal_signal.py'

import torch
import numpy as np
from typing import List, Union
from torch_geometric.data import Data


Edge_Index = Union[np.ndarray, None]
Edge_Weight = Union[np.ndarray, None]
Node_Features = List[Union[np.ndarray, None]]
Targets = List[Union[np.ndarray, None]]
Additional_Features = List[np.ndarray]


class StaticGraphTemporalSignal(object):
    r"""A data iterator object to contain a static graph with a dynamically
    changing constant time difference temporal feature set (multiple signals).
    The node labels (target) are also temporal. The iterator returns a single
    constant time difference temporal snapshot for a time period (e.g. day or week).
    This single temporal snapshot is a Pytorch Geometric Data object. Between two
    temporal snapshots the features and optionally passed attributes might change.
    However, the underlying graph is the same.

    Args:
        edge_index (Numpy array): Index tensor of edges.
        edge_weight (Numpy array): Edge weight tensor.
        features (List of Numpy arrays): List of node feature tensors.
        targets (List of Numpy arrays): List of node label (target) tensors.
        **kwargs (optional List of Numpy arrays): List of additional attributes.
    """

    def __init__(
        self,
        edge_index: Edge_Index,
        edge_weight: Edge_Weight,
        features: Node_Features,
        targets: Targets,
        **kwargs: Additional_Features
    ):
        self.edge_index = edge_index
        self.edge_weight = edge_weight
        self.features = features
        self.targets = targets
        self.additional_feature_keys = []
        for key, value in kwargs.items():
            setattr(self, key, value)
            self.additional_feature_keys.append(key)
        self._check_temporal_consistency()
        self._set_snapshot_count()

    def _check_temporal_consistency(self):
        assert len(self.features) == len(
            self.targets
        ), "Temporal dimension inconsistency."
        for key in self.additional_feature_keys:
            assert len(self.targets) == len(
                getattr(self, key)
            ), "Temporal dimension inconsistency."

    def _set_snapshot_count(self):
        self.snapshot_count = len(self.features)

    def _get_edge_index(self):
        if self.edge_index is None:
            return self.edge_index
        else:
            return torch.LongTensor(self.edge_index)

    def _get_edge_weight(self):
        if self.edge_weight is None:
            return self.edge_weight
        else:
            return torch.FloatTensor(self.edge_weight)

    def _get_features(self, time_index: int):
        if self.features[time_index] is None:
            return self.features[time_index]
        else:
            return torch.FloatTensor(self.features[time_index])

    def _get_target(self, time_index: int):
        if self.targets[time_index] is None:
            return self.targets[time_index]
        else:
            if self.targets[time_index].dtype.kind == "i":
                return torch.LongTensor(self.targets[time_index])
            elif self.targets[time_index].dtype.kind == "f":
                return torch.FloatTensor(self.targets[time_index])

    def _get_additional_feature(self, time_index: int, feature_key: str):
        feature = getattr(self, feature_key)[time_index]
        if feature.dtype.kind == "i":
            return torch.LongTensor(feature)
        elif feature.dtype.kind == "f":
            return torch.FloatTensor(feature)

    def _get_additional_features(self, time_index: int):
        additional_features = {
            key: self._get_additional_feature(time_index, key)
            for key in self.additional_feature_keys
        }
        return additional_features

    def __getitem__(self, time_index: int):
        x = self._get_features(time_index)
        edge_index = self._get_edge_index()
        edge_weight = self._get_edge_weight()
        y = self._get_target(time_index)
        additional_features = self._get_additional_features(time_index)

        snapshot = Data(x=x, edge_index=edge_index, edge_attr=edge_weight,
                        y=y, **additional_features)
        return snapshot

    def __next__(self):
        if self.t < len(self.features):
            snapshot = self[self.t]
            self.t = self.t + 1
            return snapshot
        else:
            self.t = 0
            raise StopIteration

    def __iter__(self):
        self.t = 0
        return self