<a href="https://colab.research.google.com/github/probml/probml-notebooks/blob/main/notebooks/logspace_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Illustrate log-space computation in JAX

Adapted from the TF version: [Notebook](https://github.com/probml/probml-notebooks/blob/main/notebooks/logspace_tf.ipynb)


Code is derived from https://github.com/EEA-sensors/sequential-parallelization-examples



In [None]:
import jax.numpy as jnp
from jax import random
from jax import jit

In [None]:
mm = jit(jnp.matmul)
mv = jit(jnp.multiply)


@jit
def log_mv(log_A, log_b):
    Amax = jnp.max(log_A, axis=(-1, -2), keepdims=True)
    bmax = jnp.max(log_b, axis=(-1), keepdims=True)
    return jnp.squeeze(Amax, axis=-1) + bmax + jnp.log(mv(jnp.exp(log_A - Amax), jnp.exp(log_b - bmax)))


@jit
def semilog_mv(A, log_b):
    bmax = jnp.max(log_b, axis=(-1), keepdims=True)
    return bmax + jnp.log(mv(A, jnp.exp(log_b - bmax)))


@jit
def log_mm(log_A, log_B):
    Amax = jnp.max(log_A, axis=(-1, -2), keepdims=True)
    Bmax = jnp.max(log_B, axis=(-1, -2), keepdims=True)
    return Amax + Bmax + jnp.log(mm(jnp.exp(log_A - Amax), jnp.exp(log_B - Bmax)))


@jit
def log_normalize(log_p):
    pmax = jnp.max(log_p, axis=(-1), keepdims=True)
    temp = jnp.exp(log_p - pmax)
    return jnp.log(temp / jnp.sum(temp, axis=-1, keepdims=True))

In [None]:
print("Test utility functions:")
key = random.PRNGKey(5)

key, subkey1 = random.split(key)
key, subkey2 = random.split(key)

A = random.uniform(subkey1, shape=[4, 4])
B = random.uniform(subkey2, shape=[4, 4])
log_A = jnp.log(A)
log_B = jnp.log(B)
r1 = mm(A, B)
r2 = jnp.exp(log_mm(log_A, log_B))
assert jnp.allclose(r1, r2)
print(r1)
del subkey1, subkey2

key, subkey1 = random.split(key)

b = random.uniform(subkey1, shape=[4])
log_b = jnp.log(b)
r1 = mv(A, b)
r2 = jnp.exp(log_mv(log_A, log_b))
r3 = jnp.exp(semilog_mv(A, log_b))
assert jnp.allclose(r1, r2)
assert jnp.allclose(r1, r3)
print(r1)

r1 = b / jnp.sum(b, keepdims=True)
r2 = jnp.exp(log_normalize(log_b))
assert jnp.allclose(r1, r2)
print(r1)
del subkey1

Test utility functions:
[[1.3614577  0.8750447  0.5797691  1.0019269 ]
 [0.753297   0.56651133 0.37677762 0.3595949 ]
 [1.3744494  1.3743275  0.96036065 1.8512903 ]
 [0.21262836 0.66064924 0.50598025 0.67006475]]
[[0.4120565  0.02189492 0.51332504 0.7139932 ]
 [0.20964764 0.00740943 0.04417798 0.64985305]
 [0.5894005  0.10242306 0.8178628  0.31164294]
 [0.0631081  0.07607288 0.07739824 0.04151168]]
[0.33839598 0.04428024 0.32203266 0.2952911 ]
