# Tracing the results of Haug and Kim (PRL 133, 050603)

In [None]:
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import numpy as np
import scipy
import jax
import jax.numpy as jnp
import qujax

In [None]:
jax.config.update('jax_enable_x64', True)

## DQFIM of hardware-efficient ansatz with Haar-random training states

The goal of this exercise is to reproduce the blue lines of Figure 2(a). We will construct the HE
ansatz unitary as a function of the Ry and Rz gate parameters using qujax, and directly compute
the DQFIM matrix Q_nm of Equation (4) using Haar-random input states for rho_L.

In [None]:
def he_ansatz_args(num_qubits, num_layers):
    """Return the HE ansatz circuit as a (gates, qubit indices, params indices) argument to qujax"""
    if num_qubits % 2:
        raise ValueError('HE ansatz is only defined for even Nq')

    params_per_layer = 2 * num_qubits
    gates = (['Ry'] * num_qubits + ['Rz'] * num_qubits + ['CX'] * num_qubits) * num_layers
    layer_qubit_inds = [[iq] for iq in range(num_qubits)] * 2
    layer_qubit_inds += [[iq, (iq + 1) % num_qubits] for iq in range(0, num_qubits, 2)]
    layer_qubit_inds += [[iq, (iq + 1) % num_qubits] for iq in range(1, num_qubits, 2)]
    qubit_inds = layer_qubit_inds * num_layers
    params_inds = sum(
        ([[il * params_per_layer + iq] for iq in range(num_qubits)]
         + [[il * params_per_layer + num_qubits + iq] for iq in range(num_qubits)]
         + [[]] * num_qubits
         for il in range(num_layers)),
        []
    )
    return gates, qubit_inds, params_inds
    

In [None]:
num_qubits = 4
num_layers = 6
# Make sure we have the intended circuit
qujax.print_circuit(*he_ansatz_args(num_qubits, num_layers))
# Make a function that returns the unitary matrix from the parameter values
unitary = qujax.get_params_to_unitarytensor_func(*he_ansatz_args(num_qubits, num_layers))
# Make a function that returns the derivative of the unitary matrix from the parameter values
d_unitary = jax.jacfwd(unitary, holomorphic=True)

In [None]:
rng = np.random.default_rng()

def dqfim(params, rho):
    """Compute Q_nm for given theta and rho."""
    u = unitary(params).reshape(2 ** num_qubits, 2 ** num_qubits)
    du = d_unitary(params).reshape(2 ** num_qubits, 2 ** num_qubits, -1)
    qmat = jnp.einsum('ijn,jk,ikm->nm', du, rho, du.conjugate())
    qmat -= (jnp.einsum('ijn,jk,ik', du, rho, u.conjugate())
             * jnp.einsum('ij,jk,ikm', u, rho, du.conjugate()))
    qmat = 4. * qmat.real
    return qmat

def dqfim_rank(params, rho):
    """Compute the rank of Q_nm."""
    return jnp.linalg.matrix_rank(dqfim(params, rho))

# Vectorize the rank calculation to obtain the mean rank under many random theta values
v_dqfim_rank = jax.vmap(dqfim_rank, in_axes=(0, None))

def dqfim_rank_mean(rho, num_pset=10):
    """Compute the mean rank of DQFIM obtained with multiple parameter value sets."""
    # Qujax expects gate parameters to be in [0, 2] -> multiply the output of random() by 2
    # Input to dqfim must be of complex dtype, so we add 0.j
    params_set = 2. * rng.random((num_pset, num_qubits * 2 * num_layers)) + 0.j
    return np.mean(np.asarray(v_dqfim_rank(params_set, rho)))

In [None]:
num_samples = 1
hdim = 2 ** num_qubits
rand_uni = scipy.stats.unitary_group.rvs(hdim, size=num_samples).reshape((num_samples, hdim, hdim))
states = rand_uni[:, :, 0]
rho = np.einsum('li,lj->ij', states, states.conjugate()) / num_samples
dqfim_rank_mean(rho)