From ad9a27f3e5b1d2b69b2b078663d556e51f7a9daf Mon Sep 17 00:00:00 2001 From: PHLens Date: Mon, 6 May 2024 20:31:15 +0000 Subject: [PATCH] Move autocast op list to autocast_mode.h to make sure other backends can reuse it. (#125114) This PR refactors the op list added in #124051. To make sure other backends can reuse it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125114 Approved by: https://github.com/albanD --- aten/src/ATen/autocast_mode.cpp | 153 ------------------------------- aten/src/ATen/autocast_mode.h | 155 ++++++++++++++++++++++++++++++++ 2 files changed, 155 insertions(+), 153 deletions(-) diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 0b99b11430956..c233f17b44580 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -158,159 +158,6 @@ namespace { Explicit registration for out-of-place ops *****************************************/ -#define AT_FORALL_LOWER_PRECISION_FP(_) \ - _(_convolution, deprecated) \ - _(_convolution) \ - _(conv1d) \ - _(conv2d) \ - _(conv3d) \ - _(conv_tbc) \ - _(conv_transpose1d) \ - _(conv_transpose2d, input) \ - _(conv_transpose3d, input) \ - _(convolution) \ - _(prelu) \ - _(addmm) \ - _(addmv) \ - _(addr) \ - _(matmul) \ - _(einsum) \ - _(mm) \ - _(mv) \ - _(linalg_vecdot) \ - _(linear) \ - _(addbmm) \ - _(baddbmm) \ - _(bmm) \ - _(chain_matmul) \ - _(linalg_multi_dot) \ - _(_thnn_fused_lstm_cell) \ - _(_thnn_fused_gru_cell) \ - _(lstm_cell) \ - _(gru_cell) \ - _(rnn_tanh_cell) \ - _(rnn_relu_cell) \ - _(_scaled_dot_product_flash_attention) \ - _(scaled_dot_product_attention) - -#define AT_FORALL_FP32(_) \ - _(acos) \ - _(asin) \ - _(cosh) \ - _(erfinv) \ - _(exp) \ - _(expm1) \ - _(log) \ - _(log10) \ - _(log2) \ - _(log1p) \ - _(reciprocal) \ - _(rsqrt) \ - _(sinh) \ - _(tan) \ - _(pow, Tensor_Scalar) \ - _(pow, Tensor_Tensor) \ - _(pow, Scalar) \ - _(softplus) \ - _(layer_norm) \ - _(native_layer_norm) \ - _(group_norm) \ - _(frobenius_norm, dim) \ - _(nuclear_norm) \ - _(nuclear_norm, dim) \ - _(cosine_similarity) \ - _(poisson_nll_loss) \ - _(cosine_embedding_loss) \ - _(nll_loss) \ - _(nll_loss2d) \ - _(hinge_embedding_loss) \ - _(kl_div) \ - _(l1_loss) \ - _(smooth_l1_loss) \ - _(huber_loss) \ - _(mse_loss) \ - _(margin_ranking_loss) \ - _(multilabel_margin_loss) \ - _(soft_margin_loss) \ - _(triplet_margin_loss) \ - _(multi_margin_loss) \ - _(binary_cross_entropy_with_logits) \ - _(dist) \ - _(pdist) \ - _(cdist) \ - _(renorm) \ - _(logsumexp) \ - _(upsample_nearest1d) \ - _(_upsample_nearest_exact1d) \ - _(upsample_nearest2d) \ - _(_upsample_nearest_exact2d) \ - _(upsample_nearest3d) \ - _(_upsample_nearest_exact3d) \ - _(upsample_linear1d) \ - _(upsample_bilinear2d) \ - _(_upsample_bilinear2d_aa) \ - _(upsample_trilinear3d) \ - _(upsample_bicubic2d) \ - _(_upsample_bicubic2d_aa) - -#define AT_FORALL_FP32_SET_OPT_DTYPE(_) \ - _(prod) \ - _(prod, dim_int) \ - _(prod, dim_Dimname) \ - _(softmax, int) \ - _(softmax, Dimname) \ - _(log_softmax, int) \ - _(log_softmax, Dimname) \ - _(cumprod) \ - _(cumprod, dimname) \ - _(cumsum) \ - _(cumsum, dimname) \ - _(linalg_vector_norm) \ - _(linalg_matrix_norm) \ - _(linalg_matrix_norm, str_ord) \ - _(sum) \ - _(sum, dim_IntList) \ - _(sum, dim_DimnameList) - -#define AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(_) \ - _(ADD_NS(norm), \ - "norm.Scalar", \ - Tensor(const Tensor&, const Scalar&), \ - Tensor(const Tensor&, const c10::optional&, ScalarType), \ - fp32_append_dtype) \ - _(ADD_NS(norm), \ - "norm.ScalarOpt_dim", \ - Tensor(const Tensor&, const c10::optional&, IntArrayRef, bool), \ - Tensor( \ - const Tensor&, \ - const c10::optional&, \ - IntArrayRef, \ - bool, \ - ScalarType), \ - fp32_append_dtype) \ - _(ADD_NS(norm), \ - "norm.names_ScalarOpt_dim", \ - Tensor(const Tensor&, const c10::optional&, DimnameList, bool), \ - Tensor( \ - const Tensor&, \ - const c10::optional&, \ - DimnameList, \ - bool, \ - ScalarType), \ - fp32_append_dtype) - -#define AT_FORALL_PROMOTE(_) \ - _(addcdiv) \ - _(addcmul) \ - _(atan2) \ - _(bilinear) \ - _(cross) \ - _(dot) \ - _(grid_sampler) \ - _(index_put) \ - _(tensordot) \ - _(scatter_add) - TORCH_LIBRARY_IMPL(_, Autocast, m) { m.fallback(torch::CppFunction::makeFallthrough()); } diff --git a/aten/src/ATen/autocast_mode.h b/aten/src/ATen/autocast_mode.h index de072617c0f2c..59a91848a5175 100644 --- a/aten/src/ATen/autocast_mode.h +++ b/aten/src/ATen/autocast_mode.h @@ -744,3 +744,158 @@ copy pasted in from VariableTypeEverything.cpp with appropriate substitutions. REGISTER_SIGNATURE, \ REDISPATCH_SIGNATURE, \ POLICY) + +// Op lists for different policies. +// To make sure other backends can reuse the policy op list. +#define AT_FORALL_LOWER_PRECISION_FP(_) \ + _(_convolution, deprecated) \ + _(_convolution) \ + _(conv1d) \ + _(conv2d) \ + _(conv3d) \ + _(conv_tbc) \ + _(conv_transpose1d) \ + _(conv_transpose2d, input) \ + _(conv_transpose3d, input) \ + _(convolution) \ + _(prelu) \ + _(addmm) \ + _(addmv) \ + _(addr) \ + _(matmul) \ + _(einsum) \ + _(mm) \ + _(mv) \ + _(linalg_vecdot) \ + _(linear) \ + _(addbmm) \ + _(baddbmm) \ + _(bmm) \ + _(chain_matmul) \ + _(linalg_multi_dot) \ + _(_thnn_fused_lstm_cell) \ + _(_thnn_fused_gru_cell) \ + _(lstm_cell) \ + _(gru_cell) \ + _(rnn_tanh_cell) \ + _(rnn_relu_cell) \ + _(_scaled_dot_product_flash_attention) \ + _(scaled_dot_product_attention) + +#define AT_FORALL_FP32(_) \ + _(acos) \ + _(asin) \ + _(cosh) \ + _(erfinv) \ + _(exp) \ + _(expm1) \ + _(log) \ + _(log10) \ + _(log2) \ + _(log1p) \ + _(reciprocal) \ + _(rsqrt) \ + _(sinh) \ + _(tan) \ + _(pow, Tensor_Scalar) \ + _(pow, Tensor_Tensor) \ + _(pow, Scalar) \ + _(softplus) \ + _(layer_norm) \ + _(native_layer_norm) \ + _(group_norm) \ + _(frobenius_norm, dim) \ + _(nuclear_norm) \ + _(nuclear_norm, dim) \ + _(cosine_similarity) \ + _(poisson_nll_loss) \ + _(cosine_embedding_loss) \ + _(nll_loss) \ + _(nll_loss2d) \ + _(hinge_embedding_loss) \ + _(kl_div) \ + _(l1_loss) \ + _(smooth_l1_loss) \ + _(huber_loss) \ + _(mse_loss) \ + _(margin_ranking_loss) \ + _(multilabel_margin_loss) \ + _(soft_margin_loss) \ + _(triplet_margin_loss) \ + _(multi_margin_loss) \ + _(binary_cross_entropy_with_logits) \ + _(dist) \ + _(pdist) \ + _(cdist) \ + _(renorm) \ + _(logsumexp) \ + _(upsample_nearest1d) \ + _(_upsample_nearest_exact1d) \ + _(upsample_nearest2d) \ + _(_upsample_nearest_exact2d) \ + _(upsample_nearest3d) \ + _(_upsample_nearest_exact3d) \ + _(upsample_linear1d) \ + _(upsample_bilinear2d) \ + _(_upsample_bilinear2d_aa) \ + _(upsample_trilinear3d) \ + _(upsample_bicubic2d) \ + _(_upsample_bicubic2d_aa) + +#define AT_FORALL_FP32_SET_OPT_DTYPE(_) \ + _(prod) \ + _(prod, dim_int) \ + _(prod, dim_Dimname) \ + _(softmax, int) \ + _(softmax, Dimname) \ + _(log_softmax, int) \ + _(log_softmax, Dimname) \ + _(cumprod) \ + _(cumprod, dimname) \ + _(cumsum) \ + _(cumsum, dimname) \ + _(linalg_vector_norm) \ + _(linalg_matrix_norm) \ + _(linalg_matrix_norm, str_ord) \ + _(sum) \ + _(sum, dim_IntList) \ + _(sum, dim_DimnameList) + +#define AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(_) \ + _(ADD_NS(norm), \ + "norm.Scalar", \ + Tensor(const Tensor&, const Scalar&), \ + Tensor(const Tensor&, const c10::optional&, ScalarType), \ + fp32_append_dtype) \ + _(ADD_NS(norm), \ + "norm.ScalarOpt_dim", \ + Tensor(const Tensor&, const c10::optional&, IntArrayRef, bool), \ + Tensor( \ + const Tensor&, \ + const c10::optional&, \ + IntArrayRef, \ + bool, \ + ScalarType), \ + fp32_append_dtype) \ + _(ADD_NS(norm), \ + "norm.names_ScalarOpt_dim", \ + Tensor(const Tensor&, const c10::optional&, DimnameList, bool), \ + Tensor( \ + const Tensor&, \ + const c10::optional&, \ + DimnameList, \ + bool, \ + ScalarType), \ + fp32_append_dtype) + +#define AT_FORALL_PROMOTE(_) \ + _(addcdiv) \ + _(addcmul) \ + _(atan2) \ + _(bilinear) \ + _(cross) \ + _(dot) \ + _(grid_sampler) \ + _(index_put) \ + _(tensordot) \ + _(scatter_add)