Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
7808339
Remove special casing for 0-sized inputs
IvanYashchuk Jun 13, 2022
1703f58
Add OpInfo for native_layer_norm
IvanYashchuk Jun 13, 2022
e357a7a
Remove special casing for 0-sized in decompositions.py
IvanYashchuk Jun 13, 2022
09f6a13
Skip CPU bfloat16 for native_layer_norm
IvanYashchuk Jun 13, 2022
fc89dd3
Add refs for rsqrt, native_layer_norm
IvanYashchuk Jun 13, 2022
90b7bdf
Add samples with None arguments for weight and bias
IvanYashchuk Jun 13, 2022
3d68aba
Three cases for optional weight and bias
IvanYashchuk Jun 13, 2022
8b95bda
Add error_inputs_native_layer_norm
IvanYashchuk Jun 13, 2022
9b5db3f
Add input checks
IvanYashchuk Jun 13, 2022
f323b3e
Remove native_layer_norm decomp
IvanYashchuk Jun 13, 2022
f246459
formatting
IvanYashchuk Jun 13, 2022
dfe8a1d
Skip mypy
IvanYashchuk Jun 13, 2022
87cdca6
Remove test_comprehensive skip
IvanYashchuk Jun 13, 2022
472ddf6
Merge remote-tracking branch 'upstream/viable/strict' into native-lay…
IvanYashchuk Jun 14, 2022
0ecffe6
Revert "Remove test_comprehensive skip"
IvanYashchuk Jun 14, 2022
0e47507
Remove test_comprehensive skip
IvanYashchuk Jun 14, 2022
1013fe0
Add rsqrt to __all__
IvanYashchuk Jun 14, 2022
7f836d3
Use torch.rsqrt per Horace's request
IvanYashchuk Jun 14, 2022
df7d6b5
Merge remote-tracking branch 'upstream/master' into native-layer-norm
IvanYashchuk Jun 14, 2022
5730c09
Add a comment to layer_norm_kernel.cu
IvanYashchuk Jun 15, 2022
8d95161
Use _maybe_convert_to_dtype
IvanYashchuk Jun 15, 2022
a11bda3
Add docstring to _normalize
IvanYashchuk Jun 15, 2022
6cd5a74
Change TensorLikeType->Tensor
IvanYashchuk Jun 15, 2022
5901865
Fix typo
IvanYashchuk Jun 16, 2022
2aaf514
Add refs.nn.functional.layer_norm
IvanYashchuk Jun 16, 2022
9198804
Merge remote-tracking branch 'upstream/viable/strict' into native-lay…
IvanYashchuk Jun 16, 2022
0027287
mypy fixes
IvanYashchuk Jun 16, 2022
e875261
Remove supports_expanded_weight
IvanYashchuk Jun 16, 2022
eeccfe4
xfail gradgrad test and link the issue
IvanYashchuk Jun 16, 2022
58b18bc
Skip jit test because gradgrad is failing
IvanYashchuk Jun 16, 2022
f7307e1
Skip test_correctness_with_reusing_ir
IvanYashchuk Jun 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions aten/src/ATen/native/cuda/layer_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -850,23 +850,25 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_cuda(
auto acc_type = at::toAccumulateType(input.scalar_type(), /*is_cuda=*/true);
Tensor mean = at::empty({M}, X->options().dtype(acc_type));
Tensor rstd = at::empty({M}, X->options().dtype(acc_type));
// Calling the kernel for M==0 gives a CUDA error
// See: https://github.com/pytorch/pytorch/pull/28614
if (M > 0) {
LayerNormKernelImpl(*X, *gamma, *beta, M, N, eps, &Y, &mean, &rstd);
}
const auto input_shape = input.sizes();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment for this block -- is this a BC-breaking change?

Copy link
Collaborator Author

@IvanYashchuk IvanYashchuk Jun 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a BC-breaking change?

Yes, it does change the shape of var and mean returned for input tensors with 0s in the shape. It's a bug fix.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Behavior on master:

In [1]: import torch

In [2]: a = torch.randn(2, 0, 3, 3, device='cuda') # note 0 in the 2nd dim

In [3]: [x.shape for x in torch.native_layer_norm(a, (3, 3), None, None, 1e-5)]
Out[3]: [torch.Size([2, 0, 3, 3]), torch.Size([0]), torch.Size([0])] # shapes for var and mean are wrong

In [4]: a = torch.randn(2, 1, 3, 3, device='cuda') # note non-zero in the 2nd dim

In [5]: [x.shape for x in torch.native_layer_norm(a, (3, 3), None, None, 1e-5)]
Out[5]: [torch.Size([2, 1, 3, 3]), torch.Size([2, 1, 1, 1]), torch.Size([2, 1, 1, 1])]

const size_t axis = input.dim() - normalized_shape.size();

const auto input_shape = input.sizes();
const size_t axis = input.dim() - normalized_shape.size();
std::vector<int64_t> stat_shape;
for (size_t idx = 0; idx < axis; ++idx) {
stat_shape.push_back(input_shape[idx]);
}
for (size_t idx = axis; idx < input.dim(); ++idx) {
stat_shape.push_back(1);
}

std::vector<int64_t> stat_shape;
for (size_t idx = 0; idx < axis; ++idx) {
stat_shape.push_back(input_shape[idx]);
}
for (size_t idx = axis; idx < input.dim(); ++idx) {
stat_shape.push_back(1);
}
mean = mean.view(stat_shape);
rstd = rstd.view(stat_shape);

mean = mean.view(stat_shape);
rstd = rstd.view(stat_shape);
}
return std::make_tuple(std::move(Y), std::move(mean), std::move(rstd));
}

Expand Down
4 changes: 0 additions & 4 deletions aten/src/ATen/native/layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ void layer_norm_with_mean_rstd_out(
double eps,
int64_t M,
int64_t N) {
if (M <= 0) {
return;
}

LayerNormKernel(kCPU, input, gamma, beta, M, N, eps, &out, &mean, &rstd);
const auto input_shape = input.sizes();
const size_t axis = input.dim() - normalized_shape.size();
Expand Down
3 changes: 3 additions & 0 deletions test/test_decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,9 @@ def check_decomposed(aten_name):
func = op.get_op()
for sample_input in samples:
if requires_grad:
if None in sample_input.args:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this test cannot handle None as a positional argument to native_layer_norm. It has the signature: native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor). Both weight and bias can be None and I added sample inputs for these cases.

continue

fn, primals = normalize_op_input_output(func, sample_input)
primals = tree_map(
lambda x: x if isinstance(x, torch.Tensor) else x, primals
Expand Down
32 changes: 0 additions & 32 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,38 +815,6 @@ def normalize(input, norm_dims, eps):
return out, mean, rstd


@register_decomposition(aten.native_layer_norm.default)
def native_layer_norm(
input: Tensor,
normalized_shape: List[int],
weight: Optional[Tensor],
bias: Optional[Tensor],
eps: float,
) -> Tuple[Tensor, Tensor, Tensor]:
computation_dtype = utils.get_computation_dtype(input.dtype)

axis = input.dim() - len(normalized_shape)
if prod(list(input.shape[:axis])) == 0:
mean = input.new_zeros((0,), dtype=computation_dtype)
rstd = input.new_zeros((0,), dtype=computation_dtype)
out = input
else:
reduction_dims = list(range(axis, input.dim()))
out, mean, rstd = normalize(input, reduction_dims, eps)

if weight is not None:
out = out * weight
if bias is not None:
out = out + bias

out = out.to(dtype=input.dtype)

if input.device.type == 'cpu':
mean = mean.to(dtype=input.dtype)
rstd = rstd.to(dtype=input.dtype)
return (out, mean, rstd)


@register_decomposition(aten.native_group_norm.default, disable_meta=True)
def native_group_norm(
input: Tensor,
Expand Down
99 changes: 96 additions & 3 deletions torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,13 @@
"hsplit",
"hstack",
"narrow",
"native_layer_norm",
"permute",
"ravel",
"reshape",
"roll",
"rot90",
"rsqrt",
"stack",
"swap_axes", # alias for transpose
"squeeze",
Expand Down Expand Up @@ -626,6 +628,11 @@ def round(a):
return prims.round(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def rsqrt(a):
return prims.rsqrt(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def sigmoid(a: TensorLikeType) -> TensorLikeType:
return true_divide(1, add(1, exp(neg(a))))
Expand Down Expand Up @@ -1426,7 +1433,7 @@ def _set_correction(
@out_wrapper
def var(
a: TensorLikeType,
dim: Union[Optional[int], Optional[List[int]]] = None,
dim: Optional[DimsType] = None,
unbiased: Optional[bool] = None,
keepdim: bool = False,
*,
Expand Down Expand Up @@ -1484,7 +1491,7 @@ def std(

def mean(
a: TensorLikeType,
dim: Union[Optional[int], Optional[List[int]]] = None,
dim: Optional[DimsType] = None,
keepdim: bool = False,
*,
dtype=None,
Expand Down Expand Up @@ -1539,7 +1546,7 @@ def std_mean(

def var_mean(
a: TensorLikeType,
dim: Union[Optional[int], Optional[List[int]]] = None,
dim: Optional[DimsType] = None,
unbiased: Optional[bool] = None,
keepdim: bool = False,
*,
Expand Down Expand Up @@ -1797,6 +1804,92 @@ def narrow(a: TensorLikeType, dim: int, start: int, length: int) -> TensorLikeTy
return prims.slice_in_dim(a, start, start + length, axis=dim)


def _normalize(
a: Tensor, norm_dims: DimsType, eps: float
) -> Tuple[Tensor, Tensor, Tensor]:
"""Computes mean and 1/std of a tensor along norm_dims.

Used as a helper function for normalization layers.

Args:
a (Tensor): input tensor
norm_dims (DimsType): dimensions to normalize over
eps (float): epsilon for numerical stability

Returns:
out (Tensor): normalized tensor.
mean (Tensor): mean of the tensor along norm_dims.
rstd (Tensor): 1/std of the tensor along norm_dims.
"""
computation_dtype = utils.get_computation_dtype(a.dtype)
a_acc = _maybe_convert_to_dtype(a, computation_dtype)
assert isinstance(a_acc, TensorLike) # to avoid mypy error for var_mean
biased_var, mean = var_mean(a_acc, dim=norm_dims, unbiased=False, keepdim=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently nvfuser doesn't understand var_mean, so this decomp is not readily usable

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Also, the reason the previous decomp used separate calls to var and mean was that it more closely matched the numerics of eager mode. @ngimel's opinion was that the discrepancy shouldn't matter, but just noting in case you need to adjust tolerances on any tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will make mean working with the nvFuser executor in a separate PR. var will be implemented later as it's currently a "prim" in PyTorch, but a composite in nvFuser.

It's separate calls to var and mean. Take a look at the var_mean function:

def var_mean(
    a: TensorLikeType,
    dim: Union[Optional[int], Optional[List[int]]] = None,
    unbiased: Optional[bool] = None,
    keepdim: bool = False,
    *,
    correction: Optional[int] = None,
):
    v = var(a, dim, unbiased, keepdim, correction=correction)
    m = mean(a, dim, keepdim)
    return v, m

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR for mean: #79444
PR for var: #79517
And together they make nvfuser understand var_mean.

Copy link
Collaborator

@lezcano lezcano Jun 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe @peterbell10 once told me that var_mean had some speed and stability problems in PyTorch (?). I do not remember very well so I'll let him discuss this, but perhaps this is relevant here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On CPU, the mean from var_mean is less accurate than calling mean separately. mean is implemented roughly as sum(x, dim) / x.size(dim), where sum uses a low-error summation algorithm. var_mean on the other hand computes both the mean and var in a single pass, but with a naive summation that is less accurate.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

var_mean here calls into reference decomposition which in turn computes mean and var separately (mean via sum and division)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For similar reasons, can we change this call to a torch.var_mean as well?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, then, unless you decompose torch.var_mean further, you'd get inaccurate cpu result

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then I think var_mean should probably call into torch.var and torch.mean.

rstd = torch.rsqrt(biased_var + eps)
out = (a - mean) * rstd
return out, mean, rstd


@register_decomposition(torch.ops.aten.native_layer_norm)
def native_layer_norm(
input: Tensor,
normalized_shape: ShapeType,
weight: Optional[Tensor],
bias: Optional[Tensor],
eps: float,
) -> Tuple[Tensor, Tensor, Tensor]:
normalized_ndim = len(normalized_shape)
utils.check(
normalized_ndim >= 1,
lambda: "Expected normalized_shape to be at least 1-dimensional, i.e., "
+ "containing at least one element, but got normalized_shape = "
+ str(normalized_shape),
)
# torch.Size([1, 2, 3]) == [1, 2, 3] evaluates to False
# while torch.Size([1, 2, 3]) == (1, 2, 3) is True
# therefore we use tuple(normalized_shape)
utils.check(
weight is None or weight.shape == tuple(normalized_shape),
lambda: "Expected weight to be of same shape as normalized_shape, but got "
+ "weight of shape "
+ str(weight.shape) # type: ignore[union-attr]
+ " and normalized_shape = "
+ str(normalized_shape),
)
utils.check(
bias is None or bias.shape == tuple(normalized_shape),
lambda: "Expected bias to be of same shape as normalized_shape, but got "
+ "bias of shape "
+ str(bias.shape) # type: ignore[union-attr]
+ " and normalized_shape = "
+ str(normalized_shape),
)
utils.check(
input.ndim >= normalized_ndim
and input.shape[(input.ndim - normalized_ndim) :] == tuple(normalized_shape),
lambda: "Given normalized_shape="
+ str(normalized_shape)
+ ", expected input with shape "
+ str(normalized_shape)
+ ", but got input of size "
+ str(input.shape),
)
axis = input.ndim - normalized_ndim
reduction_dims = list(range(axis, input.ndim))
out, mean, rstd = _normalize(input, reduction_dims, eps)
if weight is None and bias is not None:
out = out + bias
elif weight is not None and bias is None:
out = out * weight
elif weight is not None and bias is not None:
out = out * weight + bias
out = prims.convert_element_type(out, input.dtype)
if input.device.type == "cpu":
mean = prims.convert_element_type(mean, input.dtype)
rstd = prims.convert_element_type(rstd, input.dtype)
return (out, mean, rstd)


# TODO: Adding this as a meta function causes functorch tests to fail when compiled with debug mode.
# test/test_eager_transforms.py::TestFunctionalizeCPU::test_functionalize_fx_transpose_simple_cpu
@register_decomposition(torch.ops.aten.permute, disable_meta=True)
Expand Down
16 changes: 16 additions & 0 deletions torch/_refs/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch._prims.utils as utils
from torch._prims.utils import (
ShapeType,
TensorLike,
TensorLikeType,
NumberType,
Expand Down Expand Up @@ -35,6 +36,8 @@
"tanhshrink",
]

Tensor = torch.Tensor

# celu is implemented specially because it has an alpha argument
# celu is very similar to elu
@register_decomposition(torch.ops.aten.celu)
Expand Down Expand Up @@ -146,6 +149,19 @@ def relu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType:
return torch.where(torch.le(a, 0), 0, a)


def layer_norm(
input: Tensor,
normalized_shape: ShapeType,
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
eps: float = 1e-5,
) -> Tensor:
"""
Reference implementation of :func:`torch.nn.functional.layer_norm`.
"""
return torch.native_layer_norm(input, normalized_shape, weight, bias, eps)[0]


@register_decomposition(torch.ops.aten.leaky_relu)
@elementwise_type_promotion_wrapper(
type_promoting_args=("a",),
Expand Down
Loading