In [1]:
import numpy as np

In [63]:
# useful matrices
Identity_22 = np.eye(2, dtype=np.complex128)
Pauli_x = np.array([[0, 1], [1, 0]], dtype=np.complex128)

# threshold
thr = 10**-9

In [183]:
def is_unitary(A):
    n = A.shape[0]
    if (A.shape != (n, n)):
        raise ValueError("Matrix is not square.")
    A = np.array(A)
    return np.allclose(np.eye(n), A @ A.conj().T)


def is_identity(A):
    n = A.shape[0]
    if (A.shape != (n, n)):
        raise ValueError("Matrix is not square.")
    return np.allclose(A, np.eye(n))


def elimination_matrix(a,b):
    # a, b allowed to be complex
    
    # impose theta real + positive {eq.10}
    theta = np.arctan(abs(b/a))
    
    # lambda is the negative arg() of a
    lamda = - np.angle(a)
    
    # {eq.12}
    mu = np.pi + np.angle(b)
    
    # {eq.7}
    U_special = np.array([ [np.exp(1j*lamda) * np.cos(theta), np.exp(1j*mu) * np.sin(theta)],
                           [-np.exp(-1j*mu) * np.sin(theta), np.exp(-1j*lamda) * np.cos(theta)] ])
    
    return U_special


def two_level_decomp(A):
    n = A.shape[0]
    decomp = []
    indices = []
    A_c = np.copy(A)

    n = A_c.shape[0]
    for i in range(n-2):
        for j in range(n-1, i, -1):

            a = A_c[i,j-1]
            b = A_c[i,j]

            # --- need checks --- 
            # if A[i,j] = 0, nothing to do! Except in last row - need to check diagonal element is 1 
            if abs(A_c[i,j]) < thr:
                U_22 = Identity_22

                if j == i+1:
                    U_22 = np.array([[1 / a, 0], [0, a]])

            # if A[i,j-1] = 0, need to swap columns - again checking last row to ensure diagonal element is 1 
            elif abs(A_c[i,j-1]) < thr:
                U_22 = Pauli_x

                if j == i+1:
                    U_22 = np.array([[1 / b, 0], [0, b]])

            # Special unitary matrix
            else: 
                U_22 = elimination_matrix(a,b)

            # ----- U_22 found -----

            # multiply submatrix of A with U_22
            A_c[:,(j-1,j)] = A_c[:,(j-1,j)] @ U_22

            # If not the identity matrix - represents a gate! So should store
            if not is_identity(U_22):
                decomp.append(U_22.conj().T)
                indices.append(np.array([j-1,j]))


        # check for diagonal element equal to 1
        assert np.allclose(A_c[i,i],1.0)
    
    # lower right hand 2x2 matrix remaining after decomp
    lower_rh_matrix = A_c[n-2:n, n-2:n]
    
    # if not equal to I - is a non trivial gate
    if not is_identity(lower_rh_matrix):
        decomp.append(lower_rh_matrix)
        indices.append(np.array([n-2,n-1]))

    return decomp, indices


def gray_method(A):
    
    n = A.shape[0]
    
    # using bitwise_xor find Gray permutations
    permutations = []
    for i in range(n):
        permutations.append(i ^ (i // 2))
        
    # 
    A[:,:] = A[:,permutations]
    A[:,:] = A[permutations,:]
    
    decomp, indices = two_level_decomp(A)
    new_ind = []
    
    for pair in indices:
        new_ind.append(np.take(permutations, pair, 0))
        
        
    return decomp, new_ind
    

In [184]:
from scipy.stats import unitary_group
nq = 2
A = unitary_group.rvs(2**nq)
decomp = two_level_decomp(A)
decomp

([array([[-0.27938904-0.94005548j,  0.07399077-0.18100504j],
         [-0.07399077-0.18100504j, -0.27938904+0.94005548j]]),
  array([[-0.26992115+4.74461663e-01j,  0.83787153-1.02609669e-16j],
         [-0.83787153-1.02609669e-16j, -0.26992115-4.74461663e-01j]]),
  array([[ 0.03435234-1.87689819e-01j,  0.98162745-1.20214691e-16j],
         [-0.98162745-1.20214691e-16j,  0.03435234+1.87689819e-01j]]),
  array([[-0.78987482-0.1433975j ,  0.58827772-0.09728433j],
         [-0.58827772-0.09728433j, -0.78987482+0.1433975j ]]),
  array([[-0.2908144 -3.75514655e-01j,  0.88000894-1.07770013e-16j],
         [-0.88000894-1.07770013e-16j, -0.2908144 +3.75514655e-01j]]),
  array([[-0.67174596-0.54306352j,  0.21235242+0.45688711j],
         [-0.43583427+0.25276051j, -0.86170921+0.06014864j]])],
 [array([2, 3]),
  array([1, 2]),
  array([0, 1]),
  array([2, 3]),
  array([1, 2]),
  array([2, 3])])

In [185]:
gray_method(A)

([array([[ 0.07399077-0.18100504j, -0.27938904-0.94005548j],
         [ 0.27938904-0.94005548j,  0.07399077+0.18100504j]]),
  array([[-0.26992115+4.74461663e-01j,  0.83787153-1.02609669e-16j],
         [-0.83787153-1.02609669e-16j, -0.26992115-4.74461663e-01j]]),
  array([[ 0.03435234-1.87689819e-01j,  0.98162745+3.15715467e-16j],
         [-0.98162745+3.15715467e-16j,  0.03435234+1.87689819e-01j]]),
  array([[-0.78987482-0.1433975j , -0.58827772+0.09728433j],
         [ 0.58827772+0.09728433j, -0.78987482+0.1433975j ]]),
  array([[-0.2908144 -3.75514655e-01j,  0.88000894-1.07770013e-16j],
         [-0.88000894-1.07770013e-16j, -0.2908144 +3.75514655e-01j]]),
  array([[-0.43583427+0.25276051j,  0.86170921-0.06014864j],
         [-0.67174596-0.54306352j, -0.21235242-0.45688711j]])],
 [array([3, 2]),
  array([1, 3]),
  array([0, 1]),
  array([3, 2]),
  array([1, 3]),
  array([3, 2])])