Skip to content

Commit ebdebc9

Browse files
committed
fix according to the review
1 parent 8c666fa commit ebdebc9

File tree

5 files changed

+246
-184
lines changed

5 files changed

+246
-184
lines changed

examples/vqe2d_lattice.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import time
2+
import optax
3+
import tensorcircuit as tc
4+
from tensorcircuit.templates.lattice import SquareLattice, get_compatible_layers
5+
from tensorcircuit.templates.hamiltonians import heisenberg_hamiltonian
6+
7+
# Use JAX for high-performance, especially on GPU.
8+
K = tc.set_backend("jax")
9+
tc.set_dtype("complex64")
10+
# On Windows, cotengra's multiprocessing can cause issues.
11+
tc.set_contractor("cotengra-8192-8192", parallel=False)
12+
13+
14+
def run_vqe():
15+
n, m, nlayers = 4, 4, 6
16+
lattice = SquareLattice(size=(n, m), pbc=True, precompute_neighbors=1)
17+
h = heisenberg_hamiltonian(lattice, j_coupling=[1.0, 1.0, 0.8]) # Jx, Jy, Jz
18+
nn_bonds = lattice.get_neighbor_pairs(k=1, unique=True)
19+
gate_layers = get_compatible_layers(nn_bonds)
20+
21+
def singlet_init(circuit):
22+
# A good initial state for Heisenberg ground state search
23+
nq = circuit._nqubits
24+
for i in range(0, nq - 1, 2):
25+
j = (i + 1) % nq
26+
circuit.X(i)
27+
circuit.H(i)
28+
circuit.cnot(i, j)
29+
circuit.X(j)
30+
return circuit
31+
32+
def vqe_forward(param):
33+
"""
34+
Defines the VQE ansatz and computes the energy expectation.
35+
The ansatz consists of nlayers of RZZ, RXX, and RYY entangling layers.
36+
"""
37+
c = tc.Circuit(n * m)
38+
c = singlet_init(c)
39+
40+
for i in range(nlayers):
41+
for layer in gate_layers:
42+
for j, k in layer:
43+
c.rzz(int(j), int(k), theta=param[i, 0])
44+
for layer in gate_layers:
45+
for j, k in layer:
46+
c.rxx(int(j), int(k), theta=param[i, 1])
47+
for layer in gate_layers:
48+
for j, k in layer:
49+
c.ryy(int(j), int(k), theta=param[i, 2])
50+
51+
return tc.templates.measurements.operator_expectation(c, h)
52+
53+
vgf = K.jit(K.value_and_grad(vqe_forward))
54+
param = tc.backend.implicit_randn(stddev=0.02, shape=[nlayers, 3])
55+
optimizer = optax.adam(learning_rate=3e-3)
56+
opt_state = optimizer.init(param)
57+
58+
@K.jit
59+
def train_step(param, opt_state):
60+
"""A single training step, JIT-compiled for maximum speed."""
61+
loss_val, grads = vgf(param)
62+
updates, opt_state = optimizer.update(grads, opt_state, param)
63+
param = optax.apply_updates(param, updates)
64+
return param, opt_state, loss_val
65+
66+
print("Starting VQE optimization...")
67+
for i in range(1000):
68+
time0 = time.time()
69+
param, opt_state, loss = train_step(param, opt_state)
70+
time1 = time.time()
71+
if i % 10 == 0:
72+
print(
73+
f"Step {i:4d}: Loss = {loss:.6f} \t (Time per step: {time1 - time0:.4f}s)"
74+
)
75+
76+
print("Optimization finished.")
77+
print(f"Final Loss: {loss:.6f}")
78+
79+
80+
if __name__ == "__main__":
81+
run_vqe()

tensorcircuit/templates/circuit_utils.py

Lines changed: 0 additions & 59 deletions
This file was deleted.

tensorcircuit/templates/lattice.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Union,
1616
TYPE_CHECKING,
1717
cast,
18+
Set,
1819
)
1920

2021
logger = logging.getLogger(__name__)
@@ -1446,3 +1447,54 @@ def remove_sites(self, identifiers: List[SiteIdentifier]) -> None:
14461447
logger.info(
14471448
f"{len(ids_to_remove)} sites removed. Lattice now has {self.num_sites} sites."
14481449
)
1450+
1451+
1452+
def get_compatible_layers(bonds: List[Tuple[int, int]]) -> List[List[Tuple[int, int]]]:
1453+
"""
1454+
Partitions a list of pairs (bonds) into compatible layers for parallel
1455+
gate application using a greedy edge-coloring algorithm.
1456+
1457+
This function takes a list of pairs, representing connections like
1458+
nearest-neighbor (NN) or next-nearest-neighbor (NNN) bonds, and
1459+
partitions them into the minimum number of sets ("layers") where no two
1460+
pairs in a set share an index. This is a general utility for scheduling
1461+
non-overlapping operations.
1462+
1463+
:Example:
1464+
1465+
>>> from tensorcircuit.templates.lattice import SquareLattice
1466+
>>> sq_lattice = SquareLattice(size=(2, 2), pbc=False)
1467+
>>> nn_bonds = sq_lattice.get_neighbor_pairs(k=1, unique=True)
1468+
1469+
>>> gate_layers = get_compatible_layers(nn_bonds)
1470+
>>> print(gate_layers)
1471+
[[[0, 1], [2, 3]], [[0, 2], [1, 3]]]
1472+
1473+
:param bonds: A list of tuples, where each tuple represents a bond (i, j)
1474+
of site indices to be scheduled.
1475+
:type bonds: List[Tuple[int, int]]
1476+
:return: A list of layers. Each layer is a list of tuples, where each
1477+
tuple represents a bond. All bonds within a layer are non-overlapping.
1478+
:rtype: List[List[Tuple[int, int]]]
1479+
"""
1480+
uncolored_edges: Set[Tuple[int, int]] = {(min(bond), max(bond)) for bond in bonds}
1481+
1482+
layers: List[List[Tuple[int, int]]] = []
1483+
1484+
while uncolored_edges:
1485+
current_layer: List[Tuple[int, int]] = []
1486+
qubits_in_this_layer: Set[int] = set()
1487+
1488+
edges_to_process = sorted(list(uncolored_edges))
1489+
1490+
for edge in edges_to_process:
1491+
i, j = edge
1492+
if i not in qubits_in_this_layer and j not in qubits_in_this_layer:
1493+
current_layer.append(edge)
1494+
qubits_in_this_layer.add(i)
1495+
qubits_in_this_layer.add(j)
1496+
1497+
uncolored_edges -= set(current_layer)
1498+
layers.append(sorted(current_layer))
1499+
1500+
return layers

tests/test_circuit_utils.py

Lines changed: 0 additions & 125 deletions
This file was deleted.

0 commit comments

Comments
 (0)