Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cirq-core/cirq/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,5 @@
z_phase_calibration_workflow as z_phase_calibration_workflow,
calibrate_z_phases as calibrate_z_phases,
)

from cirq.experiments.ghz_2d import generate_2d_ghz_circuit as generate_2d_ghz_circuit
150 changes: 150 additions & 0 deletions cirq-core/cirq/experiments/ghz_2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright 2025 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions for generating and transforming 2D GHZ circuits."""

import networkx as nx
import numpy as np

import cirq.circuits as circuits
import cirq.devices as devices
import cirq.ops as ops
import cirq.protocols as protocols
import cirq.transformers as transformers


def _transform_circuit(circuit: circuits.Circuit) -> circuits.Circuit:
"""Transforms a Cirq circuit by applying a series of modifications.

This is an internal helper function used exclusively by
`generate_2d_ghz_circuit` when `add_dd_and_align_right` is True.

The transformations for a circuit include:
1. Adding a measurement to all qubits with a key 'm'.
It serves as a stopping gate for the DD operation.
2. Aligning the circuit and merging single-qubit gates.
3. Stratifying the operations based on qubit count
(1-qubit and 2-qubit gates).
4. Applying dynamical decoupling to mitigate noise.
5. Removing the final measurement operation to yield
the state preparation circuit.

Args:
circuit: A cirq.Circuit object.

Returns:
The modified cirq.Circuit object.
"""
qubits = list(circuit.all_qubits())
circuit = circuit + circuits.Circuit(ops.measure(*qubits, key="m"))
circuit = transformers.align_right(transformers.merge_single_qubit_gates_to_phxz(circuit))
circuit = transformers.stratified_circuit(
circuit[::-1], categories=[lambda op: protocols.num_qubits(op) == k for k in (1, 2)]
)[::-1]
circuit = transformers.add_dynamical_decoupling(circuit)
circuit = circuits.Circuit(circuit[:-1])
return circuit


def generate_2d_ghz_circuit(
center: devices.GridQubit,
graph: nx.Graph,
num_qubits: int,
randomized: bool = False,
rng_or_seed: int | np.random.Generator | None = None,
add_dd_and_align_right: bool = False,
) -> circuits.Circuit:
"""Generates a 2D GHZ state circuit with 'num_qubits' qubits using BFS.

The circuit is constructed by connecting qubits
sequentially based on graph connectivity,
starting from the 'center' qubit.
The GHZ state is built using a series of H-CZ-H
gate sequences.


Args:
center: The starting qubit for the GHZ state.
graph: The connectivity graph of the qubits.
num_qubits: The number of qubits for the final
GHZ state. Must be greater than 0,
and less than or equal to
the total number of qubits
on the processor.
randomized: If True, neighbors are
added to the circuit in a random order.
If False, they are
added by distance from the center.
rng_or_seed: An optional seed or numpy random number
generator. Used only when randomized is True
add_dd_and_align_right: If True, adds dynamical
decoupling and aligns right.

Returns:
A cirq.Circuit object for the GHZ state.

Raises:
ValueError: If num_qubits is non-positive or exceeds the total
number of qubits on the processor.
"""
if num_qubits <= 0:
raise ValueError("num_qubits must be a positive integer.")

if num_qubits > len(graph.nodes):
raise ValueError("num_qubits cannot exceed the total number of qubits on the processor.")

if randomized:
rng = (
rng_or_seed
if isinstance(rng_or_seed, np.random.Generator)
else np.random.default_rng(rng_or_seed)
)

def sort_neighbors_fn(neighbors: list) -> list:
"""If 'randomized' is True, sort the neighbors randomly."""
neighbors = list(neighbors)
rng.shuffle(neighbors)
return neighbors

else:

def sort_neighbors_fn(neighbors: list) -> list:
"""If 'randomized' is False, sort the neighbors as per
distance from the center.
"""
return sorted(
neighbors, key=lambda q: (q.row - center.row) ** 2 + (q.col - center.col) ** 2
)

bfs_tree = nx.bfs_tree(graph, center, sort_neighbors=sort_neighbors_fn)
qubits_to_include = list(bfs_tree.nodes)[:num_qubits]
final_tree = bfs_tree.subgraph(qubits_to_include)

ghz_ops = []

for node in nx.topological_sort(final_tree):
# Handling the center qubit first
if node == center:
ghz_ops.append(ops.H(node))
continue

for parent in final_tree.predecessors(node):
ghz_ops.extend([ops.H(node), ops.CZ(parent, node), ops.H(node)])

circuit = circuits.Circuit(ghz_ops)

if add_dd_and_align_right:
return _transform_circuit(circuit)
else:
return circuit
155 changes: 155 additions & 0 deletions cirq-core/cirq/experiments/ghz_2d_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright 2025 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for generating and validating 2D GHZ state circuits."""

from typing import cast

import networkx as nx
import numpy as np
import pytest

import cirq
import cirq.experiments.ghz_2d as ghz_2d


def _create_mock_graph():
qubits = cirq.GridQubit.rect(6, 6)
g = nx.Graph()
for q in qubits:
g.add_node(q)
if q.col + 1 < 6:
g.add_edge(q, cirq.GridQubit(q.row, q.col + 1))
if q.row + 1 < 6:
g.add_edge(q, cirq.GridQubit(q.row + 1, q.col))
return g, cirq.GridQubit(3, 3)


graph, center_qubit = _create_mock_graph()


@pytest.mark.parametrize("num_qubits", list(range(1, len(graph.nodes) + 1)))
@pytest.mark.parametrize("randomized", [True, False])
@pytest.mark.parametrize("add_dd_and_align_right", [True, False])
def test_ghz_circuits_size(num_qubits: int, randomized: bool, add_dd_and_align_right: bool) -> None:
"""Tests the size of the GHZ circuits."""
circuit = ghz_2d.generate_2d_ghz_circuit(
center_qubit,
graph,
num_qubits=num_qubits,
randomized=randomized,
add_dd_and_align_right=add_dd_and_align_right,
)
assert len(circuit.all_qubits()) == num_qubits


@pytest.mark.parametrize("num_qubits", [2, 3, 4, 5, 6, 8, 10])
@pytest.mark.parametrize("randomized", [True, False])
@pytest.mark.parametrize("add_dd_and_align_right", [True, False]) # , True
def test_ghz_circuits_state(
num_qubits: int, randomized: bool, add_dd_and_align_right: bool
) -> None:
"""Tests the state vector form of the GHZ circuits."""

circuit = ghz_2d.generate_2d_ghz_circuit(
center_qubit,
graph,
num_qubits=num_qubits,
randomized=randomized,
add_dd_and_align_right=add_dd_and_align_right,
)

simulator = cirq.Simulator()
result = simulator.simulate(circuit)
state = result.final_state_vector

np.testing.assert_allclose(np.abs(state[0]), 1 / np.sqrt(2), atol=1e-7)
np.testing.assert_allclose(np.abs(state[-1]), 1 / np.sqrt(2), atol=1e-7)

if num_qubits > 1:
np.testing.assert_allclose(state[1:-1], 0)


def test_transform_circuit_properties() -> None:
"""Tests that _transform_circuit preserves circuit properties."""
circuit = ghz_2d.generate_2d_ghz_circuit(
center_qubit, graph, num_qubits=9, randomized=False, add_dd_and_align_right=False
)
transformed_circuit = ghz_2d._transform_circuit(circuit)

assert transformed_circuit.all_qubits() == circuit.all_qubits()

assert len(transformed_circuit) >= len(circuit)

final_moment = transformed_circuit[-1]
assert not any(isinstance(op.gate, cirq.MeasurementGate) for op in final_moment)

assert cirq.equal_up_to_global_phase(circuit.unitary(), transformed_circuit.unitary())


def manhattan_distance(q1: cirq.GridQubit, q2: cirq.GridQubit) -> int:
"""Calculates the Manhattan distance between two GridQubits."""
return abs(q1.row - q2.row) + abs(q1.col - q2.col)


@pytest.mark.parametrize("num_qubits", [2, 4, 9, 15, 20])
def test_ghz_circuits_bfs_order(num_qubits: int) -> None:
"""Verifies that the circuit construction maintains BFS order"""

circuit = ghz_2d.generate_2d_ghz_circuit(
center_qubit,
graph,
num_qubits=num_qubits,
randomized=False, # Test must run on the deterministic BFS order
add_dd_and_align_right=False, # Test must run on the raw circuit
)

max_dist_seen = 0

for moment in circuit:
for op in moment:
if isinstance(op.gate, cirq.CZPowGate):
qubits = op.qubits

dist_q0 = manhattan_distance(center_qubit, cast(cirq.GridQubit, qubits[0]))
dist_q1 = manhattan_distance(center_qubit, cast(cirq.GridQubit, qubits[1]))

child_qubit_distance = max(dist_q0, dist_q1)

if child_qubit_distance > max_dist_seen:
assert child_qubit_distance == max_dist_seen + 1
max_dist_seen = child_qubit_distance

assert child_qubit_distance <= max_dist_seen

included_qubits = circuit.all_qubits()
if included_qubits:
max_dist_required = max(
manhattan_distance(center_qubit, cast(cirq.GridQubit, q)) for q in included_qubits
)
assert max_dist_seen == max_dist_required


def test_ghz_invalid_inputs():
"""Tests that the function raises errors for invalid inputs."""

with pytest.raises(ValueError, match="num_qubits must be a positive integer."):
ghz_2d.generate_2d_ghz_circuit(center_qubit, graph, num_qubits=0) # invalid

with pytest.raises(
ValueError, match="num_qubits cannot exceed the total number of qubits on the processor."
):
ghz_2d.generate_2d_ghz_circuit(
center_qubit, graph, num_qubits=len(graph.nodes) + 1 # invalid
)