In [1]:
import sympy
import itertools as it
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
%matplotlib inline

font = { 'size':22}
matplotlib.rc('font', **font)

In [2]:
def generate_pauli_dict_2(p_I, p_X, p_Y, p_Z):

    # Given a Pauli vector return a dictionary in the format of {"XX": p_X*p_X, ...}. This is a representation of a state \rho \otimes \rho.

    paulis = ["I", "X", "Y", "Z"]
    result = {}
    for p1 in paulis:
        for p2 in paulis:
            value = 1
            if p1 == "I":
                value *= p_I
            if p1 == "X":
                value *= p_X
            if p1 == "Y":
                value *= p_Y
            if p1 == "Z":
                value *= p_Z
            if p2 == "I":
                value *= p_I
            if p2 == "X":
                value *= p_X
            if p2 == "Y":
                value *= p_Y
            if p2 == "Z":
                value *= p_Z
            result[p1+p2] = value      
    return result  

def generate_pauli_dict_3(p_I, p_X, p_Y, p_Z):

    # Given a Pauli vector representing \rho (p_I, p_X, p_Y, p_Z). Return a dictionary representing \rho \otimes \rho \otimes \rho. The output is {"XXX": p_X * p_X, "XXY": p_X * p_X * p_Y, ...}. 

    paulis = ["I", "X", "Y", "Z"]
    result = {}
    for p1 in paulis:
        for p2 in paulis:
            for p3 in paulis:
                value = 1
                if p1 == "I":
                    value *= p_I
                if p1 == "X":
                    value *= p_X
                if p1 == "Y":
                    value *= p_Y
                if p1 == "Z":
                    value *= p_Z
                if p2 == "I":
                    value *= p_I
                if p2 == "X":
                    value *= p_X
                if p2 == "Y":
                    value *= p_Y
                if p2 == "Z":
                    value *= p_Z
                if p3 == "I":
                    value *= p_I
                if p3 == "X":
                    value *= p_X
                if p3 == "Y":
                    value *= p_Y
                if p3 == "Z":
                    value *= p_Z
                result[p1+p2+p3] = value      
    return result  


def generate_pauli_dict_4(p_I, p_X, p_Y, p_Z):



    paulis = ["I", "X", "Y", "Z"]
    result = {}
    for p1 in paulis:
        for p2 in paulis:
            for p3 in paulis:
                for p4 in paulis:
                    value = 1
                    if p1 == "I":
                        value *= p_I
                    if p1 == "X":
                        value *= p_X
                    if p1 == "Y":
                        value *= p_Y
                    if p1 == "Z":
                        value *= p_Z
                    if p2 == "I":
                        value *= p_I
                    if p2 == "X":
                        value *= p_X
                    if p2 == "Y":
                        value *= p_Y
                    if p2 == "Z":
                        value *= p_Z
                    if p3 == "I":
                        value *= p_I
                    if p3 == "X":
                        value *= p_X
                    if p3 == "Y":
                        value *= p_Y
                    if p3 == "Z":
                        value *= p_Z
                    if p4 == "I":
                        value *= p_I
                    if p4 == "X":
                        value *= p_X
                    if p4 == "Y":
                        value *= p_Y
                    if p4 == "Z":
                        value *= p_Z
                    result[p1+p2+p3+p4] = value      
    return result  


def sum_up_dict(pauli_dict, lst_pauli_str):
    # Given a Pauli dictionary {"XX": ..., "XY": ..., ...} and a list of Pauli string ["XX", "YY", ...]
    # return the sum of corresponding values.
    result = 0
    for p in lst_pauli_str:
        result += pauli_dict[p]
    return result

def distillation_statistics(state, pillars):
    p_I = sum_up_dict(state, pillars[0]) / sum_up_dict(state, list(it.chain.from_iterable(pillars))) 
    p_X = sum_up_dict(state, pillars[1]) / sum_up_dict(state, list(it.chain.from_iterable(pillars))) 
    p_Y = sum_up_dict(state, pillars[2]) / sum_up_dict(state, list(it.chain.from_iterable(pillars))) 
    p_Z = sum_up_dict(state, pillars[3]) / sum_up_dict(state, list(it.chain.from_iterable(pillars))) 
    p_suc = sum_up_dict(state, list(it.chain.from_iterable(pillars)))

    return [p_I, p_X, p_Y, p_Z], p_suc

def permute_paulies(pauli_vector):
    # Dumb way to generate all permutations of the p_X, p_Y, p_Z
    P_1 = [pauli_vector[0], pauli_vector[1], pauli_vector[2], pauli_vector[3]]
    P_2 = [pauli_vector[0], pauli_vector[1], pauli_vector[3], pauli_vector[2]]
    P_3 = [pauli_vector[0], pauli_vector[2], pauli_vector[1], pauli_vector[3]]
    P_4 = [pauli_vector[0], pauli_vector[2], pauli_vector[3], pauli_vector[1]]
    P_5 = [pauli_vector[0], pauli_vector[3], pauli_vector[1], pauli_vector[2]]
    P_6 = [pauli_vector[0], pauli_vector[2], pauli_vector[3], pauli_vector[1]]
    return [P_1, P_2, P_3, P_4, P_5, P_6]

def get_bb84_rate(pauli_vector, num_pair, p_suc):
    possible_rates = []
    possible_pauli_vectors = permute_paulies(pauli_vector)

    for p in possible_pauli_vectors:
        error_bitfilip = p[1] + p[2]
        error_phaseflip= p[2] + p[3]
        r_sk = p_suc  * (1 - bin_entropy(error_bitfilip) - bin_entropy(error_phaseflip)) / num_pair

        # r_sk = (1 - bin_entropy(error_bitfilip) - bin_entropy(error_phaseflip)) 
        if r_sk < 0 :
            possible_rates.append(0)
        else:
            possible_rates.append(r_sk)
    return max(possible_rates)

def bin_entropy(p):
    # Binary entropy function
    if p == 0 or p == 1:
        return 0
    else:
        return (-p* np.log2(p)) - (1-p) * np.log2(1-p)

In [3]:
AD_2_pillars = [['II', 'ZZ'], ['XX', 'YY'], ['YX', 'XY'], ['ZI', 'IZ']]

AD_3_pillars = [['III', 'ZZI', 'ZIZ', 'IZZ'], ['XXX', 'YYX', 'YXY', 'XYY'], ['YXX', 'XYX', 'XXY', 'YYY'], ['ZII', 'IZI', 'IIZ', 'ZZZ']]

AD_4_pillars = [['IIII', 'ZZII', 'ZIZI', 'ZIIZ', 'IZZI', 'IZIZ', 'IIZZ', 'ZZZZ'], ['XXXX', 'YYXX', 'YXYX', 'YXXY', 'XYYX', 'XYXY', 'XXYY', 'YYYY'], ['YXXX', 'XYXX', 'XXYX', 'XXXY', 'YYYX', 'YYXY', 'YXYY', 'XYYY'], ['ZIII', 'IZII', 'IIZI', 'IIIZ', 'ZZZI', 'ZZIZ', 'ZIZZ', 'IZZZ']]

In [4]:
F = sympy.symbols('F')

p_I = F
p_X = (1 - F) / 3
p_Y = (1 - F) / 3
p_Z = (1 - F) / 3

werner_state_2 = generate_pauli_dict_2(p_I, p_X, p_Y, p_Z)
werner_state_3 = generate_pauli_dict_3(p_I, p_X, p_Y, p_Z)
werner_state_4 = generate_pauli_dict_4(p_I, p_X, p_Y, p_Z)

In [5]:
AD_2_I = distillation_statistics(werner_state_2, AD_2_pillars)[0][0]
AD_2_X = distillation_statistics(werner_state_2, AD_2_pillars)[0][1]
AD_2_Y = distillation_statistics(werner_state_2, AD_2_pillars)[0][2]
AD_2_Z = distillation_statistics(werner_state_2, AD_2_pillars)[0][3]
AD_2_p_suc = distillation_statistics(werner_state_2, AD_2_pillars)[1]

AD_3_I = distillation_statistics(werner_state_3, AD_3_pillars)[0][0]
AD_3_X = distillation_statistics(werner_state_3, AD_3_pillars)[0][1]
AD_3_Y = distillation_statistics(werner_state_3, AD_3_pillars)[0][2]
AD_3_Z = distillation_statistics(werner_state_3, AD_3_pillars)[0][3]
AD_3_p_suc = distillation_statistics(werner_state_3, AD_3_pillars)[1]

AD_4_I = distillation_statistics(werner_state_4, AD_4_pillars)[0][0]
AD_4_X = distillation_statistics(werner_state_4, AD_4_pillars)[0][1]
AD_4_Y = distillation_statistics(werner_state_4, AD_4_pillars)[0][2]
AD_4_Z = distillation_statistics(werner_state_4, AD_4_pillars)[0][3]
AD_4_p_suc = distillation_statistics(werner_state_4, AD_4_pillars)[1]

In [6]:
fid_list = np.linspace(0.5, 1, 500)

ID_I_comp = []
ID_X_comp = []
ID_Y_comp = []
ID_Z_comp = []

AD_2_I_comp = []
AD_2_X_comp = []
AD_2_Y_comp = []
AD_2_Z_comp = []
AD_2_p_suc_num = []
AD_2_key_rate = []

AD_3_I_comp = []
AD_3_X_comp = []
AD_3_Y_comp = []
AD_3_Z_comp = []
AD_3_p_suc_num = []
AD_3_key_rate = []

AD_4_I_comp = []
AD_4_X_comp = []
AD_4_Y_comp = []
AD_4_Z_comp = []
AD_4_p_suc_num = []
AD_4_key_rate = []


for f in fid_list:
    ID_I_comp.append(float(p_I.subs({F:f})))
    ID_X_comp.append(float(p_X.subs({F:f})))
    ID_Y_comp.append(float(p_Y.subs({F:f})))
    ID_Z_comp.append(float(p_Z.subs({F:f})))

    AD_2_I_comp.append(float(AD_2_I.subs({F:f})))
    AD_2_X_comp.append(float(AD_2_X.subs({F:f})))
    AD_2_Y_comp.append(float(AD_2_Y.subs({F:f})))
    AD_2_Z_comp.append(float(AD_2_Z.subs({F:f})))
    AD_2_p_suc_num.append(float(AD_2_p_suc.subs({F:f})))

    AD_3_I_comp.append(float(AD_3_I.subs({F:f})))
    AD_3_X_comp.append(float(AD_3_X.subs({F:f})))
    AD_3_Y_comp.append(float(AD_3_Y.subs({F:f})))
    AD_3_Z_comp.append(float(AD_3_Z.subs({F:f})))
    AD_3_p_suc_num.append(float(AD_3_p_suc.subs({F:f})))

    AD_4_I_comp.append(float(AD_4_I.subs({F:f})))
    AD_4_X_comp.append(float(AD_4_X.subs({F:f})))
    AD_4_Y_comp.append(float(AD_4_Y.subs({F:f})))
    AD_4_Z_comp.append(float(AD_4_Z.subs({F:f})))
    AD_4_p_suc_num.append(float(AD_4_p_suc.subs({F:f})))

for i in range(len(fid_list)):
    AD_2_key_rate.append(get_bb84_rate([AD_2_I_comp[i], AD_2_X_comp[i], AD_2_Y_comp[i], AD_2_Z_comp[i]], 2, AD_2_p_suc_num[i]))
    AD_3_key_rate.append(get_bb84_rate([AD_3_I_comp[i], AD_3_X_comp[i], AD_3_Y_comp[i], AD_3_Z_comp[i]], 3, AD_3_p_suc_num[i]))
    AD_4_key_rate.append(get_bb84_rate([AD_4_I_comp[i], AD_4_X_comp[i], AD_4_Y_comp[i], AD_4_Z_comp[i]], 4, AD_4_p_suc_num[i]))

In [None]:
plt.figure(figsize=(10,10))
plt.title('2-1 Repetition code')
cmap = plt.get_cmap("tab10")
plt.xlabel("Input fidelity")
plt.ylabel("Components")

plt.plot(fid_list, ID_I_comp, label = 'I (no distillation)', linestyle = '--', color = cmap(0))
plt.plot(fid_list, ID_X_comp, label = 'X, Y, Z (no distillation)', linestyle = '--',  color = cmap(4))
plt.plot(fid_list, AD_2_I_comp, label = 'I',  color = cmap(0))
plt.plot(fid_list, AD_2_X_comp, label = 'X',  color = cmap(1))
plt.plot(fid_list, AD_2_Y_comp, label = 'Y',  color = cmap(2))
plt.plot(fid_list, AD_2_Z_comp, label = 'Z',  color = cmap(3))
plt.gca().invert_xaxis()
plt.legend(loc='best')

In [None]:
plt.figure(figsize=(10,10))

cmap = plt.get_cmap("tab10")
plt.title('3-1 Repetition code')
plt.xlabel("Input fidelity")
plt.ylabel("Components")

plt.plot(fid_list, ID_I_comp, label = 'I (no distillation)', linestyle = '--', color = cmap(0))
plt.plot(fid_list, ID_X_comp, label = 'X, Y, Z (no distillation)', linestyle = '--',  color = cmap(4))
plt.plot(fid_list, AD_3_I_comp, label = 'I',  color = cmap(0))
plt.plot(fid_list, AD_3_X_comp, label = 'X',  color = cmap(1))
plt.plot(fid_list, AD_3_Y_comp, label = 'Y',  color = cmap(2))
plt.plot(fid_list, AD_3_Z_comp, label = 'Z',  color = cmap(3))
plt.gca().invert_xaxis()
plt.legend(loc='best')

In [None]:
plt.figure(figsize=(10,10))
plt.title('4-1 Repetition code')
cmap = plt.get_cmap("tab10")
plt.xlabel("Input fidelity")
plt.ylabel("Components")

plt.plot(fid_list, ID_I_comp, label = 'I (no distillation)', linestyle = '--', color = cmap(0))
plt.plot(fid_list, ID_X_comp, label = 'X, Y, Z (no distillation)', linestyle = '--',  color = cmap(4))
plt.plot(fid_list, AD_4_I_comp, label = 'I',  color = cmap(0))
plt.plot(fid_list, AD_4_X_comp, label = 'X',  color = cmap(1))
plt.plot(fid_list, AD_4_Y_comp, label = 'Y',  color = cmap(2))
plt.plot(fid_list, AD_4_Z_comp, label = 'Z',  color = cmap(3))
plt.gca().invert_xaxis()
plt.legend(loc='best')

In [None]:
print(sympy.latex(AD_4_Z))

In [None]:
print(sympy.latex(sympy.simplify(AD_4_Z)))

In [None]:
plt.figure(figsize=(10,10))
plt.title('BB84 key rates')
cmap = plt.get_cmap("tab10")
plt.xlabel("Input fidelity")
plt.ylabel("Rates")


plt.plot(fid_list, AD_2_key_rate, label = '2-1 Repetition code',  color = cmap(0))
plt.plot(fid_list, AD_3_key_rate, label = '3-1 Repetition code',  color = cmap(1))
plt.plot(fid_list, AD_4_key_rate, label = '4-1 Repetition code',  color = cmap(2))

plt.gca().invert_xaxis()
plt.legend(loc='best')