In [5]:
from qgrad.qgrad_qutip import basis, to_dm, dag, Displace, fidelity
import jax.numpy as jnp
from jax import grad, jit
from functools import reduce
import matplotlib.pyplot as plt

In [6]:
def pad_thetas(hilbert_size, thetas):
    """
    Pads zeros to the end of a theta vector to fill it upto the Hilbert space cuttoff
    
    Args:
    -----
        hilbert_size (int): Size of the hilbert space
        thetas (:obj:`jnp.ndarray`): List of angles thetas
    
    Returns:
    --------
        thetas (:obj:`jnp.ndarray`): List of angles padded with zeros in place of Hilbert space cutoff

    """
    if len(thetas) != hilbert_size:
        thetas = jnp.pad(thetas, (0, hilbert_size - len(thetas)), mode="constant")
    return thetas

def snap(hilbert_size, thetas):
    """
    SNAP gate matrix
    
    Args:
    -----
        hilbert_size (int): Hilbert space cuttoff
        thetas (:obj:`jnp.ndarray`): A vector of theta values to apply SNAP operation
    
    Returns:
    --------
        op (:obj:`jnp.ndarray`): matrix representing the SNAP gate
    """
    thetas = pad_thetas(hilbert_size, thetas)
    op = 0 * jnp.eye(hilbert_size)
    for i, theta in enumerate(thetas):
        op += jnp.exp(1j * theta) * to_dm(basis(hilbert_size, i))
    return op

In [7]:
N = 10 # dimension of Fock space
alphas = jnp.array([1., 0.5, 1.], dtype=complex)
theta1, theta2, theta3 = [0.5], [0.5, 0, -1.5], [0.5, -1.5j, 0.2j] 
# NOTE: No input values to JAX differentiable functions should be int                                  
theta1, theta2, theta3 = pad_thetas(N, theta1), pad_thetas(N, theta2), pad_thetas(N, theta3)


In [30]:
displace = Displace(N)

def add_blocks(initial, T, hilbert_size, params):
    """Applies T blocks of operators to the initial state, where each of the T blocks is
       a collection of operators, 

           :: math : `D(\alpha) SNAP(\theta) D(-\alpha),
   
       where :math:`D(\alpha)` is the displacement operator with displacement :math:`alpha` 
       and :math:`SNAP(\theta)` is the SNAP gate with parameter vector :math:`\theta` of length N, 
       the size of the Hilbert space.

    Args:
    ------
        initial (jnp.ndarray): initial state to apply blocks on (ket |0> in our case)
        T (int): number of blocks to apply
        hilbert_size (int): Size of the Hilbert space
        params (jnp.ndarray): parameter array of alphas and that's of size :math: `T * hilbert_size + T`, 
        wherein the first T parameters are alphas and the rest are T hilbert_size-dimensional vectors 
        representing corresponding theta vectors.

    Returns:
    -----------
         evolved (jnp.array): hilbert_size * 1-dimensional array representing the action of the 
         blocks on the vacuum state
    """
    for t in range(T):
        blk = jnp.dot(displace(params[0 + t]), initial)
        print(jnp.linalg.norm(blk))
        blk = jnp.dot(snap(hilbert_size, params[T + (t * hilbert_size) : T + (t + 1) * hilbert_size]), blk)
        print(jnp.linalg.norm(blk))
        initial = jnp.dot(displace(-params[0 + t]), blk) # displace(alpha)^{\dagger} = displace(-alpha)
        print(jnp.linalg.norm(initial))

    evolved = initial
    return evolved

def cost(params, initial, target, T, hilbert_size):
    """Calculates the cost, in this case fidelity, between the target state and 
    the one evolved by the action of three blocks.
    
    Args:
    -----
        params (jnp.array): alpha and theta params of Displace and SNAP respectively, with first three
                           being alpha and rest three being theta for each of the SNAP
        initial (jnp.array): initial state to apply the blocks on
        target (jnp.array): desired state
    
    Returns:
    --------
        fidelity (float): Fidelity between the target state and the evolved state
    """
    '''
    tmp = jnp.dot(displace(params[0]), initial)
    tmp = jnp.dot(snap(N, params[3:3+N]), tmp)
    tmp = jnp.dot(displace(-params[0]), tmp)
    
    tmp = jnp.dot(displace(params[1]), tmp)
    tmp = jnp.dot(snap(N, params[3+N:3+2*N]), tmp)
    tmp = jnp.dot(displace(-params[1]), tmp)
    
    tmp = jnp.dot(displace(params[2]), tmp)
    tmp = jnp.dot(snap(N, params[3+2*N:]), tmp)
    tmp = jnp.dot(displace(-params[2]), tmp)
    
    return 1 - fidelity(target, tmp)[0][0]
    '''
    evo = add_blocks(initial, T, hilbert_size, params)
    return 1 - fidelity(target, evo)[0][0]

In [49]:
epochs = 10
T = 3
N = 10
lr = 0.01 #learning rate
tol = 1e-7
diff = 1 # diff of new and prev weights should be less than diff
max_iters = 2
iters = 0
params = jnp.concatenate((alphas, theta1, theta2, theta3)).reshape(3 * N + 3, 1)
der_cost = grad(cost) #autodiff of the cost function
initial = basis(N, 0)
target = (jnp.sqrt(3) * basis(10, 3) +  basis(10, 9)) / 2.0

for epoch in range (epochs):
    iters = 0
    diff = 1
    tol = 1e-7
    while iters < max_iters:
        prev_params = params
        der = der_cost(prev_params, initial, target, T, N)
        params = prev_params - lr * der 
        iters += 1
        diff = jnp.absolute(params - prev_params)
    cost_val = cost(params, initial, target, T, N)
    progress = [epoch+1, cost_val]
    if ((epoch) % 1 == 0):
        print("Epoch: {:2f} | Cost: {:3f}".format(*jnp.asarray(progress)))

Epoch: 1.000000 | Cost: -0.049726
Epoch: 2.000000 | Cost: 0.034475
Epoch: 3.000000 | Cost: 0.123562
Epoch: 4.000000 | Cost: 0.214488
Epoch: 5.000000 | Cost: 0.304598
Epoch: 6.000000 | Cost: 0.392123
Epoch: 7.000000 | Cost: 0.475542
Epoch: 8.000000 | Cost: 0.553245
Epoch: 9.000000 | Cost: 0.623634
Epoch: 10.000000 | Cost: 0.685484


In [40]:
initial = basis(N, 0)
target = (jnp.sqrt(3) * basis(10, 3) +  basis(10, 9)) / 2.0
params = jnp.concatenate((alphas, theta1, theta2, theta3)).reshape(3 * N + 3, 1)

tmp = jnp.dot(displace(params[0]), initial)
tmp = jnp.dot(snap(N, params[3:3+N]), tmp)
tmp = jnp.dot(displace(-params[0]), tmp)
#print(jnp.linalg.norm(tmp))

tmp = jnp.dot(displace(params[1]), tmp)
tmp = jnp.dot(snap(N, params[3+N:3+2*N]), tmp)
tmp = jnp.dot(displace(-params[1]), tmp)
#print(jnp.linalg.norm(tmp))

tmp = jnp.dot(displace(params[2]), tmp)
print(jnp.linalg.norm(tmp))
tmp = jnp.dot(snap(N, params[3+2*N:3+3*N]), tmp)
print(jnp.linalg.norm(tmp))
tmp = jnp.dot(displace(-params[2]), tmp)

print(jnp.linalg.norm(tmp))


0.9999977
2.8582945
2.8582935


In [31]:
add_blocks(initial, 3, 10, params)

0.9999995
0.9999995
0.9999991
0.99999857
0.9999985
0.99999815
0.9999977
2.8582945
2.8582935


DeviceArray([[ 1.7000009 +1.0685443j ],
             [ 0.25991967-0.48506147j],
             [-1.0563033 -0.26160926j],
             [ 1.1297278 +0.44781005j],
             [-0.82733935-0.343898j  ],
             [ 0.49013975+0.20022798j],
             [-0.25088373-0.09952512j],
             [ 0.11501405+0.04432671j],
             [-0.04734494-0.01780776j],
             [ 0.02078515+0.00759719j]], dtype=complex64)

In [28]:
jnp.linalg.norm(add_blocks(initial, params, 3, 10))

0.9999995
0.9999995
0.9999991
0.99999857
0.9999985
0.99999815
0.9999977
2.8582945
2.8582935


DeviceArray(2.8582935, dtype=float32)

# Decreasing Norm
With continued application of each block, actually each operation within each block, the norm of the ket seems to be decreasing