Skip to content

Conversation

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented May 26, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit e81a4d2 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

lezcano added a commit that referenced this pull request May 26, 2022
ghstack-source-id: cec5e59
Pull Request resolved: #78350
@lezcano lezcano added module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul release notes: composability release notes category topic: not user facing topic category module: primTorch labels May 26, 2022
@lezcano lezcano changed the title Implement references for the linalg.norm* functions Reference for linalg.vector_norm May 26, 2022
if p == 1.0:
return x
elif p == 2.0:
return prims.mul(x, x)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this an optimisation we want, or is pow(x, 2.0) better? Same for sqrt below really

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer the explicit multiplication because otherwise I think we're just making more work for ourselves to convert the power to a multiplication later

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My question stems from the fact that these operations are memory bound. If you do x * x, you are reading two inputs, while if you do prims.pow(x, 2) or x.square(), you are just reading the input once. Theoretically, one could do some aliasing checks in TI (or your executor of choice) and figure out that, if the inputs are the same, you do not need to read them twice. But of course, this is optional.
What should we do about this @ngimel ?

Copy link
Collaborator Author

@lezcano lezcano Jun 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, prims.pow has all the information to do this optimisations inside. In my opinion, we should have just pow as prim and have square, sqrt, cbrt (didn't even know that this was a thing) be composite operations.
This way, executors should be sure that their pow operation is as fast as possible, perhaps dispatching to different functions if necessary. Otherwise, this job is on the user and, if you want anything that's fast, you have to write a ridiculous function of the form of fast_pow as I did. Even more, you need to know that all these prims are a thing, and you end up with things such as "is square() faster than x * x?", "oh, actually, there's the cbrt that happens to be faster than pow(x, 1/3)", etc. Note that now cbrt is just a call to pow(x, 1/3), but who knows...

Thoughts @mruberry @ngimel ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see getting rid of square, but are pow(a, 1/2) and pow(a, 1/3) valid implementations for sqrt and cbrt? What we aren't doing at this time is requiring trace executors make numerical accuracy fixes, in part because we also want to run accurately in eager mode. If every system has to remap pow(a, 1/2) into sqrt, anyway, then aren't we just hiding the fact that it's prim?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be against using pow because it requires unusual smarts from the backend. I don't care about square or x*x.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue then is that those smarts need to come from the user every time they need to use prim.pow, which is less than good, but well, leaving as is.

ord_pow /= 2.0
else:
x = prims.abs(x)
return to_result_dtype(fast_pow(reduce_sum(fast_pow(x, ord_pow)), 1.0 / ord))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this captures out behavior for norms that requires out to be of the same type as dtype if specified.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would assume so, as the result is of dtype dtype, and the out wrapper would use it to check against out. Now, I'll let @mruberry confirm

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out wrapper copies to any out that's of a weakly higher type (e.g. float result can be copied to double or cfloat out)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's also OK for the reference to be divergent from the torch implementation here, right? Like ideally wouldn't the torch reference perform the safe copy, too?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idk, other reduction references are not divergent.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We test that out doesn't allow unsafe casting, but we don't test that it supports safe casting currently, because our support for safe casting is spotty and I didn't want to skip the entire test. We should just add a different test for it (like the out warning test).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See

# Case 4: out= with correct shape and device, but a dtype

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand now, thank you!

So, at the moment, vector_norm is a structured kernel, so it supports as many (or as few) features as all the functions that are implemented as structured kernels. So, if this is a problem with structured kernels in general, it should be solved at the level of structured kernels. If it is not... well, then there's nothing to solve :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so after re-reading this for the 10th time I think I now understand the issue (sorry). So, the issue here would pop up in the following example:

x = torch.rand((3,), dtype=torch.bfloat16)
out = torch.rand((), dtype=torch.float64)
torch.linalg.vector_norm(x, dtype=torch.float32, out=out)

This, in the ref, would perform the inner computation on float32, and then cast the result to out. On the other hand, it would break in linalg.vector_norm, as the returned vector is specified to be of type float32, and the given tensor is of type float64. Now, I think that the same would happen with any function implemented as a structured kernel (perhaps just those not TI-based? idk) as the check on the dtype of the result is of equality, not of less equal. So, as mentioned above, if we want this behaviour in general, it should be implemented at a structured kernel level. Now, on this note, I am not sure whether it's desirable. For one, it would not work on linalg functions that require calling external libraries (e.g. matmul), as we need dtype equality here for obvious reasons. So I'm not sure how far we want to push this dtype equality enforcement really... unless I'm still missing something, which, at this point, wouldn't even surprise me.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to fix it in the core, but references should be faithfully reproducing this behavior.

lezcano added a commit that referenced this pull request May 27, 2022
ghstack-source-id: 2bc3f35
Pull Request resolved: #78350
lezcano added a commit that referenced this pull request May 27, 2022
ghstack-source-id: d8aef52
Pull Request resolved: #78350
facebook-github-bot pushed a commit that referenced this pull request Jul 12, 2022
Summary:
Pull Request resolved: #78350
Approved by: https://github.com/mruberry

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/e9a9b50f48841adba6450301e28bd43e9ac8ca87

Reviewed By: mehtanirav

Differential Revision: D37749537

Pulled By: mehtanirav

fbshipit-source-id: 1696c09216ba7ce3a5696f6d8cac241454d1d941
@facebook-github-bot facebook-github-bot deleted the gh/Lezcano/81/head branch July 13, 2022 14:16
@kit1980 kit1980 added the Merged label Mar 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: primTorch open source release notes: composability release notes category topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants