# Week 2: Numerical Differentiation, Tensors with jax

## Example: Numerical differentiation

In [14]:
def diff(f, x0, eps=1e-6):
    return (f(x0 + eps) - f(x0))/eps

def f(x):
    return x**2

diff(f, 3)

6.000001000927568

## Tensors

In [2]:
# pip install jax jaxlib
import jax
import jax.numpy as jnp

## Vectors and Coordinate transformations

In [22]:
v1 = jnp.array([1, 2, 3], dtype=jnp.float32)
v2 = jnp.array([4, 5, 6], dtype=jnp.float32)

### 3D rotations

In [27]:
def rot_x(theta_x):  # f: R -> R^3
    r_x = jnp.array(
        [[1., 0.              , 0.               ],
         [0., jnp.cos(theta_x), -jnp.sin(theta_x)],
         [0., jnp.sin(theta_x),  jnp.cos(theta_x)]
        ])
    return r_x
    
def rot_y(theta_y): # f: R -> R^3
    r_y = jnp.array(
        [[ jnp.cos(theta_y), 0., jnp.sin(theta_y)],
         [0.               , 1., 0.              ],
         [-jnp.sin(theta_y), 0.,  jnp.cos(theta_y)]
        ])
    return r_y
    
def rot_z(theta_z): # f: R -> R^3
    r_z = jnp.array(
        [[jnp.cos(theta_z), -jnp.sin(theta_z), 0.],
         [jnp.sin(theta_z),  jnp.cos(theta_z), 0.],
         [0.              , 0.               , 1.]
        ])
    return r_z    

In [33]:
v1prime = jnp.einsum('ij,j->i', rot_x(jnp.pi), v1)
print("v1' = R_x(pi).v1: [{:.0f}, {:.0f}, {:.0f}]".format(*v1prime))
print("v1' = R_x(pi).v1: {}".format(v1prime))

v1' = R_x(pi).v1: [1, -2, -3]
v1' = R_x(pi).v1: [ 1.        -1.9999998 -3.0000002]


Look at rotated coordinate system

In [46]:
e1, e2, e3 = jnp.array([1., 0., 0.]), jnp.array([0., 1., 0.]), jnp.array([0., 0., 1.])
o = jnp.einsum('ij,jk,kl->il', rot_x(jnp.sqrt(17)), rot_y(-1.234), rot_z(1/5))
print(o)

[[ 0.32387784 -0.06565329 -0.9438182 ]
 [ 0.65857905 -0.70056975  0.2747286 ]
 [-0.6792473  -0.7105574  -0.18366113]]


In [47]:
b1, b2, b3 = jnp.einsum('ij,j', o, e1), jnp.einsum('ij,j', o, e2), jnp.einsum('ij,j', o, e3)

In [48]:
print("b1={}, b2={}, b3={}".format(b1, b2, b3))

b1=[ 0.32387784  0.65857905 -0.6792473 ], b2=[-0.06565329 -0.70056975 -0.7105574 ], b3=[-0.9438182   0.2747286  -0.18366113]


### Contra-variant vector

For a some random vector $v$ we get

In [63]:
v = 1 * e1 + 2 * e2 - 3 * e3
vprime = jnp.einsum('ij,j', o, v)
print("v  = {}".format(v))
print("v' = {}".format(vprime))

v  = [ 1.  2. -3.]
v' = [ 3.024026  -1.5667462 -1.5493786]


For a contra-variant vector, the trafo $\vec{v}\to \vec{v}'(\vec{v})=A\cdot\vec{v}$ should be given by the Jacobian $A^i_j = J^i_j = \frac{\partial (x^\prime)^i}{\partial x^j}$

In [76]:
from jax import grad
def vprime(v, i):
    return jnp.einsum('ij,j', o, v)[i]
j = jnp.array([[jax.grad(vprime)(v, i)[j] for j in range(len(v))] for i in range(len(v))])
print(j)

[[ 0.32387784 -0.06565329 -0.9438182 ]
 [ 0.65857905 -0.70056975  0.2747286 ]
 [-0.6792473  -0.7105574  -0.18366113]]


Since computing the Jacobian is something that needs to be done very frequently, it has its own function. In fact, it even has two implementations that do the same thing (i.e., both compute the Jacobian). In the case of endomorphisms (i.e., maps from V to V), where we compute $v^\prime_i(v_j)$, this does not matter. 

More generally, we could imagine maps $\mathbb{R}^N\to\mathbb{R}^M$, in which case we should compute $f_i(x_j)$ where $i=1,..,N$ and $j=1,..,M$. 

If $N\leq M$, the $\texttt{jacfwd}$ implementation is faster and if $N>M$, the $\texttt{jacrev}$ implementation is faster.

In [145]:
def vprime(v):
    return jnp.einsum('ij,j', o, v)
j = jax.jacfwd(vprime)(v)
print(j)
print(o)

[[ 0.32387784 -0.06565329 -0.9438182 ]
 [ 0.65857905 -0.70056975  0.2747286 ]
 [-0.6792473  -0.7105574  -0.18366113]]
[[ 0.32387784 -0.06565329 -0.9438182 ]
 [ 0.65857905 -0.70056975  0.2747286 ]
 [-0.6792473  -0.7105574  -0.18366113]]


Finally, let us look at a covariant vector (like the gradient itself). We define some random function $\phi: \mathbb{R}^3\to\mathbb{R}$,
$$\phi(v) = \sin(v_x)\cos(v_y)\tan(v_z)+ e^{-v_x} v_y v_z^2\,,$$ take the gradient $\vec{\nabla}\phi$, and see how it transforms under coordiate changes $v\to v'(v)$

In [141]:
def phi(v):
    x, y, z = v
    return jnp.sin(x) * jnp.cos(y) * jnp.tan(z) + jnp.exp(-x) * y * z**2

def vorig(vp):
    return jnp.einsum('ij,j', jnp.linalg.inv(o), vp)

In [131]:
v  = 1 * e1 + 2 * e2 - 3 * e3
vp = vprime(v)
print('v=              ', v)
print('\phi(v)=        ', phi(v))
print('\\nabla\phi(v)=  ', jax.grad(phi)(v))
print('')
print('v\'=             ', vp)
print('\phi(v\')=       ', phi(vp))
print('\\nabla\phi(v\')= ', jax.grad(phi)(vp))

v=               [ 1.  2. -3.]
\phi(v)=         6.5719137
\nabla\phi(v)=   [-6.653881   3.201846  -4.7718444]

v'=              [ 3.024026  -1.5667462 -1.5493786]
\phi(v')=        -0.20498562
\nabla\phi(v')=  [ 0.37057406 -5.3590355   1.2717581 ]


In [146]:
jtilde = jax.jacfwd(vorig)(jax.grad(phi)(vp))
print(jnp.einsum('ij->ji', jtilde))
print(o)

[[ 0.3238778  -0.06565327 -0.9438182 ]
 [ 0.6585789  -0.7005696   0.27472857]
 [-0.6792472  -0.7105573  -0.1836611 ]]
[[ 0.32387784 -0.06565329 -0.9438182 ]
 [ 0.65857905 -0.70056975  0.2747286 ]
 [-0.6792473  -0.7105574  -0.18366113]]


## Tensor manipulations

In [9]:
v1 = jnp.array([1, 2, 3], dtype=jnp.float32)
v2 = jnp.array([4, 5, 6], dtype=jnp.float32)
m = jnp.array([[1,-2,3], [4,-5,6], [7,-8,9]], dtype=jnp.float32)

In [17]:
s = jnp.einsum('i,i->', v1, v2)
print("Inner product: v1.v2\n{:.0f}\n".format(s))

t = jnp.einsum('i,j->ij', v1, v2)
print("Outer product: v1 x v2\n{:}\n".format(t))

v3 = jnp.einsum('ij,j->i', m, v2)
print("Matrix times vector: m.v1:\n{:}\n".format(v3))

n = jnp.einsum('ij,jk->ik', m, m)
print("Matrix times matrix: m.m:\n{:}\n".format(n))

u = jnp.einsum('ij,kl,m,k->ijlm', m, m, v1, v2)
print("More complicated tensor: \sum_k m_ij.m_kl v1_m v2_l = u_ijlm:\n{:}\n".format(u))

Inner product: v1.v2
32

Outer product: v1 x v2
[[ 4.  5.  6.]
 [ 8. 10. 12.]
 [12. 15. 18.]]

Matrix times vector: m.v1:
[12. 27. 42.]

Matrix times matrix: m.m:
[[ 14. -16.  18.]
 [ 26. -31.  36.]
 [ 38. -46.  54.]]

More complicated tensor: \sum_k m_ij.m_kl v1_m v2_l = u_ijlm:
[[[[   66.   132.   198.]
   [  -81.  -162.  -243.]
   [   96.   192.   288.]]

  [[ -132.  -264.  -396.]
   [  162.   324.   486.]
   [ -192.  -384.  -576.]]

  [[  198.   396.   594.]
   [ -243.  -486.  -729.]
   [  288.   576.   864.]]]


 [[[  264.   528.   792.]
   [ -324.  -648.  -972.]
   [  384.   768.  1152.]]

  [[ -330.  -660.  -990.]
   [  405.   810.  1215.]
   [ -480.  -960. -1440.]]

  [[  396.   792.  1188.]
   [ -486.  -972. -1458.]
   [  576.  1152.  1728.]]]


 [[[  462.   924.  1386.]
   [ -567. -1134. -1701.]
   [  672.  1344.  2016.]]

  [[ -528. -1056. -1584.]
   [  648.  1296.  1944.]
   [ -768. -1536. -2304.]]

  [[  594.  1188.  1782.]
   [ -729. -1458. -2187.]
   [  864.  1728.  2592

Note that this is pretty fast even for big matrices. Let for example
$$ m\in\mathbb{R}^{10000\times1000}\,,\qquad v_1, v_2\in\mathbb{R}^{1000}$$
We compute 
$$ s=v_1^T \cdot m^T \cdot m \cdot v_2 $$

In [15]:
import time
key = jax.random.PRNGKey(seed=1)
m = jax.random.uniform(key, shape=(10000, 1000))
v1 = jax.random.uniform(key, shape=(1000,))
v2 = jax.random.uniform(key, shape=(1000,))
t1 = time.time()
s = jnp.einsum('i,ji,jk,k->', v1, m, m, v2)
t2 = time.time()
print("The result is {:.3f}. The computation took {:.3f} seconds".format(s, t2-t1))

The result is 591693504.000. The computation took 0.029 seconds
