<a href="https://colab.research.google.com/github/theerfan/Maqenta/blob/main/src/QuGan_pennyLane_improved.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# !pip install tensorflow pennylane pennylane-cirq
# !pip install protobuf==3.13.0

In [None]:
import pennylane as qml
import numpy as np
import tensorflow as tf

In [None]:
dev = qml.device('cirq.simulator', wires=3)

In [None]:
def real_data(angles, **kwargs):
  qml.Hadamard(wires=0)
  qml.Rot(*angles, wires=0) 

In [None]:
def generator(w, **kwargs):
    qml.Hadamard(wires=0)
    qml.RX(w[0], wires=0)
    qml.RX(w[1], wires=1)
    qml.RY(w[2], wires=0)
    qml.RY(w[3], wires=1)
    qml.RZ(w[4], wires=0)
    qml.RZ(w[5], wires=1)
    qml.CNOT(wires=[0, 1])
    qml.RX(w[6], wires=0)
    qml.RY(w[7], wires=0)
    qml.RZ(w[8], wires=0)

def discriminator(w):
    qml.Hadamard(wires=0)
    qml.RX(w[0], wires=0)
    qml.RX(w[1], wires=2)
    qml.RY(w[2], wires=0)
    qml.RY(w[3], wires=2)
    qml.RZ(w[4], wires=0)
    qml.RZ(w[5], wires=2)
    qml.CNOT(wires=[0, 2])
    qml.RX(w[6], wires=2)
    qml.RY(w[7], wires=2)
    qml.RZ(w[8], wires=2)

In [None]:
#Circuits for generating real and generated data and passing them to the discriminator
@qml.qnode(dev, interface="tf")
def real_disc_circuit(phi, theta, omega, disc_weights):
  real_data([phi, theta, omega])
  discriminator(disc_weights)
  return qml.expval(qml.PauliZ(2))

@qml.qnode(dev, interface="tf")
def gen_disc_circuit(gen_weights, disc_weights):
  generator(gen_weights)
  discriminator(disc_weights)
  return qml.expval(qml.PauliZ(2))

In [None]:
def prob_real_true(disc_weights):
  true_disc_output = real_disc_circuit(phi, theta, omega, disc_weights)
  # Convert to probability
  prob_real_true = (true_disc_output + 1) / 2
  return prob_real_true


def prob_fake_true(gen_weights, disc_weights):
  fake_disc_output = gen_disc_circuit(gen_weights, disc_weights)
  # Convert to probability
  prob_fake_true = (fake_disc_output + 1) / 2
  return prob_fake_true

def disc_cost(disc_weights):
  cost = prob_fake_true(gen_weights, disc_weights) - prob_real_true(disc_weights)
  return cost

def gen_cost(gen_weights):
  return -prob_fake_true(gen_weights, disc_weights)

In [None]:
phi = np.pi / 6
theta = np.pi / 2
omega = np.pi / 7

np.random.seed(0)
eps = 1e-2

init_gen_weights = np.array([np.pi] + [0] * 8) + np.random.normal(scale=eps, size=(9, ))
init_disc_weights = np.random.normal(size=(9, ))

gen_weights = tf.Variable(init_gen_weights)
disc_weights = tf.Variable(init_disc_weights)

In [None]:
opt = tf.keras.optimizers.SGD(0.4)

In [None]:
def disc_iteration():
  cost = lambda: disc_cost(disc_weights)

  print("####### Minimizing discriminator cost #######")

  for step in range(50):
    opt.minimize(cost, disc_weights)
    
    if step % 5 == 0:
      cost_val = cost().numpy()
      print("Step {}: cost = {}".format(step, cost_val))

  print("####### Finished minimizing discriminator cost #######")

  print("Prob(real classified as real): ", prob_real_true(disc_weights).numpy())
  print("Prob(fake classified as real): ", prob_fake_true(gen_weights, disc_weights).numpy())

In [None]:
def gen_iteration():
  cost = lambda: gen_cost(gen_weights)

  print("####### Minimizing generator cost #######")

  for step in range(50):
    opt.minimize(cost, gen_weights)
    if step % 5 == 0:
      cost_val = cost().numpy()
      print("Step {}: cost = {}".format(step, cost_val))

  print("####### Finished minimizing generator cost #######")

  print("Prob(fake classified as real): ", prob_fake_true(gen_weights, disc_weights).numpy())

In [None]:
def compare_data():

  obs = [qml.PauliX(0), qml.PauliY(0), qml.PauliZ(0)]

  bloch_vector_real = qml.map(real_data, obs, dev, interface="tf")
  bloch_vector_generator = qml.map(generator, obs, dev, interface="tf")

  print("Real Bloch vector: {}".format(bloch_vector_real([phi, theta, omega])))
  print("Generator Bloch vector: {}".format(bloch_vector_generator(gen_weights)))

In [None]:
# The training loop

for i in range(5):
  disc_iteration()
  gen_iteration()
  compare_data()