# Introduction

This is a companion tutorial to _Unitary Learning by Gradient Descent_ tutorial, but this time with qgrad (and JAX). It intends to juxtapose the ease of learning the parameters with `qgrad` as compared to doing from scratch. as is shown in the previous tutorial.

In [1]:
import matplotlib.pyplot as plt 
import jax.numpy as jnp
from jax import grad, jit
import tenpy
from jax.random import normal, PRNGKey
import numpy as onp
from qgrad.qgrad_qutip import fidelity, basis
from scipy.stats import unitary_group
from jax.scipy.linalg import expm

In [2]:
def make_dataset(m, d):
    r"""Prepares a dataset of input and output kets to be used for training.
    
    Args:
    ----
        m (int): Number of data points, 80% of which would be used for training
        d (int): Dimension of a (square) unitary matrix to be approximated
    
    Returns:
    --------
        data_points (tuple): tuple of lists containing (numpy arrays of) input and output kets respectively.
    """
    ket_input = []
    ket_output = [] 
    for i in range(m):
        idx = onp.random.randint(0, d)
        ket_input.append(basis(d, idx))
        ket_output.append(jnp.matmul(tar_unitr, ket_input[i]))  #Output data -- action of unitary on a ket states
    
    return (ket_input, ket_output)

m = 1000 # number of training data points
train_len = int(m * 0.8)
d = 2 #dimension of unitary
N = 5 #size of parameter vectors tau and t
tar_unitr = jnp.array(unitary_group.rvs(d))  # Fixed random d-dimensional target unitary matrix that we want to learn                                          
res = make_dataset(m, d)
ket_input, ket_output = res[0], res[1]



In [3]:
A = jnp.array(tenpy.linalg.random_matrix.GUE((d,d))) # tenpy for sampling A and B from GUE
B = jnp.array(tenpy.linalg.random_matrix.GUE((d,d))) 

def make_unitary(N, params):
    r"""Retruns a paramterized unitary matrix.
    
    : math:: \begin{equation}\label{decomp}
                U(\vec{t}, \vec{\tau}) = e^{-iB\tau_{N}}e^{-iAt_{N}} ... e^{-iB\tau_{1}}e^{-iAt_{1}}
             \end{equation}
             
    Args:
    ----
        N (int): Size of the parameter vectors, :math:`\tau` and :math:`\t`
        params (:obj:`np.ndarray`): parameter vector of size :math:`2 * N` where the first half parameters are  
                                   :math:`\vec{t}` params and the second half encodes \vec{\tau}) parameters.
                                   
    Returns:
        unitary (:obj:`np.ndarray`): numpy array representation of paramterized unitary matrix 
    """
    unitary = jnp.eye(d)
    params = jnp.array(params)
    for i in range (N): 
        unitary = jnp.matmul(jnp.matmul(expm(-1j * B * params[i + N][0]),
                                        expm(-1j * A * params[i][0])), unitary)
    
    return unitary 

In [None]:
def test_score(params, x, y):
    """Calculates the avergage fidelity between the predicted and output kets for a given 
       on the whole dataset.
       
       Args:
       ----
           params: parameters :math:`\t` and :math:`\tau` in :math:`U^{\dagger} U(\vec{t},\vec{\tau})`
           x: input kets :math:`|\psi_{l}>` in the dataset 
           y: output kets :math:`U(\vec{t}, \vec{\tau})*|ket\_input>` in the dataset
           
       Returns:
       -------
           fidel (float): fidelity between :math:`U(\vec{t}, \vec{\tau})*|ket\_input>` and
                          the output (label) kets for parameters :math:`\vec{t}, \vec{\tau}`
                          averaged over the entire training set.
       """
    fidel = 0
    for i in range(train_len):
        pred = np.matmul(make_unitary(N, params), x[i])
        step_fidel = fidelity(pred, y[i])
        fidel += step_fidel
        
        
    return fidel/train_len

In [108]:
def loss(params, inputs, outputs):
    loss_val = 0.0
    unit = make_unitary(N, params)
    for k in range(3):
        pred = jnp.dot(unit, inputs[k]) #prediction wth parametrized unitary
        loss_val += jnp.absolute(jnp.real(jnp.dot(outputs[k].conjugate().T, pred)))
    return loss_val

@jit
def cost(params, inputs, outputs):
    r"""Calculates the cost on the whole training dataset.
    
    Args:
    ----
        params: parameters :math:`\t` and :math:`\tau` in :math:`U^{\dagger} U(\vec{t},\vec{\tau})`
        inputs: input kets :math:`|\psi_{l}>` in the dataset
        outputs: output kets :math:`U(\vec{t}, \vec{\tau})*|ket\_input>` in the dataset
    
    Returns:
    -------
        cost (float): cost (evaluated on the entire dataset) of parametrizing 
                     :math:`\tau` in :math:`U^{\dagger} U(\vec{t},\vec{\tau})` with `params`                  
    """
    '''
    loss = 0.0
    for k in range(train_len):
        pred = jnp.dot(make_unitary(N, params), inputs[k]) 
        loss += jnp.absolute(jnp.real(jnp.dot(outputs[k].conjugate().T, pred)))
        # TODO check real and abs in loss above since   
        # it's not in the original paper
    '''

    loss_val = 1 - (1 / train_len) * loss(params, inputs, outputs)
    return loss_val[0][0]

In [109]:
#der_cost = grad(cost)
cost(random.normal(random.PRNGKey(1), (2 * N, 1)), ket_input, ket_output)

DeviceArray(0.99802923, dtype=float32)

In [110]:
grad(cost)(random.normal(random.PRNGKey(1), (2 * N, 1)), ket_input, ket_output)

ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop. Try using lax.scan instead.

In [81]:
from jax.experimental import loops
from jax import ops 

# Differentiation through loops
In order to capture the structured control-flow one has to use the higher-order JAX operations. Therefore, if the cost function uses loops, one needs to define loops as detailed [here](https://jax.readthedocs.io/en/latest/jax.experimental.loops.html)

In [161]:
def cost_jax(params, inputs, outputs):
    with loops.Scope() as s:
        s.loss = 0.0
        s.params = params
        s.inputs = inputs
        s.outputs = outputs
        s.N = N
        s.train_len = train_len
        for k in range(train_len):
            s.pred = jnp.dot(make_unitary(s.N, s.params), s.inputs[k])
            s.loss += jnp.absolute(jnp.real(jnp.dot(s.outputs[k].conjugate().T, s.pred)))
        s.loss = 1 - (1 / s.train_len) * s.loss
        return s.loss[0][0]

In [96]:
cost_jax(random.normal(random.PRNGKey(1), (2 * N, 1)), ket_input, ket_output)

NameError: name 'cost_jax' is not defined

In [162]:
grad(cost_jax)(onp.random.rand(2 * N, 1), ket_input, ket_output)

ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop. Try using lax.scan instead.

In [34]:
def test(a):
    sum_ = 0
    for i in range(len(a)):
        #print("a", a[i])
        sum_ += sum_ + a[i]
        #print(sum_)
    return sum_ / 5.

In [35]:
jnp.arange(5.)

DeviceArray([0., 1., 2., 3., 4.], dtype=float32)

In [36]:
test(jnp.arange(5))

DeviceArray(5.2, dtype=float32)

In [38]:
grad(test)(jnp.arange(10.))

DeviceArray([102.4,  51.2,  25.6,  12.8,   6.4,   3.2,   1.6,   0.8,
               0.4,   0.2], dtype=float32)

In [40]:
def test2(n):
    ans = 0
    for i in range(n):
        iden = jnp.eye(2)
        ans = jnp.dot(iden, basis(2))
    return ans

In [46]:
grad(test2)(4.)

TypeError: 'JVPTracer' object cannot be interpreted as an integer

In [90]:
from jax import random

In [91]:
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))

In [92]:
random_flattened_image

DeviceArray([-2.47086883e+00,  6.43303931e-01,  1.98767751e-01,
              1.04226995e+00, -3.28564614e-01, -1.39949054e-01,
             -1.61815274e+00,  5.27156554e-02,  1.56182075e+00,
              2.31161676e-02,  9.45196986e-01, -1.32468688e+00,
             -4.57311004e-01, -1.60654560e-01,  6.61273953e-04,
              3.99095118e-01, -9.17360544e-01, -5.29269695e-01,
             -4.47437018e-01,  8.68972898e-01, -1.78759992e-01,
              6.88066781e-01,  2.00557494e+00, -9.70265865e-01,
              7.63374984e-01,  4.68729347e-01,  1.00560963e+00,
             -8.98162067e-01, -9.02278304e-01, -9.10900295e-01,
              1.15207247e-01, -1.45289779e+00, -1.47773966e-01,
              7.89271474e-01,  4.46381688e-01, -1.34863889e+00,
              6.53553724e-01, -6.93057656e-01, -2.35830415e-02,
             -2.05141592e+00,  1.19544709e+00, -2.32006028e-01,
              1.74441651e-01, -7.04607069e-01, -1.71720314e+00,
              4.92502391e-01,  1.6793657

In [95]:
random.normal(random.PRNGKey(1), (2 * N, 1))

DeviceArray([[ 0.6908049 ],
             [-0.4874411 ],
             [-1.155789  ],
             [ 0.12108456],
             [-0.19598441],
             [-0.5078767 ],
             [ 0.9156865 ],
             [ 1.7096796 ],
             [-0.36749426],
             [ 0.14315681]], dtype=float32)