In [90]:
%reload_ext autoreload
%autoreload 2

In [92]:
from typing import List

from qiskit_qec.circuits.code_circuit import CodeCircuit
from qiskit import QuantumRegister, ClassicalRegister, QuantumCircuit

class BaconShorCodeCircuit():

    def __init__(
            self,
            d: int,
            T: int,
            xbasis: bool = False,
            resets: bool = False,
            bridge_qubits: bool = False,
            barriers: bool = False,
    ): 
        
        if bridge_qubits == False:
            raise NotImplementedError("Bacon-Shor code without bridge qubits is not implemented yet.")
        
        # super().__init__()

        self.d = d
        self.T = 0

        self._bridge_qubits = bridge_qubits
        self._resets = resets
        self._xbasis = xbasis
        self._barriers = barriers

        self.code_qubits = QuantumRegister(d**2, "code_qubit")
        self.x_ancilla_qubits = QuantumRegister((d-1)*d, "x_ancilla_qubit")
        self.z_ancilla_qubits = QuantumRegister((d-1)*d, "z_ancilla_qubit")
        self.bulk_bridge_qubits = QuantumRegister(2*(d-1)*d, "bulk_bridge_qubit") if bridge_qubits else None        
        self.boundary_bridge_qubits = QuantumRegister(2*(d-1), "boundary_bridge_qubit") if bridge_qubits else None        
        self.qubit_registers = {"code_qubit", "x_aniclla_qubit", "z_ancilla_qubit", "bulk_bridge_qubit", "boundary_bridge_qubit"}

        self.x_ancilla_bits = []
        self.z_ancilla_bits = []
        self.code_bits = ClassicalRegister(d**2, "code_bit")

        self.circuit = {}
        for log in ["0", "1"]:
            self.circuit[log] = QuantumCircuit(self.code_qubits, self.bulk_bridge_qubits, 
                                          self.boundary_bridge_qubits, self.x_ancilla_qubits, 
                                          self.z_ancilla_qubits, name=log)
        
        self._preparation()

        for _ in range(T-1):
            self.syndrome_measurement()

        if T !=0:
            self.syndrome_measurement(final=True)
            self.readout()

    def x(self, logs=("0", "1")):
        # TODO: This is wrong, apply a logical gate instead of a transversal
        for log in logs:
            if self._xbasis:
                self.circuit[log].z(self.code_qubits)
            else:
                self.circuit[log].x(self.code_qubits)

    def _preparation(self):
        for log in ["0", "1"]:
            if self._xbasis:
                self.circuit[log].h(self.code_qubits)
        self.x(["1"]) # Apply an X for the logical 1 circuit

    def readout(self):
        """Readout of all code qubits"""
        for log in ["0", "1"]:
            if self._xbasis:
                self.circuit[log].h(self.code_qubits)
            self.circuit[log].add_register(self.code_bits)
            self.circuit[log].measure(self.code_qubits, self.code_bits)

    def syndrome_measurement(self, final: bool = False):

        bridge_CX = BaconShorCodeCircuit.bridge_CX
        
        self.x_ancilla_bits.append(ClassicalRegister((self.d-1)*self.d, f"round_{self.T}_x_link_bit"))
        self.z_ancilla_bits.append(ClassicalRegister((self.d-1)*self.d, f"round_{self.T}_z_link_bit"))

        for log in ["0", "1"]:

            # Entangling gates
            self.circuit[log].barrier() if self._barriers else None
            ### Z stabilizers             
            for y in range(self.d): # L->R CXs
                for x in range(self.d-1):
                    c_idx = y*self.d + x
                    a_idx = y*(self.d-1) + x
                    if y == 0: # boundary bridge qubits
                        b_idx = x
                        bridge_CX(self.circuit[log], self.code_qubits[c_idx], 
                                                       self.z_ancilla_qubits[a_idx], self.boundary_bridge_qubits[b_idx])
                    else: # bulk bridge qubits
                        b_idx = 2*(y-1)*self.d + x + self.d
                        bridge_CX(self.circuit[log], self.code_qubits[c_idx], 
                                                       self.z_ancilla_qubits[a_idx], self.bulk_bridge_qubits[b_idx])

            for y in range(self.d): # R->L CXs
                for x in range(self.d-1):
                    c_idx = y*self.d + x + 1
                    a_idx = y*(self.d-1) + x
                    if y == self.d-1: # boundary bridge qubits
                        b_idx = self.d-1 + x
                        bridge_CX(self.circuit[log], self.code_qubits[c_idx], 
                                                       self.z_ancilla_qubits[a_idx], self.boundary_bridge_qubits[b_idx])
                    else: # bulk bridge qubits
                        b_idx = 2*y*self.d + x + 1
                        bridge_CX(self.circuit[log], self.code_qubits[c_idx], 
                                                       self.z_ancilla_qubits[a_idx], self.bulk_bridge_qubits[b_idx])


            ### X stabilizers
            self.circuit[log].barrier() if self._barriers else None
            self.circuit[log].h(self.x_ancilla_qubits)

            for y in range(self.d-1): # T->B CXs
                for x in range(self.d):
                    c_idx = y*self.d + x
                    a_idx = y*self.d + x 
                    b_idx = 2*y*self.d + x
                    bridge_CX(self.circuit[log], self.x_ancilla_qubits[a_idx], 
                                                   self.code_qubits[c_idx], self.bulk_bridge_qubits[b_idx])
            
            for y in range(self.d-1): # B->T CXs
                for x in range(self.d):
                    c_idx = (y+1)*self.d + x 
                    a_idx = y*self.d + x
                    b_idx = 2*y*self.d + x + self.d
                    bridge_CX(self.circuit[log], self.x_ancilla_qubits[a_idx], 
                                                   self.code_qubits[c_idx], self.bulk_bridge_qubits[b_idx])

            self.circuit[log].h(self.x_ancilla_qubits)


            # Measurements
            self.circuit[log].barrier() if self._barriers else None

            self.circuit[log].add_register(self.z_ancilla_bits[self.T])
            self.circuit[log].add_register(self.x_ancilla_bits[self.T])

            self.circuit[log].measure(self.z_ancilla_qubits, self.z_ancilla_bits[self.T])
            self.circuit[log].measure(self.x_ancilla_qubits, self.x_ancilla_bits[self.T])

            # Resets
            if self._resets and not final:
                self.circuit[log].barrier() if self._barriers else None
                self.circuit[log].reset(self.z_ancilla_qubits)
                self.circuit[log].reset(self.x_ancilla_qubits)


        self.T += 1

    @staticmethod
    def bridge_CX(circ: QuantumCircuit, control: QuantumRegister, target: QuantumRegister, bridge: QuantumRegister):
        """CX gate over a bridge qubit using 4 CXs instead of 2 SWAPs and 1 CX."""
        circ.cx(control, bridge)
        circ.cx(bridge, target)
        circ.cx(control, bridge)
        circ.cx(bridge, target)

In [93]:
code = BaconShorCodeCircuit(2, 3, bridge_qubits=True, barriers=True, xbasis=False, resets=True)
code.circuit["0"].draw(fold=-1)

# Simulator

In [94]:
from qiskit_aer import AerSimulator

sim = AerSimulator()
job = sim.run(code.circuit["0"], shots=1024)
result = job.result()
result.get_counts()

{'1111 10 11 01 11 01 00': 1,
 '0000 11 11 11 11 00 00': 2,
 '1111 01 00 10 00 01 00': 4,
 '1111 11 11 00 11 00 00': 2,
 '1010 00 00 11 00 11 00': 3,
 '0000 00 11 00 00 00 00': 2,
 '0101 11 11 11 00 11 00': 3,
 '1111 00 11 00 11 11 00': 3,
 '1111 01 11 01 00 10 00': 2,
 '0000 01 11 01 00 10 00': 4,
 '0101 01 00 01 00 10 00': 4,
 '0000 00 00 11 00 00 00': 2,
 '1111 11 11 00 11 11 00': 1,
 '1010 01 00 01 11 01 00': 4,
 '1010 11 11 00 00 00 00': 4,
 '0000 01 00 01 00 01 00': 2,
 '1111 10 00 10 11 10 00': 2,
 '0000 11 11 11 11 11 00': 3,
 '0101 00 11 00 11 11 00': 3,
 '0101 00 00 00 00 00 00': 2,
 '1111 10 00 01 00 01 00': 3,
 '1111 11 00 11 11 11 00': 2,
 '0000 00 11 00 11 11 00': 3,
 '0101 10 11 10 11 10 00': 4,
 '1010 11 00 00 11 00 00': 3,
 '1111 01 11 01 11 10 00': 5,
 '1111 01 00 10 00 10 00': 2,
 '1010 00 00 00 00 11 00': 2,
 '1111 00 00 00 00 11 00': 2,
 '0101 00 00 11 11 00 00': 3,
 '1010 00 11 11 11 11 00': 6,
 '1010 11 11 11 11 11 00': 6,
 '1111 01 11 01 11 01 00': 6,
 '1111 10 

# Profiling

In [33]:
import cProfile

cProfile.run("BaconShorCodeCircuit(20, 20, bridge_qubits=True, barriers=True)")

         20842996 function calls (20842824 primitive calls) in 5.940 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    60800    0.062    0.000    4.707    0.000 323930224.py:155(bridge_CX)
        1    0.000    0.000    0.002    0.002 323930224.py:60(x)
        1    0.000    0.000    0.002    0.002 323930224.py:68(_preparation)
        1    0.000    0.000    0.006    0.006 323930224.py:74(readout)
        1    0.000    0.000    5.929    5.929 323930224.py:8(__init__)
       20    0.070    0.004    5.907    0.295 323930224.py:82(syndrome_measurement)
   792286    0.136    0.000    0.227    0.000 <frozen abc>:117(__instancecheck__)
   243401    0.066    0.000    0.091    0.000 <frozen importlib._bootstrap>:405(parent)
    35116    0.006    0.000    0.015    0.000 <string>:1(<lambda>)
        1    0.011    0.011    5.940    5.940 <string>:1(<module>)
   243200    0.096    0.000    0.128    0.000 _utils.py:78(_ctrl_state_to_i