In [None]:
import matplotlib.pyplot as plt
import numpy as np
import qiskit as qs
from qiskit.circuit.library import MCMTGate, ZGate, phase_estimation

In [None]:
from qiskit.transpiler import generate_preset_pass_manager
from qiskit_ibm_runtime import SamplerV2 as Sampler, EstimatorV2 as Estimator
from qiskit_ibm_runtime.fake_provider import FakeMelbourneV2
from qiskit.primitives import StatevectorSampler

hardware = False

if hardware:
    fake_melbourne = FakeMelbourneV2()
    pm = generate_preset_pass_manager(backend=fake_melbourne, optimization_level=1)
    options = {"simulator": {"seed_simulator": 42}}
    sampler = Sampler(mode=fake_melbourne, options=options)
else:
    sampler = StatevectorSampler()

In [None]:
def grover_operator(oracle, n_input_qubits):
    n = oracle.num_qubits
    
    grover_diff = qs.QuantumCircuit(n)
    
    for i in range(n_input_qubits):
        grover_diff.h(i)
        grover_diff.x(i)
    grover_diff.append(MCMTGate(ZGate(), n_input_qubits-1, 1), list(range(n_input_qubits)))
    for i in range(n_input_qubits):
        grover_diff.x(i)
        grover_diff.h(i)
    
    return oracle.compose(grover_diff, list(range(n)))

def initial_state(n_input_qubits, n_aux_qubits):
    n = n_input_qubits+n_aux_qubits+1
    init = qs.QuantumCircuit(n, n_input_qubits)
    for i in range(n_input_qubits):
        init.h(i)
    init.x(n-1)
    init.h(n-1)
    return init

def opt_r(N, k):
    r_float = np.pi/4/np.arcsin(np.sqrt(k/N)) - 1/2
    return int(np.round(r_float))

In [None]:
x = 2   #extension of tesselation rectangle in x direction
y = 2   #extension of tesselation rectangle in y direction

pieces = [
    [[0,0],[1,0],[0,1]],
    [[0,0]]
]

In [None]:
def n_qubit(x):
    return int(np.ceil(np.log2(x)))

def binary(k, bits):
    if k >= 2**bits:
        raise Exception("k cannot be larger than 2^bits-1")
    _binary = [int(x) for x in bin(k)[2:]]
    return np.append(np.zeros(bits-len(_binary), dtype=int), np.array(_binary))

def check_site(site_x, site_y, n_enc_qubits_x, n_enc_qubits_y):
    x_bin = binary(site_x, n_enc_qubits_x)
    y_bin = binary(site_y, n_enc_qubits_y)

    qc = qs.QuantumCircuit(n_enc_qubits_x+n_enc_qubits_y+1)

    for i in range(n_enc_qubits_x):
        if x_bin[i] == 0:
            qc.x(i)
    for i in range(n_enc_qubits_y):
        if y_bin[i] == 0:
            qc.x(n_enc_qubits_x+i)
            
    qc.mcx(list(range(n_enc_qubits_x+n_enc_qubits_y)), n_enc_qubits_x+n_enc_qubits_y)

    for i in range(n_enc_qubits_x):
        if x_bin[i] == 0:
            qc.x(i)
    for i in range(n_enc_qubits_y):
        if y_bin[i] == 0:
            qc.x(n_enc_qubits_x+i)

    return qc

#implement a simple circuit that adds a constant integer to a register
def adder(a, n_target_qubits, n_control_qubits):
    qc = qs.QuantumCircuit(n_target_qubits+n_control_qubits)

    def incrementer(k):
        qc = qs.QuantumCircuit(n_target_qubits+n_control_qubits)
        for i in range(n_target_qubits-k):
            if n_control_qubits+n_target_qubits-1-i-k > 0:
                qc.mcx((list(range(n_control_qubits))
                        + list(range(n_control_qubits+k, n_control_qubits+n_target_qubits-1-i))),
                       n_control_qubits+n_target_qubits-1-i)
            else:
                qc.x(n_control_qubits+n_target_qubits-1-i)
        return qc
    
    def decrementer(k):
        qc = qs.QuantumCircuit(n_target_qubits+n_control_qubits)
        for i in range(n_target_qubits-k):
            if n_control_qubits+n_target_qubits-1-i-k > 0:
                for j in range(n_control_qubits+k, n_control_qubits+n_target_qubits-1-i):
                    qc.x(j)
                qc.mcx((list(range(n_control_qubits))
                        + list(range(n_control_qubits+k, n_control_qubits+n_target_qubits-1-i))),
                       n_control_qubits+n_target_qubits-1-i)
                for j in range(n_control_qubits+k, n_control_qubits+n_target_qubits-1-i):
                    qc.x(j)
            else:
                qc.x(n_control_qubits+n_target_qubits-1-i)
        return qc

    def binary(x):
        return [int(y) for y in bin(x)[2:]]

    if a > 0:
        bin_a = binary(a)
        for k, b in enumerate(reversed(bin_a)):
            if b:
                qc = qc.compose(incrementer(k), list(range(n_target_qubits+n_control_qubits)))
    else:
        bin_a = binary(abs(a))
        for k, b in enumerate(reversed(bin_a)):
            if b:
                qc = qc.compose(decrementer(k), list(range(n_target_qubits+n_control_qubits)))        
    return qc

def compute_occupied_site(piece_leg, x, y, max_ext_x, max_ext_y):
    n_enc_qubits_x = n_qubit(x)
    n_enc_qubits_y = n_qubit(y)
    n_enc_qubits_ext_x = n_qubit(x+2*max_ext_x)
    n_enc_qubits_ext_y = n_qubit(y+2*max_ext_y)
    
    qc = qs.QuantumCircuit(n_enc_qubits_x+n_enc_qubits_y+2
                           +n_enc_qubits_ext_x+n_enc_qubits_ext_y)

    #copy to working register 
    for i in range(n_enc_qubits_x):
        qc.cx(i, n_enc_qubits_x+n_enc_qubits_y+2+i)
    for i in range(n_enc_qubits_y):
        qc.cx(n_enc_qubits_x+i, n_enc_qubits_x+n_enc_qubits_y+2+n_enc_qubits_ext_x+i)

    #shift such that legs of pieces cannot lie outside the extended area
    qc.compose(
        adder(max_ext_x, n_enc_qubits_ext_x, 0),
        list(range(
            n_enc_qubits_x+n_enc_qubits_y+2,
            n_enc_qubits_x+n_enc_qubits_y+2+n_enc_qubits_ext_x))
    )
    qc.compose(
        adder(max_ext_y, n_enc_qubits_ext_y, 0),
        list(range(
            n_enc_qubits_x+n_enc_qubits_y+2+n_enc_qubits_ext_x,
            n_enc_qubits_x+n_enc_qubits_y+2+n_enc_qubits_ext_x+n_enc_qubits_ext_y
        ))
    )

    def add_x(a):
        if a!=0:
            qc.compose(
                adder(a, n_enc_qubits_ext_x, 2),
                (list(range(
                    n_enc_qubits_x+n_enc_qubits_y,
                    n_enc_qubits_x+n_enc_qubits_y+2)) +
                list(range(
                    n_enc_qubits_x+n_enc_qubits_y+2,
                    n_enc_qubits_x+n_enc_qubits_y+2+n_enc_qubits_ext_x))
                )
            )

    def add_y(a):
        if a!=0:
            qc.compose(
                adder(a, n_enc_qubits_ext_y, 2),
                (list(range(
                    n_enc_qubits_x+n_enc_qubits_y,
                    n_enc_qubits_x+n_enc_qubits_y+2)) +
                list(range(
                    n_enc_qubits_x+n_enc_qubits_y+2+n_enc_qubits_ext_x,
                    n_enc_qubits_x+n_enc_qubits_y+2+n_enc_qubits_ext_x+n_enc_qubits_ext_y))
                )
            )

    #depending on rotation of piece, add extension of piece leg, unless the piece leg is (0,0)
    if piece_leg != [0,0]:
        qc.x(n_enc_qubits_x+n_enc_qubits_y)
        qc.x(n_enc_qubits_x+n_enc_qubits_y+1)
        add_x(piece_leg[0])
        add_y(piece_leg[1])
        qc.x(n_enc_qubits_x+n_enc_qubits_y)
        qc.x(n_enc_qubits_x+n_enc_qubits_y+1)

        qc.x(n_enc_qubits_x+n_enc_qubits_y)
        add_x(piece_leg[1])
        add_y(-piece_leg[0])
        qc.x(n_enc_qubits_x+n_enc_qubits_y)

        qc.x(n_enc_qubits_x+n_enc_qubits_y+1)
        add_x(-piece_leg[1])
        add_y(piece_leg[0])
        qc.x(n_enc_qubits_x+n_enc_qubits_y+1)

        add_x(-piece_leg[0])
        add_y(-piece_leg[1])

    return qc

def tessellation_oracle(x, y, pieces):
    n_enc_qubits_x = n_qubit(x)
    n_enc_qubits_y = n_qubit(y)
    
    max_ext_x = 0
    max_ext_y = 0
    for piece in pieces:
        for leg in piece:
            if abs(leg[0]) > max_ext_x:
                max_ext_x = abs(leg[0])
            if abs(leg[1]) > max_ext_y:
                max_ext_y = abs(leg[1])     
    
    n_enc_qubits_ext_x = n_qubit(x+2*max_ext_x)
    n_enc_qubits_ext_y = n_qubit(y+2*max_ext_y)
    
    n_input_qubits = (n_enc_qubits_x + n_enc_qubits_y + 2) * len(pieces)
    n_aux_qubits = n_enc_qubits_ext_x + n_enc_qubits_ext_y + x*y
    n = n_input_qubits + n_aux_qubits
    N = 2**n_input_qubits
    
    qc = qs.QuantumCircuit(n)

    for i, piece in enumerate(pieces):
        for leg in piece:
            qc = qc.compose(
                compute_occupied_site(leg, x, y, max_ext_x, max_ext_y),
                (list(range(
                    (n_enc_qubits_x+n_enc_qubits_y+2)*i,
                    (n_enc_qubits_x+n_enc_qubits_y+2)*(i+1))) +
                list(range(
                    (n_enc_qubits_x+n_enc_qubits_y+2)*len(pieces),
                    (n_enc_qubits_x+n_enc_qubits_y+2)*len(pieces)+n_enc_qubits_ext_x+n_enc_qubits_ext_y))
                )
            )

            for sx in range(x):
                for sy in range(y):
                    s = sy*x + sx
                    qc = qc.compose(
                        check_site(sx+max_ext_x, sy+max_ext_y, n_enc_qubits_ext_x, n_enc_qubits_ext_y),
                        (list(range(
                            (n_enc_qubits_x+n_enc_qubits_y+2)*len(pieces),
                            (n_enc_qubits_x+n_enc_qubits_y+2)*len(pieces)+n_enc_qubits_ext_x+n_enc_qubits_ext_y)) +
                        [(n_enc_qubits_x+n_enc_qubits_y+2)*len(pieces)+n_enc_qubits_ext_x+n_enc_qubits_ext_y+s]
                        )
                    )             

            #uncompute the addition
            qc = qc.compose(
                compute_occupied_site(leg, x, y, max_ext_x, max_ext_y).inverse(),
                (list(range(
                    (n_enc_qubits_x+n_enc_qubits_y+2)*i,
                    (n_enc_qubits_x+n_enc_qubits_y+2)*(i+1))) +
                list(range(
                    (n_enc_qubits_x+n_enc_qubits_y+2)*len(pieces),
                    (n_enc_qubits_x+n_enc_qubits_y+2)*len(pieces)+n_enc_qubits_ext_x+n_enc_qubits_ext_y))
                )
            )

    oracle = qs.QuantumCircuit(n+1)

    oracle = oracle.compose(qc, list(range(n)))

    oracle.mcx(
        list(range(
            (n_enc_qubits_x+n_enc_qubits_y+2)*len(pieces)+n_enc_qubits_ext_x+n_enc_qubits_ext_y,
            (n_enc_qubits_x+n_enc_qubits_y+2)*len(pieces)+n_enc_qubits_ext_x+n_enc_qubits_ext_y+x*y)),
        n
    )
    
    oracle = oracle.compose(qc.inverse(), list(range(n)))

    return oracle, n_input_qubits, n_aux_qubits

In [None]:
oracle, n_input_qubits, n_aux_qubits = tessellation_oracle(x, y, pieces)

qc = initial_state(n_input_qubits, n_aux_qubits)

N = 2**n_input_qubits
r = opt_r(N, 4)

gr_op = grover_operator(oracle, n_input_qubits)

for i in range(r):
    qc = qc.compose(gr_op, list(range(n_input_qubits+n_aux_qubits+1)))

qc.measure(list(range(n_input_qubits)), list(range(n_input_qubits)))

In [None]:
from qiskit.quantum_info import Statevector
import copy

for i in range(256):
    binary = list(reversed([int(x) for x in bin(i)[2:]]))
    binary_string = bin(i)[2:]

    _oracle = oracle.copy()

    for j, b in enumerate(binary):
        if b:
            _oracle.x(j)

    sv = Statevector.from_instruction(_oracle)
    print(binary, sv.to_dict()[(oracle.num_qubits-len(binary_string))*'0'+binary_string])

In [None]:
shots = 2**15
if hardware:
    job = sampler.run([pm.run(qc)], shots=shots)
else:
    job = sampler.run([qc], shots=shots)
print(job.result()[0].data.c.get_counts())

In [None]:
print(dict(qc.count_ops()))

In [None]:
gr_op = grover_operator(oracle, n_input_qubits)

n_est = 1

qc = qs.QuantumCircuit(n+n_est)

for i in range(n):
    qc.h(i)

qc.compose(phase_estimation(n_est, gr_op))

qc.measure_all()

In [None]:
#shots = 2048
#if hardware:
#    job = sampler.run([pm.run(qc)], shots=shots)
#else:
#    job = sampler.run([qc], shots=shots)
#result = job.result()[0].data.meas.get_counts()