In [1]:
# The goal of the following is to let users write tensor-like types with variadic
# shapes that are amendable to both detailed static type checking and runtime type
# checking.
#
# There are already libraries like
# - https://github.com/patrick-kidger/torchtyping
# - https://github.com/ramonhagenaars/nptyping
#
# The former targets shape-based runtime type checking functionality in a pretty
# general, but hacky  way. Its shape information is not amenable to static analyis. The
# latter is more about expressing really verbose shapes... and stuff. It supports
# runtime checking against *single* array expressions, but nothing beyond that. So,
# quite limited. It also only works with mypy -- depends on mypy plugin.
#
#
# This is a really rough proof of concept in which I hijack the machinery of
# `__class_getitem__` so that `Tensor[A, B]` returns a "phantom type" whose
# `__instancecheck__` enforces both isisntance(<obj>, Tensor) and <obj>.shape == (A, B)
# This means that static type checkers see `Tensor[A, B]`, and runtime type-checkers can
# call isinstance against the phantom type. Thus this method plugs and plays with all
# runtime type checkers
#
# Some next steps:
# - Add memoization so that the following passes: `assert Tensor[A, B] is Tensor[A, B]`
# - (DONE) Make some weird context manager for binding shape-dimension symbols to particular
#   values, so that we can check for mutual consistency across multiple tensors within
#   that context. E.g. `def f(x: Tensor[A, B], y: Tensor[B, A]): ...`
# - Write a `parse` function that makes it trivial to validate and statically cast
#   tensors to shape-typed narrowed versions of themselves
#

# requirements
# - Python 3.9+
# - typing extensions
# - phantom-types
# - pytorch


from collections import defaultdict
from torch import Tensor as _Tensor
import torch as tr
from phantom import Phantom, PhantomMeta

from typing import (
    Generic,
    NewType,
    TYPE_CHECKING,
    Any,
    cast,
    Hashable,
    Optional,
    Callable,
    TypeVar,
    overload, Type, Iterator
)
from typing_extensions import TypeVarTuple, Unpack, TypeGuard, TypeAlias

from functools import wraps


In [2]:

F = TypeVar('F', bound=Callable[..., Any])

class DimBinder:
    bindings: Optional[dict[Any, int]] = None

class DimBindContext:
    _depth: int = 0

    def __enter__(self):
        self._depth += 1
        if self._depth == 1:
            DimBinder.bindings = {}

    def __exit__(self, exc_type, exc_val, exc_tb):
        self._depth -= 1
        if self._depth == 0:
            DimBinder.bindings = None

    def __call__(self, func: F) -> F:
        @wraps(func)
        def wrapper(*args, **kwargs):
            with self:
                return func(*args, **kwargs)

        return cast(F, wrapper)

dim_binding_scope = DimBindContext()

In [37]:
from __future__ import annotations
A = NewType("A", int)
B = NewType("B", int)
C = NewType("C", int)

Shape = TypeVarTuple("Shape")


def check(shape_type: tuple[Hashable, ...], shape: tuple[int, ...]) -> bool:

    if len(shape_type) != len(shape):  # TODO: permit arbitrary len-shape
        return False

    # E.g. Tensor[A, B, B, C] :: matches == {A: [0], B: [1, 2], C: [3]}
    matches: defaultdict[Any, list[int]] = defaultdict(list)
    for n, a in enumerate(shape_type):
        matches[a].append(n)

    for symbol, indices in matches.items():
        if len(indices) == 1 and DimBinder.bindings is None:
            continue

        first, *rest = indices
        if DimBinder.bindings is None or symbol is int:
            a = shape[first]
        else:
            _a = shape[first]
            a = DimBinder.bindings.setdefault(symbol, _a)
            if a != _a:
                return False
        if not all(a == shape[b] for b in rest):
            return False
    return True
        
        

class NewMeta(PhantomMeta, type(_Tensor)):
    ...
    
class Tensor(Generic[Unpack[Shape]], _Tensor):
    if not TYPE_CHECKING:
        _cache = {}
        @classmethod
        def __class_getitem__(cls, key):
            if not isinstance(key, tuple):
                key = (key,)
                
            kk = tuple(k.__name__ for k in key)
            if kk in cls._cache:
                return cls._cache[kk]

            class PhantomTensor(
                _Tensor,
                Phantom,
                metaclass=NewMeta,
                predicate=lambda x: check(key, x.shape),
            ):
                _shape = key
            cls._cache[kk] = PhantomTensor
            return PhantomTensor

    @property
    def shape(self) -> tuple[Unpack[Shape]]: ...


In [41]:
from __future__ import annotations

T1 = TypeVar("T1", bound=_Tensor)
T2 = TypeVar("T2", bound=_Tensor)
T3 = TypeVar("T3", bound=_Tensor)
T4 = TypeVar("T4", bound=_Tensor)

@overload
def parse(
    __a: tuple[_Tensor, Type[T1]],
    __b: tuple[_Tensor, Type[T2]],
    __c: tuple[_Tensor, Type[T3]],
    __4: tuple[_Tensor, Type[T4]],
) -> tuple[T1, T2, T3, T4]:
    ...


@overload
def parse(
    __a: tuple[_Tensor, Type[T1]],
    __b: tuple[_Tensor, Type[T2]],
    __c: tuple[_Tensor, Type[T3]],
) -> tuple[T1, T2, T3]:
    ...


@overload
def parse(
    __a: _Tensor,
    __b: Type[T1],
) -> T1:
    ...


@overload
def parse(
    __a: tuple[_Tensor, Type[T1]],
    __b: tuple[_Tensor, Type[T2]],
) -> tuple[T1, T2]:
    ...


@overload
def parse(__a: tuple[_Tensor, Type[T1]]) -> T1:
    ...


@overload
def parse(
    *tensor_type_pairs: tuple[_Tensor, Type[_Tensor]] | _Tensor | Type[_Tensor]
) -> _Tensor | tuple[_Tensor, ...]:
    ...


@dim_binding_scope
def parse(
    *tensor_type_pairs: tuple[_Tensor, Type[_Tensor]] | _Tensor | Type[_Tensor]
) -> _Tensor | tuple[_Tensor, ...]:
    out = []
    if len(tensor_type_pairs) == 0: raise ValueError("")
    if len(tensor_type_pairs) == 2 and not isinstance(tensor_type_pairs[0], tuple):
        tensor_type_pairs = (tensor_type_pairs,) # type: ignore

    for tensor, type_ in tensor_type_pairs:
        if not isinstance(tensor, type_.__bound__):  # type: ignore
            raise TypeError(f"Expected Tensor, got: {type(tensor)}")

        type_shape = type_._shape  # type: ignore
        if not check(type_shape, tensor.shape):
            assert DimBinder.bindings is not None
            type_str = ", ".join(
                f"{p.__name__}={DimBinder.bindings[p]}" for p in type_shape
            )
            raise TypeError(
                f"shape-{tuple(tensor.shape)} doesn't match shape-type ({type_str})"
            )
        out.append(tensor)
    if len(out) == 1:
        return out[0]
    return tuple(out)


In [42]:
x = parse(tr.rand(2, 3), Tensor[A, B])
x, y = parse(
    (tr.rand(2, 3), Tensor[A, B]),
    (tr.rand(2, 3), Tensor[A, B]),
)


In [43]:
x = parse((tr.rand(2, 3), Tensor[A, B]))
x, y = parse(
    (tr.rand(4, 3), Tensor[A, B]),
    (tr.rand(4, 3), Tensor[A, B]),
)


In [44]:
with dim_binding_scope:

    assert isinstance(tr.rand(2), Tensor[A])  # binds A=2
    assert not isinstance(tr.rand(3), Tensor[A])  # nope!
    assert isinstance(tr.rand(2), Tensor[A])  # yep!


    assert isinstance(tr.rand(2, 4), Tensor[A, B])  # binds B=4
    assert not isinstance(tr.rand(2), Tensor[B])  # nope!
    assert isinstance(tr.rand(4), Tensor[B])  # yep!
    assert isinstance(tr.rand(4, 2, 2, 4), Tensor[B, A, A, B])  # yep!

assert isinstance(tr.rand(1, 3, 3, 1), Tensor[B, A, A, B])  # no dims bound
assert isinstance(tr.rand(1, 4, 4, 1), Tensor[B, A, A, B])  # no dims bound


In [None]:
oo = parse((tr.rand(2), Tensor[A]))

In [None]:
(x @ y) - z

In [45]:
import pytest

with dim_binding_scope:

    parse(tr.rand(2), Tensor[A])  # binds A=2
    
    with pytest.raises(TypeError):
        parse(tr.rand(3), Tensor[A])  # nope!
    
    parse(tr.rand(2), Tensor[A])  # yep!


    parse(tr.rand(2, 4), Tensor[A, B])  # binds B=4
    
    with pytest.raises(TypeError):
        parse(tr.rand(2), Tensor[B])  # nope!

    parse(tr.rand(4), Tensor[B])  # yep!
    parse(tr.rand(4, 2, 2, 4), Tensor[B, A, A, B])  # yep!

x = parse(tr.rand(1, 3, 3, 1), Tensor[B, A, A, B])  # no dims bound
parse(tr.rand(1, 4, 4, 1), Tensor[B, A, A, B]);  # no dims bound

In [119]:
Tensor._cache

{('A',): __main__.Tensor.__class_getitem__.<locals>.PhantomTensor,
 ('A', 'B'): __main__.Tensor.__class_getitem__.<locals>.PhantomTensor,
 ('B',): __main__.Tensor.__class_getitem__.<locals>.PhantomTensor,
 ('B',
  'A',
  'A',
  'B'): __main__.Tensor.__class_getitem__.<locals>.PhantomTensor}

In [58]:
from beartype import beartype

@dim_binding_scope 
# ^ ensures A, B, C consistent across all input/output tensor shapes
#   within scope of function
@beartype
def matrix_multiply(x: Tensor[A, B], y: Tensor[B, C]) -> Tensor[A, C]:
    a, b = x.shape
    b, c = y.shape
    return cast(Tensor[A, C], tr.rand(a, c))

@beartype
def needs_vector(x: Tensor[int]): ...


x, y = parse(
    (tr.rand(3, 4), Tensor[A, B]),
    (tr.rand(4, 5), Tensor[B, C]),
)
x  # type revealed: Tensor[A, B]
y  # type revealed: Tensor[B, C]

z = matrix_multiply(x, y)
z  # type revealed: Tensor[A, C]

with pytest.raises(Exception):
    needs_vector(z)  # beartype will roar!

with pytest.raises(Exception):
    matrix_multiply(x, x)  # beartype will roar!

In [47]:
x = parse((tr.rand(3, 4), Tensor[A, B]))
y = parse((tr.rand(3, 4), Tensor[A, B]))

matrix_multiply(x, y)

BeartypeCallHintParamViolation: @beartyped matrix_multiply() parameter y="tensor([[0.3709, 0.7728, 0.2731, 0.9063],
        [0.0662, 0.8832, 0.8433, 0.8805],
       .... violates type hint <class '__main__.Tensor.__class_getitem__.<locals>.PhantomTensor'>, as "tensor([[0.3709, 0.7728, 0.2731, 0.9063],
        [0.0662, 0.8832, 0.8433, 0.8805],
       .... not instance of <protocol "__main__.PhantomTensor">.

In [121]:
x = tr.rand(8, 9)
y = tr.rand(9, 8)
out = f(x, y)

In [122]:
x = tr.rand(2, 4)
y = tr.rand(1, 2)  # should be (2, 4)
f(x, y)

BeartypeCallHintParamViolation: @beartyped f() parameter y="tensor([[0.9299, 0.4118]])" violates type hint <class '__main__.Tensor.__class_getitem__.<locals>.PhantomTensor'>, as "tensor([[0.9299, 0.4118]])" not instance of <protocol "__main__.PhantomTensor">.

In [1]:
1**"str"

TypeError: unsupported operand type(s) for ** or pow(): 'int' and 'str'

In [10]:
H = NewType("A", int)
B = NewType("B", int)
C = NewType("C", int)


In [71]:
from typing import TypeVar
T1 = TypeVar("T1", bound=int)
T2 = TypeVar("T2", bound=int)



def make_tensor(shape) -> Tensor[A, B, B, C]: ...

def outer_transpose(x: Tensor[T1, Unpack[Shape], T2]) -> Tensor[T2, Unpack[Shape], T1]: ...

x = make_tensor(...)

reveal_type(x)  # reveals Tensor[A, B, B, C]
reveal_type(outer_transpose(x))  # reveals Tensor[C, B, B, A]

y = outer_transpose(x)


y.shape

NameError: name 'reveal_type' is not defined

In [43]:
DimBinder.bindings

{<function typing.NewType.<locals>.new_type(x)>: 2,
 <function typing.NewType.<locals>.new_type(x)>: 4}

In [49]:
from pytest import raises
from beartype import beartype


@beartype
def make_tensor(shape: tuple[int, int, int]) -> Tensor[A, B, A]:
    return cast(Tensor[A, B, A], tr.rand(shape))

@beartype
def f(x: Tensor[A, B]): ...

@beartype
def g(x: Tensor[A, B, A]): ...

x = make_tensor((2, 4, 2))

with raises(Exception):  # <- beartype raises
    make_tensor((1, 2))  # pyright flags as error

with raises(Exception):  # <- beartype raises
    f(x)  # pyright flags as error

g(x)

In [8]:
make_tensor((1, 2, 2))

BeartypeCallHintReturnViolation: @beartyped make_tensor() return "tensor([[[0.5726, 0.9630],
         [0.7850, 0.8574]]])" violates type hint <class '__main__.Tensor.__class_getitem__.<locals>.PhantomTensor'>, as "tensor([[[0.5726, 0.9630],
         [0.7850, 0.8574]]])" not instance of <protocol "__main__.PhantomTensor">.

In [75]:
make_tensor((1, 2, 1))

tensor([[[0.0346],
         [0.4760]]])

In [None]:
register_tensor_shit = lambda x: x
any_third_party_type_checker = lambda x: x

In [None]:
z = 1
def parse(*a: Any) -> Any: ...

In [None]:
x, y, z = parse((x, Tensor[A, B, B]), (y, Tensor[B]), (z, Tensor[B, A, C]))

In [77]:
@any_third_party_type_checker
@register_tensor_shit
def some_tensor_op(x: Tensor[A, B, B], y: Tensor[B], z: Tensor[B, A, C]) -> Tensor[A, C]: ...

In [76]:
make_tensor((1, 2, 2))

BeartypeCallHintReturnViolation: @beartyped make_tensor() return "tensor([[[0.9308, 0.6804],
         [0.0165, 0.4240]]])" violates type hint <class '__main__.Tensor.__class_getitem__.<locals>.PhantomTensor'>, as "tensor([[[0.9308, 0.6804],
         [0.0165, 0.4240]]])" not instance of <protocol "__main__.PhantomTensor">.

In [56]:
def make_tensor() -> Tensor[A, B]: ...

def f(x: Tensor[A, B], y: Tensor[B, A]): ...

x = make_tensor()

f(x, x)  # <- should also raise runtime error via beartype

f(x, outer_transpose(x))



False

In [57]:
assert not isinstance([1], Tensor[A, A])

assert isinstance(tr.rand(2, 2), Tensor[A, A])
assert not isinstance(tr.rand(2, 3), Tensor[A, A])

assert isinstance(tr.rand(2, 3), Tensor[A, B])

In [59]:
assert isinstance(tr.rand(1, 2, 3, 2, 1), Tensor[A, B, C, B, A])
assert not isinstance(tr.rand(1, 2, 3, 2, 1), Tensor[A, B, C, A, A])

In [6]:
issubclass(A, int)

TypeError: issubclass() arg 1 must be a class

In [5]:
def f() -> Tensor[A, B]: ...

def g(x: Tensor[B, A]): ...

x = f()
x.shape

<class '__main__.Tensor'> ((<function NewType.<locals>.new_type at 0x0000022F38CD7D30>, <function NewType.<locals>.new_type at 0x0000022F38CD7E50>),) {}
<class 'super'>


In [7]:
isinstance(tr.rand((2,1)), Tensor)

False