Skip to content

Commit

Permalink
fix crash in sample (#2904)
Browse files Browse the repository at this point in the history
Co-authored-by: Shylock Hg <33566796+Shylock-Hg@users.noreply.github.com>
  • Loading branch information
critical27 and Shylock-Hg committed Sep 22, 2021
1 parent 32ff73f commit 371b011
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
8 changes: 7 additions & 1 deletion src/common/algorithm/ReservoirSampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@ class ReservoirSampling final {
return false;
}

std::vector<T>&& samples() && { return std::move(samples_); }
std::vector<T> samples() {
auto result = std::move(samples_);
samples_.clear();
samples_.reserve(num_);
cnt_ = 0;
return result;
}

private:
std::vector<T> samples_;
Expand Down
22 changes: 12 additions & 10 deletions src/common/algorithm/test/ReservoirSamplingTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ TEST(ReservoirSamplingTest, Sample) {
sampler.sampling(std::move(i));
}

auto result = std::move(sampler).samples();
auto result = sampler.samples();
EXPECT_EQ(5, result.size());
for (auto i : result) {
EXPECT_LE(0, i);
Expand All @@ -27,16 +27,18 @@ TEST(ReservoirSamplingTest, Sample) {
}
{
ReservoirSampling<int64_t> sampler(5);
std::vector<int64_t> sampleSpace = {0, 1, 2};
for (auto i : sampleSpace) {
sampler.sampling(std::move(i));
}
for (size_t count = 0; count < 10; count++) {
std::vector<int64_t> sampleSpace = {0, 1, 2};
for (auto i : sampleSpace) {
sampler.sampling(std::move(i));
}

auto result = std::move(sampler).samples();
EXPECT_EQ(3, result.size());
EXPECT_EQ(0, result[0]);
EXPECT_EQ(1, result[1]);
EXPECT_EQ(2, result[2]);
auto result = sampler.samples();
EXPECT_EQ(3, result.size());
EXPECT_EQ(0, result[0]);
EXPECT_EQ(1, result[1]);
EXPECT_EQ(2, result[2]);
}
}
}
} // namespace algorithm
Expand Down
2 changes: 1 addition & 1 deletion src/storage/exec/GetNeighborsNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class GetNeighborsSampleNode : public GetNeighborsNode {
}

RowReaderWrapper reader;
auto samples = std::move(*sampler_).samples();
auto samples = sampler_->samples();
for (auto& sample : samples) {
auto columnIdx = std::get<4>(sample);
// add edge prop value to the target column
Expand Down

0 comments on commit 371b011

Please sign in to comment.