In [None]:
import jax.numpy as jnp
import jax.random as jrandom
import jax.nn as jnn
import jax

# JAX Function

In [None]:
def jax_linear(w, b, x):
    return jnp.dot(x, w) + b

def jax_bce_loss(y_true, y_pre):
    return -jnp.sum(y_true * jnp.log(y_pre) + (1 - y_true) * jnp.log(1 - y_pre))

def jax_ce_loss(y_true, y_pre):
    y_true = jnn.one_hot(y, len(set(y_true)))
    return -jnp.sum(y_true * jnp.log(y_pre))

def jax_mse_loss(y_true, y_pre):
    return jnp.sum((y_true - y_pre)**2)

In [None]:
@jax.grad
def get_jax_logistic_regression_grad(parmas, x, y_true):
    y_pre = jnn.sigmoid(jax_linear(*parmas, x))
    return jax_bce_loss(y_true, y_pre)

@jax.grad
def get_jax_softmax_classification_grad(parmas, x, y_true):
    y_pre  = jnn.softmax(jax_linear(*parmas, x))
    return jax_ce_loss(y_true, y_pre)

@jax.grad
def get_jax_linear_regression_grad(parmas, x, y_true):
    for w, b in parmas[:-1]:
        x = jax_linear(w, b, x)
        x = jnn.relu(x)
    y_pre = jax_linear(*parmas[-1], x)
    return jax_mse_loss(y_true, y_pre)

# Network Function

## Linear

In [None]:
# linear
def linear(parmas, x):
    w, b = parmas
    return jnp.dot(x, w) + b

def grad_linear(parmas, layer_inputs, grad_cum_dot):
    w, _ = parmas
    dw = jnp.dot(layer_inputs.T, grad_cum_dot)
    db = jnp.sum(grad_cum_dot, axis=0)
    dx = jnp.dot(grad_cum_dot, w.T)
    return (dw, db, dx)

## Activation Function

In [None]:
# sigmoid
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

def grad_sigmoid(x, grad_cum_dot):
    return grad_cum_dot * (sigmoid(x) * (1 - sigmoid(x)))

In [None]:
# softmax
def softmax(x):
    return jnp.exp(x) / jnp.sum(jnp.exp(x), axis=1).reshape(-1,1)

def grad_softmax(x, grad_cum_dot):
    target_indexs = grad_cum_dot.argmin(axis=1).tolist()
    return jnp.array([[s-1 if target_index == idx else s for idx, s in enumerate(i)] for i, target_index in zip(softmax(x), target_indexs)])

In [None]:
# relu
def relu(x):
    return jnp.maximum(0, x)

def grad_relu(x, grad_cum_dot):
    return grad_cum_dot * jnp.maximum(0, x/jnp.abs(x))

## Loss Function

In [None]:
# binary cross entropy gradient
def grad_bce_loss(y_true, y_pre):
    return - (y_true * 1 / y_pre + (1 - y_true) * 1 / (1 - y_pre) * -1)

In [None]:
# cross entropy loss gradient
def grad_ce_loss(y_true, y_pre):
    y_true = jnn.one_hot(y_true, len(set(y_true)))
    return 1 / (y_true * y_pre)

In [None]:
# mse loss gradient
def grad_mse_loss(y_true, y_pre):
    return -2 * (y_true - y_pre)

# Gradient Calculation

## Linear Regression

In [None]:
KEY = jrandom.PRNGKey(529)
INPUT_DIM = 3
HIDDEN_DIM = 4
OUTPUT_DIM = 1
BATCH = 10

parmas1 = (jrandom.normal(KEY, (INPUT_DIM, HIDDEN_DIM)), jrandom.normal(KEY, (HIDDEN_DIM, )))
parmas2 = (jrandom.normal(KEY, (HIDDEN_DIM, OUTPUT_DIM)), jrandom.normal(KEY, (OUTPUT_DIM, )))
x = jrandom.normal(KEY, (BATCH ,INPUT_DIM))
y = jrandom.randint(KEY, (BATCH, OUTPUT_DIM), -10, 10)

# forward
linear_output = linear(parmas1, x)
relu_output = relu(linear_output)
outputs = linear(parmas2, relu_output)

# gradient
dLoss_dLinear = grad_mse_loss(y, outputs)
dw2, db2, dx2 = grad_linear(parmas2, relu_output, dLoss_dLinear) 
dRelu_dLinear = grad_relu(linear_output, dx2) 
dw1, db1, _ = grad_linear(parmas1, x, dRelu_dLinear)

(dw1, db1, dw2, db2)



(DeviceArray([[ 5.0259308e+01, -2.9323208e+01, -1.6577687e+02,
                3.3996141e-01],
              [-1.6393526e+01, -7.4917145e+00,  4.0747322e+01,
                9.8981008e-02],
              [-1.6545301e+01,  1.6847119e+02,  2.0225949e+02,
               -1.4161079e+00]], dtype=float32),
 DeviceArray([-20.804895 , 114.67517  , 172.82722  ,  -1.0782892], dtype=float32),
 DeviceArray([[ 22.628633],
              [124.500916],
              [435.7387  ],
              [134.8935  ]], dtype=float32),
 DeviceArray([120.27084], dtype=float32))

In [None]:
get_jax_linear_regression_grad([parmas1, parmas2], x, y)

[(DeviceArray([[ 5.0259308e+01, -2.9323208e+01, -1.6577687e+02,
                 3.3996141e-01],
               [-1.6393526e+01, -7.4917145e+00,  4.0747322e+01,
                 9.8981008e-02],
               [-1.6545301e+01,  1.6847119e+02,  2.0225949e+02,
                -1.4161079e+00]], dtype=float32),
  DeviceArray([-20.804895 , 114.67517  , 172.82722  ,  -1.0782892], dtype=float32)),
 (DeviceArray([[ 22.628632],
               [124.50092 ],
               [435.7387  ],
               [134.89351 ]], dtype=float32),
  DeviceArray([120.27084], dtype=float32))]

## Logistic Regression

In [None]:
KEY = jrandom.PRNGKey(528)
INPUT_DIM = 3
OUTPUT_DIM = 1
BATCH = 10

parmas = (jrandom.normal(KEY, (INPUT_DIM, OUTPUT_DIM)), jrandom.normal(KEY, (OUTPUT_DIM, )))
x = jrandom.normal(KEY, (BATCH ,INPUT_DIM))
y = jrandom.randint(KEY, (BATCH, 1), 0, 2)

# forward
linear_output = linear(parmas, x)
sigmoid_output = sigmoid(linear_output)

# gradient
dLoss_dSigomid = grad_bce_loss(y, sigmoid_output)
dSigmoid_dLinear = grad_sigmoid(linear_output, dLoss_dSigomid) 
dw, db, _ = grad_linear(parmas, x, dSigmoid_dLinear) 

(dw, db)

(DeviceArray([[-0.35676906],
              [ 0.9949498 ],
              [ 4.8142233 ]], dtype=float32),
 DeviceArray([-0.19507158], dtype=float32))

In [None]:
get_jax_logistic_regression_grad(parmas, x, y)

(DeviceArray([[-0.3567691],
              [ 0.9949497],
              [ 4.8142233]], dtype=float32),
 DeviceArray([-0.1950717], dtype=float32))

## Softmax Classification

In [None]:
KEY = jrandom.PRNGKey(129)
INPUT_DIM = 3
BATCH = 10
N_CLASS = 3

parmas = (jrandom.normal(KEY, (INPUT_DIM, N_CLASS)), jrandom.normal(KEY, (N_CLASS, )))
x = jrandom.normal(KEY, (BATCH ,INPUT_DIM))
y = jrandom.randint(KEY, (BATCH, ), 0, N_CLASS)

# forward
linear_output = linear(parmas, x)
softmax_output = softmax(linear_output)

# gradient
dLoss_dSoftmax = grad_ce_loss(y, softmax_output)
dSoftmax_dLinear = grad_softmax(linear_output, dLoss_dSoftmax) 
dw, db, _ = grad_linear(parmas, x, dSoftmax_dLinear) 

(dw, db)

(DeviceArray([[-1.3611741,  2.6924202, -1.3312461],
              [ 1.2818991,  1.4020957, -2.6839948],
              [-0.0522493, -0.670455 ,  0.7227044]], dtype=float32),
 DeviceArray([ 1.1519504, -4.9833374,  3.8313875], dtype=float32))

In [None]:
get_jax_softmax_classification_grad(parmas, x, y)

(DeviceArray([[-1.361174 ,  2.6924202, -1.331246 ],
              [ 1.2818991,  1.4020958, -2.6839945],
              [-0.0522493, -0.6704552,  0.7227043]], dtype=float32),
 DeviceArray([ 1.1519502, -4.9833374,  3.831387 ], dtype=float32))