# Circuit learning module: Lambeq manually with SPSA and JAX

This module performs the optimization of the parametrized circuit manually compared to Lambeq's automatic QuantumTrainer class. I created this because I wanted to have more control over the optimization process and debug it better. The code is based on the workflow presented in https://github.com/CQCL/Quanthoven.

In [1]:
import warnings
warnings.filterwarnings('ignore')

import json
import os
import glob
from pathlib import Path
from jax import numpy as np
import numpy
import pickle

from discopy.utils import loads
from discopy.quantum import Circuit
from discopy.tensor import Tensor

from jax import jit
from noisyopt import minimizeSPSA

this_folder = os.path.abspath(os.getcwd())
os.environ['TOKENIZERS_PARALLELISM'] = 'true'

BATCH_SIZE = 32
EPOCHS = 100
SEED = 0

## Read circuit data

We read the circuits from the pickled files.

In [2]:
training_circuits_paths = glob.glob(this_folder + "//simplified-JOB-diagrams//circuits//binary_classification//training//[0-9]*.p")
#validation_circuits_paths = glob.glob(this_folder + "//simplified-JOB-diagrams//circuits//binary_classification//validation//[0-9]*.p")
test_circuits_paths = glob.glob(this_folder + "//simplified-JOB-diagrams//circuits//binary_classification//test//[0-9]*.p")

def read_diagrams(circuit_paths):
    circuits = {}
    for serialized_diagram in circuit_paths:
        base_name = Path(serialized_diagram).stem
        f = open(serialized_diagram, "rb")
        diagram = pickle.load(f)
        circuits[base_name] = diagram
    return circuits


training_circuits = read_diagrams(training_circuits_paths) #+ training_circuits_paths[21:])
#for key in training_circuits:
#    print("training query: ", key)
test_circuits = read_diagrams(test_circuits_paths) #+ [test_circuits_paths[2]] + test_circuits_paths[8:])
#test_circuits = read_diagrams([test_circuits_paths[2]])
#for key in test_circuits:
#    print("test query: ", key)

## Read training and test data

In [3]:
training_data, test_data = None, None
with open(this_folder + "//data//training_data.json", "r") as inputfile:
    training_data = json.load(inputfile)['training_data']
with open(this_folder + "//data//test_data.json", "r") as inputfile:
    test_data = json.load(inputfile)['test_data']
    

def time_to_states(data, circuits):
    labeled_data = {}
    for elem in data:
        if elem["name"] in circuits.keys():
            if elem["time"] < 5000:
                labeled_data[elem["name"]] = [1,0] # corresponds to |0>
            else:
                labeled_data[elem["name"]] = [0,1] # corresponds to |1>
    return labeled_data


training_data_labels = time_to_states(training_data, training_circuits)
test_data_labels = time_to_states(test_data, test_circuits)

#for key in training_data_labels:
#    print("training: ", key)
#for key in test_data_labels:
#    print("test ", key)

## Lambeq optimizer

In [4]:
#all_circuits = list(training_circuits.values()) + list(test_circuits.values())

training_circuits_l = []
test_circuits_l = []
training_data_labels_l = []
test_data_labels_l = []

# Organize circuits and labels in correct order into two lists which will be input for training the model
for key in training_data_labels:
    training_circuits_l.append(training_circuits[key])
    training_data_labels_l.append(training_data_labels[key])

for key in test_data_labels:
    test_circuits_l.append(test_circuits[key])
    test_data_labels_l.append(test_data_labels[key])

all_circuits = training_circuits_l + test_circuits_l

train_syms = set([sym for circuit in training_circuits.values() for sym in circuit.free_symbols])
test_syms = set([sym for circuit in test_circuits.values() for sym in circuit.free_symbols])

print("Test circuits need to share training circuits' parameters. The parameters that are not covered: ", test_syms.difference(train_syms))

print("Total number of circuits: ", len(all_circuits))
print("Total number of variables: ", len(train_syms))

Test circuits need to share training circuits' parameters. The parameters that are not covered:  set()
Total number of circuits:  1168
Total number of variables:  240


## Model

In [5]:
from sympy import default_sort_key

parameters = sorted(
    {s for circ in all_circuits for s in circ.free_symbols},
    key=default_sort_key)

len(parameters)

240

In [6]:
def normalise(predictions):
    predictions = np.abs(predictions) + 1e-9
    return predictions / predictions.sum()

def make_pred_fn(circuits):
    circuit_fns = [c.lambdify(*parameters) for c in circuits]

    def predict(params):
        outputs = Circuit.eval(*(c(*params) for c in circuit_fns))
        return np.array([normalise(output.array) for output in outputs])
    return predict


train_pred_fn = jit(make_pred_fn(training_circuits_l))
#dev_pred_fn = jit(make_pred_fn(dev_circuits))
test_pred_fn = make_pred_fn(test_circuits_l)

## Loss function and evaluation

In [7]:
def make_cost_fn(pred_fn, labels):
    def cost_fn(params, **kwargs):
        predictions = pred_fn(params)

        cost = -np.sum(labels * np.log(predictions)) / len(labels)  # binary cross-entropy loss
        costs.append(cost)

        acc = np.sum(np.round(predictions) == labels) / len(labels) / 2  # half due to double-counting
        accuracies.append(acc)

        return cost

    costs, accuracies = [], []
    return cost_fn, costs, accuracies

## Trainer

In [None]:
BATCH_SIZE = 32
EPOCHS = 200
SEED = 0

# This avoids TracerArrayConversionError from jax
Tensor.np = np

rng = numpy.random.default_rng(SEED)
init_params_spsa = np.array(rng.random(len(parameters)))
numpy.random.seed(SEED)

train_cost_fn, train_costs, train_accs = make_cost_fn(train_pred_fn, training_data_labels_l)
#dev_cost_fn, dev_costs, dev_accs = make_cost_fn(dev_pred_fn, dev_labels)

# Evaluate the initial cost
cost_store_spsa = [train_cost_fn(init_params_spsa)]

def callback_fn(xk):
    cost_val = train_cost_fn(xk)
    cost_store_spsa.append(cost_val)

    iteration_num = len(cost_store_spsa)
    print(
            f"Params = {xk}, "
            f"Iteration = {iteration_num}, "
            f"Cost = {cost_val}"
        )

In [None]:
result = minimizeSPSA(train_cost_fn, x0=init_params_spsa, a=0.91, c=0.12, niter=EPOCHS, callback=callback_fn)

In [None]:
import matplotlib.pyplot as plt

fig, ((ax_tl, ax_tr), (ax_bl, ax_br)) = plt.subplots(2, 2, sharex=True, sharey='row', figsize=(10, 6))
ax_tl.set_title('Training set')
ax_tr.set_title('Development set')
ax_bl.set_xlabel('Iterations')
ax_br.set_xlabel('Iterations')
ax_bl.set_ylabel('Accuracy')
ax_tl.set_ylabel('Loss')

colours = iter(plt.rcParams['axes.prop_cycle'].by_key()['color'])
ax_tl.plot(trainer.train_epoch_costs[::10], color=next(colours))
ax_bl.plot(trainer.train_results['acc'][::10], color=next(colours))
ax_tr.plot(trainer.val_costs[::10], color=next(colours))
ax_br.plot(trainer.val_results['acc'][::10], color=next(colours))

for e in model(test_circuits_l):
    print(e)
for e in test_data_labels_l:
    print(e)

# print test accuracy
test_acc = acc(model(test_circuits_l), test_data_labels_l)
print('Validation accuracy:', test_acc.item())