# What is JIT

Just-in-time compilation (JIT) is a method for selectively speeding up Python code by compiling specific funtions within a python script. Remember earlier we said coding languages are either compiled (C) or interpreted (Python). JIT is something of a middle-ground between the two. With JIT, a specific function (or functions) is flagged for compiling at run time. The first time that function is called within the code, it will be compiled on the fly by the JIT compiler. This compiled version of the function will then live in memory, and will be used by later calls to the function for a very high performance increase. Note that the first call to the function is actually somewhat slower than if it were not JIT'd, as the compiler has to do its magic to make the compiled function. However generally even one sucessive call will make up the difference. This does mean that JIT generally shines the brightest when applied to functions that are repeatedly called. 

## What should I JIT?

In general, JIT compilers are very good at speeding up numerical code and loops in python. It's not particularly good at e.g. string comprehension. Moreover, JIT compilers do not magically know about every package ever written, although they do know some (almost always numpy). To steal an example from numba, this code would benifit greatly from JIT: 

In [None]:
def go_fast(a): # Function is compiled to machine code when called the first time
    trace = 0.0
    for i in range(a.shape[0]):   # Numba likes loops
        trace += np.tanh(a[i, i]) # Numba likes NumPy functions
    return a + trace              # Numba likes NumPy broadcasting

print(go_fast(x))

while JIT-ing the following code would do almost nothing, and indeed is not possible with some JIT compilers

In [None]:
def use_pandas(a): # Function will not benefit from Numba jit
    df = pd.DataFrame.from_dict(a) # Numba doesn't know about pd.DataFrame
    df += 1                        # Numba doesn't understand what this is
    return df.cov()                # or this!

print(use_pandas(x))

The use cases where JIT is appropriate, however, are generally the performance taxing poritions of your code. It's generally the computations which take a lot of time, and not, say, plotting results. 

## How do I JIT?

JIT-ing a function is almost trivially easy. First you need a JIT compiler; in the case of Python, these are packages which can be conda (or pip) installed. The two most popular options by far are Numba and jax. Generally the performance of these two packages is equivalent, and they support largely the same features (e.g., direct compilation to GPU, automatic vectorization, etc.). They differ in their backend, XLA for jax vs LLVM for Numba. It's not particularly important what those backends are; the key point is that jax supports one large feature, automatic differentiation of arbitrary functions, that Numba does not support. The trade off is that JIT-ing functions with jax imposes some tricky restrictions on what code you can write (e.g., jax arrays are completely immutable, jax functions must be pure). I would reccomend sticking to Numba unless you know you need the features that jax provides. In that case, reading the [jax sharp bits documentation](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) can help you get the hang of coding with jax.

## Coding with Numba

Coding with Numba is as easy as coding with numpy. In fact, Numba was built to work with numpy, and is natively very, very good at JIT-ing code that is largely numpy. First, we import Numba (and numpy). Then, to JIT a function, we simply prepend the function with the `@jit` decorator:

In [12]:
#Example courtousy of Numba

from numba import jit
import numpy as np
import time

x = np.arange(100).reshape(10, 10)

@jit(nopython=True)
def go_fast(a): # Function is compiled and runs in machine code
    trace = 0.0
    for i in range(a.shape[0]):
        trace += np.tanh(a[i, i])
    return a + trace

def go_slow(a): # Function is compiled and runs in machine code
    trace = 0.0
    for i in range(a.shape[0]):
        trace += np.tanh(a[i, i])
    return a + trace

# DO NOT REPORT THIS... COMPILATION TIME IS INCLUDED IN THE EXECUTION TIME!
start = time.perf_counter()
go_fast(x)
end = time.perf_counter()
print("Elapsed (with compilation) = {}s".format((end - start)))

# NOW THE FUNCTION IS COMPILED, RE-TIME IT EXECUTING FROM CACHE
start = time.perf_counter()
go_fast(x)
end = time.perf_counter()
compiled = end - start
print("Elapsed (after compilation) = {}s".format(compiled))


Elapsed (with compilation) = 0.13916220399551094s
Elapsed (after compilation) = 3.3343909308314323e-05s


Note the argument `nopython=True` in the above `@jit` decorator. This tells the compiler to attempt to compile the code with no reference to the Python interpreter whatsoever. This is much faster, however it will fail if objects outside the numpy scope are passed to or called within the function. See the above example where a Pandas dataframe is called within `use_pandas`. In that case, Numba will fall back to an object aware mode, which involves the interpreter. This is much slower, and should be avoided if at all possible. In general this involves "deconstructing" objects which you pass to Numba. In the below contrived example, we have a data class, which contains the data as well as other attributes such as housekeeping data and the date. If we write a function to take in the full class as an argument, Numba will fail:

In [29]:
class TOD:
    def __init__(self, data, housekeeping, date):
        self.data = data
        self.housekeeping = housekeeping
        self.date = date

my_data = np.random.rand(150, 10000)
my_housekeeping = np.random.rand(5, 10000)
my_date = "June 1, 2024"

my_TOD = TOD(my_data, my_housekeeping, my_date)

@jit(nopython=True)
def jit_fft_bad(TOD): #IDK why we're taking a sin of our data
    return np.sin(TOD.data)

jit_fft(my_TOD)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
non-precise type pyobject
During: typing of argument at /tmp/ipykernel_323893/3844633038.py (13)

File "../../../../../tmp/ipykernel_323893/3844633038.py", line 13:
<source missing, REPL/exec in use?> 

This error may have been caused by the following argument(s):
- argument 0: Cannot determine Numba type of <class '__main__.TOD'> 

This error may have been caused by the following argument(s):
- argument 0: Cannot determine Numba type of <class '__main__.TOD'> 

This error may have been caused by the following argument(s):
- argument 0: Cannot determine Numba type of <class '__main__.TOD'> 

This error may have been caused by the following argument(s):
- argument 0: Cannot determine Numba type of <class '__main__.TOD'>


If we instead write it to take a numpy array as its argument and pass it `my_TOD.data`, it will succeed:

In [32]:
@jit(nopython=True)
def jit_fft_good(data):
    return np.sin(data)

jit_fft_good(my_TOD.data)

array([[5.82749896e-01, 5.58756162e-01, 3.67827394e-01, ...,
        8.25043924e-01, 6.66373835e-01, 6.92616439e-01],
       [1.46209823e-01, 3.08533492e-02, 7.28769822e-01, ...,
        3.72495531e-01, 4.93774063e-01, 5.57428505e-01],
       [8.32212642e-01, 5.56117264e-01, 6.21674391e-01, ...,
        5.13366140e-01, 5.43360493e-01, 7.47630492e-01],
       ...,
       [5.28097610e-01, 5.24873373e-01, 4.13912628e-01, ...,
        6.29261231e-01, 2.35466710e-01, 7.75887593e-01],
       [5.05065489e-01, 5.34624834e-01, 4.70674399e-01, ...,
        7.21044899e-01, 2.76443731e-01, 8.14298097e-01],
       [6.69978058e-02, 4.79798892e-01, 7.57841785e-01, ...,
        2.95194044e-01, 2.03142847e-04, 8.41174483e-01]])