# Quantum Error Correction for Dummies (QEC4D)
# Chapter III - Shor's repetition code

So far we've seen how to correct for X errors and how to correct for Z errors separately. However in quantum computation, both types of error happen in the same circuit, we therefore need a code capable of correcting for **both** X **and** Z erros at the same time. This is precisely what ***Shor's repetition code*** -aka the 9 qbit repetition code- does. Note that Shor's code is the ***first quantum error-correcting code***.

## 0) Overview
Shor's code can:
- Encode one logical qbit into 9 physical ones.
- Correct an error on a single qbit.
- Correct either an X or a Z error -or both- on the single qbit of interest.

In [1]:
# First, a few imports
import numpy as np
from copy import deepcopy
from numpy.random import rand, seed
from qiskit import *
from qiskit import Aer
from typing import List, Optional, Tuple

## 2) Building our circuit

### 2-A) Encoding

In [2]:
def parity_check(ctrl_qb_idx:int, target_qbs_idxs:List[int], c:QuantumCircuit) -> QuantumCircuit:
    circuit = deepcopy(c)
    for target_idx in target_qbs_idxs:
        circuit.cnot(ctrl_qb_idx, target_idx)
    return circuit

First, we reproduce the 3-qbit repetition code for phase flip on qbits 0, 3, 6. This is exactly the same structure as we had previously on a 3-qbit circuit, only more spread-out.

In [3]:
c = QuantumCircuit(QuantumRegister(9))
c = parity_check(0, [3,6], c)
c.draw('text')

Now these 3 qbits will be the ones correcting for the phase-flip, so we want to apply a Hadamard transform to them to change from the Z basis to the X basis.

In [4]:
def hadamard_transform(target_qbs_idxs:List[int], c:QuantumCircuit) -> QuantumCircuit:
    circuit = deepcopy(c)
    for target_qb_idx in target_qbs_idxs:
        circuit.h(target_qb_idx)
    return circuit

In [5]:
c_with_hadamard = hadamard_transform([0,3,6], c)
c_with_hadamard.draw('text')

These three qbits are of course the 3-encoding of the single input qbit. But they could also be considered as 3 independent inputs themselves. Thinking that way, it makes sense to now look at how to correct each of those inputs using the 3 bit-flip repetition code we saw in the last chapter.

In [6]:
def all_parity_checks(all_ctrl_qd_idxs:List[int], c:QuantumCircuit) -> QuantumCircuit:
    all_target_qbs_idxs = []
    for idxs in all_ctrl_qd_idxs:
        all_target_qbs_idxs.append([idxs+1, idxs+2])
    circuit = deepcopy(c)
    for i in range(len(all_ctrl_qd_idxs)):
        circuit = parity_check(all_ctrl_qd_idxs[i], all_target_qbs_idxs[i], circuit)
    return circuit

In [7]:
c_with_parity_checks = all_parity_checks([0,3,6], c_with_hadamard)
c_with_parity_checks.draw('text')

At this stage, we have the input qbit 0 which is encoded with a 3-qbit phase-flip repetition code into qbits 0, 3 and 6 such that we have in terms of indexes $0 \mapsto \{0,3,6\}$.
Then, we can consider each of these qbits as an independent input qbit and see that they are encoded with a 3-qbit bit-flip repetition code with respectively $0 \mapsto \{0,1,2\} \:, \: 3 \mapsto \{3,4,5\} \: , \: 6 \mapsto \{6,7,8\}$.

### 2-B) Introducing random error
We have reached the point in the circuit where the error should occur. This time, to check for the efficiency of the code, we'll start introducing random errors.

In [8]:
def introduce_random_error(circuit:QuantumCircuit) -> QuantumCircuit:
    c = deepcopy(circuit)
    error_idx = np.random.randint(0,9) #random number >= 0 and <9
    y_error_happens = np.random.randint(0,4) #random number >= 0 and <4
    if y_error_happens == 3: #one chance out of five to have both errors on the same qbit 
        c.y(error_idx) #same as having both X and Z
    else:
        coin_flip = np.random.randint(0,2) #random number in {0,1}
        if coin_flip == 0: # bit flip
            c.x(error_idx)
        else: # phase flip
            c.z(error_idx)
    return c, error_idx

In [9]:
c_with_errors, error_idx = introduce_random_error(c_with_parity_checks)
c_with_errors.draw('text')

Again, we add parity checks on each group of 3 qbits correcting for X errors:

In [10]:
c_with_second_layer_of_parity_checks = all_parity_checks([0,3,6], c_with_errors)
c_with_second_layer_of_parity_checks.draw('text')

### 2-C) Correcting errors
Next up we start correcting errors, using the Toffoli gate trick introduced in chapter I. Here we correct the bit-flip errors on the three sets of three qbits $\{0,1,2\}, \{3,4,5\}, \{6,7,8\}$.

In [11]:
def apply_all_toffolis(c:QuantumCircuit, all_idxs:List[Tuple[int]]) -> QuantumCircuit:
    circuit = deepcopy(c)
    for idxs in all_idxs:
        circuit = apply_toffoli(circuit, idxs)
    return circuit

def apply_toffoli(c:QuantumCircuit, idxs:Tuple[int]) -> QuantumCircuit:
    circuit = deepcopy(c)
    circuit.ccx(idxs[0], idxs[1], idxs[2])
    return circuit

In [12]:
circuit_corrected_bitflips = apply_all_toffolis(c_with_second_layer_of_parity_checks, [(2,1,0), (5,4,3), (8,7,6)])
circuit_corrected_bitflips.draw('text')

We put back the Hadamards and now start correcting for for phase-flip errors on qbits $\{0,3,6\}$.

In [13]:
circuit_change_basis_qbs036 = hadamard_transform([0,3,6], circuit_corrected_bitflips)
circuit_change_basis_qbs036.draw('text')

Perform the parity checks on the phase-flip correcting qbits.

In [14]:
circuit_with_phaseflip_parity_checks = parity_check(0, [3,6],circuit_change_basis_qbs036)
circuit_with_phaseflip_parity_checks.draw('text')

And finally correct the phase-flip error with a Toffoli.

In [15]:
final_circuit = apply_toffoli(circuit_with_phaseflip_parity_checks, [6,3,0])
final_circuit.draw('text')

## 3) Measuring and testing Shor's code behaviour

We're done, we've built the full circuit corresponding to Shor's code including the encoding and decoding parts. All that's left to do is try it with different inputs and errors and compare the results.

In [16]:
def simulate_measurements(circuit:QuantumCircuit, nb_shots:int=1024) -> dict:
    """Simulates measurement results."""
    backend_sim = Aer.get_backend('qasm_simulator')
    job_sim = backend_sim.run(transpile(circuit, backend_sim), shots=nb_shots)
    result_sim = job_sim.result()
    counts = result_sim.get_counts(circuit)
    return counts

def most_common_output(counts:dict) -> str:
    """ Provides the (q/c)-bit 0>n encoding of the most measured output."""
    return int(max(counts, key=counts.get)[::-1])

In [17]:
nb_trials = 10

for i in range(nb_trials):
    circ = QuantumCircuit(QuantumRegister(9)) #generate circuit

    # Pick initial bit value and encode it
    init_bit_value = np.random.randint(0,2) #pick whether we encode a 0 or 1
    if init_bit_value == 1:
       circ.x(0) #encode the 1

    # Build the circuit
    circ = parity_check(0, [3,6], circ)
    circ = hadamard_transform([0,3,6], circ)
    circ = all_parity_checks([0,3,6], circ)
    circ, error_index =  introduce_random_error(circ)
    circ = all_parity_checks([0,3,6], circ)
    circ = apply_all_toffolis(circ, [(2,1,0), (5,4,3), (8,7,6)])
    circ = hadamard_transform([0,3,6], circ)
    circ = parity_check(0, [3,6], circ)
    circ = apply_toffoli(circ, [6,3,0])
    
   
    # Measuring
   
    circ.add_register(ClassicalRegister(1)) #add one classical bit for outcome measurement
    circ.measure(0,0) #measure value of qbit 0 and store into cbit 0
    result_counts = simulate_measurements(circ)
    output_nb1 = most_common_output(result_counts)

    # Update tracked value
    if not (init_bit_value == output_nb1):
        print("###", init_bit_value, output_nb1, output_nb1 == init_bit_value)

    print(f"Encoded bit value: {init_bit_value} - output value: {output_nb1} \t values match: {init_bit_value == output_nb1}")


Encoded bit value: 1 - output value: 1 	 values match: True
Encoded bit value: 1 - output value: 1 	 values match: True
Encoded bit value: 0 - output value: 0 	 values match: True
Encoded bit value: 0 - output value: 0 	 values match: True
Encoded bit value: 0 - output value: 0 	 values match: True
Encoded bit value: 1 - output value: 1 	 values match: True
Encoded bit value: 0 - output value: 0 	 values match: True
Encoded bit value: 0 - output value: 0 	 values match: True
Encoded bit value: 0 - output value: 0 	 values match: True
Encoded bit value: 0 - output value: 0 	 values match: True


In [18]:
circ.draw('text')