In [1]:
import pennylane as qml
from pennylane import numpy as np
import random
import math

In [147]:
dev = qml.device("default.qubit", wires = 5) #dim(B')=1, dim(B)=1, dim(A)=3

initstate = []
for i in range(16):
    initstate.append('{0:04b}'.format(i)) #binary strings of all 2^4 possible states
    
#########################################################
    
def unitary(angle, pauli): #emulating Figure 3B from the Quantum AutoEncoder paper
    #one rotation for each qubit
    for i in range(1,5):
        if pauli[i-1]==0:
            qml.RX(angle[i-1], wires=i)
        if pauli[i-1]==1:
            qml.RY(angle[i-1], wires=i)
        if pauli[i-1]==2:
            qml.RZ(angle[i-1], wires=i)
    
    index = 4
    #controlled pauli rotations between every permutation of qubits
    for i in range(1,5):
        for j in range(1,5):
            if i!=j:
                if pauli[index]==0:
                    qml.CRX(angle[index], wires=[i,j])
                if pauli[index]==1:
                    qml.CRY(angle[index], wires=[i,j])
                if pauli[index]==2:
                    qml.CRZ(angle[index], wires=[i,j])
                index += 1
                
    #one rotation for each qubit
    for i in range(1,5):
        if pauli[len(pauli)-5+i]==0:
            qml.RX(angle[len(pauli)-5+i], wires=i)
        if pauli[len(pauli)-5+i]==1:
            qml.RY(angle[len(pauli)-5+i], wires=i)
        if pauli[len(pauli)-5+i]==2:
            qml.RZ(angle[len(pauli)-5+i], wires=i)

#########################################################
            
def invert(angle, pauli):
    #invert the last 4 rotations in reverse order
    for i in range(4,0, -1):
        if pauli[len(pauli)-5+i]==0:
            qml.RX(angle[len(pauli)-5+i], wires=i)
        if pauli[len(pauli)-5+i]==1:
            qml.RY(angle[len(pauli)-5+i], wires=i)
        if pauli[len(pauli)-5+i]==2:
            qml.RZ(angle[len(pauli)-5+i], wires=i)
            
    index = 15
    #invert controlled pauli rotations in reverse order
    for i in range(4,0, -1):
        for j in range(4,0, -1):
            if i!=j:
                if pauli[index]==0:
                    qml.CRX(angle[index], wires=[i,j])
                if pauli[index]==1:
                    qml.CRY(angle[index], wires=[i,j])
                if pauli[index]==2:
                    qml.CRZ(angle[index], wires=[i,j])
                index -= 1 
                
    #invert first 4 rotations            
    for i in range(4,0,-1):
        if pauli[i-1]==0:
            qml.RX(angle[i-1], wires=i)
        if pauli[i-1]==1:
            qml.RY(angle[i-1], wires=i)
        if pauli[i-1]==2:
            qml.RZ(angle[i-1], wires=i)
            
#########################################################
            
angles = []
paulis = []
for i in range(20): #4+(4*3)+4
    paulis.append(random.randint(0,2))
    angles.append(random.uniform(0,2*math.pi))
invangles = []
for i in range(len(angles)):
    invangles.append(angles[i]*-1)

@qml.qnode(dev)
def total():
    unitary(angles, paulis)
    qml.SWAP(wires=[0,1]) #swap B and B': trash state is now wire 0
    invert(invangles, paulis)
    return qml.expval(qml.PauliZ(0)) #inner product |0>, which was the reference state

def qae():  
    angles = []
    paulis = []
    netfid = 0
    for i in range(20): #4+(4*3)+4
        paulis.append(random.randint(0,2))
        angles.append(random.uniform(0,2*math.pi))
    invangles = []
    for i in range(len(angles)):
        invangles.append(angles[i]*-1)    
        
    for m in range(len(initstate)):
        for n in range(1,5):
            if initstate[m][n-1]=='1':
                qml.PauliX(wires=n) #initialize each possible 2x2 binary grid as a 4-qubit state for wires 1-4
        
        sq_inner = (total())**2
        netfid += sq_inner
    return(netfid/len(initstate)) #avg of all fidelities (equal probs for states): cost function defined in paper

print(qae(), '\n')
for i in range(20):
    print(paulis[i], angles[i])

0.8615083810132819 

1 1.0441858983373202
2 1.27367955413923
0 5.750321782802544
2 4.882060098177182
2 4.682368937417197
2 5.205237834793725
0 2.4031500140094115
1 0.18467309569945117
0 5.379752648900845
0 5.167578133604119
2 0.48887152280862617
0 2.5074757660492186
2 1.2741396613497558
1 2.884294165605277
2 2.9939253509482024
2 2.258891849330684
1 6.280243144257122
2 5.584422606401422
0 3.8703226767207455
2 3.2390645927078796
