Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/RefImplementations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1234,6 +1234,11 @@ bool EmbeddingSpMDM_ref(
int64_t current = 0;

if (no_bag) {
// in edge cases, input_stride can be larger than output_stride, to make
// sure no memory overflow also be backward compatible, get the min value
// of input_stride and output_stride, and only copy the overlap part of
// data
auto copy_width = std::min(output_stride, input_stride);
for (int m = 0; m < output_size; ++m) {
int64_t idx = indices[m];

Expand All @@ -1242,7 +1247,7 @@ bool EmbeddingSpMDM_ref(
}
if constexpr (isOutput8bit) {
const InType* input_row_ptr = input + input_stride * idx;
memcpy(out, input_row_ptr, sizeof(InType) * input_stride);
memcpy(out, input_row_ptr, sizeof(InType) * copy_width);
} else {
memset(buf.data(), 0, sizeof(float) * block_size);
const float* scale_bias = reinterpret_cast<const float*>(
Expand Down
Loading