In [1]:
import jax.numpy as jnp
import numpy as np

# Creating a NumPy array
np_array = np.array([1.0, 2.0, 3.0])

# Creating a JAX Device Array
jax_array = jnp.array([1.0, 2.0, 3.0])

print("NumPy Array:", np_array)
print("JAX Device Array:", jax_array)

NumPy Array: [1. 2. 3.]
JAX Device Array: [1. 2. 3.]


In [11]:
import jax.scipy as jsp

# Example: Computing the exponential integral function
eig_val = jsp.special.exp1(jnp.array([1.0]))

print("Exponential Integral of 1.0:", eig_val)

# Example: Computing the gamma function
gamma_val = jsp.special.gamma(jnp.array(2.5))

print("Gamma function of 2.5:", gamma_val)

# Example: Computing the modified Bessel function of the first kind
bessel_val = jsp.special.i0(jnp.array(1.0))

print("Modified Bessel function I_0(1.0):", bessel_val)

Exponential Integral of 1.0: [0.21938422]
Gamma function of 2.5: 1.3293406
Modified Bessel function I_0(1.0): 1.266066


In [12]:
from jax import lax

# Example: Conditional computation using lax.cond
x = 5
out = lax.cond(x > 0, lambda x: x + 1, lambda x: x - 1, x)

print("Result of lax.cond:", out)

Result of lax.cond: 6


In [4]:
import numpy as np

# Define the function y = x1^2 * sin(x1^2) + e^(x1/x2)
def f(x1, x2):
    return x1**2 * np.sin(x1**2) + np.exp(x1 / x2)

# Define the derivative using central difference with a small h
def numerical_derivative(f, x1, x2, h=1e-5):
    return (f(x1 + h, x2) - f(x1 - h, x2)) / (2 * h)

# Set the value of x1 and x2
x1 = 0.5
x2 = 1.5

# Calculate the value of the function and its numerical derivative at x = 0.5
y = f(x1, x2)
dy_dx1 = numerical_derivative(f, x1, x2)

print("Function value y at x=0.5:", y)
print("Numerical derivative dy/dx1 at x=0.5:", dy_dx1)

Function value y at x=0.5: 1.4574634148997203
Numerical derivative dy/dx1 at x=0.5: 1.4200403482433896


In [4]:
from jax import numpy as jnp
from jax import jacfwd

# Define the function y = x1^2 * sin(x1^2) + e^(x1/x2)
def f(x1, x2):
    return x1**2 * jnp.sin(x1**2) + jnp.exp(x1 / x2)

# Get the derivative function using JAX
df = jacfwd(f, argnums=(0,))

# Set the value of x1 and x2
x1 = 0.5
x2 = 1.5

# Calculate the value of the function and its numerical derivative at x = 0.5
y = f(x1, x2)
dy_dx1 = df(x1, x2)

print("Function value y at x=0.5:", y)
print("Algorithmic derivative dy/dx1 at x=0.5:", dy_dx1)

Function value y at x=0.5: 1.4574635
Algorithmic derivative dy/dx1 at x=0.5: (Array(1.4200404, dtype=float32, weak_type=True),)


In [5]:
import jax
import jax.numpy as jnp

@jax.jit
def example_function(x, y):
    return jnp.sin(x) * jnp.cos(y) + jnp.sin(x)/(1+ jnp.cos(y))

# Example usage
x = jnp.array([[1.0, 2.0], [3.0, 4.0]])
y = jnp.array([[5.0, 6.0], [7.0, 8.0]])
result = example_function(x, y)
print(result)

[[19. 22.]
 [43. 50.]]


In [60]:
import jax
import jax.numpy as jnp
import numpy as np
import time

@jax.jit
def example_function(x, y):
    return jnp.sin(x) * jnp.cos(y) + jnp.sin(x)/(1+ jnp.cos(y))

def example_function_numpy(x, y):
    return jnp.sin(x) * jnp.cos(y) + jnp.sin(x)/(1 + jnp.cos(y))

x_jax = jnp.array([[1.0, 2.0], [3.0, 4.0]])
y_jax = jnp.array([[5.0, 6.0], [7.0, 8.0]])

x_np = np.array([[1.0, 2.0], [3.0, 4.0]])
y_np = np.array([[5.0, 6.0], [7.0, 8.0]])

# Timing JAX implementation
start = time.time()
jax_result = example_function(x_jax, y_jax).block_until_ready()  # block_until_ready() ensures all operations complete
end = time.time()
print("JAX JIT execution time: {:.5} ms".format( (end - start)*1000 ))

# Timing NumPy implementation
start = time.time()
np_result = example_function_numpy(x_np, y_np)
end = time.time()
print("NumPy execution time: {:.5} ms".format( (end - start)*1000 ))

# JAX JIT execution time: 17.981 ms
# NumPy execution time: 1.0002 ms

JAX JIT execution time: 16.37 ms
NumPy execution time: 1.0004 ms


In [61]:
# Timing repeated JAX implementation

N_rep = 10000

start = time.time()
for _ in range(N_rep):
    jax_result = example_function(x_jax, y_jax).block_until_ready()
end = time.time()
print("JAX JIT execution time (10,000 runs): {:.5} ms".format( (end - start)*1000 ))

# Timing repeated NumPy implementation
start = time.time()
for _ in range(N_rep):
    np_result = example_function_numpy(x_np, y_np)
end = time.time()
print("NumPy execution time (10,000 runs): {:.5} ms".format( (end - start)*1000 ))

# JAX JIT execution time (10,000 runs): 56.354 ms
# NumPy execution time (10,000 runs): 661.88 ms

JAX JIT execution time (10,000 runs): 56.106 ms
NumPy execution time (10,000 runs): 665.18 ms
