From fa834fde181803236740526cd4cb605dc8f6a599 Mon Sep 17 00:00:00 2001 From: 0x177 Date: Thu, 28 Aug 2025 13:10:54 +0300 Subject: [PATCH 1/5] make final_density_matrix work with classical control --- cirq-core/cirq/sim/mux.py | 34 +++++++++++++++++++++++++++++++--- cirq-core/cirq/sim/mux_test.py | 22 ++++++++++++++++++++++ 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/cirq-core/cirq/sim/mux.py b/cirq-core/cirq/sim/mux.py index e811623f74a..bc809885532 100644 --- a/cirq-core/cirq/sim/mux.py +++ b/cirq-core/cirq/sim/mux.py @@ -28,6 +28,8 @@ from cirq.sim import density_matrix_simulator, sparse_simulator from cirq.sim.clifford import clifford_simulator from cirq.transformers import measurement_transformers +from cirq.linalg.transformations import partial_trace +from cirq.protocols.qid_shape_protocol import num_qubits, qid_shape if TYPE_CHECKING: import cirq @@ -291,11 +293,17 @@ def final_density_matrix( return sparse_result.density_matrix_of() else: # noisy case: use DensityMatrixSimulator with dephasing + + if ignore_measurement_results: + noise_applied = circuit_like.with_noise(noise) + defered = measurement_transformers.defer_measurements(noise_applied) + dephased = measurement_transformers.dephase_measurements(defered) + density_result = density_matrix_simulator.DensityMatrixSimulator( - dtype=dtype, noise=noise, seed=seed + dtype=dtype, noise=None if ignore_measurement_results else noise, seed=seed ).simulate( program=( - measurement_transformers.dephase_measurements(circuit_like) + dephased if ignore_measurement_results else circuit_like ), @@ -303,4 +311,24 @@ def final_density_matrix( qubit_order=qubit_order, param_resolver=param_resolver, ) - return density_result.final_density_matrix + + res = density_result.final_density_matrix + + if ignore_measurement_results: + nq = num_qubits(circuit_like) + qids = qid_shape(circuit_like) + + #assuming that the ancella bits from the transformations are at the end + keep = list(range(nq)) + + dephased_qids = qid_shape(dephased) + tensor_form = np.reshape(res, dephased_qids + dephased_qids) + + reduced_form = partial_trace(tensor_form, keep) + + width = np.prod(qids) + r = np.reshape(reduced_form, (width,width)) + + return r + + return res diff --git a/cirq-core/cirq/sim/mux_test.py b/cirq-core/cirq/sim/mux_test.py index 21692174b1f..419282cb1c5 100644 --- a/cirq-core/cirq/sim/mux_test.py +++ b/cirq-core/cirq/sim/mux_test.py @@ -410,3 +410,25 @@ def test_ps_initial_state_dmat(): cirq.final_density_matrix(cirq.Circuit(cirq.H(q0), cirq.I(q1))), cirq.final_density_matrix(cirq.Circuit(cirq.I.on_each(q0, q1)), initial_state=sp0), ) + +def test_dm_classical_control(): + q0, q1 = cirq.LineQubit.range(2) + + circuit = cirq.Circuit( + cirq.H(q0), + cirq.measure(q0, key='a'), + cirq.H(q1).with_classical_controls('a'), + cirq.measure(q1, key='b'), + ) + + dm = cirq.final_density_matrix(circuit) + + assert dm.shape == (4,4) + + expected = np.zeros((4,4), dtype=np.complex64) + + expected[0,0] = 0.5 + 0.0j + expected[2,2] = 0.25 + 0.0j + expected[3,3] = 0.25 + 0.0j + + np.testing.assert_allclose(dm,expected) From e74fee0440ae4d50384ef8e16f894e0bdf2d787f Mon Sep 17 00:00:00 2001 From: 0x177 Date: Fri, 29 Aug 2025 17:03:51 +0300 Subject: [PATCH 2/5] fix formatting, handle terminal measurements more efficiently --- cirq-core/cirq/sim/mux.py | 48 ++++++++++++++++++++-------------- cirq-core/cirq/sim/mux_test.py | 17 ++++++------ 2 files changed, 36 insertions(+), 29 deletions(-) diff --git a/cirq-core/cirq/sim/mux.py b/cirq-core/cirq/sim/mux.py index bc809885532..3ece200585d 100644 --- a/cirq-core/cirq/sim/mux.py +++ b/cirq-core/cirq/sim/mux.py @@ -25,11 +25,11 @@ from cirq import circuits, devices, ops, protocols, study, value from cirq._doc import document +from cirq.linalg import transformations +from cirq.protocols import qid_shape_protocol from cirq.sim import density_matrix_simulator, sparse_simulator from cirq.sim.clifford import clifford_simulator from cirq.transformers import measurement_transformers -from cirq.linalg.transformations import partial_trace -from cirq.protocols.qid_shape_protocol import num_qubits, qid_shape if TYPE_CHECKING: import cirq @@ -293,20 +293,26 @@ def final_density_matrix( return sparse_result.density_matrix_of() else: # noisy case: use DensityMatrixSimulator with dephasing + has_classical_control = circuit_like != measurement_transformers.defer_measurements(circuit_like) + handling_classical_control = ignore_measurement_results and has_classical_control - if ignore_measurement_results: - noise_applied = circuit_like.with_noise(noise) + if handling_classical_control: + # case 1: classical control + noise_applied = circuit_like.with_noise(noise) if noise is not None else circuit_like defered = measurement_transformers.defer_measurements(noise_applied) dephased = measurement_transformers.dephase_measurements(defered) + program = dephased + elif ignore_measurement_results: + # case 2: no classical control, only terminal measurement + program = measurement_transformers.dephase_measurements(circuit_like) + else: + # case 3: no measurement + program = circuit_like density_result = density_matrix_simulator.DensityMatrixSimulator( - dtype=dtype, noise=None if ignore_measurement_results else noise, seed=seed + dtype=dtype, noise=None if handling_classical_control else noise, seed=seed ).simulate( - program=( - dephased - if ignore_measurement_results - else circuit_like - ), + program, initial_state=initial_state, qubit_order=qubit_order, param_resolver=param_resolver, @@ -314,21 +320,23 @@ def final_density_matrix( res = density_result.final_density_matrix - if ignore_measurement_results: - nq = num_qubits(circuit_like) - qids = qid_shape(circuit_like) + if handling_classical_control: + num_qubits = qid_shape_protocol.num_qubits(circuit_like) + qid_shape = qid_shape_protocol.qid_shape(circuit_like) + + # assuming that the ancella bits from the transformations are at the end + keep = list(range(num_qubits)) + + dephased_qid_shape = qid_shape_protocol.qid_shape(dephased) - #assuming that the ancella bits from the transformations are at the end - keep = list(range(nq)) + tensor_form = np.reshape(res, dephased_qid_shape + dephased_qid_shape) - dephased_qids = qid_shape(dephased) - tensor_form = np.reshape(res, dephased_qids + dephased_qids) + reduced_form = transformations.partial_trace(tensor_form, keep) - reduced_form = partial_trace(tensor_form, keep) + width = np.prod(qid_shape) - width = np.prod(qids) r = np.reshape(reduced_form, (width,width)) - + return r return res diff --git a/cirq-core/cirq/sim/mux_test.py b/cirq-core/cirq/sim/mux_test.py index 419282cb1c5..360633d4b61 100644 --- a/cirq-core/cirq/sim/mux_test.py +++ b/cirq-core/cirq/sim/mux_test.py @@ -413,22 +413,21 @@ def test_ps_initial_state_dmat(): def test_dm_classical_control(): q0, q1 = cirq.LineQubit.range(2) - + circuit = cirq.Circuit( cirq.H(q0), cirq.measure(q0, key='a'), cirq.H(q1).with_classical_controls('a'), cirq.measure(q1, key='b'), ) - - dm = cirq.final_density_matrix(circuit) - assert dm.shape == (4,4) + dm = cirq.final_density_matrix(circuit) + assert dm.shape == (4, 4) - expected = np.zeros((4,4), dtype=np.complex64) + expected = np.zeros((4, 4), dtype=np.complex64) - expected[0,0] = 0.5 + 0.0j - expected[2,2] = 0.25 + 0.0j - expected[3,3] = 0.25 + 0.0j + expected[0, 0] = 0.5 + 0.0j + expected[2, 2] = 0.25 + 0.0j + expected[3, 3] = 0.25 + 0.0j - np.testing.assert_allclose(dm,expected) + np.testing.assert_allclose(dm, expected) From e12cab71b94549a6229579cc1b0b2999e549e07b Mon Sep 17 00:00:00 2001 From: 0x177 Date: Fri, 29 Aug 2025 19:35:23 +0300 Subject: [PATCH 3/5] nits, lint, formatting --- cirq-core/cirq/sim/mux.py | 18 +++++++------- cirq-core/cirq/sim/mux_test.py | 43 +++++++++++++++++----------------- 2 files changed, 30 insertions(+), 31 deletions(-) diff --git a/cirq-core/cirq/sim/mux.py b/cirq-core/cirq/sim/mux.py index 3ece200585d..f6f23a176a6 100644 --- a/cirq-core/cirq/sim/mux.py +++ b/cirq-core/cirq/sim/mux.py @@ -26,7 +26,6 @@ from cirq import circuits, devices, ops, protocols, study, value from cirq._doc import document from cirq.linalg import transformations -from cirq.protocols import qid_shape_protocol from cirq.sim import density_matrix_simulator, sparse_simulator from cirq.sim.clifford import clifford_simulator from cirq.transformers import measurement_transformers @@ -293,7 +292,9 @@ def final_density_matrix( return sparse_result.density_matrix_of() else: # noisy case: use DensityMatrixSimulator with dephasing - has_classical_control = circuit_like != measurement_transformers.defer_measurements(circuit_like) + has_classical_control = circuit_like != measurement_transformers.defer_measurements( + circuit_like + ) handling_classical_control = ignore_measurement_results and has_classical_control if handling_classical_control: @@ -317,17 +318,16 @@ def final_density_matrix( qubit_order=qubit_order, param_resolver=param_resolver, ) - res = density_result.final_density_matrix if handling_classical_control: - num_qubits = qid_shape_protocol.num_qubits(circuit_like) - qid_shape = qid_shape_protocol.qid_shape(circuit_like) + num_qubits = protocols.qid_shape_protocol.num_qubits(circuit_like) + qid_shape = protocols.qid_shape_protocol.qid_shape(circuit_like) # assuming that the ancella bits from the transformations are at the end keep = list(range(num_qubits)) - dephased_qid_shape = qid_shape_protocol.qid_shape(dephased) + dephased_qid_shape = protocols.qid_shape_protocol.qid_shape(dephased) tensor_form = np.reshape(res, dephased_qid_shape + dephased_qid_shape) @@ -335,8 +335,6 @@ def final_density_matrix( width = np.prod(qid_shape) - r = np.reshape(reduced_form, (width,width)) - - return r - + res = np.reshape(reduced_form, (width, width)) + return res diff --git a/cirq-core/cirq/sim/mux_test.py b/cirq-core/cirq/sim/mux_test.py index 360633d4b61..d79945c333c 100644 --- a/cirq-core/cirq/sim/mux_test.py +++ b/cirq-core/cirq/sim/mux_test.py @@ -380,6 +380,28 @@ def test_final_density_matrix_noise(): ) +def test_final_density_matrix_classical_control(): + q0, q1 = cirq.LineQubit.range(2) + + circuit = cirq.Circuit( + cirq.H(q0), + cirq.measure(q0, key='a'), + cirq.H(q1).with_classical_controls('a'), + cirq.measure(q1, key='b'), + ) + + dm = cirq.final_density_matrix(circuit) + assert dm.shape == (4, 4) + + expected = np.zeros((4, 4), dtype=np.complex64) + + expected[0, 0] = 0.5 + 0.0j + expected[2, 2] = 0.25 + 0.0j + expected[3, 3] = 0.25 + 0.0j + + np.testing.assert_allclose(dm, expected) + + def test_ps_initial_state_wfn(): q0, q1 = cirq.LineQubit.range(2) s00 = cirq.KET_ZERO(q0) * cirq.KET_ZERO(q1) @@ -410,24 +432,3 @@ def test_ps_initial_state_dmat(): cirq.final_density_matrix(cirq.Circuit(cirq.H(q0), cirq.I(q1))), cirq.final_density_matrix(cirq.Circuit(cirq.I.on_each(q0, q1)), initial_state=sp0), ) - -def test_dm_classical_control(): - q0, q1 = cirq.LineQubit.range(2) - - circuit = cirq.Circuit( - cirq.H(q0), - cirq.measure(q0, key='a'), - cirq.H(q1).with_classical_controls('a'), - cirq.measure(q1, key='b'), - ) - - dm = cirq.final_density_matrix(circuit) - assert dm.shape == (4, 4) - - expected = np.zeros((4, 4), dtype=np.complex64) - - expected[0, 0] = 0.5 + 0.0j - expected[2, 2] = 0.25 + 0.0j - expected[3, 3] = 0.25 + 0.0j - - np.testing.assert_allclose(dm, expected) From 2828ba929b95be0ce97de8909166f8563da6c934 Mon Sep 17 00:00:00 2001 From: Pavol Juhas Date: Tue, 2 Sep 2025 11:08:53 -0700 Subject: [PATCH 4/5] Few syntax tweaks, no change in function - access protocol functions from the `protocols` module - eliminate single-use variables - remove unnecessary blank lines --- cirq-core/cirq/sim/mux.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/cirq-core/cirq/sim/mux.py b/cirq-core/cirq/sim/mux.py index f6f23a176a6..513ab282690 100644 --- a/cirq-core/cirq/sim/mux.py +++ b/cirq-core/cirq/sim/mux.py @@ -318,23 +318,15 @@ def final_density_matrix( qubit_order=qubit_order, param_resolver=param_resolver, ) - res = density_result.final_density_matrix + result = density_result.final_density_matrix if handling_classical_control: - num_qubits = protocols.qid_shape_protocol.num_qubits(circuit_like) - qid_shape = protocols.qid_shape_protocol.qid_shape(circuit_like) - - # assuming that the ancella bits from the transformations are at the end - keep = list(range(num_qubits)) - - dephased_qid_shape = protocols.qid_shape_protocol.qid_shape(dephased) - - tensor_form = np.reshape(res, dephased_qid_shape + dephased_qid_shape) - + # assuming that the ancilla qubits from the transformations are at the end + keep = list(range(protocols.num_qubits(circuit_like))) + dephased_qid_shape = protocols.qid_shape(dephased) + tensor_form = np.reshape(result, dephased_qid_shape + dephased_qid_shape) reduced_form = transformations.partial_trace(tensor_form, keep) + width = np.prod(protocols.qid_shape(circuit_like)) + result = np.reshape(reduced_form, (width, width)) - width = np.prod(qid_shape) - - res = np.reshape(reduced_form, (width, width)) - - return res + return result From 2940ee8ab6913ff121aa42d2d917de339bb6bf32 Mon Sep 17 00:00:00 2001 From: Pavol Juhas Date: Tue, 2 Sep 2025 11:13:59 -0700 Subject: [PATCH 5/5] Adopt style of other final_density_matrix tests in the test module - use np.diag to construct the expected density matrix - remove test of matrix shape; it is already checked in assert_allclose --- cirq-core/cirq/sim/mux_test.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/cirq-core/cirq/sim/mux_test.py b/cirq-core/cirq/sim/mux_test.py index d79945c333c..fd5c3781d7f 100644 --- a/cirq-core/cirq/sim/mux_test.py +++ b/cirq-core/cirq/sim/mux_test.py @@ -382,24 +382,16 @@ def test_final_density_matrix_noise(): def test_final_density_matrix_classical_control(): q0, q1 = cirq.LineQubit.range(2) - circuit = cirq.Circuit( cirq.H(q0), cirq.measure(q0, key='a'), cirq.H(q1).with_classical_controls('a'), cirq.measure(q1, key='b'), ) - - dm = cirq.final_density_matrix(circuit) - assert dm.shape == (4, 4) - - expected = np.zeros((4, 4), dtype=np.complex64) - - expected[0, 0] = 0.5 + 0.0j - expected[2, 2] = 0.25 + 0.0j - expected[3, 3] = 0.25 + 0.0j - - np.testing.assert_allclose(dm, expected) + np.testing.assert_allclose( + cirq.final_density_matrix(circuit), + np.diag(np.array([0.5, 0.0, 0.25, 0.25], dtype=np.complex64)), + ) def test_ps_initial_state_wfn():