<a href="https://colab.research.google.com/github/xinyanz-erin/Applied-Finance-Project/blob/Erin/American_jax_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [28]:
import jax
import jax.numpy as jnp
from jax import random
from jax import jit
import numpy as np
from torch.utils.dlpack import from_dlpack

def Brownian_motion(key, initial_stocks, numsteps, drift, cov, T):
    stocks_init = jnp.zeros((numsteps + 1, initial_stocks.shape[0]))
    stocks_init = jax.ops.index_update(stocks_init,   # jax.ops.index_update(x, idx, y) <-> Pure equivalent of x[idx] = y
                            jax.ops.index[0],         # initialization of stock prices
                            initial_stocks)
    noise = jax.random.multivariate_normal(key,  jnp.array([0]*initial_stocks.shape[0]), cov, (numsteps+1,)) # noise must have mean 0
    sigma = jnp.diag(cov) ** 0.5
    dt = T / numsteps
    def time_step(t, val):
        #dx =  drift + noise[t,:] # no need to multiply by sigma here because noise generated by cov not corr
        dx2 = jnp.exp((drift - sigma ** 2. / 2.) * dt + jnp.sqrt(dt) * noise[t,:])
        val = jax.ops.index_update(val,
                            jax.ops.index[t],
                            val[t-1] * dx2)
        return val
    return jax.lax.fori_loop(1, numsteps+1, time_step, stocks_init)[1:] # jax.lax.fori_loop(lower, upper, body_fun, init_val)

def optionvalueavg(key, initial_stocks, numsteps, drift, r, cov, K, T, numpaths):  
  out = batch_simple(keys, initial_stocks, numsteps, drift, cov, T)
  out_avg = jnp.mean(batch_simple(keys, initial_stocks, numsteps, drift, cov, T), axis=2)
  # print(out_avg)
  dt = T / numsteps
  ITM_matrix = out_avg >= K # matrix that is true if ITM, else false
  # print(ITM_matrix)
  CF_matrix = jnp.zeros((numpaths, numsteps)) # matrix that will store cash flows

  # update time T
  CF_matrix = jax.ops.index_update(x=CF_matrix, 
                                   idx=jax.ops.index[:,-1], 
                                   y=jnp.maximum(K-out_avg[:,-1], 0))
  # print(CF_matrix) 
  # jnp array is immutable, so we'll need to reassign it every time we update. This is simply x[idx]=y
 
  # update t=T-1 to t=0
  for i in range(2, numsteps+1):
    ITM_index = jnp.where(ITM_matrix[:,-i])[0]
    # print(ITM_index)
    X = out_avg[ITM_index,-i] # underlying price that are in the money
    Y = CF_matrix[ITM_index,-i:].dot(jnp.exp(-r[0]*jnp.arange(i) * dt)) # Y is discounted futrue cashflow
    X_matrix = jnp.array([X**0,X**1,X**2]).transpose()

    ###############################################################
    sig = jnp.diag(jnp.array([1]*len(X)))
    # beta_temp = jnp.dot(jnp.dot(jnp.dot(jnp.linalg.inv(jnp.dot(jnp.dot(X_matrix.transpose(), jnp.linalg.inv(sig)), X_matrix)),\
    #                                     X_matrix.transpose()), jnp.linalg.inv(sig)), Y)
    beta_temp = jnp.dot(jnp.dot(jnp.linalg.inv(jnp.dot(X_matrix.transpose(), X_matrix)),X_matrix.transpose()), Y)
    # print(beta_temp)
    ##############################################################

    # print(X_matrix)
    # ECV = X_matrix.dot(jax.numpy.linalg.lstsq(X_matrix, Y)[0]) # predicted Y - ECV - expected continue value
    ECV = X_matrix.dot(beta_temp)
    # print(ECV)
    # print(X-K)
    # update column i in CF_matrix
    CF_matrix = jax.ops.index_update(x=CF_matrix, 
                                     idx=jax.ops.index[jnp.array(ITM_index),-i],  # index subsetting ITM paths for each timestep
                                     y=(K-X > ECV).astype(int) * (K-X)) # X-K if X-K > ECV, else 0
    # let columns after the first non-zero number be 0 in each row in CF_matrix
    CF_matrix = jax.ops.index_update(x=CF_matrix, 
                                  idx=jax.ops.index[ITM_index[jnp.where(K-X > ECV)[0]], -(i-1):], 
                                  # only for paths where X-K > ECV, update columns after to 0
                                  y=0)
    # print(CF_matrix)

  return  CF_matrix.dot(jnp.exp(-r[0]*jnp.arange(numsteps) * dt)).mean() # discount cash matrix and take mean for all paths to get price

numstocks = 1
numpaths = 10000
numsteps = 20

rng = jax.random.PRNGKey(1)
rng, key = jax.random.split(rng)

drift = jnp.array([0.01]*numstocks)
r = drift # let r = drift to match B-S

cov = jnp.identity(numstocks)*0.25*0.25
initial_stocks = jnp.array([100.]*numstocks) # must be float

T = 1.0
K = 100.0

fast_simple = jax.jit(Brownian_motion, static_argnums=2)
#fast_simple(key, initial_stocks, numsteps, drift, cov)

keys = jax.random.split(key, numpaths)
batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None, None))

# option price
# 3 stocks basket
print(optionvalueavg(key, initial_stocks, numsteps, drift, r, cov, K, T, numpaths)) # here numsteps different from T

# # delta test
# # 3 stock basket
# goptionvalueavg = jax.grad(optionvalueavg,argnums=1)
# print(goptionvalueavg(keys, initial_stocks, numsteps, drift, r, cov, K, T, numpaths)) # here numsteps different from T

9.576359
