Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions examples/vqe2d_lattice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
This example demonstrates how to use the VQE algorithm to find the ground state
of a 2D Heisenberg model on a square lattice. It showcases the setup of the lattice,
the Heisenberg Hamiltonian, a suitable ansatz, and the optimization process.
"""

import time
import optax
import tensorcircuit as tc
from tensorcircuit.templates.lattice import SquareLattice, get_compatible_layers
from tensorcircuit.templates.hamiltonians import heisenberg_hamiltonian

# Use JAX for high-performance, especially on GPU.
K = tc.set_backend("jax")
tc.set_dtype("complex64")
# On Windows, cotengra's multiprocessing can cause issues, use threads instead.
tc.set_contractor("cotengra-8192-8192", parallel="threads")


def run_vqe():
"""Set up and run the VQE optimization for a 2D Heisenberg model."""
n, m, nlayers = 4, 4, 2
lattice = SquareLattice(size=(n, m), pbc=True, precompute_neighbors=1)
h = heisenberg_hamiltonian(lattice, j_coupling=[1.0, 1.0, 0.8]) # Jx, Jy, Jz
nn_bonds = lattice.get_neighbor_pairs(k=1, unique=True)
gate_layers = get_compatible_layers(nn_bonds)
n_params = nlayers * len(nn_bonds) * 3

def singlet_init(circuit):
# A good initial state for Heisenberg ground state search
nq = circuit._nqubits
for i in range(0, nq - 1, 2):
j = (i + 1) % nq
circuit.X(i)
circuit.H(i)
circuit.cnot(i, j)
circuit.X(j)
return circuit

def vqe_forward(param):
"""
Defines the VQE ansatz and computes the energy expectation.
The ansatz consists of nlayers of RZZ, RXX, and RYY entangling layers.
"""
c = tc.Circuit(n * m)
c = singlet_init(c)
param_idx = 0

for _ in range(nlayers):
for layer in gate_layers:
for j, k in layer:
c.rzz(int(j), int(k), theta=param[param_idx])
param_idx += 1
for layer in gate_layers:
for j, k in layer:
c.rxx(int(j), int(k), theta=param[param_idx])
param_idx += 1
for layer in gate_layers:
for j, k in layer:
c.ryy(int(j), int(k), theta=param[param_idx])
param_idx += 1

return tc.templates.measurements.operator_expectation(c, h)

vgf = K.jit(K.value_and_grad(vqe_forward))
param = tc.backend.implicit_randn(stddev=0.02, shape=[n_params])
optimizer = optax.adam(learning_rate=3e-3)
opt_state = optimizer.init(param)

@K.jit
def train_step(param, opt_state):
"""A single training step, JIT-compiled for maximum speed."""
loss_val, grads = vgf(param)
updates, opt_state = optimizer.update(grads, opt_state, param)
param = optax.apply_updates(param, updates)
return param, opt_state, loss_val

print("Starting VQE optimization...")
for i in range(1000):
time0 = time.time()
param, opt_state, loss = train_step(param, opt_state)
time1 = time.time()
if i % 10 == 0:
print(
f"Step {i:4d}: Loss = {loss:.6f} \t (Time per step: {time1 - time0:.4f}s)"
)

print("Optimization finished.")
print(f"Final Loss: {loss:.6f}")


if __name__ == "__main__":
run_vqe()
52 changes: 52 additions & 0 deletions tensorcircuit/templates/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Union,
TYPE_CHECKING,
cast,
Set,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -1446,3 +1447,54 @@ def remove_sites(self, identifiers: List[SiteIdentifier]) -> None:
logger.info(
f"{len(ids_to_remove)} sites removed. Lattice now has {self.num_sites} sites."
)


def get_compatible_layers(bonds: List[Tuple[int, int]]) -> List[List[Tuple[int, int]]]:
"""
Partitions a list of pairs (bonds) into compatible layers for parallel
gate application using a greedy edge-coloring algorithm.

This function takes a list of pairs, representing connections like
nearest-neighbor (NN) or next-nearest-neighbor (NNN) bonds, and
partitions them into the minimum number of sets ("layers") where no two
pairs in a set share an index. This is a general utility for scheduling
non-overlapping operations.

:Example:

>>> from tensorcircuit.templates.lattice import SquareLattice
>>> sq_lattice = SquareLattice(size=(2, 2), pbc=False)
>>> nn_bonds = sq_lattice.get_neighbor_pairs(k=1, unique=True)

>>> gate_layers = get_compatible_layers(nn_bonds)
>>> print(gate_layers)
[[[0, 1], [2, 3]], [[0, 2], [1, 3]]]

:param bonds: A list of tuples, where each tuple represents a bond (i, j)
of site indices to be scheduled.
:type bonds: List[Tuple[int, int]]
:return: A list of layers. Each layer is a list of tuples, where each
tuple represents a bond. All bonds within a layer are non-overlapping.
:rtype: List[List[Tuple[int, int]]]
"""
uncolored_edges: Set[Tuple[int, int]] = {(min(bond), max(bond)) for bond in bonds}

layers: List[List[Tuple[int, int]]] = []

while uncolored_edges:
current_layer: List[Tuple[int, int]] = []
qubits_in_this_layer: Set[int] = set()

edges_to_process = sorted(list(uncolored_edges))

for edge in edges_to_process:
i, j = edge
if i not in qubits_in_this_layer and j not in qubits_in_this_layer:
current_layer.append(edge)
qubits_in_this_layer.add(i)
qubits_in_this_layer.add(j)

uncolored_edges -= set(current_layer)
layers.append(sorted(current_layer))

return layers
84 changes: 84 additions & 0 deletions tests/test_lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
RectangularLattice,
SquareLattice,
TriangularLattice,
AbstractLattice,
get_compatible_layers,
)


Expand Down Expand Up @@ -1664,3 +1666,85 @@ def test_distance_matrix_invariants_for_all_lattice_types(self, lattice):
# "The specialized PBC implementation is significantly slower "
# "than the general-purpose implementation."
# )


def _validate_layers(bonds, layers) -> None:
"""
A helper function to scientifically validate the output of get_compatible_layers.
"""
# MODIFICATION: This function now takes the original bonds list for comparison.
expected_edges = set(tuple(sorted(b)) for b in bonds)
actual_edges = set(tuple(sorted(edge)) for layer in layers for edge in layer)

assert (
expected_edges == actual_edges
), "Completeness check failed: The set of all edges in the layers must "
"exactly match the input bonds."

for i, layer in enumerate(layers):
qubits_in_layer: set[int] = set()
for edge in layer:
q1, q2 = edge
assert (
q1 not in qubits_in_layer
), f"Compatibility check failed: Qubit {q1} is reused in layer {i}."
qubits_in_layer.add(q1)
assert (
q2 not in qubits_in_layer
), f"Compatibility check failed: Qubit {q2} is reused in layer {i}."
qubits_in_layer.add(q2)


@pytest.mark.parametrize(
"lattice_instance",
[
SquareLattice(size=(3, 2), pbc=False),
SquareLattice(size=(3, 3), pbc=True),
HoneycombLattice(size=(2, 2), pbc=False),
],
ids=[
"SquareLattice_3x2_OBC",
"SquareLattice_3x3_PBC",
"HoneycombLattice_2x2_OBC",
],
)
def test_layering_on_various_lattices(lattice_instance: AbstractLattice):
"""Tests gate layering for various standard lattice types."""
bonds = lattice_instance.get_neighbor_pairs(k=1, unique=True)
layers = get_compatible_layers(bonds)

assert len(layers) > 0, "Layers should not be empty for non-trivial lattices."
_validate_layers(bonds, layers)


def test_layering_on_1d_chain_pbc():
"""Test layering on a 1D chain with periodic boundaries (a cycle graph)."""
lattice_even = ChainLattice(size=(6,), pbc=True)
bonds_even = lattice_even.get_neighbor_pairs(k=1, unique=True)
layers_even = get_compatible_layers(bonds_even)
_validate_layers(bonds_even, layers_even)

lattice_odd = ChainLattice(size=(5,), pbc=True)
bonds_odd = lattice_odd.get_neighbor_pairs(k=1, unique=True)
layers_odd = get_compatible_layers(bonds_odd)
assert len(layers_odd) == 3, "A 5-site cycle graph should be 3-colorable."
_validate_layers(bonds_odd, layers_odd)


def test_layering_on_custom_star_graph():
"""Test layering on a custom lattice forming a star graph."""
star_edges = [(0, 1), (0, 2), (0, 3)]
layers = get_compatible_layers(star_edges)
assert len(layers) == 3, "A star graph S_4 requires 3 layers."
_validate_layers(star_edges, layers)


def test_layering_on_edge_cases():
"""Test various edge cases: empty, single-site, and no-edge lattices."""
layers_empty = get_compatible_layers([])
assert layers_empty == [], "Layers should be empty for an empty set of bonds."

single_edge = [(0, 1)]
layers_single = get_compatible_layers(single_edge)
assert layers_single == [[(0, 1)]]
_validate_layers(single_edge, layers_single)