From 83ede366b76f2e29d371e1ac664cf3d07341706e Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Mon, 17 Jul 2023 13:48:02 -0700 Subject: [PATCH] Make Cirq-FT registers multi-dimensional (#6200) * Make Cirq-FT registers multi-dimensional * Update gate_with_registers notebook * Fix mypy, lint and coverage checks * Change Registers.bitsize to Registers.total_bits() * Fix mypy errors in arithmetic_gates * Change SelectionRegisters.build to match Registers.build and address other nits * Fix typo --- cirq-ft/cirq_ft/algos/and_gate.py | 11 +- cirq-ft/cirq_ft/algos/and_gate_test.py | 10 +- .../algos/apply_gate_to_lth_target.ipynb | 4 +- .../cirq_ft/algos/apply_gate_to_lth_target.py | 12 +- .../algos/apply_gate_to_lth_target_test.py | 10 +- cirq-ft/cirq_ft/algos/arithmetic_gates.py | 8 +- cirq-ft/cirq_ft/algos/generic_select.py | 19 ++- cirq-ft/cirq_ft/algos/hubbard_model.py | 78 +++++++----- .../mean_estimation/complex_phase_oracle.py | 15 ++- .../complex_phase_oracle_test.py | 2 +- .../mean_estimation_operator.py | 10 +- .../mean_estimation_operator_test.py | 18 +-- .../algos/multi_control_multi_target_pauli.py | 12 +- .../algos/prepare_uniform_superposition.py | 14 ++- .../prepare_uniform_superposition_test.py | 2 +- .../algos/programmable_rotation_gate_array.py | 17 +-- .../programmable_rotation_gate_array_test.py | 8 +- cirq-ft/cirq_ft/algos/qrom.py | 24 ++-- cirq-ft/cirq_ft/algos/qrom_test.py | 14 ++- .../algos/qubitization_walk_operator.py | 13 +- .../algos/qubitization_walk_operator_test.py | 2 +- .../cirq_ft/algos/reflection_using_prepare.py | 15 ++- cirq-ft/cirq_ft/algos/select_swap_qrom.py | 19 ++- .../algos/selected_majorana_fermion.py | 38 ++++-- .../algos/selected_majorana_fermion_test.py | 30 +++-- cirq-ft/cirq_ft/algos/state_preparation.py | 17 ++- cirq-ft/cirq_ft/algos/swap_network.py | 29 ++--- cirq-ft/cirq_ft/algos/swap_network_test.py | 7 +- cirq-ft/cirq_ft/algos/unary_iteration.ipynb | 5 +- cirq-ft/cirq_ft/algos/unary_iteration_gate.py | 23 ++-- .../algos/unary_iteration_gate_test.py | 26 ++-- .../cirq_ft/infra/gate_with_registers.ipynb | 8 +- cirq-ft/cirq_ft/infra/gate_with_registers.py | 115 ++++++++++++------ .../cirq_ft/infra/gate_with_registers_test.py | 33 +++-- cirq-ft/cirq_ft/infra/testing.py | 4 +- cirq-ft/cirq_ft/infra/testing_test.py | 5 +- 36 files changed, 427 insertions(+), 250 deletions(-) diff --git a/cirq-ft/cirq_ft/algos/and_gate.py b/cirq-ft/cirq_ft/algos/and_gate.py index b3646fdf825..aab45a978e7 100644 --- a/cirq-ft/cirq_ft/algos/and_gate.py +++ b/cirq-ft/cirq_ft/algos/and_gate.py @@ -14,6 +14,9 @@ from typing import Sequence, Tuple +import numpy as np +from numpy.typing import NDArray + import attr import cirq from cirq._compat import cached_property @@ -110,16 +113,16 @@ def _decompose_single_and( def _decompose_via_tree( self, - controls: Sequence[cirq.Qid], + controls: NDArray[cirq.Qid], # type:ignore[type-var] control_values: Sequence[int], - ancillas: Sequence[cirq.Qid], + ancillas: NDArray[cirq.Qid], target: cirq.Qid, ) -> cirq.ops.op_tree.OpTree: """Decomposes multi-controlled `And` in-terms of an `And` ladder of size #controls- 2.""" if len(controls) == 2: yield And(control_values, adjoint=self.adjoint).on(*controls, target) return - new_controls = (ancillas[0], *controls[2:]) + new_controls = np.concatenate([ancillas[0:1], controls[2:]]) new_control_values = (1, *control_values[2:]) and_op = And(control_values[:2], adjoint=self.adjoint).on(*controls[:2], ancillas[0]) if self.adjoint: @@ -134,7 +137,7 @@ def _decompose_via_tree( ) def decompose_from_registers( - self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] + self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] ) -> cirq.OP_TREE: control, ancilla, target = quregs['control'], quregs['ancilla'], quregs['target'] if len(self.cv) == 2: diff --git a/cirq-ft/cirq_ft/algos/and_gate_test.py b/cirq-ft/cirq_ft/algos/and_gate_test.py index 117f3a6ccac..f41b6a271c1 100644 --- a/cirq-ft/cirq_ft/algos/and_gate_test.py +++ b/cirq-ft/cirq_ft/algos/and_gate_test.py @@ -45,7 +45,7 @@ def random_cv(n: int) -> List[int]: def test_multi_controlled_and_gate(cv: List[int]): gate = cirq_ft.And(cv) r = gate.registers - assert r['ancilla'].bitsize == r['control'].bitsize - 2 + assert r['ancilla'].total_bits() == r['control'].total_bits() - 2 quregs = r.get_named_qubits() and_op = gate.on_registers(**quregs) circuit = cirq.Circuit(and_op) @@ -54,7 +54,7 @@ def test_multi_controlled_and_gate(cv: List[int]): qubit_order = gate.registers.merge_qubits(**quregs) for input_control in input_controls: - initial_state = input_control + [0] * (r['ancilla'].bitsize + 1) + initial_state = input_control + [0] * (r['ancilla'].total_bits() + 1) result = cirq.Simulator(dtype=np.complex128).simulate( circuit, initial_state=initial_state, qubit_order=qubit_order ) @@ -80,8 +80,10 @@ def test_and_gate_diagram(): qubit_regs = gate.registers.get_named_qubits() op = gate.on_registers(**qubit_regs) # Qubit order should be alternating (control, ancilla) pairs. - c_and_a = sum(zip(qubit_regs["control"][1:], qubit_regs["ancilla"] + [0]), ())[:-1] - qubit_order = qubit_regs["control"][0:1] + list(c_and_a) + qubit_regs["target"] + c_and_a = sum(zip(qubit_regs["control"][1:], qubit_regs["ancilla"]), ()) + ( + qubit_regs["control"][-1], + ) + qubit_order = np.concatenate([qubit_regs["control"][0:1], c_and_a, qubit_regs["target"]]) # Test diagrams. cirq.testing.assert_has_diagram( cirq.Circuit(op), diff --git a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.ipynb b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.ipynb index 3cc9c31a14a..90e4ac086fe 100644 --- a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.ipynb +++ b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.ipynb @@ -89,7 +89,7 @@ " return cirq.I\n", "\n", "apply_z_to_odd = cirq_ft.ApplyGateToLthQubit(\n", - " cirq_ft.SelectionRegisters.build(selection=(3, 4)),\n", + " cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('selection', 3, 4)]),\n", " nth_gate=_z_to_odd,\n", " control_regs=cirq_ft.Registers.build(control=2),\n", ")\n", @@ -123,4 +123,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py index 021021cc4e1..e796b1b05f0 100644 --- a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py +++ b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py @@ -56,8 +56,12 @@ def make_on( ) -> cirq.Operation: """Helper constructor to automatically deduce bitsize attributes.""" return cls( - infra.SelectionRegisters.build( - selection=(len(quregs['selection']), len(quregs['target'])) + infra.SelectionRegisters( + [ + infra.SelectionRegister( + 'selection', len(quregs['selection']), len(quregs['target']) + ) + ] ), nth_gate=nth_gate, control_regs=infra.Registers.build(control=len(quregs['control'])), @@ -76,8 +80,8 @@ def target_registers(self) -> infra.Registers: return infra.Registers.build(target=self.selection_registers.total_iteration_size) def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: - wire_symbols = ["@"] * self.control_registers.bitsize - wire_symbols += ["In"] * self.selection_registers.bitsize + wire_symbols = ["@"] * self.control_registers.total_bits() + wire_symbols += ["In"] * self.selection_registers.total_bits() for it in itertools.product(*[range(x) for x in self.selection_regs.iteration_lengths]): wire_symbols += [str(self.nth_gate(*it))] return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) diff --git a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target_test.py b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target_test.py index 17148febef2..da285792d36 100644 --- a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target_test.py +++ b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target_test.py @@ -23,14 +23,16 @@ def test_apply_gate_to_lth_qubit(selection_bitsize, target_bitsize): greedy_mm = cirq_ft.GreedyQubitManager(prefix="_a", maximize_reuse=True) gate = cirq_ft.ApplyGateToLthQubit( - cirq_ft.SelectionRegisters.build(selection=(selection_bitsize, target_bitsize)), + cirq_ft.SelectionRegisters( + [cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize)] + ), lambda _: cirq.X, ) g = cirq_ft.testing.GateHelper(gate, context=cirq.DecompositionContext(greedy_mm)) # Upper bounded because not all ancillas may be used as part of unary iteration. assert ( len(g.all_qubits) - <= target_bitsize + 2 * (selection_bitsize + gate.control_registers.bitsize) - 1 + <= target_bitsize + 2 * (selection_bitsize + gate.control_registers.total_bits()) - 1 ) for n in range(target_bitsize): @@ -52,7 +54,7 @@ def test_apply_gate_to_lth_qubit(selection_bitsize, target_bitsize): def test_apply_gate_to_lth_qubit_diagram(): # Apply Z gate to all odd targets and Identity to even targets. gate = cirq_ft.ApplyGateToLthQubit( - cirq_ft.SelectionRegisters.build(selection=(3, 5)), + cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('selection', 3, 5)]), lambda n: cirq.Z if n & 1 else cirq.I, control_regs=cirq_ft.Registers.build(control=2), ) @@ -87,7 +89,7 @@ def test_apply_gate_to_lth_qubit_diagram(): def test_apply_gate_to_lth_qubit_make_on(): gate = cirq_ft.ApplyGateToLthQubit( - cirq_ft.SelectionRegisters.build(selection=(3, 5)), + cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('selection', 3, 5)]), lambda n: cirq.Z if n & 1 else cirq.I, control_regs=cirq_ft.Registers.build(control=2), ) diff --git a/cirq-ft/cirq_ft/algos/arithmetic_gates.py b/cirq-ft/cirq_ft/algos/arithmetic_gates.py index 8f51d019794..d477c598e77 100644 --- a/cirq-ft/cirq_ft/algos/arithmetic_gates.py +++ b/cirq-ft/cirq_ft/algos/arithmetic_gates.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Iterable, Optional, Sequence, Tuple, Union, List, Iterator +from numpy.typing import NDArray from cirq._compat import cached_property import attr @@ -153,7 +154,10 @@ def __repr__(self) -> str: return f'cirq_ft.algos.BiQubitsMixer({self.adjoint})' def decompose_from_registers( - self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] + self, + *, + context: cirq.DecompositionContext, + **quregs: NDArray[cirq.Qid], # type:ignore[type-var] ) -> cirq.OP_TREE: x, y, ancilla = quregs['x'], quregs['y'], quregs['ancilla'] x_msb, x_lsb = x @@ -224,7 +228,7 @@ def __repr__(self) -> str: return f'cirq_ft.algos.SingleQubitCompare({self.adjoint})' def decompose_from_registers( - self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] + self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] ) -> cirq.OP_TREE: a = quregs['a'] b = quregs['b'] diff --git a/cirq-ft/cirq_ft/algos/generic_select.py b/cirq-ft/cirq_ft/algos/generic_select.py index e70f70b35ce..68d62cf98f6 100644 --- a/cirq-ft/cirq_ft/algos/generic_select.py +++ b/cirq-ft/cirq_ft/algos/generic_select.py @@ -15,6 +15,7 @@ """Gates for applying generic selected unitaries.""" from typing import Collection, Optional, Sequence, Tuple, Union +from numpy.typing import NDArray import attr import cirq @@ -73,20 +74,26 @@ def control_registers(self) -> infra.Registers: @cached_property def selection_registers(self) -> infra.SelectionRegisters: - return infra.SelectionRegisters.build( - selection=(self.selection_bitsize, len(self.select_unitaries)) + return infra.SelectionRegisters( + [ + infra.SelectionRegister( + 'selection', self.selection_bitsize, len(self.select_unitaries) + ) + ] ) @cached_property def target_registers(self) -> infra.Registers: return infra.Registers.build(target=self.target_bitsize) - def decompose_from_registers(self, context, **qubit_regs: Sequence[cirq.Qid]) -> cirq.OP_TREE: + def decompose_from_registers( + self, context, **quregs: NDArray[cirq.Qid] # type:ignore[type-var] + ) -> cirq.OP_TREE: if self.control_val == 0: - yield cirq.X(*qubit_regs['control']) - yield super(GenericSelect, self).decompose_from_registers(context=context, **qubit_regs) + yield cirq.X(*quregs['control']) + yield super(GenericSelect, self).decompose_from_registers(context=context, **quregs) if self.control_val == 0: - yield cirq.X(*qubit_regs['control']) + yield cirq.X(*quregs['control']) def nth_operation( # type: ignore[override] self, diff --git a/cirq-ft/cirq_ft/algos/hubbard_model.py b/cirq-ft/cirq_ft/algos/hubbard_model.py index 7093a734c7b..520d305062a 100644 --- a/cirq-ft/cirq_ft/algos/hubbard_model.py +++ b/cirq-ft/cirq_ft/algos/hubbard_model.py @@ -47,6 +47,7 @@ See the documentation for `PrepareHubbard` and `SelectHubbard` for details. """ from typing import Collection, Optional, Sequence, Tuple, Union +from numpy.typing import NDArray import attr import cirq @@ -123,15 +124,17 @@ def control_registers(self) -> infra.Registers: @cached_property def selection_registers(self) -> infra.SelectionRegisters: - return infra.SelectionRegisters.build( - U=(1, 2), - V=(1, 2), - p_x=((self.x_dim - 1).bit_length(), self.x_dim), - p_y=((self.y_dim - 1).bit_length(), self.y_dim), - alpha=(1, 2), - q_x=((self.x_dim - 1).bit_length(), self.x_dim), - q_y=((self.y_dim - 1).bit_length(), self.y_dim), - beta=(1, 2), + return infra.SelectionRegisters( + [ + infra.SelectionRegister('U', 1, 2), + infra.SelectionRegister('V', 1, 2), + infra.SelectionRegister('p_x', (self.x_dim - 1).bit_length(), self.x_dim), + infra.SelectionRegister('p_y', (self.y_dim - 1).bit_length(), self.y_dim), + infra.SelectionRegister('alpha', 1, 2), + infra.SelectionRegister('q_x', (self.x_dim - 1).bit_length(), self.x_dim), + infra.SelectionRegister('q_y', (self.y_dim - 1).bit_length(), self.y_dim), + infra.SelectionRegister('beta', 1, 2), + ] ) @cached_property @@ -145,17 +148,22 @@ def registers(self) -> infra.Registers: ) def decompose_from_registers( - self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] + self, + *, + context: cirq.DecompositionContext, + **quregs: NDArray[cirq.Qid], # type:ignore[type-var] ) -> cirq.OP_TREE: p_x, p_y, q_x, q_y = quregs['p_x'], quregs['p_y'], quregs['q_x'], quregs['q_y'] U, V, alpha, beta = quregs['U'], quregs['V'], quregs['alpha'], quregs['beta'] control, target = quregs.get('control', ()), quregs['target'] yield selected_majorana_fermion.SelectedMajoranaFermionGate( - selection_regs=infra.SelectionRegisters.build( - alpha=(1, 2), - p_y=(self.registers['p_y'].bitsize, self.y_dim), - p_x=(self.registers['p_x'].bitsize, self.x_dim), + selection_regs=infra.SelectionRegisters( + [ + infra.SelectionRegister('alpha', 1, 2), + infra.SelectionRegister('p_y', self.registers['p_y'].total_bits(), self.y_dim), + infra.SelectionRegister('p_x', self.registers['p_x'].total_bits(), self.x_dim), + ] ), control_regs=self.control_registers, target_gate=cirq.Y, @@ -165,10 +173,12 @@ def decompose_from_registers( yield swap_network.MultiTargetCSwap.make_on(control=V, target_x=p_y, target_y=q_y) yield swap_network.MultiTargetCSwap.make_on(control=V, target_x=alpha, target_y=beta) - q_selection_regs = infra.SelectionRegisters.build( - beta=(1, 2), - q_y=(self.registers['q_y'].bitsize, self.y_dim), - q_x=(self.registers['q_x'].bitsize, self.x_dim), + q_selection_regs = infra.SelectionRegisters( + [ + infra.SelectionRegister('beta', 1, 2), + infra.SelectionRegister('q_y', self.registers['q_y'].total_bits(), self.y_dim), + infra.SelectionRegister('q_x', self.registers['q_x'].total_bits(), self.x_dim), + ] ) yield selected_majorana_fermion.SelectedMajoranaFermionGate( selection_regs=q_selection_regs, control_regs=self.control_registers, target_gate=cirq.X @@ -190,12 +200,14 @@ def decompose_from_registers( ] yield apply_gate_to_lth_target.ApplyGateToLthQubit( - selection_regs=infra.SelectionRegisters.build( - q_y=(self.registers['q_y'].bitsize, self.y_dim), - q_x=(self.registers['q_x'].bitsize, self.x_dim), + selection_regs=infra.SelectionRegisters( + [ + infra.SelectionRegister('q_y', self.registers['q_y'].total_bits(), self.y_dim), + infra.SelectionRegister('q_x', self.registers['q_x'].total_bits(), self.x_dim), + ] ), nth_gate=lambda *_: cirq.Z, - control_regs=infra.Registers.build(control=1 + self.control_registers.bitsize), + control_regs=infra.Registers.build(control=1 + self.control_registers.total_bits()), ).on_registers( q_x=q_x, q_y=q_y, control=[*V, *control], target=target_qubits_for_apply_to_lth_gate ) @@ -280,15 +292,17 @@ def __attrs_post_init__(self): @cached_property def selection_registers(self) -> infra.SelectionRegisters: - return infra.SelectionRegisters.build( - U=(1, 2), - V=(1, 2), - p_x=((self.x_dim - 1).bit_length(), self.x_dim), - p_y=((self.y_dim - 1).bit_length(), self.y_dim), - alpha=(1, 2), - q_x=((self.x_dim - 1).bit_length(), self.x_dim), - q_y=((self.y_dim - 1).bit_length(), self.y_dim), - beta=(1, 2), + return infra.SelectionRegisters( + [ + infra.SelectionRegister('U', 1, 2), + infra.SelectionRegister('V', 1, 2), + infra.SelectionRegister('p_x', (self.x_dim - 1).bit_length(), self.x_dim), + infra.SelectionRegister('p_y', (self.y_dim - 1).bit_length(), self.y_dim), + infra.SelectionRegister('alpha', 1, 2), + infra.SelectionRegister('q_x', (self.x_dim - 1).bit_length(), self.x_dim), + infra.SelectionRegister('q_y', (self.y_dim - 1).bit_length(), self.y_dim), + infra.SelectionRegister('beta', 1, 2), + ] ) @cached_property @@ -300,7 +314,7 @@ def registers(self) -> infra.Registers: return infra.Registers([*self.selection_registers, *self.junk_registers]) def decompose_from_registers( - self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] + self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] ) -> cirq.OP_TREE: p_x, p_y, q_x, q_y = quregs['p_x'], quregs['p_y'], quregs['q_x'], quregs['q_y'] U, V, alpha, beta = quregs['U'], quregs['V'], quregs['alpha'], quregs['beta'] diff --git a/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle.py b/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle.py index 75e4f87a741..74f992b0f5d 100644 --- a/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle.py +++ b/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence +from numpy.typing import NDArray import attr import cirq @@ -49,10 +49,15 @@ def registers(self) -> infra.Registers: return infra.Registers([*self.control_registers, *self.selection_registers]) def decompose_from_registers( - self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] + self, + *, + context: cirq.DecompositionContext, + **quregs: NDArray[cirq.Qid], # type:ignore[type-var] ) -> cirq.OP_TREE: qm = context.qubit_manager - target_reg = {reg.name: qm.qalloc(reg.bitsize) for reg in self.encoder.target_registers} + target_reg = { + reg.name: qm.qalloc(reg.total_bits()) for reg in self.encoder.target_registers + } target_qubits = self.encoder.target_registers.merge_qubits(**target_reg) encoder_op = self.encoder.on_registers(**quregs, **target_reg) @@ -73,6 +78,6 @@ def decompose_from_registers( qm.qfree([*arctan_sign, *arctan_target, *target_qubits]) def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: - wire_symbols = ['@'] * self.control_registers.bitsize - wire_symbols += ['ROTy'] * self.selection_registers.bitsize + wire_symbols = ['@'] * self.control_registers.total_bits() + wire_symbols += ['ROTy'] * self.selection_registers.total_bits() return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) diff --git a/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle_test.py b/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle_test.py index 6d701f3f571..7e6b61527c2 100644 --- a/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle_test.py +++ b/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle_test.py @@ -38,7 +38,7 @@ def control_registers(self) -> cirq_ft.Registers: @cached_property def selection_registers(self) -> cirq_ft.SelectionRegisters: - return cirq_ft.SelectionRegisters.build(selection=(self.bitsize, 2**self.bitsize)) + return cirq_ft.SelectionRegisters.build(selection=self.bitsize) @cached_property def target_registers(self) -> cirq_ft.Registers: diff --git a/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator.py b/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator.py index b8c45633b05..f3fff37ecd7 100644 --- a/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator.py +++ b/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Collection, Optional, Sequence, Tuple, Union +from numpy.typing import NDArray import attr import cirq @@ -110,7 +111,10 @@ def registers(self) -> infra.Registers: return infra.Registers([*self.control_registers, *self.selection_registers]) def decompose_from_registers( - self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] + self, + *, + context: cirq.DecompositionContext, + **quregs: NDArray[cirq.Qid], # type:ignore[type-var] ) -> cirq.OP_TREE: select_reg = {reg.name: quregs[reg.name] for reg in self.select.registers} reflect_reg = {reg.name: quregs[reg.name] for reg in self.reflect.registers} @@ -125,7 +129,9 @@ def decompose_from_registers( def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: wire_symbols = [] if self.cv == () else [["@(0)", "@"][self.cv[0]]] - wire_symbols += ['U_ko'] * (self.registers.bitsize - self.control_registers.bitsize) + wire_symbols += ['U_ko'] * ( + self.registers.total_bits() - self.control_registers.total_bits() + ) if self.power != 1: wire_symbols[-1] = f'U_ko^{self.power}' return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) diff --git a/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator_test.py b/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator_test.py index 914b31eeae9..5460c8ad194 100644 --- a/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator_test.py +++ b/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator_test.py @@ -33,7 +33,7 @@ class BernoulliSynthesizer(cirq_ft.PrepareOracle): @cached_property def selection_registers(self) -> cirq_ft.SelectionRegisters: - return cirq_ft.SelectionRegisters.build(q=(self.nqubits, 2)) + return cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('q', self.nqubits, 2)]) def decompose_from_registers( # type:ignore[override] self, context, q: Sequence[cirq.Qid] @@ -60,7 +60,9 @@ def control_registers(self) -> cirq_ft.Registers: @cached_property def selection_registers(self) -> cirq_ft.SelectionRegisters: - return cirq_ft.SelectionRegisters.build(q=(self.selection_bitsize, 2)) + return cirq_ft.SelectionRegisters( + [cirq_ft.SelectionRegister('q', self.selection_bitsize, 2)] + ) @cached_property def target_registers(self) -> cirq_ft.Registers: @@ -126,7 +128,7 @@ def satisfies_theorem_321( overlap_sum = 0.0 eigvals, eigvects = cirq.linalg.unitary_eig(u) - for (eig_val, eig_vect) in zip(eigvals, eigvects.T): + for eig_val, eig_vect in zip(eigvals, eigvects.T): theta = np.abs(np.angle(eig_val)) hav_theta = np.sin(theta / 2) overlap_prob = overlap(prep_state, eig_vect) @@ -173,7 +175,7 @@ class GroverSynthesizer(cirq_ft.PrepareOracle): @cached_property def selection_registers(self) -> cirq_ft.SelectionRegisters: - return cirq_ft.SelectionRegisters.build(selection=(self.n, 2**self.n)) + return cirq_ft.SelectionRegisters.build(selection=self.n) def decompose_from_registers( # type:ignore[override] self, *, context, selection: Sequence[cirq.Qid] @@ -200,7 +202,7 @@ def control_registers(self) -> cirq_ft.Registers: @cached_property def selection_registers(self) -> cirq_ft.SelectionRegisters: - return cirq_ft.SelectionRegisters.build(selection=(self.n, 2**self.n)) + return cirq_ft.SelectionRegisters.build(selection=self.n) @cached_property def target_registers(self) -> cirq_ft.Registers: @@ -209,8 +211,10 @@ def target_registers(self) -> cirq_ft.Registers: def decompose_from_registers( # type:ignore[override] self, context, *, selection: Sequence[cirq.Qid], target: Sequence[cirq.Qid] ) -> cirq.OP_TREE: - selection_cv = [*bit_tools.iter_bits(self.marked_item, self.selection_registers.bitsize)] - yval_bin = [*bit_tools.iter_bits(self.marked_val, self.target_registers.bitsize)] + selection_cv = [ + *bit_tools.iter_bits(self.marked_item, self.selection_registers.total_bits()) + ] + yval_bin = [*bit_tools.iter_bits(self.marked_val, self.target_registers.total_bits())] for b, q in zip(yval_bin, target): if b: diff --git a/cirq-ft/cirq_ft/algos/multi_control_multi_target_pauli.py b/cirq-ft/cirq_ft/algos/multi_control_multi_target_pauli.py index 8e39bb21095..6ab7e65e51f 100644 --- a/cirq-ft/cirq_ft/algos/multi_control_multi_target_pauli.py +++ b/cirq-ft/cirq_ft/algos/multi_control_multi_target_pauli.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Tuple +from typing import Tuple +from numpy.typing import NDArray import attr import cirq @@ -38,11 +39,14 @@ def registers(self) -> infra.Registers: return infra.Registers.build(control=1, targets=self._num_targets) def decompose_from_registers( - self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] + self, + *, + context: cirq.DecompositionContext, + **quregs: NDArray[cirq.Qid], # type:ignore[type-var] ): control, targets = quregs['control'], quregs['targets'] - def cnots_for_depth_i(i: int, q: Sequence[cirq.Qid]) -> cirq.OP_TREE: + def cnots_for_depth_i(i: int, q: NDArray[cirq.Qid]) -> cirq.OP_TREE: for c, t in zip(q[: 2**i], q[2**i : min(len(q), 2 ** (i + 1))]): yield cirq.CNOT(c, t) @@ -77,7 +81,7 @@ def registers(self) -> infra.Registers: return infra.Registers.build(controls=len(self.cvs), target=1) def decompose_from_registers( - self, *, context: cirq.DecompositionContext, **quregs: Sequence['cirq.Qid'] + self, *, context: cirq.DecompositionContext, **quregs: NDArray['cirq.Qid'] ) -> cirq.OP_TREE: controls, target = quregs['controls'], quregs['target'] qm = context.qubit_manager diff --git a/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py b/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py index a1ce5256499..e75f735bfe9 100644 --- a/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py +++ b/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Tuple +from typing import Tuple +from numpy.typing import NDArray import attr import cirq @@ -56,12 +57,15 @@ def __repr__(self) -> str: def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: control_symbols = ["@" if cv else "@(0)" for cv in self.cv] - target_symbols = ['target'] * self.registers['target'].bitsize + target_symbols = ['target'] * self.registers['target'].total_bits() target_symbols[0] = f"UNIFORM({self.n})" return cirq.CircuitDiagramInfo(wire_symbols=control_symbols + target_symbols) def decompose_from_registers( - self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] + self, + *, + context: cirq.DecompositionContext, + **quregs: NDArray[cirq.Qid], # type:ignore[type-var] ) -> cirq.OP_TREE: controls, target = quregs['controls'], quregs['target'] # Find K and L as per https://arxiv.org/abs/1805.03662 Fig 12. @@ -69,13 +73,13 @@ def decompose_from_registers( while n > 1 and n % 2 == 0: k += 1 n = n // 2 - l, logL = int(n), self.registers['target'].bitsize - k + l, logL = int(n), self.registers['target'].total_bits() - k logL_qubits = target[:logL] yield [ op.controlled_by(*controls, control_values=self.cv) for op in cirq.H.on_each(*target) ] - if not logL_qubits: + if not len(logL_qubits): return ancilla = context.qubit_manager.qalloc(1) diff --git a/cirq-ft/cirq_ft/algos/prepare_uniform_superposition_test.py b/cirq-ft/cirq_ft/algos/prepare_uniform_superposition_test.py index 63e88c1749d..b89ac0bb698 100644 --- a/cirq-ft/cirq_ft/algos/prepare_uniform_superposition_test.py +++ b/cirq-ft/cirq_ft/algos/prepare_uniform_superposition_test.py @@ -51,7 +51,7 @@ def test_prepare_uniform_superposition_t_complexity(n: int): result = cirq_ft.t_complexity(gate) # TODO(#233): Controlled-H is currently counted as a separate rotation, but it can be # implemented using 2 T-gates. - assert result.rotations <= 2 + 2 * gate.registers.bitsize + assert result.rotations <= 2 + 2 * gate.registers.total_bits() assert result.t <= 12 * (n - 1).bit_length() diff --git a/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array.py b/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array.py index 5d7a825218e..47e43857247 100644 --- a/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array.py +++ b/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array.py @@ -14,6 +14,7 @@ import abc from typing import Sequence, Tuple +from numpy.typing import NDArray import cirq import numpy as np @@ -95,13 +96,15 @@ def rotation_gate(self, exponent: int = -1) -> cirq.Gate: return cirq.pow(self._rotation_gate, power) @abc.abstractmethod - def interleaved_unitary(self, index: int, **qubit_regs: Sequence[cirq.Qid]) -> cirq.Operation: + def interleaved_unitary( + self, index: int, **qubit_regs: NDArray[cirq.Qid] # type:ignore[type-var] + ) -> cirq.Operation: pass @cached_property def selection_registers(self) -> infra.SelectionRegisters: - return infra.SelectionRegisters.build( - selection=(self._selection_bitsize, len(self.angles[0])) + return infra.SelectionRegisters( + [infra.SelectionRegister('selection', self._selection_bitsize, len(self.angles[0]))] ) @cached_property @@ -129,7 +132,7 @@ def registers(self) -> infra.Registers: ) def decompose_from_registers( - self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] + self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] ) -> cirq.OP_TREE: selection, kappa_load_target = quregs.pop('selection'), quregs.pop('kappa_load_target') rotations_target = quregs.pop('rotations_target') @@ -138,7 +141,7 @@ def decompose_from_registers( # 1. Find a convenient way to process batches of size kappa. num_bits = sum(max(thetas).bit_length() for thetas in self.angles) iteration_length = self.selection_registers[0].iteration_length - selection_bitsizes = [s.bitsize for s in self.selection_registers] + selection_bitsizes = [s.total_bits() for s in self.selection_registers] angles_bits = np.zeros(shape=(iteration_length, num_bits), dtype=int) angles_bit_pow = np.zeros(shape=(num_bits,), dtype=int) angles_idx = np.zeros(shape=(num_bits,), dtype=int) @@ -192,13 +195,13 @@ def __init__( ): super().__init__(*angles, kappa=kappa, rotation_gate=rotation_gate) if not interleaved_unitaries: - identity_gate = cirq.IdentityGate(self.rotations_target.bitsize) + identity_gate = cirq.IdentityGate(self.rotations_target.total_bits()) interleaved_unitaries = (identity_gate,) * (len(angles) - 1) assert len(interleaved_unitaries) == len(angles) - 1 assert all(cirq.num_qubits(u) == self._target_bitsize for u in interleaved_unitaries) self._interleaved_unitaries = tuple(interleaved_unitaries) - def interleaved_unitary(self, index: int, **qubit_regs: Sequence[cirq.Qid]) -> cirq.Operation: + def interleaved_unitary(self, index: int, **qubit_regs: NDArray[cirq.Qid]) -> cirq.Operation: return self._interleaved_unitaries[index].on(*qubit_regs['rotations_target']) @cached_property diff --git a/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array_test.py b/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array_test.py index 3813e6de7ce..a860bb07b6e 100644 --- a/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array_test.py +++ b/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence +from numpy.typing import NDArray import cirq import cirq_ft @@ -23,7 +23,9 @@ class CustomProgrammableRotationGateArray(cirq_ft.ProgrammableRotationGateArrayBase): - def interleaved_unitary(self, index: int, **qubit_regs: Sequence[cirq.Qid]) -> cirq.Operation: + def interleaved_unitary( + self, index: int, **qubit_regs: NDArray[cirq.Qid] # type:ignore[type-var] + ) -> cirq.Operation: two_qubit_ops_factory = [ cirq.X(*qubit_regs['unrelated_target']).controlled_by(*qubit_regs['rotations_target']), cirq.Z(*qubit_regs['unrelated_target']).controlled_by(*qubit_regs['rotations_target']), @@ -91,7 +93,7 @@ def rotation_ops(theta: int) -> cirq.OP_TREE: # Set bits in initial_state s.t. selection register stores `selection_integer`. qubit_vals = {x: 0 for x in g.all_qubits} qubit_vals.update( - zip(g.quregs['selection'], iter_bits(selection_integer, g.r['selection'].bitsize)) + zip(g.quregs['selection'], iter_bits(selection_integer, g.r['selection'].total_bits())) ) initial_state = [qubit_vals[x] for x in g.all_qubits] # Actual circuit simulation. diff --git a/cirq-ft/cirq_ft/algos/qrom.py b/cirq-ft/cirq_ft/algos/qrom.py index b5b0defe091..e07ffe29b86 100644 --- a/cirq-ft/cirq_ft/algos/qrom.py +++ b/cirq-ft/cirq_ft/algos/qrom.py @@ -89,15 +89,19 @@ def control_registers(self) -> infra.Registers: @cached_property def selection_registers(self) -> infra.SelectionRegisters: if len(self.data[0].shape) == 1: - return infra.SelectionRegisters.build( - selection=(self.selection_bitsizes[0], self.data[0].shape[0]) + return infra.SelectionRegisters( + [ + infra.SelectionRegister( + 'selection', self.selection_bitsizes[0], self.data[0].shape[0] + ) + ] ) else: - return infra.SelectionRegisters.build( - **{ - f'selection{i}': (sb, len) + return infra.SelectionRegisters( + [ + infra.SelectionRegister(f'selection{i}', sb, len) for i, (len, sb) in enumerate(zip(self.data[0].shape, self.selection_bitsizes)) - } + ] ) @cached_property @@ -119,7 +123,7 @@ def _load_nth_data( self, selection_idx: Tuple[int, ...], gate: Callable[[cirq.Qid], cirq.Operation], - **target_regs: Sequence[cirq.Qid], + **target_regs: NDArray[cirq.Qid], # type: ignore[type-var] ) -> cirq.OP_TREE: for i, d in enumerate(self.data): target = target_regs[f'target{i}'] @@ -128,7 +132,7 @@ def _load_nth_data( yield gate(q) def decompose_zero_selection( - self, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] + self, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] ) -> cirq.OP_TREE: controls = self.control_registers.merge_qubits(**quregs) target_regs = {k: v for k, v in quregs.items() if k in self.target_registers} @@ -157,9 +161,9 @@ def nth_operation( def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo: wire_symbols = ["@"] * self.num_controls - wire_symbols += ["In"] * self.selection_registers.bitsize + wire_symbols += ["In"] * self.selection_registers.total_bits() for i, target in enumerate(self.target_registers): - wire_symbols += [f"QROM_{i}"] * target.bitsize + wire_symbols += [f"QROM_{i}"] * target.total_bits() return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) def __pow__(self, power: int): diff --git a/cirq-ft/cirq_ft/algos/qrom_test.py b/cirq-ft/cirq_ft/algos/qrom_test.py index bc81ba603b7..7cb67bb84d4 100644 --- a/cirq-ft/cirq_ft/algos/qrom_test.py +++ b/cirq-ft/cirq_ft/algos/qrom_test.py @@ -33,14 +33,19 @@ def test_qrom_1d(data, num_controls): decomposed_circuit = cirq.Circuit(cirq.decompose(g.operation, context=g.context)) inverse = cirq.Circuit(cirq.decompose(g.operation**-1, context=g.context)) - assert len(inverse.all_qubits()) <= g.r.bitsize + g.r['selection'].bitsize + num_controls + assert ( + len(inverse.all_qubits()) <= g.r.total_bits() + g.r['selection'].total_bits() + num_controls + ) assert inverse.all_qubits() == decomposed_circuit.all_qubits() for selection_integer in range(len(data[0])): for cval in range(2): qubit_vals = {x: 0 for x in g.all_qubits} qubit_vals.update( - zip(g.quregs['selection'], iter_bits(selection_integer, g.r['selection'].bitsize)) + zip( + g.quregs['selection'], + iter_bits(selection_integer, g.r['selection'].total_bits()), + ) ) if num_controls: qubit_vals.update(zip(g.quregs['control'], [cval] * num_controls)) @@ -131,11 +136,12 @@ def test_qrom_multi_dim(data, num_controls): inverse = cirq.Circuit(cirq.decompose(g.operation**-1, context=g.context)) assert ( - len(inverse.all_qubits()) <= g.r.bitsize + qrom.selection_registers.bitsize + num_controls + len(inverse.all_qubits()) + <= g.r.total_bits() + qrom.selection_registers.total_bits() + num_controls ) assert inverse.all_qubits() == decomposed_circuit.all_qubits() - lens = tuple(reg.bitsize for reg in qrom.selection_registers) + lens = tuple(reg.total_bits() for reg in qrom.selection_registers) for idxs in itertools.product(*[range(dim) for dim in data[0].shape]): qubit_vals = {x: 0 for x in g.all_qubits} for cval in range(2): diff --git a/cirq-ft/cirq_ft/algos/qubitization_walk_operator.py b/cirq-ft/cirq_ft/algos/qubitization_walk_operator.py index 4d2bac72cd0..6910de45905 100644 --- a/cirq-ft/cirq_ft/algos/qubitization_walk_operator.py +++ b/cirq-ft/cirq_ft/algos/qubitization_walk_operator.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Collection, Optional, Sequence, Tuple, Union +from numpy.typing import NDArray import attr import cirq @@ -84,20 +85,22 @@ def reflect(self) -> reflection_using_prepare.ReflectionUsingPrepare: ) def decompose_from_registers( - self, context: cirq.DecompositionContext, **qubit_regs: Sequence[cirq.Qid] + self, + context: cirq.DecompositionContext, + **quregs: NDArray[cirq.Qid], # type:ignore[type-var] ) -> cirq.OP_TREE: - select_reg = {reg.name: qubit_regs[reg.name] for reg in self.select.registers} + select_reg = {reg.name: quregs[reg.name] for reg in self.select.registers} select_op = self.select.on_registers(**select_reg) - reflect_reg = {reg.name: qubit_regs[reg.name] for reg in self.reflect.registers} + reflect_reg = {reg.name: quregs[reg.name] for reg in self.reflect.registers} reflect_op = self.reflect.on_registers(**reflect_reg) for _ in range(self.power): yield select_op yield reflect_op def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: - wire_symbols = ['@' if self.control_val else '@(0)'] * self.control_registers.bitsize - wire_symbols += ['W'] * (self.registers.bitsize - self.control_registers.bitsize) + wire_symbols = ['@' if self.control_val else '@(0)'] * self.control_registers.total_bits() + wire_symbols += ['W'] * (self.registers.total_bits() - self.control_registers.total_bits()) wire_symbols[-1] = f'W^{self.power}' if self.power != 1 else 'W' return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) diff --git a/cirq-ft/cirq_ft/algos/qubitization_walk_operator_test.py b/cirq-ft/cirq_ft/algos/qubitization_walk_operator_test.py index da7fa561518..8ea413661da 100644 --- a/cirq-ft/cirq_ft/algos/qubitization_walk_operator_test.py +++ b/cirq-ft/cirq_ft/algos/qubitization_walk_operator_test.py @@ -31,7 +31,7 @@ def walk_operator_for_pauli_hamiltonian( ham_coeff, probability_epsilon=eps ) select = cirq_ft.GenericSelect( - prepare.selection_registers.bitsize, select_unitaries=ham_dps, target_bitsize=len(q) + prepare.selection_registers.total_bits(), select_unitaries=ham_dps, target_bitsize=len(q) ) return cirq_ft.QubitizationWalkOperator(select=select, prepare=prepare) diff --git a/cirq-ft/cirq_ft/algos/reflection_using_prepare.py b/cirq-ft/cirq_ft/algos/reflection_using_prepare.py index 503b07c5444..361644fc5ee 100644 --- a/cirq-ft/cirq_ft/algos/reflection_using_prepare.py +++ b/cirq-ft/cirq_ft/algos/reflection_using_prepare.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Collection, Optional, Sequence, Tuple, Union +from numpy.typing import NDArray import attr import cirq @@ -69,15 +70,17 @@ def registers(self) -> infra.Registers: return infra.Registers([*self.control_registers, *self.selection_registers]) def decompose_from_registers( - self, context: cirq.DecompositionContext, **qubit_regs: Sequence[cirq.Qid] + self, + context: cirq.DecompositionContext, + **quregs: NDArray[cirq.Qid], # type:ignore[type-var] ) -> cirq.OP_TREE: qm = context.qubit_manager # 0. Allocate new ancillas, if needed. - phase_target = qm.qalloc(1)[0] if self.control_val is None else qubit_regs.pop('control')[0] + phase_target = qm.qalloc(1)[0] if self.control_val is None else quregs.pop('control')[0] state_prep_ancilla = { - reg.name: qm.qalloc(reg.bitsize) for reg in self.prepare_gate.junk_registers + reg.name: qm.qalloc(reg.total_bits()) for reg in self.prepare_gate.junk_registers } - state_prep_selection_regs = qubit_regs + state_prep_selection_regs = quregs prepare_op = self.prepare_gate.on_registers( **state_prep_selection_regs, **state_prep_ancilla ) @@ -99,8 +102,8 @@ def decompose_from_registers( qm.qfree([phase_target]) def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: - wire_symbols = ['@' if self.control_val else '@(0)'] * self.control_registers.bitsize - wire_symbols += ['R_L'] * self.selection_registers.bitsize + wire_symbols = ['@' if self.control_val else '@(0)'] * self.control_registers.total_bits() + wire_symbols += ['R_L'] * self.selection_registers.total_bits() return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) def __repr__(self): diff --git a/cirq-ft/cirq_ft/algos/select_swap_qrom.py b/cirq-ft/cirq_ft/algos/select_swap_qrom.py index a7b0dda29f0..5773a617c4b 100644 --- a/cirq-ft/cirq_ft/algos/select_swap_qrom.py +++ b/cirq-ft/cirq_ft/algos/select_swap_qrom.py @@ -13,13 +13,13 @@ # limitations under the License. from typing import List, Optional, Sequence, Tuple +from numpy.typing import NDArray import cirq import numpy as np from cirq._compat import cached_property from cirq_ft import infra from cirq_ft.algos import qrom, swap_network -from numpy.typing import NDArray def find_optimal_log_block_size(iteration_length: int, target_bitsize: int) -> int: @@ -139,8 +139,12 @@ def __init__( @cached_property def selection_registers(self) -> infra.SelectionRegisters: - return infra.SelectionRegisters.build( - selection=(self.selection_q + self.selection_r, self._iteration_length) + return infra.SelectionRegisters( + [ + infra.SelectionRegister( + 'selection', self.selection_q + self.selection_r, self._iteration_length + ) + ] ) @cached_property @@ -177,7 +181,10 @@ def __repr__(self) -> str: ) def decompose_from_registers( - self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] + self, + *, + context: cirq.DecompositionContext, + **quregs: NDArray[cirq.Qid], # type:ignore[type-var] ) -> cirq.OP_TREE: # Divide each data sequence and corresponding target registers into # `self.num_blocks` batches of size `self.block_size`. @@ -208,7 +215,7 @@ def decompose_from_registers( selection=q, **qrom_gate.target_registers.split_qubits(ordered_target_qubits) ) swap_with_zero_gate = swap_network.SwapWithZeroGate( - k, self.target_registers.bitsize, self.block_size + k, self.target_registers.total_bits(), self.block_size ) swap_with_zero_op = swap_with_zero_gate.on_registers( selection=r, **swap_with_zero_gate.target_registers.split_qubits(ordered_target_qubits) @@ -231,7 +238,7 @@ def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo: wire_symbols = ["In_q"] * self.selection_q wire_symbols += ["In_r"] * self.selection_r for i, target in enumerate(self.target_registers): - wire_symbols += [f"QROAM_{i}"] * target.bitsize + wire_symbols += [f"QROAM_{i}"] * target.total_bits() return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) def _value_equality_values_(self): diff --git a/cirq-ft/cirq_ft/algos/selected_majorana_fermion.py b/cirq-ft/cirq_ft/algos/selected_majorana_fermion.py index 7f6ee786dc3..501b23bb786 100644 --- a/cirq-ft/cirq_ft/algos/selected_majorana_fermion.py +++ b/cirq-ft/cirq_ft/algos/selected_majorana_fermion.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence +from typing import Sequence, Union +from numpy.typing import NDArray import attr import cirq +import numpy as np + from cirq._compat import cached_property from cirq_ft import infra from cirq_ft.algos import unary_iteration_gate @@ -45,11 +48,20 @@ class SelectedMajoranaFermionGate(unary_iteration_gate.UnaryIterationGate): target_gate: cirq.Gate = cirq.Y @classmethod - def make_on(cls, *, target_gate=cirq.Y, **quregs: Sequence[cirq.Qid]) -> cirq.Operation: + def make_on( + cls, + *, + target_gate=cirq.Y, + **quregs: Union[Sequence[cirq.Qid], NDArray[cirq.Qid]], # type: ignore[type-var] + ) -> cirq.Operation: """Helper constructor to automatically deduce selection_regs attribute.""" return cls( - selection_regs=infra.SelectionRegisters.build( - selection=(len(quregs['selection']), len(quregs['target'])) + selection_regs=infra.SelectionRegisters( + [ + infra.SelectionRegister( + 'selection', len(quregs['selection']), len(quregs['target']) + ) + ] ), target_gate=target_gate, ).on_registers(**quregs) @@ -71,20 +83,20 @@ def extra_registers(self) -> infra.Registers: return infra.Registers.build(accumulator=1) def decompose_from_registers( - self, context: cirq.DecompositionContext, **qubit_regs: Sequence[cirq.Qid] + self, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] ) -> cirq.OP_TREE: - qubit_regs['accumulator'] = context.qubit_manager.qalloc(1) - control = qubit_regs[self.control_regs[0].name] if self.control_registers.bitsize else [] - yield cirq.X(*qubit_regs['accumulator']).controlled_by(*control) + quregs['accumulator'] = np.array(context.qubit_manager.qalloc(1)) + control = quregs[self.control_regs[0].name] if self.control_registers.total_bits() else [] + yield cirq.X(*quregs['accumulator']).controlled_by(*control) yield super(SelectedMajoranaFermionGate, self).decompose_from_registers( - context=context, **qubit_regs + context=context, **quregs ) - context.qubit_manager.qfree(qubit_regs['accumulator']) + context.qubit_manager.qfree(quregs['accumulator']) def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: - wire_symbols = ["@"] * self.control_registers.bitsize - wire_symbols += ["In"] * self.selection_registers.bitsize - wire_symbols += [f"Z{self.target_gate}"] * self.target_registers.bitsize + wire_symbols = ["@"] * self.control_registers.total_bits() + wire_symbols += ["In"] * self.selection_registers.total_bits() + wire_symbols += [f"Z{self.target_gate}"] * self.target_registers.total_bits() return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) def nth_operation( # type: ignore[override] diff --git a/cirq-ft/cirq_ft/algos/selected_majorana_fermion_test.py b/cirq-ft/cirq_ft/algos/selected_majorana_fermion_test.py index 7fbfc16f37e..44870200d86 100644 --- a/cirq-ft/cirq_ft/algos/selected_majorana_fermion_test.py +++ b/cirq-ft/cirq_ft/algos/selected_majorana_fermion_test.py @@ -24,11 +24,13 @@ def test_selected_majorana_fermion_gate(selection_bitsize, target_bitsize, target_gate): greedy_mm = cirq_ft.GreedyQubitManager(prefix="_a", maximize_reuse=True) gate = cirq_ft.SelectedMajoranaFermionGate( - cirq_ft.SelectionRegisters.build(selection=(selection_bitsize, target_bitsize)), + cirq_ft.SelectionRegisters( + [cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize)] + ), target_gate=target_gate, ) g = cirq_ft.testing.GateHelper(gate, context=cirq.DecompositionContext(greedy_mm)) - assert len(g.all_qubits) <= gate.registers.bitsize + selection_bitsize + 1 + assert len(g.all_qubits) <= gate.registers.total_bits() + selection_bitsize + 1 sim = cirq.Simulator(dtype=np.complex128) for n in range(target_bitsize): @@ -64,7 +66,9 @@ def test_selected_majorana_fermion_gate(selection_bitsize, target_bitsize, targe def test_selected_majorana_fermion_gate_diagram(): selection_bitsize, target_bitsize = 3, 5 gate = cirq_ft.SelectedMajoranaFermionGate( - cirq_ft.SelectionRegisters.build(selection=(selection_bitsize, target_bitsize)), + cirq_ft.SelectionRegisters( + [cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize)] + ), target_gate=cirq.X, ) circuit = cirq.Circuit(gate.on_registers(**gate.registers.get_named_qubits())) @@ -97,7 +101,9 @@ def test_selected_majorana_fermion_gate_diagram(): def test_selected_majorana_fermion_gate_decomposed_diagram(): selection_bitsize, target_bitsize = 2, 3 gate = cirq_ft.SelectedMajoranaFermionGate( - cirq_ft.SelectionRegisters.build(selection=(selection_bitsize, target_bitsize)), + cirq_ft.SelectionRegisters( + [cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize)] + ), target_gate=cirq.X, ) greedy_mm = cirq_ft.GreedyQubitManager(prefix="_a", maximize_reuse=True) @@ -105,11 +111,13 @@ def test_selected_majorana_fermion_gate_decomposed_diagram(): context = cirq.DecompositionContext(greedy_mm) circuit = cirq.Circuit(cirq.decompose_once(g.operation, context=context)) ancillas = sorted(set(circuit.all_qubits()) - set(g.operation.qubits)) - qubits = ( - g.quregs['control'] - + [q for qs in zip(g.quregs['selection'], ancillas[1:]) for q in qs] - + ancillas[0:1] - + g.quregs['target'] + qubits = np.concatenate( + [ + g.quregs['control'], + [q for qs in zip(g.quregs['selection'], ancillas[1:]) for q in qs], + ancillas[0:1], + g.quregs['target'], + ] ) cirq.testing.assert_has_diagram( circuit, @@ -138,7 +146,9 @@ def test_selected_majorana_fermion_gate_decomposed_diagram(): def test_selected_majorana_fermion_gate_make_on(): selection_bitsize, target_bitsize = 3, 5 gate = cirq_ft.SelectedMajoranaFermionGate( - cirq_ft.SelectionRegisters.build(selection=(selection_bitsize, target_bitsize)), + cirq_ft.SelectionRegisters( + [cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize)] + ), target_gate=cirq.X, ) op = gate.on_registers(**gate.registers.get_named_qubits()) diff --git a/cirq-ft/cirq_ft/algos/state_preparation.py b/cirq-ft/cirq_ft/algos/state_preparation.py index 9dfed3c43d3..cb7fd397b03 100644 --- a/cirq-ft/cirq_ft/algos/state_preparation.py +++ b/cirq-ft/cirq_ft/algos/state_preparation.py @@ -20,7 +20,8 @@ largest absolute error that one can tolerate in the prepared amplitudes. """ -from typing import List, Sequence +from typing import List +from numpy.typing import NDArray import attr import cirq @@ -34,7 +35,6 @@ select_and_prepare, swap_network, ) -from numpy.typing import NDArray @cirq.value_equality() @@ -106,7 +106,9 @@ def from_lcu_probs( ) N = len(lcu_probabilities) return StatePreparationAliasSampling( - selection_registers=infra.SelectionRegisters.build(selection=((N - 1).bit_length(), N)), + selection_registers=infra.SelectionRegisters( + [infra.SelectionRegister('selection', (N - 1).bit_length(), N)] + ), alt=np.array(alt), keep=np.array(keep), mu=mu, @@ -118,7 +120,7 @@ def sigma_mu_bitsize(self) -> int: @cached_property def alternates_bitsize(self) -> int: - return self.selection_registers.bitsize + return self.selection_registers.total_bits() @cached_property def keep_bitsize(self) -> int: @@ -126,7 +128,7 @@ def keep_bitsize(self) -> int: @cached_property def selection_bitsize(self) -> int: - return self.selection_registers.bitsize + return self.selection_registers.total_bits() @cached_property def junk_registers(self) -> infra.Registers: @@ -157,7 +159,10 @@ def __repr__(self) -> str: ) def decompose_from_registers( - self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] + self, + *, + context: cirq.DecompositionContext, + **quregs: NDArray[cirq.Qid], # type:ignore[type-var] ) -> cirq.OP_TREE: selection, less_than_equal = quregs['selection'], quregs['less_than_equal'] sigma_mu, alt, keep = quregs['sigma_mu'], quregs['alt'], quregs['keep'] diff --git a/cirq-ft/cirq_ft/algos/swap_network.py b/cirq-ft/cirq_ft/algos/swap_network.py index f88114b3bf0..480bf0e4f3c 100644 --- a/cirq-ft/cirq_ft/algos/swap_network.py +++ b/cirq-ft/cirq_ft/algos/swap_network.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence +from typing import Sequence, Union +from numpy.typing import NDArray import attr import cirq @@ -36,7 +37,9 @@ class MultiTargetCSwap(infra.GateWithRegisters): bitsize: int @classmethod - def make_on(cls, **quregs: Sequence[cirq.Qid]) -> cirq.Operation: + def make_on( + cls, **quregs: Union[Sequence[cirq.Qid], NDArray[cirq.Qid]] # type: ignore[type-var] + ) -> cirq.Operation: """Helper constructor to automatically deduce bitsize attributes.""" return cls(bitsize=len(quregs['target_x'])).on_registers(**quregs) @@ -45,7 +48,7 @@ def registers(self) -> infra.Registers: return infra.Registers.build(control=1, target_x=self.bitsize, target_y=self.bitsize) def decompose_from_registers( - self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] + self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] ) -> cirq.OP_TREE: control, target_x, target_y = quregs['control'], quregs['target_x'], quregs['target_y'] yield [cirq.CSWAP(*control, t_x, t_y) for t_x, t_y in zip(target_x, target_y)] @@ -81,7 +84,7 @@ class MultiTargetCSwapApprox(MultiTargetCSwap): """ def decompose_from_registers( - self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] + self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] ) -> cirq.OP_TREE: control, target_x, target_y = quregs['control'], quregs['target_x'], quregs['target_y'] @@ -143,25 +146,23 @@ def __attrs_post_init__(self): @cached_property def selection_registers(self) -> infra.SelectionRegisters: - return infra.SelectionRegisters.build( - selection=(self.selection_bitsize, self.n_target_registers) + return infra.SelectionRegisters( + [infra.SelectionRegister('selection', self.selection_bitsize, self.n_target_registers)] ) @cached_property def target_registers(self) -> infra.Registers: - return infra.Registers.build( - **{f'target{i}': self.target_bitsize for i in range(self.n_target_registers)} - ) + return infra.Registers.build(target=(self.n_target_registers, self.target_bitsize)) @cached_property def registers(self) -> infra.Registers: return infra.Registers([*self.selection_registers, *self.target_registers]) def decompose_from_registers( - self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] + self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] ) -> cirq.OP_TREE: - selection, target_regs = quregs.pop('selection'), quregs - assert len(target_regs) == self.n_target_registers + selection, target = quregs['selection'], quregs['target'] + assert target.shape == (self.n_target_registers, self.target_bitsize) cswap_n = MultiTargetCSwapApprox(self.target_bitsize) # Imagine a complete binary tree of depth `logN` with `N` leaves, each denoting a target # register. If the selection register stores index `r`, we want to bring the value stored @@ -179,8 +180,8 @@ def decompose_from_registers( # The inner loop is executed at-most `N - 1` times, where `N:= len(target_regs)`. yield cswap_n.on_registers( control=selection[len(selection) - j - 1], - target_x=target_regs[f'target{i}'], - target_y=target_regs[f'target{i + 2**j}'], + target_x=target[i], + target_y=target[i + 2**j], ) def __repr__(self) -> str: diff --git a/cirq-ft/cirq_ft/algos/swap_network_test.py b/cirq-ft/cirq_ft/algos/swap_network_test.py index 73845694c2e..c934ca4dec3 100644 --- a/cirq-ft/cirq_ft/algos/swap_network_test.py +++ b/cirq-ft/cirq_ft/algos/swap_network_test.py @@ -33,12 +33,9 @@ def test_swap_with_zero_gate(selection_bitsize, target_bitsize, n_target_registe # Allocate selection and target qubits. all_qubits = cirq.LineQubit.range(cirq.num_qubits(gate)) selection = all_qubits[:selection_bitsize] - targets = { - f'target{i}': all_qubits[st : st + target_bitsize] - for i, st in enumerate(range(selection_bitsize, len(all_qubits), target_bitsize)) - } + target = np.array(all_qubits[selection_bitsize:]).reshape((n_target_registers, target_bitsize)) # Create a circuit. - circuit = cirq.Circuit(gate.on_registers(selection=selection, **targets)) + circuit = cirq.Circuit(gate.on_registers(selection=selection, target=target)) # Load data[i] in i'th target register; where each register is of size target_bitsize data = [random.randint(0, 2**target_bitsize - 1) for _ in range(n_target_registers)] diff --git a/cirq-ft/cirq_ft/algos/unary_iteration.ipynb b/cirq-ft/cirq_ft/algos/unary_iteration.ipynb index 1aab46f7e4f..003941203c3 100644 --- a/cirq-ft/cirq_ft/algos/unary_iteration.ipynb +++ b/cirq-ft/cirq_ft/algos/unary_iteration.ipynb @@ -286,7 +286,6 @@ " num_controls = self._control_bitsize + self._selection_bitsize\n", " for target_bit in range(self._target_bitsize):\n", " bit_pattern = iter_bits(target_bit, self._selection_bitsize)\n", - " num_controls = self._selection_bitsize + self._control_bitsize\n", " control_values = [1]*self._control_bitsize + list(bit_pattern)\n", " yield cirq.X.controlled(\n", " num_controls=num_controls,\n", @@ -472,7 +471,7 @@ "metadata": {}, "outputs": [], "source": [ - "from cirq_ft import Registers, SelectionRegisters, UnaryIterationGate\n", + "from cirq_ft import Registers, SelectionRegister, SelectionRegisters, UnaryIterationGate\n", "from cirq._compat import cached_property\n", "\n", "class ApplyXToLthQubit(UnaryIterationGate):\n", @@ -487,7 +486,7 @@ "\n", " @cached_property\n", " def selection_registers(self) -> SelectionRegisters:\n", - " return SelectionRegisters.build(selection=(self._selection_bitsize, self._target_bitsize))\n", + " return SelectionRegisters([SelectionRegister('selection', self._selection_bitsize, self._target_bitsize)])\n", "\n", " @cached_property\n", " def target_registers(self) -> Registers:\n", diff --git a/cirq-ft/cirq_ft/algos/unary_iteration_gate.py b/cirq-ft/cirq_ft/algos/unary_iteration_gate.py index 668f9f2d753..e4dc67c8596 100644 --- a/cirq-ft/cirq_ft/algos/unary_iteration_gate.py +++ b/cirq-ft/cirq_ft/algos/unary_iteration_gate.py @@ -14,8 +14,11 @@ import abc from typing import Dict, Iterator, List, Sequence, Tuple +from numpy.typing import NDArray import cirq +import numpy as np + from cirq._compat import cached_property from cirq_ft import infra from cirq_ft.algos import and_gate @@ -134,7 +137,7 @@ def _unary_iteration_multi_controls( and_ancilla = qm.qalloc(num_controls - 2) and_target = qm.qalloc(1)[0] multi_controlled_and = and_gate.And((1,) * len(controls)).on_registers( - control=controls, ancilla=and_ancilla, target=and_target + control=np.array(controls), ancilla=np.array(and_ancilla), target=and_target ) ops.append(multi_controlled_and) yield from _unary_iteration_single_control(ops, and_target, selection, l_iter, r_iter, qm) @@ -282,14 +285,16 @@ def nth_operation( """ def decompose_zero_selection( - self, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] + self, + context: cirq.DecompositionContext, + **quregs: NDArray[cirq.Qid], # type: ignore[type-var] ) -> cirq.OP_TREE: """Specify decomposition of the gate when selection register is empty By default, if the selection register is empty, the decomposition will raise a `NotImplementedError`. The derived classes can override this method and specify a custom decomposition that should be used if the selection register is empty, - i.e. `self.selection_registers.bitsize == 0`. + i.e. `self.selection_registers.total_bits() == 0`. The derived classes should specify the following arguments as `**kwargs`: 1) Register names in `self.control_registers`: Each argument corresponds to a @@ -302,9 +307,9 @@ def decompose_zero_selection( raise NotImplementedError("Selection register must not be empty.") def decompose_from_registers( - self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] + self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] ) -> cirq.OP_TREE: - if self.selection_registers.bitsize == 0: + if self.selection_registers.total_bits() == 0: return self.decompose_zero_selection(context=context, **quregs) num_loops = len(self.selection_registers) @@ -352,7 +357,7 @@ def unary_iteration_loops( r_iter=self.selection_registers[nested_depth].iteration_length, flanking_ops=ops, controls=controls, - selection=quregs[self.selection_registers[nested_depth].name], + selection=[*quregs[self.selection_registers[nested_depth].name]], qubit_manager=context.qubit_manager, ) for op_tree, control_qid, n in ith_for_loop: @@ -371,7 +376,7 @@ def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.Circ Descendants are encouraged to override this with more descriptive circuit diagram information. """ - wire_symbols = ["@"] * self.control_registers.bitsize - wire_symbols += ["In"] * self.selection_registers.bitsize - wire_symbols += [self.__class__.__name__] * self.target_registers.bitsize + wire_symbols = ["@"] * self.control_registers.total_bits() + wire_symbols += ["In"] * self.selection_registers.total_bits() + wire_symbols += [self.__class__.__name__] * self.target_registers.total_bits() return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) diff --git a/cirq-ft/cirq_ft/algos/unary_iteration_gate_test.py b/cirq-ft/cirq_ft/algos/unary_iteration_gate_test.py index 62af16eaa2c..ffa50adc940 100644 --- a/cirq-ft/cirq_ft/algos/unary_iteration_gate_test.py +++ b/cirq-ft/cirq_ft/algos/unary_iteration_gate_test.py @@ -35,8 +35,8 @@ def control_registers(self) -> cirq_ft.Registers: @cached_property def selection_registers(self) -> cirq_ft.SelectionRegisters: - return cirq_ft.SelectionRegisters.build( - selection=(self._selection_bitsize, self._target_bitsize) + return cirq_ft.SelectionRegisters( + [cirq_ft.SelectionRegister('selection', self._selection_bitsize, self._target_bitsize)] ) @cached_property @@ -63,7 +63,6 @@ def test_unary_iteration_gate(selection_bitsize, target_bitsize, control_bitsize assert len(g.all_qubits) <= 2 * (selection_bitsize + control_bitsize) + target_bitsize - 1 for n in range(target_bitsize): - # Initial qubit values qubit_vals = {q: 0 for q in g.operation.qubits} # All controls 'on' to activate circuit @@ -89,10 +88,13 @@ def control_registers(self) -> cirq_ft.Registers: @cached_property def selection_registers(self) -> cirq_ft.SelectionRegisters: - return cirq_ft.SelectionRegisters.build( - i=((self._target_shape[0] - 1).bit_length(), self._target_shape[0]), - j=((self._target_shape[1] - 1).bit_length(), self._target_shape[1]), - k=((self._target_shape[2] - 1).bit_length(), self._target_shape[2]), + return cirq_ft.SelectionRegisters( + [ + cirq_ft.SelectionRegister( + 'ijk'[i], (self._target_shape[i] - 1).bit_length(), self._target_shape[i] + ) + for i in range(3) + ] ) @cached_property @@ -120,10 +122,12 @@ def test_multi_dimensional_unary_iteration_gate(target_shape: Tuple[int, int, in greedy_mm = cirq_ft.GreedyQubitManager(prefix="_a", maximize_reuse=True) gate = ApplyXToIJKthQubit(target_shape) g = cirq_ft.testing.GateHelper(gate, context=cirq.DecompositionContext(greedy_mm)) - assert len(g.all_qubits) <= gate.registers.bitsize + gate.selection_registers.bitsize - 1 + assert ( + len(g.all_qubits) <= gate.registers.total_bits() + gate.selection_registers.total_bits() - 1 + ) max_i, max_j, max_k = target_shape - i_len, j_len, k_len = tuple(reg.bitsize for reg in gate.selection_registers) + i_len, j_len, k_len = tuple(reg.total_bits() for reg in gate.selection_registers) for i, j, k in itertools.product(range(max_i), range(max_j), range(max_k)): qubit_vals = {x: 0 for x in g.operation.qubits} # Initialize selection bits appropriately: @@ -143,7 +147,9 @@ def test_multi_dimensional_unary_iteration_gate(target_shape: Tuple[int, int, in def test_unary_iteration_loop(): n_range, m_range = (3, 5), (6, 8) - selection_registers = cirq_ft.SelectionRegisters.build(n=(3, 5), m=(3, 8)) + selection_registers = cirq_ft.SelectionRegisters( + [cirq_ft.SelectionRegister('n', 3, 5), cirq_ft.SelectionRegister('m', 3, 8)] + ) selection = selection_registers.get_named_qubits() target = {(n, m): cirq.q(f't({n}, {m})') for n in range(*n_range) for m in range(*m_range)} qm = cirq_ft.GreedyQubitManager("ancilla", maximize_reuse=True) diff --git a/cirq-ft/cirq_ft/infra/gate_with_registers.ipynb b/cirq-ft/cirq_ft/infra/gate_with_registers.ipynb index 1f9abd71809..70e4a6e59ba 100644 --- a/cirq-ft/cirq_ft/infra/gate_with_registers.ipynb +++ b/cirq-ft/cirq_ft/infra/gate_with_registers.ipynb @@ -39,7 +39,7 @@ "source": [ "## `Registers`\n", "\n", - "`Register` objects have a name and a bitsize. `Registers` is an ordered collection of `Register` with some helpful methods." + "`Register` objects have a name and a shape. `Registers` is an ordered collection of `Register` with some helpful methods." ] }, { @@ -51,8 +51,8 @@ "source": [ "from cirq_ft import Register, Registers\n", "\n", - "control_reg = Register(name='control', bitsize=2)\n", - "target_reg = Register(name='target', bitsize=3)\n", + "control_reg = Register(name='control', shape=(2,))\n", + "target_reg = Register(name='target', shape=(3,))\n", "control_reg, target_reg" ] }, @@ -239,4 +239,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/cirq-ft/cirq_ft/infra/gate_with_registers.py b/cirq-ft/cirq_ft/infra/gate_with_registers.py index 0a8047227e0..04fb9e88ee7 100644 --- a/cirq-ft/cirq_ft/infra/gate_with_registers.py +++ b/cirq-ft/cirq_ft/infra/gate_with_registers.py @@ -13,7 +13,9 @@ # limitations under the License. import abc +import itertools from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union, overload +from numpy.typing import NDArray import attr import cirq @@ -26,14 +28,27 @@ class Register: Args: name: The string name of the register - bitsize: The number of (qu)bits in the register. + shape: Shape of the multi-dimensional qubit register. """ name: str - bitsize: int + shape: Tuple[int, ...] = attr.field( + converter=lambda v: (v,) if isinstance(v, int) else tuple(v) + ) + + def all_idxs(self) -> Iterable[Tuple[int, ...]]: + """Iterate over all possible indices of a multidimensional register.""" + yield from itertools.product(*[range(sh) for sh in self.shape]) + + def total_bits(self) -> int: + """The total number of bits in this register. + + This is the product of bitsize and each of the dimensions in `shape`. + """ + return int(np.product(self.shape)) def __repr__(self): - return f'cirq_ft.Register("{self.name}", {self.bitsize})' + return f'cirq_ft.Register(name="{self.name}", shape={self.shape})' class Registers: @@ -52,13 +67,12 @@ def __init__(self, registers: Iterable[Register]): def __repr__(self): return f'cirq_ft.Registers({self._registers})' - @property - def bitsize(self) -> int: - return sum(reg.bitsize for reg in self) + def total_bits(self) -> int: + return sum(reg.total_bits() for reg in self) @classmethod - def build(cls, **registers: int) -> 'Registers': - return cls(Register(name=k, bitsize=v) for k, v in registers.items()) + def build(cls, **registers: Union[int, Tuple[int, ...]]) -> 'Registers': + return cls(Register(name=k, shape=v) for k, v in registers.items()) @overload def __getitem__(self, key: int) -> Register: @@ -91,35 +105,51 @@ def __iter__(self): def __len__(self) -> int: return len(self._registers) - def split_qubits(self, qubits: Sequence[cirq.Qid]) -> Dict[str, Sequence[cirq.Qid]]: + def split_qubits( + self, qubits: Sequence[cirq.Qid] + ) -> Dict[str, NDArray[cirq.Qid]]: # type: ignore[type-var] qubit_regs = {} base = 0 for reg in self: - qubit_regs[reg.name] = qubits[base : base + reg.bitsize] - base += reg.bitsize + qubit_regs[reg.name] = np.array(qubits[base : base + reg.total_bits()]).reshape( + reg.shape + ) + base += reg.total_bits() return qubit_regs - def merge_qubits(self, **qubit_regs: Union[cirq.Qid, Sequence[cirq.Qid]]) -> List[cirq.Qid]: + def merge_qubits( + self, **qubit_regs: Union[cirq.Qid, Sequence[cirq.Qid], NDArray[cirq.Qid]] + ) -> List[cirq.Qid]: ret: List[cirq.Qid] = [] for reg in self: assert reg.name in qubit_regs, "All qubit registers must pe present" qubits = qubit_regs[reg.name] - qubits = [qubits] if isinstance(qubits, cirq.Qid) else qubits + qubits = np.array([qubits] if isinstance(qubits, cirq.Qid) else qubits) assert ( - len(qubits) == reg.bitsize - ), f"{reg.name} register must of length {reg.bitsize} but is of length {len(qubits)}" - ret += qubits + qubits.shape == reg.shape + ), f'{reg.name} register must of shape {reg.shape} but is of shape {qubits.shape}' + ret += qubits.flatten().tolist() return ret - def get_named_qubits(self) -> Dict[str, List[cirq.Qid]]: - def qubits_for_reg(name: str, bitsize: int): - return ( - [cirq.NamedQubit(f"{name}")] - if bitsize == 1 - else cirq.NamedQubit.range(bitsize, prefix=name) + def get_named_qubits(self) -> Dict[str, NDArray[cirq.Qid]]: + def _qubit_array(reg: Register): + qubits = np.empty(reg.shape, dtype=object) + for ii in reg.all_idxs(): + qubits[ii] = cirq.NamedQubit(f'{reg.name}[{", ".join(str(i) for i in ii)}]') + return qubits + + def _qubits_for_reg(reg: Register): + if len(reg.shape) > 1: + return _qubit_array(reg) + + return np.array( + [cirq.NamedQubit(f"{reg.name}")] + if reg.total_bits() == 1 + else cirq.NamedQubit.range(reg.total_bits(), prefix=reg.name), + dtype=object, ) - return {reg.name: qubits_for_reg(reg.name, reg.bitsize) for reg in self} + return {reg.name: _qubits_for_reg(reg) for reg in self._registers} def __eq__(self, other) -> bool: return self._registers == other._registers @@ -138,13 +168,24 @@ class SelectionRegister(Register): iteration_length: int = attr.field() + @iteration_length.default + def _default_iteration_length(self): + return 2 ** self.shape[0] + @iteration_length.validator def validate_iteration_length(self, attribute, value): - if not (0 <= value <= 2**self.bitsize): - raise ValueError(f'iteration length must be in range [0, 2^{self.bitsize}]') + if len(self.shape) != 1: + raise ValueError(f'Selection register {self.name} should be flat. Found {self.shape=}') + if not (0 <= value <= 2 ** self.shape[0]): + raise ValueError(f'iteration length must be in range [0, 2^{self.shape[0]}]') def __repr__(self) -> str: - return f'cirq_ft.SelectionRegister("{self.name}", {self.bitsize}, {self.iteration_length})' + return ( + f'cirq_ft.SelectionRegister(' + f'name="{self.name}", ' + f'shape={self.shape}, ' + f'iteration_length={self.iteration_length})' + ) class SelectionRegisters(Registers): @@ -203,16 +244,8 @@ def total_iteration_size(self) -> int: return int(np.product(self.iteration_lengths)) @classmethod - def build(cls, **registers: Union[int, Tuple[int, int]]) -> 'SelectionRegisters': - reg_dict: Dict[str, Tuple[int, int]] = { - k: v if isinstance(v, tuple) else (v, 2**v) for k, v in registers.items() - } - return SelectionRegisters( - [ - SelectionRegister(name=k, bitsize=v[0], iteration_length=v[1]) - for k, v in reg_dict.items() - ] - ) + def build(cls, **registers: Union[int, Tuple[int, ...]]) -> 'SelectionRegisters': + return cls(SelectionRegister(name=k, shape=v) for k, v in registers.items()) @overload def __getitem__(self, key: int) -> SelectionRegister: @@ -294,10 +327,10 @@ def registers(self) -> Registers: ... def _num_qubits_(self) -> int: - return self.registers.bitsize + return self.registers.total_bits() def decompose_from_registers( - self, *, context: cirq.DecompositionContext, **quregs: Sequence[cirq.Qid] + self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] ) -> cirq.OP_TREE: return NotImplemented @@ -312,7 +345,9 @@ def _decompose_with_context_( def _decompose_(self, qubits: Sequence[cirq.Qid]) -> cirq.OP_TREE: return self._decompose_with_context_(qubits) - def on_registers(self, **qubit_regs: Union[cirq.Qid, Sequence[cirq.Qid]]) -> cirq.Operation: + def on_registers( + self, **qubit_regs: Union[cirq.Qid, Sequence[cirq.Qid], NDArray[cirq.Qid]] + ) -> cirq.Operation: return self.on(*self.registers.merge_qubits(**qubit_regs)) def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: @@ -322,7 +357,7 @@ def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.Circ """ wire_symbols = [] for reg in self.registers: - wire_symbols += [reg.name] * reg.bitsize + wire_symbols += [reg.name] * reg.total_bits() wire_symbols[0] = self.__class__.__name__ return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) diff --git a/cirq-ft/cirq_ft/infra/gate_with_registers_test.py b/cirq-ft/cirq_ft/infra/gate_with_registers_test.py index 69ba4aa26aa..a3442eb0554 100644 --- a/cirq-ft/cirq_ft/infra/gate_with_registers_test.py +++ b/cirq-ft/cirq_ft/infra/gate_with_registers_test.py @@ -14,13 +14,14 @@ import cirq import cirq_ft +import numpy as np import pytest from cirq_ft.infra.jupyter_tools import execute_notebook def test_register(): r = cirq_ft.Register("my_reg", 5) - assert r.bitsize == 5 + assert r.shape == (5,) def test_registers(): @@ -50,9 +51,9 @@ def test_registers(): qubits = cirq.LineQubit.range(8) qregs = regs.split_qubits(qubits) - assert qregs["r1"] == cirq.LineQubit.range(5) - assert qregs["r2"] == cirq.LineQubit.range(5, 5 + 2) - assert qregs["r3"] == [cirq.LineQubit(7)] + assert qregs["r1"].tolist() == cirq.LineQubit.range(5) + assert qregs["r2"].tolist() == cirq.LineQubit.range(5, 5 + 2) + assert qregs["r3"].tolist() == [cirq.LineQubit(7)] qubits = qubits[::-1] merged_qregs = regs.merge_qubits(r1=qubits[:5], r2=qubits[5:7], r3=qubits[-1]) @@ -63,7 +64,11 @@ def test_registers(): "r2": cirq.NamedQubit.range(2, prefix="r2"), "r3": [cirq.NamedQubit("r3")], } - assert regs.get_named_qubits() == expected_named_qubits + + named_qregs = regs.get_named_qubits() + for reg_name in expected_named_qubits: + assert np.array_equal(named_qregs[reg_name], expected_named_qubits[reg_name]) + # Python dictionaries preserve insertion order, which should be same as insertion order of # initial registers. for reg_order in [[r1, r2, r3], [r2, r3, r1]]: @@ -76,7 +81,9 @@ def test_registers(): @pytest.mark.parametrize('n, N, m, M', [(4, 10, 5, 19), (4, 16, 5, 32)]) def test_selection_registers_indexing(n, N, m, M): - reg = cirq_ft.SelectionRegisters.build(x=(n, N), y=(m, M)) + reg = cirq_ft.SelectionRegisters( + [cirq_ft.SelectionRegister('x', n, N), cirq_ft.SelectionRegister('y', m, M)] + ) assert reg.iteration_lengths == (N, M) for x in range(N): for y in range(M): @@ -89,7 +96,15 @@ def test_selection_registers_consistent(): with pytest.raises(ValueError, match="iteration length must be in "): _ = cirq_ft.SelectionRegister('a', 3, 10) - selection_reg = cirq_ft.SelectionRegisters.build(n=(3, 5), m=(4, 12)) + with pytest.raises(ValueError, match="should be flat"): + _ = cirq_ft.SelectionRegister('a', (3, 5), 5) + + selection_reg = cirq_ft.SelectionRegisters( + [ + cirq_ft.SelectionRegister('n', shape=3, iteration_length=5), + cirq_ft.SelectionRegister('m', shape=4, iteration_length=12), + ] + ) assert selection_reg[0] == cirq_ft.SelectionRegister('n', 3, 5) assert selection_reg['n'] == cirq_ft.SelectionRegister('n', 3, 5) assert selection_reg[1] == cirq_ft.SelectionRegister('m', 4, 12) @@ -101,7 +116,9 @@ def test_registers_getitem_raises(): with pytest.raises(IndexError, match="must be of the type"): _ = g[2.5] - selection_reg = cirq_ft.SelectionRegisters.build(n=(3, 5)) + selection_reg = cirq_ft.SelectionRegisters( + [cirq_ft.SelectionRegister('n', shape=3, iteration_length=5)] + ) with pytest.raises(IndexError, match='must be of the type'): _ = selection_reg[2.5] diff --git a/cirq-ft/cirq_ft/infra/testing.py b/cirq-ft/cirq_ft/infra/testing.py index cf31ce91b62..6ceb21d5c2a 100644 --- a/cirq-ft/cirq_ft/infra/testing.py +++ b/cirq-ft/cirq_ft/infra/testing.py @@ -14,7 +14,7 @@ from dataclasses import dataclass from typing import Any, Dict, List, Sequence, Tuple - +from numpy.typing import NDArray import cirq import numpy as np from cirq._compat import cached_property @@ -42,7 +42,7 @@ def r(self) -> gate_with_registers.Registers: return self.gate.registers @cached_property - def quregs(self) -> Dict[str, List[cirq.Qid]]: + def quregs(self) -> Dict[str, NDArray[cirq.Qid]]: # type: ignore[type-var] """A dictionary of named qubits appropriate for the registers for the gate.""" return self.r.get_named_qubits() diff --git a/cirq-ft/cirq_ft/infra/testing_test.py b/cirq-ft/cirq_ft/infra/testing_test.py index 8af3fee6a0a..ad3c8a83674 100644 --- a/cirq-ft/cirq_ft/infra/testing_test.py +++ b/cirq-ft/cirq_ft/infra/testing_test.py @@ -14,6 +14,7 @@ import cirq import cirq_ft +import numpy as np import pytest @@ -34,11 +35,13 @@ def test_gate_helper(): g = cirq_ft.testing.GateHelper(cirq_ft.And(cv=(1, 0, 1, 0))) assert g.gate == cirq_ft.And(cv=(1, 0, 1, 0)) assert g.r == cirq_ft.Registers.build(control=4, ancilla=2, target=1) - assert g.quregs == { + expected_quregs = { 'control': cirq.NamedQubit.range(4, prefix='control'), 'ancilla': cirq.NamedQubit.range(2, prefix='ancilla'), 'target': [cirq.NamedQubit('target')], } + for key in expected_quregs: + assert np.array_equal(g.quregs[key], expected_quregs[key]) assert g.operation.qubits == tuple(g.all_qubits) assert len(g.circuit) == 1