Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions cirq-core/cirq/sim/mux.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from cirq import circuits, devices, ops, protocols, study, value
from cirq._doc import document
from cirq.linalg import transformations
from cirq.sim import density_matrix_simulator, sparse_simulator
from cirq.sim.clifford import clifford_simulator
from cirq.transformers import measurement_transformers
Expand Down Expand Up @@ -291,16 +292,41 @@ def final_density_matrix(
return sparse_result.density_matrix_of()
else:
# noisy case: use DensityMatrixSimulator with dephasing
has_classical_control = circuit_like != measurement_transformers.defer_measurements(
circuit_like
)
handling_classical_control = ignore_measurement_results and has_classical_control

if handling_classical_control:
# case 1: classical control
noise_applied = circuit_like.with_noise(noise) if noise is not None else circuit_like
defered = measurement_transformers.defer_measurements(noise_applied)
dephased = measurement_transformers.dephase_measurements(defered)
program = dephased
elif ignore_measurement_results:
# case 2: no classical control, only terminal measurement
program = measurement_transformers.dephase_measurements(circuit_like)
else:
# case 3: no measurement
program = circuit_like

density_result = density_matrix_simulator.DensityMatrixSimulator(
dtype=dtype, noise=noise, seed=seed
dtype=dtype, noise=None if handling_classical_control else noise, seed=seed
).simulate(
program=(
measurement_transformers.dephase_measurements(circuit_like)
if ignore_measurement_results
else circuit_like
),
program,
initial_state=initial_state,
qubit_order=qubit_order,
param_resolver=param_resolver,
)
return density_result.final_density_matrix
result = density_result.final_density_matrix

if handling_classical_control:
# assuming that the ancilla qubits from the transformations are at the end
keep = list(range(protocols.num_qubits(circuit_like)))
dephased_qid_shape = protocols.qid_shape(dephased)
tensor_form = np.reshape(result, dephased_qid_shape + dephased_qid_shape)
reduced_form = transformations.partial_trace(tensor_form, keep)
width = np.prod(protocols.qid_shape(circuit_like))
result = np.reshape(reduced_form, (width, width))

return result
14 changes: 14 additions & 0 deletions cirq-core/cirq/sim/mux_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,20 @@ def test_final_density_matrix_noise():
)


def test_final_density_matrix_classical_control():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.H(q0),
cirq.measure(q0, key='a'),
cirq.H(q1).with_classical_controls('a'),
cirq.measure(q1, key='b'),
)
np.testing.assert_allclose(
cirq.final_density_matrix(circuit),
np.diag(np.array([0.5, 0.0, 0.25, 0.25], dtype=np.complex64)),
)


def test_ps_initial_state_wfn():
q0, q1 = cirq.LineQubit.range(2)
s00 = cirq.KET_ZERO(q0) * cirq.KET_ZERO(q1)
Expand Down