-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Reference implementations for rsqrt and native_layer_norm #79413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful links
✅ No Failures (0 Pending)As of commit f7307e1 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
|
It would be nice to have a way to display reference implementation code in documentation, this can compensate for unclear description (meanwhile it's being improved): #51455 |
|
cc @Chillee for native_layer_norm |
| def _normalize(a, norm_dims, eps): | ||
| computation_dtype = utils.get_computation_dtype(a.dtype) | ||
| a_acc = prims.convert_element_type(a, computation_dtype) | ||
| biased_var, mean = var_mean(a_acc, dim=norm_dims, unbiased=False, keepdim=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
currently nvfuser doesn't understand var_mean, so this decomp is not readily usable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
Also, the reason the previous decomp used separate calls to var and mean was that it more closely matched the numerics of eager mode. @ngimel's opinion was that the discrepancy shouldn't matter, but just noting in case you need to adjust tolerances on any tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will make mean working with the nvFuser executor in a separate PR. var will be implemented later as it's currently a "prim" in PyTorch, but a composite in nvFuser.
It's separate calls to var and mean. Take a look at the var_mean function:
def var_mean(
a: TensorLikeType,
dim: Union[Optional[int], Optional[List[int]]] = None,
unbiased: Optional[bool] = None,
keepdim: bool = False,
*,
correction: Optional[int] = None,
):
v = var(a, dim, unbiased, keepdim, correction=correction)
m = mean(a, dim, keepdim)
return v, m
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe @peterbell10 once told me that var_mean had some speed and stability problems in PyTorch (?). I do not remember very well so I'll let him discuss this, but perhaps this is relevant here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On CPU, the mean from var_mean is less accurate than calling mean separately. mean is implemented roughly as sum(x, dim) / x.size(dim), where sum uses a low-error summation algorithm. var_mean on the other hand computes both the mean and var in a single pass, but with a naive summation that is less accurate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
var_mean here calls into reference decomposition which in turn computes mean and var separately (mean via sum and division)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For similar reasons, can we change this call to a torch.var_mean as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, then, unless you decompose torch.var_mean further, you'd get inaccurate cpu result
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then I think var_mean should probably call into torch.var and torch.mean.
| func = op.get_op() | ||
| for sample_input in samples: | ||
| if requires_grad: | ||
| if None in sample_input.args: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because this test cannot handle None as a positional argument to native_layer_norm. It has the signature: native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor). Both weight and bias can be None and I added sample inputs for these cases.
torch/_refs/__init__.py
Outdated
| computation_dtype = utils.get_computation_dtype(a.dtype) | ||
| a_acc = prims.convert_element_type(a, computation_dtype) | ||
| biased_var, mean = var_mean(a_acc, dim=norm_dims, unbiased=False, keepdim=True) | ||
| rstd = rsqrt(biased_var + eps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just call torch.rsqrt here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer to use the torch._refs function here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding is that the current thinking is to prefer calling torch.foo vs. _refs.foo if convenient, as it's fairly easy to map from torch.foo => _refs.foo but not the other way around.
See @mruberry's comment here: #78689 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's very easy to map _refs.rsqrt to torch.rsqrt because _refs.rsqrt is a wrapper of torch.rsqrt for regular and meta tensors and test_decompy.py passing indicates that it's fine as it is.
I changed this to torch.rsqrt.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the perspective of the "correctness" tests it's fine. But, if I understand correctly, _refs.rsqrt => prims.rsqrt, and then you're saying that prims.rsqrt simply calls torch.rsqrt under the hood.
But when we trace things out, using _refs.rsqrt doesn't have an interception point to allow us to decompose native_layer_norm but not decompose _refs.rsqrt.
TBH, for rsqrt this is a somewhat strange discussion haha since _refs.rsqrt == prims.rsqrt == aten::rsqrt. But it matters more for var_mean.
For example, say that I'm a compiler that wants native_layer_norm decomposed, but I have a special handling for var_mean. If we call _refs.var_mean, then our trace has no option to trace out var_mean - we'll get var followed by mean.
If we use torch.var_mean though, then we'll have the choice to either further decompose into var and mean or preserve var_mean in our trace.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because of the way the current context mapping is implemented it's a one-way torch. -> refs. transform, so we should get in the habit of preferring torch. calls when possible. In some cases the torch operation doesn't have as much functionality as the ref, and in cases where that extra functionality is used the ref should be called explicitly, instead.
| def _normalize(a, norm_dims, eps): | ||
| computation_dtype = utils.get_computation_dtype(a.dtype) | ||
| a_acc = prims.convert_element_type(a, computation_dtype) | ||
| biased_var, mean = var_mean(a_acc, dim=norm_dims, unbiased=False, keepdim=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
Also, the reason the previous decomp used separate calls to var and mean was that it more closely matched the numerics of eager mode. @ngimel's opinion was that the discrepancy shouldn't matter, but just noting in case you need to adjust tolerances on any tests.
torch/_refs/__init__.py
Outdated
|
|
||
| def _normalize(a, norm_dims, eps): | ||
| computation_dtype = utils.get_computation_dtype(a.dtype) | ||
| a_acc = prims.convert_element_type(a, computation_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make this conversation conditional
torch/_refs/__init__.py
Outdated
| return prims.slice_in_dim(a, start, start + length, axis=dim) | ||
|
|
||
|
|
||
| def _normalize(a, norm_dims, eps): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment for this function -- what it's meant to be used for, how it's intended to be used -- we should probably add type annotations to it
torch/_refs/__init__.py
Outdated
| weight: Optional[Tensor], | ||
| bias: Optional[Tensor], | ||
| eps: float, | ||
| ) -> Tuple[TensorLikeType, TensorLikeType, TensorLikeType]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the latest updates I think it's just all Tensor now
|
All CI is passing now. Since this PR adds a new OpInfo it revealed a few problems and tests had to be skipped. I submitted an issue #79705. Except for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work, @IvanYashchuk! Nice attention to the bug fix and the issue.
Let's ship this as soon as the tests are happy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, besides the minor nit about torch.var and torch.mean.
I don't want to block this PR on that though - this PR has been through a bunch of rounds already :)
If it's an issue for using this decomposition in the future we can just modify it to torch.var/torch.mean then.
|
@pytorchbot merge -g |
|
@pytorchbot successfully started a merge job. Check the current status here |
|
Hey @IvanYashchuk. |
…79413) Summary: This PR adds references for: - `torch.rsqrt` - `torch.native_layer_norm` - `torch.nn.functional.layer_norm` `native_layer_norm` had a different number of dimensions if the input was 0-sized. I fixed that. Pull Request resolved: #79413 Approved by: https://github.com/mruberry, https://github.com/Chillee Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/bc1fef96aff4aeffe7a7b39ef8b4ff467860df28 Reviewed By: malfet Differential Revision: D37254233 fbshipit-source-id: d712aabb1c26fccb0b19f703f7c75df46f503396
This PR adds references for:
torch.rsqrttorch.native_layer_normtorch.nn.functional.layer_normnative_layer_normhad a different number of dimensions if the input was 0-sized. I fixed that.