Skip to content

Commit

Permalink
Update on "[Inductor] [Quant] Enable lowering of quant per tensor and…
Browse files Browse the repository at this point in the history
… refactor quant pattern"


**Summary**
Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can:

- Avoid the pattern matcher failure introduced in #123445
- Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed.

**Changes in this PR**

- Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase.
- Corresponding changes in the quantization pattern matcher to ensure no bc-breaking.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q
```


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
  • Loading branch information
leslie-fang-intel committed May 7, 2024
2 parents 8b0ea3a + f146f03 commit 76db2b1
Show file tree
Hide file tree
Showing 109 changed files with 1,581 additions and 1,335 deletions.
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/xla.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
73b915b55d96553a0e370b2bab01f47b8c2a9e7c
e3fc03314dab5f44e3ed9ccbba6c15fbca3285cd
11 changes: 11 additions & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,17 @@
- third_party/mkl-dnn.BUILD
- torch/csrc/jit/codegen/onednn/**
- test/test_jit_llga_fuser.py
- test/test_mkldnn.py

"ciflow/linux-aarch64":
- third_party/ideep
- caffe2/ideep/**
- caffe2/python/ideep/**
- cmake/Modules/FindMKLDNN.cmake
- third_party/mkl-dnn.BUILD
- torch/csrc/jit/codegen/onednn/**
- test/test_jit_llga_fuser.py
- test/test_mkldnn.py

"module: amp (automated mixed precision)":
- torch/amp/**
Expand Down
1 change: 1 addition & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -1622,6 +1622,7 @@ exclude_patterns = [
'torch/functional.py',
'torch/futures/__init__.py',
'torch/fx/__init__.py',
'torch/fx/_compatibility.py',
'torch/fx/_symbolic_trace.py',
'torch/fx/annotate.py',
'torch/fx/config.py',
Expand Down
153 changes: 0 additions & 153 deletions aten/src/ATen/autocast_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,159 +158,6 @@ namespace {
Explicit registration for out-of-place ops
*****************************************/

#define AT_FORALL_LOWER_PRECISION_FP(_) \
_(_convolution, deprecated) \
_(_convolution) \
_(conv1d) \
_(conv2d) \
_(conv3d) \
_(conv_tbc) \
_(conv_transpose1d) \
_(conv_transpose2d, input) \
_(conv_transpose3d, input) \
_(convolution) \
_(prelu) \
_(addmm) \
_(addmv) \
_(addr) \
_(matmul) \
_(einsum) \
_(mm) \
_(mv) \
_(linalg_vecdot) \
_(linear) \
_(addbmm) \
_(baddbmm) \
_(bmm) \
_(chain_matmul) \
_(linalg_multi_dot) \
_(_thnn_fused_lstm_cell) \
_(_thnn_fused_gru_cell) \
_(lstm_cell) \
_(gru_cell) \
_(rnn_tanh_cell) \
_(rnn_relu_cell) \
_(_scaled_dot_product_flash_attention) \
_(scaled_dot_product_attention)

#define AT_FORALL_FP32(_) \
_(acos) \
_(asin) \
_(cosh) \
_(erfinv) \
_(exp) \
_(expm1) \
_(log) \
_(log10) \
_(log2) \
_(log1p) \
_(reciprocal) \
_(rsqrt) \
_(sinh) \
_(tan) \
_(pow, Tensor_Scalar) \
_(pow, Tensor_Tensor) \
_(pow, Scalar) \
_(softplus) \
_(layer_norm) \
_(native_layer_norm) \
_(group_norm) \
_(frobenius_norm, dim) \
_(nuclear_norm) \
_(nuclear_norm, dim) \
_(cosine_similarity) \
_(poisson_nll_loss) \
_(cosine_embedding_loss) \
_(nll_loss) \
_(nll_loss2d) \
_(hinge_embedding_loss) \
_(kl_div) \
_(l1_loss) \
_(smooth_l1_loss) \
_(huber_loss) \
_(mse_loss) \
_(margin_ranking_loss) \
_(multilabel_margin_loss) \
_(soft_margin_loss) \
_(triplet_margin_loss) \
_(multi_margin_loss) \
_(binary_cross_entropy_with_logits) \
_(dist) \
_(pdist) \
_(cdist) \
_(renorm) \
_(logsumexp) \
_(upsample_nearest1d) \
_(_upsample_nearest_exact1d) \
_(upsample_nearest2d) \
_(_upsample_nearest_exact2d) \
_(upsample_nearest3d) \
_(_upsample_nearest_exact3d) \
_(upsample_linear1d) \
_(upsample_bilinear2d) \
_(_upsample_bilinear2d_aa) \
_(upsample_trilinear3d) \
_(upsample_bicubic2d) \
_(_upsample_bicubic2d_aa)

#define AT_FORALL_FP32_SET_OPT_DTYPE(_) \
_(prod) \
_(prod, dim_int) \
_(prod, dim_Dimname) \
_(softmax, int) \
_(softmax, Dimname) \
_(log_softmax, int) \
_(log_softmax, Dimname) \
_(cumprod) \
_(cumprod, dimname) \
_(cumsum) \
_(cumsum, dimname) \
_(linalg_vector_norm) \
_(linalg_matrix_norm) \
_(linalg_matrix_norm, str_ord) \
_(sum) \
_(sum, dim_IntList) \
_(sum, dim_DimnameList)

#define AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(_) \
_(ADD_NS(norm), \
"norm.Scalar", \
Tensor(const Tensor&, const Scalar&), \
Tensor(const Tensor&, const c10::optional<Scalar>&, ScalarType), \
fp32_append_dtype) \
_(ADD_NS(norm), \
"norm.ScalarOpt_dim", \
Tensor(const Tensor&, const c10::optional<Scalar>&, IntArrayRef, bool), \
Tensor( \
const Tensor&, \
const c10::optional<Scalar>&, \
IntArrayRef, \
bool, \
ScalarType), \
fp32_append_dtype) \
_(ADD_NS(norm), \
"norm.names_ScalarOpt_dim", \
Tensor(const Tensor&, const c10::optional<Scalar>&, DimnameList, bool), \
Tensor( \
const Tensor&, \
const c10::optional<Scalar>&, \
DimnameList, \
bool, \
ScalarType), \
fp32_append_dtype)

#define AT_FORALL_PROMOTE(_) \
_(addcdiv) \
_(addcmul) \
_(atan2) \
_(bilinear) \
_(cross) \
_(dot) \
_(grid_sampler) \
_(index_put) \
_(tensordot) \
_(scatter_add)

TORCH_LIBRARY_IMPL(_, Autocast, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
Expand Down
155 changes: 155 additions & 0 deletions aten/src/ATen/autocast_mode.h
Original file line number Diff line number Diff line change
Expand Up @@ -744,3 +744,158 @@ copy pasted in from VariableTypeEverything.cpp with appropriate substitutions.
REGISTER_SIGNATURE, \
REDISPATCH_SIGNATURE, \
POLICY)

// Op lists for different policies.
// To make sure other backends can reuse the policy op list.
#define AT_FORALL_LOWER_PRECISION_FP(_) \
_(_convolution, deprecated) \
_(_convolution) \
_(conv1d) \
_(conv2d) \
_(conv3d) \
_(conv_tbc) \
_(conv_transpose1d) \
_(conv_transpose2d, input) \
_(conv_transpose3d, input) \
_(convolution) \
_(prelu) \
_(addmm) \
_(addmv) \
_(addr) \
_(matmul) \
_(einsum) \
_(mm) \
_(mv) \
_(linalg_vecdot) \
_(linear) \
_(addbmm) \
_(baddbmm) \
_(bmm) \
_(chain_matmul) \
_(linalg_multi_dot) \
_(_thnn_fused_lstm_cell) \
_(_thnn_fused_gru_cell) \
_(lstm_cell) \
_(gru_cell) \
_(rnn_tanh_cell) \
_(rnn_relu_cell) \
_(_scaled_dot_product_flash_attention) \
_(scaled_dot_product_attention)

#define AT_FORALL_FP32(_) \
_(acos) \
_(asin) \
_(cosh) \
_(erfinv) \
_(exp) \
_(expm1) \
_(log) \
_(log10) \
_(log2) \
_(log1p) \
_(reciprocal) \
_(rsqrt) \
_(sinh) \
_(tan) \
_(pow, Tensor_Scalar) \
_(pow, Tensor_Tensor) \
_(pow, Scalar) \
_(softplus) \
_(layer_norm) \
_(native_layer_norm) \
_(group_norm) \
_(frobenius_norm, dim) \
_(nuclear_norm) \
_(nuclear_norm, dim) \
_(cosine_similarity) \
_(poisson_nll_loss) \
_(cosine_embedding_loss) \
_(nll_loss) \
_(nll_loss2d) \
_(hinge_embedding_loss) \
_(kl_div) \
_(l1_loss) \
_(smooth_l1_loss) \
_(huber_loss) \
_(mse_loss) \
_(margin_ranking_loss) \
_(multilabel_margin_loss) \
_(soft_margin_loss) \
_(triplet_margin_loss) \
_(multi_margin_loss) \
_(binary_cross_entropy_with_logits) \
_(dist) \
_(pdist) \
_(cdist) \
_(renorm) \
_(logsumexp) \
_(upsample_nearest1d) \
_(_upsample_nearest_exact1d) \
_(upsample_nearest2d) \
_(_upsample_nearest_exact2d) \
_(upsample_nearest3d) \
_(_upsample_nearest_exact3d) \
_(upsample_linear1d) \
_(upsample_bilinear2d) \
_(_upsample_bilinear2d_aa) \
_(upsample_trilinear3d) \
_(upsample_bicubic2d) \
_(_upsample_bicubic2d_aa)

#define AT_FORALL_FP32_SET_OPT_DTYPE(_) \
_(prod) \
_(prod, dim_int) \
_(prod, dim_Dimname) \
_(softmax, int) \
_(softmax, Dimname) \
_(log_softmax, int) \
_(log_softmax, Dimname) \
_(cumprod) \
_(cumprod, dimname) \
_(cumsum) \
_(cumsum, dimname) \
_(linalg_vector_norm) \
_(linalg_matrix_norm) \
_(linalg_matrix_norm, str_ord) \
_(sum) \
_(sum, dim_IntList) \
_(sum, dim_DimnameList)

#define AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(_) \
_(ADD_NS(norm), \
"norm.Scalar", \
Tensor(const Tensor&, const Scalar&), \
Tensor(const Tensor&, const c10::optional<Scalar>&, ScalarType), \
fp32_append_dtype) \
_(ADD_NS(norm), \
"norm.ScalarOpt_dim", \
Tensor(const Tensor&, const c10::optional<Scalar>&, IntArrayRef, bool), \
Tensor( \
const Tensor&, \
const c10::optional<Scalar>&, \
IntArrayRef, \
bool, \
ScalarType), \
fp32_append_dtype) \
_(ADD_NS(norm), \
"norm.names_ScalarOpt_dim", \
Tensor(const Tensor&, const c10::optional<Scalar>&, DimnameList, bool), \
Tensor( \
const Tensor&, \
const c10::optional<Scalar>&, \
DimnameList, \
bool, \
ScalarType), \
fp32_append_dtype)

#define AT_FORALL_PROMOTE(_) \
_(addcdiv) \
_(addcmul) \
_(atan2) \
_(bilinear) \
_(cross) \
_(dot) \
_(grid_sampler) \
_(index_put) \
_(tensordot) \
_(scatter_add)
2 changes: 1 addition & 1 deletion aten/src/ATen/core/functional.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace c10 {
// const reference (const T&); taking T by non-const reference
// will result in an error like:
//
// error: no type named 'type' in 'class std::result_of<foobar::__lambda(T)>'
// error: no type named 'type' in 'class std::invoke_result<foobar::__lambda, T>'
//
// No explicit template parameters are required.

Expand Down
Loading

0 comments on commit 76db2b1

Please sign in to comment.