Skip to content

Commit

Permalink
Make Cirq-FT registers multi-dimensional (#6200)
Browse files Browse the repository at this point in the history
* Make Cirq-FT registers multi-dimensional

* Update gate_with_registers notebook

* Fix mypy, lint and coverage checks

* Change Registers.bitsize to Registers.total_bits()

* Fix mypy errors in arithmetic_gates

* Change SelectionRegisters.build to match Registers.build and address other nits

* Fix typo
  • Loading branch information
tanujkhattar committed Jul 17, 2023
1 parent a076858 commit 83ede36
Show file tree
Hide file tree
Showing 36 changed files with 427 additions and 250 deletions.
11 changes: 7 additions & 4 deletions cirq-ft/cirq_ft/algos/and_gate.py
Expand Up @@ -14,6 +14,9 @@

from typing import Sequence, Tuple

import numpy as np
from numpy.typing import NDArray

import attr
import cirq
from cirq._compat import cached_property
Expand Down Expand Up @@ -110,16 +113,16 @@ def _decompose_single_and(

def _decompose_via_tree(
self,
controls: Sequence[cirq.Qid],
controls: NDArray[cirq.Qid], # type:ignore[type-var]
control_values: Sequence[int],
ancillas: Sequence[cirq.Qid],
ancillas: NDArray[cirq.Qid],
target: cirq.Qid,
) -> cirq.ops.op_tree.OpTree:
"""Decomposes multi-controlled `And` in-terms of an `And` ladder of size #controls- 2."""
if len(controls) == 2:
yield And(control_values, adjoint=self.adjoint).on(*controls, target)
return
new_controls = (ancillas[0], *controls[2:])
new_controls = np.concatenate([ancillas[0:1], controls[2:]])
new_control_values = (1, *control_values[2:])
and_op = And(control_values[:2], adjoint=self.adjoint).on(*controls[:2], ancillas[0])
if self.adjoint:
Expand All @@ -134,7 +137,7 @@ def _decompose_via_tree(
)

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid]
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid]
) -> cirq.OP_TREE:
control, ancilla, target = quregs['control'], quregs['ancilla'], quregs['target']
if len(self.cv) == 2:
Expand Down
10 changes: 6 additions & 4 deletions cirq-ft/cirq_ft/algos/and_gate_test.py
Expand Up @@ -45,7 +45,7 @@ def random_cv(n: int) -> List[int]:
def test_multi_controlled_and_gate(cv: List[int]):
gate = cirq_ft.And(cv)
r = gate.registers
assert r['ancilla'].bitsize == r['control'].bitsize - 2
assert r['ancilla'].total_bits() == r['control'].total_bits() - 2
quregs = r.get_named_qubits()
and_op = gate.on_registers(**quregs)
circuit = cirq.Circuit(and_op)
Expand All @@ -54,7 +54,7 @@ def test_multi_controlled_and_gate(cv: List[int]):
qubit_order = gate.registers.merge_qubits(**quregs)

for input_control in input_controls:
initial_state = input_control + [0] * (r['ancilla'].bitsize + 1)
initial_state = input_control + [0] * (r['ancilla'].total_bits() + 1)
result = cirq.Simulator(dtype=np.complex128).simulate(
circuit, initial_state=initial_state, qubit_order=qubit_order
)
Expand All @@ -80,8 +80,10 @@ def test_and_gate_diagram():
qubit_regs = gate.registers.get_named_qubits()
op = gate.on_registers(**qubit_regs)
# Qubit order should be alternating (control, ancilla) pairs.
c_and_a = sum(zip(qubit_regs["control"][1:], qubit_regs["ancilla"] + [0]), ())[:-1]
qubit_order = qubit_regs["control"][0:1] + list(c_and_a) + qubit_regs["target"]
c_and_a = sum(zip(qubit_regs["control"][1:], qubit_regs["ancilla"]), ()) + (
qubit_regs["control"][-1],
)
qubit_order = np.concatenate([qubit_regs["control"][0:1], c_and_a, qubit_regs["target"]])
# Test diagrams.
cirq.testing.assert_has_diagram(
cirq.Circuit(op),
Expand Down
4 changes: 2 additions & 2 deletions cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.ipynb
Expand Up @@ -89,7 +89,7 @@
" return cirq.I\n",
"\n",
"apply_z_to_odd = cirq_ft.ApplyGateToLthQubit(\n",
" cirq_ft.SelectionRegisters.build(selection=(3, 4)),\n",
" cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('selection', 3, 4)]),\n",
" nth_gate=_z_to_odd,\n",
" control_regs=cirq_ft.Registers.build(control=2),\n",
")\n",
Expand Down Expand Up @@ -123,4 +123,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
12 changes: 8 additions & 4 deletions cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py
Expand Up @@ -56,8 +56,12 @@ def make_on(
) -> cirq.Operation:
"""Helper constructor to automatically deduce bitsize attributes."""
return cls(
infra.SelectionRegisters.build(
selection=(len(quregs['selection']), len(quregs['target']))
infra.SelectionRegisters(
[
infra.SelectionRegister(
'selection', len(quregs['selection']), len(quregs['target'])
)
]
),
nth_gate=nth_gate,
control_regs=infra.Registers.build(control=len(quregs['control'])),
Expand All @@ -76,8 +80,8 @@ def target_registers(self) -> infra.Registers:
return infra.Registers.build(target=self.selection_registers.total_iteration_size)

def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
wire_symbols = ["@"] * self.control_registers.bitsize
wire_symbols += ["In"] * self.selection_registers.bitsize
wire_symbols = ["@"] * self.control_registers.total_bits()
wire_symbols += ["In"] * self.selection_registers.total_bits()
for it in itertools.product(*[range(x) for x in self.selection_regs.iteration_lengths]):
wire_symbols += [str(self.nth_gate(*it))]
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)
Expand Down
10 changes: 6 additions & 4 deletions cirq-ft/cirq_ft/algos/apply_gate_to_lth_target_test.py
Expand Up @@ -23,14 +23,16 @@
def test_apply_gate_to_lth_qubit(selection_bitsize, target_bitsize):
greedy_mm = cirq_ft.GreedyQubitManager(prefix="_a", maximize_reuse=True)
gate = cirq_ft.ApplyGateToLthQubit(
cirq_ft.SelectionRegisters.build(selection=(selection_bitsize, target_bitsize)),
cirq_ft.SelectionRegisters(
[cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize)]
),
lambda _: cirq.X,
)
g = cirq_ft.testing.GateHelper(gate, context=cirq.DecompositionContext(greedy_mm))
# Upper bounded because not all ancillas may be used as part of unary iteration.
assert (
len(g.all_qubits)
<= target_bitsize + 2 * (selection_bitsize + gate.control_registers.bitsize) - 1
<= target_bitsize + 2 * (selection_bitsize + gate.control_registers.total_bits()) - 1
)

for n in range(target_bitsize):
Expand All @@ -52,7 +54,7 @@ def test_apply_gate_to_lth_qubit(selection_bitsize, target_bitsize):
def test_apply_gate_to_lth_qubit_diagram():
# Apply Z gate to all odd targets and Identity to even targets.
gate = cirq_ft.ApplyGateToLthQubit(
cirq_ft.SelectionRegisters.build(selection=(3, 5)),
cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('selection', 3, 5)]),
lambda n: cirq.Z if n & 1 else cirq.I,
control_regs=cirq_ft.Registers.build(control=2),
)
Expand Down Expand Up @@ -87,7 +89,7 @@ def test_apply_gate_to_lth_qubit_diagram():

def test_apply_gate_to_lth_qubit_make_on():
gate = cirq_ft.ApplyGateToLthQubit(
cirq_ft.SelectionRegisters.build(selection=(3, 5)),
cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('selection', 3, 5)]),
lambda n: cirq.Z if n & 1 else cirq.I,
control_regs=cirq_ft.Registers.build(control=2),
)
Expand Down
8 changes: 6 additions & 2 deletions cirq-ft/cirq_ft/algos/arithmetic_gates.py
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from typing import Iterable, Optional, Sequence, Tuple, Union, List, Iterator
from numpy.typing import NDArray

from cirq._compat import cached_property
import attr
Expand Down Expand Up @@ -153,7 +154,10 @@ def __repr__(self) -> str:
return f'cirq_ft.algos.BiQubitsMixer({self.adjoint})'

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid]
self,
*,
context: cirq.DecompositionContext,
**quregs: NDArray[cirq.Qid], # type:ignore[type-var]
) -> cirq.OP_TREE:
x, y, ancilla = quregs['x'], quregs['y'], quregs['ancilla']
x_msb, x_lsb = x
Expand Down Expand Up @@ -224,7 +228,7 @@ def __repr__(self) -> str:
return f'cirq_ft.algos.SingleQubitCompare({self.adjoint})'

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid]
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid]
) -> cirq.OP_TREE:
a = quregs['a']
b = quregs['b']
Expand Down
19 changes: 13 additions & 6 deletions cirq-ft/cirq_ft/algos/generic_select.py
Expand Up @@ -15,6 +15,7 @@
"""Gates for applying generic selected unitaries."""

from typing import Collection, Optional, Sequence, Tuple, Union
from numpy.typing import NDArray

import attr
import cirq
Expand Down Expand Up @@ -73,20 +74,26 @@ def control_registers(self) -> infra.Registers:

@cached_property
def selection_registers(self) -> infra.SelectionRegisters:
return infra.SelectionRegisters.build(
selection=(self.selection_bitsize, len(self.select_unitaries))
return infra.SelectionRegisters(
[
infra.SelectionRegister(
'selection', self.selection_bitsize, len(self.select_unitaries)
)
]
)

@cached_property
def target_registers(self) -> infra.Registers:
return infra.Registers.build(target=self.target_bitsize)

def decompose_from_registers(self, context, **qubit_regs: Sequence[cirq.Qid]) -> cirq.OP_TREE:
def decompose_from_registers(
self, context, **quregs: NDArray[cirq.Qid] # type:ignore[type-var]
) -> cirq.OP_TREE:
if self.control_val == 0:
yield cirq.X(*qubit_regs['control'])
yield super(GenericSelect, self).decompose_from_registers(context=context, **qubit_regs)
yield cirq.X(*quregs['control'])
yield super(GenericSelect, self).decompose_from_registers(context=context, **quregs)
if self.control_val == 0:
yield cirq.X(*qubit_regs['control'])
yield cirq.X(*quregs['control'])

def nth_operation( # type: ignore[override]
self,
Expand Down
78 changes: 46 additions & 32 deletions cirq-ft/cirq_ft/algos/hubbard_model.py
Expand Up @@ -47,6 +47,7 @@
See the documentation for `PrepareHubbard` and `SelectHubbard` for details.
"""
from typing import Collection, Optional, Sequence, Tuple, Union
from numpy.typing import NDArray

import attr
import cirq
Expand Down Expand Up @@ -123,15 +124,17 @@ def control_registers(self) -> infra.Registers:

@cached_property
def selection_registers(self) -> infra.SelectionRegisters:
return infra.SelectionRegisters.build(
U=(1, 2),
V=(1, 2),
p_x=((self.x_dim - 1).bit_length(), self.x_dim),
p_y=((self.y_dim - 1).bit_length(), self.y_dim),
alpha=(1, 2),
q_x=((self.x_dim - 1).bit_length(), self.x_dim),
q_y=((self.y_dim - 1).bit_length(), self.y_dim),
beta=(1, 2),
return infra.SelectionRegisters(
[
infra.SelectionRegister('U', 1, 2),
infra.SelectionRegister('V', 1, 2),
infra.SelectionRegister('p_x', (self.x_dim - 1).bit_length(), self.x_dim),
infra.SelectionRegister('p_y', (self.y_dim - 1).bit_length(), self.y_dim),
infra.SelectionRegister('alpha', 1, 2),
infra.SelectionRegister('q_x', (self.x_dim - 1).bit_length(), self.x_dim),
infra.SelectionRegister('q_y', (self.y_dim - 1).bit_length(), self.y_dim),
infra.SelectionRegister('beta', 1, 2),
]
)

@cached_property
Expand All @@ -145,17 +148,22 @@ def registers(self) -> infra.Registers:
)

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid]
self,
*,
context: cirq.DecompositionContext,
**quregs: NDArray[cirq.Qid], # type:ignore[type-var]
) -> cirq.OP_TREE:
p_x, p_y, q_x, q_y = quregs['p_x'], quregs['p_y'], quregs['q_x'], quregs['q_y']
U, V, alpha, beta = quregs['U'], quregs['V'], quregs['alpha'], quregs['beta']
control, target = quregs.get('control', ()), quregs['target']

yield selected_majorana_fermion.SelectedMajoranaFermionGate(
selection_regs=infra.SelectionRegisters.build(
alpha=(1, 2),
p_y=(self.registers['p_y'].bitsize, self.y_dim),
p_x=(self.registers['p_x'].bitsize, self.x_dim),
selection_regs=infra.SelectionRegisters(
[
infra.SelectionRegister('alpha', 1, 2),
infra.SelectionRegister('p_y', self.registers['p_y'].total_bits(), self.y_dim),
infra.SelectionRegister('p_x', self.registers['p_x'].total_bits(), self.x_dim),
]
),
control_regs=self.control_registers,
target_gate=cirq.Y,
Expand All @@ -165,10 +173,12 @@ def decompose_from_registers(
yield swap_network.MultiTargetCSwap.make_on(control=V, target_x=p_y, target_y=q_y)
yield swap_network.MultiTargetCSwap.make_on(control=V, target_x=alpha, target_y=beta)

q_selection_regs = infra.SelectionRegisters.build(
beta=(1, 2),
q_y=(self.registers['q_y'].bitsize, self.y_dim),
q_x=(self.registers['q_x'].bitsize, self.x_dim),
q_selection_regs = infra.SelectionRegisters(
[
infra.SelectionRegister('beta', 1, 2),
infra.SelectionRegister('q_y', self.registers['q_y'].total_bits(), self.y_dim),
infra.SelectionRegister('q_x', self.registers['q_x'].total_bits(), self.x_dim),
]
)
yield selected_majorana_fermion.SelectedMajoranaFermionGate(
selection_regs=q_selection_regs, control_regs=self.control_registers, target_gate=cirq.X
Expand All @@ -190,12 +200,14 @@ def decompose_from_registers(
]

yield apply_gate_to_lth_target.ApplyGateToLthQubit(
selection_regs=infra.SelectionRegisters.build(
q_y=(self.registers['q_y'].bitsize, self.y_dim),
q_x=(self.registers['q_x'].bitsize, self.x_dim),
selection_regs=infra.SelectionRegisters(
[
infra.SelectionRegister('q_y', self.registers['q_y'].total_bits(), self.y_dim),
infra.SelectionRegister('q_x', self.registers['q_x'].total_bits(), self.x_dim),
]
),
nth_gate=lambda *_: cirq.Z,
control_regs=infra.Registers.build(control=1 + self.control_registers.bitsize),
control_regs=infra.Registers.build(control=1 + self.control_registers.total_bits()),
).on_registers(
q_x=q_x, q_y=q_y, control=[*V, *control], target=target_qubits_for_apply_to_lth_gate
)
Expand Down Expand Up @@ -280,15 +292,17 @@ def __attrs_post_init__(self):

@cached_property
def selection_registers(self) -> infra.SelectionRegisters:
return infra.SelectionRegisters.build(
U=(1, 2),
V=(1, 2),
p_x=((self.x_dim - 1).bit_length(), self.x_dim),
p_y=((self.y_dim - 1).bit_length(), self.y_dim),
alpha=(1, 2),
q_x=((self.x_dim - 1).bit_length(), self.x_dim),
q_y=((self.y_dim - 1).bit_length(), self.y_dim),
beta=(1, 2),
return infra.SelectionRegisters(
[
infra.SelectionRegister('U', 1, 2),
infra.SelectionRegister('V', 1, 2),
infra.SelectionRegister('p_x', (self.x_dim - 1).bit_length(), self.x_dim),
infra.SelectionRegister('p_y', (self.y_dim - 1).bit_length(), self.y_dim),
infra.SelectionRegister('alpha', 1, 2),
infra.SelectionRegister('q_x', (self.x_dim - 1).bit_length(), self.x_dim),
infra.SelectionRegister('q_y', (self.y_dim - 1).bit_length(), self.y_dim),
infra.SelectionRegister('beta', 1, 2),
]
)

@cached_property
Expand All @@ -300,7 +314,7 @@ def registers(self) -> infra.Registers:
return infra.Registers([*self.selection_registers, *self.junk_registers])

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid]
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid]
) -> cirq.OP_TREE:
p_x, p_y, q_x, q_y = quregs['p_x'], quregs['p_y'], quregs['q_x'], quregs['q_y']
U, V, alpha, beta = quregs['U'], quregs['V'], quregs['alpha'], quregs['beta']
Expand Down
15 changes: 10 additions & 5 deletions cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle.py
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Sequence
from numpy.typing import NDArray

import attr
import cirq
Expand Down Expand Up @@ -49,10 +49,15 @@ def registers(self) -> infra.Registers:
return infra.Registers([*self.control_registers, *self.selection_registers])

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid]
self,
*,
context: cirq.DecompositionContext,
**quregs: NDArray[cirq.Qid], # type:ignore[type-var]
) -> cirq.OP_TREE:
qm = context.qubit_manager
target_reg = {reg.name: qm.qalloc(reg.bitsize) for reg in self.encoder.target_registers}
target_reg = {
reg.name: qm.qalloc(reg.total_bits()) for reg in self.encoder.target_registers
}
target_qubits = self.encoder.target_registers.merge_qubits(**target_reg)
encoder_op = self.encoder.on_registers(**quregs, **target_reg)

Expand All @@ -73,6 +78,6 @@ def decompose_from_registers(
qm.qfree([*arctan_sign, *arctan_target, *target_qubits])

def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
wire_symbols = ['@'] * self.control_registers.bitsize
wire_symbols += ['ROTy'] * self.selection_registers.bitsize
wire_symbols = ['@'] * self.control_registers.total_bits()
wire_symbols += ['ROTy'] * self.selection_registers.total_bits()
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)
Expand Up @@ -38,7 +38,7 @@ def control_registers(self) -> cirq_ft.Registers:

@cached_property
def selection_registers(self) -> cirq_ft.SelectionRegisters:
return cirq_ft.SelectionRegisters.build(selection=(self.bitsize, 2**self.bitsize))
return cirq_ft.SelectionRegisters.build(selection=self.bitsize)

@cached_property
def target_registers(self) -> cirq_ft.Registers:
Expand Down

0 comments on commit 83ede36

Please sign in to comment.