In [2]:
import jax
import jax.numpy as jnp
import time

## The lambda function

In [3]:
sq_fn = lambda x: jnp.square(x)

In [4]:
sq_fn

<function __main__.<lambda>(x)>

In [5]:
sq_fn(3)

Array(9, dtype=int32, weak_type=True)

In [8]:
def reg_eq_fn(x):
    return jnp.square(x)

In [9]:
reg_eq_fn

<function __main__.reg_eq_fn(x)>

In [11]:
reg_eq_fn(3)

Array(9, dtype=int32, weak_type=True)

## Memory location

In [12]:
id(sq_fn)

140428313390960

In [13]:
id(reg_eq_fn)

140428322022448

## Memory comprehension

In [14]:
y_lc = [jnp.square(x) for x in range(5)]

In [15]:
y_lc

[Array(0, dtype=int32, weak_type=True),
 Array(1, dtype=int32, weak_type=True),
 Array(4, dtype=int32, weak_type=True),
 Array(9, dtype=int32, weak_type=True),
 Array(16, dtype=int32, weak_type=True)]

In [16]:
y_reg = []
for x in range(5):
    y_reg.append(jnp.square(x))

In [17]:
y_reg

[Array(0, dtype=int32, weak_type=True),
 Array(1, dtype=int32, weak_type=True),
 Array(4, dtype=int32, weak_type=True),
 Array(9, dtype=int32, weak_type=True),
 Array(16, dtype=int32, weak_type=True)]

## The enumerate operation

In [18]:
for idx, item in enumerate(y_lc):
    print(idx, 2*idx, item)

0 0 0
1 2 1
2 4 4
3 6 9
4 8 16


## Zip object

In [19]:
list_one = [1, 2, 3]
list_two = [99, 98, 97]
zip_obj = zip(list_one, list_two)

In [21]:
type(zip_obj)

zip

In [22]:
for tuple_obj in zip_obj:
    (num_1, num_2) = tuple_obj
    print(num_1, num_2, num_1 + num_2)

1 99 100
2 98 100
3 97 100


## Decorators and just-in-time (jit) compilers in JAX

In [23]:
def decorator(func):
    def wrapper():
        print("Something is happerinning before the function is called.")
        func()
        print("Something is happerinning after the function is called.")
        
    return wrapper

In [24]:
def print_hello_world():
    print("hello world!")

In [25]:
@decorator
def print_hi_world():
    print("hi world!")

In [26]:
print_hello_world()

hello world!


In [27]:
print_hi_world()

Something is happerinning before the function is called.
hi world!
Something is happerinning after the function is called.


## Array slicing

In [28]:
array_example = [1, 2, 3, 4, 5]

In [29]:
array_example[:3]

[1, 2, 3]

In [30]:
array_example[2:4]

[3, 4]

In [31]:
array_example[:-1]

[1, 2, 3, 4]

## Type hints

In [32]:
def add_two_integers(first_num: int, second_num: int) -> int:
    return first_num + second_num

In [33]:
add_two_integers(1, 2)

3

In [34]:
add_two_integers(1.1, 2.2)

3.3000000000000003

## Random keys in JAX

In [35]:
key = jax.random.PRNGKey(0)

In [36]:
key, new_key = jax.random.split(key)

## Vectorized mapping orvap in JAX

In [47]:
def vec_norm(vec):
    x = vec[0]
    y = vec[1]
    n = jnp.square(x) + jnp.square(y)
    return n

In [48]:
theta = jnp.linspace(0, jnp.pi, 100000)
x = jnp.cos(theta).reshape(-1, 1)
y = jnp.sin(theta).reshape(-1, 1)

vecs = jnp.concatenate((x, y), axis = 1)

In [50]:
start_time_reg = time.time()
list_norm_reg = [vec_norm(vec) for vec in vecs]
end_time_reg = time.time()
exec_time_reg = end_time_reg - start_time_reg

In [51]:
start_time_vm = time.time()
list_norm_vm = jax.vmap(vec_norm)(vecs)
end_time_vm = time.time()
exec_time_vm = end_time_vm - start_time_vm

In [52]:
print(f"execution time (regular): {round(exec_time_reg, 2)} seconds")
print(f"execution time (vmap): {round(exec_time_vm, 2)} seconds")

execution time (regular): 25.31 seconds
execution time (vmap): 0.09 seconds


## Automatic differentiation using jacrev in JAX

In [60]:
dr_dt = jax.jacrev(lambda t: jnp.array([t, jnp.square(t)]))

In [61]:
t_val = jnp.linspace(0, 2, 10)

In [62]:
t_val[0]

Array(0., dtype=float32)

In [63]:
dr_dt(t_val[0])

Array([1., 0.], dtype=float32)

In [64]:
dr_dt_vec = jax.vmap(dr_dt)

In [65]:
dr_dt_vec(t_val)

Array([[1.        , 0.        ],
       [1.        , 0.44444445],
       [1.        , 0.8888889 ],
       [1.        , 1.3333334 ],
       [1.        , 1.7777778 ],
       [1.        , 2.2222223 ],
       [1.        , 2.6666667 ],
       [1.        , 3.1111112 ],
       [1.        , 3.5555556 ],
       [1.        , 4.        ]], dtype=float32)

## OOP basics

In [70]:
class TimeStep:
    def __init__(self, position, velocity):
        self.position = position
        self.velocity = velocity
        self.dt = 0.01
        
        
    def move(self):
        new_x = self.position[0] + self.velocity[0] * self.dt
        new_y = self.position[1] + self.velocity[1] * self.dt
        
        self.position = (new_x, new_y)
        
    def __call__(self):
        self.move()
        return self.position
    
    
class Motion(TimeStep):
    def __init__(self, position, velocity):
        super().__init__(position, velocity)
        
    def compute_trajectory(self, tf):
        vx = 1.
        t = 0.
        
        while t < tf:
            self.move()
            t += self.dt
            vy = 2 * t
            self.velocity = (vx, vy)
            
            
    def __call__(self, tf):
        self.compute_trajectory(tf)
        return self.position

In [71]:
projectile = Motion(position=(0.0, 0.0), velocity=(1.0, 0.0))
projectile_position_after_2s = projectile(tf = 2.)
print(f"Projectile position after 2 seconds: {round(projectile_position_after_2s[0], 2), round(projectile_position_after_2s[1], 2)}")


Projectile position after 2 seconds: (2.0, 3.98)
