In [1]:
from scipy.stats import norm

In [2]:
norm.cdf(0, loc=0, scale=1)

0.5

In [3]:
import jax
import jax.numpy as np

from jax import custom_jvp

In [4]:
@custom_jvp
def g(x, y):
  return np.sin(x) * y

@g.defjvp
def g_jvp(primals, tangents):
  x, y = primals
  x_dot, y_dot = tangents
  primal_out = g(x, y)
  tangent_out = np.cos(x) * x_dot * y + np.sin(x) * y_dot
  return primal_out, tangent_out

In [5]:
def f(x, y):
    return x @ y, [x, y]

In [6]:
A = np.ones((4, 5))
B = np.ones((5, 6))

In [499]:
jac = jax.jacfwd(f, argnums=[0, 1])(A, B)
jac[0][0].shape, jac[0][1].shape, jac[1][0][0].shape

((4, 6, 4, 5), (4, 6, 5, 6), (4, 5, 4, 5))

In [482]:
jac = jax.jacfwd(f, argnums=[0, 1])(A, B)

In [483]:
jax.jacfwd(f, argnums=[0, 1])(A, B)[0][1].shape

(4, 6, 5, 6)

In [437]:
from functools import partial

In [456]:
jvp_partial = lambda partials, tangents: jax.jvp(f, partials, tangents)
jvp_vmap = jax.vmap(jvp_partial, in_axes=((None, None), (0, 0)), out_axes=(-1, -1))

In [439]:
jvp_vmap = jax.vmap(
    jax.vmap(
        jvp_partial, 
        in_axes=((None, None), (None, 0)), 
        out_axes=(None, 0)
    ), 
    in_axes=((None, None), (0, None)), 
    out_axes=(0, None)
)

In [440]:
jax._src.api._std_basis(A).shape

(20, 4, 5)

In [409]:
f(A, B)[0].shape, f(A, B)[1].shape

((6,), (6,))

In [441]:
primals_out, tangents_out = jvp_vmap((A, B), jax._src.api._std_basis((A, B)))

jac1 =  tangents_out[0][:, :, :, :4*5*6].reshape((4, 5, 6, 4, 5, 6))

ValueError: vmap has mapped output but out_axes is None

In [444]:
primals_out, tangents_out = jax.jvp(f, (A, B), (A, B))

In [448]:
primals_out[0].shape, primals_out[1].shape, primals_out[2].shape

((4, 6), (4, 5), (5, 6))

In [452]:
tangents_out[0].shape, tangents_out[1].shape, tangents_out[2].shape

((4, 6), (4, 5), (5, 6))

In [447]:
tangents_out[0].shape, tangents_out[1].shape

((4, 6), (4, 5))

In [466]:
jac = jax.jacfwd(f, argnums=(0, 1))(A, B)
jac[0][0].shape, jac[0][1].shape, jac[1][0].shape, jac[1][1].shape

((4, 6, 4, 5), (4, 6, 5, 6), (4, 5, 4, 5), (4, 5, 5, 6))

In [386]:
np.allclose(jac[0][0], jac1)

Array(True, dtype=bool)

In [472]:
primals_out, tangents_out = jvp_vmap((A, B), jax._src.api._std_basis((A, B)))
assert np.allclose(tangents_out[0][:, :, :4*5].reshape(4, 6, 4, 5), jac[0][0])
assert np.allclose(tangents_out[0][:, :, 4*5:].reshape(4, 6, 5, 6), jac[0][1])
assert np.allclose(tangents_out[1][:, :, :4*5].reshape(4, 5, 4, 5), jac[1][0])
assert np.allclose(tangents_out[1][:, :, 4*5:].reshape(4, 5, 5, 6), jac[1][1])

In [475]:
jax._src.api._std_basis((A, B))[0].shape, jax._src.api._std_basis((A, B))[1].shape

((50, 4, 5), (50, 5, 6))

In [478]:
jax._src.api._std_basis((A, B))[0].shape, jax._src.api._std_basis((A, B))[1].shape

((50, 4, 5), (50, 5, 6))

In [479]:
jac

((Array([[[[1., 1., 1., 1., 1.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.]],
  
          [[1., 1., 1., 1., 1.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.]],
  
          [[1., 1., 1., 1., 1.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.]],
  
          [[1., 1., 1., 1., 1.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.]],
  
          [[1., 1., 1., 1., 1.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.]],
  
          [[1., 1., 1., 1., 1.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.]]],
  
  
         [[[0., 0., 0., 0., 0.],
           [1., 1., 1., 1., 1.],
           [0., 0., 0., 0., 0.],
           [0., 0., 0., 0., 0.]],
  
          [[0., 0., 0., 0., 0.],
           

In [504]:
A

Array([[ 1.1467233 , -1.17104   ,  0.09341183],
       [-0.22301869,  0.32389048,  0.2273164 ]], dtype=float32)

In [156]:
def g(A, B):
    C = A @ B
    D = A * A
    return C, D
    
key = jax.random.key(42)
A_key, B_key = jax.random.split(key, 2)
X1 = jax.random.normal(A_key, shape=(2, 3))
X2 = jax.random.normal(B_key, shape=(3, 4))

Y1, Y2 = g(X1, X2)
print(C.shape, D.shape)

jac_fwd = jax.jacfwd(g, argnums=(0, 1))(X1, X2)
jac_rev = jax.jacrev(g, argnums=(0, 1))(X1, X2)

() (2,)


In [16]:
print(jac_fwd[0][0].shape, jac_fwd[0][1].shape, jac_fwd[1][0].shape, jac_fwd[1][1].shape)

(2, 4, 2, 3) (2, 4, 3, 4) (2, 3, 2, 3) (2, 3, 3, 4)


In [17]:
jac_rev[0][0].shape, jac_rev[0][1].shape, jac_rev[1][0].shape, jac_rev[1][1].shape

((2, 4, 2, 3), (2, 4, 3, 4), (2, 3, 2, 3), (2, 3, 3, 4))

In [21]:
jac[0][0].shape, V1.shape

((2, 4, 2, 3), (2, 3))

In [29]:
def contract(A, B):
    return np.einsum("ijkl,kl -> ij", A, B)

In [117]:
# let us choose x to be the (1, 2) entry of X1
# below are three ways of obtaining the same answer in jax
# along with their clock time

# first way
# indexing the result of jacfwd

jac = jax.jacfwd(g, argnums=(0, 1))(X1, X2)
jvp_via_indexing = (jac[0][0][:, :, 1, 2], jac[1][0][:, :, 1, 2])

# second way
# contracting with V1 and V2

V1, V2 = np.zeros(X1.shape), np.zeros(X2.shape)
V1 = V1.at[1, 2].set(1)

def contract(A, B):
    return np.einsum("ijkl,kl -> ij", A, B)

jvp_after_jac_is_computed = (
    contract(jac[0][0], V1) + contract(jac[0][1], V2), 
    contract(jac[1][0], V1) + contract(jac[1][1], V2)
)

# third way
# use jax's jvp

primals, tangents = jax.jvp(g, (X1, X2), (V1, V2))

print(np.allclose(jvp_via_indexing[0], tangents[0]))
print(np.allclose(jvp_via_indexing[1], tangents[1]))
print(np.allclose(jvp_after_jac_is_computed[0], tangents[0]))
print(np.allclose(jvp_after_jac_is_computed[1], tangents[1]))

ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.

In [49]:
tuple(range(5))

(0, 1, 2, 3, 4)

In [None]:
def jvp_inefficient(fun, primals, tangents):
    num_positional_args = tuple(range(len(primals)))
    jac = jax.jacfwd(fun, argnums=num_positional_args)(*primals)
    

In [75]:
import numpy as npy

In [85]:
def matmul(A, B):
    return npy.matmul(A, B)

In [107]:
X1, X2

(Array([[ 1.1467233 , -1.17104   ,  0.09341183],
        [-0.22301869,  0.32389048,  0.2273164 ]], dtype=float32),
 Array([[-0.70990354,  1.8192104 , -0.39635026,  0.4026267 ],
        [-0.7557475 ,  0.8272896 ,  0.7292502 ,  0.00616852],
        [ 0.22729982,  1.089758  , -1.7545393 , -2.0472434 ]],      dtype=float32))

In [108]:
X1.shape

(2, 3)

In [112]:
result_shape = jax.core.ShapedArray((2, 4), X1.dtype)
jax.pure_callback(npy.matmul, result_shape, X1, X2)

Array([[ 0.09218019,  1.2191379 , -1.47238   ,  0.26324108],
       [-0.03478869,  0.10995318, -0.07424483, -0.55316734]],      dtype=float32)

In [123]:
def matmul(A, B):
    result_shape = jax.core.ShapedArray((A.shape[0], B.shape[1]), A.dtype)
    return jax.pure_callback(npy.matmul, result_shape, A, B)

In [125]:
@jax.custom_jvp
def matmul(A, B):
    result_shape = jax.core.ShapedArray((A.shape[0], B.shape[1]), A.dtype)
    return jax.pure_callback(npy.matmul, result_shape, A, B)

@matmul.defjvp
def matmul_jvp(primals, tangents):
    A, B = primals
    A_dot, B_dot = tangents
    primal_out = matmul(A, B)
    tangent_out = matmul(A_dot, B) + matmul(A, B_dot)
    return primal_out, tangent_out

In [122]:
def g2(A, B):
    C = matmul(A, B)
    D = A * A
    return C, D

In [126]:
primals, tangents = jax.jvp(g2, (X1, X2), (V1, V2))

In [146]:
partial

NameError: name 'partial' is not defined

In [151]:
from jax import make_jaxpr

make_jaxpr(g)(X1, X2)

{ lambda ; a:f32[2,3] b:f32[3,4]. let
    c:f32[2,4] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] a b
    d:f32[2,3] = mul a a
  in (c, d) }

In [160]:
# make note of computation shapes here

# let us choose y to be the (1, 2) entry of Y1
# below are three ways of obtaining the same answer in jax

# first way
# indexing the result of jacrev (the jacobian)

jac = jax.jacrev(g, argnums=(0, 1))(X1, X2)
vjp_via_indexing = (jac[0][0][1, 2, :, :], jac[0][1][1, 2, :, :])

# second way
# contracting the jacobian with W1 and W2

W1, W2 = np.zeros(Y1.shape), np.zeros(Y2.shape)
W1 = W1.at[1, 2].set(1)

def contract(A, B):
    return np.einsum("ij,ijkl->kl", A, B)

vjp_after_jac_is_computed = (
    contract(W1, jac[0][0]) + contract(W2, jac[1][0]), 
    contract(W1, jac[0][1]) + contract(W2, jac[1][1])
)

# third way
# use vjp in jax

primals, vjpfun = jax.vjp(g, X1, X2)
cotangents = vjpfun((W1, W2))

print(np.allclose(vjp_via_indexing[0], cotangents[0]))  # true
print(np.allclose(vjp_via_indexing[1], cotangents[1]))  # true
print(np.allclose(vjp_after_jac_is_computed[0], cotangents[0]))  # true
print(np.allclose(vjp_after_jac_is_computed[1], cotangents[1]))  # true

True
True
True
True
