In [5]:
import torch
from torchtyping import TensorType

from torch_geometric.data import Data, Batch
from torch_geometric.tp import Graph, typecheck

In [11]:
d = Data()
d.x = torch.randn((1,2,3))

@typecheck
def forward(d: Graph["x": torch.Tensor]):
    print(d)

forward(d)

Data(x=[1, 2, 3])


In [18]:
d = Data()
d.x = "test"

@typecheck
def forward(d: Graph["x": torch.Tensor]):
    print(d)

# Raises TypeError as "x" is not a tensor
forward(d)

TypeError: <class 'str'> is not a subtype of <class 'torch.Tensor'> for graph attribute x

In [19]:
d = Data()
d.x = torch.randn((1,2,3))
d.y = "testing"

@typecheck
def forward(d: Graph["x": torch.Tensor]):
    print(d)

# Raises TypeError as we specify the data should only contain the tensor "x"
forward(d)

TypeError: d Data attributes {'x', 'y'} do not match required set {'x'}

In [20]:
d = Data()
d.x = torch.randn((1,2,3))
d.y = "testing"

@typecheck
def forward(d: Graph["x": torch.Tensor, ...]):
    print(d)

# Specify only some args with ...
forward(d)

Data(x=[1, 2, 3], y='testing')


In [23]:
d = Data()
d.x = torch.randn((1,2,3))
d.y = "testing"

@typecheck
def forward(d: Graph["x": TensorType[-1, -1, 3], ...]):
    print(d)

# Our tensor is of the right shape!
forward(d)

Data(x=[1, 2, 3], y='testing')


In [24]:
d = Data()
d.x = torch.randn((1,2,3))
d.y = "testing"

@typecheck
def forward(d: Graph["x": TensorType[-1, 20, 3], ...]):
    print(d)

# Raises TypeError as the tensor "x" is of the wrong shape
forward(d)

TypeError: x must be of type TensorType[-1, 20, 3], got type TensorType[1, 2, 3] instead.

In [27]:
# Equivalent for Batches
from torch_geometric.tp import GraphBatch

d = Data()
d.x = torch.randn((1,2,3))
d.y = "testing"

@typecheck
def forward(d: GraphBatch["x": TensorType[-1, -1, 3], ...]):
    print(d)

b = Batch.from_data_list([d])

forward(b)

DataBatch(x=[1, 2, 3], y=[1], batch=[1], ptr=[2])
