Skip to content

Commit

Permalink
Speed up MPS simulator test (#4063)
Browse files Browse the repository at this point in the history
This was create 288 circuits and running 4 simulations per circuit.  The test took >9 seconds.  I've replaced this with random circuits and reduced the number to of circuits to 5 to get this to around 2 seconds.

I don't think this test was testing anything special about the two qubit circuit, but if so then this won't be a good fix so you should yell.
  • Loading branch information
dabacon committed Apr 28, 2021
1 parent 5f51fff commit ebf8fa3
Showing 1 changed file with 18 additions and 39 deletions.
57 changes: 18 additions & 39 deletions cirq-core/cirq/contrib/quimb/mps_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,47 +118,26 @@ def test_same_partial_trace():
qubit_order = cirq.LineQubit.range(2)
q0, q1 = qubit_order

angles = [0.0, 0.20160913, math.pi / 3.0, math.pi / 2.0, math.pi]

gate_cls = [cirq.rx, cirq.ry, cirq.rz]

for angle_0 in angles:
for gate_0 in gate_cls:
for angle_1 in angles:
for gate_1 in gate_cls:
for use_cnot in [False, True]:
op0 = gate_0(angle_0)
op1 = gate_1(angle_1)

circuit = cirq.Circuit()
circuit.append(op0(q0))
if use_cnot:
circuit.append(cirq.qft(q0, q1))
circuit.append(op1(q1))
if use_cnot:
circuit.append(cirq.qft(q1, q0))

for initial_state in range(4):
expected_density_matrix = cirq.final_density_matrix(
circuit, qubit_order=qubit_order, initial_state=initial_state
)
expected_partial_trace = cirq.partial_trace(
expected_density_matrix.reshape(2, 2, 2, 2), keep_indices=[0]
)
mps_simulator = ccq.mps_simulator.MPSSimulator()

for _ in range(50):
for initial_state in range(4):
circuit = cirq.testing.random_circuit(qubit_order, 3, 0.9)
expected_density_matrix = cirq.final_density_matrix(
circuit, qubit_order=qubit_order, initial_state=initial_state
)
expected_partial_trace = cirq.partial_trace(
expected_density_matrix.reshape(2, 2, 2, 2), keep_indices=[0]
)

mps_simulator = ccq.mps_simulator.MPSSimulator()
final_state = mps_simulator.simulate(
circuit, qubit_order=qubit_order, initial_state=initial_state
).final_state
actual_density_matrix = final_state.partial_trace([q0, q1])
actual_partial_trace = final_state.partial_trace([q0])
final_state = mps_simulator.simulate(
circuit, qubit_order=qubit_order, initial_state=initial_state
).final_state
actual_density_matrix = final_state.partial_trace([q0, q1])
actual_partial_trace = final_state.partial_trace([q0])

np.testing.assert_allclose(
actual_density_matrix, expected_density_matrix, atol=1e-4
)
np.testing.assert_allclose(
actual_partial_trace, expected_partial_trace, atol=1e-4
)
np.testing.assert_allclose(actual_density_matrix, expected_density_matrix, atol=1e-4)
np.testing.assert_allclose(actual_partial_trace, expected_partial_trace, atol=1e-4)


def test_probs_dont_sum_up_to_one():
Expand Down

0 comments on commit ebf8fa3

Please sign in to comment.