<a href="https://colab.research.google.com/github/xinyanz-erin/Applied-Finance-Project/blob/Pui/Peter_American_Put_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# American Put

In [1]:
# change from CF matrix to CF array
# put
# large numpaths & numsteps to test price here
# use lax.scan
# replace X=0 and Y=0 for OTM
# add inner loop for multiple batches: can have more numpaths by increasing numbatches
# if run out of memory, try restarting the session

import jax
import jax.numpy as jnp
from jax import random
from jax import jit
import numpy as np

def Brownian_motion(key, initial_stocks, numsteps, drift, cov, T):
    stocks_init = jnp.zeros((numsteps + 1, initial_stocks.shape[0]))
    stocks_init = jax.ops.index_update(stocks_init,   # jax.ops.index_update(x, idx, y) <-> Pure equivalent of x[idx] = y
                            jax.ops.index[0],         # initialization of stock prices
                            initial_stocks)
    noise = jax.random.multivariate_normal(key,  jnp.array([0]*initial_stocks.shape[0]), cov, (numsteps+1,)) # noise must have mean 0
    sigma = jnp.diag(cov) ** 0.5
    dt = T / numsteps
    def time_step(t, val):
        dx = jnp.exp((drift - sigma ** 2. / 2.) * dt + jnp.sqrt(dt) * noise[t,:])
        val = jax.ops.index_update(val,
                            jax.ops.index[t],
                            val[t-1] * dx)
        return val
    return jax.lax.fori_loop(1, numsteps+1, time_step, stocks_init)[1:] # jax.lax.fori_loop(lower, upper, body_fun, init_val)

def optionvalueavg(key, initial_stocks, numsteps, drift, r, cov, K, T, numpaths):  
  numbatches = 10 # if numbatches is larger, max numpaths can be larger. Note: the larger the numbatches, the longer the computation will take
  out = batch_simple(keys, initial_stocks, numsteps, drift, cov, T)
  out_avg = jnp.mean(batch_simple(keys, initial_stocks, numsteps, drift, cov, T), axis=2)

  dt = T / numsteps
  ITM_matrix = out_avg <= K # matrix that is true if ITM, else false
  CF_array = jnp.maximum(K-out_avg[:,-1], 0) # array that will store discounted cash flows for the next time period

  def for_loop(numsteps, r, K, dt, ITM_matrix, CF_array, numpaths, numbatches):
    def body_fun(val, i):
      ITM_boolean = ITM_matrix[:,-i]
      X = jnp.where(ITM_boolean, out_avg[:,-i], 0) # if ITM, use out_avg, if not, use 0
      val = (val * jnp.exp(-r * dt)) # discount CF_array by one step
      Y = jnp.where(ITM_boolean, val, 0) # if ITM, Y = CF array, if not, use 0
      X_matrix = jnp.array([X**0,X**1,X**2,X**3]).transpose()

      ECV = jnp.zeros(numpaths) # initialize ECV array with 0's
      X_matrix_3D = X_matrix.reshape((numbatches, -1, 4))
      Y_3D = Y.reshape((numbatches, -1))
      ITM_boolean_3D = ITM_boolean.reshape((numbatches, -1))

      # inner loop (more batches so that numpaths can be larger without running out of memory)
      def ECV_fun(val_2, i_2):
        ECV_sub = X_matrix_3D[i_2].dot(jax.numpy.linalg.lstsq(X_matrix_3D[i_2], Y_3D[i_2], rcond=-1)[0]) # predicted Y - ECV - expected continue value
        ECV_sub = jnp.where(ITM_boolean_3D[i_2], ECV_sub, 10000) # replace not ITM ECV by 10000 (so that K-X < ECV)
        val_2 = jax.lax.dynamic_update_slice(val_2,
                                             ECV_sub,
                                             ((numpaths*i_2/numbatches).astype(int),)) # update the part for this batch in ECV array
        return val_2, i_2

      counts_2 = jnp.arange(numbatches)
      ECV, _ =  jax.lax.scan(ECV_fun, ECV, counts_2)

      # update CF_array
      KX_larger_than_ECV = (K-X > ECV)
      val = KX_larger_than_ECV*(K-X) + (1-KX_larger_than_ECV)*val
      return val, i

    counts = jnp.arange(2, numsteps+1)
    CF_array, _ =  jax.lax.scan(body_fun, CF_array, counts)
    return (CF_array * jnp.exp(-r * dt)).mean()

  CF_out = for_loop(numsteps, r.copy()[0], K, dt, ITM_matrix, CF_array, numpaths, numbatches)
  return CF_out

numstocks = 1
numpaths = 100000
numsteps = 200

rng = jax.random.PRNGKey(np.random.randint(10000))
rng, key = jax.random.split(rng)
drift = jnp.array([0.05]*numstocks)
r = drift

cov = jnp.identity(numstocks)*0.25*0.25
initial_stocks = jnp.array([1.]*numstocks) # must be float

T = 1.0
K = 1.0

fast_simple = jax.jit(Brownian_motion, static_argnums=2)

keys = jax.random.split(key, numpaths)
batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None, None))

# option price
print(optionvalueavg(key, initial_stocks, numsteps, drift, r, cov, K, T, numpaths))
%timeit(optionvalueavg(key, initial_stocks, numsteps, drift, r, cov, K, T, numpaths))

# delta
goptionvalueavg = jax.grad(optionvalueavg,argnums=1)
print(goptionvalueavg(keys, initial_stocks, numsteps, drift, r, cov, K, T, numpaths))
%timeit(goptionvalueavg(keys, initial_stocks, numsteps, drift, r, cov, K, T, numpaths))

0.08003467
1 loop, best of 5: 1.17 s per loop
[-0.41563094]
1 loop, best of 5: 11.2 s per loop
