In [None]:
import numpy as np
from qiskit import QuantumRegister, ClassicalRegister, QuantumCircuit, Aer, execute
%matplotlib inline
from qiskit.providers.aer import AerError, QasmSimulator
from qiskit.visualization import plot_histogram

backend_state = Aer.get_backend('statevector_simulator')
backend_sim = Aer.get_backend('qasm_simulator')

In [None]:
## Input section
nNodes=int(input("Number of nodes : "))
nColors=int(input("Number of colours : ")) # nColors <= nNodes

In [None]:
# Calculation of required Qubits
nc=nColors+1                  # For each node, one extra bit for intermediate calculation, no of elements in a row.
nn2=round((nNodes-1)*nNodes/2)# number of pairs of vertices 
sc=round(nc*nNodes)           # bits required for representing the colouring matrix
sg=round(nc*nNodes + nn2)     # Index number from which the graph representation bits start
nqbits= sc + 2*nn2            # calculation of total qubits

In [None]:
# Create a Quantum Circuit
q = QuantumRegister(nqbits+1)
c=ClassicalRegister(nColors*nNodes)
qc = QuantumCircuit(q,c)

In [None]:
# Importing graph visualization libraries
import networkx as nx
import matplotlib.pyplot as plt

G = nx.Graph()

In [None]:
#Graph adjacency input
# Add the graph (binary list of nNodes*(nNodes-1)/2 elements,
# because the graph is symmetric and there is no i=>i edges.
# Set to |1> the qubits corresponding to an edge.


'''
3 nodes, needs 3 colors
0 1 1
1 0 1
1 1 0
=> 1 1 1
3 nodes, needs 2 colors
0 1 0
1 0 1
0 1 0
=> 1 0 1
'''
count = 0
for i in range(nNodes):
    for j in range(i+1, nNodes):
        if(int(input("vertex {} adjacent to {} ? ( 1 : yes / 0 : no) : ".format(i+1,j+1))) == 1):
            qc.x(sg+count) #set to one if there is edge
            G.add_edge(i+1, j+1)
        count += 1


qc.draw(output='mpl')

colours = []
for node in G:
    colours.append("green")

nx.draw(G,node_color = colours, with_labels = True, font_color = "white")
plt.show()

In [None]:
#Creating superposition state on colouring matrix qubits
s=0
for n in range(nNodes):
    for k in range(nColors):
        qc.h(s+k)
    s=s+nc

qc.draw(output='mpl')

In [None]:
# A filtering code used to generate only valid colouring matrices.
# Constraint: only one 1 in a row
s=0
for n in range(nNodes):
    for k in range(nColors-1):
        for l in range(k+1,nColors):
            # Eliminate 11
            qc.ccx (s+k,s+l,s+nColors)
            qc.cx (s+nColors,s+k)
            qc.reset(s+nColors)
    # Eliminate 0* (no colour assigned to the node n)
    for k in range(nColors):
        qc.x(s+k) 
    cb=list(range(s,s+nColors) )
    qc.mcx (cb,s+nColors)
    for k in range(nColors):
        qc.x(s+k) 
    qc.cx (s+nColors,s+nColors-1)
    qc.reset(s+nColors)
    s=s+nc
    
qc.barrier()
print('end of colouring matrices')

qc.draw(output='mpl')

In [None]:
#Identifying the pairs of vertices having same colour,
#and setting their respective qubits to 1.
for k in range(nColors):
    s=nc*nNodes
    for n1 in range(nNodes-1):
        for n2 in range(n1+1,nNodes):
            n11=nc*n1+k       # If q[n11]=|1> it means the node n1 has the color k
            n22=nc*n2+k       # If q[n22]=|1> it means the node n2 has the color k
            qc.ccx(n11,n22,s) # If same color k, set s to |1>.
                              # Notice it can happens at most for one k
            s=s+1

print('end of pairs of nodes')
qc.barrier()

qc.draw(output='mpl')

In [None]:
# Compare to the graph and check pair of vertices with same colours and if an edge is present, 
#make it invalid.
for k in range(nColors):
    s=nc*nNodes
    for n1 in range(nNodes-1):
        for n2 in range(n1+1,nNodes):
            n11=nc*n1+k # If q[n11]=|1> it means the node n1 has the color k
            pair = s+nNodes
            ancil = n1*nc + nColors
            cb=[n11,pair,s]
            qc.mcx (cb,ancil)
            cb=[ancil,pair,s]
            qc.mcx (cb,n11)
            qc.reset(ancil)
            qc.barrier()
             # If same color k, set s to |1>.
             # Notice it can happens at most for one k
            s=s+1
print('end of compare to graph')

qc.draw(output='mpl')

In [None]:
#Filter 0 rows
#Eliminating invalid matrices
s=0
for n in range(nNodes):
    
    cb=list(range(s,s+nColors))
 
    qc.x(cb)
    qc.mcx (cb,s+nColors)
    qc.x(cb)
    pos=0
    for m in range(nNodes):
        for k in range(nColors):
            qc.reset(nqbits)
            qc.cswap(s+nColors,pos+k,nqbits)
        qc.barrier()
        pos=pos+nc
    s=s+nc

qc.barrier()

qc.draw(output='mpl')

In [None]:
# Measure (only the qubits describing the colouring matrices)
cb=0
for n in range(nNodes):
    s=n*(nColors+1)
    for k in range(nColors):
        qb=s+k
        qc.measure(qb,cb)
        cb=cb+1
print('end of measures')

qc.draw(output='mpl')

In [None]:
# Quick on small graphs, but memory error for 4 nodes, 3 colours
job = execute(qc, backend_sim,shots=1000)
result=job.result()

print('end of execute')

In [None]:
# Take the results from the job.
counts = result.get_counts(qc)

if len(counts) > 1:
    counts.pop("0"*nNodes*nColors, None)
    
plot_histogram(counts, color='midnightblue', title="Colourings")

In [None]:
#Displaying graph colourings
colour = ["blue", "red", "green", "purple"]

for graph in counts:
    print("Graph for colouring {}".format(graph))
    colours = []
    cList =list(map(''.join, zip(*[iter(graph)]*nColors)))
    for node in cList:
        colours.append(colour[node.index("1")])
    nx.draw(G,node_color = colours, with_labels = True, font_color = "white")
    plt.show()