Skip to content

RFC: Common memory-layout handling infra #19299

@AdrianLundell

Description

@AdrianLundell

🚀 The feature, motivation and pitch

In PyTorch, a tensor is defined by two things, the layout of its data in memory, and the stride/dim_order and shape meta-data determining how to interpret this data. Additionally, convolution and pooling operators are by default assumed to be channels-first. Many Executorch backends differ from this definition both in that there is no stride concept in their IR, and that the accelerated kernels only handles channels-last operations.

Currently, the seemingly intended way to get around this issue is via the dim_order concept, by exporting models in non-contiguous format and then lowering a (0,2,3,1) dim_order strided tensor (N,C,H,W) to a contignuous (N, H, W, C) tensor in the non-strided IR (see #8037). This strategy has shown to lead to a number of issues:

1. Backend fragmentation
Letting all backends implement their own dim-order based solution leads to massive duplicate effort and friction when combining backends.

2. Not all graphs can be exported in non-contiguous format
Exporting the graph in channels-last requires using torch.channels_last or torch.channels_last_3d, which is only supported for rank 4 and 5 tensors respectively, leaving graphs designed for rank 3 in channels-first. There is also a limitation in the view operator (https://docs.pytorch.org/docs/stable/generated/torch.2. Not all graphs can be exported in non-contiguous format Tensor.view.html) causing certain views to crash at export-time for channels-last input.

3. Dim-order is not useable for lowering to a non-strided format without a lot of edge-case handling
The fundamental issue with trying to use dim-order and stride is as hinted at that it just cannot be represented in a non-strided format. For example, the permute operator operates in Pytorch by changing only the stride and shape and not moving any data, which is impossible. On the contrary, the to_dim_order operator changes only the data but without changing shape, which is also not possible. This leads to subtle differences in the aten-graph vs backend-representation, likely to lead to bugs. Additionally, since the dim_order is computed from strides rather than being a defining property of the tensor, it frequently happens to change through the graph in unexpected ways:

  • Several operator implementations does not handle dim_order correctly
  • Dim_order is ambiguously defined for certain shapes
  • Dim_order changes for squeezes/unsqueezes are not easily predicatable

While this is no problem in PyTorch since operators handles any dim-order, for non-strided formats this causes invalid graphs unless caught and carefully adjusted for. In short, lowering the non-strided format as the shape permuted by the dim-order has proven to not be a feasible solution, instead requiring a per-backend custom solution on-top of the already complex dim-order situation. The intention with the suggested handling is to simplify by instead modifying the Aten graph to be as similar as possible to the lowered IR: if the lowered IR is contiguous, make the Aten graph contiguous when lowering it.

Alternatives

No response

Additional context

No response

RFC (Optional)

The solution would be a generalization of the solution proposed for the arm-backend in #19015. This PR introduces operators which act in channels-last, but in a contiguous data format, and then replace channels-first operators with the channels-last operator surrounded by permutes. Permutes are then fused using general optimization passes.

Extending this solution to make it useable for all backends would require:

  • General channels-last contiguous convolution and pooling ops
  • A "true" permute operator which does not change dim-order and instead shuffles data.
  • A pass for normalizing a graph to being fully contiguous by replacing ops with the operators defined above and removing to_dim_order ops.
  • A pass for normalizing non-contiguous input/output to a to_dim_order op + contiguous permute (see picture below)
  • One or multiple general permute fusing passes
  • A helper-pass for removing input/output permutes in cases where they are not fully fused if the user prefers to supply already permuted data.
  • Extensive testing ensuring that the general memory-handling does not produce unnecessary permutes.

Each backend can then use these passes to handle the layout issue in the same way and replace the general channels-last operator with their own channels-last implementation. Since the infra is implemented as optional passes, this solution will allow any customization needed by each backend as well.

To go one step further, this lowering could be run before the transform_for_annotation step, which would allow for an optimal transpose optimizing between multiple backends. This would require a more significant effort since it would require handling operators such as LSTM which may be decomposed into multiple linear-ops each requiring channels-last handling, but the possibility can at least be considered.

This image details the handling of non-contiguous input in an example graph with a single convolution. The channels-last input is first dim-order permuted (only-meta data) to contiguous format, and then true permuted back to the same shape as the input had originally. Even if this seemingly means two extra permutes at first, the data shuffling permute will be fused with the convolution and only the meta-data op will be left in the graph, which should be essentially free performance-wise.

Image

cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell

Metadata

Metadata

Labels

partner: armFor backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions