Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor duck #40

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open

Conversation

corwinjoy
Copy link

As promised, here is the PR to upgrade the library to define a 'torch-like' protocol and use that for the base type rather than using torch.Tensor directly. This lets users perform dimension checking on classes that support a Tensor interface but do not directly inherit from torch.Tensor. I think the change is fairly clear-cut, I have added a test case to demonstrate and verify that dimensions are actually checked.
The only question I have is about the change to line 304 in typechecker.py (the last change below).
Is this test really necessary?
I had to change it to use default construction because protocols don't support isinstance if they have properties.

Corwin Joy added 2 commits November 4, 2022 15:17
@@ -301,7 +300,7 @@ def check_type(*args, **kwargs):
# Now check if it's annotating a tensor
if is_torchtyping_annotation:
base_cls, *all_metadata = get_args(expected_type)
if not issubclass(base_cls, torch.Tensor):
if not isinstance(base_cls(), TensorLike):
Copy link
Author

@corwinjoy corwinjoy Nov 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this last change. As mentioned, the protocol class only supports isintance() because it has properties. This means I had to require default construction.
But, I think this test may be unnecessary - after all the other tests I think we know this is a TensorLike element?
I think it might be better to just get rid of this test. @patrick-kidger

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, it does seem I have a strong motivation to remove this. The case where I want to apply it is to check shape signatures on an abstract base class so default construction may not be an option.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the PR accordingly.

@patrick-kidger
Copy link
Owner

Thanks for the PR! Unfortunately, this isn't quite the direction I had in mind.

Following on from the discussion in #39, perhaps it's worth making clear that I don't intend to make TorchTyping depend on JAX. Rather, that the plan is to simply copy over the non-JAX parts of the code. (Which is most of it.)

The idea would be to end up with annotations that look like Float[Tensor, "batch channel", ...], where ... is information about those PyTorch concepts that don't exist in JAX. (In particular, device and layout). And then add in some backward-compatible aliases, so that TensorType[...] is lowered to this new representation.

At a technical level this should be essentially simple. The main hurdle - and the reason I've been putting off doing this is - is writing up documentation that makes this transition clear.

@corwinjoy
Copy link
Author

Thanks for the clarification! I can totally see why you want to pull over the jaxtyping code and have a single code base. I understand that this PR is perhaps not what you were looking for, but I think it could actually represent a very important step in generalizing what you have and maybe even merging the two code bases. Let's take an example snippet from jaxtyping where the dtype is extracted (array_types.py: 129-)

class _MetaAbstractArray(type):
    def __instancecheck__(cls, obj):
        if not isinstance(obj, cls.array_type):
            return False

        if hasattr(obj.dtype, "type") and hasattr(obj.dtype.type, "__name__"):
            # JAX, numpy
            dtype = obj.dtype.type.__name__
        elif hasattr(obj.dtype, "as_numpy_dtype"):
            # TensorFlow
            dtype = obj.dtype.as_numpy_dtype.__name__
        else:
            # PyTorch
            repr_dtype = repr(obj.dtype).split(".")
            if len(repr_dtype) == 2 and repr_dtype[0] == "torch":
                dtype = repr_dtype[1]
            else:
                raise RuntimeError(
                    "Unrecognised array/tensor type to extract dtype from"
                )

        if cls.dtypes is not _any_dtype and dtype not in cls.dtypes:
            return False

I think you would agree that it's a bit awkward and somewhat hard to extend since the supported classes have to be coded in advance.
Instead, with a design like in this PR, we could make _MetaAbstractArray use a protocol class to declare what kinds of properties it is checking at runtime. For concreteness let's say we define this protocol like

class _ArrayLike(Protocol)
    @property
    def dtype(self) -> AbstractDtype:
        pass
...

Then, to type-check a concrete class like numpy.array or torch.Tensor we just use the adapter pattern to map the specialized methods to the interface. (As an example, a simple name remapper: Adapter Method – Python Design Patterns).

This would make it easy for folks like me to extend your library to array-type objects such as LinearOperator by just writing an adapter to the interface specified by the library.

In addition, I think it could also let you merge these two libraries and make your life easier. You wrote that:
"Yep, I did consider merging things. This ends up being infeasible due to edge cases, e.g. JAX and PyTorch (eventually) having different PyTree manipulation routines."
Looking at the pytree_type.py it seems plausible that you could define a protocol class along the lines of _TreeLike(Protocol) with methods for accessing leaves.
Then, JAX Pytree support can be done via an adapter and this can be an optional import for those that don't want a JAX dependency.
Anyway, I think this could be pretty nice and would be happy to help make it happen. I don't thoroughly understand the jaxtyping code but I think this is doable and would be happy to help with documentation!

@patrick-kidger
Copy link
Owner

Hmm. I suppose the practical implementation of such an adaptor would be via a registry:

import functools as ft

@ft.singledispatch
def get_dtype(obj):
    # Note that this default implementation does not explicitly
    # depend on any of PyTorch/etc; thus the singledispatch
    # hook is made available just for the sake of user-defined
    # custom types.
    if hasattr(obj.dtype, "type") and hasattr(obj.dtype.type, "__name__"):
        # JAX, numpy
        dtype = obj.dtype.type.__name__
    elif hasattr(obj.dtype, "as_numpy_dtype"):
        # TensorFlow
        dtype = obj.dtype.as_numpy_dtype.__name__
    else:
        # PyTorch
        repr_dtype = repr(obj.dtype).split(".")
        if len(repr_dtype) == 2 and repr_dtype[0] == "torch":
            dtype = repr_dtype[1]
        else:
            raise RuntimeError(
                "Unrecognised array/tensor type to extract dtype from"
            )

class _MetaAbstractArray(type):
    def __instancecheck__(cls, obj):
        ...
        dtype = get_dtype(obj)
        ...

and then in your user code, you could add a custom overload for your type.

I'd be willing to accept a PR for this over in jaxtyping.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants