Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Cirq-FT registers multi-dimensional #6200

Merged
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 7 additions & 4 deletions cirq-ft/cirq_ft/algos/and_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

from typing import Sequence, Tuple

import numpy as np
from numpy.typing import NDArray

import attr
import cirq
from cirq._compat import cached_property
Expand Down Expand Up @@ -110,16 +113,16 @@ def _decompose_single_and(

def _decompose_via_tree(
self,
controls: Sequence[cirq.Qid],
controls: NDArray[cirq.Qid], # type:ignore[type-var]
control_values: Sequence[int],
ancillas: Sequence[cirq.Qid],
ancillas: NDArray[cirq.Qid],
target: cirq.Qid,
) -> cirq.ops.op_tree.OpTree:
"""Decomposes multi-controlled `And` in-terms of an `And` ladder of size #controls- 2."""
if len(controls) == 2:
yield And(control_values, adjoint=self.adjoint).on(*controls, target)
return
new_controls = (ancillas[0], *controls[2:])
new_controls = np.concatenate([ancillas[0:1], controls[2:]])
new_control_values = (1, *control_values[2:])
and_op = And(control_values[:2], adjoint=self.adjoint).on(*controls[:2], ancillas[0])
if self.adjoint:
Expand All @@ -134,7 +137,7 @@ def _decompose_via_tree(
)

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid]
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid]
) -> cirq.OP_TREE:
control, ancilla, target = quregs['control'], quregs['ancilla'], quregs['target']
if len(self.cv) == 2:
Expand Down
10 changes: 6 additions & 4 deletions cirq-ft/cirq_ft/algos/and_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def random_cv(n: int) -> List[int]:
def test_multi_controlled_and_gate(cv: List[int]):
gate = cirq_ft.And(cv)
r = gate.registers
assert r['ancilla'].bitsize == r['control'].bitsize - 2
assert r['ancilla'].total_bits() == r['control'].total_bits() - 2
quregs = r.get_named_qubits()
and_op = gate.on_registers(**quregs)
circuit = cirq.Circuit(and_op)
Expand All @@ -54,7 +54,7 @@ def test_multi_controlled_and_gate(cv: List[int]):
qubit_order = gate.registers.merge_qubits(**quregs)

for input_control in input_controls:
initial_state = input_control + [0] * (r['ancilla'].bitsize + 1)
initial_state = input_control + [0] * (r['ancilla'].total_bits() + 1)
result = cirq.Simulator(dtype=np.complex128).simulate(
circuit, initial_state=initial_state, qubit_order=qubit_order
)
Expand All @@ -80,8 +80,10 @@ def test_and_gate_diagram():
qubit_regs = gate.registers.get_named_qubits()
op = gate.on_registers(**qubit_regs)
# Qubit order should be alternating (control, ancilla) pairs.
c_and_a = sum(zip(qubit_regs["control"][1:], qubit_regs["ancilla"] + [0]), ())[:-1]
qubit_order = qubit_regs["control"][0:1] + list(c_and_a) + qubit_regs["target"]
c_and_a = sum(zip(qubit_regs["control"][1:], qubit_regs["ancilla"]), ()) + (
qubit_regs["control"][-1],
)
qubit_order = np.concatenate([qubit_regs["control"][0:1], c_and_a, qubit_regs["target"]])
# Test diagrams.
cirq.testing.assert_has_diagram(
cirq.Circuit(op),
Expand Down
4 changes: 2 additions & 2 deletions cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
" return cirq.I\n",
"\n",
"apply_z_to_odd = cirq_ft.ApplyGateToLthQubit(\n",
" cirq_ft.SelectionRegisters.build(selection=(3, 4)),\n",
" cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('selection', 3, 4)]),\n",
" nth_gate=_z_to_odd,\n",
" control_regs=cirq_ft.Registers.build(control=2),\n",
")\n",
Expand Down Expand Up @@ -123,4 +123,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
12 changes: 8 additions & 4 deletions cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,12 @@ def make_on(
) -> cirq.Operation:
"""Helper constructor to automatically deduce bitsize attributes."""
return cls(
infra.SelectionRegisters.build(
selection=(len(quregs['selection']), len(quregs['target']))
infra.SelectionRegisters(
[
infra.SelectionRegister(
'selection', len(quregs['selection']), len(quregs['target'])
)
]
),
nth_gate=nth_gate,
control_regs=infra.Registers.build(control=len(quregs['control'])),
Expand All @@ -76,8 +80,8 @@ def target_registers(self) -> infra.Registers:
return infra.Registers.build(target=self.selection_registers.total_iteration_size)

def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
wire_symbols = ["@"] * self.control_registers.bitsize
wire_symbols += ["In"] * self.selection_registers.bitsize
wire_symbols = ["@"] * self.control_registers.total_bits()
wire_symbols += ["In"] * self.selection_registers.total_bits()
for it in itertools.product(*[range(x) for x in self.selection_regs.iteration_lengths]):
wire_symbols += [str(self.nth_gate(*it))]
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)
Expand Down
10 changes: 6 additions & 4 deletions cirq-ft/cirq_ft/algos/apply_gate_to_lth_target_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,16 @@
def test_apply_gate_to_lth_qubit(selection_bitsize, target_bitsize):
greedy_mm = cirq_ft.GreedyQubitManager(prefix="_a", maximize_reuse=True)
gate = cirq_ft.ApplyGateToLthQubit(
cirq_ft.SelectionRegisters.build(selection=(selection_bitsize, target_bitsize)),
cirq_ft.SelectionRegisters(
[cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize)]
),
lambda _: cirq.X,
)
g = cirq_ft.testing.GateHelper(gate, context=cirq.DecompositionContext(greedy_mm))
# Upper bounded because not all ancillas may be used as part of unary iteration.
assert (
len(g.all_qubits)
<= target_bitsize + 2 * (selection_bitsize + gate.control_registers.bitsize) - 1
<= target_bitsize + 2 * (selection_bitsize + gate.control_registers.total_bits()) - 1
)

for n in range(target_bitsize):
Expand All @@ -52,7 +54,7 @@ def test_apply_gate_to_lth_qubit(selection_bitsize, target_bitsize):
def test_apply_gate_to_lth_qubit_diagram():
# Apply Z gate to all odd targets and Identity to even targets.
gate = cirq_ft.ApplyGateToLthQubit(
cirq_ft.SelectionRegisters.build(selection=(3, 5)),
cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('selection', 3, 5)]),
lambda n: cirq.Z if n & 1 else cirq.I,
control_regs=cirq_ft.Registers.build(control=2),
)
Expand Down Expand Up @@ -87,7 +89,7 @@ def test_apply_gate_to_lth_qubit_diagram():

def test_apply_gate_to_lth_qubit_make_on():
gate = cirq_ft.ApplyGateToLthQubit(
cirq_ft.SelectionRegisters.build(selection=(3, 5)),
cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('selection', 3, 5)]),
lambda n: cirq.Z if n & 1 else cirq.I,
control_regs=cirq_ft.Registers.build(control=2),
)
Expand Down
8 changes: 6 additions & 2 deletions cirq-ft/cirq_ft/algos/arithmetic_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from typing import Iterable, Optional, Sequence, Tuple, Union, List, Iterator
from numpy.typing import NDArray

from cirq._compat import cached_property
import attr
Expand Down Expand Up @@ -153,7 +154,10 @@ def __repr__(self) -> str:
return f'cirq_ft.algos.BiQubitsMixer({self.adjoint})'

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid]
self,
*,
context: cirq.DecompositionContext,
**quregs: NDArray[cirq.Qid], # type:ignore[type-var]
) -> cirq.OP_TREE:
x, y, ancilla = quregs['x'], quregs['y'], quregs['ancilla']
x_msb, x_lsb = x
Expand Down Expand Up @@ -224,7 +228,7 @@ def __repr__(self) -> str:
return f'cirq_ft.algos.SingleQubitCompare({self.adjoint})'

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid]
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid]
) -> cirq.OP_TREE:
a = quregs['a']
b = quregs['b']
Expand Down
19 changes: 13 additions & 6 deletions cirq-ft/cirq_ft/algos/generic_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Gates for applying generic selected unitaries."""

from typing import Collection, Optional, Sequence, Tuple, Union
from numpy.typing import NDArray

import attr
import cirq
Expand Down Expand Up @@ -73,20 +74,26 @@ def control_registers(self) -> infra.Registers:

@cached_property
def selection_registers(self) -> infra.SelectionRegisters:
return infra.SelectionRegisters.build(
selection=(self.selection_bitsize, len(self.select_unitaries))
return infra.SelectionRegisters(
[
infra.SelectionRegister(
'selection', self.selection_bitsize, len(self.select_unitaries)
)
]
)

@cached_property
def target_registers(self) -> infra.Registers:
return infra.Registers.build(target=self.target_bitsize)

def decompose_from_registers(self, context, **qubit_regs: Sequence[cirq.Qid]) -> cirq.OP_TREE:
def decompose_from_registers(
self, context, **quregs: NDArray[cirq.Qid] # type:ignore[type-var]
) -> cirq.OP_TREE:
if self.control_val == 0:
yield cirq.X(*qubit_regs['control'])
yield super().decompose_from_registers(context=context, **qubit_regs)
yield cirq.X(*quregs['control'])
yield super().decompose_from_registers(context=context, **quregs)
if self.control_val == 0:
yield cirq.X(*qubit_regs['control'])
yield cirq.X(*quregs['control'])

def nth_operation( # type: ignore[override]
self,
Expand Down
79 changes: 47 additions & 32 deletions cirq-ft/cirq_ft/algos/hubbard_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
See the documentation for `PrepareHubbard` and `SelectHubbard` for details.
"""
from typing import Collection, Optional, Sequence, Tuple, Union
from numpy.typing import NDArray

import attr
import cirq
Expand Down Expand Up @@ -123,15 +124,18 @@ def control_registers(self) -> infra.Registers:

@cached_property
def selection_registers(self) -> infra.SelectionRegisters:
return infra.SelectionRegisters.build(
U=(1, 2),
V=(1, 2),
p_x=((self.x_dim - 1).bit_length(), self.x_dim),
p_y=((self.y_dim - 1).bit_length(), self.y_dim),
alpha=(1, 2),
q_x=((self.x_dim - 1).bit_length(), self.x_dim),
q_y=((self.y_dim - 1).bit_length(), self.y_dim),
beta=(1, 2),
return infra.SelectionRegisters(
[
infra.SelectionRegister('U', 1, 2),
infra.SelectionRegister('U', 1, 2),
infra.SelectionRegister('V', 1, 2),
infra.SelectionRegister('p_x', (self.x_dim - 1).bit_length(), self.x_dim),
infra.SelectionRegister('p_y', (self.y_dim - 1).bit_length(), self.y_dim),
infra.SelectionRegister('alpha', 1, 2),
infra.SelectionRegister('q_x', (self.x_dim - 1).bit_length(), self.x_dim),
infra.SelectionRegister('q_y', (self.y_dim - 1).bit_length(), self.y_dim),
infra.SelectionRegister('beta', 1, 2),
]
)

@cached_property
Expand All @@ -145,17 +149,22 @@ def registers(self) -> infra.Registers:
)

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid]
self,
*,
context: cirq.DecompositionContext,
**quregs: NDArray[cirq.Qid], # type:ignore[type-var]
) -> cirq.OP_TREE:
p_x, p_y, q_x, q_y = quregs['p_x'], quregs['p_y'], quregs['q_x'], quregs['q_y']
U, V, alpha, beta = quregs['U'], quregs['V'], quregs['alpha'], quregs['beta']
control, target = quregs.get('control', ()), quregs['target']

yield selected_majorana_fermion.SelectedMajoranaFermionGate(
selection_regs=infra.SelectionRegisters.build(
alpha=(1, 2),
p_y=(self.registers['p_y'].bitsize, self.y_dim),
p_x=(self.registers['p_x'].bitsize, self.x_dim),
selection_regs=infra.SelectionRegisters(
[
infra.SelectionRegister('alpha', 1, 2),
infra.SelectionRegister('p_y', self.registers['p_y'].total_bits(), self.y_dim),
infra.SelectionRegister('p_x', self.registers['p_x'].total_bits(), self.x_dim),
]
),
control_regs=self.control_registers,
target_gate=cirq.Y,
Expand All @@ -165,10 +174,12 @@ def decompose_from_registers(
yield swap_network.MultiTargetCSwap.make_on(control=V, target_x=p_y, target_y=q_y)
yield swap_network.MultiTargetCSwap.make_on(control=V, target_x=alpha, target_y=beta)

q_selection_regs = infra.SelectionRegisters.build(
beta=(1, 2),
q_y=(self.registers['q_y'].bitsize, self.y_dim),
q_x=(self.registers['q_x'].bitsize, self.x_dim),
q_selection_regs = infra.SelectionRegisters(
[
infra.SelectionRegister('beta', 1, 2),
infra.SelectionRegister('q_y', self.registers['q_y'].total_bits(), self.y_dim),
infra.SelectionRegister('q_x', self.registers['q_x'].total_bits(), self.x_dim),
]
)
yield selected_majorana_fermion.SelectedMajoranaFermionGate(
selection_regs=q_selection_regs, control_regs=self.control_registers, target_gate=cirq.X
Expand All @@ -190,12 +201,14 @@ def decompose_from_registers(
]

yield apply_gate_to_lth_target.ApplyGateToLthQubit(
selection_regs=infra.SelectionRegisters.build(
q_y=(self.registers['q_y'].bitsize, self.y_dim),
q_x=(self.registers['q_x'].bitsize, self.x_dim),
selection_regs=infra.SelectionRegisters(
[
infra.SelectionRegister('q_y', self.registers['q_y'].total_bits(), self.y_dim),
infra.SelectionRegister('q_x', self.registers['q_x'].total_bits(), self.x_dim),
]
),
nth_gate=lambda *_: cirq.Z,
control_regs=infra.Registers.build(control=1 + self.control_registers.bitsize),
control_regs=infra.Registers.build(control=1 + self.control_registers.total_bits()),
).on_registers(
q_x=q_x, q_y=q_y, control=[*V, *control], target=target_qubits_for_apply_to_lth_gate
)
Expand Down Expand Up @@ -280,15 +293,17 @@ def __attrs_post_init__(self):

@cached_property
def selection_registers(self) -> infra.SelectionRegisters:
return infra.SelectionRegisters.build(
U=(1, 2),
V=(1, 2),
p_x=((self.x_dim - 1).bit_length(), self.x_dim),
p_y=((self.y_dim - 1).bit_length(), self.y_dim),
alpha=(1, 2),
q_x=((self.x_dim - 1).bit_length(), self.x_dim),
q_y=((self.y_dim - 1).bit_length(), self.y_dim),
beta=(1, 2),
return infra.SelectionRegisters(
[
infra.SelectionRegister('U', 1, 2),
infra.SelectionRegister('V', 1, 2),
infra.SelectionRegister('p_x', (self.x_dim - 1).bit_length(), self.x_dim),
infra.SelectionRegister('p_y', (self.y_dim - 1).bit_length(), self.y_dim),
infra.SelectionRegister('alpha', 1, 2),
infra.SelectionRegister('q_x', (self.x_dim - 1).bit_length(), self.x_dim),
infra.SelectionRegister('q_y', (self.y_dim - 1).bit_length(), self.y_dim),
infra.SelectionRegister('beta', 1, 2),
]
)

@cached_property
Expand All @@ -300,7 +315,7 @@ def registers(self) -> infra.Registers:
return infra.Registers([*self.selection_registers, *self.junk_registers])

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid]
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid]
) -> cirq.OP_TREE:
p_x, p_y, q_x, q_y = quregs['p_x'], quregs['p_y'], quregs['q_x'], quregs['q_y']
U, V, alpha, beta = quregs['U'], quregs['V'], quregs['alpha'], quregs['beta']
Expand Down
15 changes: 10 additions & 5 deletions cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Sequence
from numpy.typing import NDArray

import attr
import cirq
Expand Down Expand Up @@ -49,10 +49,15 @@ def registers(self) -> infra.Registers:
return infra.Registers([*self.control_registers, *self.selection_registers])

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid]
self,
*,
context: cirq.DecompositionContext,
**quregs: NDArray[cirq.Qid], # type:ignore[type-var]
) -> cirq.OP_TREE:
qm = context.qubit_manager
target_reg = {reg.name: qm.qalloc(reg.bitsize) for reg in self.encoder.target_registers}
target_reg = {
reg.name: qm.qalloc(reg.total_bits()) for reg in self.encoder.target_registers
}
target_qubits = self.encoder.target_registers.merge_qubits(**target_reg)
encoder_op = self.encoder.on_registers(**quregs, **target_reg)

Expand All @@ -73,6 +78,6 @@ def decompose_from_registers(
qm.qfree([*arctan_sign, *arctan_target, *target_qubits])

def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
wire_symbols = ['@'] * self.control_registers.bitsize
wire_symbols += ['ROTy'] * self.selection_registers.bitsize
wire_symbols = ['@'] * self.control_registers.total_bits()
wire_symbols += ['ROTy'] * self.selection_registers.total_bits()
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def control_registers(self) -> cirq_ft.Registers:

@cached_property
def selection_registers(self) -> cirq_ft.SelectionRegisters:
return cirq_ft.SelectionRegisters.build(selection=(self.bitsize, 2**self.bitsize))
return cirq_ft.SelectionRegisters.build(selection=self.bitsize)

@cached_property
def target_registers(self) -> cirq_ft.Registers:
Expand Down