Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.nn.LayerNorm support for arbitrary axis in order to allow NCHW application #71465

Open
dzhulgakov opened this issue Jan 19, 2022 · 21 comments
Labels
actionable high priority module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@dzhulgakov
Copy link
Collaborator

dzhulgakov commented Jan 19, 2022

馃殌 The feature, motivation and pitch

LayerNorm starts to be applied to image data on per-channel basis (e.g. in ConvNeXt model).

torch.nn.LayerNorm support normalization only on the last several dimensions. If the input is NCHW tensor (default layout) it requires explicit NHWC conversion in order to do normalization on channels only. Example from ConvNeXt implementation: https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py#L15

LayerNorm could instead allow to specify which axis the normalization should be performed over and be backed by an efficient kernel. Example in Keras receiving axis argument: https://www.tensorflow.org/api_docs/python/tf/keras/layers/LayerNormalization

Additionally, it'd be nice to also support memory_format=channels_last setting for the input tensor. Though it's an orthogonal point - this issue talks about semantics, not about the underlying memory format.

P.S. related discussion: https://twitter.com/wightmanr/status/1481383606782087168

cc @ezyang @gchanan @zou3519 @kadeng @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @kshitij12345 @saketh-are

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Jan 19, 2022

Related nn.LayerNorm docs issue:

(it would be nice to have a super-compact reference pytorch impl of layernorm in docs for it's much clearer on what's aggregated and even a small Python-only snippet helps a lot in understanding). Maybe ConvNeXt could be a realistic example there of using LayerNorm in vision models with NCHW semantic layout

Related nn.Linear issue:

(in the sense that Linear currently is also performed only on the last dimension, and sometimes it's more flexible code-wise to use nn.Linear on NCHW without doing permutes (the existing alternative is to translate it to use pointwise Conv2d, but it's nasty to maintain two versions in the user code)

@H-Huang H-Huang added module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 20, 2022
@albanD
Copy link
Collaborator

albanD commented Jan 25, 2022

Note that for no batch dim support, we most likely don't want to add any case that would make the existence or not of a batch dim ambiguous.

@rwightman
Copy link

rwightman commented Jul 5, 2022

@dzhulgakov Can we bump this in priority? More and more networks would benefit from this. The lack of a native impl appears to be governing some design decisions in cnn-transformer hybrid models implemented in PyTorch (or causing issues with Tensorflow/JAX originating models where this isn't a concern).

It's not only ConvNeXt now, but EdgeNeXt, CoaTNet (Google version), NesT, VisFormer, PoolFormer (would have used it but opted for GroupNorm(groups=1) instead, not equivalent). Now MobileVit2 appears to be trying to use it but opted for GroupNorm(groups=1), possibly without realizing the lack of equivalence based on how they named it (apple/ml-cvnets#34)

@DmitryUlyanov
Copy link
Contributor

I am not sure why you tagged me here, but why won't you just use InstanceNorm?

@rwightman
Copy link

rwightman commented Jul 6, 2022

I am not sure why you tagged me here, but why won't you just use InstanceNorm?

Sorry, I was a bit quick on the @ autocomplete there :/

InstanceNorm wouldn't be equivalent. The LayerNorm op we want just computes stats over C dim and applies affine to same dim. As it stands right now you can only apply PT LN over the last n-dim of a tensor.

Three other non-equivalent that some use (either on purpose or by mistake):

  • InstanceNorm would be stats over H, W and applies affine to C
  • GroupNorm(groups=1) would be stats over C, H, W, affine to C
  • Like GN, there is also LayerNorm([C, H, W]). This is same as GN above except the affine part which gets applies to C, H, W in the LN case

To easiest way to get what we want in PyTorch right now is F.layer_norm(x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) ... however those permutes aren't free, and this mucks performance up when you're trying to use memory_format=channels_last.

@jbschlosser
Copy link
Contributor

Making high-pri due to activity.

@vadimkantorov
Copy link
Contributor

To easiest way to get what we want in PyTorch right now is F.layer_norm(x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) ... however those permutes aren't free, and this mucks performance up when you're trying to use memory_format=channels_last.

@ngimel also highlighted that channel_last isn't generally well supported by Norm's: #72341 :(

@rwightman
Copy link

@vadimkantorov yeah, I'd really want to have channels_last support, it would be much less useful without.

@ngimel
Copy link
Collaborator

ngimel commented Jul 6, 2022

Channels-last layer-norm can kinda be done today, with permute (free, only changes sizes and strides), regular layer norm and unpermute (again, free, and will nicely result in channels-last tensor to send to the next conv).

@rwightman
Copy link

rwightman commented Jul 6, 2022

@ngimel as in this ? F.layer_norm(x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)?

I've spent some time running through iterations of this, and while the above is 'okay' there didn't appear to be any improvement using CL vs not, where as if you replaced those LN with BN there was, I was able to boost throughput noteably w/ channels_last doing some hacks (checking for contiguous, and switching between that impl and a full custom mean/var LN impl as you see in convnext) with ~10+% bump on Ampere cards, but killed throughput on older cards.

@ngimel
Copy link
Collaborator

ngimel commented Jul 6, 2022

Yeah that one, it would result in good perf only if permuted tensor is contiguous, so your checks are correct. Strange that it kills throughput on older cards - Volta should have decent channels last? But older cards that don't support channels last would suffer a lot.

@ngimel
Copy link
Collaborator

ngimel commented Jul 6, 2022

@rwightman what would be the command to run convnext channels-last fp16 benchmark?

@ezyang
Copy link
Contributor

ezyang commented Jul 6, 2022

@ngimel If layer norm kernel already has the smarts for this we should just add the memory format support, should be easy right?

@ngimel
Copy link
Collaborator

ngimel commented Jul 6, 2022

layer norm kernel doesn't have the smarts, it works on contiguous tensors and normalizes on the last dimension.

@ezyang
Copy link
Contributor

ezyang commented Jul 7, 2022

Do you think our compiler like nvFuser is good enough to produce a good kernel? Maybe we should just have someone sit down and write the Triton kernel

@ngimel
Copy link
Collaborator

ngimel commented Jul 9, 2022

see huggingface/pytorch-image-models#1340

@rwightman
Copy link

So, back to this thread with the original ask, LayerNorm w/ arbitrary axis. @ngimel demo'd some hacks that can be used with current PyTorch codegen to get some better performance doing a custom LN layer for the LN over C-dim for 2D NCHW case. It can work but it's got a lot of gotchas re use of torchsript, possibly complications (or needing a more basic impl) for appropriate ONNX export (haven't tested this yet), and relying on codegen still has baggage in current PyTorch (although getting better with every release).

Regardless of what we can squeeze out performance wise, there is still torch as an API, this type of layer is becoming common. I've already used it more than I've ever used InstanceNorm, for example. PyTorch stands out relative to other frameworks as not supporting arbitrary feature axis/dim arg for LN (ie requiring 'last' N dim). Whatever the performance issues are, I feel there is a need for builtin support for this so we don't have every modelling lib (like timm) trying to roll their own fast custom layer with a bunch of hacks (for a common use case).

JAX (via Flax) has separate reduction and feature axis, defaults to -1 so w/ defaults for tensor layout, this use case is already covered for 2d tensor. The flexibility to cover other cases is nice, including matching behaviour of GN with groups=1. I've bumped into the need for a more generic, but efficient norm and affine over arbitrary dim a number of times (this Flax impl covers all of those use cases).

class LayerNorm(Module):
  """Layer normalization (https://arxiv.org/abs/1607.06450).

  LayerNorm normalizes the activations of the layer for each given example in a
  batch independently, rather than across a batch like Batch Normalization.
  i.e. applies a transformation that maintains the mean activation within
  each example close to 0 and the activation standard deviation close to 1.

  Attributes:
    epsilon: A small float added to variance to avoid dividing by zero.
    dtype: the dtype of the result (default: infer from input and params).
    param_dtype: the dtype passed to parameter initializers (default: float32).
    use_bias:  If True, bias (beta) is added.
    use_scale: If True, multiply by scale (gamma). When the next layer is linear
      (also e.g. nn.relu), this can be disabled since the scaling will be done
      by the next layer.
    bias_init: Initializer for bias, by default, zero.
    scale_init: Initializer for scale, by default, one.
    reduction_axes: Axes for computing normalization statistics.
    feature_axes: Feature axes for learned bias and scaling.
  """
  epsilon: float = 1e-6
  dtype: Optional[Dtype] = None
  param_dtype: Dtype = jnp.float32
  use_bias: bool = True
  use_scale: bool = True
  bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros
  scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
  reduction_axes: Axes = -1
  feature_axes: Axes = -1

[[docs]](https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.LayerNorm.html#flax.linen.LayerNorm.__call__)  @compact
  def __call__(self, x):
    """Applies layer normalization on the input.

    Args:
      x: the inputs

    Returns:
      Normalized inputs (the same shape as inputs).
    """
    mean, var = _compute_stats(x, self.reduction_axes, self.dtype, None, None)

    return _normalize(
        self, x, mean, var, self.reduction_axes, self.feature_axes,
        self.dtype, self.param_dtype, self.epsilon,
        self.use_bias, self.use_scale,
        self.bias_init, self.scale_init)

Keras / TF also has axis arg with arbitrary dim, as with JAX, this 2D use case is covered by default by their combo of default tensor layout and -1 for normalization dim. affine has to match reduction here though.


  Note that other implementations of layer normalization may choose to define
  `gamma` and `beta` over a separate set of axes from the axes being
  normalized across. For example, Group Normalization
  ([Wu et al. 2018](https://arxiv.org/abs/1803.08494)) with group size of 1
  corresponds to a Layer Normalization that normalizes across height, width,
  and channel and has `gamma` and `beta` span only the channel dimension.
  So, this Layer Normalization implementation will not match a Group
  Normalization layer with group size set to 1.
  Args:
    axis: Integer or List/Tuple. The axis or axes to normalize across. Typically
      this is the features axis/axes. The left-out axes are typically the batch
      axis/axes. This argument defaults to `-1`, the last dimension in the
      input.
    epsilon: Small float added to variance to avoid dividing by zero. Defaults
      to 1e-3
    center: If True, add offset of `beta` to normalized tensor. If False, `beta`
      is ignored. Defaults to True.
    scale: If True, multiply by `gamma`. If False, `gamma` is not used. Defaults
      to True. When the next layer is linear (also e.g. `nn.relu`), this can be
      disabled since the scaling will be done by the next layer.
    beta_initializer: Initializer for the beta weight. Defaults to zeros.
    gamma_initializer: Initializer for the gamma weight. Defaults to ones.
    beta_regularizer: Optional regularizer for the beta weight. None by default.
    gamma_regularizer: Optional regularizer for the gamma weight. None by
      default.
    beta_constraint: Optional constraint for the beta weight. None by default.
    gamma_constraint: Optional constraint for the gamma weight. None by default.
  Input shape:
    Arbitrary. Use the keyword argument `input_shape` (tuple of
    integers, does not include the samples axis) when using this layer as the

@xvjiarui
Copy link

Just out of curiosity, is the channel last LayerNorm or LayerNorm2d from timm equivalent to Positional Norm here https://github.com/Boyiliee/Positional-Normalization?

@hgaiser
Copy link
Contributor

hgaiser commented Sep 13, 2023

Is there any progress on this? I'm running into the same issue when implementing ViTDet in torchvision.

@pbelcak
Copy link

pbelcak commented Dec 28, 2023

Bumpidy bump. Would also greatly benefit from this being implemented (usecase CV models with NCHW tensor layout).

Looking at the support for fused index_select with matmul, maybe instead of having a completely new operator, we could just fuse .transpose(k, -1) (or, more generally, Tensor.permute) with layer norm?

@ezyang
Copy link
Contributor

ezyang commented Jan 1, 2024

Someone should check if PT2 layer norm does well on this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
actionable high priority module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: To pick up
Development

No branches or pull requests