Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raw tensor type #68

Open
sbrunk opened this issue Dec 11, 2023 · 0 comments
Open

Raw tensor type #68

sbrunk opened this issue Dec 11, 2023 · 0 comments

Comments

@sbrunk
Copy link
Owner

sbrunk commented Dec 11, 2023

With compile-time tracked tensor shapes as discussed #63 we'll probably need three type parameters like Tensor[DType, Shape, Device]. While this is great for type-safety, it makes the Tensor type a bit more convoluted.

For use-cases like prototyping, it can be useful to have an escape-hatch with some kind of raw tensor type, similar to upstream PyTorch or NumPy etc. where these attributes are only tracked at runtime. This would of course be less safe, and we need to think about how both variants could coexists and how we can convert between them etc.

Design considerations

If all our tensor type parameters were covariant, i.e. Tensor[+D <: DType, +S <: Shape, +DE <: Device], a raw tensor could perhaps be Tensor[DType, Shape, Device] (DType, Shape and Device being the upper bound of the type parameters), but currently they aren't, and I'm not sure if it's feasible as tensors are currently mutable and even if we had immutable tensors, covariance isn't without it's own issues (I'm by no means a type a type system expert, this is just my current understanding).

Without covariance, we could have a RawTensor type as a super-type of Tensor, with "unsafe" operations defined only as extension methods that we need to import explicitly. If these operations return a tensor, it would always be a RawTensor. That would probably make it quite easy (too easy?) to run unsafe operations on any tensor. Going from a RawTensor to a typed tensor would always need an explicit unsafe cast. Need to figure out if this can still cause name clashes though.

Yet another way could be to have a more strict separation without inheritance hierarchy but RawTensor in a different package and explicit conversions. Unsafe ops could always work on typed tensors too for convenience.

Perhaps there are other options as well?

An open question is if we'd also need to add support for raw tensor types in torch.nn (modules etc.), and if we could generate/derive that to avoid extra overhead but that's out of scope for this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant