# Circuit learning module: Lambeq manually with SPSA and JAX

This module performs the optimization of the parametrized circuit manually compared to Lambeq's automatic QuantumTrainer class. I created this because I wanted to have more control over the optimization process and debug it better. The code is based on the workflow presented in https://github.com/CQCL/Quanthoven.

In [1]:
import warnings
import json
import os
import sys
import glob
from math import ceil
from pathlib import Path
from jax import numpy as np
from sympy import default_sort_key
import numpy
import pickle
import matplotlib.pyplot as plt

import jax
from jax import jit
from noisyopt import minimizeSPSA, minimizeCompass

from discopy.quantum import Circuit
from discopy.tensor import Tensor
from discopy.utils import loads
#from pytket.extensions.qiskit import AerBackend
#from pytket.extensions.qulacs import QulacsBackend
#from pytket.extensions.cirq import CirqStateSampleBackend
backend = None

from utils import *
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score

warnings.filterwarnings('ignore')
this_folder = os.path.abspath(os.getcwd())
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
#os.environ["JAX_PLATFORMS"] = "cpu"

EPOCHS = 500
SEED = 0

# This avoids TracerArrayConversionError from jax
Tensor.np = np

rng = numpy.random.default_rng(SEED)
numpy.random.seed(SEED)

## Read circuit data

We read the circuits from the pickled files. Select if we perform binary classification or multi-class classification. Give number of qubits to create classes:
- 1 qubits -> 2^1 = 2 classes i.e. binary classification
- 2 qubits -> 2^2 = 4 classes
- ...
- 5 qubits -> 2^5 = 32 classes, etc.

In [2]:
# Select workload
workload = "execution_time"
#workload = "cardinality"

# Select workload size
#workload_size = "small"
#workload_size = "medium"
#workload_size = "large"
workload_size = "main"

classification = 1
layers = 1
single_qubit_params = 3
n_wire_count = 1

loss = multi_class_loss
acc = multi_class_acc

if classification == 1:
    loss = bin_class_loss
    acc = bin_class_acc

# Access the selected circuits
path_name = this_folder + "//simplified-JOB-diagrams//" + workload + "//" + workload_size + "//circuits//" + str(classification) + "//" + str(layers) + "_layer//" + str(single_qubit_params) + "_single_qubit_params//" + str(n_wire_count) + "_n_wire_count//"

training_circuits_paths = glob.glob(path_name + "training//[0-9]*.p")
validation_circuits_paths = glob.glob(path_name + "validation//[0-9]*.p")
test_circuits_paths = glob.glob(path_name + "test//[0-9]*.p")

In [3]:
training_circuits = read_diagrams(training_circuits_paths)
validation_circuits = read_diagrams(validation_circuits_paths)
test_circuits = read_diagrams(test_circuits_paths)

## Read training and test data

In [4]:
training_data, test_data, validation_data = None, None, None
data_path = this_folder + "//data//" + workload + "//" + workload_size + "//"

with open(data_path + "training_data.json", "r") as inputfile:
    training_data = json.load(inputfile)['training_data']
with open(data_path + "test_data.json", "r") as inputfile:
    test_data = json.load(inputfile)['test_data']
with open(data_path + "validation_data.json", "r") as inputfile:
    validation_data = json.load(inputfile)['validation_data']

training_data_labels = create_labeled_classes(training_data, classification, workload)
test_data_labels = create_labeled_classes(test_data, classification, workload)
validation_data_labels = create_labeled_classes(validation_data, classification, workload)

## Lambeq optimizer

## Model

In [5]:
def make_pred_fn(circuits):
    # In the case we want to use other backends. 
    # Currently does not work properly.
    if backend:
        compiled_circuits1 = backend.get_compiled_circuits([c.to_tk() for c in circuits])
        circuits = [Circuit.from_tk(c) for c in compiled_circuits1]
        
    circuit_fns = [c.lambdify(*parameters) for c in circuits]
    
    def predict(params):
        outputs = Circuit.eval(*(c(*params) for c in circuit_fns), backend = backend)
        res = []
        
        for output in outputs:
            predictions = np.abs(output.array) + 1e-9
            ratio = predictions / predictions.sum()
            res.append(ratio)
            
        return np.array(res)
    return predict

## Loss function and evaluation

In [6]:
def make_cost_fn(pred_fn, labels):
    def cost_fn(params, **kwargs):
        predictions = pred_fn(params)

        cost = loss(predictions, labels) #-np.sum(labels * np.log(predictions)) / len(labels)  # binary cross-entropy loss
        costs.append(cost)

        accuracy = acc(predictions, labels) #np.sum(np.round(predictions) == labels) / len(labels) / 2  # half due to double-counting
        accuracies.append(accuracy)

        return cost

    costs, accuracies = [], []
    return cost_fn, costs, accuracies

## Minimization with noisyopt

In [7]:
def initialize_parameters(old_params, old_values, new_params):
    new_values = list(numpy.array(rng.random(len(new_params))))
    old_param_dict = {}
    for p, v in zip(old_params, old_values):
        old_param_dict[p] = v
        
    parameters = sorted(set(old_params + new_params), key=default_sort_key)
    values = []
    for p in parameters:
        if p in old_param_dict:
            values.append(old_param_dict[p])
        else:
            values.append(new_values.pop())
            
    return parameters, np.array(values)

In [8]:
EPOCHS = 4000
initial_number_of_circuits = 409
syms = {}
limit = False
all_training_keys = list(training_circuits.keys())
initial_circuit_keys = all_training_keys[:initial_number_of_circuits + 1]
current_training_circuits = {}
result_file = workload_size + "_noisyopt_2" + str(classification) + "_" + str(layers) + "_" + str(single_qubit_params)

for k in initial_circuit_keys:
    current_training_circuits[k] = training_circuits[k]
    
syms = get_symbols(current_training_circuits)
parameters = sorted(syms, key=default_sort_key)
if initial_number_of_circuits > 5 and os.path.exists("points//" + result_file + ".npz"):
    with open("points//" + result_file + ".npz", "rb") as f:
        print("Loading parameters from file " + result_file)
        npzfile = np.load(f)
        init_params_spsa = npzfile['arr_0']
else:
    print("Initializing new parameters")
    init_params_spsa = np.array(rng.random(len(parameters)))
result = None
run = 0

Loading parameters from file main_noisyopt_21_1_3


In [None]:
for i, key in enumerate(all_training_keys[initial_number_of_circuits:]):
    print("Progress: ", round((i + initial_number_of_circuits)/len(all_training_keys), 3))
    
    if len(syms) == len(get_symbols(current_training_circuits)) and i > 0:
        if i != len(all_training_keys[1:]):
            current_training_circuits[key] = training_circuits[key]
            new_parameters = sorted(get_symbols({key: training_circuits[key]}), key=default_sort_key)
            if result:
                parameters, init_params_spsa = initialize_parameters(parameters, result.x, new_parameters)
                #continue
            else:
                syms = get_symbols(current_training_circuits)
                parameters = sorted(syms, key=default_sort_key)
                init_params_spsa = np.array(rng.random(len(parameters)))
    else:
        run += 1
    
    # Select those circuits from test and validation circuits which share the parameters with the current training circuits
    current_validation_circuits = select_circuits(current_training_circuits, validation_circuits)
    current_test_circuits = select_circuits(current_training_circuits, test_circuits)
    
    if len(current_validation_circuits) == 0 or len(current_test_circuits) == 0:
        continue
    
    # Create lists with circuits and their corresponding label
    training_circuits_l, training_data_labels_l = construct_data_and_labels(current_training_circuits, training_data_labels)
    validation_circuits_l, validation_data_labels_l = construct_data_and_labels(current_validation_circuits, validation_data_labels)
    test_circuits_l, test_data_labels_l = construct_data_and_labels(current_test_circuits, test_data_labels)
    
    # Limit the number of validation and test circuits to 20% of number of the training circuits
    if limit:
        val_test_circ_size = ceil(len(current_training_circuits))
        if len(current_validation_circuits) > val_test_circ_size:
            validation_circuits_l = validation_circuits_l[:val_test_circ_size]
            validation_data_labels_l = validation_data_labels_l[:val_test_circ_size]
        if len(current_test_circuits) > val_test_circ_size:
            test_circuits_l = test_circuits_l[:val_test_circ_size]
            test_data_labels_l = test_data_labels_l[:val_test_circ_size]
    
    stats = f"Number of training circuits: {len(training_circuits_l)}   "\
        + f"Number of validation circuits: {len(validation_circuits_l)}   "\
        + f"Number of test circuits: {len(test_circuits_l)}   "\
        + f"Number of parameters in model: {len(set([sym for circuit in training_circuits_l for sym in circuit.free_symbols]))}"
    
    with open("results//" + result_file + ".txt", "a") as f:
        f.write(stats + "\n")
    
    print(stats)
    
    train_pred_fn = jit(make_pred_fn(training_circuits_l))
    dev_pred_fn = jit(make_pred_fn(validation_circuits_l))
    test_pred_fn = make_pred_fn(test_circuits_l)
    
    train_cost_fn, train_costs, train_accs = make_cost_fn(train_pred_fn, training_data_labels_l)
    dev_cost_fn, dev_costs, dev_accs = make_cost_fn(dev_pred_fn, validation_data_labels_l)
    
    def callback_fn(xk):
        #print(xk)
        valid_loss = dev_cost_fn(xk)
        train_loss = numpy.around(min(float(train_costs[-1]), float(train_costs[-2])), 4)
        train_acc = numpy.around(min(float(train_accs[-1]), float(train_accs[-2])), 4)
        valid_acc = numpy.around(float(dev_accs[-1]), 4)
        iters = int(len(train_accs)/2)
        if iters % 200 == 0:
            info = f"Epoch: {iters}   "\
            + f"train/loss: {train_loss}   "\
            + f"valid/loss: {numpy.around(float(valid_loss), 4)}   "\
            + f"train/acc: {train_acc}   "\
            + f"valid/acc: {valid_acc}"
        
            with open("results//" + result_file + ".txt", "a") as f:
                f.write(info + "\n")
                
            print(info, file=sys.stderr)
        return valid_loss
    
    a_value = 0.0053
    c_value = 0.0185
            
    train_cost_fn, train_costs, train_accs = make_cost_fn(train_pred_fn, training_data_labels_l)
    dev_cost_fn, dev_costs, dev_accs = make_cost_fn(dev_pred_fn, validation_data_labels_l)

    result = minimizeSPSA(train_cost_fn, x0=init_params_spsa, a = a_value, c = c_value, niter=EPOCHS, callback=callback_fn)
    #result = minimizeCompass(train_cost_fn, x0=init_params_spsa, redfactor=2.0, deltainit=1.0, deltatol=0.001, feps=1e-15, errorcontrol=True, funcNinit=30, funcmultfactor=2.0, paired=True, alpha=0.05, callback=callback_fn)

    figure_path = this_folder + "//results//" + result_file + ".png"
    visualize_result_noisyopt(result, make_cost_fn, test_pred_fn, test_data_labels_l, train_costs, train_accs, dev_costs, dev_accs, figure_path, result_file)
    
    run += 1
    #EPOCHS += 100
    syms = get_symbols(current_training_circuits)
    
    # Extend for the next optimization round
    current_training_circuits[key] = training_circuits[key]
    new_parameters = sorted(get_symbols({key: training_circuits[key]}), key=default_sort_key)
    parameters, init_params_spsa = initialize_parameters(parameters, result.x, new_parameters)

Progress:  0.913
Number of training circuits: 408   Number of validation circuits: 113   Number of test circuits: 112   Number of parameters in model: 286


Epoch: 200   train/loss: 0.3159   valid/loss: 0.3892   train/acc: 0.8627   valid/acc: 0.8584
Epoch: 400   train/loss: 0.2866   valid/loss: 0.39   train/acc: 0.8873   valid/acc: 0.8673
Epoch: 600   train/loss: 0.2727   valid/loss: 0.3955   train/acc: 0.8922   valid/acc: 0.8673
Epoch: 800   train/loss: 0.2954   valid/loss: 0.3936   train/acc: 0.8848   valid/acc: 0.8584
Epoch: 1000   train/loss: 0.2753   valid/loss: 0.3955   train/acc: 0.8995   valid/acc: 0.8673
Epoch: 1200   train/loss: 0.2832   valid/loss: 0.3944   train/acc: 0.8946   valid/acc: 0.8673
Epoch: 1400   train/loss: 0.3006   valid/loss: 0.3956   train/acc: 0.8799   valid/acc: 0.8673
Epoch: 1600   train/loss: 0.2768   valid/loss: 0.3957   train/acc: 0.8848   valid/acc: 0.8584
Epoch: 1800   train/loss: 0.2911   valid/loss: 0.3973   train/acc: 0.875   valid/acc: 0.8496
Epoch: 2000   train/loss: 0.2771   valid/loss: 0.3945   train/acc: 0.8873   valid/acc: 0.8673
Epoch: 2200   train/loss: 0.2824   valid/loss: 0.3973   train/acc: 

Test accuracy: 0.81250006
Progress:  0.915
Number of training circuits: 409   Number of validation circuits: 113   Number of test circuits: 112   Number of parameters in model: 286


Epoch: 200   train/loss: 0.2931   valid/loss: 0.3934   train/acc: 0.868   valid/acc: 0.8496
Epoch: 400   train/loss: 0.3068   valid/loss: 0.3921   train/acc: 0.8704   valid/acc: 0.8584
Epoch: 600   train/loss: 0.278   valid/loss: 0.3896   train/acc: 0.8924   valid/acc: 0.8584
Epoch: 800   train/loss: 0.2716   valid/loss: 0.3909   train/acc: 0.89   valid/acc: 0.8496
Epoch: 1000   train/loss: 0.2989   valid/loss: 0.3915   train/acc: 0.8729   valid/acc: 0.8496
Epoch: 1200   train/loss: 0.2708   valid/loss: 0.393   train/acc: 0.8924   valid/acc: 0.8496
Epoch: 1400   train/loss: 0.2816   valid/loss: 0.3931   train/acc: 0.8851   valid/acc: 0.8496
Epoch: 1600   train/loss: 0.2769   valid/loss: 0.3912   train/acc: 0.8875   valid/acc: 0.8673
Epoch: 1800   train/loss: 0.2751   valid/loss: 0.3905   train/acc: 0.8973   valid/acc: 0.8496
Epoch: 2000   train/loss: 0.2758   valid/loss: 0.391   train/acc: 0.8851   valid/acc: 0.8673
Epoch: 2200   train/loss: 0.2921   valid/loss: 0.3911   train/acc: 0.8

Test accuracy: 0.81250006
Progress:  0.917
Number of training circuits: 410   Number of validation circuits: 113   Number of test circuits: 112   Number of parameters in model: 286


Epoch: 200   train/loss: 0.2782   valid/loss: 0.3986   train/acc: 0.8829   valid/acc: 0.8496
Epoch: 400   train/loss: 0.3094   valid/loss: 0.3918   train/acc: 0.8707   valid/acc: 0.8584
Epoch: 600   train/loss: 0.2872   valid/loss: 0.3926   train/acc: 0.8805   valid/acc: 0.8673
Epoch: 800   train/loss: 0.2758   valid/loss: 0.3908   train/acc: 0.8927   valid/acc: 0.8673
Epoch: 1000   train/loss: 0.3093   valid/loss: 0.3906   train/acc: 0.8756   valid/acc: 0.8584
Epoch: 1200   train/loss: 0.2747   valid/loss: 0.3908   train/acc: 0.8854   valid/acc: 0.8673
Epoch: 1400   train/loss: 0.2689   valid/loss: 0.391   train/acc: 0.8902   valid/acc: 0.8673
Epoch: 1600   train/loss: 0.276   valid/loss: 0.3909   train/acc: 0.8927   valid/acc: 0.8673
Epoch: 1800   train/loss: 0.2695   valid/loss: 0.3906   train/acc: 0.8878   valid/acc: 0.8584
Epoch: 2000   train/loss: 0.2764   valid/loss: 0.3907   train/acc: 0.8927   valid/acc: 0.8673
Epoch: 2200   train/loss: 0.2834   valid/loss: 0.3905   train/acc:

Test accuracy: 0.81250006
Progress:  0.92
Number of training circuits: 411   Number of validation circuits: 113   Number of test circuits: 112   Number of parameters in model: 286


Epoch: 200   train/loss: 0.2769   valid/loss: 0.3837   train/acc: 0.8905   valid/acc: 0.8584
Epoch: 400   train/loss: 0.2894   valid/loss: 0.3838   train/acc: 0.8954   valid/acc: 0.8673
Epoch: 600   train/loss: 0.3035   valid/loss: 0.3838   train/acc: 0.8832   valid/acc: 0.8584
Epoch: 800   train/loss: 0.273   valid/loss: 0.3832   train/acc: 0.8929   valid/acc: 0.8673
Epoch: 1000   train/loss: 0.2913   valid/loss: 0.3834   train/acc: 0.8929   valid/acc: 0.8673
Epoch: 1200   train/loss: 0.2885   valid/loss: 0.3834   train/acc: 0.8832   valid/acc: 0.8673
Epoch: 1400   train/loss: 0.2723   valid/loss: 0.3835   train/acc: 0.8954   valid/acc: 0.8673
Epoch: 1600   train/loss: 0.2764   valid/loss: 0.3836   train/acc: 0.8978   valid/acc: 0.8673
Epoch: 1800   train/loss: 0.2998   valid/loss: 0.3833   train/acc: 0.8832   valid/acc: 0.8673
Epoch: 2000   train/loss: 0.291   valid/loss: 0.3836   train/acc: 0.8832   valid/acc: 0.8673
Epoch: 2200   train/loss: 0.2669   valid/loss: 0.3834   train/acc:

Test accuracy: 0.80357146
Progress:  0.922
Number of training circuits: 412   Number of validation circuits: 113   Number of test circuits: 112   Number of parameters in model: 286


Epoch: 200   train/loss: 0.2889   valid/loss: 0.3849   train/acc: 0.8883   valid/acc: 0.8673
Epoch: 400   train/loss: 0.292   valid/loss: 0.3854   train/acc: 0.8859   valid/acc: 0.8673
Epoch: 600   train/loss: 0.293   valid/loss: 0.3858   train/acc: 0.8786   valid/acc: 0.8673
Epoch: 800   train/loss: 0.2976   valid/loss: 0.3867   train/acc: 0.8835   valid/acc: 0.8584
Epoch: 1000   train/loss: 0.2961   valid/loss: 0.3884   train/acc: 0.8908   valid/acc: 0.8673
Epoch: 1200   train/loss: 0.2964   valid/loss: 0.3911   train/acc: 0.8883   valid/acc: 0.8673
Epoch: 1400   train/loss: 0.3053   valid/loss: 0.3948   train/acc: 0.8859   valid/acc: 0.8673
Epoch: 1600   train/loss: 0.2925   valid/loss: 0.3943   train/acc: 0.8786   valid/acc: 0.8673
Epoch: 1800   train/loss: 0.2934   valid/loss: 0.3929   train/acc: 0.8811   valid/acc: 0.8584
Epoch: 2000   train/loss: 0.286   valid/loss: 0.3933   train/acc: 0.8786   valid/acc: 0.8673
Epoch: 2200   train/loss: 0.2814   valid/loss: 0.3917   train/acc: 

Test accuracy: 0.80357146
Progress:  0.924
Number of training circuits: 413   Number of validation circuits: 113   Number of test circuits: 112   Number of parameters in model: 286


Epoch: 200   train/loss: 0.3024   valid/loss: 0.3908   train/acc: 0.8741   valid/acc: 0.8584
Epoch: 400   train/loss: 0.2735   valid/loss: 0.3918   train/acc: 0.8935   valid/acc: 0.8673
Epoch: 600   train/loss: 0.2922   valid/loss: 0.3934   train/acc: 0.8789   valid/acc: 0.8673
Epoch: 800   train/loss: 0.2853   valid/loss: 0.3931   train/acc: 0.891   valid/acc: 0.8673
Epoch: 1000   train/loss: 0.2967   valid/loss: 0.3949   train/acc: 0.8789   valid/acc: 0.8673
Epoch: 1200   train/loss: 0.2907   valid/loss: 0.3924   train/acc: 0.8959   valid/acc: 0.8673
Epoch: 1400   train/loss: 0.301   valid/loss: 0.3958   train/acc: 0.8692   valid/acc: 0.8673
Epoch: 1600   train/loss: 0.2757   valid/loss: 0.3946   train/acc: 0.8886   valid/acc: 0.8673
Epoch: 1800   train/loss: 0.2926   valid/loss: 0.3941   train/acc: 0.8886   valid/acc: 0.8673
Epoch: 2000   train/loss: 0.2775   valid/loss: 0.3931   train/acc: 0.8838   valid/acc: 0.8584
Epoch: 2200   train/loss: 0.2893   valid/loss: 0.3929   train/acc: