<a href="https://colab.research.google.com/github/xinyanz-erin/Applied-Finance-Project/blob/Judy/American_Call_jax_Judy_Test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Change to CF Array(Testing the right jnp lstsq function)


In [None]:
##### Test(Judy)

In [1]:
# np.linalg.lstsq function
# change from CF matrix to CF array
# put
# r is constant here for testing
# large numpaths & numsteps to test price here


import jax
import jax.numpy as jnp
from jax import random
from jax import jit
import numpy as np
from torch.utils.dlpack import from_dlpack

def Brownian_motion(key, initial_stocks, numsteps, drift, cov, T):
    stocks_init = jnp.zeros((numsteps + 1, initial_stocks.shape[0]))
    stocks_init = jax.ops.index_update(stocks_init,   # jax.ops.index_update(x, idx, y) <-> Pure equivalent of x[idx] = y
                            jax.ops.index[0],         # initialization of stock prices
                            initial_stocks)
    noise = jax.random.multivariate_normal(key,  jnp.array([0]*initial_stocks.shape[0]), cov, (numsteps+1,)) # noise must have mean 0
    sigma = jnp.diag(cov) ** 0.5
    dt = T / numsteps
    def time_step(t, val):
        #dx =  drift + noise[t,:] # no need to multiply by sigma here because noise generated by cov not corr
        dx2 = jnp.exp((drift - sigma ** 2. / 2.) * dt + jnp.sqrt(dt) * noise[t,:])
        val = jax.ops.index_update(val,
                            jax.ops.index[t],
                            val[t-1] * dx2)
        return val
    return jax.lax.fori_loop(1, numsteps+1, time_step, stocks_init)[1:] # jax.lax.fori_loop(lower, upper, body_fun, init_val)

def optionvalueavg(key, initial_stocks, numsteps, drift, r, cov, K, T, numpaths):  
  out = batch_simple(keys, initial_stocks, numsteps, drift, cov, T)
  out_avg = jnp.mean(batch_simple(keys, initial_stocks, numsteps, drift, cov, T), axis=2)

  dt = T / numsteps
  ITM_matrix = out_avg <= K # matrix that is true if ITM, else false
  CF_array = jnp.maximum(K-out_avg[:,-1], 0) # array that will store discounted cash flows for the next time period
 
  # update t=T-1 to t=0
  for i in range(2, numsteps+1):
    ITM_index = jnp.where(ITM_matrix[:,-i])[0]
    X = out_avg[ITM_index,-i]
    CF_array = (CF_array * np.exp(-r * dt)) # discount CF_array by one step
    Y = CF_array[ITM_index] # Y is subset of CF_array that is ITM
    X_matrix = jnp.array([X**0,X**1,X**2]).transpose()
    #ECV = X_matrix.dot(jax.numpy.linalg.lstsq(X_matrix, Y)[0]) # predicted Y - ECV - expected continue value
    ECV = X_matrix.dot(jnp.array(np.linalg.lstsq(X_matrix, Y, rcond=None)[0])) # predicted Y - ECV - expected continue value

    # update CF_array
    CF_array = jax.ops.index_update(x=CF_array, 
                                    idx=jax.ops.index[jnp.array(ITM_index[jnp.where(K-X > ECV)[0]])],  # index subsetting paths that has K-X > ECV
                                    y=(K-X)[jnp.where(K-X > ECV)[0]]) # K-X for these paths

  return  (CF_array * np.exp(-r * dt)).mean() # take mean for all paths to get price

numstocks = 1
numpaths = 5000
numsteps = 50

rng = jax.random.PRNGKey(1)
rng, key = jax.random.split(rng)
drift = jnp.array([0.06]*numstocks)
r = drift # let r = drift to match B-S

cov = jnp.identity(numstocks)*0.2*0.2
initial_stocks = jnp.array([40.]*numstocks) # must be float

T = 1.0
K = 40.0

fast_simple = jax.jit(Brownian_motion, static_argnums=2)
#fast_simple(key, initial_stocks, numsteps, drift, cov)

keys = jax.random.split(key, numpaths)
batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None, None))

#################################################################################### values for checking
#S, K, r, drift, sigma, T
# 40, 40, 0.06, 0.06, 0.2(1stock) 0.34641016151(3stock), 1
# option price should be around 2.33
# np price: 2.3384

#S, K, r, drift, sigma, T
# 100, 100, 0.01, 0.01, 0.25, 1
# option price should be around 9.51
# np price: 9.5610
####################################################################################

# option price
print(optionvalueavg(key, initial_stocks, numsteps, drift, r, cov, K, T, numpaths))

# # delta test
# goptionvalueavg = jax.grad(optionvalueavg,argnums=1)
# print(goptionvalueavg(keys, initial_stocks, numsteps, drift, r, cov, K, T, numpaths))

2.3384314


In [3]:
# jnp.linalg.lstsq function
# change from CF matrix to CF array
# put
# r is constant here for testing
# large numpaths & numsteps to test price here

import jax
import jax.numpy as jnp
from jax import random
from jax import jit
import numpy as np
from torch.utils.dlpack import from_dlpack

def Brownian_motion(key, initial_stocks, numsteps, drift, cov, T):
    stocks_init = jnp.zeros((numsteps + 1, initial_stocks.shape[0]))
    stocks_init = jax.ops.index_update(stocks_init,   # jax.ops.index_update(x, idx, y) <-> Pure equivalent of x[idx] = y
                            jax.ops.index[0],         # initialization of stock prices
                            initial_stocks)
    noise = jax.random.multivariate_normal(key,  jnp.array([0]*initial_stocks.shape[0]), cov, (numsteps+1,)) # noise must have mean 0
    sigma = jnp.diag(cov) ** 0.5
    dt = T / numsteps
    def time_step(t, val):
        #dx =  drift + noise[t,:] # no need to multiply by sigma here because noise generated by cov not corr
        dx2 = jnp.exp((drift - sigma ** 2. / 2.) * dt + jnp.sqrt(dt) * noise[t,:])
        val = jax.ops.index_update(val,
                            jax.ops.index[t],
                            val[t-1] * dx2)
        return val
    return jax.lax.fori_loop(1, numsteps+1, time_step, stocks_init)[1:] # jax.lax.fori_loop(lower, upper, body_fun, init_val)

def optionvalueavg(key, initial_stocks, numsteps, drift, r, cov, K, T, numpaths):  
  out = batch_simple(keys, initial_stocks, numsteps, drift, cov, T)
  out_avg = jnp.mean(batch_simple(keys, initial_stocks, numsteps, drift, cov, T), axis=2)

  dt = T / numsteps
  ITM_matrix = out_avg <= K # matrix that is true if ITM, else false
  CF_array = jnp.maximum(K-out_avg[:,-1], 0) # array that will store discounted cash flows for the next time period
 
  # update t=T-1 to t=0
  for i in range(2, numsteps+1):
    ITM_index = jnp.where(ITM_matrix[:,-i])[0]
    X = out_avg[ITM_index,-i]
    CF_array = (CF_array * np.exp(-r * dt)) # discount CF_array by one step
    Y = CF_array[ITM_index] # Y is subset of CF_array that is ITM
    X_matrix = jnp.array([X**0,X**1,X**2]).transpose()
    #ECV = X_matrix.dot(jax.numpy.linalg.lstsq(X_matrix, Y)[0]) # predicted Y - ECV - expected continue value
    ECV = X_matrix.dot(jnp.array(jnp.linalg.lstsq(X_matrix, Y, rcond=-1)[0])) # predicted Y - ECV - expected continue value

    # update CF_array
    CF_array = jax.ops.index_update(x=CF_array, 
                                    idx=jax.ops.index[jnp.array(ITM_index[jnp.where(K-X > ECV)[0]])],  # index subsetting paths that has K-X > ECV
                                    y=(K-X)[jnp.where(K-X > ECV)[0]]) # K-X for these paths

  return  (CF_array * np.exp(-r * dt)).mean() # take mean for all paths to get price

numstocks = 1
numpaths = 5000
numsteps = 50

rng = jax.random.PRNGKey(1)
rng, key = jax.random.split(rng)
drift = jnp.array([0.06]*numstocks)
r = drift # let r = drift to match B-S

cov = jnp.identity(numstocks)*0.2*0.2
initial_stocks = jnp.array([40.]*numstocks) # must be float

T = 1.0
K = 40.0

fast_simple = jax.jit(Brownian_motion, static_argnums=2)
#fast_simple(key, initial_stocks, numsteps, drift, cov)

keys = jax.random.split(key, numpaths)
batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None, None))

#################################################################################### values for checking
#S, K, r, drift, sigma, T
# 40, 40, 0.06, 0.06, 0.2(1stock) 0.34641016151(3stock), 1
# option price should be around 2.33
# np price: 2.3384

#S, K, r, drift, sigma, T
# 100, 100, 0.01, 0.01, 0.25, 1
# option price should be around 9.51
# np price: 9.5610
####################################################################################

# option price
print(optionvalueavg(key, initial_stocks, numsteps, drift, r, cov, K, T, numpaths))

# # delta test
# goptionvalueavg = jax.grad(optionvalueavg,argnums=1)
# print(goptionvalueavg(keys, initial_stocks, numsteps, drift, r, cov, K, T, numpaths))

2.3400316


In [4]:
# # delta test
goptionvalueavg = jax.grad(optionvalueavg,argnums=1)
print(goptionvalueavg(keys, initial_stocks, numsteps, drift, r, cov, K, T, numpaths))

[-0.40192717]
