diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc index 83e6343af6..37a6df3f40 100644 --- a/src/RefImplementations.cc +++ b/src/RefImplementations.cc @@ -1234,6 +1234,11 @@ bool EmbeddingSpMDM_ref( int64_t current = 0; if (no_bag) { + // in edge cases, input_stride can be larger than output_stride, to make + // sure no memory overflow also be backward compatible, get the min value + // of input_stride and output_stride, and only copy the overlap part of + // data + auto copy_width = std::min(output_stride, input_stride); for (int m = 0; m < output_size; ++m) { int64_t idx = indices[m]; @@ -1242,7 +1247,7 @@ bool EmbeddingSpMDM_ref( } if constexpr (isOutput8bit) { const InType* input_row_ptr = input + input_stride * idx; - memcpy(out, input_row_ptr, sizeof(InType) * input_stride); + memcpy(out, input_row_ptr, sizeof(InType) * copy_width); } else { memset(buf.data(), 0, sizeof(float) * block_size); const float* scale_bias = reinterpret_cast(