-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Add 2D GHZ state generation #7767
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,150 @@ | ||
| # Copyright 2025 The Cirq Developers | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Functions for generating and transforming 2D GHZ circuits.""" | ||
|
|
||
| import networkx as nx | ||
| import numpy as np | ||
|
|
||
| import cirq.circuits as circuits | ||
| import cirq.devices as devices | ||
| import cirq.ops as ops | ||
| import cirq.protocols as protocols | ||
| import cirq.transformers as transformers | ||
|
|
||
|
|
||
| def _transform_circuit(circuit: circuits.Circuit) -> circuits.Circuit: | ||
| """Transforms a Cirq circuit by applying a series of modifications. | ||
|
|
||
| This is an internal helper function used exclusively by | ||
| `generate_2d_ghz_circuit` when `add_dd_and_align_right` is True. | ||
|
|
||
| The transformations for a circuit include: | ||
| 1. Adding a measurement to all qubits with a key 'm'. | ||
| It serves as a stopping gate for the DD operation. | ||
| 2. Aligning the circuit and merging single-qubit gates. | ||
| 3. Stratifying the operations based on qubit count | ||
| (1-qubit and 2-qubit gates). | ||
| 4. Applying dynamical decoupling to mitigate noise. | ||
| 5. Removing the final measurement operation to yield | ||
| the state preparation circuit. | ||
|
|
||
| Args: | ||
| circuit: A cirq.Circuit object. | ||
|
|
||
| Returns: | ||
| The modified cirq.Circuit object. | ||
| """ | ||
| qubits = list(circuit.all_qubits()) | ||
| circuit = circuit + circuits.Circuit(ops.measure(*qubits, key="m")) | ||
| circuit = transformers.align_right(transformers.merge_single_qubit_gates_to_phxz(circuit)) | ||
| circuit = transformers.stratified_circuit( | ||
| circuit[::-1], categories=[lambda op: protocols.num_qubits(op) == k for k in (1, 2)] | ||
| )[::-1] | ||
| circuit = transformers.add_dynamical_decoupling(circuit) | ||
| circuit = circuits.Circuit(circuit[:-1]) | ||
| return circuit | ||
|
|
||
|
|
||
| def generate_2d_ghz_circuit( | ||
| center: devices.GridQubit, | ||
| graph: nx.Graph, | ||
| num_qubits: int, | ||
| randomized: bool = False, | ||
| rng_or_seed: int | np.random.Generator | None = None, | ||
| add_dd_and_align_right: bool = False, | ||
| ) -> circuits.Circuit: | ||
| """Generates a 2D GHZ state circuit with 'num_qubits' qubits using BFS. | ||
|
|
||
| The circuit is constructed by connecting qubits | ||
| sequentially based on graph connectivity, | ||
| starting from the 'center' qubit. | ||
| The GHZ state is built using a series of H-CZ-H | ||
| gate sequences. | ||
|
|
||
|
|
||
| Args: | ||
| center: The starting qubit for the GHZ state. | ||
| graph: The connectivity graph of the qubits. | ||
| num_qubits: The number of qubits for the final | ||
| GHZ state. Must be greater than 0, | ||
| and less than or equal to | ||
| the total number of qubits | ||
| on the processor. | ||
| randomized: If True, neighbors are | ||
| added to the circuit in a random order. | ||
| If False, they are | ||
| added by distance from the center. | ||
| rng_or_seed: An optional seed or numpy random number | ||
| generator. Used only when randomized is True | ||
| add_dd_and_align_right: If True, adds dynamical | ||
| decoupling and aligns right. | ||
|
|
||
| Returns: | ||
| A cirq.Circuit object for the GHZ state. | ||
|
|
||
| Raises: | ||
| ValueError: If num_qubits is non-positive or exceeds the total | ||
| number of qubits on the processor. | ||
| """ | ||
shashwatk1998 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if num_qubits <= 0: | ||
| raise ValueError("num_qubits must be a positive integer.") | ||
|
|
||
| if num_qubits > len(graph.nodes): | ||
| raise ValueError("num_qubits cannot exceed the total number of qubits on the processor.") | ||
|
|
||
| if randomized: | ||
| rng = ( | ||
| rng_or_seed | ||
| if isinstance(rng_or_seed, np.random.Generator) | ||
| else np.random.default_rng(rng_or_seed) | ||
| ) | ||
|
|
||
| def sort_neighbors_fn(neighbors: list) -> list: | ||
| """If 'randomized' is True, sort the neighbors randomly.""" | ||
| neighbors = list(neighbors) | ||
| rng.shuffle(neighbors) | ||
| return neighbors | ||
|
|
||
| else: | ||
|
|
||
| def sort_neighbors_fn(neighbors: list) -> list: | ||
| """If 'randomized' is False, sort the neighbors as per | ||
| distance from the center. | ||
| """ | ||
| return sorted( | ||
| neighbors, key=lambda q: (q.row - center.row) ** 2 + (q.col - center.col) ** 2 | ||
| ) | ||
|
|
||
| bfs_tree = nx.bfs_tree(graph, center, sort_neighbors=sort_neighbors_fn) | ||
| qubits_to_include = list(bfs_tree.nodes)[:num_qubits] | ||
| final_tree = bfs_tree.subgraph(qubits_to_include) | ||
|
|
||
| ghz_ops = [] | ||
|
|
||
| for node in nx.topological_sort(final_tree): | ||
| # Handling the center qubit first | ||
| if node == center: | ||
| ghz_ops.append(ops.H(node)) | ||
| continue | ||
|
|
||
| for parent in final_tree.predecessors(node): | ||
| ghz_ops.extend([ops.H(node), ops.CZ(parent, node), ops.H(node)]) | ||
|
|
||
| circuit = circuits.Circuit(ghz_ops) | ||
|
|
||
| if add_dd_and_align_right: | ||
| return _transform_circuit(circuit) | ||
| else: | ||
| return circuit | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,155 @@ | ||
| # Copyright 2025 The Cirq Developers | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Tests for generating and validating 2D GHZ state circuits.""" | ||
|
|
||
| from typing import cast | ||
|
|
||
| import networkx as nx | ||
| import numpy as np | ||
| import pytest | ||
|
|
||
| import cirq | ||
| import cirq.experiments.ghz_2d as ghz_2d | ||
|
|
||
|
|
||
| def _create_mock_graph(): | ||
| qubits = cirq.GridQubit.rect(6, 6) | ||
| g = nx.Graph() | ||
| for q in qubits: | ||
| g.add_node(q) | ||
| if q.col + 1 < 6: | ||
| g.add_edge(q, cirq.GridQubit(q.row, q.col + 1)) | ||
| if q.row + 1 < 6: | ||
| g.add_edge(q, cirq.GridQubit(q.row + 1, q.col)) | ||
| return g, cirq.GridQubit(3, 3) | ||
|
|
||
|
|
||
| graph, center_qubit = _create_mock_graph() | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("num_qubits", list(range(1, len(graph.nodes) + 1))) | ||
| @pytest.mark.parametrize("randomized", [True, False]) | ||
| @pytest.mark.parametrize("add_dd_and_align_right", [True, False]) | ||
| def test_ghz_circuits_size(num_qubits: int, randomized: bool, add_dd_and_align_right: bool) -> None: | ||
| """Tests the size of the GHZ circuits.""" | ||
| circuit = ghz_2d.generate_2d_ghz_circuit( | ||
| center_qubit, | ||
| graph, | ||
| num_qubits=num_qubits, | ||
| randomized=randomized, | ||
| add_dd_and_align_right=add_dd_and_align_right, | ||
| ) | ||
| assert len(circuit.all_qubits()) == num_qubits | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("num_qubits", [2, 3, 4, 5, 6, 8, 10]) | ||
| @pytest.mark.parametrize("randomized", [True, False]) | ||
| @pytest.mark.parametrize("add_dd_and_align_right", [True, False]) # , True | ||
| def test_ghz_circuits_state( | ||
| num_qubits: int, randomized: bool, add_dd_and_align_right: bool | ||
| ) -> None: | ||
| """Tests the state vector form of the GHZ circuits.""" | ||
|
|
||
| circuit = ghz_2d.generate_2d_ghz_circuit( | ||
| center_qubit, | ||
| graph, | ||
| num_qubits=num_qubits, | ||
| randomized=randomized, | ||
| add_dd_and_align_right=add_dd_and_align_right, | ||
| ) | ||
|
|
||
| simulator = cirq.Simulator() | ||
| result = simulator.simulate(circuit) | ||
| state = result.final_state_vector | ||
|
|
||
| np.testing.assert_allclose(np.abs(state[0]), 1 / np.sqrt(2), atol=1e-7) | ||
| np.testing.assert_allclose(np.abs(state[-1]), 1 / np.sqrt(2), atol=1e-7) | ||
|
|
||
| if num_qubits > 1: | ||
| np.testing.assert_allclose(state[1:-1], 0) | ||
|
|
||
|
|
||
| def test_transform_circuit_properties() -> None: | ||
| """Tests that _transform_circuit preserves circuit properties.""" | ||
| circuit = ghz_2d.generate_2d_ghz_circuit( | ||
| center_qubit, graph, num_qubits=9, randomized=False, add_dd_and_align_right=False | ||
| ) | ||
| transformed_circuit = ghz_2d._transform_circuit(circuit) | ||
|
|
||
| assert transformed_circuit.all_qubits() == circuit.all_qubits() | ||
|
|
||
| assert len(transformed_circuit) >= len(circuit) | ||
|
|
||
| final_moment = transformed_circuit[-1] | ||
| assert not any(isinstance(op.gate, cirq.MeasurementGate) for op in final_moment) | ||
|
|
||
| assert cirq.equal_up_to_global_phase(circuit.unitary(), transformed_circuit.unitary()) | ||
|
|
||
|
|
||
| def manhattan_distance(q1: cirq.GridQubit, q2: cirq.GridQubit) -> int: | ||
| """Calculates the Manhattan distance between two GridQubits.""" | ||
| return abs(q1.row - q2.row) + abs(q1.col - q2.col) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("num_qubits", [2, 4, 9, 15, 20]) | ||
| def test_ghz_circuits_bfs_order(num_qubits: int) -> None: | ||
| """Verifies that the circuit construction maintains BFS order""" | ||
|
|
||
| circuit = ghz_2d.generate_2d_ghz_circuit( | ||
| center_qubit, | ||
| graph, | ||
| num_qubits=num_qubits, | ||
| randomized=False, # Test must run on the deterministic BFS order | ||
| add_dd_and_align_right=False, # Test must run on the raw circuit | ||
| ) | ||
|
|
||
| max_dist_seen = 0 | ||
|
|
||
| for moment in circuit: | ||
| for op in moment: | ||
| if isinstance(op.gate, cirq.CZPowGate): | ||
| qubits = op.qubits | ||
|
|
||
| dist_q0 = manhattan_distance(center_qubit, cast(cirq.GridQubit, qubits[0])) | ||
| dist_q1 = manhattan_distance(center_qubit, cast(cirq.GridQubit, qubits[1])) | ||
|
|
||
| child_qubit_distance = max(dist_q0, dist_q1) | ||
|
|
||
| if child_qubit_distance > max_dist_seen: | ||
| assert child_qubit_distance == max_dist_seen + 1 | ||
| max_dist_seen = child_qubit_distance | ||
|
|
||
| assert child_qubit_distance <= max_dist_seen | ||
|
|
||
| included_qubits = circuit.all_qubits() | ||
| if included_qubits: | ||
| max_dist_required = max( | ||
| manhattan_distance(center_qubit, cast(cirq.GridQubit, q)) for q in included_qubits | ||
| ) | ||
| assert max_dist_seen == max_dist_required | ||
|
|
||
|
|
||
| def test_ghz_invalid_inputs(): | ||
| """Tests that the function raises errors for invalid inputs.""" | ||
|
|
||
| with pytest.raises(ValueError, match="num_qubits must be a positive integer."): | ||
| ghz_2d.generate_2d_ghz_circuit(center_qubit, graph, num_qubits=0) # invalid | ||
|
|
||
| with pytest.raises( | ||
| ValueError, match="num_qubits cannot exceed the total number of qubits on the processor." | ||
| ): | ||
| ghz_2d.generate_2d_ghz_circuit( | ||
| center_qubit, graph, num_qubits=len(graph.nodes) + 1 # invalid | ||
| ) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.