# Experiments on VQC
### Extra info and data (Figure 4)

In [None]:
import warnings
warnings.filterwarnings("ignore")

import numpy as np

from qiskit import QuantumCircuit, transpile
from qiskit.providers.aer.backends import AerSimulator

from qclib.machine_learning.datasets import digits
from qclib.state_preparation import BaaLowRankInitialize
from qclib.state_preparation.util.baa import adaptive_approximation

seed = 42

In [None]:
# Dataset load.

matrix_dim = 8
sample_total, training_input, test_input, class_labels = digits.load(classes=[0, 1],
                                                                     training_size=40,
                                                                     test_size=10,
                                                                     random_seed=seed)

In [None]:
baa_strategy = 'brute_force'
baa_low_rank = True

In [None]:
# Dataset numerical analysis.
# Average number of CNOTs, depth, and fidelity.

def _fidelity(input_state, transpiled_circuit):
    backend = AerSimulator()
    transpiled_circuit.save_statevector()
    ket = backend.run(transpiled_circuit).result().get_statevector()
    bra = np.conj(input_state)

    return np.abs(bra.dot(ket))**2

def _counts(input_state, l, result, method='baa'):
    if method == 'qiskit':
        n_qubits = int(np.log2(len(input_state)))
        circuit = QuantumCircuit(n_qubits)
        circuit.initialize(input_state)
    else:
        circuit = BaaLowRankInitialize(input_state, opt_params={'max_fidelity_loss':l, 'strategy':baa_strategy, 'use_low_rank':baa_low_rank}).definition
    transpiled_circuit = transpile(circuit, basis_gates=['u1','u2','u3', 'cx'], optimization_level=3)
    
    count_ops = transpiled_circuit.count_ops()
    n_cx = 0
    if 'cx' in count_ops:
        n_cx = count_ops['cx']
    n_dp = transpiled_circuit.depth()

    fidelity = _fidelity(input_state, transpiled_circuit)

    result.append([l, n_cx, n_dp, fidelity])
        
def _grid_search(input_state, fidelity_loss=None):
    result = []
    if fidelity_loss is None:
        fidelity_loss = [i/10 for i in range(11)]

    n = int(np.log2(len(input_state)))
    for l in fidelity_loss:
        _counts(input_state, l=l, result=result, method='baa')
    
    return result

def _print_results(results):
    fidelity_loss = [r[0] for r in results[0]]

    n_cx = {}
    n_dp = {}
    fidelity = {}
    for l in fidelity_loss:
        n_cx[l] = []
        n_dp[l] = []
        fidelity[l] = []

    for result in results:
        for l in fidelity_loss:
            n_cx[l].extend([r[1] for r in result if r[0]==l])
            n_dp[l].extend([r[2] for r in result if r[0]==l])
            fidelity[l].extend([r[3] for r in result if r[0]==l])
    
    print('AVG:')
    for l in fidelity_loss:
        avg_cx = sum(n_cx[l]) / len(n_cx[l])
        avg_dp = sum(n_dp[l]) / len(n_dp[l])
        avg_fidelity = sum(fidelity[l]) / len(fidelity[l])
        print('l={3}\tCNOTs={0}\tdepth={1}\tfidelity={2}'.format(avg_cx, avg_dp, avg_fidelity, l))
    
    print('STD:')
    for l in fidelity_loss:
        std_cx = np.std(n_cx[l])
        std_dp = np.std(n_dp[l])
        std_fidelity = np.std(fidelity[l])
        print('l={3}\tCNOTs={0}\tdepth={1}\tfidelity={2}'.format(std_cx, std_dp, std_fidelity, l))

results = []
print('training set')
for i in training_input:
    print(f'class {i}', end=' ')
    for input_state in training_input[i]:
        print('.', end='')
        results.append(_grid_search(input_state))
    print()
print('test set')
for i in test_input:
    print(f'class {i}', end=' ')
    for input_state in test_input[i]:
        print('.', end='')
        results.append(_grid_search(input_state))
    print()

_stdout = sys.stdout
with open('save/digits_dataset_info.txt', 'w') as f:
    sys.stdout = f
    _print_results(results)
    sys.stdout = _stdout

In [None]:
# Plots examples of the states that represent the digits.

def plot_digits(digits, labels, text):
    import matplotlib.pyplot as plt

    _, axes = plt.subplots(nrows=1, ncols=len(digits), figsize=(10, 3))
    for ax, digit, label in zip(axes, digits, labels):
        ax.set_axis_off()
        image = digit[:int(matrix_dim**2)].reshape(matrix_dim, matrix_dim)
        ax.imshow(image, cmap=plt.cm.gray_r, interpolation='none')
        ax.set_title(label)

    plt.savefig(f'save/digits_{text}.pdf')
    plt.show()

for sample in range(3):
    print(f'sample: {sample}')
    original = {}
    for i in training_input:
        original[i] = training_input[i][sample]
        original[i] = training_input[i][sample]

    for digit in training_input:
        _digits = []
        _labels = []
        for l_max in [i/10 for i in range(11)]:
            node = adaptive_approximation(original[digit], max_fidelity_loss=l_max, strategy=baa_strategy, use_low_rank=baa_low_rank)
            _digits.append(node.state_vector()**2)
            _labels.append(str(l_max))
        
        plot_digits(_digits, _labels, f'{digit}_{sample}')
