In [None]:
%load_ext jupyternotify
import pennylane as qml
from pennylane import numpy as np
import tensorflow as tf
from sympy.physics.quantum.dagger import Dagger


In [None]:
px = np.array([[0 + 0j, 1 + 0j] , [1 + 0j, 0+0j]])
py = np.array( [[0 + 0j, 0 - 1j] , [0 + 1j, 0+0j]])
pz = np.array( [[1 + 0j, 0 + 0j] , [0 + 0j, -1+0j]])
iden = np.array( [[1 + 0j, 0 + 0j] , [0 + 0j, 1+0j]])

n = 1
lamb = np.float(2)
s = np.exp(-1 / (2 * lamb)) - 1
cst1 = (s / 2 + 1) ** 2
cst2 = (s / 2) * (s / 2 + 1)
cst3 = (s / 2) ** 2

In [None]:
def fidelity(rsv, gsv):
    
    rsv_conj = np.conj(rsv)
    fid = sum(rsv_conj*gsv) * np.conj(sum(rsv_conj*gsv))
    
    return fid

dev1 = qml.device('default.qubit', wires=1) #real
dev2 = qml.device('default.qubit', wires=1) #generator

obs_list = [qml.Identity(0), qml.PauliX(0) , qml.PauliY(0), qml.PauliZ(0)]

In [None]:
def real_circuit(real_weights, wires, **kwargs):
    
    qml.RX(real_weights[0], wires=0)
    qml.RY(real_weights[1], wires=0)
    qml.RZ(real_weights[2], wires=0)

qnodes_real = qml.map(real_circuit, obs_list, dev1, measure="expval",  interface="tf")


In [None]:
def gen_circuit(gen_weights, wires, **kwargs):
    
    qml.RX(gen_weights[0], wires=0)
    qml.RY(gen_weights[1], wires=0)
    qml.RZ(gen_weights[2], wires=0)

qnodes_gen = qml.map(gen_circuit, obs_list, dev2, measure="expval",  interface="tf")


In [None]:
def discriminator(real_exp, gen_exp, disc_weights):
    
    phi_matrix = disc_weights[0]*iden + disc_weights[1]*px + disc_weights[2]*py + disc_weights[3]*pz 
    psi_matrix = disc_weights[4]*iden + disc_weights[5]*px + disc_weights[6]*py + disc_weights[7]*pz 
    
    phi_exp = disc_weights[0]*gen_exp[0] + disc_weights[1]*gen_exp[1] + disc_weights[2]*gen_exp[2] + disc_weights[3]*gen_exp[3]
    psi_exp = disc_weights[4]*real_exp[0] + disc_weights[5]*real_exp[1] + disc_weights[6]*real_exp[2] + disc_weights[7]*real_exp[3]

    A = tf.linalg.expm( (-1/lamb) * phi_matrix )
    B = tf.linalg.expm( (1/lamb) * psi_matrix )
    
    term1 = np.matmul( Dagger(gen_sv) , np.matmul(A.numpy(), gen_sv) )
    term2 = np.matmul( Dagger(real_sv) , np.matmul(B.numpy(), real_sv) )
    term3 = np.matmul( Dagger(gen_sv) , np.matmul(B.numpy(), real_sv) )
    term4 = np.matmul( Dagger(real_sv) , np.matmul(A.numpy(), gen_sv) )
    term5 = np.matmul( Dagger(gen_sv) , np.matmul(A.numpy(), real_sv) )
    term6 = np.matmul( Dagger(real_sv) , np.matmul(B.numpy(), gen_sv) )
    term7 = np.matmul( Dagger(gen_sv) , np.matmul(B.numpy(), gen_sv) )
    term8 = np.matmul( Dagger(real_sv) , np.matmul(A.numpy(), real_sv) )

    regterm = (lamb / np.e * (cst1 * term1 * term2 - cst2 * term3 * term4 - cst2 * term5 * term6 + cst3 * term7 * term8)).item()

    return psi_exp , phi_exp , regterm

In [None]:
def disc_loss(disc_weights):
    
    psi_exp , phi_exp , regterm = discriminator(real_exp, gen_exp, disc_weights)
    loss = np.real(psi_exp - phi_exp - regterm)
    
    return -loss 


def gen_loss(gen_weights):
     
    gen_exp = qnodes_gen(gen_weights)
    gen_sv = dev2.state

    psi_exp , phi_exp , regterm = discriminator(real_exp, gen_exp, disc_weights)
    loss = np.real(psi_exp - phi_exp - regterm)
    
    return loss 


In [None]:

real_weights = np.random.uniform(0,2,3)
real_exp = qnodes_real(real_weights)
real_sv = dev1.state

gen_weights = np.random.uniform(0,2,3)
gen_exp = qnodes_gen(gen_weights)
gen_sv = dev2.state

disc_weights = np.random.uniform(1,1,8)

discriminator(real_exp, gen_exp, disc_weights)

In [None]:
discriminator_optimizer = tf.keras.optimizers.Adam(0.1)
generator_optimizer = tf.keras.optimizers.Adam(0.1)

In [None]:
%%time
%%notify

fid = fidelity(real_sv, gen_sv)

f = []
f.append(fid)
dloss=[]
gloss=[]

niter = 0 


while fid <0.99:
    
    costd = lambda: disc_loss(disc_weights)
    discriminator_optimizer.minimize(costd, disc_weights)
    dloss.append(disc_loss(disc_weights))
    
    costg = lambda: gen_loss(gen_weights)
    generator_optimizer.minimize(costg, gen_weights)
    gloss.append(gen_loss(gen_weights))

    gen_exp = qnodes_gen(gen_weights)
    gen_sv = dev2.state


    fid = fidelity(real_sv, gen_sv)
    
    f.append(fid)
             
    niter += 1

    if niter == 1000:

        break 
    
