# LRSP CNOTs, depht and fidelity
### Section 2.1 of the Supplementary Information (Fig.3)

In [None]:
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import random

from qiskit import transpile
from qiskit.providers.aer.backends import AerSimulator

from qclib.state_preparation.schmidt import initialize, cnot_count as cnots

In [None]:
def _fidelity(input_state, transpiled_circuit):
    backend = AerSimulator()
    transpiled_circuit.save_statevector()
    ket = np.asarray(backend.run(transpiled_circuit).result().get_statevector())
    bra = np.conj(input_state)
    
    return np.abs(bra.dot(ket))**2

def _counts(input_state, r=0, result=None, plot=False, partition=None):
    circuit = initialize(input_state, low_rank=r, partition=partition)
    transpiled_circuit = transpile(circuit, basis_gates=['u1','u2','u3', 'cx'], optimization_level=3)

    if plot:
        _plot_density_matrix(circuit)

    count_ops = transpiled_circuit.count_ops()
    n_cx = 0
    if 'cx' in count_ops:
        n_cx = count_ops['cx']
    n_dp = transpiled_circuit.depth()
    
    fidelity = _fidelity(input_state, transpiled_circuit)

    if result is not None:
        result.append([int(np.log2(r)), n_cx, n_dp, fidelity])

    est_cx = cnots(input_state, low_rank=r, method='estimate')

    print('\tCNOTs = {0}\tdepth = {1}\tfidelity={2}\test.cx={3}'.format(n_cx, n_dp, fidelity, est_cx))

def _grid_search(input_state, plot=False):
    result = []

    n = int(np.log2(len(input_state)))
    for m in list(range(n//2+1))[::-1]:
        r = 2**m
        print('rank = {0}:'.format(r), end='')
        _counts(input_state, r=r, result=result, plot=plot)

    return result

# Random dense states

In [None]:
rnd = np.random.RandomState(42)

min_n = 11
max_n = 14

result = {}
for n in range(min_n, max_n+1):
    N = 2**n
    input_state = ((1.0 - 0.001) * rnd.rand(N) + 0.001) + ((1.0 - 0.001) * rnd.rand(N) + 0.001) * 1j
    input_state = input_state/np.linalg.norm(input_state)

    print('\nn = {0}:'.format(n))
    result[n] = _grid_search(input_state)


# Plot dense results

In [None]:
def plot1(n, r, fidelity):
    # libraries
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    from matplotlib.ticker import FormatStrFormatter
    
    plt.rcParams["font.family"] = "Times New Roman"

    # Data
    df=pd.DataFrame({ 'x_values': r, 'fidelity': fidelity })

    # multiple line plots
    plt.plot( 'x_values', 'fidelity', 'C2' , data=df, marker='s', markersize=8)
    plt.xticks(r) # force integer on x axis
    # show legend
    plt.legend(loc='upper left')
    
    # config
    plt.legend(loc='upper left', fontsize=20)
    plt.xticks(r, fontsize=20)
    plt.yticks(fontsize=20)
    #plt.xlabel('m', fontsize=20)
    plt.gcf().set_dpi(500)
    
    # format y axis numbers
    ax = plt.gca()
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    ax.set_ylim([0.7, 1.05])

    # show graph
    plt.savefig(f'save/fidelity'+str(n)+f'qubits.pdf')
    plt.show()

def plot2(n, r, n_cx, n_dp):
    # libraries
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd

    plt.rcParams["font.family"] = "Times New Roman"
    
    # Data
    df=pd.DataFrame({ 'x_values': r, 'cnots': n_cx, 'depth': n_dp })

    # multiple line plots
    plt.plot( 'x_values', 'cnots' , data=df, marker='o', markersize=8)
    plt.plot( 'x_values', 'depth' , data=df, marker='^', markersize=8)
    plt.xticks(r) # force integer on x axis
    # show legend
    plt.legend()
    
    # config
    plt.legend(fontsize=20)
    plt.xticks(r, fontsize=20)
    plt.yticks(fontsize=20)
    #plt.xlabel('m', fontsize=20)
    plt.gcf().set_dpi(500)
    
    # show graph
    plt.savefig(f'save/hyperparameter'+str(n)+f'qubits.pdf')
    plt.show()

def plot_result(n, result):
    r_range = [r[0] for r in result[n]]
    n_cx    = [r[1] for r in result[n]]
    n_dp    = [r[2] for r in result[n]]

    plot2(n, r_range, n_cx, n_dp)

def plot_fidelity(n, result):
    r_range = [r[0] for r in result[n]]
    fidelity = [r[3] for r in result[n]]

    plot1(n, r_range, fidelity)

for n in range(min_n, max_n+1):
    plot_result(n, result)
    plot_fidelity(n, result)
