In [2]:
# The goal of the following is to let users write tensor-like types with variadic 
# shapes that are amendable to both detailed static type checking and runtime type 
# checking.
#
# There are already libraries like
# - https://github.com/patrick-kidger/torchtyping 
# - https://github.com/ramonhagenaars/nptyping
#
# The former targets shape-based runtime type checking functionality in a pretty 
# general, but hacky  way. Its shape information is not amenable to static analyis. The 
# latter is more about expressing really verbose shapes... and stuff. It supports 
# runtime checking against *single* array expressions, but nothing beyond that. So, 
# quite limited. It also only works with mypy -- depends on mypy plugin.
#
#
# This is a really rough proof of concept in which I hijack the machinery of 
# `__class_getitem__` so that `Tensor[A, B]` returns a "phantom type" whose
# `__instancecheck__` enforces both isisntance(<obj>, Tensor) and <obj>.shape == (A, B)
# This means that static type checkers see `Tensor[A, B]`, and runtime type-checkers can
# call isinstance against the phantom type. Thus this method plugs and plays with all 
# runtime type checkers
#
# Some next steps: 
# - Add memoization so that the following passes: `assert Tensor[A, B] is Tensor[A, B]`
# - Make some weird context manager for binding shape-dimension symbols to particular 
#   values, so that we can check for mutual consistency across multiple tensors within 
#   that context. E.g. `def f(x: Tensor[A, B], y: Tensor[B, A]): ...`
# - Write a `parse` function that makes it trivial to validate and statically cast 
#   tensors to shape-typed narrowed versions of themselves
#



# requirements
# - Python 3.9+  
# - typing extensions
# - phantom-types
# - pytorch

from collections import defaultdict
from torch import Tensor as _Tensor
import torch as tr
from phantom import Phantom, PhantomMeta

from typing import Generic, NewType, TYPE_CHECKING, Any, cast, Hashable, Optional
from typing_extensions import TypeVarTuple, Unpack

A = NewType("A", int)
B = NewType("B", int)
C = NewType("C", int)

Shape = TypeVarTuple("Shape")


class NewMeta(PhantomMeta, type(_Tensor)):
    ...

class DimBinder:
    bindings: Optional[dict[Any, int]] = None

def check(shape_type: tuple[Hashable, ...], shape: tuple[int, ...]) -> bool:
    if len(shape_type) != len(shape):
        return False

    matches: defaultdict[Any, list[int]] = defaultdict(list)
    for n, a in enumerate(shape_type):
        matches[a].append(n)

    for indices in matches.values():
        if len(indices) == 1:
            continue
        first, *rest = indices
        a = shape[first]
        if not all(a == shape[b] for b in rest):
            return False
    return True
        

class Tensor(Generic[Unpack[Shape]], _Tensor):
    if not TYPE_CHECKING:
        @classmethod
        def __class_getitem__(cls, key):
            class PhantomTensor(
                _Tensor,
                Phantom,
                metaclass=NewMeta,
                predicate=lambda x: check(key, x.shape),
            ):
                ...

            return PhantomTensor

    @property
    def shape(self) -> tuple[Unpack[Shape]]: ...


In [70]:
H = NewType("A", int)
B = NewType("B", int)
C = NewType("C", int)


In [71]:
from typing import TypeVar
T1 = TypeVar("T1", bound=int)
T2 = TypeVar("T2", bound=int)



def make_tensor(shape) -> Tensor[A, B, B, C]: ...

def outer_transpose(x: Tensor[T1, Unpack[Shape], T2]) -> Tensor[T2, Unpack[Shape], T1]: ...

x = make_tensor(...)

reveal_type(x)  # reveals Tensor[A, B, B, C]
reveal_type(outer_transpose(x))  # reveals Tensor[C, B, B, A]

y = outer_transpose(x)


y.shape

NameError: name 'reveal_type' is not defined

In [7]:
from pytest import raises
from beartype import beartype


@beartype
def make_tensor(shape: tuple[int, int, int]) -> Tensor[A, B, A]:
    return cast(Tensor[A, B, A], tr.rand(shape))

@beartype
def f(x: Tensor[A, B]): ...

@beartype
def g(x: Tensor[A, B, A]): ...

x = make_tensor((1, 2, 1))

with raises(Exception):  # <- beartype raises
    make_tensor((1, 2))  # pyright flags as error

with raises(Exception):  # <- beartype raises
    f(x)  # pyright flags as error

g(x)

In [8]:
make_tensor((1, 2, 2))

BeartypeCallHintReturnViolation: @beartyped make_tensor() return "tensor([[[0.5726, 0.9630],
         [0.7850, 0.8574]]])" violates type hint <class '__main__.Tensor.__class_getitem__.<locals>.PhantomTensor'>, as "tensor([[[0.5726, 0.9630],
         [0.7850, 0.8574]]])" not instance of <protocol "__main__.PhantomTensor">.

In [75]:
make_tensor((1, 2, 1))

tensor([[[0.0346],
         [0.4760]]])

In [None]:
register_tensor_shit = lambda x: x
any_third_party_type_checker = lambda x: x

In [None]:
z = 1
def parse(*a: Any) -> Any: ...

In [None]:
x, y, z = parse((x, Tensor[A, B, B]), (y, Tensor[B]), (z, Tensor[B, A, C]))

In [77]:
@any_third_party_type_checker
@register_tensor_shit
def some_tensor_op(x: Tensor[A, B, B], y: Tensor[B], z: Tensor[B, A, C]) -> Tensor[A, C]: ...

In [76]:
make_tensor((1, 2, 2))

BeartypeCallHintReturnViolation: @beartyped make_tensor() return "tensor([[[0.9308, 0.6804],
         [0.0165, 0.4240]]])" violates type hint <class '__main__.Tensor.__class_getitem__.<locals>.PhantomTensor'>, as "tensor([[[0.9308, 0.6804],
         [0.0165, 0.4240]]])" not instance of <protocol "__main__.PhantomTensor">.

In [56]:
def make_tensor() -> Tensor[A, B]: ...

def f(x: Tensor[A, B], y: Tensor[B, A]): ...

x = make_tensor()

f(x, x)  # <- should also raise runtime error via beartype

f(x, outer_transpose(x))



False

In [57]:
assert not isinstance([1], Tensor[A, A])

assert isinstance(tr.rand(2, 2), Tensor[A, A])
assert not isinstance(tr.rand(2, 3), Tensor[A, A])

assert isinstance(tr.rand(2, 3), Tensor[A, B])

In [59]:
assert isinstance(tr.rand(1, 2, 3, 2, 1), Tensor[A, B, C, B, A])
assert not isinstance(tr.rand(1, 2, 3, 2, 1), Tensor[A, B, C, A, A])

In [6]:
issubclass(A, int)

TypeError: issubclass() arg 1 must be a class

In [5]:
def f() -> Tensor[A, B]: ...

def g(x: Tensor[B, A]): ...

x = f()
x.shape

<class '__main__.Tensor'> ((<function NewType.<locals>.new_type at 0x0000022F38CD7D30>, <function NewType.<locals>.new_type at 0x0000022F38CD7E50>),) {}
<class 'super'>


In [7]:
isinstance(tr.rand((2,1)), Tensor)

False