In [35]:
from typing import List, Dict, Any, Set, Iterable
from dataclasses import dataclass

import stim

In [36]:



def print_2d(values: Dict[complex, Any]):
    assert all(v.real == int(v.real) for v in values)
    assert all(v.imag == int(v.imag) for v in values)
    assert all(v.real >= 0 for v in values)
    assert all(v.imag >= 0 for v in values)
    
    w = int(max((v.real for v in values), default = 0) + 1)
    h = int(max((v.imag for v in values), default = 0) + 1)
    s = ""
    for y in range(h):
        for x in range(w):
            s += str(values.get(x + y*1j, "_"))
        s +="\n"
    print(s)    



def torus(c: complex, *, distance: int) -> complex:
    r = c.real % (distance * 4)
    i = c.imag % (distance * 6)
    return r + i*1j


@dataclass
class EdgeType:
    pauli: str
    hex_to_hex_delta: complex
    hex_to_qubit_delta: complex

def sorted_complex(xs: Iterable[complex]) -> List[complex]:
    return sorted(xs, key=lambda v:(v.real, v.imag))


def generate_circuit(distance: int, rounds: int) -> stim.Circuit:

    hex_centers: Dict[complex, int] = {}
    for row in range(3 * distance):
        for col in range(2 * distance):
            center = row * 2j + 2 * col - 1j * (col % 2)
            category = (-row - col % 2) % 3
            hex_centers[torus(center, distance=distance)] = category


    edge_types = [
        EdgeType(pauli="X", hex_to_hex_delta=2 - 3j, hex_to_qubit_delta=1 - 1j),
        EdgeType(pauli="Y", hex_to_hex_delta=2 + 3j, hex_to_qubit_delta=1 + 1j),
        EdgeType(pauli="Z", hex_to_hex_delta=4, hex_to_qubit_delta=1)
    ]

    qubit_coordinates: Set[complex] = set()
    for h in hex_centers:
        for e in edge_types:
            for sign in [-1, +1]:
                q = h + e.hex_to_qubit_delta * sign
                qubit_coordinates.add(torus(q, distance=distance))

    fused_dict = dict(hex_centers)
    for q in qubit_coordinates:
        fused_dict[q] = "q"
    print_2d(fused_dict)

    q2i: Dict[complex, int] = {q:i for i, q in enumerate(sorted_complex(qubit_coordinates))}

    round_circuits = []
    for r in range(3):
        relevant_hexes = [h for h, category in hex_centers.items() if category == r]
        edge_groups: Dict[str, List[frozenset[complex]]] = {"X": [], "Y": [], "Z": []}
        for h in relevant_hexes:
            for edge_type in edge_types:
                q1 = torus(h + edge_type.hex_to_qubit_delta, distance=distance)
                q2 = torus(h + edge_type.hex_to_hex_delta - edge_type.hex_to_qubit_delta, distance=distance)
                edge_groups[edge_type.pauli].append(frozenset([q1,q2]))
        circuit = stim.Circuit()
        x_qubits = [q2i[q] for pair in edge_groups["X"] for q in sorted_complex(pair)]
        y_qubits = [q2i[q] for pair in edge_groups["Y"] for q in sorted_complex(pair)]

        #Make all the parity operations Z basis parities
        circuit.append_operation("H", x_qubits)
        circuit.append_operation("H_YZ", x_qubits)

        #Turn parity observables into single qubit observables
        pair_targets = [
            q2i[q]
            for group in edge_groups.values()
            for pair in group
            for q in sorted_complex(pair)
        ]
        circuit.append_operation("CNOT", pair_targets)

        #Measure
        circuit.append_operation("M", pair_targets[1::2])

        #Restore qubit bases
        circuit.append_operation("CNOT", pair_targets)
        circuit.append_operation("H_YZ", x_qubits)
        circuit.append_operation("H", x_qubits)
        
        round_circuits.append(circuit)

    full_circuit = stim.Circuit()
    cycle = round_circuits[0] + round_circuits[1] + round_circuits[2]
    for q, i in q2i.items(): 
        full_circuit.append_operation("QUBIT_COORDS", [i], [q.real, q.imag])
    full_circuit += cycle * rounds
    return full_circuit

                 
def main():
    circuit = generate_circuit(distance=3, rounds=50)
    samples = circuit.compiler_sampler.sample(10)
    for sample in samples:
        print("".join("_1"[e] for e in sample))
    

if __name__ == '__main__':
    main()


0q_q
_q1q
2q_q
_q0q
1q_q
_q2q

REPEAT 1000 {
    H 4 5 7 8
    H_YZ 4 5 7 8
    CX 4 5 7 8 1 2 10 11 0 6 3 9
    M 5 8 2 11 6 9
    CX 4 5 7 8 1 2 10 11 0 6 3 9
    H_YZ 4 5 7 8
    H 4 5 7 8 6 11 2 3
    H_YZ 6 11 2 3
    CX 6 11 2 3 8 9 0 5 1 7 4 10
    M 11 3 9 5 7 10
    CX 6 11 2 3 8 9 0 5 1 7 4 10
    H_YZ 6 11 2 3
    H 6 11 2 3 9 10 0 1
    H_YZ 9 10 0 1
    CX 9 10 0 1 6 7 3 4 5 11 2 8
    M 10 1 7 4 11 8
    CX 9 10 0 1 6 7 3 4 5 11 2 8
    H_YZ 9 10 0 1
    H 9 10 0 1
}
