In [12]:
import matplotlib.pyplot as plt
import numpy as np
import cirq
from cirq.contrib.svg import SVGCircuit
import random as rd
from sympy import *
import tensorflow as tf
import tensorflow_quantum as tfq
import math
import re
import itertools
from numpy import linalg as LA
from scipy.stats import poisson
import time
from random import choices
from random import  uniform

In [13]:
k = 4 #length of clauses
n_var = 6 #number of variables
p = 4 #number of layers

nqubits = n_var #number of qubits in the circuit
all_vars = [i for i in range(-n_var,n_var+1)]
all_vars = [i for i in all_vars if i != 0]

qubits = []

for i in range(nqubits):
    qubits.append(cirq.GridQubit(0,i))        
    
qubits = list(qubits) #don't know why
all_qubits = [i for i in range(nqubits)]

In [14]:
r_by_k = {2 : 1, 3: 6.43, 4: 20.43, 5 : 45.7, 6: 70.21, 8: 176.54, 10: 708.92, 16: 45425.2}

def generate_instance(k: int, n: int) -> np.ndarray:
    #generate an instance of random k-SAT with n variables in the satisfiability threshold
    if not (r := r_by_k.get(k)):
        raise ValueError(f"k must be in {list(r_by_k)} (got {k})")
    m = poisson(r*n).rvs()
    #return np.random.choice(all_vars, size=(m, k))
    all_variables = []
    all_signs = []
    for i in range(m):
        #all_signs.append([rd.choice(l) for i in range(k)])
        all_variables.append(choices(all_vars, k = k))

    all_variables = np.array(all_variables)
    #all_signs = np.array(all_signs)
    return all_variables

In [15]:
def generate_binary_strings(bit_count):
    binary_strings = []
    def genbin(n, bs=''):
        if len(bs) == n:
            binary_strings.append(bs)
        else:
            genbin(n, bs + '0')
            genbin(n, bs + '1')

    genbin(bit_count)
    return binary_strings

binary_strings = generate_binary_strings(nqubits)

In [16]:
def dimacs_writer(dimacs_filename, cnf_array):
    #writes the dimacs file with the CNF
    cnf = cnf_array
    cnf_length = len(cnf)
    n_sat = len(cnf[0])
    var_num = np.max(cnf) 
    with open(dimacs_filename, "w") as f:

        f.write('c DIMACS file CNF '+str(n_sat)+'-SAT \n')
        f.write("p cnf {} {}\n".format(var_num, cnf_length))
        
        for i, clause in enumerate(cnf):
            line = clause.tolist()
            if i == cnf_length - 1:
                s = ' '.join(str(x) for x in line)+' 0'
                f.write(s)
            else: 
                s = ' '.join(str(x) for x in line)+' 0 \n'
                f.write(s)

In [17]:
class Verifier():
    #verifier from Qiskit page, takes a bit string and checks if cnf is satisfied
    def __init__(self, dimacs_file):
        with open(dimacs_file, 'r') as f:
            self.dimacs = f.read()

    def is_correct(self, guess):
        # Convert characters to bools & reverse
        guess = [bool(int(x)) for x in guess][::-1]
        for line in self.dimacs.split('\n'):
            line = line.strip(' 0')
            clause_eval = False
            for literal in line.split(' '):
                if literal in ['p', 'c']:
                    # line is not a clause
                    clause_eval = True
                    break
                if '-' in literal:
                    literal = literal.strip('-')
                    lit_eval = not guess[int(literal)-1]
                else:
                    lit_eval = guess[int(literal)-1]
                clause_eval |= lit_eval
            if clause_eval is False:
                return False
        return True

In [18]:
def my_gate(c, index):
    g = c * cirq.Z.on(qubits[index]) + cirq.I.on(qubits[index])
    return g

In [19]:
def ham_layer(diagonal, circuit, qubits, par):
    
    l = cirq.DiagonalGate(diagonal)._decompose_(qubits)
    l.pop(0)
    for j, gate in enumerate(l):

        if j % 2 == 0:
            dictn = gate._json_dict_()
            my_string = str(dictn['gate'])
            my_other_string = str(dictn['qubits'])
            number_p = re.findall("\d+\.\d+", my_string)
            res_p = [eval(i) for i in number_p]
            if '-' in my_string:
                sign = -1
            else:
                sign = 1
            
            number_q = re.findall(r'\d+', my_other_string)
            res_q = [eval(i) for i in number_q]
            kernel = sign*par*res_p[0]*np.pi
            rzgate = cirq.rz(kernel).on(qubits[res_q[1]])
            circuit.append(rzgate)
        else:
            circuit.append(gate)

In [20]:
def mixing_circuit(circuit, qubits, par):
    for i in range(len(qubits)):
        circuit.append(cirq.rx(par).on(qubits[i]))
    return circuit

In [21]:
class MyLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):

    def __init__(self, initial_learning_rate):
        self.initial_learning_rate = initial_learning_rate

    def __call__(self, step):
        return self.initial_learning_rate / (step+1)

In [22]:
parameter_list = []
batches = 50
for batch in range(batches):

    valid_keys = []
    dimacs_filename = "random_cnf_BM.dimacs" 

    while not valid_keys:
        #only accepts satisfiable CNFs
        inst = generate_instance(k, n_var)
        dimacs_writer(dimacs_filename, inst)
        v = Verifier('random_cnf_BM.dimacs')

        for key in binary_strings:
            if v.is_correct(key) == True:
                valid_keys.append(key)
    with open('random_cnf_BM.dimacs', 'r') as f:
        dimacs = f.read()
    unsat_list = []

    for key in binary_strings:
        guess = [bool(int(x)) for x in key][::-1]

        clause_eval_list = []
        counter = 0
        for j, line in enumerate(dimacs.split('\n')):

            line = line.strip(' 0')
            clause_eval = False

            for literal in line.split(' '):
                if literal in ['p', 'c']:
                    #line is not a clause
                    clause_eval = True
                    break
                if '-' in literal:
                    literal = literal.strip('-')
                    lit_eval = not guess[int(literal)-1]
                else:
                    lit_eval = guess[int(literal)-1]
                clause_eval |= lit_eval
            if j > 1:
                counter += 1
                clause_eval_list.append(clause_eval)
        unsat_clauses = counter - sum(clause_eval_list)
        unsat_list.append(unsat_clauses)

    diagonal = unsat_list
    
    x = [1, -1]
    combinations = [p for p in itertools.product(x, repeat=nqubits)]

    ops_list = []
    for j, combination in enumerate(combinations):
        ops_list.append((diagonal[j]/2**nqubits)*math.prod([my_gate(combination[i], i) for i in range(nqubits)]))

    cost = np.sum(ops_list)
    cost_m = cost.matrix()
    gs_energy = np.real(min(LA.eig(cost_m)[0]))
    qaoa_circuit = cirq.Circuit()
    num_param = 2 * p 
    parameters = symbols("q0:%d" % num_param)

    #setting up the layers
    for i in range(p):
        ham_layer(diagonal, qaoa_circuit, qubits, parameters[2 * i])
        mixing_circuit(qaoa_circuit, qubits, parameters[2 * i + 1])
        
    initial = cirq.Circuit()

    for qubit in qubits:
        initial.append(cirq.H(qubit)) #applying Hadamard to all qubits before running circuit

    #setting up the model
    lr = 1e-1

    inputs = tfq.convert_to_tensor([initial])
    ins = tf.keras.layers.Input(shape = (), dtype = tf.dtypes.string)
    outs = tfq.layers.PQC(qaoa_circuit, cost)(ins)
    ksat = tf.keras.models.Model(inputs = ins, outputs = outs)
    opt = tf.keras.optimizers.Adam(learning_rate = 0.01)
    ksat.trainable_variables[0].assign([rd.uniform(-2*np.pi, 2*np.pi) for i in range(2*p)]) #initializing angles with some small noise
    
    losses = []
    error = 1e2*rd.random()
    tol = 1e-2

    start = time.time()

    j=0
    while j < 1e3:
        previous_error = error   
        with tf.GradientTape() as tape:
            error = ksat(inputs)

        grads = tape.gradient(error, ksat.trainable_variables)
        opt.apply_gradients(zip(grads, ksat.trainable_variables))
        error = error.numpy()[0,0]
        losses.append(error)

        print('Batch is '+str(batch)+' and absolute value of (ground state energy - error) is ' + str(abs(gs_energy - error)), end = '\r')

        #if abs(error - previous_error) < 1e-10:
        #    print('\n got stuck!')
        #    break
        j += 1
    params = ksat.get_weights()[0]
    end = time.time()
    parameter_list.append(params)

Batch is 49 and absolute value of (ground state energy - error) is 4.4422354698181155

In [24]:
average_parameters = 0
for i in range(40):
    average_parameters += parameter_list[i]
    
average_parameters/len(parameter_list)

array([ 0.3984816 , -0.23699822, -1.2622085 ,  0.38815212, -0.6662491 ,
       -0.6555451 ,  0.17665611,  0.6545369 ], dtype=float32)