In [61]:
import jax.numpy as jnp
import jax.random as jrandom
import jax.nn as jnn
import jax

# JAX Function

In [62]:
def jax_fc(w, b, x):
    return jnp.dot(x, w) + b

def jax_sigmoid(x):
    return jnn.sigmoid(x)

def jax_softmax(x):
    return jnn.softmax(x)

def jax_relu(x):
    return jnn.relu(x)

@jax.grad
def get_jax_logistic_regression_grad(parmas, x, y_true):
    y_pre = jax_sigmoid(jax_fc(*parmas, x))
    loss = -(jnp.sum(y_true * jnp.log(y_pre) + (1 - y_true) * jnp.log(1 - y_pre)))
    return loss

@jax.grad
def get_jax_softmax_classification_grad(parmas, x, y_true):
    pre  = jax_softmax(jax_fc(*parmas, x))
    loss = -jnp.sum(y_true * jnp.log(pre))
    return loss

@jax.grad
def get_jax_linear_regression_grad(parmas, x, y_true):
    for w, b in parmas[:-1]:
        x = jax_fc(w, b, x)
        x = jax_relu(x)
    y_pre = jax_fc(*parmas[-1], x)
    loss = jnp.sum((y_true - y_pre)**2)
    return loss

# Function Definition

In [63]:
# linear
def linear(parmas, x):
    w, b = parmas
    return jnp.dot(x, w) + b

def grad_linear(parmas, layer_inputs, grad_cum_dot):
    w, b = parmas
    dw = jnp.dot(layer_inputs.T, grad_cum_dot)
    db = jnp.sum(grad_cum_dot, axis=0)
    dx = jnp.dot(grad_cum_dot, w.T)
    return (dw, db, dx)

In [64]:
# sigmoid
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

def grad_sigmoid(x, grad_cum_dot):
    return grad_cum_dot * (sigmoid(x) * (1 - sigmoid(x)))

In [65]:
# softmax
def softmax(x):
    return jnp.exp(x) / jnp.sum(jnp.exp(x), axis=1).reshape(-1,1)

def grad_softmax(x, grad_cum_dot):
    target_indexs = grad_cum_dot.argmin(axis=1).tolist()
    return jnp.array([[s-1 if target_index == idx else s for idx, s in enumerate(i)] for i, target_index in zip(softmax(x), target_indexs)])

In [66]:
# relu
def relu(x):
    return jnp.maximum(0, x)

def grad_relu(x, grad_cum_dot):
    return grad_cum_dot * jnp.maximum(0, x/jnp.abs(x))

In [67]:
# binary cross entropy
def bce_loss(y_true, y_pre):
    return jnp.mean(y_true * jnp.log(y_pre) + (1 - y_true) * jnp.log(1 - y_pre))

def grad_bce_loss(y_true, y_pre):
    return - (y_true * 1 / y_pre + (1 - y_true) * 1 / (1 - y_pre) * -1)

In [68]:
# cross entropy loss
def ce_loss(y_true, y_pre_pro):
    return -jnp.sum(y_true * jnp.log(y_pre_pro))

def grad_ce_loss(y_true, y_pre_pro):
    return 1 / (y_true * y_pre_pro)

In [69]:
# mse loss
def mse_loss(y_true, y_pre):
    return jnp.sum((y_true - y_pre)**2)

def grad_mse_loss(y_true, y_pre):
    return -2 * (y_true - y_pre)

# Gradient Calculation

## Linear Regression

In [70]:
KEY = jrandom.PRNGKey(520)
INPUT_DIM = 3
HIDDEN_DIM = 4
OUTPUT_DIM = 1
BATCH = 3

parmas1 = (jrandom.normal(KEY, (INPUT_DIM, HIDDEN_DIM)), jrandom.normal(KEY, (HIDDEN_DIM, )))
parmas2 = (jrandom.normal(KEY, (HIDDEN_DIM, OUTPUT_DIM)), jrandom.normal(KEY, (OUTPUT_DIM, )))
x = jrandom.normal(KEY, (BATCH ,INPUT_DIM))
y = jrandom.randint(KEY, (BATCH, OUTPUT_DIM), -10, 1)

# forward
linear_output = linear(parmas1, x)
relu_output = relu(linear_output)
outputs = linear(parmas2, relu_output)

# loss
loss = mse_loss(y, outputs)

# gradient
dLoss_dLinear = grad_mse_loss(y, outputs)
dw2, db2, dx2 = grad_linear(parmas2, relu_output, dLoss_dLinear) 
dRelu_dLinear = grad_relu(linear_output, dx2) 
dw1, db1, _ = grad_linear(parmas1, x, dRelu_dLinear)


(dw1, db1, dw2, db2)

(DeviceArray([[ 0.09599687,  0.08536603,  0.        ,  0.03540015],
              [-0.01222258, -0.01086903,  0.        , -0.00450724],
              [ 0.02458649,  0.02186375,  0.        ,  0.0090666 ]],            dtype=float32),
 DeviceArray([0.0590255 , 0.05248893, 0.        , 0.02176646], dtype=float32),
 DeviceArray([[-0.17297728],
              [-0.05039567],
              [ 0.        ],
              [-0.21061537]], dtype=float32),
 DeviceArray([22.230576], dtype=float32))

In [71]:
get_jax_linear_regression_grad([parmas1, parmas2], x, y)

[(DeviceArray([[ 0.09599687,  0.08536603,  0.        ,  0.03540015],
               [-0.01222258, -0.01086903,  0.        , -0.00450724],
               [ 0.02458649,  0.02186375,  0.        ,  0.0090666 ]],            dtype=float32),
  DeviceArray([0.0590255 , 0.05248893, 0.        , 0.02176646], dtype=float32)),
 (DeviceArray([[-0.17297728],
               [-0.05039567],
               [ 0.        ],
               [-0.21061537]], dtype=float32),
  DeviceArray([22.230576], dtype=float32))]

## Logistic Regression

In [72]:
KEY = jrandom.PRNGKey(529)
INPUT_DIM = 3
OUTPUT_DIM = 1
BATCH = 4

parmas = (jrandom.normal(KEY, (INPUT_DIM, OUTPUT_DIM)), jrandom.normal(KEY, (OUTPUT_DIM, )))
x = jrandom.normal(KEY, (BATCH ,INPUT_DIM))
y = jrandom.randint(KEY, (BATCH, OUTPUT_DIM), 0, 2)

# forward
linear_output = linear(parmas, x)
sigmoid_output = sigmoid(linear_output)

# loss
loss = bce_loss(y, sigmoid_output)

# gradient
dLoss_dSigomid = grad_bce_loss(y, sigmoid_output)
dSigmoid_dLinear = grad_sigmoid(linear_output, dLoss_dSigomid) 
dw, db, _ = grad_linear(parmas, x, dSigmoid_dLinear) 

(dw, db)

(DeviceArray([[-0.587844  ],
              [ 0.09830173],
              [-0.2256486 ]], dtype=float32),
 DeviceArray([0.81867844], dtype=float32))

In [73]:
get_jax_logistic_regression_grad(parmas, x, y)

(DeviceArray([[-0.587844  ],
              [ 0.09830171],
              [-0.22564854]], dtype=float32),
 DeviceArray([0.8186784], dtype=float32))

## Softmax Classification

In [74]:
KEY = jrandom.PRNGKey(129)
INPUT_DIM = 3
BATCH = 4
N_CLASS = 3

parmas = (jrandom.normal(KEY, (INPUT_DIM, N_CLASS)), jrandom.normal(KEY, (N_CLASS, )))
x = jrandom.normal(KEY, (BATCH ,INPUT_DIM))
y = jnn.one_hot(jrandom.randint(KEY, (BATCH, ), 0, N_CLASS), N_CLASS)

# forward
linear_output = linear(parmas, x)
softmax_output = softmax(linear_output)

# loss
loss = ce_loss(y, softmax_output)

# gradient
dLoss_dSoftmax = grad_ce_loss(y, softmax_output)
dSoftmax_dLinear = grad_softmax(linear_output, dLoss_dSoftmax) 
dw, db, _ = grad_linear(parmas, x, dSoftmax_dLinear) 

(dw, db)

(DeviceArray([[-1.1197042 ,  1.4028358 , -0.28313178],
              [ 0.62128055, -1.0922139 ,  0.47093317],
              [-0.28436232,  0.50908446, -0.22472213]], dtype=float32),
 DeviceArray([ 1.43662  , -2.0455873,  0.6089672], dtype=float32))

In [75]:
get_jax_softmax_classification_grad(parmas, x, y)

(DeviceArray([[-1.119704  ,  1.402836  , -0.2831316 ],
              [ 0.6212805 , -1.0922139 ,  0.47093332],
              [-0.28436238,  0.5090845 , -0.22472215]], dtype=float32),
 DeviceArray([ 1.43662 , -2.045587,  0.608967], dtype=float32))