Type annotations for a tensor's shape, dtype, names, ...
Welcome! For new projects I now strongly recommend using my newer jaxtyping project instead. It supports PyTorch, doesn't actually depend on JAX, and unlike TorchTyping it is compatible with static type checkers. :)
def batch_outer_product(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # x has shape (batch, x_channels) # y has shape (batch, y_channels) # return has shape (batch, x_channels, y_channels) return x.unsqueeze(-1) * y.unsqueeze(-2)
def batch_outer_product(x: TensorType["batch", "x_channels"], y: TensorType["batch", "y_channels"] ) -> TensorType["batch", "x_channels", "y_channels"]: return x.unsqueeze(-1) * y.unsqueeze(-2)
with programmatic checking that the shape (dtype, ...) specification is met.
Bye-bye bugs! Say hello to enforced, clear documentation of your code.
If (like me) you find yourself littering your code with comments like
# x has shape (batch, hidden_state) or statements like
assert x.shape == y.shape , just to keep track of what shape everything is, then this is for you.
pip install torchtyping
Requires Python >=3.7 and PyTorch >=1.7.0.
typeguard then it must be a version <3.0.0.
torchtyping allows for type annotating:
- shape: size, number of dimensions;
- dtype (float, integer, etc.);
- layout (dense, sparse);
- names of dimensions as per named tensors;
- arbitrary number of batch dimensions with
- ...plus anything else you like, as
torchtypingis highly extensible.
typeguard is (optionally) installed then at runtime the types can be checked to ensure that the tensors really are of the advertised shape, dtype, etc.
# EXAMPLE from torch import rand from torchtyping import TensorType, patch_typeguard from typeguard import typechecked patch_typeguard() # use before @typechecked @typechecked def func(x: TensorType["batch"], y: TensorType["batch"]) -> TensorType["batch"]: return x + y func(rand(3), rand(3)) # works func(rand(3), rand(1)) # TypeError: Dimension 'batch' of inconsistent size. Got both 1 and 3.
typeguard also has an import hook that can be used to automatically test an entire module, without needing to manually add
If you're not using
torchtyping.patch_typeguard() can be omitted altogether, and
torchtyping just used for documentation purposes. If you're not already using
typeguard for your regular Python programming, then strongly consider using it. It's a great way to squash bugs. Both
torchtyping also integrate with
pytest, so if you're concerned about any performance penalty then they can be enabled during tests only.
torchtyping.TensorType[shape, dtype, layout, details]
The core of the library.
details are optional.
shapeargument can be any of:
int: the dimension must be of exactly this size. If it is
-1then any size is allowed.
str: the size of the dimension passed at runtime will be bound to this name, and all tensors checked that the sizes are consistent.
...: An arbitrary number of dimensions of any sizes.
str: intpair (technically it's a slice), combining both
intbehaviour. (Just a
stron its own is equivalent to
str: strpair, in which case the size of the dimension passed at runtime will be bound to both names, and all dimensions with either name must have the same size. (Some people like to use this as a way to associate multiple names with a dimension, for extra documentation purposes.)
str: ...pair, in which case the multiple dimensions corresponding to
...will be bound to the name specified by
str, and again checked for consistency between arguments.
None, which when used in conjunction with
is_namedbelow, indicates a dimension that must not have a name in the sense of named tensors.
None: intpair, combining both
intbehaviour. (Just a
Noneon its own is equivalent to
None: strpair, combining both
strbehaviour. (That is, it must not have a named dimension, but must be of a size consistent with other uses of the string.)
typing.Any: Any size is allowed for this dimension (equivalent to
- Any tuple of the above. For example.
TensorType["batch": ..., "length": 10, "channels", -1]. If you just want to specify the number of dimensions then use for example
TensorType[-1, -1, -1]for a three-dimensional tensor.
dtypeargument can be any of:
float, which are converted to their corresponding PyTorch types.
floatis specifically interpreted as
torch.get_default_dtype(), which is usually
layoutargument can be either
torch.sparse_coo, for dense and sparse tensors respectively.
detailsargument offers a way to pass an arbitrary number of additional flags that customise and extend
torchtyping. Two flags are built-in by default.
torchtyping.is_namedcauses the names of tensor dimensions to be checked, and
torchtyping.is_floatcan be used to check that arbitrary floating point types are passed in. (Rather than just a specific one as with e.g.
TensorType[torch.float32].) For discussion on how to customise
torchtypingwith your own
details, see the further documentation.
- Check multiple things at once by just putting them all together inside a single
. For example
TensorType["batch": ..., "length", "channels", float, is_named].
torchtyping integrates with
typeguard to perform runtime type checking.
torchtyping.patch_typeguard() should be called at the global level, and will patch
typeguard to check
This function is safe to run multiple times. (It does nothing after the first run).
- If using
torchtyping.patch_typeguard()should be called any time before using
@typeguard.typechecked. For example you could call it at the start of each file using
- If using
torchtyping.patch_typeguard()should be called any time before defining the functions you want checked. For example you could call
torchtyping.patch_typeguard()just once, at the same time as the
typeguardimport hook. (The order of the hook and the patch doesn't matter.)
- If you're not using
torchtyping.patch_typeguard()can be omitted altogether, and
torchtypingjust used for documentation purposes.
torchtyping offers a
pytest plugin to automatically run
torchtyping.patch_typeguard() before your tests.
pytest will automatically discover the plugin, you just need to pass the
--torchtyping-patch-typeguard flag to enable it. Packages can then be passed to
typeguard as normal, either by using
typeguard's import hook, or the
See the further documentation for:
- How to write custom extensions to
- Resources and links to other libraries and materials on this topic;
- More examples.