In [None]:
from qiskit import QuantumCircuit, ClassicalRegister, QuantumRegister
from qiskit import Aer, execute
from qiskit.circuit.library import MCXGate
import matplotlib.pyplot as plt
import random
import math

#需要quantum fanout和之前的多toffoli的技术
def fanout_left(qc,target):
    length = len(target)
    d = math.ceil(math.log2(length))
    qc_temp = []
    for i in range(d):
        for j in range(length):
            if (2*j+1) <= length-1:
                qc.cx(target[2*j],target[2*j+1])
                qc_temp.append([target[2*j],target[2*j+1]])
            else:
                break
        target = target[::2]
        length = len(target)
        
def fanout_right(qc,target):
    length = len(target)
    d = math.ceil(math.log2(length))
    qc_temp = []
    for i in range(d):
        for j in range(length):
            if (2*j+1) <= length-1:
                qc_temp.append([target[2*j],target[2*j+1]])
            else:
                break
        target = target[::2]
        length = len(target)
    for i in range(len(qc_temp)):
        a = qc_temp[len(qc_temp)-1-i][0]
        b = qc_temp[len(qc_temp)-1-i][1]
        qc.cx(a,b)
        
# control is control qubit, target is the list of target qubit
def fanout(qc,control,target):
    length = len(target)
    d = math.ceil(math.log2(length))
    qc_temp = []
    for i in range(d):
        for j in range(length):
            if (2*j+1) <= length-1:
                qc.cx(target[2*j],target[2*j+1])
                qc_temp.append([target[2*j],target[2*j+1]])
            else:
                break
        target = target[::2]
        length = len(target)
    qc.cx(control,target[0])
    for i in range(len(qc_temp)):
        a = qc_temp[len(qc_temp)-1-i][0]
        b = qc_temp[len(qc_temp)-1-i][1]
        qc.cx(a,b)
        
#inl is the list of input qubit, inl[0] dicide the sign of fixed point number, outl is the list of output qubit
def multi_toffoli(qc,inl,outl):
    n = len(inl)
    for i in range(n-1):
        qc.h(outl[i])
    for i in range(n-1):
        qc.cx(inl[i+1],outl[i])
    for i in range(n-1):
        qc.tdg(outl[i])
    fanout(qc,inl[0],outl) # first fanout
    for i in range(n-1):
        qc.t(outl[i])
    for i in range(n-1):
        qc.cx(inl[i+1],outl[i])
    for i in range(n-1):
        qc.tdg(outl[i])
    fanout(qc,inl[0],outl) #second fanout
    for i in range(n-1):
        qc.t(outl[i])
        qc.t(inl[i+1])
    temp = (n-1) % 8
    if temp == 1:
        qc.t(inl[0])
    elif temp == 2:
        qc.s(inl[0])
    elif temp == 3:
        qc.t(inl[0])
        qc.s(inl[0])
    elif temp == 4:
        qc.z(inl[0])
    elif temp == 5:
        qc.t(inl[0])
        qc.z(inl[0])
    elif temp == 6:
        qc.s(inl[0])
        qc.z(inl[0])
    elif temp == 7:
        qc.tdg(inl[0])
    #qc.barrier()               #this line can be deleted
    for i in range(n-1):
        qc.h(outl[i])
    fanout(qc,inl[0],inl[1:n]) #third fanout
    for i in range(n-1):
        qc.tdg(inl[i+1])
    fanout(qc,inl[0],inl[1:n]) #last fanout

#select
#inputn是输入的比特数量，outputn是输出的比特数量，swapn是swap中当控制位的qubit，table就是完整表格，要求每一项都是（in，out），按照in的字典序排列
def select(qc,inputn,outputn,swapn,table):
    if (2**inputn) != len(table):
        print("Table error!!!")
    n = inputn-swapn
    for i in range(2**n):
        clist = []
        for j in range(2**swapn):
            for k in range(outputn):
                index = j * (2**n) + i
                if table[index][1][k] == '1':
                    clist.append(inputn + outputn*j + k)
        for j in range(inputn-swapn):
            if table[i][0][swapn+j] == '0':
                qc.x(swapn+j)
        fanout_left(qc,clist)
        gate = MCXGate(n)
        mclist = []
        for j in range(n):
            mclist.append(j+swapn)
        mclist.append(clist[0])
        qc.append(gate, mclist)
        fanout_right(qc,clist)
        for j in range(inputn-swapn):
            if table[i][0][swapn+j] == '0':
                qc.x(swapn+j)
    
#swap
def swap(qc,inputn,outputn,swapn,table):
    for i in range(swapn):
        gap = 2**(i+1)
        list1 = []
        list2 = []
        for j in range(2**(swapn-i-1)):
            index1 = inputn + j*gap*outputn
            index2 = inputn + j*gap*outputn + (2**i)*outputn
            for k in range(outputn):
                list1.append(index1+k)
                list2.append(index2+k)
        inl = [swapn-i-1]
        outl = []
        for j in range(len(list1)):
            inl.append(list1[j])
            outl.append(list2[j])
        for j in range(len(list1)):
            qc.cx(list2[j],list1[j])
        multi_toffoli(qc,inl,outl)
        for j in range(len(list1)):
            qc.cx(list2[j],list1[j])

#selectswap
def selectswap(inputn,outputn,swapn,table,input_instance):
    qubitn = inputn + outputn * 2**swapn
    qc = QuantumCircuit(qubitn,qubitn)
    test_input = table[input_instance][0]
    for i in range(inputn):
        if test_input[i] == '1':
            qc.x(i)
    select(qc,inputn,outputn,swapn,table)
    swap(qc,inputn,outputn,swapn,table)
    for i in range(outputn * (2**swapn)+inputn):
        qc.measure(i,i)
    # for i in range(outputn):
    #         qc.measure(inputn + i,inputn + i)
    return qc


In [None]:
inputn = 5
outputn = 5
swapn = 2
table = []
seed = "01"
for i in range(2**inputn):
    newin = ''
    data = bin(i)
    length = inputn - (len(data)-2)
    for j in range(length):
        newin = newin + '0'
    for j in range(len(data)-2):
        newin = newin + data[j+2]
    newout = ''
    for i in range(outputn):
        newout = newout + random.choice(seed)[0]
    table.append([newin,newout])
    
input_instance = 15
qc = selectswap(inputn,outputn,swapn,table,input_instance)
simulator = Aer.get_backend('qasm_simulator')
result = execute(qc, simulator, shots=3, memory=True).result()
memory = result.get_memory(qc)
check = 1
for i in range(3):
    output = []
    for j in range(outputn* (2**swapn) + inputn):
        output.append(memory[i][outputn* (2**swapn) + inputn-1-j])
    print("output[",i,"]",output)
print(table[input_instance])
print(table)