-
Notifications
You must be signed in to change notification settings - Fork 52
Description
Issue
Multiple WebNN operators still have limited ranks (convolve, pool, and resample are all fixed to 4D tensor rank and 2D axes) which was historically done for backends that might be more limited, but actually current backends (CoreML, DML, LiteRT, ORT) generally support more ranks than that, and this WebNN limitation has been problematic (e.g. Whisper uses 1D convolution, and so ORT's WebNN EP needed to add extra reshape calls that prepend 1's to the inputs):
Backend operator rank support
- Convolve (filtered windowed reduction)
- Pool* (windowed reduction)
- CoreML supports 3D-5D input with corresponding 1-3 reduction axes
- DirectML supports 4D-5D with corresponding 2-3 reduction axes (3D supported with some leading 1's)
- ONNX supports 3D-5D input with corresponding 1-3 reduction axes
- TFLite/LiteRT supports an indeterminate (TBD?) input count with 2 reduction axes.
- Resample
Proposal
Extend these operators to the range of underlying backend support.
API Approaches
Surveying the approaches other libraries took:
- (A) Bake the axis count directly into the operator name (TF, PyTorch): torch.AveragePool1d, torch.AveragePool2d, torch.AveragePool3d, tf.batchNorm2d, tf.batchNorm3d, tf.batchNorm4d…
- ➖ 😢 This pollutes the API surface with 3+ different functions that are basically identical except for their axis count. It's weird too, given most other operators simply take different ranks/axes without different operator names (there is no torch.Add2, torch.Add3...).
- ➕ On the plus side, you can mix and match a different number of axes with input ranks more easily (TF and PT still have some limits, but it’s less rigid).
- (B) Use a single operator name, with an implicit axis count based on the input rank (e.g. CoreML and ONNX) 🤔:
- ➕ Consolidated operator name.
- ➖ There’s a hard-coded association between the input rank and the number of axes, meaning you can do {1D reduction on 3D input, 2D on 4D, 3D on 5D}, but not {1D reduction on 4D input, 3D on 4D, 2D on 5D…}.
- (C) Pass the reduction axis count separately from the input rank:
- ➕ Consolidated operator name.
- ➕ More flexible number of axes with input ranks (still limited to the rightmost axes though, requiring a transpose if you want intermediate dimensions).
- (D) Pass the explicit axes. I haven’t seen this in any ML API for conv/pool, but it would be consistent with the reduction operator (and resample), would obviate the “layout” enum for pooling, and would enable more primitive compositions without clumsy transpose calls (like
localResponseNormalizationdecomposition which could have usedaveragePoolwith axis 1 if axes were a parameter rather than implicitly only the rightmost axes).- ➕ Consolidated operator name.
- ➕ More flexible number of axes with input ranks.
- ➕ Facilitates rarer operator compositions and nicely obviates layout enums.
Implementation details (TBD)
Allowed ranks
We can reflect platform rank differences through MLOpSupportLimits (though, such variance is complicating for callers, and so if a little bit of emulation on the level of a reshape permits uniformity, then I'd push for that). Ideally any axis count 1-3 would be legal to WebNN so long as axis count <= input rank, for both windowed reduction (pooling) and filtered reduction (convolution), just like normal reduction which accepts any count of axes <= input rank. Then there wouldn't be any implicit hard-coded expectation of exactly two non-spatial dimensions (batch and channel) per input tensor, like with CoreML's and ONNX's Convolution, as those dimensions can always be trivially flattened with a reshape before reaching those backends. That is:
| Input 1D | Input 2D | Input 3D | Input 4D | Input 5D | |
|---|---|---|---|---|---|
| Reduction axes 1 | ✅ | ✅ | ✅ | ✅ | ✅ |
| Reduction axes 2 | ✖️na | ✅ | ✅ | ✅ | ✅ |
| Reduction axes 3 | ✖️na | ✖️na | ✅ | ✅ | ✅ |
TBD: TFLite appears to support 3D convolution, yet not 3D pooling (can average pooling be implemented via convolution?).
Naming
I really want to avoid adding a zoo of new function names (foo1, foo2, foo3, bar1, bar2, bar3) and instead just use the operator name without suffix (e.g. like reduceMax, we have poolMax2d -> poolMax and MLPool2dOptions -> MLPoolOptions), either using the windowDimensions field to know how many axes are being reduced (option C) or passing an explicit list of axes (option D) like resample takes. So we would remove the 2D suffix from these:
conv2d->convconvTranspose2d->convTransposeaveragePool2d->averagePooll2Pool2d->l2PoolmaxPool2d->maxPoolresample2d() ->resample
(conveniently that would obviate the naming issue #821)