In [1]:
import warnings
warnings.filterwarnings("ignore")

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import numpy as np

BATCH_SIZE = 50
EPOCHS = 120
SEED = 2

In [3]:
def read_data(filename):
    labels, sentences1, sentences2 = [], [], []
    with open(filename) as f:
        for line in f:
            split_list = line.split(",")
            t = int(split_list[2])
            labels.append(t)
            sentences1.append(split_list[0])
            sentences2.append(split_list[1])
    return labels, sentences1, sentences2

all_labels, all_sentences1, all_sentences2 = read_data('MC1.TXT')
all_train_data = list(zip((list(zip(all_sentences1[0:80], all_sentences2[0:80]))), all_labels[0: 80]))

train_labels, train_data1, train_data2 = all_labels[0: 80], all_sentences1[0: 80], all_sentences2[0: 80]
test_labels, test_data1, test_data2 = all_labels[80: 100], all_sentences1[80: 100], all_sentences2[80: 100]
all_test_data =  list(zip(list(zip(all_sentences1[80:100], all_sentences2[80:100])), all_labels[80: 100]))
first_sentences = [x[0][0][:] for x in all_train_data]
second_sentences = [x[0][1][:] for x in all_train_data]

first_sentences_test = [x[0][0][:] for x in all_test_data]
second_sentences_test = [x[0][1][:] for x in all_test_data]


In [4]:
from lambeq import BobcatParser

parser = BobcatParser(verbose='text')

raw_train_sentences1_diagrams = parser.sentences2diagrams(first_sentences)
raw_train_sentences2_diagrams = parser.sentences2diagrams(second_sentences)
raw_test_sentences1_diagrams = parser.sentences2diagrams(first_sentences_test)
raw_test_sentences2_diagrams = parser.sentences2diagrams(second_sentences_test)

Tagging sentences.
Parsing tagged sentences.
Turning parse trees to diagrams.
Tagging sentences.
Parsing tagged sentences.
Turning parse trees to diagrams.
Tagging sentences.
Parsing tagged sentences.
Turning parse trees to diagrams.
Tagging sentences.
Parsing tagged sentences.
Turning parse trees to diagrams.


In [5]:
from lambeq import remove_cups

train_s1_diagrams = [remove_cups(diagram) for diagram in raw_train_sentences1_diagrams]

test_s1_diagrams = [remove_cups(diagram) for diagram in raw_test_sentences1_diagrams]

train_s2_diagrams = [remove_cups(diagram) for diagram in raw_train_sentences2_diagrams]

test_s2_diagrams = [remove_cups(diagram) for diagram in raw_test_sentences2_diagrams]

In [6]:
from lambeq import AtomicType, IQPAnsatz

ansatz = IQPAnsatz({AtomicType.NOUN: 1, AtomicType.SENTENCE: 1},
                   n_layers=1, n_single_qubit_params=3)

train_s1_circuits = [ansatz(diagram) for diagram in train_s1_diagrams]

test_s1_circuits = [ansatz(diagram) for diagram in test_s1_diagrams]

train_s2_circuits = [ansatz(diagram) for diagram in train_s2_diagrams]

test_s2_circuits = [ansatz(diagram) for diagram in test_s2_diagrams]

sentence1_circuits = train_s1_circuits + test_s1_circuits
senetence2_circuits =  train_s2_circuits + test_s2_circuits
print(sentence1_circuits[17])

Ket(0, 0, 0) >> H @ Id(2) >> Id(1) @ H @ Id(1) >> Id(2) @ H >> CRz(creates__n.r@s@n.l_0) @ Id(1) >> Id(1) @ CRz(creates__n.r@s@n.l_1) >> Id(2) @ Rx(dish__n_2) >> Id(2) @ Rz(dish__n_1) >> Id(2) @ Rx(dish__n_0) >> Id(2) @ Bra(0) >> Rx(chef__n_2) @ Id(1) >> Rz(chef__n_1) @ Id(1) >> Rx(chef__n_0) @ Id(1) >> Bra(0) @ Id(1)


In [7]:
from pytket.circuit.display import render_circuit_jupyter

tket_circuit = sentence1_circuits[17].to_tk()

render_circuit_jupyter(tket_circuit)

In [8]:
from lambeq import NumpyModel, Model
class ClassificationModel(NumpyModel):
    
    def forward(self, Mytuple):
        a = np.array(Mytuple)
        y_hat = []
        output1 = NumpyModel.get_diagram_output(self, diagrams=a[:, 0])
        output2 = NumpyModel.get_diagram_output(self, diagrams=a[:, 1])
        final_o1 = np.nan_to_num(output1)
        final_o2 = np.nan_to_num(output2)
        results_tweaked = [np.abs(np.array(res) - 1e-9) for res in output1]
        assert len(results_tweaked[0]) == 2
        results_tweaked2 = [np.abs(np.array(res) - 1e-9) for res in output2]
        assert len(results_tweaked2[0]) == 2
        pred_labels = [np.round(res / np.sum(res)) for res in results_tweaked]
        pred_labels2 = [np.round(res / np.sum(res)) for res in results_tweaked2]
        print(pred_labels[0])
        print(pred_labels2[0])
        for idx in range(0, len(pred_labels)):
            if pred_labels[idx][0] == pred_labels2[idx][0] and pred_labels[idx][1] == pred_labels[idx][1]:
                y_hat.append(1)
                print('matched!')
            else:
                y_hat.append(0)
                print('unmatched!')
        return np.array(y_hat)

In [9]:

train_circuits2 = list(zip(train_s1_circuits , train_s2_circuits))

test_circuits2 = list(zip(test_s1_circuits , test_s2_circuits))
all_circuits = train_circuits2 + test_circuits2
sentence1_circuits= [ seq[0] for seq in all_circuits]
sentence2_circuits= [ seq[1] for seq in all_circuits]
all_circuits_lst = sentence1_circuits + sentence2_circuits

model = ClassificationModel.from_diagrams(all_circuits_lst, use_jit=True)


In [10]:
import numpy
epsilon=numpy.finfo('float').eps
loss = lambda y_hat, y: -np.sum(y * np.log(y_hat + epsilon)) / len(y)  # binary cross-entropy lossةخيثم
acc = lambda y_hat, y: np.sum(np.round(y_hat) == y) / len(y) / 2  # half due to double-counting

In [11]:
from lambeq import QuantumTrainer, SPSAOptimizer

trainer = QuantumTrainer(
    model,
    epochs=EPOCHS,
    loss_function=loss,
    optimizer=SPSAOptimizer,
    optim_hyperparams={'a': 0.05, 'c': 0.06, 'A':0.01*EPOCHS},
    evaluate_functions={'acc': acc},
    evaluate_on_train=True,
    verbose = 'text',
    seed=0
)

In [12]:
from lambeq import Dataset

train_dataset = Dataset(
            train_circuits2,
            train_labels,
            batch_size=BATCH_SIZE)

val_dataset = Dataset(test_circuits2, test_labels, shuffle=False)


In [13]:
trainer.fit(train_dataset, val_dataset, logging_step=12)



[0. 1.]
[0. 1.]
matched!
matched!
matched!
unmatched!
matched!
matched!
matched!
unmatched!
unmatched!
matched!
matched!
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
matched!
matched!
unmatched!
matched!
unmatched!
matched!
matched!
matched!
matched!
unmatched!
unmatched!
matched!
unmatched!
matched!
matched!
matched!
matched!
unmatched!
matched!
matched!
matched!
unmatched!
matched!
unmatched!
unmatched!
unmatched!
matched!
matched!
unmatched!
matched!
matched!
matched!
[0. 1.]
[0. 1.]
matched!
unmatched!
unmatched!
unmatched!
matched!
matched!
unmatched!
unmatched!
unmatched!
matched!
unmatched!
matched!
unmatched!
matched!
unmatched!
matched!
matched!
matched!
matched!
matched!
unmatched!
matched!
unmatched!
matched!
matched!
matched!
matched!
unmatched!
matched!
matched!
unmatched!
unmatched!
matched!
unmatched!
matched!
matched!
unmatched!
matched!
unmatched!
unmatched!
matched!
matched!
matched!
unmatched!
unmatched!
unmatched!
unmatched!
matched!


Epoch 1:    train/loss: 8.5604   valid/loss: 5.4065   train/acc: 0.2313   valid/acc: 0.3000


[0. 1.]
[0. 1.]
matched!
matched!
matched!
unmatched!
matched!
matched!
unmatched!
matched!
matched!
matched!
unmatched!
matched!
unmatched!
unmatched!
matched!
unmatched!
matched!
unmatched!
matched!
matched!
matched!
matched!
matched!
matched!
matched!
unmatched!
unmatched!
matched!
unmatched!
matched!
matched!
matched!
matched!
matched!
matched!
matched!
matched!
matched!
matched!
unmatched!
matched!
matched!
unmatched!
matched!
matched!
matched!
unmatched!
matched!
unmatched!
unmatched!
[1. 0.]
[1. 0.]
matched!
unmatched!
matched!
unmatched!
unmatched!
matched!
matched!
unmatched!
matched!
unmatched!
unmatched!
unmatched!
matched!
unmatched!
matched!
matched!
matched!
matched!
matched!
matched!
matched!
unmatched!
matched!
matched!
unmatched!
unmatched!
unmatched!
unmatched!
matched!
unmatched!
[1. 0.]
[1. 0.]
matched!
matched!
matched!
matched!
unmatched!
matched!
unmatched!
matched!
unmatched!
matched!
unmatched!
unmatched!
matched!
unmatched!
unmatched!
matched!
unmatched!
match

Epoch 12:   train/loss: 7.6593   valid/loss: 5.4065   train/acc: 0.2562   valid/acc: 0.4000


[1. 0.]
[0. 1.]
unmatched!
matched!
matched!
matched!
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
matched!
matched!
matched!
unmatched!
unmatched!
matched!
unmatched!
matched!
unmatched!
unmatched!
unmatched!
matched!
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
matched!
unmatched!
matched!
unmatched!
matched!
matched!
unmatched!
matched!
unmatched!
unmatched!
matched!
matched!
unmatched!
unmatched!
matched!
matched!
matched!
[1. 0.]
[0. 1.]
unmatched!
unmatched!
unmatched!
matched!
matched!
matched!
unmatched!
unmatched!
matched!
unmatched!
matched!
unmatched!
unmatched!
matched!
unmatched!
unmatched!
unmatched!
matched!
matched!
unmatched!
matched!
matched!
unmatched!
matched!
unmatched!
matched!
unmatched!
matched!
matched!
unmatched!
[0. 1.]
[0. 1.]
matched!
unmatched!
unmatched!
matched!
matched!
matched!
matched!
matched!
unmatched!
matched!
matched!
unmatched!
unmatched!
unmatched!


Epoch 24:   train/loss: 10.3626   valid/loss: 1.8022   train/acc: 0.2188   valid/acc: 0.3500


[1. 0.]
[1. 0.]
matched!
matched!
matched!
unmatched!
unmatched!
matched!
matched!
matched!
unmatched!
unmatched!
matched!
unmatched!
matched!
matched!
unmatched!
matched!
matched!
matched!
matched!
matched!
matched!
unmatched!
matched!
matched!
matched!
matched!
unmatched!
unmatched!
matched!
unmatched!
matched!
matched!
matched!
matched!
unmatched!
unmatched!
matched!
matched!
matched!
matched!
matched!
matched!
matched!
unmatched!
matched!
unmatched!
matched!
matched!
matched!
matched!
[0. 1.]
[1. 0.]
unmatched!
unmatched!
matched!
unmatched!
matched!
matched!
matched!
unmatched!
matched!
unmatched!
matched!
unmatched!
unmatched!
matched!
matched!
unmatched!
unmatched!
matched!
matched!
matched!
matched!
unmatched!
matched!
unmatched!
unmatched!
matched!
unmatched!
matched!
matched!
unmatched!
[0. 1.]
[1. 0.]
unmatched!
unmatched!
unmatched!
unmatched!
matched!
matched!
unmatched!
unmatched!
matched!
unmatched!
unmatched!
matched!
unmatched!
matched!
matched!
matched!
matched!
match

Epoch 36:   train/loss: 8.3351   valid/loss: 9.0109   train/acc: 0.2687   valid/acc: 0.1500


[0. 1.]
[0. 1.]
matched!
unmatched!
unmatched!
unmatched!
matched!
unmatched!
matched!
matched!
unmatched!
unmatched!
unmatched!
matched!
unmatched!
unmatched!
unmatched!
unmatched!
matched!
matched!
unmatched!
matched!
unmatched!
matched!
unmatched!
unmatched!
matched!
matched!
matched!
unmatched!
matched!
unmatched!
matched!
unmatched!
unmatched!
matched!
matched!
matched!
matched!
matched!
unmatched!
matched!
matched!
unmatched!
matched!
unmatched!
unmatched!
unmatched!
matched!
matched!
matched!
unmatched!
[0. 1.]
[1. 0.]
unmatched!
matched!
unmatched!
matched!
matched!
unmatched!
matched!
matched!
matched!
matched!
matched!
matched!
unmatched!
unmatched!
unmatched!
matched!
matched!
matched!
matched!
unmatched!
matched!
unmatched!
unmatched!
matched!
unmatched!
unmatched!
matched!
matched!
matched!
matched!
[0. 1.]
[1. 0.]
unmatched!
matched!
unmatched!
matched!
unmatched!
unmatched!
unmatched!
unmatched!
matched!
unmatched!
unmatched!
unmatched!
unmatched!
matched!
unmatched!
mat

Epoch 48:   train/loss: 9.0109   valid/loss: 5.4065   train/acc: 0.2812   valid/acc: 0.3000


[1. 0.]
[0. 1.]
unmatched!
unmatched!
unmatched!
unmatched!
matched!
unmatched!
unmatched!
unmatched!
matched!
matched!
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
matched!
unmatched!
matched!
matched!
unmatched!
matched!
unmatched!
matched!
matched!
unmatched!
unmatched!
unmatched!
unmatched!
matched!
unmatched!
matched!
matched!
unmatched!
matched!
unmatched!
matched!
unmatched!
unmatched!
matched!
matched!
unmatched!
matched!
unmatched!
unmatched!
matched!
unmatched!
unmatched!
[1. 0.]
[1. 0.]
matched!
matched!
matched!
matched!
unmatched!
matched!
matched!
unmatched!
matched!
matched!
unmatched!
unmatched!
matched!
unmatched!
unmatched!
unmatched!
matched!
unmatched!
unmatched!
unmatched!
matched!
unmatched!
unmatched!
unmatched!
unmatched!
matched!
matched!
matched!
matched!
unmatched!
unmatched!
matched!
matched!
unmatched!
matched!
matched!
matched!
unmatched!
matched!
matched!
unmatched!
matched!
unmatched!
unmatched!
unmatched!
unmat

Epoch 60:   train/loss: 9.9120   valid/loss: 9.0109   train/acc: 0.2313   valid/acc: 0.2500


[1. 0.]
[1. 0.]
matched!
matched!
matched!
unmatched!
matched!
unmatched!
matched!
matched!
matched!
unmatched!
matched!
matched!
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
matched!
unmatched!
unmatched!
unmatched!
unmatched!
matched!
matched!
matched!
matched!
unmatched!
matched!
matched!
unmatched!
unmatched!
matched!
unmatched!
matched!
unmatched!
matched!
unmatched!
unmatched!
unmatched!
matched!
matched!
matched!
matched!
matched!
unmatched!
matched!
matched!
unmatched!
matched!
[0. 1.]
[1. 0.]
unmatched!
unmatched!
matched!
matched!
unmatched!
matched!
matched!
unmatched!
unmatched!
matched!
matched!
unmatched!
matched!
matched!
unmatched!
unmatched!
matched!
unmatched!
matched!
matched!
matched!
matched!
unmatched!
matched!
unmatched!
matched!
matched!
unmatched!
unmatched!
matched!
[0. 1.]
[1. 0.]
unmatched!
unmatched!
matched!
matched!
matched!
unmatched!
matched!
matched!
unmatched!
matched!
unmatched!
unmatched!
matched!
matched!
matched!
unmatched!
ma

Epoch 72:   train/loss: 7.6593   valid/loss: 7.2087   train/acc: 0.2750   valid/acc: 0.3000


[0. 1.]
[1. 0.]
unmatched!
matched!
matched!
unmatched!
matched!
unmatched!
unmatched!
unmatched!
matched!
unmatched!
unmatched!
unmatched!
matched!
matched!
unmatched!
matched!
unmatched!
matched!
matched!
unmatched!
matched!
unmatched!
matched!
matched!
unmatched!
unmatched!
matched!
unmatched!
unmatched!
matched!
unmatched!
matched!
unmatched!
unmatched!
unmatched!
matched!
matched!
matched!
unmatched!
matched!
matched!
unmatched!
unmatched!
unmatched!
unmatched!
matched!
matched!
matched!
matched!
unmatched!
[1. 0.]
[0. 1.]
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
matched!
unmatched!
matched!
unmatched!
matched!
unmatched!
matched!
unmatched!
unmatched!
matched!
matched!
matched!
matched!
matched!
matched!
matched!
matched!
matched!
unmatched!
unmatched!
matched!
matched!
[1. 0.]
[0. 1.]
unmatched!
matched!
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
matched!
matched!
matched!
unmatched!
unmatched!
matched!
matched!
matched!

Epoch 84:   train/loss: 6.9835   valid/loss: 5.4065   train/acc: 0.3438   valid/acc: 0.3000


[1. 0.]
[0. 1.]
unmatched!
unmatched!
matched!
unmatched!
matched!
matched!
unmatched!
matched!
unmatched!
unmatched!
unmatched!
unmatched!
matched!
unmatched!
matched!
unmatched!
matched!
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
matched!
matched!
unmatched!
matched!
unmatched!
unmatched!
matched!
unmatched!
unmatched!
matched!
matched!
unmatched!
matched!
unmatched!
matched!
matched!
unmatched!
matched!
unmatched!
matched!
unmatched!
matched!
unmatched!
unmatched!
matched!
matched!
unmatched!
unmatched!
[0. 1.]
[0. 1.]
matched!
unmatched!
unmatched!
unmatched!
matched!
unmatched!
matched!
unmatched!
unmatched!
matched!
matched!
matched!
matched!
matched!
unmatched!
unmatched!
matched!
unmatched!
matched!
unmatched!
matched!
unmatched!
unmatched!
unmatched!
matched!
matched!
matched!
matched!
unmatched!
unmatched!
[1. 0.]
[1. 0.]
matched!
matched!
matched!
matched!
unmatched!
unmatched!
unmatched!
matched!
matched!
matched!
matched!
matched!
matched!
matched!
matched!
unm

Epoch 96:   train/loss: 10.8131   valid/loss: 1.8022   train/acc: 0.2250   valid/acc: 0.3250


[0. 1.]
[0. 1.]
matched!
matched!
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
unmatched!
matched!
matched!
unmatched!
unmatched!
unmatched!
matched!
matched!
matched!
unmatched!
matched!
matched!
matched!
unmatched!
matched!
unmatched!
matched!
matched!
matched!
unmatched!
matched!
unmatched!
unmatched!
matched!
unmatched!
matched!
unmatched!
unmatched!
unmatched!
matched!
unmatched!
matched!
matched!
matched!
matched!
matched!
unmatched!
matched!
unmatched!
matched!
matched!
matched!
matched!
[1. 0.]
[1. 0.]
matched!
unmatched!
unmatched!
unmatched!
matched!
matched!
unmatched!
unmatched!
matched!
unmatched!
matched!
unmatched!
matched!
unmatched!
matched!
matched!
matched!
matched!
matched!
unmatched!
matched!
matched!
matched!
unmatched!
matched!
matched!
matched!
matched!
matched!
matched!
[0. 1.]
[1. 0.]
unmatched!
matched!
unmatched!
matched!
unmatched!
unmatched!
matched!
unmatched!
matched!
matched!
unmatched!
unmatched!
unmatched!
matched!
unmatched!
matched!
unmatc

Epoch 108:  train/loss: 8.7856   valid/loss: 7.2087   train/acc: 0.2250   valid/acc: 0.2250


[0. 1.]
[0. 1.]
matched!
matched!
matched!
unmatched!
matched!
matched!
matched!
matched!
unmatched!
unmatched!
matched!
unmatched!
matched!
unmatched!
unmatched!
unmatched!
matched!
matched!
unmatched!
matched!
matched!
matched!
matched!
matched!
unmatched!
matched!
matched!
matched!
matched!
unmatched!
unmatched!
matched!
matched!
matched!
unmatched!
matched!
matched!
matched!
matched!
unmatched!
matched!
matched!
unmatched!
matched!
matched!
matched!
unmatched!
unmatched!
matched!
matched!
[1. 0.]
[0. 1.]
unmatched!
unmatched!
matched!
matched!
matched!
unmatched!
matched!
unmatched!
matched!
matched!
matched!
matched!
unmatched!
matched!
matched!
matched!
unmatched!
unmatched!
unmatched!
matched!
matched!
unmatched!
matched!
unmatched!
matched!
matched!
unmatched!
unmatched!
matched!
unmatched!
[0. 1.]
[0. 1.]
matched!
unmatched!
matched!
unmatched!
matched!
unmatched!
unmatched!
unmatched!
matched!
matched!
matched!
matched!
unmatched!
unmatched!
unmatched!
matched!
unmatched!
unm

Epoch 120:  train/loss: 9.0109   valid/loss: 3.6044   train/acc: 0.2188   valid/acc: 0.3500

Training completed!
