# Some useful (?) tricks for jupyter notebooks

## Jupyter Notebook Extensions
install it: conda install -c conda-forge jupyter_nbextensions_configurator
* Snippets: create one by modifing /jupyter/nbextensions/snippets/snippets.json
* Collapsible headings
* Autopep8
* Variable inspector
* Ruler (80 characters)
* Scratchpad
* Zenmode

In [None]:
%matplotlib inline
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
pd.options.display.max_columns = None

## Convert jupyter notebook into pdf
1. *conda install nbconvert*
2. install [pandoc](https://www.pandoc.org "pandoc")
3. install Tex: [MikTeX](https://miktex.org "MikTeX") for Windows and [MacTeX](http://tug.org/mactex/ "MacTex") for MacOS
4. *jupyter nbconvert --to pdf notebook.ipynb* or *File/Download As/pdf (.tex)*

## Share code
%pastebin *+ cell numbers*

In [None]:
%pastebin 1

# numba
philosophy: find bottlenecks in code and speed them up  
just in time compilation in no python mode (@njit) as opposed to object mode

## A few things I learned
1. Compiling takes time: compile only once (!) and only if necessary
2. Use built-in functions and numpy: well integrated with numba
3. In most cases, gains from jitting much bigger than from parallel computing
4. Loops are ok!
5. Use only jitted function as input for functions to be jitted
6. Function inputs should only be parameters that change
7. Parallel computing comes with an overhead
8. Be careful with global variables
9. Test and retest

## Examples:

In [None]:
import random
import numba
import time
from numba import njit, prange

In [None]:
# example 1: jitting interpolation function
m = 200
n = 30
v_x = np.linspace(0, 1, m)
mat_y = np.random.rand(m, n)
x0 = np.random.rand(n)

def interp_mat(v_x, mat_y, v_x0):
    """interpolation and extrapolation; nth value of v_x0 corresponds to
    nth col of mat_y"""
    v_y0 = np.empty_like(v_x0)
    for col in range(len(v_y0)):
        i = np.fmin(np.searchsorted(v_x, v_x0[col], 'left'), len(v_x)-1)
        v_y0[col] = (mat_y[i-1, col]
                     + (v_x0[col] - v_x[i-1])/(v_x[i] - v_x[i-1])
                     * (mat_y[i, col] - mat_y[i-1, col]))
    return v_y0

interp_jit = numba.njit(interp_mat)

interp_jit(v_x, mat_y, x0)

%timeit interp_mat(v_x, mat_y, x0)
%timeit interp_jit(v_x, mat_y, x0)

In [None]:
#example 2: sum of squared random variables
# loops are ok (for n large enough!)

n = 10_000 # 1_000_000 100_000_000 

# normal function

start = time.time()

def test(n):
    sum = 0
    for i in range(n):
        x = random.uniform(0, 1)
        sum += x**2
    return sum
   
test(n)
print(f'loop: {time.time() - start}')

# jitted function

start = time.time()

jit_test = numba.njit(test) # or use decorator @njit
    
jit_test(n)
print(f'jitted loop: {time.time() - start}')


# straight numpy
start = time.time()

np.sum(np.random.uniform(0,1,n)**2)
print(f'numpy: {time.time() - start}')


# jitted numpy
start = time.time()

@njit
def np_test(n):
    return np.sum(np.random.uniform(0,1,n)**2)

np_test(n)
print(f'jitted numpy: {time.time() - start}')

In [None]:
# example 3: passing extra parameters can be costly

from numba import float64, int32

a = 0
b = 1
n = 10


@njit
def test():
    sum = 0
    for i in range(n):
        x = random.uniform(a, b)
        sum += x**2
    return sum

@njit
def test2(a, b, n):
    sum = 0
    for i in range(n):
        x = random.uniform(a, b)
        sum += x**2
    return sum

@njit(float64(int32, int32, int32))
def test3(a, b, n):
    sum = 0
    for i in range(n):
        x = random.uniform(a, b)
        sum += x**2
    return sum

test()
test2(a, b, n)
test3(a, b, n)

%timeit test()
%timeit test2(a, b, n)
%timeit test3(a, b, n)

In [None]:
# example 4: be careful with global variables

@njit
def add_a(x):
    return a + x

a = 1
print(add_a(1))

a = 2
print(add_a(1)) 

In [None]:
# example 5: parallel computing
from numba import njit, prange

@njit
def test(n):
    sum = 0
    for i in range(n):
        x = random.uniform(a, b)
        sum += x**2
    return sum

def loop(J, n):
    for j in range(J):
        test(n)

@njit(parallel=True)
def loop_par(J, n):
    for j in prange(J):
        test(n)

J, n = 1000, 1000       
%timeit loop(J, n)
%timeit loop_par(J, n)

# Dynamic programming with Python
[quantecon](https://www.quantecon.org "quantecon"): lecture notes, code, packages


## Replication of DFJ (2010)

### The model

* individuals from age 70 to 100 (max) choose consumption every period,
* they have a given initial wealth and income every period and pay for medical expenses,
* they are hit by health shocks, survival shocks and shocks to medical expenses
* there is a consumption floor that guarantees minimal consumption


The timing inside a period is as follows:
1. health status and medical expenses are realized
2. the individual consumes and saves
3. survival shock hits

Value function:
\begin{equation}
\begin{split}
    V_t(x_t, g, h_t, I, \zeta_t) = \max_{c_t, x_{t+1}}  \{ \frac{c_t^{1-\nu}} {1-\nu}
    + \beta s_{g,h,I,t} E_t V_{t+1}(x_{t+1}, g, h_{t+1}, I, \zeta_{t+1}) \}
\end{split}
\end{equation}

subject to:

\begin{equation}
    x_{t+1} = x_t - c_t + y_n(r(x_t-c_t) + y(g, I, t+1), \tau) + b_{t+1} - m_{t+1}    
\end{equation}

\begin{equation}
    \ln{m_t} = m(g,h,I,t) + \sigma(g,h,I,t) \cdot \psi_t    
\end{equation}

\begin{equation} 
    \pi_{j,k,g,I,t} = \Pr(h_{t+1} = k | h_t = j, g, I, t), \; j, k \in \{1,0\}.
\end{equation} 
where
\begin{equation}
    \begin{split}
        b_{t+1} = \max \{0, \underline{c} + m_{t+1} - [x_t - c_t + y_n(r(x_t - c_t) + y(g, I, t+1), \tau)]\}
    \end{split}    
\end{equation}

and
\begin{align}
    &\psi_t = \zeta_t + \xi_t, \; \xi_t \sim N(0,\sigma_\xi^2) \\
    &\zeta_t = \rho_m \zeta_{t-1} + \epsilon_{t}, \; \epsilon_{t} \sim N(0,\sigma_\epsilon^2)
\end{align}

### Basic code

In [None]:
import time
import DFJ_basic
DFJ = DFJ_basic

start = time.time()

cp = DFJ.common_params(3, 2, 10, 500_000) 
ip = DFJ.indiv_params(1, 0.5)
m_c, m_V = DFJ.solve_model(cp, ip)

print(f'time:{time.time()-start}')

# figure: consumption
ax = plt.subplot()
for per in [0, 10, 20]:
    ax.plot(cp.grid_x, m_c[per, :, 5], label=f'period {str(per)}')
ax.legend()
ax.set_title('Decision function')
ax.set_xlabel('cash-on-hand')
ax.set_ylabel('consumption')
plt.show()

### Jitted code

In [None]:
import DFJ

start = time.time()

cp = DFJ.common_params(9, 10, 100, 500_000)
ip = DFJ.indiv_params(1, 0.5)
m_c, m_V = DFJ.solve_model(cp, ip)

print(f'time:{time.time()-start}')

# figure: consumption
ax = plt.subplot(1, 1, 1)
for per in [0, 10, 20]:
    ax.plot(cp.grid_x, m_c[per, :, 5], label=f'period {str(per)}')
ax.legend()
ax.set_title('Decision function')
ax.set_xlabel('cash-on-hand')
ax.set_ylabel('consumption')
plt.show()

### Jitted + parallel code

In [None]:
import DFJ_parallel
DFJ = DFJ_parallel

start = time.time()

cp = DFJ.common_params(9, 10, 100, 500_000)
ip = DFJ.indiv_params(1, 0.5)
m_c, m_V = DFJ.solve_model(cp, ip)

print(f'time:{time.time()-start}')

# figure: consumption
ax = plt.subplot(1, 1, 1)
for per in [0, 10, 20]:
    ax.plot(cp.grid_x, m_c[per, :, 5], label=f'period {str(per)}')
ax.legend()
ax.set_title('Decision function')
ax.set_xlabel('cash-on-hand')
ax.set_ylabel('consumption')
plt.show()

## Comparing speed

In [None]:
import pandas as pd
import importlib
params = [(3,2,10),(9,8,100),(18,16,1000)]
mod_list = ['DFJ_basic', 'DFJ', 'DFJ_parallel']
df = pd.DataFrame(index=mod_list, columns=params)
for mod in mod_list:
    DFJ = importlib.import_module(mod)
    ip = DFJ.indiv_params(1, 0.5)
    for col in params:
        if mod == 'DFJ_basic' and col == (18,16,1000):
            df.loc[mod, col] = 'too long'
            break
        start = time.time()
        cp = DFJ.common_params(col[0], col[1], col[2], 500_000)
        m_c, m_V = DFJ.solve_model(cp, ip)
        df.loc[mod, col] = time.time() - start 

In [None]:
print('Speed comparison: basic, jitted, jitted+parallel\n' + 
     'for # of gridpoints' + 
      '(persistent shock, transitory shock and cash-on-hand)')
display(df)