Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[discussion] Have PyTorch functions support python scalars (like NumPy) + introduce convenience constants like torch.pi and torch.e and maybe analogue of scipy.constants namespace #110636

Open
vadimkantorov opened this issue Oct 5, 2023 · 17 comments
Labels
module: numpy Related to numpy support, and also numpy compatibility of our operators module: python frontend For issues relating to PyTorch's Python frontend triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Oct 5, 2023

🚀 The feature, motivation and pitch

OP: #110351 (comment)

As an example, it would be nice to have torch.sqrt(python scalar) -> python scalar without having to dispatch between torch.sqrt and math.sqrt and to enable a bit more polymorphic code

Another idea is to also have torch.pi and other constants (like NumPy), in order to avoid importing numpy or math in order to get these constants.

Please close this if it's duplicate. I tried to search for similar previous discussions, but the keywords are a bit too generic :(

For torch.sqrt specifically, the polymorphic equivalent currently exists: x ** 0.5 which works both for tensor and python scalar inputs. But it is useful to have this behavior for many (at least simple) functions like torch.exp and so forth

Alternatives

No response

Additional context

No response

cc @mruberry @rgommers @albanD

@qqaatw
Copy link
Collaborator

qqaatw commented Oct 5, 2023

The idea sounds interesting to me, but I think without wrapped by Tensor the AD and other tracers would not work in this case?

@ezyang ezyang added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: numpy Related to numpy support, and also numpy compatibility of our operators module: python frontend For issues relating to PyTorch's Python frontend labels Oct 6, 2023
@ezyang
Copy link
Contributor

ezyang commented Oct 6, 2023

There is some precedent for this, but we diverge fairly significantly from Numpy's behavior on operators that do support python scalars. For example:

>>> import torch
>>> torch.add(2, 3)
tensor(5)
>>> import numpy
>>> numpy.add(2, 3)
5
>>> type(numpy.add(2, 3))
<class 'numpy.int64'>
>>> numpy.sqrt(4)
2.0
>>> type(numpy.sqrt(4))
<class 'numpy.float64'>

I don't remember if the numpy devs consider numpy scalars to be a mistake. It'd be a pretty big modeling change for PyTorch though, it doesn't feel like we're likely to do it.

@rgommers
Copy link
Collaborator

rgommers commented Oct 6, 2023

I don't remember if the numpy devs consider numpy scalars to be a mistake.

Yes, we do consider them a mistake. 0-D arrays/tensors have turned out to work just fine in PyTorch, CuPy, JAX & co., and avoid a ton of complexity that come with numpy's array scalars (which are instances of dtypes).

@rgommers
Copy link
Collaborator

rgommers commented Oct 6, 2023

Another idea is to also have torch.pi and other constants (like NumPy), in order to avoid importing numpy or math in order to get these constants.

This is indeed a good idea, and very easy to do. +1 for adding pi, nan, inf, and e as the four relevant numerical constants. And also add newaxis as an alias of None.

@vadimkantorov
Copy link
Contributor Author

For PyTorch, my main motivation was more polymorphic code to avoid multiple distinct, but very similar code paths for dealing with python scalars and tensors (examples are found in the optimizers code mainly)

@ezyang
Copy link
Contributor

ezyang commented Oct 7, 2023

Well, if you're OK with the operators promoting their results into 0d tensors, I don't think it's the worst to extend what ops take python scalars (in fact, we are quite a bit better about this in PrimTorch refs), though it can be pretty annoying to teach our codegen to do it.

@vadimkantorov
Copy link
Contributor Author

I think, in the optimizers' code specifically, these multiple code paths are done because Python scalars processing are now faster than CPU 0d tensors.

So promotion on 0d tensors should be fine, but ideally it should allow to remove these duplicate code paths from the optimizers...

@ezyang
Copy link
Contributor

ezyang commented Oct 9, 2023

yeah, makes sense. Pretty hard to fix eager mode 0d perf :(

@vadimkantorov
Copy link
Contributor Author

But it's also a bit surprising that this overhead is noticeable (given that much larger tensors are usually being manipulated in the optimizers, so these scalar norm calculations are only a small fraction)

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Oct 11, 2023

What do you think of the proposal in this issue, @janeyx99 ?

@vadimkantorov
Copy link
Contributor Author

yeah, makes sense. Pretty hard to fix eager mode 0d perf :(

So maybe in eager mode torch.sqrt and some other functions can produce Python scalars when given Python scalars (I see, this would be divergent from NumPy) - as a hack. And maybe some time in the future, they can be replaced to produce 0d tensors if it's found not very badly affecting perf

@vadimkantorov vadimkantorov changed the title [discussion] Have PyTorch functions support python scalars (like NumPy) [discussion] Have PyTorch functions support python scalars (like NumPy) + introduce convenience constants like torch.pi and torch.e Oct 12, 2023
@ezyang
Copy link
Contributor

ezyang commented Oct 13, 2023

I'd oppose having torch.sqrt return Python scalar, because torch.add(2, 3) doesn't return a Python scalar.

@janeyx99
Copy link
Contributor

@vadimkantorov The reason Python scalar math is so much faster is because we don't need to dispatch into kernels. If we are trying to move everything to be torch ops, I feel like the perf hit still remains, and it would still be valuable keeping hyperparams as strict Python scalars. I'm not sure I see how the proposal would be better even if we have torch ops taking and returning scalars, because the kernel dispatch would still exist, right?

In fact, we've been migrating to the (not quite) "opposite" approach as more people enroll into torch.compile() or play with CUDA graphs--we want to accept more ScalarTensors instead of just Scalars for many of our foreach ops to enable dynamism.

I agree the branching in our code adds confusion, but the tradeoff of perf is too strong.

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Nov 11, 2023

Regarding the scalars, it just feels strange that this dispatch overhead is noticeable in optimizers / overall perf given that a lot of number-crunching happens in optimizers :)

Regarding constants, maybe just worth duplicating in PyTorch in scipy.constants namespace (+ constants from math python's module): https://docs.scipy.org/doc/scipy/reference/constants.html

I again don't know if PyTorch should just include these as python numbers or as some collection of scalar pytorch tensors allocated/cached on all devices lazily :)

@vadimkantorov vadimkantorov changed the title [discussion] Have PyTorch functions support python scalars (like NumPy) + introduce convenience constants like torch.pi and torch.e [discussion] Have PyTorch functions support python scalars (like NumPy) + introduce convenience constants like torch.pi and torch.e and maybe analogue of scipy.constants namespace Nov 17, 2023
pytorchmergebot pushed a commit that referenced this issue Apr 27, 2024
…5026)

Fixes #65307

For consistency with Python Array API (https://data-apis.org/array-api/latest/API_specification/constants.html) and NumPy  (https://numpy.org/devdocs/reference/constants.html), I added `torch.newaxis = None`.

Note that the consistency is directly mentioned also in the `__init__.py`, right above the added export.

The `torch.newaxis` is also mentioned in #110636.

Pull Request resolved: #125026
Approved by: https://github.com/lezcano
petrex pushed a commit to petrex/pytorch that referenced this issue May 3, 2024
…orch#125026)

Fixes pytorch#65307

For consistency with Python Array API (https://data-apis.org/array-api/latest/API_specification/constants.html) and NumPy  (https://numpy.org/devdocs/reference/constants.html), I added `torch.newaxis = None`.

Note that the consistency is directly mentioned also in the `__init__.py`, right above the added export.

The `torch.newaxis` is also mentioned in pytorch#110636.

Pull Request resolved: pytorch#125026
Approved by: https://github.com/lezcano
@rgommers
Copy link
Collaborator

A status update:

This is indeed a good idea, and very easy to do. +1 for adding pi, nan, inf, and e as the four relevant numerical constants. And also add newaxis as an alias of None.

This was all done. The first four are available in 2.3.0, and newaxis was added in gh-125026 three weeks ago so should be in 2.4.0.

Well, if you're OK with the operators promoting their results into 0d tensors, I don't think it's the worst to extend what ops take python scalars

+1, it would yield easier to read (and write) code. This hasn't moved much recently. With 2.3.0:

>>> t = torch.tensor([2, 3])
>>> torch.add(t, 1)
tensor([3, 4])
>>> torch.atan2(t, 1)
...
TypeError: atan2(): argument 'other' (position 2) must be Tensor, not int

>>> torch.maximum(t, 1)
...
TypeError: maximum(): argument 'other' (position 2) must be Tensor, not int

The ergonomic benefits are useful - to do this in user code in a generic way yields things like:

torch.maximum(t, torch.tensor(1, dtype=t.dtype, device=t.device))

which is quite verbose.

@vadimkantorov
Copy link
Contributor Author

@rgommers What do you think about adding sth like scipy.constants to torch? Was it actually useful/used by scipy users?

@rgommers
Copy link
Collaborator

scipy.constants is a pretty niche module, so I wouldn't consider it for PyTorch. It's quite handy if you do say fundamental physics research, but that's a small group compared to the whole PyTorch audience - and no reason those users can't just use SciPy to get at what they need.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: numpy Related to numpy support, and also numpy compatibility of our operators module: python frontend For issues relating to PyTorch's Python frontend triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants