Skip to content

Commit

Permalink
Refactor RNG ownership semantics (#611)
Browse files Browse the repository at this point in the history
- Change `stim::TableauSimulator`'s constructor to require a move-reference for its rng
- Change `stim::FrameSimulator`'s constructor to require a move-reference for its rng
- Change `stim::FrameSimulator`'s RNG from a reference to a normal object owned by the simulator
- Replace SHARED_TEST_RNG method with INDEPENDENT_TEST_RNG method
- Rewrite all tests to use INDEPENDENT_TEST_RNG, fixing potential overlaps in randomness
- Also fix some `pybind11::args args` -> `const pybind11::args &args` warnings
- Also fix some int vs size_t comparison warnings

Fixes #353
  • Loading branch information
Strilanc committed Aug 19, 2023
1 parent 0afe203 commit b56c1d7
Show file tree
Hide file tree
Showing 49 changed files with 526 additions and 432 deletions.
11 changes: 0 additions & 11 deletions glue/javascript/common.js.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,6 @@

using namespace stim;

static bool shared_rng_initialized;
static std::mt19937_64 shared_rng;

std::mt19937_64 &JS_BIND_SHARED_RNG() {
if (!shared_rng_initialized) {
shared_rng = externally_seeded_rng();
shared_rng_initialized = true;
}
return shared_rng;
}

uint32_t js_val_to_uint32_t(const emscripten::val &val) {
double v = val.as<double>();
double f = floor(v);
Expand Down
2 changes: 0 additions & 2 deletions glue/javascript/common.js.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

#include "stim/probability_util.h"

std::mt19937_64 &JS_BIND_SHARED_RNG();

template <typename T>
emscripten::val vec_to_js_array(const std::vector<T> &items) {
emscripten::val result = emscripten::val::array();
Expand Down
3 changes: 2 additions & 1 deletion glue/javascript/pauli_string.js.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ ExposedPauliString::ExposedPauliString(const emscripten::val &arg) : pauli_strin
}

ExposedPauliString ExposedPauliString::random(size_t n) {
return ExposedPauliString(PauliString<stim::MAX_BITWORD_WIDTH>::random(n, JS_BIND_SHARED_RNG()));
auto rng = externally_seeded_rng();
return ExposedPauliString(PauliString<stim::MAX_BITWORD_WIDTH>::random(n, rng));
}

ExposedPauliString ExposedPauliString::times(const ExposedPauliString &other) const {
Expand Down
3 changes: 2 additions & 1 deletion glue/javascript/tableau.js.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ ExposedTableau::ExposedTableau(int n) : tableau(n) {
}

ExposedTableau ExposedTableau::random(int n) {
return ExposedTableau(Tableau<MAX_BITWORD_WIDTH>::random(n, JS_BIND_SHARED_RNG()));
auto rng = externally_seeded_rng();
return ExposedTableau(Tableau<MAX_BITWORD_WIDTH>::random(n, rng));
}

ExposedTableau ExposedTableau::from_named_gate(const std::string &name) {
Expand Down
2 changes: 1 addition & 1 deletion glue/javascript/tableau_simulator.js.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ static JsCircuitInstruction args_to_target_pairs(TableauSimulator<MAX_BITWORD_WI
return result;
}

ExposedTableauSimulator::ExposedTableauSimulator() : sim(JS_BIND_SHARED_RNG(), 0) {
ExposedTableauSimulator::ExposedTableauSimulator() : sim(externally_seeded_rng(), 0) {
}

bool ExposedTableauSimulator::measure(size_t target) {
Expand Down
10 changes: 6 additions & 4 deletions src/stim/circuit/gate_data.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ std::pair<std::vector<PauliString<W>>, std::vector<PauliString<W>>> circuit_outp
if (circuit.count_measurements() > 1) {
throw std::invalid_argument("count_measurements > 1");
}
TableauSimulator<W> sim1(SHARED_TEST_RNG(), circuit.count_qubits(), -1);
TableauSimulator<W> sim2(SHARED_TEST_RNG(), circuit.count_qubits(), +1);
TableauSimulator<W> sim1(INDEPENDENT_TEST_RNG(), circuit.count_qubits(), -1);
TableauSimulator<W> sim2(INDEPENDENT_TEST_RNG(), circuit.count_qubits(), +1);
sim1.expand_do_circuit(circuit);
sim2.expand_do_circuit(circuit);
return {sim1.canonical_stabilizers(), sim2.canonical_stabilizers()};
Expand Down Expand Up @@ -163,14 +163,16 @@ TEST_EACH_WORD_SIZE_W(gate_data, stabilizer_flows_are_correct, {

Circuit c;
c.safe_append(g.id, targets, {});
auto r = check_if_circuit_has_stabilizer_flows(256, SHARED_TEST_RNG(), c, flows);
auto rng = INDEPENDENT_TEST_RNG();
auto r = check_if_circuit_has_stabilizer_flows(256, rng, c, flows);
for (uint32_t fk = 0; fk < (uint32_t)flows.size(); fk++) {
EXPECT_TRUE(r[fk]) << "gate " << g.name << " has an unsatisfied flow: " << flows[fk];
}
}
})

TEST_EACH_WORD_SIZE_W(gate_data, stabilizer_flows_are_also_correct_for_decomposed_circuit, {
auto rng = INDEPENDENT_TEST_RNG();
for (const auto &g : GATE_DATA.items) {
auto flows = g.flows<W>();
if (flows.empty()) {
Expand All @@ -194,7 +196,7 @@ TEST_EACH_WORD_SIZE_W(gate_data, stabilizer_flows_are_also_correct_for_decompose
}

Circuit c(g.extra_data_func().h_s_cx_m_r_decomposition);
auto r = check_if_circuit_has_stabilizer_flows(256, SHARED_TEST_RNG(), c, flows);
auto r = check_if_circuit_has_stabilizer_flows(256, rng, c, flows);
for (uint32_t fk = 0; fk < (uint32_t)flows.size(); fk++) {
EXPECT_TRUE(r[fk]) << "gate " << g.name << " has a decomposition with an unsatisfied flow: " << flows[fk];
}
Expand Down
3 changes: 2 additions & 1 deletion src/stim/circuit/stabilizer_flow.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
using namespace stim;

TEST_EACH_WORD_SIZE_W(stabilizer_flow, check_if_circuit_has_stabilizer_flows, {
auto rng = INDEPENDENT_TEST_RNG();
auto results = check_if_circuit_has_stabilizer_flows<W>(
256,
SHARED_TEST_RNG(),
rng,
Circuit(R"CIRCUIT(
R 4
CX 0 4 1 4 2 4 3 4
Expand Down
3 changes: 2 additions & 1 deletion src/stim/cmd/command_gen.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ TEST_EACH_WORD_SIZE_W(command_gen, no_noise_no_detections, {
}
CircuitGenParameters params(r, d, func.second.first);
auto circuit = func.second.second(params).circuit;
auto [det_samples, obs_samples] = sample_batch_detection_events<W>(circuit, 256, SHARED_TEST_RNG());
auto rng = INDEPENDENT_TEST_RNG();
auto [det_samples, obs_samples] = sample_batch_detection_events<W>(circuit, 256, rng);
EXPECT_FALSE(det_samples.data.not_zero() || obs_samples.data.not_zero())
<< "d=" << d << ", r=" << r << ", task=" << func.second.first << ", func=" << func.first;
}
Expand Down
3 changes: 2 additions & 1 deletion src/stim/cmd/command_sample.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ int stim::command_sample(int argc, const char **argv) {

if (num_shots == 1 && !skip_reference_sample) {
TableauSimulator<MAX_BITWORD_WIDTH>::sample_stream(in, out, out_format.id, false, rng);
} else if (num_shots > 0) {
} else {
assert(num_shots > 0);
auto circuit = Circuit::from_file(in);
simd_bits<MAX_BITWORD_WIDTH> ref(0);
if (!skip_reference_sample) {
Expand Down
2 changes: 1 addition & 1 deletion src/stim/dem/detector_error_model.pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -997,7 +997,7 @@ void stim_pybind::pybind_detector_error_model_methods(
c.def(
"compile_sampler",
[](const DetectorErrorModel &self, const pybind11::object &seed) -> DemSampler<MAX_BITWORD_WIDTH> {
return DemSampler<MAX_BITWORD_WIDTH>(self, *make_py_seeded_rng(seed), 1024);
return DemSampler<MAX_BITWORD_WIDTH>(self, make_py_seeded_rng(seed), 1024);
},
pybind11::kw_only(),
pybind11::arg("seed") = pybind11::none(),
Expand Down
17 changes: 11 additions & 6 deletions src/stim/io/measure_record_reader.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,8 @@ TEST_EACH_WORD_SIZE_W(MeasureRecordReader, read_records_into_RoundTrip, {
size_t n_shots = 100;
size_t n_results = 512 - 8;

auto shot_maj_data = simd_bit_table<W>::random(n_shots, n_results, SHARED_TEST_RNG());
auto rng = INDEPENDENT_TEST_RNG();
auto shot_maj_data = simd_bit_table<W>::random(n_shots, n_results, rng);
auto shot_min_data = shot_maj_data.transposed();
for (const auto &kv : format_name_to_enum_map()) {
SampleFormat format = kv.second.id;
Expand Down Expand Up @@ -556,13 +557,14 @@ TEST_EACH_WORD_SIZE_W(MeasureRecordReader, read_b8_detection_event_data_full_run
})

TEST_EACH_WORD_SIZE_W(MeasureRecordReader, start_and_read_entire_record, {
auto rng = INDEPENDENT_TEST_RNG();
size_t n = 512 - 8;
size_t no = 5;
size_t nd = n - no;

// Compute expected data.
simd_bits<W> test_data(n);
biased_randomize_bits(0.1, test_data.u64, test_data.u64 + test_data.num_u64_padded(), SHARED_TEST_RNG());
biased_randomize_bits(0.1, test_data.u64, test_data.u64 + test_data.num_u64_padded(), rng);
SparseShot sparse_test_data;
sparse_test_data.obs_mask = simd_bits<64>(no);
for (size_t k = 0; k < nd; k++) {
Expand Down Expand Up @@ -660,9 +662,10 @@ TEST_EACH_WORD_SIZE_W(MeasureRecordReader, start_and_read_entire_record_all_zero
})

TEST_EACH_WORD_SIZE_W(MeasureRecordReader, start_and_read_entire_record_ptb64_dense, {
auto rng = INDEPENDENT_TEST_RNG();
FILE *f = tmpfile();
auto saved1 = simd_bits<W>::random(64 * 71, SHARED_TEST_RNG());
auto saved2 = simd_bits<W>::random(64 * 71, SHARED_TEST_RNG());
auto saved1 = simd_bits<W>::random(64 * 71, rng);
auto saved2 = simd_bits<W>::random(64 * 71, rng);
for (size_t k = 0; k < 64 * 71 / 8; k++) {
putc(saved1.u8[k], f);
}
Expand All @@ -689,12 +692,13 @@ TEST_EACH_WORD_SIZE_W(MeasureRecordReader, start_and_read_entire_record_ptb64_de
})

TEST_EACH_WORD_SIZE_W(MeasureRecordReader, start_and_read_entire_record_ptb64_sparse, {
auto rng = INDEPENDENT_TEST_RNG();
FILE *tmp = tmpfile();
simd_bit_table<W> ground_truth(71, 64 * 5);
{
MeasureRecordBatchWriter writer(tmp, 64 * 5, stim::SAMPLE_FORMAT_PTB64);
for (size_t k = 0; k < 71; k++) {
ground_truth[k].randomize(64 * 5, SHARED_TEST_RNG());
ground_truth[k].randomize(64 * 5, rng);
writer.batch_write_bit<W>(ground_truth[k]);
}
writer.write_end();
Expand All @@ -719,6 +723,7 @@ TEST_EACH_WORD_SIZE_W(MeasureRecordReader, start_and_read_entire_record_ptb64_sp
})

TEST_EACH_WORD_SIZE_W(MeasureRecordReader, read_file_data_into_shot_table_vs_write_table, {
auto rng = INDEPENDENT_TEST_RNG();
for (const auto &format_data : format_name_to_enum_map()) {
SampleFormat format = format_data.second.id;
size_t num_shots = 500;
Expand All @@ -729,7 +734,7 @@ TEST_EACH_WORD_SIZE_W(MeasureRecordReader, read_file_data_into_shot_table_vs_wri

simd_bit_table<W> expected(num_shots, bits_per_shot);
for (size_t shot = 0; shot < num_shots; shot++) {
expected[shot].randomize(bits_per_shot, SHARED_TEST_RNG());
expected[shot].randomize(bits_per_shot, rng);
}
simd_bit_table<W> expected_transposed = expected.transposed();

Expand Down
3 changes: 3 additions & 0 deletions src/stim/mem/simd_bit_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ struct simd_bit_table {
/// Returns a subset of the table.
simd_bit_table slice_maj(size_t maj_start_bit, size_t maj_stop_bit) const;

/// Returns a copy of a column of the table.
simd_bits<W> read_across_majors_at_minor_index(size_t major_start, size_t major_stop, size_t minor_index) const;

/// Concatenates the contents of the two tables, along the major axis.
simd_bit_table<W> concat_major(const simd_bit_table<W> &second, size_t n_first, size_t n_second) const;
/// Overwrites a range of the table with a range from another table with the same minor size.
Expand Down
12 changes: 7 additions & 5 deletions src/stim/mem/simd_bit_table.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,18 +200,19 @@ TEST_EACH_WORD_SIZE_W(simd_bit_table, transposed, {
})

TEST_EACH_WORD_SIZE_W(simd_bit_table, random, {
auto t = simd_bit_table<W>::random(100, 90, SHARED_TEST_RNG());
auto rng = INDEPENDENT_TEST_RNG();
auto t = simd_bit_table<W>::random(100, 90, rng);
ASSERT_NE(t[99], simd_bits<W>(90));
ASSERT_EQ(t[100], simd_bits<W>(90));
t = t.transposed();
ASSERT_NE(t[89], simd_bits<W>(100));
ASSERT_EQ(t[90], simd_bits<W>(100));
ASSERT_NE(
simd_bit_table<W>::random(10, 10, SHARED_TEST_RNG()), simd_bit_table<W>::random(10, 10, SHARED_TEST_RNG()));
ASSERT_NE(simd_bit_table<W>::random(10, 10, rng), simd_bit_table<W>::random(10, 10, rng));
})

TEST_EACH_WORD_SIZE_W(simd_bit_table, slice_maj, {
auto m = simd_bit_table<W>::random(100, 64, SHARED_TEST_RNG());
auto rng = INDEPENDENT_TEST_RNG();
auto m = simd_bit_table<W>::random(100, 64, rng);
auto s = m.slice_maj(5, 15);
ASSERT_EQ(s[0], m[5]);
ASSERT_EQ(s[9], m[14]);
Expand Down Expand Up @@ -291,7 +292,8 @@ TEST(simd_bit_table, lg) {
}

TEST_EACH_WORD_SIZE_W(simd_bit_table, destructive_resize, {
simd_bit_table<W> table = table.random(5, 7, SHARED_TEST_RNG());
auto rng = INDEPENDENT_TEST_RNG();
simd_bit_table<W> table = table.random(5, 7, rng);
const uint8_t *prev_pointer = table.data.u8;
table.destructive_resize(5, 7);
ASSERT_EQ(table.data.u8, prev_pointer);
Expand Down
32 changes: 19 additions & 13 deletions src/stim/mem/simd_bits.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ TEST_EACH_WORD_SIZE_W(simd_bits, str, {
TEST_EACH_WORD_SIZE_W(simd_bits, randomize, {
simd_bits<W> d(1024);

d.randomize(64 + 57, SHARED_TEST_RNG());
auto rng = INDEPENDENT_TEST_RNG();
d.randomize(64 + 57, rng);
uint64_t mask = (1ULL << 57) - 1;
// Randomized.
ASSERT_NE(d.u64[0], 0);
Expand All @@ -138,7 +139,7 @@ TEST_EACH_WORD_SIZE_W(simd_bits, randomize, {
for (size_t k = 0; k < d.num_u64_padded(); k++) {
d.u64[k] = UINT64_MAX;
}
d.randomize(64 + 57, SHARED_TEST_RNG());
d.randomize(64 + 57, rng);
// Randomized.
ASSERT_NE(d.u64[0], 0);
ASSERT_NE(d.u64[0], SIZE_MAX);
Expand All @@ -151,8 +152,9 @@ TEST_EACH_WORD_SIZE_W(simd_bits, randomize, {
})

TEST_EACH_WORD_SIZE_W(simd_bits, xor_assignment, {
simd_bits<W> m0 = simd_bits<W>::random(512, SHARED_TEST_RNG());
simd_bits<W> m1 = simd_bits<W>::random(512, SHARED_TEST_RNG());
auto rng = INDEPENDENT_TEST_RNG();
simd_bits<W> m0 = simd_bits<W>::random(512, rng);
simd_bits<W> m1 = simd_bits<W>::random(512, rng);
simd_bits<W> m2(512);
m2 ^= m0;
ASSERT_EQ(m0, m2);
Expand Down Expand Up @@ -284,7 +286,7 @@ TEST_EACH_WORD_SIZE_W(simd_bits, right_shift_assignment, {
})

TEST_EACH_WORD_SIZE_W(simd_bits, fuzz_right_shift_assignment, {
auto rng = SHARED_TEST_RNG();
auto rng = INDEPENDENT_TEST_RNG();
for (int i = 0; i < 5; i++) {
std::uniform_int_distribution dist_bits(1, 1200);
int num_bits = dist_bits(rng);
Expand Down Expand Up @@ -334,7 +336,7 @@ TEST_EACH_WORD_SIZE_W(simd_bits, left_shift_assignment, {
})

TEST_EACH_WORD_SIZE_W(simd_bits, fuzz_left_shift_assignment, {
auto rng = SHARED_TEST_RNG();
auto rng = INDEPENDENT_TEST_RNG();
for (int i = 0; i < 5; i++) {
std::uniform_int_distribution dist_bits(1, 1200);
int num_bits = dist_bits(rng);
Expand All @@ -356,8 +358,9 @@ TEST_EACH_WORD_SIZE_W(simd_bits, fuzz_left_shift_assignment, {
TEST_EACH_WORD_SIZE_W(simd_bits, assignment, {
simd_bits<W> m0(512);
simd_bits<W> m1(512);
m0.randomize(512, SHARED_TEST_RNG());
m1.randomize(512, SHARED_TEST_RNG());
auto rng = INDEPENDENT_TEST_RNG();
m0.randomize(512, rng);
m1.randomize(512, rng);
auto old_m1 = m1.u64[0];
ASSERT_NE(m0, m1);
m0 = m1;
Expand Down Expand Up @@ -389,8 +392,9 @@ TEST_EACH_WORD_SIZE_W(simd_bits, swap_with, {
simd_bits<W> m1(512);
simd_bits<W> m2(512);
simd_bits<W> m3(512);
m0.randomize(512, SHARED_TEST_RNG());
m1.randomize(512, SHARED_TEST_RNG());
auto rng = INDEPENDENT_TEST_RNG();
m0.randomize(512, rng);
m1.randomize(512, rng);
m2 = m0;
m3 = m1;
ASSERT_EQ(m0, m2);
Expand All @@ -402,7 +406,8 @@ TEST_EACH_WORD_SIZE_W(simd_bits, swap_with, {

TEST_EACH_WORD_SIZE_W(simd_bits, clear, {
simd_bits<W> m0(512);
m0.randomize(512, SHARED_TEST_RNG());
auto rng = INDEPENDENT_TEST_RNG();
m0.randomize(512, rng);
ASSERT_TRUE(m0.not_zero());
m0.clear();
ASSERT_TRUE(!m0.not_zero());
Expand Down Expand Up @@ -471,8 +476,9 @@ TEST_EACH_WORD_SIZE_W(simd_bits, mask_assignment_or, {
})

TEST_EACH_WORD_SIZE_W(simd_bits, truncated_overwrite_from, {
simd_bits<W> dat = simd_bits<W>::random(1024, SHARED_TEST_RNG());
simd_bits<W> mut = simd_bits<W>::random(1024, SHARED_TEST_RNG());
auto rng = INDEPENDENT_TEST_RNG();
simd_bits<W> dat = simd_bits<W>::random(1024, rng);
simd_bits<W> mut = simd_bits<W>::random(1024, rng);
simd_bits<W> old = mut;

mut.truncated_overwrite_from(dat, 455);
Expand Down
4 changes: 2 additions & 2 deletions src/stim/mem/simd_bits_range_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ simd_bits_range_ref<W> simd_bits_range_ref<W>::operator<<=(int offset) {
}
while (offset >= 64) {
incoming_word = 0ULL;
for (int w = 0; w < num_u64_padded(); w++) {
for (size_t w = 0; w < num_u64_padded(); w++) {
cur_word = u64[w];
u64[w] = incoming_word;
incoming_word = cur_word;
Expand All @@ -114,7 +114,7 @@ simd_bits_range_ref<W> simd_bits_range_ref<W>::operator<<=(int offset) {
return *this;
}
incoming_word = 0ULL;
for (int w = 0; w < num_u64_padded(); w++) {
for (size_t w = 0; w < num_u64_padded(); w++) {
cur_word = u64[w];
u64[w] <<= offset;
u64[w] |= incoming_word;
Expand Down
Loading

0 comments on commit b56c1d7

Please sign in to comment.