# Packages

In [None]:
!pip install cirq
!pip install git+https://github.com/qdevpsi3/quantum-nearest-classifier.git

# Utils

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
from sklearn.metrics import accuracy_score

In [None]:
def iris_experiment(model):
    iris = datasets.load_iris()
    X = iris.data
    y = y_true = iris.target

    # train model
    model.fit(X, y)

    # test model
    y_pred = model.predict(X)
    score = accuracy_score(y_true, y_pred)
    error = 100 * (1. - score)

    return error


def iris_plot(c_error, q_errors):
    labels = ['100', '500', '1000']

    x = np.arange(len(labels))
    width = 0.35

    fig, ax = plt.subplots()
    ax.bar(x - width / 2, q_errors[0], width, label='w/o mitigation')
    ax.bar(x + width / 2, q_errors[1], width, label='mitigation')
    plt.axhline(y=c_error, color='black', linestyle='--', label='classical')

    ax.set_ylabel('Classification error %')
    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    ax.legend(loc='lower right')

    fig.tight_layout()

    plt.show()

# Experiment

In [None]:
from sklearn.neighbors import NearestCentroid

from quantum_ncs.classifier import QuantumNearestCentroid

c_model = NearestCentroid()
q_model = QuantumNearestCentroid(error_rate=0.05)

c_error = iris_experiment(c_model)
q_errors = []

for mitigation in [False, True]:
    errors = []
    q_model.error_mitigation = mitigation
    for repetitions in [100, 500, 1000]:
        q_model.repetitions = repetitions
        errors.append(iris_experiment(q_model))
    q_errors.append(errors)

iris_plot(c_error, q_errors)