# Introduction

In this example, we shall try to learn an abitrary $2 \times 2$ 
unitary matrix $U$, via gradient descent. We shall start with a 
random parameterized unitary matrix $U(\phi, \theta, \omega)$,
which is essentially a rotation around $Z-Y-Z$ axes. 
The $2 \times 2$ unitary matrix takes the form


$U(\phi, \theta, \omega) = RZ(\omega)RY(\theta)RZ(\phi)= \begin{bmatrix}
 e^{-i(\phi+\omega)/2}\cos(\theta/2) & -e^{i(\phi-\omega)/2}\sin(\theta/2) \\
 e^{-i(\phi-\omega)/2}\sin(\theta/2) & e^{i(\phi+\omega)/2}\cos(\theta/2)
 \end{bmatrix}$
 
 
This comes with `qgrad` under `rot`.


Here the input dataset consists of $2 \times 1$ kets, call them
$| \Psi_{i} \rangle$ and output dataset is the action of the 
target unitary $U$ on these kets, $U |\Psi_{i} \rangle$. The 
maximum value of $i$ is $80$, meaning that merely use 80
data points (kets in this case) to efficiently learn the 
target unitary, $U$.


This tutorial is different from the 
[Qubit Rotation](https://github.com/qgrad/qgrad/blob/master/examples/QubitRotation.py),
in that it learns the unitary matrix to not take a fixed _specific_
state to another _fixed_ state. Here the unitary
$U(\phi, \theta, \omega)$ is learnt to evolve _any_ same
dimensional ket as the target unitary, $U$ would evolve it.


**Note**: Another version of this tutorial is implemented 
without `qgrad` that uses the parametrization used in 
[Seth Lloyd and Reevu Maity, 2020](https://arxiv.org/pdf/1901.03431.pdf)
and verifies the results of that paper. This tutorial 
shows similar results, with different unitary paramterization
$U(\phi, \theta, \omega)$ as shown above since the
parametrization used in the original paper uses hamiltonians 
in the powers of exponents, whose autodifferentiation is
not currently supported in JAX. For further reading 
on this autodifferentiation incompatibility, please 
refer to this companion 
[blog](https://araza6.github.io/posts/hamiltonian-differentiation/)


In [1]:
import jax.numpy as jnp
from jax import grad
from jax.experimental import optimizers
from jax.random import PRNGKey, uniform

import numpy as onp

#Visualization
import matplotlib.pyplot as plt 
from matplotlib import cm

from qgrad.qgrad_qutip import fidelity, rot, Unitary

from qutip import rand_ket # only to make the dataset

from scipy.stats import unitary_group