In [1]:
import jax.numpy as jnp
import jax.random as jrandom
import jax.nn as jnn
import jax

# JAX Function

In [2]:
def jax_fc(w, b, x):
    return jnp.dot(x, w) + b

def jax_sigmoid(x):
    return jnn.sigmoid(x)

def jax_softmax(x):
    return jnn.softmax(x)

@jax.grad
def get_jax_logistic_regression_grad(parmas, x, y_true):
    y_pre = jax_sigmoid(jax_fc(*parmas, x))
    loss = -(jnp.sum(y_true * jnp.log(y_pre) + (1 - y_true) * jnp.log(1 - y_pre)))
    return loss

@jax.grad
def get_jax_softmax_classification_grad(parmas, x, y_true):
    pre  = jax_softmax(jax_fc(*parmas, x))
    loss = -jnp.sum(y_true * jnp.log(pre))
    return loss

# Function Definition

In [3]:
# linear
def linear(parmas, x):
    w, b = parmas
    return jnp.dot(x, w) + b

def grad_linear(parmas, layer_inputs, grad_cum_dot):
    w, b = parmas
    dw = jnp.dot(layer_inputs.T, grad_cum_dot)
    db = jnp.sum(grad_cum_dot, axis=0)
    dx = jnp.dot(grad_cum_dot, w.T)
    return (dw, db, dx)

In [4]:
# sigmoid
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

def grad_sigmoid(x, grad_cum_dot):
    return grad_cum_dot * (sigmoid(x) * (1 - sigmoid(x)))

In [5]:
# softmax
def softmax(x):
    return jnp.exp(x) / jnp.sum(jnp.exp(x), axis=1).reshape(-1,1)

def grad_softmax(x, grad_cum_dot):
    target_indexs = grad_cum_dot.argmin(axis=1).tolist()
    return jnp.array([[s-1 if target_index == idx else s for idx, s in enumerate(i)] for i, target_index in zip(softmax(x), target_indexs)])

In [6]:
# binary cross entropy
def bce_loss(y_true, y_pre):
    return jnp.mean(y_true * jnp.log(y_pre) + (1 - y_true) * jnp.log(1 - y_pre))

def grad_bce_loss(y_true, y_pre):
    return - (y_true * 1 / y_pre + (1 - y_true) * 1 / (1 - y_pre) * -1)

In [11]:
# cross entropy loss
def ce_loss(y_true, y_pre_pro):
    return -jnp.sum(y_true * jnp.log(y_pre_pro))

def grad_ce_loss(y_true, y_pre_pro):
    return 1 / (y_true * y_pre_pro)

# Gradient Calculation

## Logistic Regression

In [12]:
KEY = jrandom.PRNGKey(529)
INPUT_DIM = 3
BATCH = 4

parmas = (jrandom.normal(KEY, (INPUT_DIM, 1)), jrandom.normal(KEY, (1, )))
x = jrandom.normal(KEY, (BATCH ,INPUT_DIM))
y = jrandom.randint(KEY, (BATCH, 1), 0, 2)

# forward
linear_output = linear(parmas, x)
sigmoid_output = sigmoid(linear_output)

# loss
loss = bce_loss(y, sigmoid_output)

# gradient
dLoss_dSigomid = grad_bce_loss(y, sigmoid_output)
dSigmoid_dLinear = grad_sigmoid(linear_output, dLoss_dSigomid) 
dw, db, _ = grad_linear(parmas, x, dSigmoid_dLinear) 

(dw, db)

(DeviceArray([[-0.587844  ],
              [ 0.09830173],
              [-0.2256486 ]], dtype=float32),
 DeviceArray([0.81867844], dtype=float32))

In [13]:
get_jax_logistic_regression_grad(parmas, x, y)

(DeviceArray([[-0.587844  ],
              [ 0.09830171],
              [-0.22564854]], dtype=float32),
 DeviceArray([0.8186784], dtype=float32))

## Softmax Classification

In [14]:
KEY = jrandom.PRNGKey(529)
INPUT_DIM = 3
BATCH = 4
N_CLASS = 3

parmas = (jrandom.normal(KEY, (INPUT_DIM, N_CLASS)), jrandom.normal(KEY, (N_CLASS, )))
x = jrandom.normal(KEY, (BATCH ,INPUT_DIM))
y = jnn.one_hot(jrandom.randint(KEY, (BATCH, ), 0, N_CLASS), N_CLASS)

# forward
linear_output = linear(parmas, x)
softmax_output = softmax(linear_output)

# loss
loss = ce_loss(y, softmax_output)

# gradient
dLoss_dSoftmax = grad_ce_loss(y, softmax_output)
dSoftmax_dLinear = grad_softmax(linear_output, dLoss_dSoftmax) 
dw, db, _ = grad_linear(parmas, x, dSoftmax_dLinear) 

(dw, db)

(DeviceArray([[-1.4472896 ,  0.5539345 ,  0.8933551 ],
              [ 0.2626161 , -0.02851331, -0.23410276],
              [ 0.7812388 , -1.1439149 ,  0.36267602]], dtype=float32),
 DeviceArray([ 2.4134421 , -0.74843514, -1.6650069 ], dtype=float32))

In [15]:
get_jax_softmax_classification_grad(parmas, x, y)

(DeviceArray([[-1.4472897 ,  0.5539345 ,  0.8933551 ],
              [ 0.2626161 , -0.02851331, -0.23410273],
              [ 0.78123885, -1.1439149 ,  0.36267602]], dtype=float32),
 DeviceArray([ 2.4134421 , -0.74843526, -1.6650071 ], dtype=float32))