-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Reference for linalg.vector_norm #78350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful links
✅ No Failures (0 Pending)As of commit e81a4d2 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
torch/_refs/linalg/__init__.py
Outdated
if p == 1.0: | ||
return x | ||
elif p == 2.0: | ||
return prims.mul(x, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this an optimisation we want, or is pow(x, 2.0)
better? Same for sqrt
below really
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer the explicit multiplication because otherwise I think we're just making more work for ourselves to convert the power to a multiplication later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My question stems from the fact that these operations are memory bound. If you do x * x
, you are reading two inputs, while if you do prims.pow(x, 2)
or x.square()
, you are just reading the input once. Theoretically, one could do some aliasing checks in TI (or your executor of choice) and figure out that, if the inputs are the same, you do not need to read them twice. But of course, this is optional.
What should we do about this @ngimel ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, prims.pow
has all the information to do this optimisations inside. In my opinion, we should have just pow
as prim and have square
, sqrt
, cbrt
(didn't even know that this was a thing) be composite operations.
This way, executors should be sure that their pow
operation is as fast as possible, perhaps dispatching to different functions if necessary. Otherwise, this job is on the user and, if you want anything that's fast, you have to write a ridiculous function of the form of fast_pow
as I did. Even more, you need to know that all these prims are a thing, and you end up with things such as "is square()
faster than x * x
?", "oh, actually, there's the cbrt
that happens to be faster than pow(x, 1/3)
", etc. Note that now cbrt
is just a call to pow(x, 1/3)
, but who knows...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see getting rid of square, but are pow(a, 1/2) and pow(a, 1/3) valid implementations for sqrt and cbrt? What we aren't doing at this time is requiring trace executors make numerical accuracy fixes, in part because we also want to run accurately in eager mode. If every system has to remap pow(a, 1/2) into sqrt, anyway, then aren't we just hiding the fact that it's prim?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would be against using pow
because it requires unusual smarts from the backend. I don't care about square
or x*x
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue then is that those smarts need to come from the user every time they need to use prim.pow
, which is less than good, but well, leaving as is.
torch/_refs/linalg/__init__.py
Outdated
ord_pow /= 2.0 | ||
else: | ||
x = prims.abs(x) | ||
return to_result_dtype(fast_pow(reduce_sum(fast_pow(x, ord_pow)), 1.0 / ord)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this captures out
behavior for norms that requires out
to be of the same type as dtype
if specified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would assume so, as the result is of dtype dtype
, and the out wrapper would use it to check against out
. Now, I'll let @mruberry confirm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out
wrapper copies to any out
that's of a weakly higher type (e.g. float
result can be copied to double
or cfloat
out
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's also OK for the reference to be divergent from the torch implementation here, right? Like ideally wouldn't the torch reference perform the safe copy, too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Idk, other reduction references are not divergent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We test that out doesn't allow unsafe casting, but we don't test that it supports safe casting currently, because our support for safe casting is spotty and I didn't want to skip the entire test. We should just add a different test for it (like the out warning test).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See
Line 877 in d990277
# Case 4: out= with correct shape and device, but a dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand now, thank you!
So, at the moment, vector_norm
is a structured kernel, so it supports as many (or as few) features as all the functions that are implemented as structured kernels. So, if this is a problem with structured kernels in general, it should be solved at the level of structured kernels. If it is not... well, then there's nothing to solve :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, so after re-reading this for the 10th time I think I now understand the issue (sorry). So, the issue here would pop up in the following example:
x = torch.rand((3,), dtype=torch.bfloat16)
out = torch.rand((), dtype=torch.float64)
torch.linalg.vector_norm(x, dtype=torch.float32, out=out)
This, in the ref, would perform the inner computation on float32
, and then cast the result to out
. On the other hand, it would break in linalg.vector_norm
, as the returned vector is specified to be of type float32
, and the given tensor is of type float64
. Now, I think that the same would happen with any function implemented as a structured kernel (perhaps just those not TI-based? idk) as the check on the dtype of the result is of equality, not of less equal. So, as mentioned above, if we want this behaviour in general, it should be implemented at a structured kernel level. Now, on this note, I am not sure whether it's desirable. For one, it would not work on linalg functions that require calling external libraries (e.g. matmul), as we need dtype equality here for obvious reasons. So I'm not sure how far we want to push this dtype equality enforcement really... unless I'm still missing something, which, at this point, wouldn't even surprise me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to fix it in the core, but references should be faithfully reproducing this behavior.
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Summary: Pull Request resolved: #78350 Approved by: https://github.com/mruberry Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/e9a9b50f48841adba6450301e28bd43e9ac8ca87 Reviewed By: mehtanirav Differential Revision: D37749537 Pulled By: mehtanirav fbshipit-source-id: 1696c09216ba7ce3a5696f6d8cac241454d1d941
Stack from ghstack:
cc @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @lezcano @ezyang @ngimel @peterbell10