In [1]:
import numpy as np
import itertools as it
import matplotlib.pyplot as plt


def base(M, n):
    # calculate the image of the base under a matrix M
    # M: Binary matrix
    # n: number of pairs

    # make a set of all combinations of the first column and the last n columns (these correspond to X_1, Z_1,...,Z_n)
    s = []
    for i in range(n+1, 2*n):
        s.append(M[0:2*n, i])
    powerset = it.chain.from_iterable(it.combinations(s, r) for r in range(1, len(s)+1))
    
    res = [vector(GF(2),2*n)]
        
    for i in powerset:
        v = vector(sum(i))     # calculate the sum of the elements of each combination (e.g IZZ = IZI + IIZ)
        res.append(v)
        
    return res


def pillars(M, n):
    # calculate the image of the pillars under a matrix M
    # M: Binary matrix
    # n: number of pairs
    X1 = vector(M[0:2*n, 0])
    Z1 = vector(M[0:2*n, n])
    Y1 = X1 + Z1
    
    pI = base(M, n)
    pX = [(X1 + b) for b in pI]
    pY = [(Y1 + b) for b in pI]
    pZ = [(Z1 + b) for b in pI]
    
    return [pI, pX, pY, pZ]   

def generate_pauli_dict_2(p_I, p_X, p_Y, p_Z):

    # Given a Pauli vector representing \rho (p_I, p_X, p_Y, p_Z). Return a dictionary representing \rho \otimes \rho. The output is {"XX": p_X * p_X, "XY": p_X * p_Y ...}. 

    paulis = ["I", "X", "Y", "Z"]
    result = {}
    for p1 in paulis:
        for p2 in paulis:
            value = 1
            if p1 == "I":
                value *= p_I
            if p1 == "X":
                value *= p_X
            if p1 == "Y":
                value *= p_Y
            if p1 == "Z":
                value *= p_Z
            if p2 == "I":
                value *= p_I
            if p2 == "X":
                value *= p_X
            if p2 == "Y":
                value *= p_Y
            if p2 == "Z":
                value *= p_Z
            result[p1+p2] = value      
    return result  


def generate_pauli_dict_3(p_I, p_X, p_Y, p_Z):

    # Given a Pauli vector representing \rho (p_I, p_X, p_Y, p_Z). Return a dictionary representing \rho \otimes \rho \otimes \rho. The output is {"XXX": p_X * p_X, "XXY": p_X * p_X * p_Y, ...}. 

    paulis = ["I", "X", "Y", "Z"]
    result = {}
    for p1 in paulis:
        for p2 in paulis:
            for p3 in paulis:
                value = 1
                if p1 == "I":
                    value *= p_I
                if p1 == "X":
                    value *= p_X
                if p1 == "Y":
                    value *= p_Y
                if p1 == "Z":
                    value *= p_Z
                if p2 == "I":
                    value *= p_I
                if p2 == "X":
                    value *= p_X
                if p2 == "Y":
                    value *= p_Y
                if p2 == "Z":
                    value *= p_Z
                if p3 == "I":
                    value *= p_I
                if p3 == "X":
                    value *= p_X
                if p3 == "Y":
                    value *= p_Y
                if p3 == "Z":
                    value *= p_Z
                result[p1+p2+p3] = value      
    return result  

def generate_pauli_dict_2_non_identical(p_I_1, p_X_1, p_Y_1, p_Z_1, p_I_2, p_X_2, p_Y_2, p_Z_2):

    # Given two Pauli vectors representing \rho_1 : (p_I_1, p_X_1, p_Y_1, p_Z_1) and \rho_2 : (p_I_2, p_X_2, p_Y_2, p_Z_2). Return a dictionary representing \rho \otimes \rho. The output is {"XX": p_X_1 * p_X_1, "XY": p_X_1 * p_Y_2 ...}. 

    paulis = ["I", "X", "Y", "Z"]
    result = {}
    for p1 in paulis:
        for p2 in paulis:
            value = 1
            if p1 == "I":
                value *= p_I_1
            if p1 == "X":
                value *= p_X_1
            if p1 == "Y":
                value *= p_Y_1
            if p1 == "Z":
                value *= p_Z_1
            if p2 == "I":
                value *= p_I_2
            if p2 == "X":
                value *= p_X_2
            if p2 == "Y":
                value *= p_Y_2
            if p2 == "Z":
                value *= p_Z_2
            result[p1+p2] = value      
    return result 

def bin_to_pauli(input):
    # Convert a binary string to correspond Pauli operator, in string type.
    if input == "00":
        return "I"
    elif input == "10":
        return "X"
    elif input == "01":
        return "Z"
    elif input == "11":
        return "Y"
    
def translate_pauli_2(pillars):
    # Given pillars from a symplectic operator acting on 2 pairs, get the Pauli elements in string format.
    # For example, if DEJMPS matrix is given, this returns [["II", "YY"], ["XX", "ZZ"], ["ZX", "XZ"], ["YI", "IY"]]
    pilI = []
    pilX = []
    pilY = []
    pilZ = []

    #Case: pilI
    pilI.append(bin_to_pauli(str(pillars[0][0][0]) + str(pillars[0][0][2])) + bin_to_pauli(str(pillars[0][0][1]) + str(pillars[0][0][3]))) 
    pilI.append(bin_to_pauli(str(pillars[0][1][0]) + str(pillars[0][1][2])) + bin_to_pauli(str(pillars[0][1][1]) + str(pillars[0][1][3]))) 

    #Case: pilX
    
    pilX.append(bin_to_pauli(str(pillars[1][0][0]) + str(pillars[1][0][2])) + bin_to_pauli(str(pillars[1][0][1]) + str(pillars[1][0][3]))) 
    pilX.append(bin_to_pauli(str(pillars[1][1][0]) + str(pillars[1][1][2])) + bin_to_pauli(str(pillars[1][1][1]) + str(pillars[1][1][3]))) 

    #Case: pilY
    pilY.append(bin_to_pauli(str(pillars[2][0][0]) + str(pillars[2][0][2])) + bin_to_pauli(str(pillars[2][0][1]) + str(pillars[2][0][3]))) 
    pilY.append(bin_to_pauli(str(pillars[2][1][0]) + str(pillars[2][1][2])) + bin_to_pauli(str(pillars[2][1][1]) + str(pillars[2][1][3]))) 

    #Case: pilZ
    pilZ.append(bin_to_pauli(str(pillars[3][0][0]) + str(pillars[3][0][2])) + bin_to_pauli(str(pillars[3][0][1]) + str(pillars[3][0][3]))) 
    pilZ.append(bin_to_pauli(str(pillars[3][1][0]) + str(pillars[3][1][2])) + bin_to_pauli(str(pillars[3][1][1]) + str(pillars[3][1][3]))) 
    
    return [pilI, pilX, pilY, pilZ]

def translate_pauli_3(pillars):
    # Given pillars from a symplectic operator acting on 3 pairs, get the Pauli elements in string format.

    pilI = []
    pilX = []
    pilY = []
    pilZ = []

    #Case: pilI
    pilI.append(bin_to_pauli(str(pillars[0][0][0]) + str(pillars[0][0][3])) + bin_to_pauli(str(pillars[0][0][1]) + str(pillars[0][0][4])) + bin_to_pauli(str(pillars[0][0][2]) + str(pillars[0][0][5]))) 
    pilI.append(bin_to_pauli(str(pillars[0][1][0]) + str(pillars[0][1][3])) + bin_to_pauli(str(pillars[0][1][1]) + str(pillars[0][1][4])) + bin_to_pauli(str(pillars[0][1][2]) + str(pillars[0][1][5]))) 
    pilI.append(bin_to_pauli(str(pillars[0][2][0]) + str(pillars[0][2][3])) + bin_to_pauli(str(pillars[0][2][1]) + str(pillars[0][2][4])) + bin_to_pauli(str(pillars[0][2][2]) + str(pillars[0][2][5]))) 
    pilI.append(bin_to_pauli(str(pillars[0][3][0]) + str(pillars[0][3][3])) + bin_to_pauli(str(pillars[0][3][1]) + str(pillars[0][3][4])) + bin_to_pauli(str(pillars[0][3][2]) + str(pillars[0][3][5]))) 

    #Case: pilX
    
    pilX.append(bin_to_pauli(str(pillars[1][0][0]) + str(pillars[1][0][3])) + bin_to_pauli(str(pillars[1][0][1]) + str(pillars[1][0][4])) + bin_to_pauli(str(pillars[1][0][2]) + str(pillars[1][0][5]))) 
    pilX.append(bin_to_pauli(str(pillars[1][1][0]) + str(pillars[1][1][3])) + bin_to_pauli(str(pillars[1][1][1]) + str(pillars[1][1][4])) + bin_to_pauli(str(pillars[1][1][2]) + str(pillars[1][1][5]))) 
    pilX.append(bin_to_pauli(str(pillars[1][2][0]) + str(pillars[1][2][3])) + bin_to_pauli(str(pillars[1][2][1]) + str(pillars[1][2][4])) + bin_to_pauli(str(pillars[1][2][2]) + str(pillars[1][2][5]))) 
    pilX.append(bin_to_pauli(str(pillars[1][3][0]) + str(pillars[1][3][3])) + bin_to_pauli(str(pillars[1][3][1]) + str(pillars[1][3][4])) + bin_to_pauli(str(pillars[1][3][2]) + str(pillars[1][3][5]))) 
    

    #Case: pilY
    pilY.append(bin_to_pauli(str(pillars[2][0][0]) + str(pillars[2][0][3])) + bin_to_pauli(str(pillars[2][0][1]) + str(pillars[2][0][4])) + bin_to_pauli(str(pillars[2][0][2]) + str(pillars[2][0][5]))) 
    pilY.append(bin_to_pauli(str(pillars[2][1][0]) + str(pillars[2][1][3])) + bin_to_pauli(str(pillars[2][1][1]) + str(pillars[2][1][4])) + bin_to_pauli(str(pillars[2][1][2]) + str(pillars[2][1][5]))) 
    pilY.append(bin_to_pauli(str(pillars[2][2][0]) + str(pillars[2][2][3])) + bin_to_pauli(str(pillars[2][2][1]) + str(pillars[2][2][4])) + bin_to_pauli(str(pillars[2][2][2]) + str(pillars[2][2][5]))) 
    pilY.append(bin_to_pauli(str(pillars[2][3][0]) + str(pillars[2][3][3])) + bin_to_pauli(str(pillars[2][3][1]) + str(pillars[2][3][4])) + bin_to_pauli(str(pillars[2][3][2]) + str(pillars[2][3][5]))) 

    #Case: pilZ
    pilZ.append(bin_to_pauli(str(pillars[3][0][0]) + str(pillars[3][0][3])) + bin_to_pauli(str(pillars[3][0][1]) + str(pillars[3][0][4])) + bin_to_pauli(str(pillars[3][0][2]) + str(pillars[3][0][5]))) 
    pilZ.append(bin_to_pauli(str(pillars[3][1][0]) + str(pillars[3][1][3])) + bin_to_pauli(str(pillars[3][1][1]) + str(pillars[3][1][4])) + bin_to_pauli(str(pillars[3][1][2]) + str(pillars[3][1][5]))) 
    pilZ.append(bin_to_pauli(str(pillars[3][2][0]) + str(pillars[3][2][3])) + bin_to_pauli(str(pillars[3][2][1]) + str(pillars[3][2][4])) + bin_to_pauli(str(pillars[3][2][2]) + str(pillars[3][2][5]))) 
    pilZ.append(bin_to_pauli(str(pillars[3][3][0]) + str(pillars[3][3][3])) + bin_to_pauli(str(pillars[3][3][1]) + str(pillars[3][3][4])) + bin_to_pauli(str(pillars[3][3][2]) + str(pillars[3][3][5]))) 
    
    
    return [pilI, pilX, pilY, pilZ]


def sum_up_dict(pauli_dict, lst_pauli_str):
    # Given a Pauli dictionary {"XX": ..., "XY": ..., ...} and a list of Pauli string ["XX", "YY", ...]
    # return the sum of corresponding values.
    result = 0
    for p in lst_pauli_str:
        result += pauli_dict[p]
    return result

def get_components_psuc_2(input_vector, M, N):
    # Given an input Pauli vector, and a symplectic operator. Get the result after distillation circuit M. Assume input is two identical pairs.
    pauli_dict = generate_pauli_dict_2(input_vector[0], input_vector[1], input_vector[2], input_vector[3])
    pil_str = translate_pauli_2(pillars(M, N)) 

    # Pauli vectors for the state after distillation.
    p_I = sum_up_dict(pauli_dict, pil_str[0]) / sum_up_dict(pauli_dict, list(it.chain.from_iterable(pil_str))) 
    p_X = sum_up_dict(pauli_dict, pil_str[1]) / sum_up_dict(pauli_dict, list(it.chain.from_iterable(pil_str))) 
    p_Y = sum_up_dict(pauli_dict, pil_str[2]) / sum_up_dict(pauli_dict, list(it.chain.from_iterable(pil_str))) 
    p_Z = sum_up_dict(pauli_dict, pil_str[3]) / sum_up_dict(pauli_dict, list(it.chain.from_iterable(pil_str))) 
    
    #list(it.chain.from_iterable...) is the p_suc
    p_suc = sum_up_dict(pauli_dict, list(it.chain.from_iterable(pil_str)))
    return [p_I, p_X, p_Y, p_Z], p_suc

def get_components_psuc_2_non_identical(input_vector_1, input_vector_2, M, N):
    # Given two input Pauli vectors, and a symplectic operator. Get the result after distillation circuit M. Assume input two pairs are not identical.
    pauli_dict = generate_pauli_dict_2_non_identical(input_vector_1[0], input_vector_1[1], input_vector_1[2], input_vector_1[3], input_vector_2[0], input_vector_2[1], input_vector_2[2], input_vector_2[3])
    pil_str = translate_pauli_2(pillars(M, N)) 

    # Pauli vectors for the state after distillation.
    p_I = sum_up_dict(pauli_dict, pil_str[0]) / sum_up_dict(pauli_dict, list(it.chain.from_iterable(pil_str))) 
    p_X = sum_up_dict(pauli_dict, pil_str[1]) / sum_up_dict(pauli_dict, list(it.chain.from_iterable(pil_str))) 
    p_Y = sum_up_dict(pauli_dict, pil_str[2]) / sum_up_dict(pauli_dict, list(it.chain.from_iterable(pil_str))) 
    p_Z = sum_up_dict(pauli_dict, pil_str[3]) / sum_up_dict(pauli_dict, list(it.chain.from_iterable(pil_str))) 
    
    #list(it.chain.from_iterable...) is the p_suc
    p_suc = sum_up_dict(pauli_dict, list(it.chain.from_iterable(pil_str)))
    return [p_I, p_X, p_Y, p_Z], p_suc

def get_components_psuc_3(input_vector, M, N):
    # Given an input Pauli vector, and a symplectic operator. Get the result after distillation circuit M. Assume input is three identical pairs.
    pauli_dict = generate_pauli_dict_3(input_vector[0], input_vector[1], input_vector[2], input_vector[3])
    pil_str = translate_pauli_3(pillars(M, N)) 

    # Pauli vectors for the state after distillation.
    p_I = sum_up_dict(pauli_dict, pil_str[0]) / sum_up_dict(pauli_dict, list(it.chain.from_iterable(pil_str))) 
    p_X = sum_up_dict(pauli_dict, pil_str[1]) / sum_up_dict(pauli_dict, list(it.chain.from_iterable(pil_str))) 
    p_Y = sum_up_dict(pauli_dict, pil_str[2]) / sum_up_dict(pauli_dict, list(it.chain.from_iterable(pil_str))) 
    p_Z = sum_up_dict(pauli_dict, pil_str[3]) / sum_up_dict(pauli_dict, list(it.chain.from_iterable(pil_str))) 
    
    #list(it.chain.from_iterable...) is the p_suc
    p_suc = sum_up_dict(pauli_dict, list(it.chain.from_iterable(pil_str)))
    return [p_I, p_X, p_Y, p_Z], p_suc

def get_werner(fidelity):
    # Generate Werner state in Pauli vector form.
    return [fidelity, (1-fidelity) / 3, (1-fidelity) / 3, (1-fidelity) / 3]

def permute_paulies(p_vector):
    # Dumb way to generate all permutations of the p_X, p_Y, p_Z
    P_1 = [p_vector[0], p_vector[1], p_vector[2], p_vector[3]]
    P_2 = [p_vector[0], p_vector[1], p_vector[3], p_vector[2]]
    P_3 = [p_vector[0], p_vector[2], p_vector[1], p_vector[3]]
    P_4 = [p_vector[0], p_vector[2], p_vector[3], p_vector[1]]
    P_5 = [p_vector[0], p_vector[3], p_vector[1], p_vector[2]]
    P_6 = [p_vector[0], p_vector[2], p_vector[3], p_vector[1]]

    return [P_1, P_2, P_3, P_4, P_5, P_6]

def bin_entropy(p):
    # Binary entropy function
    if p == 0 or p == 1:
        return 0
    else:
        return (-p* np.log2(p)) - (1-p) * np.log2(1-p)
    
def get_bb84_rate(p_vector, p_ED1, p_ED2, p_AD):
    # Calculate BB84 key rates

    # p_bitflip is p_X + p_Y
    error_bitfilip = p_vector[1] + p_vector[2]

    # p_phaseflip is p_Y + p_Z
    error_phaseflip= p_vector[2] + p_vector[3]

    prefactor = (((2/p_ED1) + 1) / p_ED2 ) * 3 / p_AD
    
    r_sk = (1 - bin_entropy(error_bitfilip) - bin_entropy(error_phaseflip))  * 1/prefactor

    if r_sk < 0 :
        return 0 

    return r_sk

In [2]:
input_fid_list = np.linspace(0.5, 1, 500)

num_local_pauli = 6



output_fids = np.zeros((len(input_fid_list)))
output_keyrates = np.zeros((len(input_fid_list)))


# For each fidelity
for idx_fidelity in range(len(input_fid_list)):
    werner_state = get_werner(input_fid_list[idx_fidelity])
    temp_rate = []
    temp_fids = []
    for idx_local_pauli in range(num_local_pauli): 
        for idx_local_pauli_2 in range(num_local_pauli):    
            # DEJMPS - stage 1
            M_ED = load(f'../ED_transversal/2_pair/ED2_11.sobj').inverse()
            state_DEJMPS_1, p_ED1 = get_components_psuc_2(werner_state, M_ED, 2)
            
            # DEJMPS - stage 2
            state_DEJMPS_2, p_ED2 = get_components_psuc_2_non_identical(state_DEJMPS_1, werner_state, M_ED, 2)
            state_permute = permute_paulies(state_DEJMPS_2)[idx_local_pauli]
            # AD stage
            M_CNOT = load(f'../CNOTs/3_pair/repetition_3.sobj').inverse()

            state_cnot, p_AD = get_components_psuc_3(state_DEJMPS_2, M_CNOT, 3)
            state_final_permute = permute_paulies(state_cnot)[idx_local_pauli_2]
            temp_fids.append(state_final_permute[0])
            temp_rate.append(get_bb84_rate(state_final_permute, p_ED1, p_ED2, p_AD))
    output_fids[idx_fidelity] = max(temp_fids)
    output_keyrates[idx_fidelity] = max(temp_rate)

np.save("../data/werner/fids_DEJMPS3_3", output_fids)
np.save("../data/werner/bb84_key_rate_DEJMPS3_3", output_keyrates)
