In [9]:
from jax import numpy as jnp, random
from jax import grad, jacfwd, jacrev, jit, vmap

In [10]:
key = random.PRNGKey(0)

# Linear layers

#### Real-valued linear layer
$f: \mathbb{R}^n \times \mathbb{R}^p \rightarrow \mathbb{R}$, $x \mapsto x^T W + b^T$ , $W \in \mathbb{R}^{n}, b \in \mathbb{R}$  

In [11]:
# input and output dimensions
d_input = 3
d_output = 1

1. No batch

In [12]:
def layer(x, W, b):
    return jnp.dot(x, W) + b

def grad_layer(x, W, b):
    return grad(layer,  argnums=(1, 2))(x, W, b)

In [13]:
# initialize layer input
x = random.normal(key, (d_input,))

# initialize layer coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (d_input,))
b = random.normal(b_key, ())

# evaluate y
y = layer(x, W, b)
print(y)

# evaluate gradient
W_grad_y, b_grad_y = grad_layer(x, W, b)
print(W_grad_y == x)
print(b_grad_y == 1)

1.2867186
[ True  True  True]
True


In [75]:
#y
#x, W_grad_y, b_grad_y
len(grad_layer(x, W, b))

2

2. With batch

In [5]:
def layer_batch(inputs, W, b):
    return vmap(layer, in_axes=(0, None, None), out_axes=0)(inputs, W, b)

def grad_layer_batch(inputs, W, b):
    return vmap(grad_layer, in_axes=(0, None, None), out_axes=None)(inputs, W, b)

def grad_layer_batch2(inputs, W, b):
    return jacfwd(layer_batch, argnums=(1, 2))(inputs, W, b)

In [6]:
# input, output dimensions and batch size
batch_size = 10

# initialize layer input
inputs = random.normal(key, (batch_size, d_input))

# initialize layer coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (d_input, d_output))
b = random.normal(b_key, (d_output,))

# evaluate y
y = layer_batch(inputs, W, b)
print(y.shape)

# evaluate gradient
W_grad_y, b_grad_y = grad_layer_batch2(inputs, W, b)
W_grad_y = W_grad_y.squeeze()
b_grad_y = b_grad_y.squeeze()

# check gradient
#print((W_grad_y[:, 0] == 2).all())
#print((grad_y[:, 1] == 1).all())

(10, 1)


In [105]:
#inputs, W_grad_y, b_grad_y

## Vector-valued linear layer
$f: \mathbb{R}^n \times \mathbb{R}^m \rightarrow \mathbb{R}^p$, $x \mapsto x^T W + b^T$ , $W \in \mathbb{R}^{n \times m}, b \in \mathbb{R}^m$  
3. No batch

In [108]:
# input and output dimensions
d_input = 3
d_output = 4

In [112]:
def jac_fwd_layer(x, W, b):
    return jacfwd(layer, (1, 2))(x, W, b)

def jac_rev_layer(x, W, b):
    return jacrev(layer, (1, 2))(x, W, b)

In [113]:
# initialize layer input
x = random.normal(key, (d_input,))

# initialize layer coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (d_input, d_output))
b = random.normal(b_key, (d_output,))

# evaluate y
y = layer(x, W, b)
print(y)

# evaluate jacobian
jac_y = jac_fwd_layer(x, W, b) # forward ad
#jac_y = jac_rev_layer(x, W, b) # backward ad

# check shapes
W_jac_y, b_jac_y = jac_y
print(W_jac_y.shape)
print(b_jac_y.shape)

[-3.2433186   1.4805036  -0.14534098  2.4608994 ]
(4, 3, 4)
(4, 4)


AttributeError: 'tuple' object has no attribute 'shape'

4. with batch

In [114]:
def jac_fwd_layer_batch(inputs, W, b):
    return vmap(jac_fwd_layer, in_axes=(None, None, 0), out_axes=0)(inputs, W, b)

#def jac_rev_layer_batch(W, b, inputs):
#    return vmap(jac_rev_layer, in_axes=(None, None, 0), out_axes=0)(W, b, inputs)

In [111]:
batch_size = 10

# initialize layer input
inputs = random.normal(key, (batch_size, d_input))

# initialize layer coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (d_input, d_output))
b = random.normal(b_key, (d_output,))

# evaluate y
y = layer_batch(W, b, inputs)
print(y.shape)

# evaluate jacobian
#jac_y = jac_fwd_layer_batch(W, b, inputs) # forward ad
jac_y = jac_rev_layer_batch(W, b, inputs) # backward ad
#print(jac_y.shape)

(3, 10, 3)


NameError: name 'jac_rev_layer_batch' is not defined

In [20]:
# check shapes
W_jac_y, b_jac_y = jac_y
print(W_jac_y.shape)
print(b_jac_y.shape)

(10, 4, 3, 4)
(10, 4, 4)
