New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.nn.LayerNorm support for arbitrary axis in order to allow NCHW application #71465
Comments
Related nn.LayerNorm docs issue: (it would be nice to have a super-compact reference pytorch impl of layernorm in docs for it's much clearer on what's aggregated and even a small Python-only snippet helps a lot in understanding). Maybe ConvNeXt could be a realistic example there of using LayerNorm in vision models with NCHW semantic layout Related nn.Linear issue: (in the sense that Linear currently is also performed only on the last dimension, and sometimes it's more flexible code-wise to use nn.Linear on NCHW without doing permutes (the existing alternative is to translate it to use pointwise Conv2d, but it's nasty to maintain two versions in the user code) |
Note that for no batch dim support, we most likely don't want to add any case that would make the existence or not of a batch dim ambiguous. |
@dzhulgakov Can we bump this in priority? More and more networks would benefit from this. The lack of a native impl appears to be governing some design decisions in cnn-transformer hybrid models implemented in PyTorch (or causing issues with Tensorflow/JAX originating models where this isn't a concern). It's not only ConvNeXt now, but EdgeNeXt, CoaTNet (Google version), NesT, VisFormer, PoolFormer (would have used it but opted for GroupNorm(groups=1) instead, not equivalent). Now MobileVit2 appears to be trying to use it but opted for GroupNorm(groups=1), possibly without realizing the lack of equivalence based on how they named it (apple/ml-cvnets#34) |
I am not sure why you tagged me here, but why won't you just use InstanceNorm? |
Sorry, I was a bit quick on the @ autocomplete there :/ InstanceNorm wouldn't be equivalent. The LayerNorm op we want just computes stats over C dim and applies affine to same dim. As it stands right now you can only apply PT LN over the last n-dim of a tensor. Three other non-equivalent that some use (either on purpose or by mistake):
To easiest way to get what we want in PyTorch right now is |
Making high-pri due to activity. |
@ngimel also highlighted that channel_last isn't generally well supported by Norm's: #72341 :( |
@vadimkantorov yeah, I'd really want to have channels_last support, it would be much less useful without. |
Channels-last layer-norm can kinda be done today, with permute (free, only changes sizes and strides), regular layer norm and unpermute (again, free, and will nicely result in channels-last tensor to send to the next conv). |
@ngimel as in this ? I've spent some time running through iterations of this, and while the above is 'okay' there didn't appear to be any improvement using CL vs not, where as if you replaced those LN with BN there was, I was able to boost throughput noteably w/ channels_last doing some hacks (checking for contiguous, and switching between that impl and a full custom mean/var LN impl as you see in convnext) with ~10+% bump on Ampere cards, but killed throughput on older cards. |
Yeah that one, it would result in good perf only if permuted tensor is contiguous, so your checks are correct. Strange that it kills throughput on older cards - Volta should have decent channels last? But older cards that don't support channels last would suffer a lot. |
@rwightman what would be the command to run convnext channels-last fp16 benchmark? |
@ngimel If layer norm kernel already has the smarts for this we should just add the memory format support, should be easy right? |
layer norm kernel doesn't have the smarts, it works on contiguous tensors and normalizes on the last dimension. |
Do you think our compiler like nvFuser is good enough to produce a good kernel? Maybe we should just have someone sit down and write the Triton kernel |
So, back to this thread with the original ask, LayerNorm w/ arbitrary axis. @ngimel demo'd some hacks that can be used with current PyTorch codegen to get some better performance doing a custom LN layer for the LN over C-dim for 2D NCHW case. It can work but it's got a lot of gotchas re use of torchsript, possibly complications (or needing a more basic impl) for appropriate ONNX export (haven't tested this yet), and relying on codegen still has baggage in current PyTorch (although getting better with every release). Regardless of what we can squeeze out performance wise, there is still torch as an API, this type of layer is becoming common. I've already used it more than I've ever used InstanceNorm, for example. PyTorch stands out relative to other frameworks as not supporting arbitrary feature axis/dim arg for LN (ie requiring 'last' N dim). Whatever the performance issues are, I feel there is a need for builtin support for this so we don't have every modelling lib (like timm) trying to roll their own fast custom layer with a bunch of hacks (for a common use case). JAX (via Flax) has separate reduction and feature axis, defaults to -1 so w/ defaults for tensor layout, this use case is already covered for 2d tensor. The flexibility to cover other cases is nice, including matching behaviour of GN with groups=1. I've bumped into the need for a more generic, but efficient norm and affine over arbitrary dim a number of times (this Flax impl covers all of those use cases). class LayerNorm(Module):
"""Layer normalization (https://arxiv.org/abs/1607.06450).
LayerNorm normalizes the activations of the layer for each given example in a
batch independently, rather than across a batch like Batch Normalization.
i.e. applies a transformation that maintains the mean activation within
each example close to 0 and the activation standard deviation close to 1.
Attributes:
epsilon: A small float added to variance to avoid dividing by zero.
dtype: the dtype of the result (default: infer from input and params).
param_dtype: the dtype passed to parameter initializers (default: float32).
use_bias: If True, bias (beta) is added.
use_scale: If True, multiply by scale (gamma). When the next layer is linear
(also e.g. nn.relu), this can be disabled since the scaling will be done
by the next layer.
bias_init: Initializer for bias, by default, zero.
scale_init: Initializer for scale, by default, one.
reduction_axes: Axes for computing normalization statistics.
feature_axes: Feature axes for learned bias and scaling.
"""
epsilon: float = 1e-6
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
use_bias: bool = True
use_scale: bool = True
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros
scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
reduction_axes: Axes = -1
feature_axes: Axes = -1
[[docs]](https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.LayerNorm.html#flax.linen.LayerNorm.__call__) @compact
def __call__(self, x):
"""Applies layer normalization on the input.
Args:
x: the inputs
Returns:
Normalized inputs (the same shape as inputs).
"""
mean, var = _compute_stats(x, self.reduction_axes, self.dtype, None, None)
return _normalize(
self, x, mean, var, self.reduction_axes, self.feature_axes,
self.dtype, self.param_dtype, self.epsilon,
self.use_bias, self.use_scale,
self.bias_init, self.scale_init) Keras / TF also has axis arg with arbitrary dim, as with JAX, this 2D use case is covered by default by their combo of default tensor layout and -1 for normalization dim. affine has to match reduction here though.
|
Just out of curiosity, is the channel last LayerNorm or LayerNorm2d from timm equivalent to Positional Norm here https://github.com/Boyiliee/Positional-Normalization? |
Is there any progress on this? I'm running into the same issue when implementing ViTDet in torchvision. |
Bumpidy bump. Would also greatly benefit from this being implemented (usecase CV models with NCHW tensor layout). Looking at the support for fused |
Someone should check if PT2 layer norm does well on this |
馃殌 The feature, motivation and pitch
LayerNorm starts to be applied to image data on per-channel basis (e.g. in ConvNeXt model).
torch.nn.LayerNorm
support normalization only on the last several dimensions. If the input is NCHW tensor (default layout) it requires explicit NHWC conversion in order to do normalization on channels only. Example from ConvNeXt implementation: https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py#L15LayerNorm
could instead allow to specify which axis the normalization should be performed over and be backed by an efficient kernel. Example in Keras receivingaxis
argument: https://www.tensorflow.org/api_docs/python/tf/keras/layers/LayerNormalizationAdditionally, it'd be nice to also support
memory_format=channels_last
setting for the input tensor. Though it's an orthogonal point - this issue talks about semantics, not about the underlying memory format.P.S. related discussion: https://twitter.com/wightmanr/status/1481383606782087168
cc @ezyang @gchanan @zou3519 @kadeng @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @kshitij12345 @saketh-are
The text was updated successfully, but these errors were encountered: