# Circuit learning module: Lambeq's QuantumTrainer

This module performs the optimization with Lambeq's native optimizer. Because the circuits are constructed with Lambeq and DisCoPy, this optimizer is the natural choice. The code is based on the workflow presented in https://github.com/CQCL/lambeq/blob/main/docs/examples/quantum_pipeline.ipynb.

In [1]:
import warnings
warnings.filterwarnings('ignore')

import json
import os
import glob
from pathlib import Path
from jax import numpy as np
#import numpy as np
import pickle
import math

from discopy.utils import loads
from pytket.extensions.qiskit import AerBackend
from lambeq import TketModel, NumpyModel
from lambeq import QuantumTrainer, SPSAOptimizer
from lambeq import Dataset

from calibrate import calibrate
from utils import read_diagrams, create_labeled_classes, bin_class_loss, multi_class_loss, bin_class_acc, multi_class_acc, visualize_results

from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score

this_folder = os.path.abspath(os.getcwd())
os.environ['TOKENIZERS_PARALLELISM'] = 'true'

# Uncomment if you do not want to access GPU
#os.environ["JAX_PLATFORMS"] = "cpu"

BATCH_SIZE = 64
EPOCHS = 3000
SEED = 0

backend = AerBackend()
backend_config = {
    'backend': backend,
    'compilation': backend.default_compilation_pass(2),
    'shots': 3200
}

ModuleNotFoundError: No module named 'jaxlib'

## Select workload and read the circuits

Select if we perform binary classification or multi-class classification. Give number of qubits to create classes:
- 1 qubits -> 2^1 = 2 classes i.e. binary classification
- 2 qubits -> 2^2 = 4 classes
- ...
- 5 qubits -> 2^5 = 32 classes, etc.

In [None]:
# Select workload
workload = "execution_time"
#workload = "cardinality"

# Select workload size
#workload_size = "small"
#workload_size = "medium"
#workload_size = "large"
workload_size = "main"

classification = 1
layers = 1
single_qubit_params = 3
n_wire_count = 1
loss = multi_class_loss
acc = multi_class_acc

if classification == 1:
    loss = bin_class_loss
    acc = bin_class_acc

# Access the selected circuits
path_name = this_folder + "//simplified-JOB-diagrams//" + workload_size + "//circuits//" + str(classification) + "//" + str(layers) + "_layer//" + str(single_qubit_params) + "_single_qubit_params//" + str(n_wire_count) + "_n_wire_count//"

training_circuits_paths = glob.glob(path_name + "training//[0-9]*.p")
validation_circuits_paths = glob.glob(path_name + "validation//[0-9]*.p")
test_circuits_paths = glob.glob(path_name + "test//[0-9]*.p")

We read the circuits from the pickled files.

In [None]:
training_circuits = read_diagrams(training_circuits_paths)
validation_circuits = read_diagrams(validation_circuits_paths)
test_circuits = read_diagrams(test_circuits_paths)

## Read training, validation and test data

In [None]:
training_data, test_data, validation_data = None, None, None
data_path = this_folder + "//data//" + workload + "//" + workload_size + "//"

with open(data_path + "training_data.json", "r") as inputfile:
    training_data = json.load(inputfile)['training_data']
with open(data_path + "test_data.json", "r") as inputfile:
    test_data = json.load(inputfile)['test_data']
with open(data_path + "validation_data.json", "r") as inputfile:
    validation_data = json.load(inputfile)['validation_data']

training_data_labels = create_labeled_classes(training_data, classification)
test_data_labels = create_labeled_classes(test_data, classification)
validation_data_labels = create_labeled_classes(validation_data, classification)

## Prepare circuits for the Lambeq optimizer

## Continuously optimize parameters

In [None]:
def get_symbols(circs):
    return set([sym for circuit in circs.values() for sym in circuit.free_symbols])


def construct_data_and_labels(circuits, labels):
    circuits_l = []
    data_labels_l = []
    for key in circuits:
        circuits_l.append(circuits[key])
        data_labels_l.append(labels[key])
    return circuits_l, data_labels_l


def select_circuits(base_circuits, selected_circuits):
    res = {}
    syms = get_symbols(base_circuits)
    for c in selected_circuits:
        s_syms = set(selected_circuits[c].free_symbols)
        if s_syms.difference(syms) == set():
            res[c] = selected_circuits[c]
    return res

In [None]:
eval_metrics = {"acc": acc}
syms = {}

all_training_keys = list(training_circuits.keys())
initial_circuit_keys = all_training_keys[:1]

current_training_circuits = {}
for k in initial_circuit_keys:
    current_training_circuits[k] = training_circuits[k]

for run, key in enumerate(all_training_keys[1:]):
    print("Progress: ", round(run/len(all_training_keys), 2))
    current_training_circuits[key] = training_circuits[key]
    
    if run == 0:
        syms = get_symbols(current_training_circuits)
    else:
        if len(syms) == len(get_symbols(current_training_circuits)):
            continue
    
    # Select those circuits from test and validation circuits which share the parameters with the current training circuits
    current_validation_circuits = select_circuits(current_training_circuits, validation_circuits)
    current_test_circuits = select_circuits(current_training_circuits, test_circuits)
    
    # Create lists with circuits and their corresponding label
    training_circuits_l, training_data_labels_l = construct_data_and_labels(current_training_circuits, training_data_labels)
    validation_circuits_l, validation_data_labels_l = construct_data_and_labels(current_validation_circuits, validation_data_labels)
    test_circuits_l, test_data_labels_l = construct_data_and_labels(current_test_circuits, test_data_labels)
    
    print(f"Number of training circuits: {len(training_circuits_l)}   ",
          f"Number of validation circuits: {len(validation_circuits_l)}   ",
          f"Number of test circuits: {len(test_circuits_l)}   ",
          f"Number of parameters in model: {len(set([sym for circuit in training_circuits_l for sym in circuit.free_symbols]))}")
    
    # Select model
    #model = TketModel.from_diagrams(training_circuits_l, backend_config = backend_config)
    model = NumpyModel.from_diagrams(training_circuits_l, use_jit=True)
    
    # Initialize the weights from the possible previous training checkpoint
    if run > 0:
        model.from_checkpoint(this_folder + "//training_checkpoints//" + workload + "//checkpoint_" + str(run - 1) + ".chk")
    else:
        model.initialise_weights()
    
    trainer = QuantumTrainer(
        model,
        loss_function=loss,
        epochs=EPOCHS,
        optimizer=SPSAOptimizer,
        optim_hyperparams={'a': 0.0053, 'c': 0.0185, 'A':0.01*EPOCHS},
        evaluate_functions=eval_metrics,
        evaluate_on_train=True,
        verbose = 'text',
        seed=SEED
        )
    
    train_dataset = Dataset(training_circuits_l, training_data_labels_l)
    val_dataset = Dataset(validation_circuits_l, validation_data_labels_l, shuffle=False)
    trainer.fit(train_dataset, val_dataset, evaluation_step=1, logging_step=100)
    checkpoint_path = this_folder + "//training_checkpoints//" + workload + "//checkpoint_" + str(run) + ".chk"
    model.make_checkpoint(checkpoint_path)
    visualize_results(model, trainer, test_circuits_l, test_data_labels_l, acc)
    syms = get_symbols(current_training_circuits)

## Select the model

Select the used model between `TketModel` or `NumpyModel`. `NumpyModel` can use JAX which speeds up the training.

## Define loss function and evaluation metrics

## Initialize the trainer and the datasets

## Train the model and visualize