From 4f07e08368045c0e3647e971822d44e7e8591c1d Mon Sep 17 00:00:00 2001 From: Emma Lin Date: Mon, 24 Nov 2025 11:24:19 -0800 Subject: [PATCH] fix memory overflow (#5169) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2166 In edge cases, the weight input_stride and output_stride could be different. In this case, out will move based on out_stride but copy data with size input_stride. If input_stride is larger than output_stride, the memcpy could overflow. this fix guarantees the safe memcpy, and is backward compatible for cases that input_stride <= output_stride. Differential Revision: D87734295 --- src/RefImplementations.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc index 83e6343af6..37a6df3f40 100644 --- a/src/RefImplementations.cc +++ b/src/RefImplementations.cc @@ -1234,6 +1234,11 @@ bool EmbeddingSpMDM_ref( int64_t current = 0; if (no_bag) { + // in edge cases, input_stride can be larger than output_stride, to make + // sure no memory overflow also be backward compatible, get the min value + // of input_stride and output_stride, and only copy the overlap part of + // data + auto copy_width = std::min(output_stride, input_stride); for (int m = 0; m < output_size; ++m) { int64_t idx = indices[m]; @@ -1242,7 +1247,7 @@ bool EmbeddingSpMDM_ref( } if constexpr (isOutput8bit) { const InType* input_row_ptr = input + input_stride * idx; - memcpy(out, input_row_ptr, sizeof(InType) * input_stride); + memcpy(out, input_row_ptr, sizeof(InType) * copy_width); } else { memset(buf.data(), 0, sizeof(float) * block_size); const float* scale_bias = reinterpret_cast(