Skip to content

Commit

Permalink
Allow any qubit type in quimb density matrix (quantumlib#4547)
Browse files Browse the repository at this point in the history
Used to restrict to linequbit just to make positioning easier when drawing the tensor networks. Now instead of using `q.x`, we use `qubits.index(q)` to find a qubits linear position.
  • Loading branch information
mpharrigan authored and rht committed May 1, 2023
1 parent aabc851 commit 0a59503
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 31 deletions.
68 changes: 37 additions & 31 deletions cirq-core/cirq/contrib/quimb/density_matrix.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import lru_cache
from typing import Sequence, Dict, Union, Tuple, List, Optional, cast, Iterable
from typing import Sequence, Dict, Union, Tuple, List, Optional

import numpy as np
import quimb
Expand All @@ -9,41 +9,45 @@


@lru_cache()
def _qpos_tag(qubits: Union[cirq.LineQubit, Tuple[cirq.LineQubit]]):
def _qpos_tag(qubits: Union[cirq.Qid, Tuple[cirq.Qid]]):
"""Given a qubit or qubits, return a "position tag" (used for drawing).
For multiple qubits, the tag is for the first qubit.
"""
if isinstance(qubits, cirq.LineQubit):
if isinstance(qubits, cirq.Qid):
return _qpos_tag((qubits,))
x = min(q.x for q in qubits)
x = min(qubits)
return f'q{x}'


@lru_cache()
def _qpos_y(qubits: Union[cirq.LineQubit, Tuple[cirq.LineQubit]], tot_n_qubits: int):
def _qpos_y(
qubits: Union[cirq.Qid, Tuple[cirq.Qid, ...]], all_qubits: Tuple[cirq.Qid, ...]
) -> float:
"""Given a qubit or qubits, return the position y value (used for drawing).
For multiple qubits, the position is the mean of the qubit indices.
This "flips" the coordinate so qubit 0 is at the maximal y position.
Args:
qubits: The qubits involved in the tensor.
tot_n_qubits: The total number of qubits in the circuit, allowing us
all_qubits: All qubits in the circuit, allowing us
to position the zero'th qubit at the top.
"""
if isinstance(qubits, cirq.LineQubit):
return _qpos_y((qubits,), tot_n_qubits)
x = np.mean([q.x for q in qubits]).item()
return tot_n_qubits - x - 1
if isinstance(qubits, cirq.Qid):
return _qpos_y((qubits,), all_qubits)

pos = [all_qubits.index(q) for q in qubits]
x = np.mean(pos).item()
return len(all_qubits) - x - 1


def _add_to_positions(
positions: Dict[Tuple[str, str], Tuple[float, float]],
mi: int,
qubits: Union[cirq.LineQubit, Tuple[cirq.LineQubit]],
qubits: Union[cirq.Qid, Tuple[cirq.Qid]],
*,
tot_n_qubits: int,
all_qubits: Tuple[cirq.Qid, ...],
x_scale,
y_scale,
x_nudge,
Expand All @@ -55,7 +59,7 @@ def _add_to_positions(
positions: The dictionary to update. Quimb will consume this for drawing
mi: Moment index (used for x-positioning)
qubits: The qubits (used for y-positioning)
tot_n_qubits: The total number of qubits in the circuit, allowing us
all_qubits: All qubits in the circuit, allowing us
to position the zero'th qubit at the top.
x_scale: Stretch coordinates in the x direction
y_scale: Stretch coordinates in the y direction
Expand All @@ -65,15 +69,15 @@ def _add_to_positions(
the lines.
yb_offset: Offset the "backwards" circuit by this much.
"""
qy = _qpos_y(qubits, tot_n_qubits)
qy = _qpos_y(qubits, all_qubits)
positions[(f'i{mi}f', _qpos_tag(qubits))] = (mi * x_scale + qy * x_nudge, y_scale * qy)
positions[(f'i{mi}b', _qpos_tag(qubits))] = (mi * x_scale, y_scale * qy + yb_offset)


# TODO(#3388) Add documentation for Raises.
# pylint: disable=missing-raises-doc
def circuit_to_density_matrix_tensors(
circuit: cirq.Circuit, qubits: Optional[Sequence[cirq.LineQubit]] = None
circuit: cirq.Circuit, qubits: Optional[Sequence[cirq.Qid]] = None
) -> Tuple[List[qtn.Tensor], Dict['cirq.Qid', int], Dict[Tuple[str, str], Tuple[float, float]]]:
"""Given a circuit with mixtures or channels, construct a tensor network
representation of the density matrix.
Expand All @@ -87,7 +91,8 @@ def circuit_to_density_matrix_tensors(
Args:
circuit: The circuit containing operations that support the
cirq.unitary() or cirq.kraus() protocols.
qubits: The qubits in the circuit.
qubits: The qubits in the circuit. The `positions` return argument
will position qubits according to their index in this list.
Returns:
tensors: A list of Quimb Tensor objects
Expand All @@ -100,7 +105,8 @@ def circuit_to_density_matrix_tensors(
"""
if qubits is None:
# coverage: ignore
qubits = sorted(cast(Iterable[cirq.LineQubit], circuit.all_qubits()))
qubits = sorted(circuit.all_qubits())
qubits = tuple(qubits)

qubit_frontier: Dict[cirq.Qid, int] = {q: 0 for q in qubits}
kraus_frontier = 0
Expand All @@ -113,12 +119,12 @@ def circuit_to_density_matrix_tensors(
n_qubits = len(qubits)
yb_offset = (n_qubits + 0.5) * y_scale

def _positions(mi, qubits):
def _positions(_mi, _these_qubits):
return _add_to_positions(
positions,
mi,
qubits,
tot_n_qubits=n_qubits,
_mi,
_these_qubits,
all_qubits=qubits,
x_scale=x_scale,
y_scale=y_scale,
x_nudge=x_nudge,
Expand All @@ -130,22 +136,22 @@ def _positions(mi, qubits):
for q in qubits:
tensors += [
qtn.Tensor(
data=quimb.up().squeeze(), inds=(f'nf0_q{q.x}',), tags={'Q0', 'i0f', _qpos_tag(q)}
data=quimb.up().squeeze(), inds=(f'nf0_q{q}',), tags={'Q0', 'i0f', _qpos_tag(q)}
),
qtn.Tensor(
data=quimb.up().squeeze(), inds=(f'nb0_q{q.x}',), tags={'Q0', 'i0b', _qpos_tag(q)}
data=quimb.up().squeeze(), inds=(f'nb0_q{q}',), tags={'Q0', 'i0b', _qpos_tag(q)}
),
]
_positions(0, q)

for mi, moment in enumerate(circuit.moments):
for op in moment.operations:
start_inds_f = [f'nf{qubit_frontier[q]}_q{q.x}' for q in op.qubits]
start_inds_b = [f'nb{qubit_frontier[q]}_q{q.x}' for q in op.qubits]
start_inds_f = [f'nf{qubit_frontier[q]}_q{q}' for q in op.qubits]
start_inds_b = [f'nb{qubit_frontier[q]}_q{q}' for q in op.qubits]
for q in op.qubits:
qubit_frontier[q] += 1
end_inds_f = [f'nf{qubit_frontier[q]}_q{q.x}' for q in op.qubits]
end_inds_b = [f'nb{qubit_frontier[q]}_q{q.x}' for q in op.qubits]
end_inds_f = [f'nf{qubit_frontier[q]}_q{q}' for q in op.qubits]
end_inds_b = [f'nb{qubit_frontier[q]}_q{q}' for q in op.qubits]

if cirq.has_unitary(op):
U = cirq.unitary(op).reshape((2,) * 2 * len(op.qubits)).astype(np.complex128)
Expand Down Expand Up @@ -190,7 +196,7 @@ def _positions(mi, qubits):

# pylint: enable=missing-raises-doc
def tensor_density_matrix(
circuit: cirq.Circuit, qubits: Optional[List[cirq.LineQubit]] = None
circuit: cirq.Circuit, qubits: Optional[List[cirq.Qid]] = None
) -> np.ndarray:
"""Given a circuit with mixtures or channels, contract a tensor network
representing the resultant density matrix.
Expand All @@ -202,12 +208,12 @@ def tensor_density_matrix(
is encouraged for your particular problem if performance is important.
"""
if qubits is None:
qubits = sorted(cast(Iterable[cirq.LineQubit], circuit.all_qubits()))
qubits = sorted(circuit.all_qubits())

tensors, qubit_frontier, _ = circuit_to_density_matrix_tensors(circuit=circuit, qubits=qubits)
tn = qtn.TensorNetwork(tensors)
f_inds = tuple(f'nf{qubit_frontier[q]}_q{q.x}' for q in qubits)
b_inds = tuple(f'nb{qubit_frontier[q]}_q{q.x}' for q in qubits)
f_inds = tuple(f'nf{qubit_frontier[q]}_q{q}' for q in qubits)
b_inds = tuple(f'nb{qubit_frontier[q]}_q{q}' for q in qubits)
if len(qubits) <= 6:
# Heuristic: don't try to determine best order for low qubit number
# Just contract in time.
Expand Down
11 changes: 11 additions & 0 deletions cirq-core/cirq/contrib/quimb/density_matrix_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,14 @@ def test_tensor_density_matrix_4():
rho1 = cirq.final_density_matrix(circuit, dtype=np.complex128)
rho2 = ccq.tensor_density_matrix(circuit, qubits)
np.testing.assert_allclose(rho1, rho2, atol=1e-8)


def test_tensor_density_matrix_gridqubit():
qubits = cirq.GridQubit.rect(2, 2)
circuit = cirq.testing.random_circuit(qubits=qubits, n_moments=10, op_density=0.8)
cirq.DropEmptyMoments().optimize_circuit(circuit)
noise_model = cirq.ConstantQubitNoiseModel(cirq.DepolarizingChannel(p=1e-3))
circuit = cirq.Circuit(noise_model.noisy_moments(circuit.moments, qubits))
rho1 = cirq.final_density_matrix(circuit, dtype=np.complex128)
rho2 = ccq.tensor_density_matrix(circuit, qubits)
np.testing.assert_allclose(rho1, rho2, atol=1e-8)

0 comments on commit 0a59503

Please sign in to comment.