In [1]:
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True) 

In [2]:
def unstable_fn(x, dtype):
    """
    A function prone to catastrophic cancellation.
    In exact math, this expression tends to 0.5 as x -> 0.
    """
    x = jnp.array(x, dtype=dtype)
    return (jnp.sqrt(1.0+x) - 1.0)/x


def stable_fn(x, dtype):
    """
    A well-conditioned polynomial
    It is much less sensitive to rounding errors.
    """
    x = jnp.array(x, dtype=dtype)
    return x**2 + 2*x + 1 


def central_finite_diff(f, x, h):
    """
    Central difference approximation of f'(x)
    """
    return (f(x+h) - f(x-h))/2*h


x_value = 1e-12
h_value = 1e-8

un_der32 = central_finite_diff(lambda val: unstable_fn(val, dtype=jnp.float32), x_value,h_value)
un_der64 = central_finite_diff(lambda val: unstable_fn(val, dtype=jnp.float64), x_value, h_value)

s_der32 = central_finite_diff(lambda val: stable_fn(val, dtype=jnp.float32), x_value, h_value)
s_der64 = central_finite_diff(lambda val: stable_fn(val, dtype=jnp.float64), x_value, h_value)

grad_stable_32 = jax.grad(lambda x: stable_fn(x, dtype=jnp.float32))
grad_stable_64 = jax.grad(lambda x: stable_fn(x, dtype=jnp.float64))

grad_unstable_32 = jax.grad(lambda x: unstable_fn(x, dtype=jnp.float32))
grad_unstable_64 = jax.grad(lambda x: unstable_fn(x, dtype=jnp.float64))


###########################################



un_value32 = stable_fn(x_value, dtype=jnp.float32)
un_value64 = stable_fn(x_value, dtype=jnp.float64)

s_value32 = stable_fn(x_value, dtype=jnp.float32)
s_value64 = stable_fn(x_value, dtype=jnp.float64)

print("--- Function Values ---")
print(f"float32 value for the unstable fn: {un_value32}")
print(f"float64 value for the unstable fn: {un_value64}")

print(f"float32 value for the stable fn: {s_value32}")
print(f"float64 value for the stable fn: {s_value64}")

print("\n--- Derivativesfor the unstable function ---")
print(f"float32 finite derivative (step={h_value}): {un_der32}")
print(f"float64 finite derivative (step={h_value}): {un_der64}")
print(f"float32 grad autodiff {grad_unstable_32(x_value)}")
print(f"float64 grad autodiff {grad_unstable_64(x_value)}")



print("\n--- Finite-Difference Derivativesfor the stable function ---")
print(f"float32 derivative (step={h_value}): {s_der32}")
print(f"float64 derivative (step={h_value}): {s_der64}")
print(f"float32 grad autodiff {grad_stable_32(x_value)}")
print(f"float64 grad autodiff  {grad_stable_64(x_value)}")




--- Function Values ---
float32 value for the unstable fn: 1.0
float64 value for the unstable fn: 1.000000000002
float32 value for the stable fn: 1.0
float64 value for the stable fn: 1.000000000002

--- Derivativesfor the unstable function ---
float32 finite derivative (step=1e-08): 0.0
float64 finite derivative (step=1e-08): -1.1063371885278173e-17
float32 grad autodiff 499999997952.0
float64 grad autodiff -44450291.42059326

--- Finite-Difference Derivativesfor the stable function ---
float32 derivative (step=1e-08): 0.0
float64 derivative (step=1e-08): 1.9999999989472883e-16
float32 grad autodiff 2.0
float64 grad autodiff  2.000000000002
