Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC]A suggestion of channels last memory format implementation for 3D tensor #74935

Open
KevinYuk opened this issue Mar 30, 2022 · 12 comments
Open
Labels
feature A request for a proper, new feature. module: memory format Memory format/layout related issues/changes (channels_last, nhwc) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@KevinYuk
Copy link
Contributor

KevinYuk commented Mar 30, 2022

🚀 The feature, motivation and pitch

Motivation

Pytorch has already supported ChannelsLast(2d) for 4D tensor(N, C, H, W) and ChannelsLast3d for 5D tensor(N, C, H, W, D), but doesn't support ChannelsLast1d for 3D tensor(N, C, L). See below:

ChannelsLast for 4D tensor works fine:

>>> import torch
>>> N, C, H, W = 8, 3, 32, 32
>>> _4d_tensor = torch.empty(N, C, H, W)
>>> _4d_tensor_cl = _4d_tensor.to(memory_format=torch.channels_last)
>>> _4d_tensor_cl.is_contiguous(memory_format=torch.channels_last)
True
>>> tensor_4d.stride()
(3072, 1024, 32, 1)

ChannelsLast for 5D tensor works fine:

>>> import torch
>>> N, C, H, W, D = 8, 3, 32, 32, 32
>>> _5d_tensor = torch.empty(N, C, H, W, D)
>>> _5d_tensor_cl = _5d_tensor.to(memory_format=torch.channels_last_3d)
>>> _5d_tensor_cl.is_contiguous(memory_format=torch.channels_last_3d)
True
>>> _5d_tensor_cl.stride()
(98304, 1, 3072, 96, 3)

ChannelsLast for 3D tensor doens't work:

>>> import torch
>>> N, C, L = 8, 3, 32
>>> _3d_tensor = torch.empty(N, C, L)
>>> _3d_tensor_cl = _3d_tensor.to(memory_format=torch.channels_last_1d)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: module 'torch' has no attribute 'channels_last_1d'
>>> _3d_tensor_cl = _3d_tensor.to(memory_format=torch.channels_last)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: required rank 4 tensor to use channels_last format
>>> _3d_tensor_cl = _3d_tensor.to(memory_format=torch.channels_last_3d)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: required rank 5 tensor to use channels_last_3d format

However, operators such as conv1d, pool1d etc. demand ChannelsLast1d to get better performance boost due to the natural advantages of Channels Last memory format.

Usage

After the feature is supported, it works as below:

>>> import torch
>>> N, C, L = 8, 3, 32
>>> _3d_tensor = torch.empty(N, C, L)
>>> _3d_tensor_cl = _3d_tensor.to(memory_format=torch.channels_last_1d)
>>> _3d_tensor_cl.is_contiguous(memory_format=torch.channels_last_1d)
True
>>> _3d_tensor_cl.stride()
(96, 1, 3)

Value

ChannelsLast1d feature would benifit such as time series analysis models, deep learning model based on Lidar data, voice model wav2vec, etc.

Proposed Implementation

Proposal 1:

The general implementation principle of proposal 1 is as below:

  1. ChannelsLast1d will align to the usage habits of ChannelsLast(2d) and ChannelsLast3d to provide consistent use experience;
  2. No extra bits are added in TensorImpl structure;
  3. Does not introduce any overhead for important function refresh_continguous(). It does not affect the computation of any original ChannelsLast(2d) and ChannelsLast3d associated flags;
  4. The feature is transparent to the end users if they don't use it, both in functionality and performance.

The details are as follows:

Regarding 1: Users can use it as below:

_3d_tensor_cl = _3d_tensor.to(memory_format=torch.channels_last_1d)
_4d_tensor_cl = _4d_tensor.to(memory_format=torch.channels_last)
_5d_tensor_cl = _5d_tensor.to(memory_format=torch.channels_last_3d)

Regarding 2 and 3: As is known, for ChannelsLast(2d) and ChannelsLast3d, there are associated flags in TensorImpl structure as below:

  bool is_channels_last_ : 1;
  bool is_channels_last_contiguous_ : 1;
  bool is_channels_last_3d_ : 1;
  bool is_channels_last_3d_contiguous_ : 1;

Then refresh_contiguous() would update these flags to track the tensor memory format information. APIs such as is_contiguous(), is_strides_like_channels_last(), is_strides_like_channels_last_3d(), etc. could work based on these flags.
To avoide to introudce extra bits into TensorImpl structure, don't define such as bool is_channels_last_1d_ : 1; bool is_channels_last_1d_contiguous_ : 1; in TensorImpl structure for ChannelsLast1d, which would not introudce any overhead for key function refresh_contiguous(). If the associated APIs(e.g.: is_contiguous()) demand the memory format information for ChannelsLast1d, we do it as below:

  TENSORIMPL_MAYBE_VIRTUAL bool is_contiguous(
      at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const {
    ......
    if (memory_format == at::MemoryFormat::ChannelsLast1d) {
      return compute_channels_last_contiguous_1d();  //<----------------- caculate it once we need it for ChannelsLast1d
    } else if (memory_format == at::MemoryFormat::ChannelsLast) {
      return is_channels_last_contiguous_; //<--------------------------- just read it once we need it for ChannelsLast(2d), becasue the flag has been updated by refresh_contiguous()
    } else if (memory_format == at::MemoryFormat::ChannelsLast3d) {
      return is_channels_last_3d_contiguous_; //<------------------------ the same as ChannelsLast(2d) above
    }
    ......
  }

Regarding 4: If users don't use ChannelsLast1d, they don't need do anything. If user want to use ChannelsLast1d, they can get the same user experience as ChannelsLast(2d) and ChannelsLast3d.

Proposal 2:

Although proposal 1 doesn't introduce extra bits in TensorImpl structure and any overhead for such as function refresh_continguous(), proposal 1 implementation is not smooth and elegant as ChannelsLast(2d) or ChannelsLast3d.
Besides the overhead for refresh_continguous() is almost negligible. I'll explain it later. First of all, let's focus on proposal 2 implementation.

  1. ChannelsLast1d still align to the usage habits of ChannelsLast(2d) and ChannelsLast3d to provide consistent use experience;
  2. Only 2 extra bits are added in TensorImpl structure; (Note: There are 11 bit fields before. Although 2 extra bit fields are added, 11 bit fields and 13 (11+2) bit fields demand the same number of bytes. In other words, these 2 bit fields actually don’t add more bytes to the TensorImpl. There is no impact from adding two bit fields at all.)
  3. Update function refresh_continguous() for ChannelsLast1d;
  4. The feature is still transparent to the end users if they don't use it.
    The details are as follows:

Regarding 1: Users still use it as below:

_3d_tensor_cl = _3d_tensor.to(memory_format=torch.channels_last_1d)
_4d_tensor_cl = _4d_tensor.to(memory_format=torch.channels_last)
_5d_tensor_cl = _5d_tensor.to(memory_format=torch.channels_last_3d)

Regarding 2 : only add 2 extra bits in TensorImpl structure as below:

  bool is_channels_last_1d_ : 1; //<----------------
  bool is_channels_last_1d_contiguous_ : 1; //<-----
  bool is_channels_last_ : 1;
  bool is_channels_last_contiguous_ : 1;
  bool is_channels_last_3d_ : 1;
  bool is_channels_last_3d_contiguous_ : 1;

Regarding 3: The key code snippet is as below:
image

Let's carefully look at refresh_contiguous(). New added functions compute_channels_last_contiguous_1d() and compute_strides_like_channels_last_1d() for channels last 1d would not introduce additional computation for channels last 2d(4D tensor) or channels last 3d(5D tensor) in refresh_contiguous(). E.g.: suppose that we create a new 5D tensor to trigger refresh_contiguous() function, the call tree is as below:
image

Although we call compute_channels_last_contiguous_1d() in refresh_contiguous for 5D tensor, it would not do any additional computation and directly return false(illustrated by red arrow above. dim == 5 would fall into default pass);

Although we call compute_strides_like_channels_last_1d() in refresh_contiguous for 5D tensor, it would not do any additional computation and directly return false (illustrated by green arrow above. dim == 5 would fall into default pass).

Regarding 4: If users don't use ChannelsLast1d, they still don't need do anything. If user want to use ChannelsLast1d, they can get the same user experience as ChannelsLast(2d) and ChannelsLast3d.

The 2 proposals are compared in the table below:

User friendly TensorImpl modification overhead of refresh_contiguous implementation
Proposal 1 Yes No No ugly
Proposal 2 Yes Yes, but only 2 bits No elegant, align to ChannelsLast3d

**Metrics **

As is known, channels last format has better performance than channels first format for most of operators such as conv.

  • Conv1d channels last format on Intel CPU achieves about 1.99x maximum performance boost compared with conv1d channels first format for differenct shapes from wav2vec model.

How we teach this

ChannelsLast1d will align to the usage habits of ChannelsLast(2d) and ChannelsLast3d to provide consistent use experience as below, so we believe that there is no learning cost for users.

_3d_tensor_cl = _3d_tensor.to(memory_format=torch.channels_last_1d)
_4d_tensor_cl = _4d_tensor.to(memory_format=torch.channels_last)
_5d_tensor_cl = _5d_tensor.to(memory_format=torch.channels_last_3d)

Alternatives

See Proposal 2.

Additional context

No response

cc @VitalyFedyunin @jamesr66a

@ngimel ngimel added feature A request for a proper, new feature. module: memory format Memory format/layout related issues/changes (channels_last, nhwc) triage review labels Mar 30, 2022
@ezyang
Copy link
Contributor

ezyang commented Apr 4, 2022

I am conceptually OK with ChannelsLast1D. I am very unhappy with the bitfield situation. I agree that absent any refactoring, we should do Proposal 2.

@mruberry mruberry added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Apr 4, 2022
@albanD
Copy link
Collaborator

albanD commented Apr 4, 2022

For reference, the RFC from the repo: pytorch/rfcs#42

@mruberry
Copy link
Collaborator

mruberry commented Apr 4, 2022

Thank you for the detailed RFC. The Core team isn't pursuing additional channels last memory format support at this time, however, and may not be available to review related PRs. We think it would require adding numerous kernels to provide this kind of support properly.

@jgong5
Copy link
Collaborator

jgong5 commented Apr 7, 2022

Thank you for the detailed RFC. The Core team isn't pursuing additional channels last memory format support at this time, however, and may not be available to review related PRs. We think it would require adding numerous kernels to provide this kind of support properly.

Thanks @mruberry for the feedbacks. According to the suggestions from @VitalyFedyunin today, we will check and handle channels last 1D format inside individual performance-critical op kernels first (e.g., conv1d and maxpool1d) without explicitly exposing the torch.channels_last_1d to the users. Later on, after a large portion of ops honor the channels last 1D format, it would be a good motivation to support torch.channels_last_1d explicitly. Then, we can revisit this RFC.

@matthijsvk
Copy link

any updates on this?

@ezyang
Copy link
Contributor

ezyang commented Mar 28, 2023

1d still not added.

@jgong5
Copy link
Collaborator

jgong5 commented Mar 29, 2023

any updates on this?

Are you considering any particular case that needs this support?

@matthijsvk
Copy link

matthijsvk commented Mar 29, 2023

Recent language/speech/vision models often combine 1D convolutions and attention layers.
Attention usually uses the format [batch, time, channel], so already uses channels_last.
Convolutions using [batch, channel, ...] requires transpose or permute before and after every convolution, which is inefficient.
Many of the normalization layers also assume channels is the second dimension (like InstanceNorm, GroupNorm, BatchNorm) whereas LayerNorm can work with channels as the last dimension. This is one reason for preferring LN in attention models.

See for example the EfficientFormer paper (specifically first bullet in bottom half of page 4)

@jgong5
Copy link
Collaborator

jgong5 commented Mar 29, 2023

Recent language/speech/vision models often combine 1D convolutions and attention layers.
Attention usually uses the format [batch, time, channel], so already uses channels_last.

You may consider to do this implicitly in the implementation of convolution, i.e., check if the input is channels last 1d by its stride and do not force contiguous and leave the channels last 1d for output. Not necessarily to expose it as an explicit memory format.

@matthijsvk
Copy link

What do you mean with 'in the implementation of convolution'?

my problem is I don't know how to call the cuDNN conv1D kernel with channels_last from pytorch.
Conv1D, or how to convert the convolution module weights to channels_last so they match the inputs.

Also I'm not sure on what dimension BatchNorm operates on, as the documentation says it expects [batch, channels, ...] but if I don't use torch's .to(memory_format=torch.channels_last) it may operate across the wrong dimension, no?

@matthijsvk
Copy link

So I'm looking how to do conv_btc like facebookresearch/fairseq#172 is using torch.conv_tbc

@jgong5
Copy link
Collaborator

jgong5 commented Mar 30, 2023

What do you mean with 'in the implementation of convolution'?

Like this: https://github.com/intel/intel-extension-for-pytorch/blob/10e458ea56e909af040e2330b62294aff8c3fadd/csrc/cpu/aten/Conv.cpp#L67

Note that the code I referred to actually forces the format to channels last 1d while if you ignore this forcing code, the conv implementation assumes the input is channels last 1d (if it is 3d) and does not force contiguous on it.

Also I'm not sure on what dimension BatchNorm operates on, as the documentation says it expects [batch, channels, ...] but if I don't use torch's .to(memory_format=torch.channels_last) it may operate across the wrong dimension, no?

I believe BatchNorm can do this implicit approach similarly. The problem for promoting channels last 1d as a frontend API is that the ops need to cover are not that many compared with channels last 2d/3d.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: memory format Memory format/layout related issues/changes (channels_last, nhwc) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants