In [None]:
from discopy.quantum import *
from discopy.drawing import draw
import numpy as np
import sympy as sy
import itertools
from matplotlib import pyplot as plt

In [None]:
def symbols_struct(prefix: str, shape):
    sym_spec = '_'.join(f':{k}' for k in shape)
    return np.reshape(sy.symbols(f'{prefix}_{sym_spec}'), shape)

def array_subs(array, subs):
    shape = array.shape if isinstance(array, np.ndarray) else len(array)

    raise NotImplementedError

def _prepare_vars(variables):
    variables = np.asarray(variables, dtype=np.object).flatten()
    variables = [sy.symbols(v) if isinstance(v, str) else v\
                 for v in variables]
    return variables

def jacobian(circ, variables):
    from discopy.quantum import Circuit, Bra
    assert circ.dom == qubit**0
    variables = _prepare_vars(variables)
    if isinstance(circ, Circuit):
        d = len(circ.cod)
        bras = np.mgrid.__getitem__([slice(0, 2)]*d).reshape((d, -1)).T
        dvs = [(circ >> Bra(*bra)).grad(v) for bra, v in itertools.product(bras, variables)]
    else:
        raise NotImplementedError
    dvs1 = np.empty(len(dvs), dtype=np.object)
    dvs1[:] = dvs
    return np.reshape(dvs1, (len(bras), len(variables)))

def grad(circ, variables):
    variables = _prepare_vars(variables)
    return [circ.grad(v) for v in variables]

def hessian(circ, variables):
    variables = _prepare_vars(variables)
    from discopy.quantum import qubit
    from discopy.tensor import Dim
    assert circ.cod in (qubit**0, Dim(1))
    dvs = [c.grad(v) for c, v in itertools.product(grad(circ, variables), variables)]
    dvs1 = np.empty(len(dvs), dtype=np.object)
    dvs1[:] = dvs
    return np.reshape(dvs1, (len(variables),)*2)

In [None]:
params = symbols_struct('t', (2, 2))
c = Ket(0, 0) >> real_amp_ansatz(params, entanglement='linear')
draw(c)

In [None]:
jacobian(c, params).shape

In [None]:
hessian(c >> Bra(0, 0), params).shape

In [None]:
from sympy.abc import alpha, beta
circ = Ket(0, 0) >> (Ry(alpha) @ Ry(beta)) >> CX
jacobian(circ, [alpha, beta]).shape