diff --git a/torchao/prototype/inductor/codegen/cpp_int8_sdpa_template.py b/torchao/prototype/inductor/codegen/cpp_int8_sdpa_template.py index bf30f6705c..a2a9c51c3e 100644 --- a/torchao/prototype/inductor/codegen/cpp_int8_sdpa_template.py +++ b/torchao/prototype/inductor/codegen/cpp_int8_sdpa_template.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + from typing import List, Optional import torch @@ -239,22 +245,22 @@ long col = 0; for (; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); - auto tmp1 = tmp0 * vec_sum_scale; - auto tmp2 = tmp1.round(); - auto tmp3 = tmp2 + vec_beta1; + auto tmp1 = at::vec::fmadd(tmp0, vec_sum_scale, vec_beta1); + auto tmp3 = tmp1.round(); auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); - store(tmp_out + col, tmp4); auto tmp6 = at::vec::convert(tmp4); + auto tmp7 = at::vec::convert(tmp6); + tmp7.store(tmp_out + col, vec_size); vec_tmp_sum += tmp6; } if (col < kvBlockSize) { auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, kvBlockSize - col); - auto tmp1 = tmp0 * vec_sum_scale; - auto tmp2 = tmp1.round(); - auto tmp3 = tmp2 + vec_beta1; + auto tmp1 = at::vec::fmadd(tmp0, vec_sum_scale, vec_beta1); + auto tmp3 = tmp1.round(); auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); - store(tmp_out + col, tmp4, kvBlockSize - col); auto tmp6 = at::vec::convert(tmp4); + auto tmp7 = at::vec::convert(tmp6); + tmp7.store(tmp_out + col, kvBlockSize - col); vec_tmp_sum = at::vec::Vectorized::set(vec_tmp_sum, vec_tmp_sum + tmp6, kvBlockSize - col); } sum_a_ptr[row] += vec_tmp_sum.reduce_add() * beta2; @@ -341,17 +347,15 @@ long col = 0; for (; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col); - auto tmp1 = tmp0 * vec_sum_scale; - auto tmp2 = tmp1.round(); - auto tmp3 = tmp2 + vec_beta1; + auto tmp1 = at::vec::fmadd(tmp0, vec_sum_scale, vec_beta1); + auto tmp3 = tmp1.round(); auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); store(tmp_out + col, tmp4); } if (col < kvBlockSize) { auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, kvBlockSize - col); - auto tmp1 = tmp0 * vec_sum_scale; - auto tmp2 = tmp1.round(); - auto tmp3 = tmp2 + vec_beta1; + auto tmp1 = at::vec::fmadd(tmp0, vec_sum_scale, vec_beta1); + auto tmp3 = tmp1.round(); auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); store(tmp_out + col, tmp4, kvBlockSize - col); } @@ -406,9 +410,8 @@ auto tmp2 = tmp1 - vec_sum_a; auto tmp3 = tmp2 + vec_beta1; auto tmp4 = at::vec::convert(tmp3); - auto tmp5 = tmp4 * vec_alpha; - auto tmp6 = tmp5.round(); - auto tmp7 = tmp6 + vec_beta2; + auto tmp5 = at::vec::fmadd(tmp4, vec_alpha, vec_beta2); + auto tmp7 = tmp5.round(); auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); store(tmp_out + col, tmp8); } @@ -419,9 +422,8 @@ auto tmp2 = tmp1 - vec_sum_a; auto tmp3 = tmp2 + vec_beta1; auto tmp4 = at::vec::convert(tmp3); - auto tmp5 = tmp4 * vec_alpha; - auto tmp6 = tmp5.round(); - auto tmp7 = tmp6 + vec_beta2; + auto tmp5 = at::vec::fmadd(tmp4, vec_alpha, vec_beta2); + auto tmp7 = tmp5.round(); auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); store(tmp_out + col, tmp8, N - col); } @@ -463,9 +465,8 @@ auto tmp3 = tmp1 - vec_sum_a; // auto tmp3 = tmp2 + vec_beta1; auto tmp4 = at::vec::convert(tmp3); - auto tmp5 = tmp4 * vec_alpha; - auto tmp6 = tmp5.round(); - auto tmp7 = tmp6 + vec_beta2; + auto tmp5 = at::vec::fmadd(tmp4, vec_alpha, vec_beta2); + auto tmp7 = tmp5.round(); auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); store(tmp_out + col, tmp8); } @@ -473,9 +474,8 @@ auto tmp1 = at::vec::Vectorized::loadu(tmp_in + col, N - col); auto tmp3 = tmp1 - vec_sum_a; auto tmp4 = at::vec::convert(tmp3); - auto tmp5 = tmp4 * vec_alpha; - auto tmp6 = tmp5.round(); - auto tmp7 = tmp6 + vec_beta2; + auto tmp5 = at::vec::fmadd(tmp4, vec_alpha, vec_beta2); + auto tmp7 = tmp5.round(); auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); store(tmp_out + col, tmp8, N - col); } @@ -1384,7 +1384,7 @@ q_sum_ptr, static_cast(0), qSplitSize); {%- endif %} const int64_t rkvSlice = (num_keys - 1) / kvSplitSize + 1; - + for (int64_t l = 0; l < rkvSlice; l++) { int64_t n = l * kvSplitSize; int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);