# Solving a tessellation problem with the Grover algorithm
In this notebook we are going to solve a tessellation problem with the Grover algorithm, which provides quadratic speedup over classical search algorithms. Namely, we consider a rectangle of size $x\times y$ and ask how it can be tiled by a number of $p$ polyomino pieces (geometric shapes composed of unit squares). In order to do so, we encode the $x$ and $y$ coordinate of each polyomino piece in a suitable number of qubits and its rotational state in two additional qubits. From these, a cover map is computed; the Grover oracle returns a sign flip when all squares are covered.

The speedup of Grover's algorithm is only quadratic; on the other hand, there exist powerful tree-based classical algorithms for solving the tesselation problem that are much faster than checking each possible distribution of all the pieces over the rectangle. Therefore, it remains unclear whether there is any scenario of the tessellation problem where the quantum algorithm provides an actual speedup over the best classical algorithm.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import qiskit as qs
from qiskit.circuit.library import MCMTGate, ZGate, phase_estimation

In [None]:
from qiskit.transpiler import generate_preset_pass_manager
from qiskit_ibm_runtime import SamplerV2 as Sampler, EstimatorV2 as Estimator
from qiskit_ibm_runtime.fake_provider import FakeMelbourneV2
from qiskit.primitives import StatevectorSampler

hardware = False

if hardware:
    fake_melbourne = FakeMelbourneV2()
    pm = generate_preset_pass_manager(backend=fake_melbourne, optimization_level=1)
    options = {"simulator": {"seed_simulator": 42}}
    sampler = Sampler(mode=fake_melbourne, options=options)
else:
    sampler = StatevectorSampler()

Definition of Grover operator, initial state and optimal number of Grover iterations:

In [None]:
def grover_operator(oracle, n_input_qubits):
    n = oracle.num_qubits
    
    grover_diff = qs.QuantumCircuit(n)
    
    for i in range(n_input_qubits):
        grover_diff.h(i)
        grover_diff.x(i)
    grover_diff.append(MCMTGate(ZGate(), n_input_qubits-1, 1), list(range(n_input_qubits)))
    for i in range(n_input_qubits):
        grover_diff.x(i)
        grover_diff.h(i)
    
    return oracle.compose(grover_diff, list(range(n)))

def initial_state(n_input_qubits, n_aux_qubits):
    n = n_input_qubits+n_aux_qubits+1
    init = qs.QuantumCircuit(n, n_input_qubits)
    for i in range(n_input_qubits):
        init.h(i)
    init.x(n-1)
    init.h(n-1)
    return init

def opt_r(N, k):
    r_float = np.pi/4/np.arcsin(np.sqrt(k/N)) - 1/2
    return int(np.round(r_float))

Setup of the problem:

In [None]:
x = 2   #extension of tesselation rectangle in x direction
y = 2   #extension of tesselation rectangle in y direction

pieces = [                 #list of polyomino pieces, defined by their legs
    [[0,0],[1,0],[0,1]],
    [[0,0]]
]

In order to decide whether a configuration is a valid tesselation, we loop through the pieces and their legs. For each piece, we copy its position to an auxilliary register and then add the extension of the respective leg to this position. After this operation, we loop through all sites, where each site has its own qubit indicating whether it is occupied or not. For each site we check whether the position in the auxilliary register agrees with its index and in case flip it (in the end, all site qubits must be flipped in order to effectuate an overall sign flip). Note that this implies that if a site is occupied twice, the corresponding qubit would be flipped back, indicating it is not occupied. However, if the number of squares of all pieces equals the number of available squares in the rectangle, a configuration with overlapping pieces cannot be a solution anyways, so this is no problem. If the number of squares of all pieces is chosen greater than that of available squares, however, our circuit implements the somewhat peculiar problem of finding a complete tesselation where a site counts as occupied if an odd number of pieces occupies it, and as unoccupied if an even number of pieces occupies it.

In [None]:
def n_qubit(x):
    return int(np.ceil(np.log2(x)))

def binary(k, bits):
    if k >= 2**bits:
        raise Exception("k cannot be larger than 2^bits-1")
    _binary = [int(x) for x in bin(k)[2:]]
    return np.append(np.zeros(bits-len(_binary), dtype=int), np.array(_binary))

def check_site(site_x, site_y, n_enc_qubits_x, n_enc_qubits_y):
    x_bin = list(reversed(binary(site_x, n_enc_qubits_x)))
    y_bin = list(reversed(binary(site_y, n_enc_qubits_y)))

    qc = qs.QuantumCircuit(n_enc_qubits_x+n_enc_qubits_y+1)

    for i in range(n_enc_qubits_x):
        if x_bin[i] == 0:
            qc.x(i)
    for i in range(n_enc_qubits_y):
        if y_bin[i] == 0:
            qc.x(n_enc_qubits_x+i)
            
    qc.mcx(list(range(n_enc_qubits_x+n_enc_qubits_y)), n_enc_qubits_x+n_enc_qubits_y)

    for i in range(n_enc_qubits_x):
        if x_bin[i] == 0:
            qc.x(i)
    for i in range(n_enc_qubits_y):
        if y_bin[i] == 0:
            qc.x(n_enc_qubits_x+i)

    return qc

#implement a simple circuit that adds a constant integer to a register
def adder(a, n_target_qubits, n_control_qubits):
    qc = qs.QuantumCircuit(n_target_qubits+n_control_qubits)

    def incrementer(k):
        qc = qs.QuantumCircuit(n_target_qubits+n_control_qubits)
        for i in range(n_target_qubits-k):
            if n_control_qubits+n_target_qubits-1-i-k > 0:
                qc.mcx((list(range(n_control_qubits))
                        + list(range(n_control_qubits+k, n_control_qubits+n_target_qubits-1-i))),
                       n_control_qubits+n_target_qubits-1-i)
            else:
                qc.x(n_control_qubits+n_target_qubits-1-i)
        return qc
    
    def decrementer(k):
        qc = qs.QuantumCircuit(n_target_qubits+n_control_qubits)
        for i in range(n_target_qubits-k):
            if n_control_qubits+n_target_qubits-1-i-k > 0:
                for j in range(n_control_qubits+k, n_control_qubits+n_target_qubits-1-i):
                    qc.x(j)
                qc.mcx((list(range(n_control_qubits))
                        + list(range(n_control_qubits+k, n_control_qubits+n_target_qubits-1-i))),
                       n_control_qubits+n_target_qubits-1-i)
                for j in range(n_control_qubits+k, n_control_qubits+n_target_qubits-1-i):
                    qc.x(j)
            else:
                qc.x(n_control_qubits+n_target_qubits-1-i)
        return qc

    def binary(x):
        return [int(y) for y in bin(x)[2:]]

    if a > 0:
        bin_a = binary(a)
        for k, b in enumerate(reversed(bin_a)):
            if b:
                qc = qc.compose(incrementer(k), list(range(n_target_qubits+n_control_qubits)))
    else:
        bin_a = binary(abs(a))
        for k, b in enumerate(reversed(bin_a)):
            if b:
                qc = qc.compose(decrementer(k), list(range(n_target_qubits+n_control_qubits)))        
    return qc

#copy position of piece to working register, add the position of its leg under consideration
def compute_occupied_site(piece_leg, x, y, max_ext):
    n_enc_qubits_x = n_qubit(x)
    n_enc_qubits_y = n_qubit(y)
    n_enc_qubits_ext_x = n_qubit(x+2*max_ext)
    n_enc_qubits_ext_y = n_qubit(y+2*max_ext)
    
    qc = qs.QuantumCircuit(n_enc_qubits_x+n_enc_qubits_y+2
                           +n_enc_qubits_ext_x+n_enc_qubits_ext_y)

    #copy to working register 
    for i in range(n_enc_qubits_x):
        qc.cx(i, n_enc_qubits_x+n_enc_qubits_y+2+i)
    for i in range(n_enc_qubits_y):
        qc.cx(n_enc_qubits_x+i, n_enc_qubits_x+n_enc_qubits_y+2+n_enc_qubits_ext_x+i)

    #shift such that legs of pieces cannot lie outside the extended area
    qc = qc.compose(
        adder(max_ext, n_enc_qubits_ext_x, 0),
        list(range(
            n_enc_qubits_x+n_enc_qubits_y+2,
            n_enc_qubits_x+n_enc_qubits_y+2+n_enc_qubits_ext_x))
    )
    qc = qc.compose(
        adder(max_ext, n_enc_qubits_ext_y, 0),
        list(range(
            n_enc_qubits_x+n_enc_qubits_y+2+n_enc_qubits_ext_x,
            n_enc_qubits_x+n_enc_qubits_y+2+n_enc_qubits_ext_x+n_enc_qubits_ext_y
        ))
    )

    def add_x(a, qc):
        if a != 0:
            return qc.compose(
                adder(a, n_enc_qubits_ext_x, 2),
                (list(range(
                    n_enc_qubits_x+n_enc_qubits_y,
                    n_enc_qubits_x+n_enc_qubits_y+2)) +
                list(range(
                    n_enc_qubits_x+n_enc_qubits_y+2,
                    n_enc_qubits_x+n_enc_qubits_y+2+n_enc_qubits_ext_x))
                )
            )
        else:
            return qc

    def add_y(a, qc):
        if a != 0:
            return qc.compose(
                adder(a, n_enc_qubits_ext_y, 2),
                (list(range(
                    n_enc_qubits_x+n_enc_qubits_y,
                    n_enc_qubits_x+n_enc_qubits_y+2)) +
                list(range(
                    n_enc_qubits_x+n_enc_qubits_y+2+n_enc_qubits_ext_x,
                    n_enc_qubits_x+n_enc_qubits_y+2+n_enc_qubits_ext_x+n_enc_qubits_ext_y))
                )
            )
        else:
            return qc

    #depending on rotation of piece, add extension of piece leg, unless the piece leg is (0,0)
    if piece_leg != [0,0]:
        qc.x(n_enc_qubits_x+n_enc_qubits_y)
        qc.x(n_enc_qubits_x+n_enc_qubits_y+1)
        qc = add_x(piece_leg[0], qc)
        qc = add_y(piece_leg[1], qc)
        qc.x(n_enc_qubits_x+n_enc_qubits_y)
        qc.x(n_enc_qubits_x+n_enc_qubits_y+1)

        qc.x(n_enc_qubits_x+n_enc_qubits_y)
        qc = add_x(piece_leg[1], qc)
        qc = add_y(-piece_leg[0], qc)
        qc.x(n_enc_qubits_x+n_enc_qubits_y)

        qc.x(n_enc_qubits_x+n_enc_qubits_y+1)
        qc = add_x(-piece_leg[1], qc)
        qc = add_y(piece_leg[0], qc)
        qc.x(n_enc_qubits_x+n_enc_qubits_y+1)

        qc = add_x(-piece_leg[0], qc)
        qc = add_y(-piece_leg[1], qc)

    return qc

#compose the previously defined functions to an oracle    
def tessellation_oracle(x, y, pieces):
    n_enc_qubits_x = n_qubit(x)
    n_enc_qubits_y = n_qubit(y)
    
    max_ext = 0
    for piece in pieces:
        for leg in piece:
            if abs(leg[0]) > max_ext:
                max_ext = abs(leg[0])
            if abs(leg[1]) > max_ext:
                max_ext = abs(leg[1])     
    
    n_enc_qubits_ext_x = n_qubit(x+2*max_ext)
    n_enc_qubits_ext_y = n_qubit(y+2*max_ext)
    
    n_input_qubits = (n_enc_qubits_x + n_enc_qubits_y + 2) * len(pieces)
    n_aux_qubits = n_enc_qubits_ext_x + n_enc_qubits_ext_y + x*y
    n = n_input_qubits + n_aux_qubits
    
    qc = qs.QuantumCircuit(n)

    for i, piece in enumerate(pieces):
        for leg in piece:
            qc = qc.compose(
                compute_occupied_site(leg, x, y, max_ext),
                (list(range(
                    (n_enc_qubits_x+n_enc_qubits_y+2)*i,
                    (n_enc_qubits_x+n_enc_qubits_y+2)*(i+1))) +
                list(range(
                    (n_enc_qubits_x+n_enc_qubits_y+2)*len(pieces),
                    (n_enc_qubits_x+n_enc_qubits_y+2)*len(pieces)+n_enc_qubits_ext_x+n_enc_qubits_ext_y))
                )
            )
        
            for sx in range(x):
                for sy in range(y):
                    s = sy*x + sx
                    qc = qc.compose(
                        check_site(sx+max_ext, sy+max_ext, n_enc_qubits_ext_x, n_enc_qubits_ext_y),
                        (list(range(
                            (n_enc_qubits_x+n_enc_qubits_y+2)*len(pieces),
                            (n_enc_qubits_x+n_enc_qubits_y+2)*len(pieces)+n_enc_qubits_ext_x+n_enc_qubits_ext_y)) +
                        [(n_enc_qubits_x+n_enc_qubits_y+2)*len(pieces)+n_enc_qubits_ext_x+n_enc_qubits_ext_y+s]
                        )
                    )
            
            #uncompute the addition
            qc = qc.compose(
                compute_occupied_site(leg, x, y, max_ext).inverse(),
                (list(range(
                    (n_enc_qubits_x+n_enc_qubits_y+2)*i,
                    (n_enc_qubits_x+n_enc_qubits_y+2)*(i+1))) +
                list(range(
                    (n_enc_qubits_x+n_enc_qubits_y+2)*len(pieces),
                    (n_enc_qubits_x+n_enc_qubits_y+2)*len(pieces)+n_enc_qubits_ext_x+n_enc_qubits_ext_y))
                )
            )
            

    oracle = qs.QuantumCircuit(n+1)

    oracle = oracle.compose(qc, list(range(n)))

    oracle.mcx(
        list(range(
            (n_enc_qubits_x+n_enc_qubits_y+2)*len(pieces)+n_enc_qubits_ext_x+n_enc_qubits_ext_y,
            (n_enc_qubits_x+n_enc_qubits_y+2)*len(pieces)+n_enc_qubits_ext_x+n_enc_qubits_ext_y+x*y)),
        n
    )
    
    oracle = oracle.compose(qc.inverse(), list(range(n)))

    return oracle, n_input_qubits, n_aux_qubits

Define a classical function that checks the correctness of a solution:

In [None]:
def check_correctness(input_qubits, x, y, pieces, n_enc_qubits_x, n_enc_qubits_y):
    occupancies = np.zeros((x,y), dtype=bool)
    def convert(binary):
        x = 0
        for i, b in enumerate(binary):
            if b:
                x += 2**i
        return x
    piece_block = n_enc_qubits_x+n_enc_qubits_y+2
    for i, piece in enumerate(pieces):
        for leg in piece:
            pos_x = convert(input_qubits[i*piece_block:i*piece_block+n_enc_qubits_x])
            pos_y = convert(input_qubits[i*piece_block+n_enc_qubits_x:i*piece_block+n_enc_qubits_x+n_enc_qubits_y])
            rot = input_qubits[i*piece_block+n_enc_qubits_x+n_enc_qubits_y:i*piece_block+n_enc_qubits_x+n_enc_qubits_y+2]
            
            if rot == [0,0]:
                pos_site_x = pos_x + leg[0]
                pos_site_y = pos_y + leg[1]
            elif rot == [0,1]:
                pos_site_x = pos_x + leg[1]
                pos_site_y = pos_y - leg[0]
            elif rot == [1,0]:
                pos_site_x = pos_x - leg[1]
                pos_site_y = pos_y + leg[0]
            elif rot == [1,1]:
                pos_site_x = pos_x - leg[0]
                pos_site_y = pos_y - leg[1]
            
            if (0 <= pos_site_x < x) and (0 <= pos_site_y < y):
                occupancies[pos_site_x][pos_site_y] = True

    return bool(occupancies.all())

For debugging it can be helpful to check the output of the oracle on all input states (makes sense only for a small number of input qubits):

In [None]:
from qiskit.quantum_info import Statevector

oracle, n_input_qubits, n_aux_qubits = tessellation_oracle(x, y, pieces)

for i in range(2**n_input_qubits):
    _binary = list(reversed([int(x) for x in bin(i)[2:]]))
    _binary_string = bin(i)[2:]
    
    _qc = qs.QuantumCircuit(oracle.num_qubits)
    for j, b in enumerate(_binary):
        if b:
            _qc.x(j)
    _qc.x(oracle.num_qubits-1)
    _qc.h(oracle.num_qubits-1)
    _qc = _qc.compose(oracle)
    sv = Statevector.from_instruction(_qc)
    
    print(_binary+[0]*(n_input_qubits-len(_binary)), 
          sv.to_dict()[(oracle.num_qubits-len(_binary_string))*'0'+_binary_string].real < 0)

Since we do not know the number of solutions in advance, we execute Grover's algorithm multiple times with different $k$:

In [None]:
oracle, n_input_qubits, n_aux_qubits = tessellation_oracle(x, y, pieces)
gr_op = grover_operator(oracle, n_input_qubits)

N = 2**n_input_qubits

shots = 2**15


for i in range(n_input_qubits):
    k = 2**i
    qc = initial_state(n_input_qubits, n_aux_qubits)
    r = opt_r(N, k)
    for i in range(r):
        qc = qc.compose(gr_op, list(range(n_input_qubits+n_aux_qubits+1)))
    qc.measure(list(range(n_input_qubits)), list(range(n_input_qubits)))

    if hardware:
        job = sampler.run([pm.run(qc)], shots=shots)
    else:
        job = sampler.run([qc], shots=shots)
    counts = job.result()[0].data.c.get_counts()
    #print(counts)
    if any([(v > 0.5*shots/k) for v in counts.values()]):
        print(f"Grover algorithm converged for {k} solutions.")
        result = {_k for _k, v in counts.items() if (v > 0.5*shots/k)}
        break

Check whether solutions are true solutions:

In [None]:
for res in result:
    print(res, 
          check_correctness(
              [int(b) for b in reversed(res)],
              x, y, 
              pieces,
              n_qubit(x), n_qubit(y)
          )
         )