#### Pure functions
1. Pure functions are those that:
    - Return the same values when invoked with the same inputs, i.e. stateless
    - No side effects observed on the function call
2. Several examples of impure functions are shown below:

##### Case I: Globals
1. In the following, two calls to the same function occurred with different value for the global variable. Both calls yield the same results. This is due to the fact that the function signature in the second call is identical to that of the first. As such, the compiler would return the cached operation of the first call instead of recompiling the updated function

In [1]:
import numpy as onp
import jax.numpy as jnp
from jax import grad, jit, vmap, make_jaxpr, random, lax

# A global variable
counter = 5

def add_global_value(x):
    """
    A function that relies on the global variable `counter` for doing some computation.
    """
    return x + counter

x = 2

# using global variable in first call
y = jit(add_global_value)(x)
print("[counter] = {}, calling y(x) gives {}".format(counter, y))

# using updated global variable in second call
counter = 10
y = jit(add_global_value)(x)
print("[counter] = {}, calling y(x) gives {}".format(counter, y))

# output: first and second calls give the same results



[counter] = 5, calling y(x) gives 7
[counter] = 10, calling y(x) gives 7


##### Case II: Iterators
1. In the following, two functions of the same logic behavior are created. One iterates manually through the array and sums all the numbers, while the other uses an iterator object to do the same.
2. The former method gives the right solution while the latter gets stuck at the first element. This is because an iterator object is a not a stateless/memoryless object, and this cannot be handled by the *fori_loop* function. The iterator is only 'traced' at compile time and hence not updated during runtime.

In [2]:
def add_elements(array, start, end, initial_value=0):

    # loop_fn's usage should look like val = loop_fn(i, val)
    def loop_fn(i, val):
        return val + array[i]

    return lax.fori_loop(start, end, loop_fn, initial_value)

array = jnp.arange(5)
print("Adding all elements yield {}".format(add_elements(array, 0, len(array), 0)))

def add_elements(iterator, start, end, initial_value=0):

    # loop_fn's usage should look like val = loop_fn(i, val)
    def loop_fn(i, val):
        return val + next(iterator)

    return lax.fori_loop(start, end, loop_fn, initial_value)

iterator = iter(onp.arange(5))
print("Adding all elements yield {}".format(add_elements(iterator, 0, 5, 0)))

Adding all elements yield 10
Adding all elements yield 0


##### Pure functions with stateful objects
1. Not all stateful objects are impure. As long as the stateful object is not dependent on any external state, including the IO (print), a function is considered pure.
2. An example is shown below. Clearly, *dict* is a stateful object, but since it is created inside the function, it's state is only updated within the function scope and thus do not depend on external states.

In [3]:
def pure_function_with_stateful_obejcts(array):
    array_dict = {}
    for i in range(len(array)):
        array_dict[i] = array[i] + 10
    return array_dict

array = jnp.arange(5)

# First call to the function
print(f"Value returned on first call: {jit(pure_function_with_stateful_obejcts)(array)}")
# Second call to the function with different value
print(f"\nValue returned on second call: {jit(pure_function_with_stateful_obejcts)(array)}")

Value returned on first call: {0: DeviceArray(10, dtype=int32), 1: DeviceArray(11, dtype=int32), 2: DeviceArray(12, dtype=int32), 3: DeviceArray(13, dtype=int32), 4: DeviceArray(14, dtype=int32)}

Value returned on second call: {0: DeviceArray(10, dtype=int32), 1: DeviceArray(11, dtype=int32), 2: DeviceArray(12, dtype=int32), 3: DeviceArray(13, dtype=int32), 4: DeviceArray(14, dtype=int32)}
