Skip to content

Commit

Permalink
And float32
Browse files Browse the repository at this point in the history
  • Loading branch information
malfet committed May 1, 2024
1 parent cf9a938 commit 954400e
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions aten/src/ATen/native/cpu/int8mm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,13 @@ inline float32x4_t load_as_float32x4(const BFloat16* ptr) {
return vreinterpretq_f32_u32(vshlq_u32(as_int, shift));
}

inline float32x4_t load_as_float32x4(const float* ptr) {
return vld1q_f32(ptr);
}

inline float32x4x2_t load_as_float32x4x2(const float* ptr) {
return {vld1q_f32(ptr), vld1q_f32(ptr + 4)};
}

template <int BLOCK_M, int BLOCK_N, typename T>
inline void tinygemm_kernel_(
Expand Down Expand Up @@ -276,6 +283,18 @@ inline void tinygemm_kernel(
tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, scales, C, lda, ldb, ldc, K);
}

template <int BLOCK_M, int BLOCK_N>
inline void tinygemm_kernel(
const float* RESTRICT A,
const int8_t* RESTRICT B,
const float* RESTRICT scales,
float* RESTRICT C,
int lda,
int ldb,
int ldc,
int K) {
tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, scales, C, lda, ldb, ldc, K);
}
#endif

// non-vectorized version
Expand Down

0 comments on commit 954400e

Please sign in to comment.