In [None]:
import jax.numpy as jnp

from qgrad.qgrad_qutip import basis, to_dm, Displace, destroy, coherent, fidelity, sigmax

In [35]:
def pad_thetas(hilbert_size, thetas):
    """
    Pads zeros to the end of a theta vector to fill it upto the Hilbert space cuttoff
    
    Args:
    -----
        hilbert_size (int): Size of the hilbert space
        thetas (:obj:`jnp.ndarray`): List of angles thetas
    
    Returns:
    --------
        thetas (:obj:`jnp.ndarray`): List of angles padded with zeros in place of Hilbert space cutoff

    """
    if len(thetas) != hilbert_size:
        thetas = jnp.pad(thetas, (0, hilbert_size - len(thetas)), mode="constant")
    return thetas

def snap(hilbert_size, thetas):
    """
    SNAP gate matrix
    
    Args:
    -----
        hilbert_size (int): Hilbert space cuttoff
        thetas (:obj:`jnp.ndarray`): A vector of theta values to apply SNAP operation
    
    Returns:
    --------
        op (:obj:`jnp.ndarray`): matrix representing the SNAP gate
    """
    thetas = pad_thetas(hilbert_size, thetas)
    op = 0 * jnp.eye(hilbert_size)
    for i, theta in enumerate(thetas):
        op += jnp.exp(1j * theta) * to_dm(basis(hilbert_size, i))
    return op

In [40]:
N = 10
alpha = 0.1
theta = jnp.array([0.1, 0.3, 0.5])

def block(alpha, theta):
    """Single building block, B
    
    Args:
    ----
        alpha (float): Displacement parameter
        theta (jnp.array): SNAP gate parameters
        
    Returns:
    -------
        blk (jnp.ndarray): One block parameterization of U
    """
    displace = Displace(10)
    blk = jnp.dot(snap(N, theta), displace(alpha))
    return jnp.dot(dag(displace(alpha)), blk)