# Important: do Cell > All Outputs > Clear before commiting

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import qiskit
from qiskit import BasicAer, QuantumCircuit, ClassicalRegister, QuantumRegister, execute
from qiskit.aqua.circuits import FixedValueComparator

In [None]:
NUM_BITS = 3

In [None]:
x = QuantumRegister(NUM_BITS, name="x")
work = QuantumRegister(NUM_BITS-1, name="a")
y = QuantumRegister(1, name="y")
x_meas = ClassicalRegister(NUM_BITS, name="m")

# y ^= (x == k)
def u_omega_eq(k):
    qc = QuantumCircuit(x, work, y)
    qc.barrier()
    for i in range(NUM_BITS):
        if k & (1 << i):
            pass
        else:
            qc.x(x[i])
    qc.barrier()
    qc.ccx(x[0], x[1], work[0])
    for i in range(2, NUM_BITS-1):
        qc.ccx(x[i], work[i-2], work[i-1])
    qc.ccx(x[-1], work[-1], y)
    for i in range(NUM_BITS-2, 1, -1):
        qc.ccx(x[i], work[i-2], work[i-1])
    qc.ccx(x[0], x[1], work[0])
    
    qc.barrier()
    for i in range(NUM_BITS):
        if k & (1 << i):
            pass
        else:
            qc.x(x[i])
    qc.barrier()

    return qc

# y ^= (x < k)
def u_omega_cmp(k):
    qc = QuantumCircuit(x, work, y)
    cmp = FixedValueComparator(NUM_BITS, value=k, geq=False)
    cmp.build(qc, [x[i] for i in range(NUM_BITS)] + [y[0]], work)
    return qc

# S = {0,2,4,6}
N = 1<<(NUM_BITS-1)
def a():
    qc = QuantumCircuit(x)
    for i in range(1, NUM_BITS):
        qc.h(x[i])
    return qc

def a_inv():
    return a()

def diffusion():
    qc = QuantumCircuit(x)
    qc += a_inv()
    qc += u_omega_eq(0)
    qc += a()
    return qc

def grover_iteration(u_omega):
    qc = QuantumCircuit()
    qc += u_omega
    qc.barrier()
    qc += diffusion()
    return qc

In [None]:
# TODO: wrong for comparison
THETA = 2*np.arcsin(1/np.sqrt(N))
ITERS = (np.pi/2 - THETA/2) / THETA

print(ITERS)
ITERS = int(np.round(ITERS))
print(ITERS)


num_grover_iters = {0, ITERS}
k = 1
while k <= ITERS:
    num_grover_iters.add(k)
    k *= 2
    
print(num_grover_iters)

def find_min(iters):
    qc = QuantumCircuit(x)
    qc += a()
    working_min = measure(x)
    
    for _ in range(iters):
        experiments = []
        for k in num_grover_iters:
            qc = QuantumCircuit(x,work,y)
            qc += a()
            qc.x(y)
            qc.h(y)
            for _ in k:
                qc += grover_iteration(u_omega_cmp(working_min))
            experiments.append(qc)
            k = max(k, 1) * 2
        result = execute(experiments, simulator).result()
        print(result)

qc = QuantumCircuit(x,work,y)
qc += a()
qc.x(y)
qc.h(y)
qc.barrier()
for _ in range(1):
    qc += u_omega_cmp(5)
    #qc += grover_iteration(u_omega_eq, 0)


#qc += u_omega(4)

In [None]:

# Select the StatevectorSimulator from the Aer provider
simulator = BasicAer.get_backend('statevector_simulator')

# Execute and get counts
result = execute(qc, simulator).result()
statevector = result.get_statevector(qc)
# statevector[bitconcat()]
for i in range(len(statevector)):
    x_val = i & ((1<<NUM_BITS)-1)
    y_val = i >> (NUM_BITS + NUM_BITS-2)
    if statevector[i] != 0.0:
        print("{:.3f} |{}>|{}>".format(statevector[i], x_val, y_val))

In [None]:
working_min = 10000000
backend = BasicAer.get_backend('qasm_simulator')
print(BasicAer.backends())

In [None]:
experiments = []
for k in {0}:
    qc = QuantumCircuit(x,work,y,x_meas)
    qc += a()
    qc.x(y)
    qc.h(y)
    qc.barrier()
    for _ in range(k):
        qc += grover_iteration(u_omega_cmp(working_min))
    qc.measure(x, x_meas)
    #res = qc.draw(output='mpl')
    experiments.append(qc)
    k = max(k, 1) * 2
result = execute(experiments, backend, shots=100, seed_simulator=1)
result = result.result()
print(result.backend_name)
counts = result.get_counts(0)
print(result.get_counts(0))

print(result.results[0].data)
print(counts)
def unpack_bin(counts):
    res = {}
    for k in counts.keys():
        res[int(k, 2)] = counts[k]
    return res

counts = unpack_bin(counts)
print(counts)

#print({min({i&((1<<NUM_BITS)-1) for i in range(len(res.data.statevector)) if res.data.statevector[i] != 0.0}) for res in result.results})
#res