In [1]:
import jax.numpy as jnp
import jax.random as jrandom
import jax.nn as jnn
import jax

# JAX Function

In [2]:
def jax_fc(w, b, x):
    return jnp.dot(x, w) + b

def jax_sigmoid(x):
    return jnn.sigmoid(x)

def jax_softmax(x):
    return jnn.softmax(x)

@jax.grad
def get_jax_logistic_regression_grad(parmas, x, y_true):
    y_pre = jax_sigmoid(jax_fc(*parmas, x))
    loss = -(jnp.sum(y_true * jnp.log(y_pre) + (1 - y_true) * jnp.log(1 - y_pre)))
    return loss

# Function Definition

In [3]:
# linear
def linear(parmas, x):
    w, b = parmas
    return jnp.dot(x, w) + b

def grad_linear(parmas, layer_inputs, grad_cum_dot):
    w, b = parmas
    dw = jnp.dot(layer_inputs.T, grad_cum_dot)
    db = jnp.sum(grad_cum_dot, axis=0)
    dx = jnp.dot(grad_cum_dot, w.T)
    return (dw, db, dx)

In [4]:
# sigmoid
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

def grad_sigmoid(x, grad_cum_dot):
    return grad_cum_dot * (sigmoid(x) * (1 - sigmoid(x)))

In [5]:
# binary cross entropy
def bce_loss(y_true, y_pre):
    return jnp.mean(y_true * jnp.log(y_pre) + (1 - y_true) * jnp.log(1 - y_pre))

def grad_bce_loss(y_true, y_pre):
    return - (y_true * 1 / y_pre + (1 - y_true) * 1 / (1 - y_pre) * -1)

# Gradient Calculation

## Logistic Regression

In [6]:
KEY = jrandom.PRNGKey(529)
INPUT_DIM = 3
BATCH = 4

parmas = (jrandom.normal(KEY, (INPUT_DIM, 1)), jrandom.normal(KEY, (1, )))
x = jrandom.normal(KEY, (BATCH ,INPUT_DIM))
y = jrandom.randint(KEY, (BATCH, 1), 0, 2)

# forward
linear_output = linear(parmas, x)
sigmoid_output = sigmoid(linear_output)

# loss
loss = bce_loss(y, sigmoid_output)

# gradient
dLoss_dSigomid = grad_bce_loss(y, sigmoid_output)
dSigmoid_dLinear = grad_sigmoid(linear_output, dLoss_dSigomid) 
dw, db, _ = grad_linear(parmas, x, dSigmoid_dLinear) 

(dw, db)



(DeviceArray([[-0.587844  ],
              [ 0.09830172],
              [-0.22564851]], dtype=float32),
 DeviceArray([0.8186784], dtype=float32))

In [7]:
get_jax_logistic_regression_grad(parmas, x, y)

(DeviceArray([[-0.587844  ],
              [ 0.09830172],
              [-0.22564851]], dtype=float32),
 DeviceArray([0.8186784], dtype=float32))