Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use partial_trace to factor density matrix #4300

Merged
merged 3 commits into from
Jul 8, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 6 additions & 5 deletions cirq-core/cirq/linalg/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def fractional_swap(target):
return out


def partial_trace(tensor: np.ndarray, keep_indices: List[int]) -> np.ndarray:
def partial_trace(tensor: np.ndarray, keep_indices: Sequence[int]) -> np.ndarray:
"""Takes the partial trace of a given tensor.

The input tensor must have shape `(d_0, ..., d_{k-1}, d_0, ..., d_{k-1})`.
Expand Down Expand Up @@ -620,12 +620,13 @@ def factor_density_matrix(
`remainder` means the sub-matrix on the remaining axes, in the same
order as the original density matrix.
"""
axes1 = list(axes) + [i + int(t.ndim / 2) for i in axes]
extracted, remainder = factor_state_vector(t, axes1, validate=False)
extracted = partial_trace(t, axes)
remaining_axes = [i for i in range(t.ndim // 2) if i not in axes]
remainder = partial_trace(t, remaining_axes)
if validate:
t1 = density_matrix_kronecker_product(extracted, remainder)
axes2 = list(axes) + [i for i in range(int(t.ndim / 2)) if i not in axes]
t2 = transpose_density_matrix_to_axis_order(t1, axes2)
product_axes = list(axes) + remaining_axes
t2 = transpose_density_matrix_to_axis_order(t1, product_axes)
if not np.allclose(t2, t, atol=atol):
raise ValueError('The tensor cannot be factored by the requested axes')
return extracted, remainder
Expand Down
10 changes: 10 additions & 0 deletions cirq-core/cirq/sim/density_matrix_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1520,3 +1520,13 @@ def test_final_density_matrix_is_not_last_object():
assert result.final_density_matrix is not initial_state
assert not np.shares_memory(result.final_density_matrix, initial_state)
np.testing.assert_equal(result.final_density_matrix, initial_state)


def test_density_matrices_same_with_or_without_split_untangled_states():
sim = cirq.DensityMatrixSimulator(split_untangled_states=False)
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(cirq.H(q0), cirq.CX.on(q0, q1), cirq.reset(q1))
result1 = sim.simulate(circuit).final_density_matrix
sim = cirq.DensityMatrixSimulator(split_untangled_states=True)
result2 = sim.simulate(circuit).final_density_matrix
assert np.allclose(result1, result2)