In [1]:
%load_ext jupyternotify
from pytket import Circuit, Qubit, Bit, OpType
from pytket.utils.operators import QubitPauliOperator
from sympy import symbols
from openfermion import QubitOperator
from random import sample
import numpy as np
# from pytket.backends.projectq import ProjectQBackend
from pytket.backends.ibm import AerStateBackend, AerBackend, AerUnitaryBackend, IBMQBackend
from scipy.linalg import expm, sinm, cosm
from sympy.physics.quantum.dagger import Dagger
import functools
import operator
import itertools    
from openfermion import get_sparse_operator
from scipy.optimize import minimize, LinearConstraint, Bounds
import matplotlib.pyplot as plt


def fidelity(rsv, gsv): 
    overlap = np.vdot(rsv, gsv)
    return abs(overlap)**2

#constants 
n = 1 #number of qubits 
lamb = np.float(2)
s = np.exp(-1 / (2 * lamb)) - 1
cst1 = (s / 2 + 1) ** 2
cst2 = (s / 2) * (s / 2 + 1)
cst3 = (s / 2) ** 2

<IPython.core.display.Javascript object>

In [2]:
def real(n, weights):
    
    real_circ = Circuit()
    qubits = real_circ.add_q_register('q', n)

    real_circ.Rx(weights[0], qubits[0])
        
    real_circ.Ry(weights[1], qubits[0])
        
    real_circ.Rz(weights[2], qubits[0])
        

    # backend.compile_circuit(real_circ)
    # statevector = backend.get_state(real_circ)
    
    # return real_circ, statevector
    return real_circ


In [3]:
def generator(n, weights):
    
    generator_circ = Circuit()
    qubits = generator_circ.add_q_register('q', n)

    generator_circ.Rx(weights[0], qubits[0])

    generator_circ.Ry(weights[1], qubits[0])

    generator_circ.Rz(weights[2], qubits[0])
        
    # backend.compile_circuit(generator_circ)
    # statevector = backend.get_state(generator_circ)
    
    # return generator_circ, statevector
    return generator_circ


def generator_symbolic(n):
    
    generator_circ = Circuit()
    qubits = generator_circ.add_q_register('q', n)
    weight_symbols = symbols([f'theta_{i}' for i in range(3)])

    generator_circ.Rx(weight_symbols[0], qubits[0])

    generator_circ.Ry(weight_symbols[1], qubits[0])

    generator_circ.Rz(weight_symbols[2], qubits[0])
        
    # backend.compile_circuit(generator_circ)
    # statevector = backend.get_state(generator_circ)
    
    # return generator_circ, statevector
    return generator_circ,weight_symbols


In [16]:
def operator_inner(left, operator_matrix, right):
    return np.vdot(left, operator_matrix.dot(right))

In [20]:
class Discriminator:
    def __init__(self, init_weights):
        self.set_weights(init_weights)
    def set_weights(self, _init_weights):
        halfway = len(_init_weights)//2
        self.psi_weights = _init_weights[:halfway]
        self.phi_weights = _init_weights[halfway:]
        assert(len(self.phi_weights) == len(self.psi_weights))
        self.n_qubits = len(self.phi_weights)//4

    def make_operator(self, weights_list):
        iden =  weights_list[-1] * QubitOperator(" ")
        tuple_list = [(weight, tup[0], tup[1]) for weight, tup in zip(weights_list[:-1], itertools.product(['X', 'Y', 'Z'], range(self.n_qubits)))]
        measurements = functools.reduce(operator.add, (weight * QubitOperator(f'{a}{n}') for weight, a, n in tuple_list))

        return iden + measurements


    def calculate_loss(self, real_sv, gen_sv):
        # construct operators
        psi = self.make_operator(self.psi_weights)
        phi = self.make_operator(self.phi_weights)
        #convert phi and psi operators to matrix 
        psi_matrix = np.array(get_sparse_operator(psi).todense())
        phi_matrix = np.array(get_sparse_operator(phi).todense())
        
        #calculate expectation values 
        
        psi_exp = operator_inner(real_sv, psi_matrix, real_sv)
        phi_exp = operator_inner(gen_sv, phi_matrix, gen_sv)

        #calculate values for A and B which go into the calculation for the regterm
        A = expm(np.float(-1 / lamb) * phi_matrix)
        B = expm(np.float(1 / lamb) * psi_matrix)

        term1 = operator_inner(gen_sv, A, gen_sv)
        term2 = operator_inner(real_sv, B, real_sv)
        term3 = operator_inner(gen_sv, B, real_sv)
        term4 = operator_inner(real_sv, A, gen_sv)
        term5 = operator_inner(gen_sv, A, real_sv)
        term6 = operator_inner(real_sv, B, gen_sv)
        term7 = operator_inner(gen_sv, B, gen_sv)
        term8 = operator_inner(real_sv, A, real_sv)

        regterm = (lamb / np.e * (cst1 * term1 * term2 - cst2 * term3 * term4 - cst2 * term5 * term6 + cst3 * term7 * term8)).item()

        return np.real(psi_exp - phi_exp - regterm) 
    
    

In [6]:
def make_disc_loss(real_state, gen_state):
    def disc_loss(disc_weights):
        disc = Discriminator(disc_weights)

        return -disc.calculate_loss(real_state, gen_state)
    return disc_loss

def make_gen_loss(base_circuit, symb_weights, real_state, backend, discriminator):
    
    def gen_loss(gen_weights):
        gen_circ = base_circuit.copy()
        gen_circ.symbol_substitution(dict(zip(symb_weights, gen_weights)))
        gen_state = backend.get_state(gen_circ)
        return discriminator.calculate_loss(real_state, gen_state)


    return gen_loss



In [21]:
%%time
%%notify

backend = AerStateBackend()
gen_circ_base, symbolic_weights = generator_symbolic(n)
backend.compile_circuit(gen_circ_base)
np.random.seed(3)


for r in range(1):
    
    #generate a random real state and a initial fake state and the disc params 
    real_weights = np.random.uniform(0,2,3) 
    init_gen_weights = np.random.uniform(0,2,3)
    init_disc_weights = np.random.uniform(0,2,8)
    
    real_circ = real(n, real_weights)
    backend.compile_circuit(real_circ)
    real_sv = backend.get_state(real_circ)

    gen_circ = gen_circ_base.copy()
    gen_circ.symbol_substitution(dict(zip(symbolic_weights, init_gen_weights)))

    curr_gen_sv = backend.get_state(gen_circ)
    init_fid = fidelity(real_sv, curr_gen_sv)

    fid = init_fid


    # dloss = []
    gloss = []
    f = []
    
    # f.append(init_fid)
    # dloss.append(disc_loss(init_disc_weights))
    # gloss.append(gen_loss(init_gen_weights))

    curr_gen_weights = init_gen_weights
    curr_disc_weights = init_disc_weights
    niter = 0 


    while fid <0.99:

        #calculate the updated gen circ and statevector 
        curr_gen_circ = gen_circ_base.copy()
        curr_gen_circ.symbol_substitution(dict(zip(symbolic_weights, curr_gen_weights)))
        curr_gen_sv = backend.get_state(curr_gen_circ)
        fid = fidelity(real_sv, curr_gen_sv)

        f.append(fid)

        #maximise loss for disc
        disc_loss = make_disc_loss(real_sv, curr_gen_sv)
        gloss.append(-disc_loss(curr_disc_weights))

        def disc_callback(x):
            print("D train", -disc_loss(x))
            return False
        disc_result = minimize(disc_loss, curr_disc_weights, method='Powell',bounds=Bounds(0,1),  options={ 'maxiter': 10, 'ftol':1e-5}, callback=disc_callback)
        curr_disc_weights = disc_result.x
        print(curr_disc_weights)
        # dloss.append(float(disc_result.fun))

        #minimise loss for gen
        curr_disc = Discriminator(curr_disc_weights)
        gen_loss = make_gen_loss(gen_circ_base, symbolic_weights, real_sv, backend, curr_disc)
        def gen_callback(x):
            print(x)
            print("G train", gen_loss(x))
            return False
        gloss.append(gen_loss(curr_gen_weights))
        print("start", curr_gen_weights)
        print(disc_loss(curr_disc_weights), gen_loss(curr_gen_weights))
        gen_result = minimize(gen_loss, curr_gen_weights, method='Powell', bounds=Bounds(0,2.0), options={ 'maxiter': 10, 'ftol':1e-10}, callback=gen_callback )
        curr_gen_weights = gen_result.x
        
        # gloss.append(float(gen_result.fun))
        

        niter += 1

        if niter == 5:
            break 

            
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15,5))

    y = np.real(f)
    x =list(range(0, len(y)))

    y1 = -np.array(gloss)
    x1 =list(range(0, len(y1)))
    
    y2 = gloss
    x2 =list(range(0, len(y2)))


    ax1.plot(x,y)
    ax2.plot(x1, y1, label = "disc loss")
    ax2.plot(x2, y2,  label = "gen loss")
    ax2.legend()


    ax1.set_xlabel('While loop iterations')
    ax1.set_ylabel('Fidelity')

    ax2.set_xlabel('While loop iterations')
    ax2.set_ylabel('Wasserstein  Loss')

D train 0.4040901634762579
D train 0.40649340627670094
D train 0.4065452686864435
D train 0.4065638200925539
D train 0.40696750404506377
D train 0.40713734260168977
D train 0.407810072073749
D train 0.4078137546956657
[-0.7905246   1.09880422  0.46650803  0.05979353 -1.26410731  0.84390092
  1.7535759   1.39597324]
start [1.02165521 1.78589391 1.79258618]
-0.4078137546956657 0.4078137546956657
[2.00034994 1.75716179 1.79601981]
G train -2.472636876703091
[2.00034963 1.75714761 1.7960198 ]
G train -2.4726368781375347
[2.00034963 1.75714761 1.79601979]
G train -2.472636878137535
D train -0.17395756972673482
D train 0.00896880470422623
D train 0.04742632490282306
D train 0.04746066093955936
D train 0.04746345037707922
D train 0.047478451392419796
D train 0.04751834968788815
D train 0.04777189169787155
D train 0.048980834370491166
D train 0.05260800519710607
[-5.14078204  7.73866873  2.36441022 -1.26181348 -3.32045162  2.72744708
  4.01508228  0.3822102 ]
start [2.00034963 1.75714761 1.796

KeyboardInterrupt: 