In [1]:
from functools import reduce
import networkx as nx
import numpy as np
import tensornetwork as tn

# Sketch of QAOA tensor network

Here, we consider the case of a chain, for which the Tree Tensor Network becomes a Matrix Product State.

![Tensor network structure](./figures/qaoa_tn_chain.png)

# Algorithms

## Tree algorithms

In [2]:
def regular_tree(degree, depth):
    tree = nx.Graph()
    def add_descendants(node):
        if len(node) == depth:
            return
        for suffix in range(degree if node == tuple() else degree - 1):
            new_node = node + (suffix,)
            tree.add_node(new_node)
            tree.add_edge(node, new_node)
            add_descendants(new_node)
    tree.add_node(tuple())
    add_descendants(tuple())
    return tree

## Tensor contraction algorithms

### `TensorStack` class.

This class exposes a stack of elementary tensors (where the tensors act on a few qubits and the stack grows in the depth direction) as a tensor with many indices which can then be efficiently contracted with other such tensors.

In [3]:
class TensorStack:
    def __init__(self, tn_nodes, root_edges, branches_edges):
        self.tn_nodes = tn_nodes
        self.root_edges = root_edges
        self.branches_edges = branches_edges
        self.num_branches = len(branches_edges)
    
    @classmethod
    def from_elementary_tensors(cls, num_branches, tensor_list):
        vector_shape = (2,) * (1 + num_branches)
        matrix_shape = vector_shape * 2
        for tensor_idx, tensor in enumerate(tensor_list):
            if tensor.shape != vector_shape and tensor.shape != matrix_shape:
                raise ValueError("Invalid shape for tensor {}: {}, {} or {} expected".format(tensor_idx, tensor.shape, vector_shape, matrix_shape))
        root_edges = []
        num_branches = num_branches
        branches_edges = [[] for _ in range(num_branches)]
        tn_nodes = []
        for tensor in tensor_list:
            tn_node = tn.Node(tensor)
            tn_nodes.append(tn_node)
            if tensor.shape == matrix_shape:
                root_edges.append(tn_node.edges[num_branches + 1])
                for branch in range(num_branches):
                    branches_edges[branch].append(tn_node.edges[num_branches + 2 + branch])
            root_edges.append(tn_node.edges[0])
            for branch in range(num_branches):
                branches_edges[branch].append(tn_node.edges[1 + branch])
        return cls(tn_nodes, root_edges, branches_edges)

    def contract_branches(self, branches_vectors):
        if self.num_branches != len(branches_vectors):
            raise ValueError("Unexpected number of branches vectors: {}, {} expected".format(len(branches_vectors), self.num_branches))
        for branch in range(self.num_branches):
            branch_vector = branches_vectors[branch]
            if len(branch_vector.branches_edges):
                raise ValueError("Vector for branch {} has branches edges.".format(branch))
            if len(self.branches_edges[branch]) != len(branch_vector.root_edges):
                raise ValueError("Invalid number of edges for branch {}: {}, {} expected".format(branch, len(branch_vector.root_edges), len(self.branches_edges[branch])))
        for branch in range(self.num_branches):
            branch_vector = branches_vectors[branch]
            for edge in range(len(self.branches_edges[branch])):
                self.branches_edges[branch][edge] ^ branch_vector.root_edges[edge]
        return tn.contractors.greedy(
            sum([branch_vector.tn_nodes for branch_vector in branches_vectors], []) + self.tn_nodes,
            output_edge_order=self.root_edges
        )
    
    def contract_branches_and_root(self, branches_vectors, root_vector):
        if len(self.root_edges) != len(root_vector.root_edges):
            raise ValueError("Unexpected number of edges for root: {}, expected {}".format(len(root_vector.root_edges), len(self.root_edges)))
        branches_contracted = self.contract_branches(branches_vectors)
        for edge in range(len(root_vector.root_edges)):
            branches_contracted.edges[edge] ^ root_vector.root_edges[edge]
        return tn.contractors.greedy([branches_contracted] + root_vector.tn_nodes)

### Elementary matrices occurring in the definition of each tensor stack.

They involve the $U_B(\beta)$, $U_C(\gamma)$ matrices from the QAOA.

In [4]:
def ising_rotation_matrix(num_qubits, qubit1, qubit2, gamma):
    tensor_factors = [np.eye(2)] * num_qubits
    tensor_factors[qubit1] = np.cos(gamma / 2) * np.eye(2, dtype=complex)
    matrix = reduce(np.kron, tensor_factors)
    tensor_factors = [np.eye(2)] * num_qubits
    tensor_factors[qubit1] = -1j * np.sin(gamma / 2) * np.diag([1, -1])
    tensor_factors[qubit2] = np.diag([1, -1])
    matrix += reduce(np.kron, tensor_factors)
    return matrix

def x_rotation_matrix(num_qubits, qubit, beta):
    tensor_factors = [np.eye(2)] * num_qubits
    tensor_factors[qubit] = np.cos(beta / 2) * np.eye(2) - 1j * np.sin(beta / 2) * np.array([[0, 1], [1, 0]])
    matrix = reduce(np.kron, tensor_factors)
    return matrix

def qaoa_step_matrix(num_branches, beta, gamma, even):
    matrix = reduce(
        lambda x, y: x @ y,
        [
            ising_rotation_matrix(1 + num_branches, 0, branch_qubit, gamma)
            for branch_qubit in range(1, num_branches + 1)
        ]
    )
    if even:
        matrix = reduce(
            lambda x, y: x @ y,
            [
                x_rotation_matrix(1 + num_branches, qubit, beta)
                for qubit in range(num_branches + 1)
            ]
        ) @ matrix
    return matrix

### Constructors for tensor stacks.

In [5]:
def qaoa_bulk_stack(num_branches, betas, gammas, even, observable=None):
    if not len(betas) or len(betas) != len(gammas):
        raise ValueError("Angles must be specified")
    tensor_list = []
    vector_shape = (2,) * (1 + num_branches)
    matrix_shape = vector_shape * 2
    if even:
        # Matrices up to middle.
        for beta, gamma in zip(betas[0:-1], gammas[0:-1]):
            tensor_list.append(qaoa_step_matrix(num_branches, beta, gamma, even).reshape(matrix_shape))
        # Middle matrix.
        if observable is None:
            observable = np.eye(2 ** (1 + num_branches))
        step_matrix = qaoa_step_matrix(num_branches, betas[-1], gammas[-1], True)
        tensor_list.append((step_matrix.T.conj() @ observable @ step_matrix).reshape(matrix_shape))
        # Middle to top matrices.
        for beta, gamma in zip(betas[-2::-1], gammas[-2::-1]):
            tensor_list.append(qaoa_step_matrix(num_branches, beta, gamma, even).T.conj().reshape(matrix_shape))
    else:
        # Bottom matrix.
        tensor_list.append((qaoa_step_matrix(num_branches, betas[0], gammas[0], False) @ ((1 / np.sqrt(2)) ** (1 + num_branches) * np.ones(2 ** (1 + num_branches)))).reshape(vector_shape))
        # Matrices up to middle.
        for beta, gamma in zip(betas[1:], gammas[1:]):
            tensor_list.append(qaoa_step_matrix(num_branches, beta, gamma, even).reshape(matrix_shape))
        # Middle to top matrices.
        for beta, gamma in zip(betas[-1:0:-1], gammas[-1:0:-1]):
            tensor_list.append(qaoa_step_matrix(num_branches, beta, gamma, even).T.conj().reshape(matrix_shape))
        # Top matrix.
        tensor_list.append((((1 / np.sqrt(2)) ** (1 + num_branches) * np.ones(2 ** (1 + num_branches))) @ qaoa_step_matrix(num_branches, betas[0], gammas[0], False).T.conj()).reshape(vector_shape))
    return TensorStack.from_elementary_tensors(num_branches, tensor_list)

def qaoa_boundary_stack(betas, even, observable=None):
    if not len(betas):
        raise ValueError("Angles must be specified")
    tensor_list = []
    if even:
        if observable is None:
            observable = np.eye(2)
        # Matrices up to middle.
        for beta in betas[:-1]:
            tensor_list.append(x_rotation_matrix(1, 0, beta))
        # Middle matrix.
        matrix = x_rotation_matrix(1, 0, betas[-1])
        tensor_list.append(matrix.T.conj() @ observable @ matrix)
        # Matrices from middle to top.
        for beta in betas[-2::-1]:
            tensor_list.append(x_rotation_matrix(1, 0, beta).T.conj())
    else:
        tensor_list.append(1 / np.sqrt(2) * np.ones(2))
        tensor_list.extend([np.eye(2) for _ in range(2 * len(betas) - 2)])
        tensor_list.append(1 / np.sqrt(2) * np.ones(2))
    return TensorStack.from_elementary_tensors(0, tensor_list)

### Final QAOA evaluation algorithm

In [6]:
def qaoa_evaluate(tree, root, betas, gammas, observables):
    def qaoa_evaluate_helper(even, current_root, current_root_successors):
        current_observable = reduce(np.kron, [observables.get(qubit, np.eye(2)) for qubit in [current_root] + current_root_successors]) if even else None
        if len(current_root_successors):
            tensor_stack = qaoa_bulk_stack(len(current_root_successors), betas, gammas, even, current_observable)
            branches_vectors = [
                qaoa_evaluate_helper(not even, current_root_successor, list(tree.successors(current_root_successor)))
                for current_root_successor in current_root_successors
            ]
            branches_contracted = tensor_stack.contract_branches(branches_vectors)
            return TensorStack([branches_contracted], branches_contracted.edges, [])
        else:
            return qaoa_boundary_stack(betas, even, current_observable)

    root_successors = list(tree.successors(root))
    vector1 = qaoa_evaluate_helper(False, root, root_successors[:len(root_successors) // 2])
    vector2 = qaoa_evaluate_helper(True, root, root_successors[len(root_successors) // 2:])
    return vector1.contract_branches_and_root([], vector2)

# Tests

## 3-regular graphs

### $p = 1$

In [7]:
tree = nx.dfs_tree(regular_tree(3, 7))
evaluation = qaoa_evaluate(
    tree=tree,
    root=tuple(),
    betas=[-np.pi / 4],
    gammas=[np.arctan(1 / np.sqrt(2))],
    observables={
        tuple(): np.diag([1, -1]),
        (0,): np.diag([1, -1])
    }
)
0.5 * (1 - evaluation.tensor.real)

0.6924500897298674

### $p = 2$

In [8]:
tree = nx.dfs_tree(regular_tree(3, 7))
evaluation = qaoa_evaluate(
    tree=tree,
    root=tuple(),
    betas=[2 * (2.12560098), 2 * (-0.2923307)],
    gammas=[-0.4878635, 2.24375996],
    observables={
        tuple(): np.diag([1, -1]),
        (0,): np.diag([1, -1])
    }
)
0.5 * (1 - evaluation.tensor.real)

0.7559064492764007

### $p = 3$

In [9]:
tree = nx.dfs_tree(regular_tree(3, 7))
evaluation = qaoa_evaluate(
    tree=tree,
    root=tuple(),
    betas=[2 * (0.9619), 2 * (2.6820), 2 * (1.8064)],
    gammas=[2.7197, 5.4848, 2.2046],
    observables={
        tuple(): np.diag([1, -1]),
        (0,): np.diag([1, -1])
    }
)
0.5 * (1 - evaluation.tensor.real)

0.7923984115833747

### $p = 4$

In [10]:
tree = nx.dfs_tree(regular_tree(3, 7))
evaluation = qaoa_evaluate(
    tree=tree,
    root=tuple(),
    betas=[2 * (5.6836), 2 * (1.1365), 2 * (5.9864), 2 * (4.8714)],
    gammas=[0.4088, 0.7806, 0.9880, 4.2985],
    observables={
        tuple(): np.diag([1, -1]),
        (0,): np.diag([1, -1])
    }
)
0.5 * (1 - evaluation.tensor.real)

0.8168765522352619