-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Reference implementations for rsqrt and native_layer_norm #79413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7808339
1703f58
e357a7a
09f6a13
fc89dd3
90b7bdf
3d68aba
8b95bda
9b5db3f
f323b3e
f246459
dfe8a1d
87cdca6
472ddf6
0ecffe6
0e47507
1013fe0
7f836d3
df7d6b5
5730c09
8d95161
a11bda3
6cd5a74
5901865
2aaf514
9198804
0027287
e875261
eeccfe4
58b18bc
f7307e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -450,6 +450,9 @@ def check_decomposed(aten_name): | |
| func = op.get_op() | ||
| for sample_input in samples: | ||
| if requires_grad: | ||
| if None in sample_input.args: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because this test cannot handle |
||
| continue | ||
|
|
||
| fn, primals = normalize_op_input_output(func, sample_input) | ||
| primals = tree_map( | ||
| lambda x: x if isinstance(x, torch.Tensor) else x, primals | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -192,11 +192,13 @@ | |
| "hsplit", | ||
| "hstack", | ||
| "narrow", | ||
| "native_layer_norm", | ||
| "permute", | ||
| "ravel", | ||
| "reshape", | ||
| "roll", | ||
| "rot90", | ||
| "rsqrt", | ||
| "stack", | ||
| "swap_axes", # alias for transpose | ||
| "squeeze", | ||
|
|
@@ -626,6 +628,11 @@ def round(a): | |
| return prims.round(a) | ||
|
|
||
|
|
||
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) | ||
| def rsqrt(a): | ||
| return prims.rsqrt(a) | ||
|
|
||
|
|
||
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) | ||
| def sigmoid(a: TensorLikeType) -> TensorLikeType: | ||
| return true_divide(1, add(1, exp(neg(a)))) | ||
|
|
@@ -1426,7 +1433,7 @@ def _set_correction( | |
| @out_wrapper | ||
| def var( | ||
| a: TensorLikeType, | ||
| dim: Union[Optional[int], Optional[List[int]]] = None, | ||
| dim: Optional[DimsType] = None, | ||
| unbiased: Optional[bool] = None, | ||
| keepdim: bool = False, | ||
| *, | ||
|
|
@@ -1484,7 +1491,7 @@ def std( | |
|
|
||
| def mean( | ||
| a: TensorLikeType, | ||
| dim: Union[Optional[int], Optional[List[int]]] = None, | ||
| dim: Optional[DimsType] = None, | ||
| keepdim: bool = False, | ||
| *, | ||
| dtype=None, | ||
|
|
@@ -1539,7 +1546,7 @@ def std_mean( | |
|
|
||
| def var_mean( | ||
| a: TensorLikeType, | ||
| dim: Union[Optional[int], Optional[List[int]]] = None, | ||
| dim: Optional[DimsType] = None, | ||
| unbiased: Optional[bool] = None, | ||
| keepdim: bool = False, | ||
| *, | ||
|
|
@@ -1797,6 +1804,92 @@ def narrow(a: TensorLikeType, dim: int, start: int, length: int) -> TensorLikeTy | |
| return prims.slice_in_dim(a, start, start + length, axis=dim) | ||
|
|
||
|
|
||
| def _normalize( | ||
| a: Tensor, norm_dims: DimsType, eps: float | ||
| ) -> Tuple[Tensor, Tensor, Tensor]: | ||
| """Computes mean and 1/std of a tensor along norm_dims. | ||
|
|
||
| Used as a helper function for normalization layers. | ||
|
|
||
| Args: | ||
| a (Tensor): input tensor | ||
| norm_dims (DimsType): dimensions to normalize over | ||
| eps (float): epsilon for numerical stability | ||
|
|
||
| Returns: | ||
| out (Tensor): normalized tensor. | ||
| mean (Tensor): mean of the tensor along norm_dims. | ||
| rstd (Tensor): 1/std of the tensor along norm_dims. | ||
| """ | ||
| computation_dtype = utils.get_computation_dtype(a.dtype) | ||
| a_acc = _maybe_convert_to_dtype(a, computation_dtype) | ||
| assert isinstance(a_acc, TensorLike) # to avoid mypy error for var_mean | ||
| biased_var, mean = var_mean(a_acc, dim=norm_dims, unbiased=False, keepdim=True) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. currently nvfuser doesn't understand var_mean, so this decomp is not readily usable There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. Also, the reason the previous decomp used separate calls to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will make It's separate calls to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe @peterbell10 once told me that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On CPU, the mean from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For similar reasons, can we change this call to a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, then, unless you decompose There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then I think |
||
| rstd = torch.rsqrt(biased_var + eps) | ||
| out = (a - mean) * rstd | ||
| return out, mean, rstd | ||
|
|
||
|
|
||
| @register_decomposition(torch.ops.aten.native_layer_norm) | ||
| def native_layer_norm( | ||
| input: Tensor, | ||
| normalized_shape: ShapeType, | ||
| weight: Optional[Tensor], | ||
| bias: Optional[Tensor], | ||
| eps: float, | ||
| ) -> Tuple[Tensor, Tensor, Tensor]: | ||
| normalized_ndim = len(normalized_shape) | ||
| utils.check( | ||
| normalized_ndim >= 1, | ||
| lambda: "Expected normalized_shape to be at least 1-dimensional, i.e., " | ||
| + "containing at least one element, but got normalized_shape = " | ||
| + str(normalized_shape), | ||
| ) | ||
| # torch.Size([1, 2, 3]) == [1, 2, 3] evaluates to False | ||
| # while torch.Size([1, 2, 3]) == (1, 2, 3) is True | ||
| # therefore we use tuple(normalized_shape) | ||
| utils.check( | ||
| weight is None or weight.shape == tuple(normalized_shape), | ||
| lambda: "Expected weight to be of same shape as normalized_shape, but got " | ||
| + "weight of shape " | ||
| + str(weight.shape) # type: ignore[union-attr] | ||
| + " and normalized_shape = " | ||
| + str(normalized_shape), | ||
| ) | ||
| utils.check( | ||
| bias is None or bias.shape == tuple(normalized_shape), | ||
| lambda: "Expected bias to be of same shape as normalized_shape, but got " | ||
| + "bias of shape " | ||
| + str(bias.shape) # type: ignore[union-attr] | ||
| + " and normalized_shape = " | ||
| + str(normalized_shape), | ||
| ) | ||
| utils.check( | ||
| input.ndim >= normalized_ndim | ||
| and input.shape[(input.ndim - normalized_ndim) :] == tuple(normalized_shape), | ||
| lambda: "Given normalized_shape=" | ||
| + str(normalized_shape) | ||
| + ", expected input with shape " | ||
| + str(normalized_shape) | ||
| + ", but got input of size " | ||
| + str(input.shape), | ||
| ) | ||
| axis = input.ndim - normalized_ndim | ||
| reduction_dims = list(range(axis, input.ndim)) | ||
| out, mean, rstd = _normalize(input, reduction_dims, eps) | ||
| if weight is None and bias is not None: | ||
| out = out + bias | ||
| elif weight is not None and bias is None: | ||
| out = out * weight | ||
| elif weight is not None and bias is not None: | ||
| out = out * weight + bias | ||
| out = prims.convert_element_type(out, input.dtype) | ||
| if input.device.type == "cpu": | ||
| mean = prims.convert_element_type(mean, input.dtype) | ||
| rstd = prims.convert_element_type(rstd, input.dtype) | ||
| return (out, mean, rstd) | ||
|
|
||
|
|
||
| # TODO: Adding this as a meta function causes functorch tests to fail when compiled with debug mode. | ||
| # test/test_eager_transforms.py::TestFunctionalizeCPU::test_functionalize_fx_transpose_simple_cpu | ||
| @register_decomposition(torch.ops.aten.permute, disable_meta=True) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a comment for this block -- is this a BC-breaking change?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it does change the shape of var and mean returned for input tensors with 0s in the shape. It's a bug fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Behavior on master: