Skip to content

Conversation

@IvanYashchuk
Copy link
Collaborator

@IvanYashchuk IvanYashchuk commented Jun 13, 2022

This PR adds references for:

  • torch.rsqrt
  • torch.native_layer_norm
  • torch.nn.functional.layer_norm

native_layer_norm had a different number of dimensions if the input was 0-sized. I fixed that.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 13, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit f7307e1 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@vadimkantorov
Copy link
Contributor

It would be nice to have a way to display reference implementation code in documentation, this can compensate for unclear description (meanwhile it's being improved): #51455

@ngimel ngimel requested a review from Chillee June 13, 2022 15:51
@ngimel
Copy link
Collaborator

ngimel commented Jun 13, 2022

cc @Chillee for native_layer_norm

def _normalize(a, norm_dims, eps):
computation_dtype = utils.get_computation_dtype(a.dtype)
a_acc = prims.convert_element_type(a, computation_dtype)
biased_var, mean = var_mean(a_acc, dim=norm_dims, unbiased=False, keepdim=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently nvfuser doesn't understand var_mean, so this decomp is not readily usable

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Also, the reason the previous decomp used separate calls to var and mean was that it more closely matched the numerics of eager mode. @ngimel's opinion was that the discrepancy shouldn't matter, but just noting in case you need to adjust tolerances on any tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will make mean working with the nvFuser executor in a separate PR. var will be implemented later as it's currently a "prim" in PyTorch, but a composite in nvFuser.

It's separate calls to var and mean. Take a look at the var_mean function:

def var_mean(
    a: TensorLikeType,
    dim: Union[Optional[int], Optional[List[int]]] = None,
    unbiased: Optional[bool] = None,
    keepdim: bool = False,
    *,
    correction: Optional[int] = None,
):
    v = var(a, dim, unbiased, keepdim, correction=correction)
    m = mean(a, dim, keepdim)
    return v, m

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR for mean: #79444
PR for var: #79517
And together they make nvfuser understand var_mean.

Copy link
Collaborator

@lezcano lezcano Jun 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe @peterbell10 once told me that var_mean had some speed and stability problems in PyTorch (?). I do not remember very well so I'll let him discuss this, but perhaps this is relevant here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On CPU, the mean from var_mean is less accurate than calling mean separately. mean is implemented roughly as sum(x, dim) / x.size(dim), where sum uses a low-error summation algorithm. var_mean on the other hand computes both the mean and var in a single pass, but with a naive summation that is less accurate.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

var_mean here calls into reference decomposition which in turn computes mean and var separately (mean via sum and division)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For similar reasons, can we change this call to a torch.var_mean as well?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, then, unless you decompose torch.var_mean further, you'd get inaccurate cpu result

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then I think var_mean should probably call into torch.var and torch.mean.

@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 13, 2022
func = op.get_op()
for sample_input in samples:
if requires_grad:
if None in sample_input.args:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this test cannot handle None as a positional argument to native_layer_norm. It has the signature: native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor). Both weight and bias can be None and I added sample inputs for these cases.

computation_dtype = utils.get_computation_dtype(a.dtype)
a_acc = prims.convert_element_type(a, computation_dtype)
biased_var, mean = var_mean(a_acc, dim=norm_dims, unbiased=False, keepdim=True)
rstd = rsqrt(biased_var + eps)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just call torch.rsqrt here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to use the torch._refs function here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that the current thinking is to prefer calling torch.foo vs. _refs.foo if convenient, as it's fairly easy to map from torch.foo => _refs.foo but not the other way around.

See @mruberry's comment here: #78689 (comment)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's very easy to map _refs.rsqrt to torch.rsqrt because _refs.rsqrt is a wrapper of torch.rsqrt for regular and meta tensors and test_decompy.py passing indicates that it's fine as it is.

I changed this to torch.rsqrt.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the perspective of the "correctness" tests it's fine. But, if I understand correctly, _refs.rsqrt => prims.rsqrt, and then you're saying that prims.rsqrt simply calls torch.rsqrt under the hood.

But when we trace things out, using _refs.rsqrt doesn't have an interception point to allow us to decompose native_layer_norm but not decompose _refs.rsqrt.

TBH, for rsqrt this is a somewhat strange discussion haha since _refs.rsqrt == prims.rsqrt == aten::rsqrt. But it matters more for var_mean.

For example, say that I'm a compiler that wants native_layer_norm decomposed, but I have a special handling for var_mean. If we call _refs.var_mean, then our trace has no option to trace out var_mean - we'll get var followed by mean.

If we use torch.var_mean though, then we'll have the choice to either further decompose into var and mean or preserve var_mean in our trace.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of the way the current context mapping is implemented it's a one-way torch. -> refs. transform, so we should get in the habit of preferring torch. calls when possible. In some cases the torch operation doesn't have as much functionality as the ref, and in cases where that extra functionality is used the ref should be called explicitly, instead.

def _normalize(a, norm_dims, eps):
computation_dtype = utils.get_computation_dtype(a.dtype)
a_acc = prims.convert_element_type(a, computation_dtype)
biased_var, mean = var_mean(a_acc, dim=norm_dims, unbiased=False, keepdim=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Also, the reason the previous decomp used separate calls to var and mean was that it more closely matched the numerics of eager mode. @ngimel's opinion was that the discrepancy shouldn't matter, but just noting in case you need to adjust tolerances on any tests.


def _normalize(a, norm_dims, eps):
computation_dtype = utils.get_computation_dtype(a.dtype)
a_acc = prims.convert_element_type(a, computation_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this conversation conditional

return prims.slice_in_dim(a, start, start + length, axis=dim)


def _normalize(a, norm_dims, eps):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment for this function -- what it's meant to be used for, how it's intended to be used -- we should probably add type annotations to it

weight: Optional[Tensor],
bias: Optional[Tensor],
eps: float,
) -> Tuple[TensorLikeType, TensorLikeType, TensorLikeType]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the latest updates I think it's just all Tensor now

@IvanYashchuk
Copy link
Collaborator Author

All CI is passing now. Since this PR adds a new OpInfo it revealed a few problems and tests had to be skipped. I submitted an issue #79705.

Except for torch._refs.var_mean vs torch.var_mean (https://github.com/pytorch/pytorch/pull/79413/files#r899336672) I addressed all the feedback.

@Chillee, @mruberry could you please take another look?

@mruberry mruberry added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 16, 2022
@mruberry mruberry self-requested a review June 16, 2022 23:23
Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, @IvanYashchuk! Nice attention to the bug fix and the issue.

Let's ship this as soon as the tests are happy.

Copy link
Collaborator

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, besides the minor nit about torch.var and torch.mean.

I don't want to block this PR on that though - this PR has been through a bunch of rounds already :)

If it's an issue for using this decomposition in the future we can just modify it to torch.var/torch.mean then.

@IvanYashchuk
Copy link
Collaborator Author

@pytorchbot merge -g

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@github-actions
Copy link
Contributor

Hey @IvanYashchuk.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@mruberry mruberry added the topic: not user facing topic category label Jun 17, 2022
facebook-github-bot pushed a commit that referenced this pull request Jun 20, 2022
…79413)

Summary:
This PR adds references for:
- `torch.rsqrt`
- `torch.native_layer_norm`
-  `torch.nn.functional.layer_norm`

`native_layer_norm` had a different number of dimensions if the input was 0-sized. I fixed that.

Pull Request resolved: #79413
Approved by: https://github.com/mruberry, https://github.com/Chillee

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/bc1fef96aff4aeffe7a7b39ef8b4ff467860df28

Reviewed By: malfet

Differential Revision: D37254233

fbshipit-source-id: d712aabb1c26fccb0b19f703f7c75df46f503396
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request cla signed Merged module: primTorch open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.