## JVP in Jax

JVP computes `J @ v` given a function, evaluation point (primals), and a tangent
vector. It does so without computing the full Jacobian (`J`).

In [2]:
import jax.numpy as jnp
from jax import make_jaxpr, jvp, linearize
from functools import partial

In [3]:
def my_fun(x):
  a = x ** 2.0
  b = a @ jnp.array([[1.0, 2.0]])
  c = jnp.sin(b)
  d = c @ jnp.array([[1.0, 2.0, 3.0, 4.0], [3.0, 4.0, 5.0, 6.0]])
  return d

In [4]:
primals_out, tangents_out = jvp(my_fun, (jnp.array([3.0]), ), (jnp.array([1.0]), ))
primals_out, tangents_out

(Array([-1.8408432, -2.1797118, -2.5185807, -2.8574495], dtype=float32),
 Array([18.304619, 20.761639, 23.218658, 25.675676], dtype=float32))

In the Jaxpr below, note that `a` and `b` are tensor constants defined inside `my_fun` that have been
hoisted as Jaxpr parameters. `c` is the input primal and `d` is the input tangent.

In [5]:
make_jaxpr(partial(jvp, my_fun))((jnp.array([3.0]), ), (jnp.array([1.0]), ))

{ lambda a:f32[1,2] b:f32[2,4]; c:f32[1] d:f32[1]. let
    e:f32[1] = pow c 2.0
    f:f32[] = sub 2.0 1.0
    g:f32[1] = pow c f
    h:f32[1] = mul 2.0 g
    i:f32[1] = mul d h
    j:f32[2] = dot_general[
      dimension_numbers=(([0], [0]), ([], []))
      preferred_element_type=float32
    ] e a
    k:f32[2] = dot_general[
      dimension_numbers=(([0], [0]), ([], []))
      preferred_element_type=float32
    ] i a
    l:f32[2] = sin j
    m:f32[2] = cos j
    n:f32[2] = mul k m
    o:f32[4] = dot_general[
      dimension_numbers=(([0], [0]), ([], []))
      preferred_element_type=float32
    ] l b
    p:f32[4] = dot_general[
      dimension_numbers=(([0], [0]), ([], []))
      preferred_element_type=float32
    ] n b
  in (o, p) }

## Linearize in JAX

`linearize : (a -> b) -> a -> (b, T a -o T b)`, `linearize` partially evaluates a function given primals and stages out a linear tangent computation.
Computation that only depend on primals are evaluated, while computation that depend on tangents are staged out as a Jaxpr.

The linear tangent computation is the linear approximation of the function at primals.

In [6]:
primals_out, linear_jvp = linearize(my_fun, jnp.array([3.0]))
primals_out, linear_jvp

(Array([-1.8408432, -2.1797118, -2.5185807, -2.8574495], dtype=float32),
 jax.tree_util.Partial(_HashableCallableShim(functools.partial(<function _lift_linearized at 0x7f2a8c03fb00>, { lambda a:f32[1] b:f32[1,2] c:f32[2] d:f32[2,4]; e:f32[1]. let
     f:f32[1] = pjit[
       name=_power
       jaxpr={ lambda ; g:f32[1] h:f32[1]. let i:f32[1] = mul g h in (i,) }
     ] e a
     j:f32[2] = pjit[
       name=matmul
       jaxpr={ lambda ; k:f32[1] l:f32[1,2]. let
           m:f32[2] = dot_general[
             dimension_numbers=(([0], [0]), ([], []))
             preferred_element_type=float32
           ] k l
         in (m,) }
     ] f b
     n:f32[2] = pjit[
       name=sin
       jaxpr={ lambda ; o:f32[2] p:f32[2]. let q:f32[2] = mul o p in (q,) }
     ] j c
     r:f32[4] = pjit[
       name=matmul
       jaxpr={ lambda ; s:f32[2] t:f32[2,4]. let
           u:f32[4] = dot_general[
             dimension_numbers=(([0], [0]), ([], []))
             preferred_element_type=float32
       

In [7]:
print(linear_jvp(jnp.array([1.0])))
print(make_jaxpr(linear_jvp)(jnp.array([1.0])))

[18.304619 20.761639 23.218658 25.675676]
{ lambda a:f32[1] b:f32[1,2] c:f32[2] d:f32[2,4]; e:f32[1]. let
    f:f32[1] = mul e a
    g:f32[2] = dot_general[
      dimension_numbers=(([0], [0]), ([], []))
      preferred_element_type=float32
    ] f b
    h:f32[2] = mul g c
    i:f32[4] = dot_general[
      dimension_numbers=(([0], [0]), ([], []))
      preferred_element_type=float32
    ] h d
  in (i,) }


## VJP in JAX

Note that the VJP function can be obtained from the JVP function by:

- Replacing the tangent vector (`[1.0]`) with the cotangent vector (`[1.0]*4`).

- Transposing the linear JVP mapping. That means each primitive operation
  (e.g. matmul) is transposed, and the order of operations is reversed.

In [8]:
import jax
linear_vjp = jax.linear_transpose(linear_jvp, jnp.array([1.0]))
print(make_jaxpr(linear_vjp)(jnp.array([1.0, 1.0, 1.0, 1.0])))
print(linear_vjp(jnp.array([1.0, 1.0, 1.0, 1.0])))

{ lambda a:f32[2,4] b:f32[2] c:f32[1,2] d:f32[1]; e:f32[4]. let
    f:f32[2] = dot_general[
      dimension_numbers=(([0], [1]), ([], []))
      preferred_element_type=float32
    ] e a
    g:f32[2] = mul f b
    h:f32[1] = dot_general[
      dimension_numbers=(([0], [1]), ([], []))
      preferred_element_type=float32
    ] g c
    i:f32[1] = mul h d
  in (i,) }
(Array([87.96059], dtype=float32),)


In [10]:
primals = jnp.array([3.0])
print(jax.vjp(my_fun, primals)[1](jnp.array([1.0, 1.0, 1.0, 1.0])))

(Array([87.96059], dtype=float32),)
