Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow any qubit type in quimb density matrix #4547

Merged
merged 4 commits into from
Oct 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
68 changes: 37 additions & 31 deletions cirq-core/cirq/contrib/quimb/density_matrix.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import lru_cache
from typing import Sequence, Dict, Union, Tuple, List, Optional, cast, Iterable
from typing import Sequence, Dict, Union, Tuple, List, Optional

import numpy as np
import quimb
Expand All @@ -9,41 +9,45 @@


@lru_cache()
def _qpos_tag(qubits: Union[cirq.LineQubit, Tuple[cirq.LineQubit]]):
def _qpos_tag(qubits: Union[cirq.Qid, Tuple[cirq.Qid]]):
"""Given a qubit or qubits, return a "position tag" (used for drawing).

For multiple qubits, the tag is for the first qubit.
"""
if isinstance(qubits, cirq.LineQubit):
if isinstance(qubits, cirq.Qid):
return _qpos_tag((qubits,))
x = min(q.x for q in qubits)
x = min(qubits)
return f'q{x}'


@lru_cache()
def _qpos_y(qubits: Union[cirq.LineQubit, Tuple[cirq.LineQubit]], tot_n_qubits: int):
def _qpos_y(
qubits: Union[cirq.Qid, Tuple[cirq.Qid, ...]], all_qubits: Tuple[cirq.Qid, ...]
) -> float:
"""Given a qubit or qubits, return the position y value (used for drawing).

For multiple qubits, the position is the mean of the qubit indices.
This "flips" the coordinate so qubit 0 is at the maximal y position.

Args:
qubits: The qubits involved in the tensor.
tot_n_qubits: The total number of qubits in the circuit, allowing us
all_qubits: All qubits in the circuit, allowing us
to position the zero'th qubit at the top.
"""
if isinstance(qubits, cirq.LineQubit):
return _qpos_y((qubits,), tot_n_qubits)
x = np.mean([q.x for q in qubits]).item()
return tot_n_qubits - x - 1
if isinstance(qubits, cirq.Qid):
return _qpos_y((qubits,), all_qubits)

pos = [all_qubits.index(q) for q in qubits]
x = np.mean(pos).item()
return len(all_qubits) - x - 1


def _add_to_positions(
positions: Dict[Tuple[str, str], Tuple[float, float]],
mi: int,
qubits: Union[cirq.LineQubit, Tuple[cirq.LineQubit]],
qubits: Union[cirq.Qid, Tuple[cirq.Qid]],
*,
tot_n_qubits: int,
all_qubits: Tuple[cirq.Qid, ...],
x_scale,
y_scale,
x_nudge,
Expand All @@ -55,7 +59,7 @@ def _add_to_positions(
positions: The dictionary to update. Quimb will consume this for drawing
mi: Moment index (used for x-positioning)
qubits: The qubits (used for y-positioning)
tot_n_qubits: The total number of qubits in the circuit, allowing us
all_qubits: All qubits in the circuit, allowing us
to position the zero'th qubit at the top.
x_scale: Stretch coordinates in the x direction
y_scale: Stretch coordinates in the y direction
Expand All @@ -65,15 +69,15 @@ def _add_to_positions(
the lines.
yb_offset: Offset the "backwards" circuit by this much.
"""
qy = _qpos_y(qubits, tot_n_qubits)
qy = _qpos_y(qubits, all_qubits)
positions[(f'i{mi}f', _qpos_tag(qubits))] = (mi * x_scale + qy * x_nudge, y_scale * qy)
positions[(f'i{mi}b', _qpos_tag(qubits))] = (mi * x_scale, y_scale * qy + yb_offset)


# TODO(#3388) Add documentation for Raises.
# pylint: disable=missing-raises-doc
def circuit_to_density_matrix_tensors(
circuit: cirq.Circuit, qubits: Optional[Sequence[cirq.LineQubit]] = None
circuit: cirq.Circuit, qubits: Optional[Sequence[cirq.Qid]] = None
) -> Tuple[List[qtn.Tensor], Dict['cirq.Qid', int], Dict[Tuple[str, str], Tuple[float, float]]]:
"""Given a circuit with mixtures or channels, construct a tensor network
representation of the density matrix.
Expand All @@ -87,7 +91,8 @@ def circuit_to_density_matrix_tensors(
Args:
circuit: The circuit containing operations that support the
cirq.unitary() or cirq.kraus() protocols.
qubits: The qubits in the circuit.
qubits: The qubits in the circuit. The `positions` return argument
will position qubits according to their index in this list.

Returns:
tensors: A list of Quimb Tensor objects
Expand All @@ -100,7 +105,8 @@ def circuit_to_density_matrix_tensors(
"""
if qubits is None:
# coverage: ignore
qubits = sorted(cast(Iterable[cirq.LineQubit], circuit.all_qubits()))
qubits = sorted(circuit.all_qubits())
qubits = tuple(qubits)

qubit_frontier: Dict[cirq.Qid, int] = {q: 0 for q in qubits}
kraus_frontier = 0
Expand All @@ -113,12 +119,12 @@ def circuit_to_density_matrix_tensors(
n_qubits = len(qubits)
yb_offset = (n_qubits + 0.5) * y_scale

def _positions(mi, qubits):
def _positions(_mi, _these_qubits):
return _add_to_positions(
positions,
mi,
qubits,
tot_n_qubits=n_qubits,
_mi,
_these_qubits,
all_qubits=qubits,
x_scale=x_scale,
y_scale=y_scale,
x_nudge=x_nudge,
Expand All @@ -130,22 +136,22 @@ def _positions(mi, qubits):
for q in qubits:
tensors += [
qtn.Tensor(
data=quimb.up().squeeze(), inds=(f'nf0_q{q.x}',), tags={'Q0', 'i0f', _qpos_tag(q)}
data=quimb.up().squeeze(), inds=(f'nf0_q{q}',), tags={'Q0', 'i0f', _qpos_tag(q)}
),
qtn.Tensor(
data=quimb.up().squeeze(), inds=(f'nb0_q{q.x}',), tags={'Q0', 'i0b', _qpos_tag(q)}
data=quimb.up().squeeze(), inds=(f'nb0_q{q}',), tags={'Q0', 'i0b', _qpos_tag(q)}
),
]
_positions(0, q)

for mi, moment in enumerate(circuit.moments):
for op in moment.operations:
start_inds_f = [f'nf{qubit_frontier[q]}_q{q.x}' for q in op.qubits]
start_inds_b = [f'nb{qubit_frontier[q]}_q{q.x}' for q in op.qubits]
start_inds_f = [f'nf{qubit_frontier[q]}_q{q}' for q in op.qubits]
start_inds_b = [f'nb{qubit_frontier[q]}_q{q}' for q in op.qubits]
for q in op.qubits:
qubit_frontier[q] += 1
end_inds_f = [f'nf{qubit_frontier[q]}_q{q.x}' for q in op.qubits]
end_inds_b = [f'nb{qubit_frontier[q]}_q{q.x}' for q in op.qubits]
end_inds_f = [f'nf{qubit_frontier[q]}_q{q}' for q in op.qubits]
end_inds_b = [f'nb{qubit_frontier[q]}_q{q}' for q in op.qubits]

if cirq.has_unitary(op):
U = cirq.unitary(op).reshape((2,) * 2 * len(op.qubits)).astype(np.complex128)
Expand Down Expand Up @@ -190,7 +196,7 @@ def _positions(mi, qubits):

# pylint: enable=missing-raises-doc
def tensor_density_matrix(
circuit: cirq.Circuit, qubits: Optional[List[cirq.LineQubit]] = None
circuit: cirq.Circuit, qubits: Optional[List[cirq.Qid]] = None
) -> np.ndarray:
"""Given a circuit with mixtures or channels, contract a tensor network
representing the resultant density matrix.
Expand All @@ -202,12 +208,12 @@ def tensor_density_matrix(
is encouraged for your particular problem if performance is important.
"""
if qubits is None:
qubits = sorted(cast(Iterable[cirq.LineQubit], circuit.all_qubits()))
qubits = sorted(circuit.all_qubits())

tensors, qubit_frontier, _ = circuit_to_density_matrix_tensors(circuit=circuit, qubits=qubits)
tn = qtn.TensorNetwork(tensors)
f_inds = tuple(f'nf{qubit_frontier[q]}_q{q.x}' for q in qubits)
b_inds = tuple(f'nb{qubit_frontier[q]}_q{q.x}' for q in qubits)
f_inds = tuple(f'nf{qubit_frontier[q]}_q{q}' for q in qubits)
b_inds = tuple(f'nb{qubit_frontier[q]}_q{q}' for q in qubits)
if len(qubits) <= 6:
# Heuristic: don't try to determine best order for low qubit number
# Just contract in time.
Expand Down
11 changes: 11 additions & 0 deletions cirq-core/cirq/contrib/quimb/density_matrix_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,14 @@ def test_tensor_density_matrix_4():
rho1 = cirq.final_density_matrix(circuit, dtype=np.complex128)
rho2 = ccq.tensor_density_matrix(circuit, qubits)
np.testing.assert_allclose(rho1, rho2, atol=1e-8)


def test_tensor_density_matrix_gridqubit():
qubits = cirq.GridQubit.rect(2, 2)
circuit = cirq.testing.random_circuit(qubits=qubits, n_moments=10, op_density=0.8)
cirq.DropEmptyMoments().optimize_circuit(circuit)
noise_model = cirq.ConstantQubitNoiseModel(cirq.DepolarizingChannel(p=1e-3))
circuit = cirq.Circuit(noise_model.noisy_moments(circuit.moments, qubits))
rho1 = cirq.final_density_matrix(circuit, dtype=np.complex128)
rho2 = ccq.tensor_density_matrix(circuit, qubits)
np.testing.assert_allclose(rho1, rho2, atol=1e-8)