<a href="https://colab.research.google.com/github/sid8123/qml/blob/master/Decomposition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from collections import defaultdict
from functools import reduce, lru_cache
from itertools import product
import numpy as np

In [3]:
PAULIS = {
    "I": np.eye(2, dtype=complex),
    "X": np.array([[0, 1], [1, 0]], dtype=complex),
    "Y": np.array([[0, -1j], [1j, 0]], dtype=complex),
    "Z": np.array([[1, 0], [0, -1]], dtype=complex),
}


def decompose(H):
    """Decomposes a Hermitian matrix in to a linear sum of tensor products of
    Pauli matrices.
    Args:
        H (ndarray): Hermitian matrix of dimension (2^n x 2^n).
    Prints/Returns:
        components (defaultdict): Dictionary with tensor products of Pauli
        matrices as keys, and corresponding (non-zero) coefficients as values,
        that decompose H.
    """
    n = int(np.log2(len(H)))
    dims = 2 ** n

    if H.shape != (dims, dims):
        raise ValueError("The input must be a 2^n x 2^n dimensional matrix.")

    basis_key = ["".join(k) for k in product(PAULIS.keys(), repeat=n)]
    components = defaultdict(int)
    
    for i, val in enumerate(product(PAULIS.values(), repeat=n)):
        basis_mat = reduce(np.kron, val)
        coeff = H.reshape(-1).dot(basis_mat.reshape(-1)) / dims
        coeff = np.real_if_close(coeff).item()

        if not np.allclose(coeff, 0):
            components[basis_key[i]] = coeff

    print(components)

In [4]:
from datetime import datetime
start=datetime.now()
H = np.array([[ 1.5,  0,  0,  0.5, 0, 0.5, 0.5, 0],
       [ 0, 0.5,  0,  0,  0, 0, 0, 0.5],
       [ 0, 0, 0.5, 0, 0, 0, 0, 0.5],
       [0.5, 0, 0, -0.5, 0, 0, 0, 0],
       [0, 0, 0, 0, 0.5, 0, 0, 0.5],
       [0.5, 0, 0, 0, 0, -0.5, 0, 0],
       [0.5, 0, 0, 0, 0, 0, -0.5, 0],
       [0, 0.5, 0.5, 0, 0.5, 0, 0, -1.5]], dtype=np.complex128)

decompose(H)
print(datetime.now()-start)

defaultdict(<class 'int'>, {'IIZ': 0.5, 'IXX': 0.25, 'IYY': -0.25, 'IZI': 0.5, 'XIX': 0.25, 'XXI': 0.25, 'YIY': -0.25, 'YYI': -0.25, 'ZII': 0.5})
0:00:00.018253
