diff --git a/cirq-core/cirq/sim/mux.py b/cirq-core/cirq/sim/mux.py index e811623f74a..513ab282690 100644 --- a/cirq-core/cirq/sim/mux.py +++ b/cirq-core/cirq/sim/mux.py @@ -25,6 +25,7 @@ from cirq import circuits, devices, ops, protocols, study, value from cirq._doc import document +from cirq.linalg import transformations from cirq.sim import density_matrix_simulator, sparse_simulator from cirq.sim.clifford import clifford_simulator from cirq.transformers import measurement_transformers @@ -291,16 +292,41 @@ def final_density_matrix( return sparse_result.density_matrix_of() else: # noisy case: use DensityMatrixSimulator with dephasing + has_classical_control = circuit_like != measurement_transformers.defer_measurements( + circuit_like + ) + handling_classical_control = ignore_measurement_results and has_classical_control + + if handling_classical_control: + # case 1: classical control + noise_applied = circuit_like.with_noise(noise) if noise is not None else circuit_like + defered = measurement_transformers.defer_measurements(noise_applied) + dephased = measurement_transformers.dephase_measurements(defered) + program = dephased + elif ignore_measurement_results: + # case 2: no classical control, only terminal measurement + program = measurement_transformers.dephase_measurements(circuit_like) + else: + # case 3: no measurement + program = circuit_like + density_result = density_matrix_simulator.DensityMatrixSimulator( - dtype=dtype, noise=noise, seed=seed + dtype=dtype, noise=None if handling_classical_control else noise, seed=seed ).simulate( - program=( - measurement_transformers.dephase_measurements(circuit_like) - if ignore_measurement_results - else circuit_like - ), + program, initial_state=initial_state, qubit_order=qubit_order, param_resolver=param_resolver, ) - return density_result.final_density_matrix + result = density_result.final_density_matrix + + if handling_classical_control: + # assuming that the ancilla qubits from the transformations are at the end + keep = list(range(protocols.num_qubits(circuit_like))) + dephased_qid_shape = protocols.qid_shape(dephased) + tensor_form = np.reshape(result, dephased_qid_shape + dephased_qid_shape) + reduced_form = transformations.partial_trace(tensor_form, keep) + width = np.prod(protocols.qid_shape(circuit_like)) + result = np.reshape(reduced_form, (width, width)) + + return result diff --git a/cirq-core/cirq/sim/mux_test.py b/cirq-core/cirq/sim/mux_test.py index 21692174b1f..fd5c3781d7f 100644 --- a/cirq-core/cirq/sim/mux_test.py +++ b/cirq-core/cirq/sim/mux_test.py @@ -380,6 +380,20 @@ def test_final_density_matrix_noise(): ) +def test_final_density_matrix_classical_control(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.H(q0), + cirq.measure(q0, key='a'), + cirq.H(q1).with_classical_controls('a'), + cirq.measure(q1, key='b'), + ) + np.testing.assert_allclose( + cirq.final_density_matrix(circuit), + np.diag(np.array([0.5, 0.0, 0.25, 0.25], dtype=np.complex64)), + ) + + def test_ps_initial_state_wfn(): q0, q1 = cirq.LineQubit.range(2) s00 = cirq.KET_ZERO(q0) * cirq.KET_ZERO(q1)