diff --git a/examples/vqe2d_lattice.py b/examples/vqe2d_lattice.py new file mode 100644 index 00000000..13d78876 --- /dev/null +++ b/examples/vqe2d_lattice.py @@ -0,0 +1,93 @@ +""" +This example demonstrates how to use the VQE algorithm to find the ground state +of a 2D Heisenberg model on a square lattice. It showcases the setup of the lattice, +the Heisenberg Hamiltonian, a suitable ansatz, and the optimization process. +""" + +import time +import optax +import tensorcircuit as tc +from tensorcircuit.templates.lattice import SquareLattice, get_compatible_layers +from tensorcircuit.templates.hamiltonians import heisenberg_hamiltonian + +# Use JAX for high-performance, especially on GPU. +K = tc.set_backend("jax") +tc.set_dtype("complex64") +# On Windows, cotengra's multiprocessing can cause issues, use threads instead. +tc.set_contractor("cotengra-8192-8192", parallel="threads") + + +def run_vqe(): + """Set up and run the VQE optimization for a 2D Heisenberg model.""" + n, m, nlayers = 4, 4, 2 + lattice = SquareLattice(size=(n, m), pbc=True, precompute_neighbors=1) + h = heisenberg_hamiltonian(lattice, j_coupling=[1.0, 1.0, 0.8]) # Jx, Jy, Jz + nn_bonds = lattice.get_neighbor_pairs(k=1, unique=True) + gate_layers = get_compatible_layers(nn_bonds) + n_params = nlayers * len(nn_bonds) * 3 + + def singlet_init(circuit): + # A good initial state for Heisenberg ground state search + nq = circuit._nqubits + for i in range(0, nq - 1, 2): + j = (i + 1) % nq + circuit.X(i) + circuit.H(i) + circuit.cnot(i, j) + circuit.X(j) + return circuit + + def vqe_forward(param): + """ + Defines the VQE ansatz and computes the energy expectation. + The ansatz consists of nlayers of RZZ, RXX, and RYY entangling layers. + """ + c = tc.Circuit(n * m) + c = singlet_init(c) + param_idx = 0 + + for _ in range(nlayers): + for layer in gate_layers: + for j, k in layer: + c.rzz(int(j), int(k), theta=param[param_idx]) + param_idx += 1 + for layer in gate_layers: + for j, k in layer: + c.rxx(int(j), int(k), theta=param[param_idx]) + param_idx += 1 + for layer in gate_layers: + for j, k in layer: + c.ryy(int(j), int(k), theta=param[param_idx]) + param_idx += 1 + + return tc.templates.measurements.operator_expectation(c, h) + + vgf = K.jit(K.value_and_grad(vqe_forward)) + param = tc.backend.implicit_randn(stddev=0.02, shape=[n_params]) + optimizer = optax.adam(learning_rate=3e-3) + opt_state = optimizer.init(param) + + @K.jit + def train_step(param, opt_state): + """A single training step, JIT-compiled for maximum speed.""" + loss_val, grads = vgf(param) + updates, opt_state = optimizer.update(grads, opt_state, param) + param = optax.apply_updates(param, updates) + return param, opt_state, loss_val + + print("Starting VQE optimization...") + for i in range(1000): + time0 = time.time() + param, opt_state, loss = train_step(param, opt_state) + time1 = time.time() + if i % 10 == 0: + print( + f"Step {i:4d}: Loss = {loss:.6f} \t (Time per step: {time1 - time0:.4f}s)" + ) + + print("Optimization finished.") + print(f"Final Loss: {loss:.6f}") + + +if __name__ == "__main__": + run_vqe() diff --git a/tensorcircuit/templates/lattice.py b/tensorcircuit/templates/lattice.py index d4e3e541..52f152c9 100644 --- a/tensorcircuit/templates/lattice.py +++ b/tensorcircuit/templates/lattice.py @@ -15,6 +15,7 @@ Union, TYPE_CHECKING, cast, + Set, ) logger = logging.getLogger(__name__) @@ -1446,3 +1447,54 @@ def remove_sites(self, identifiers: List[SiteIdentifier]) -> None: logger.info( f"{len(ids_to_remove)} sites removed. Lattice now has {self.num_sites} sites." ) + + +def get_compatible_layers(bonds: List[Tuple[int, int]]) -> List[List[Tuple[int, int]]]: + """ + Partitions a list of pairs (bonds) into compatible layers for parallel + gate application using a greedy edge-coloring algorithm. + + This function takes a list of pairs, representing connections like + nearest-neighbor (NN) or next-nearest-neighbor (NNN) bonds, and + partitions them into the minimum number of sets ("layers") where no two + pairs in a set share an index. This is a general utility for scheduling + non-overlapping operations. + + :Example: + + >>> from tensorcircuit.templates.lattice import SquareLattice + >>> sq_lattice = SquareLattice(size=(2, 2), pbc=False) + >>> nn_bonds = sq_lattice.get_neighbor_pairs(k=1, unique=True) + + >>> gate_layers = get_compatible_layers(nn_bonds) + >>> print(gate_layers) + [[[0, 1], [2, 3]], [[0, 2], [1, 3]]] + + :param bonds: A list of tuples, where each tuple represents a bond (i, j) + of site indices to be scheduled. + :type bonds: List[Tuple[int, int]] + :return: A list of layers. Each layer is a list of tuples, where each + tuple represents a bond. All bonds within a layer are non-overlapping. + :rtype: List[List[Tuple[int, int]]] + """ + uncolored_edges: Set[Tuple[int, int]] = {(min(bond), max(bond)) for bond in bonds} + + layers: List[List[Tuple[int, int]]] = [] + + while uncolored_edges: + current_layer: List[Tuple[int, int]] = [] + qubits_in_this_layer: Set[int] = set() + + edges_to_process = sorted(list(uncolored_edges)) + + for edge in edges_to_process: + i, j = edge + if i not in qubits_in_this_layer and j not in qubits_in_this_layer: + current_layer.append(edge) + qubits_in_this_layer.add(i) + qubits_in_this_layer.add(j) + + uncolored_edges -= set(current_layer) + layers.append(sorted(current_layer)) + + return layers diff --git a/tests/test_lattice.py b/tests/test_lattice.py index d332e6cd..12354f13 100644 --- a/tests/test_lattice.py +++ b/tests/test_lattice.py @@ -23,6 +23,8 @@ RectangularLattice, SquareLattice, TriangularLattice, + AbstractLattice, + get_compatible_layers, ) @@ -1664,3 +1666,85 @@ def test_distance_matrix_invariants_for_all_lattice_types(self, lattice): # "The specialized PBC implementation is significantly slower " # "than the general-purpose implementation." # ) + + +def _validate_layers(bonds, layers) -> None: + """ + A helper function to scientifically validate the output of get_compatible_layers. + """ + # MODIFICATION: This function now takes the original bonds list for comparison. + expected_edges = set(tuple(sorted(b)) for b in bonds) + actual_edges = set(tuple(sorted(edge)) for layer in layers for edge in layer) + + assert ( + expected_edges == actual_edges + ), "Completeness check failed: The set of all edges in the layers must " + "exactly match the input bonds." + + for i, layer in enumerate(layers): + qubits_in_layer: set[int] = set() + for edge in layer: + q1, q2 = edge + assert ( + q1 not in qubits_in_layer + ), f"Compatibility check failed: Qubit {q1} is reused in layer {i}." + qubits_in_layer.add(q1) + assert ( + q2 not in qubits_in_layer + ), f"Compatibility check failed: Qubit {q2} is reused in layer {i}." + qubits_in_layer.add(q2) + + +@pytest.mark.parametrize( + "lattice_instance", + [ + SquareLattice(size=(3, 2), pbc=False), + SquareLattice(size=(3, 3), pbc=True), + HoneycombLattice(size=(2, 2), pbc=False), + ], + ids=[ + "SquareLattice_3x2_OBC", + "SquareLattice_3x3_PBC", + "HoneycombLattice_2x2_OBC", + ], +) +def test_layering_on_various_lattices(lattice_instance: AbstractLattice): + """Tests gate layering for various standard lattice types.""" + bonds = lattice_instance.get_neighbor_pairs(k=1, unique=True) + layers = get_compatible_layers(bonds) + + assert len(layers) > 0, "Layers should not be empty for non-trivial lattices." + _validate_layers(bonds, layers) + + +def test_layering_on_1d_chain_pbc(): + """Test layering on a 1D chain with periodic boundaries (a cycle graph).""" + lattice_even = ChainLattice(size=(6,), pbc=True) + bonds_even = lattice_even.get_neighbor_pairs(k=1, unique=True) + layers_even = get_compatible_layers(bonds_even) + _validate_layers(bonds_even, layers_even) + + lattice_odd = ChainLattice(size=(5,), pbc=True) + bonds_odd = lattice_odd.get_neighbor_pairs(k=1, unique=True) + layers_odd = get_compatible_layers(bonds_odd) + assert len(layers_odd) == 3, "A 5-site cycle graph should be 3-colorable." + _validate_layers(bonds_odd, layers_odd) + + +def test_layering_on_custom_star_graph(): + """Test layering on a custom lattice forming a star graph.""" + star_edges = [(0, 1), (0, 2), (0, 3)] + layers = get_compatible_layers(star_edges) + assert len(layers) == 3, "A star graph S_4 requires 3 layers." + _validate_layers(star_edges, layers) + + +def test_layering_on_edge_cases(): + """Test various edge cases: empty, single-site, and no-edge lattices.""" + layers_empty = get_compatible_layers([]) + assert layers_empty == [], "Layers should be empty for an empty set of bonds." + + single_edge = [(0, 1)] + layers_single = get_compatible_layers(single_edge) + assert layers_single == [[(0, 1)]] + _validate_layers(single_edge, layers_single)