# Intro

This notebook contains various quantum error correction code
includes: \
    9 qubit shor code\
    5 qubit code\
    TODO: 7 qubit steane code\
    TODO: bacon shor code (family?)

In [None]:
# reload modules
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
##################
# Import Modules #
##################
# General purpose
import numpy as np
import random
import os 

# For build circuit & others
from qiskit import QuantumRegister, ClassicalRegister, QuantumCircuit
import qiskit.circuit.library as qGate
from qiskit.extensions import UnitaryGate  # for two qubit, matrix to qGate

from qiskit.quantum_info import Statevector

# For run circuit
from qiskit import Aer, transpile, execute
from qiskit.providers.fake_provider import FakeSydney
from qiskit_aer import AerSimulator


# For draw figures
import matplotlib.pyplot as plt
from qiskit.visualization import plot_histogram


from IPython.display import Latex
from qiskit.visualization import state_visualization




# Shor code, 9 qubit
https://en.wikipedia.org/wiki/Quantum_error_correction

In [None]:
class ShorCode9:
    def __init__(self):
        self.qReg = QuantumRegister(9, name='q')
        self.cReg = ClassicalRegister(1, name='c')
        # self.qAnc = QuantumRegister(4, name='anc_q')
        # self.cAnc = ClassicalRegister(4, name='anc_c')
        self.qc = QuantumCircuit(self.qReg, self.cReg)
        
    def reset_qc(self):
        self.qReg = QuantumRegister(9, name='q')
        self.cReg = ClassicalRegister(1, name='c')
        # self.qAnc = QuantumRegister(4, name='anc_q')
        # self.cAnc = ClassicalRegister(4, name='anc_c')
        self.qc = QuantumCircuit(self.qReg, self.cReg)
    
    def prepare_initial_state(self, info_bit=0):
        # default all |0>, can do x gate for invert
        # note that information bit at LSB, qReg[0]
        if info_bit == 1:
            self.qc.x(self.qReg[0])
        self.qc.barrier(self.qReg)
        
    def apply_channel(self, err_bit=None, err_type=None):
        # only x, y, z can be corrected
        if err_bit != None and err_type != None:
            if err_type == 'x':
                self.qc.x(self.qReg[err_bit])
            elif err_type == 'y':
                self.qc.y(self.qReg[err_bit])
            else: # 'z'
                self.qc.z(self.qReg[err_bit])
        
        # self.qc.z(self.qReg[0])
        # self.qc.z(self.qReg[1])
        # self.qc.z(self.qReg[2])

        # self.qc.x(self.qReg[0])
        # self.qc.x(self.qReg[3])
        # self.qc.x(self.qReg[6])

        self.qc.barrier(self.qReg)
    
    def encode(self):
        # ==== Encode, input state at qReg[0], from wiki ====
        self.qc.cx(self.qReg[0], self.qReg[3])
        self.qc.cx(self.qReg[0], self.qReg[6])
        for i in [0, 3, 6]:
            self.qc.h(self.qReg[i])
            self.qc.cx(self.qReg[i], self.qReg[i+1])
            self.qc.cx(self.qReg[i], self.qReg[i+2])

        # add barrier
        self.qc.barrier(self.qReg)
    
    def decode_passive(self):
        for i in [0,3,6]:
            self.qc.cx(self.qReg[i], self.qReg[i+1])
            self.qc.cx(self.qReg[i], self.qReg[i+2])
            self.qc.ccx(self.qReg[i+1], self.qReg[i+2], self.qReg[i])
            self.qc.h(self.qReg[i])
        self.qc.cx(self.qReg[0], self.qReg[3])
        self.qc.cx(self.qReg[0], self.qReg[6])
        self.qc.ccx(self.qReg[3], self.qReg[6], self.qReg[0])
        self.qc.barrier()
    
    # def apply_error_detection(self):
    #     # ==== Error detection, stabilizer measurement ====
    #     # TODO
    #     self.qc.h(self.qAnc)
    #     stabilizer = ['x','z','z','x','i'] # in the order of generator g0
    #     for gen_idx in range(self.qReg.size-1): # generator index, n-k=5-1=4
    #         # apply one generator to all qubit
    #         for qb_idx in range(self.qReg.size): # qubit index 
    #             operation = stabilizer[-gen_idx + qb_idx]
    #             if operation == 'x':
    #                 self.qc.cx(self.qAnc[gen_idx], self.qReg[qb_idx])
    #             elif operation == 'z':
    #                 self.qc.cz(self.qAnc[gen_idx], self.qReg[qb_idx])
    #     self.qc.h(self.qAnc)
    #     self.qc.measure(self.qAnc, self.cAnc)
    #     self.qc.barrier(self.qAnc, self.qReg)
    
    # def find_correction_lut(self):
    #     # syndrome in cAnc(classical ancilla) order is reverse of wiki
    #     # TODO
    #     q5code_lut = {}
    #     lut_code = ''
    #     for err_type in ['x', 'y', 'z']:
    #         for err_bit in range(self.qReg.size):
    #             self.reset_qc()
    #             self.prepare_initial_state(info_bit=0)
    #             self.encode()
    #             self.apply_channel(err_bit=err_bit, err_type=err_type)
    #             self.apply_error_detection()
    #             counts = self.run(shots=1)
    #             syndrome = list(counts.keys())[0][-4:]
    #             q5code_lut[syndrome] = {'err_bit': err_bit, 'err_type': err_type}
    #             lut_code += f'self.qc.{q5code_lut[syndrome]["err_type"]}(self.qReg[{q5code_lut[syndrome]["err_bit"]}]).c_if(self.cAnc, int(\'{syndrome}\',2))\n'
    #     print(lut_code)
    
    # def apply_error_correction(self):
    #     # ==== Error correction ====
    #     # TODO
    #     # hardcode lut from last block (or wiki)
    #     self.qc.x(self.qReg[0]).c_if(self.cAnc, int('1000',2))
    #     self.qc.x(self.qReg[1]).c_if(self.cAnc, int('0001',2))
    #     self.qc.x(self.qReg[2]).c_if(self.cAnc, int('0011',2))
    #     self.qc.x(self.qReg[3]).c_if(self.cAnc, int('0110',2))
    #     self.qc.x(self.qReg[4]).c_if(self.cAnc, int('1100',2))
    #     self.qc.y(self.qReg[0]).c_if(self.cAnc, int('1101',2))
    #     self.qc.y(self.qReg[1]).c_if(self.cAnc, int('1011',2))
    #     self.qc.y(self.qReg[2]).c_if(self.cAnc, int('0111',2))
    #     self.qc.y(self.qReg[3]).c_if(self.cAnc, int('1111',2))
    #     self.qc.y(self.qReg[4]).c_if(self.cAnc, int('1110',2))
    #     self.qc.z(self.qReg[0]).c_if(self.cAnc, int('0101',2))
    #     self.qc.z(self.qReg[1]).c_if(self.cAnc, int('1010',2))
    #     self.qc.z(self.qReg[2]).c_if(self.cAnc, int('0100',2))
    #     self.qc.z(self.qReg[3]).c_if(self.cAnc, int('1001',2))
    #     self.qc.z(self.qReg[4]).c_if(self.cAnc, int('0010',2))
    #     self.qc.measure(self.qReg, self.cReg)
    
    def build_full_qc(self, info_bit, err_bit=None, err_type=None):
        self.reset_qc()
        self.prepare_initial_state(info_bit)
        self.encode()
        
        # sv = Statevector.from_instruction(self.qc)
        # prefix = "$$\\begin{align}"
        # suffix = "\\end{align}$$"
        # state_to_latex = state_visualization._state_to_latex_ket(sv.data, max_size = 128)
        # display(Latex(prefix + state_to_latex + suffix))
        
        # self.qc.measure(self.qReg, self.cReg)
        
        
        self.apply_channel(err_bit=err_bit, err_type=err_type)
        
        self.decode_passive()
        self.qc.measure(self.qReg[0], self.cReg[0])
        
    def run(self, shots):
        # ==== Simulation ====
        # ==noisy==
        backend = FakeSydney()
        simulator = AerSimulator.from_backend(backend)
        result = simulator.run(transpile(self.qc, simulator), shots=100000).result()
        # # ==noiseFree==
        # backend = Aer.get_backend('aer_simulator')
        # job = execute(self.qc, backend, shots=shots)
        # result = job.result()
        counts = result.get_counts()
        return counts

shor_code = ShorCode9()
shor_code.build_full_qc(1, err_bit=0, err_type='x')

counts = shor_code.run(100000)
display(plot_histogram(counts, title=r'9-qubit Shor Code: Correctable Fault w/ Correction'))
shor_code.qc.draw('mpl', filename='shor_correctable_single_fault_qc.svg')


# Five qubit code
https://en.wikipedia.org/wiki/Five-qubit_error_correcting_code \
https://journals.aps.org/pra/pdf/10.1103/PhysRevA.93.042333 \
https://quantumcomputing.stackexchange.com/questions/14264/nielsenchuang-5-qubit-quantum-error-correction-encoding-gate

In [None]:

class fiveQubitCode:
    def __init__(self):
        self.qReg = QuantumRegister(5, name='q')
        self.cReg = ClassicalRegister(5, name='c')
        self.qAnc = QuantumRegister(4, name='syn_q')
        self.cAnc = ClassicalRegister(4, name='syn_c')
        self.qc = QuantumCircuit(self.qAnc, self.qReg, self.cAnc,  self.cReg)
        
    def reset_qc(self):
        self.qReg = QuantumRegister(5, name='q')
        self.cReg = ClassicalRegister(5, name='c')
        self.qAnc = QuantumRegister(4, name='syn_q')
        self.cAnc = ClassicalRegister(4, name='syn_c')
        self.qc = QuantumCircuit(self.qAnc, self.qReg, self.cAnc,  self.cReg)
    
    def prepare_initial_state(self, info_bit=0):
        # default all |0>, can do x gate for invert
        # note that information bit at MSB, qReg[4]=qReg[-1]
        if info_bit == 1:
            self.qc.x(self.qReg[-1])
        self.qc.barrier()
        
    def apply_channel(self, err_bit=None, err_type=None):
        # only x, y, z can be corrected
        if err_bit != None and err_type != None:
            if err_type == 'x':
                self.qc.x(self.qReg[err_bit])
            elif err_type == 'y':
                self.qc.y(self.qReg[err_bit])
            else: # 'z'
                self.qc.z(self.qReg[err_bit])
        self.qc.x(self.qReg[2])
        self.qc.barrier(self.qReg, self.qAnc)
    
    def encode(self):
        # ==== Encode, input state at qReg[-1], from qstackex ====
        # add h,z and cx
        self.qc.h(self.qReg[0:-1])
        self.qc.z(self.qReg[-1])
        for i in range(self.qReg.size-1):
            self.qc.cx(self.qReg[i], self.qReg[-1])
        # add cz
        self.qc.cz(self.qReg[0], self.qReg[4])
        self.qc.cz(self.qReg[0], self.qReg[1])
        self.qc.cz(self.qReg[2], self.qReg[3])
        self.qc.cz(self.qReg[1], self.qReg[2])
        self.qc.cz(self.qReg[3], self.qReg[4])
        # add barrier
        self.qc.barrier()
    
    def apply_error_detection(self):
        # ==== Error detection, stabilizer measurement from wiki ====
        self.qc.h(self.qAnc)
        stabilizer = ['x','z','z','x','i'] # in the order of generator g0
        for gen_idx in range(self.qReg.size-1): # generator index, n-k=5-1=4
            # apply one generator to all qubit
            for qb_idx in range(self.qReg.size): # qubit index 
                operation = stabilizer[-gen_idx + qb_idx]
                if operation == 'x':
                    self.qc.cx(self.qAnc[gen_idx], self.qReg[qb_idx])
                elif operation == 'z':
                    self.qc.cz(self.qAnc[gen_idx], self.qReg[qb_idx])
        self.qc.h(self.qAnc)
        self.qc.barrier(self.qAnc, self.qReg)
        self.qc.measure(self.qAnc, self.cAnc)
        self.qc.barrier(self.qAnc, self.qReg)
    
    def find_correction_lut(self):
        # syndrome in cAnc(classical ancilla) order is reverse of wiki
        q5code_lut = {}
        lut_code = ''
        for err_type in ['x', 'y', 'z']:
            for err_bit in range(self.qReg.size):
                self.reset_qc()
                self.prepare_initial_state(info_bit=0)
                self.encode()
                self.apply_channel(err_bit=err_bit, err_type=err_type)
                self.apply_error_detection()
                counts = self.run(shots=1)
                syndrome = list(counts.keys())[0][-4:]
                q5code_lut[syndrome] = {'err_bit': err_bit, 'err_type': err_type}
                lut_code += f'self.qc.{q5code_lut[syndrome]["err_type"]}(self.qReg[{q5code_lut[syndrome]["err_bit"]}]).c_if(self.cAnc, int(\'{syndrome}\',2))\n'
        print(lut_code)
    
    def apply_error_correction(self):
        # ==== Error correction ====
        # hardcode lut from last block (or wiki)
        self.qc.x(self.qReg[0]).c_if(self.cAnc, int('1000',2))
        self.qc.x(self.qReg[1]).c_if(self.cAnc, int('0001',2))
        self.qc.x(self.qReg[2]).c_if(self.cAnc, int('0011',2))
        self.qc.x(self.qReg[3]).c_if(self.cAnc, int('0110',2))
        self.qc.x(self.qReg[4]).c_if(self.cAnc, int('1100',2))
        self.qc.y(self.qReg[0]).c_if(self.cAnc, int('1101',2))
        self.qc.y(self.qReg[1]).c_if(self.cAnc, int('1011',2))
        self.qc.y(self.qReg[2]).c_if(self.cAnc, int('0111',2))
        self.qc.y(self.qReg[3]).c_if(self.cAnc, int('1111',2))
        self.qc.y(self.qReg[4]).c_if(self.cAnc, int('1110',2))
        self.qc.z(self.qReg[0]).c_if(self.cAnc, int('0101',2))
        self.qc.z(self.qReg[1]).c_if(self.cAnc, int('1010',2))
        self.qc.z(self.qReg[2]).c_if(self.cAnc, int('0100',2))
        self.qc.z(self.qReg[3]).c_if(self.cAnc, int('1001',2))
        self.qc.z(self.qReg[4]).c_if(self.cAnc, int('0010',2))
        self.qc.barrier()
        self.qc.measure(self.qReg, self.cReg)
    
    def build_full_qc(self, info_bit, err_bit=None, err_type=None):
        self.reset_qc()
        self.prepare_initial_state(info_bit)
        self.encode()
        
        # sv = Statevector.from_instruction(self.qc)
        # prefix = "$$\\begin{align}"
        # suffix = "\\end{align}$$"
        # state_to_latex = state_visualization._state_to_latex_ket(sv.data, max_size = 128)
        # display(Latex(prefix + state_to_latex + suffix))
        # self.qc.measure(self.qReg, self.cReg)
        
        self.apply_channel(err_bit=err_bit, err_type=err_type)
        
        self.apply_error_detection()
        # self.qc.measure(self.qReg, self.cReg)
        
        self.apply_error_correction()
        
    def run(self, shots):
        # ==== Simulation ====
        # ==noisy==
        # backend = FakeSydney()
        # simulator = AerSimulator.from_backend(backend)
        # result = simulator.run(transpile(self.qc, simulator),shots=shots).result()
        # # ==noiseFree==
        backend = Aer.get_backend('aer_simulator')
        job = execute(self.qc, backend, shots=shots)
        result = job.result()
        
        counts = result.get_counts()
        return counts

q5code = fiveQubitCode()
q5code.build_full_qc(info_bit=1, err_bit=4, err_type='y')
counts = q5code.run(shots=100000)
display(plot_histogram(counts, title=r'5-qubit Code: Non-correctable Fault w/ Correction'))
display(q5code.qc.draw('mpl'))


# Real device running

# Test Playground

In [None]:

sv = Statevector([1/2**0.5,0,0,1/2**0.5])
sv = Statevector.from_instruction(qc)
print(sv)
display(sv.draw('latex'))