<a href="https://colab.research.google.com/github/xinyanz-erin/Applied-Finance-Project/blob/Lilian/European_jax_random.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!curl https://colab.chainer.org/install |sh -
import cupy

In [None]:
# Peter's code

def simple_process(key, initial_stocks, numsteps, drift, cov):
    stocks_init = jnp.zeros((numsteps + 1, initial_stocks.shape[0]))
    stocks_init = jax.ops.index_update(stocks_init,   # jax.ops.index_update(x, idx, y) <-> Pure equivalent of x[idx] = y
                            jax.ops.index[0],         # initialization of stock prices
                            initial_stocks)
    noise = jax.random.multivariate_normal(key, drift, cov, (numsteps+1,))
    sigma = jnp.diag(cov) ** 0.5
    def time_step(t, val):
        dx =  drift + noise[t,:] # no need to multiply by sigma here because noise generated by cov not corr
        dx2 = jnp.exp((drift - sigma ** 2. / 2.) + noise[t,:])
        val = jax.ops.index_update(val,
                            jax.ops.index[t],
                            val[t-1] * dx2)
        return val
    return jax.lax.fori_loop(1, numsteps+1, time_step, stocks_init)[1:] # jax.lax.fori_loop(lower, upper, body_fun, init_val)

def optionvalue(key, initial_stocks, numsteps, drift, cov, strike):
  return jnp.mean((jnp.maximum(batch_simple(keys, initial_stocks, numsteps, drift, cov)[:,-1,:]-strike,0))) # this is assuming 1 stock for testing price (didn't take avg)

def optionvalueavg(key, initial_stocks, numsteps, drift, cov, strike):
  return jnp.mean((jnp.maximum(jnp.mean(batch_simple(keys, initial_stocks, numsteps, drift, cov)[:,-1,:], axis=1)-strike,0))) # this is assuming 3 stocks in basket

In [None]:
# Peter's code

import jax
import jax.numpy as jnp
from jax import random
from jax import jit
import numpy as np
from torch.utils.dlpack import from_dlpack

numstocks = 3

rng = jax.random.PRNGKey(1)
rng, key = jax.random.split(rng)
#numsteps = 10
drift = jnp.array([0.0]*numstocks)

cov = jnp.identity(numstocks)*.25*.25
initial_stocks = jnp.array([100.]*numstocks)

K = 110.0

fast_simple = jax.jit(simple_process, static_argnums=2)
#fast_simple(key, initial_stocks, numsteps, drift, cov)

keys = jax.random.split(key, 1000000)
batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None))

print(initial_stocks) #S = 100
print(K) #K = 110
print(cov) #sigma = 0.25
print(drift) #drift = 0

#################################################################################### values for checking
#S, K, r, sigma, T
# 100, 110, 0, 0.25, 1
# 1 stock price should be around 6.1904
# 3 stock price should be around 2.3767
# delta should be around (0.39888 / numstocks)
####################################################################################

# option price
# 1 stock
print(optionvalue(key, initial_stocks, 1, drift, cov, K)) # numsteps here = years = 1
# 3 stocks basket
print(optionvalueavg(key, initial_stocks, 1, drift, cov, K)) # numsteps here = years = 1

# delta test
gooptionvalue = jax.grad(optionvalue,argnums=1)
gooptionvalue(keys, initial_stocks, 1, drift, cov, K) # numsteps here = years = 1

In [None]:
def simple_process(key, initial_stocks, numsteps, drift, cov, T):
    stocks_init = jnp.zeros((numsteps + 1, initial_stocks.shape[0]))
    stocks_init = jax.ops.index_update(stocks_init,   # jax.ops.index_update(x, idx, y) <-> Pure equivalent of x[idx] = y
                            jax.ops.index[0],         # initialization of stock prices
                            initial_stocks)
    noise = jax.random.multivariate_normal(key, drift, cov, (numsteps+1,))
    sigma = jnp.diag(cov) ** 0.5
    dt = T / numsteps
    def time_step(t, val):
        #dx =  drift + noise[t,:] # no need to multiply by sigma here because noise generated by cov not corr
        dx2 = jnp.exp((drift - sigma ** 2. / 2.) * dt + jnp.sqrt(dt) * noise[t,:])
        val = jax.ops.index_update(val,
                            jax.ops.index[t],
                            val[t-1] * dx2)
        return val
    return jax.lax.fori_loop(1, numsteps+1, time_step, stocks_init)[1:] # jax.lax.fori_loop(lower, upper, body_fun, init_val)

def optionvalue(key, initial_stocks, numsteps, drift, cov, strike, T):
  return jnp.mean((jnp.maximum(batch_simple(keys, initial_stocks, numsteps, drift, cov, T)[:,-1,:]-strike,0))) # this is assuming 1 stock for testing price (didn't take avg)

def optionvalueavg(key, initial_stocks, numsteps, drift, cov, strike, T):
  return jnp.mean((jnp.maximum(jnp.mean(batch_simple(keys, initial_stocks, numsteps, drift, cov, T)[:,-1,:], axis=1)-strike,0))) # this is assuming 3 stocks in basket

numstocks = 3

rng = jax.random.PRNGKey(1)
rng, key = jax.random.split(rng)
numsteps = 50
drift = jnp.array([0.0]*numstocks)

cov = jnp.identity(numstocks)*.25*.25
initial_stocks = jnp.array([100.]*numstocks)

T = 1.0
K = 110.0

fast_simple = jax.jit(simple_process, static_argnums=2)
#fast_simple(key, initial_stocks, numsteps, drift, cov)

keys = jax.random.split(key, 1000000)
batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None, None))

print(initial_stocks) #S = 100
print(K) #K = 110
print(cov) #sigma = 0.25
print(drift) #drift = 0

#################################################################################### values for checking
#S, K, r, sigma, T
# 100, 110, 0, 0.25, 1
# 1 stock price should be around 6.1904
# 3 stock price should be around 2.3767
# delta should be around (0.39888 / numstocks)
####################################################################################

# option price
# 1 stock
print(optionvalue(key, initial_stocks, numsteps, drift, cov, K, T)) # numsteps here = 50, but T = year = 1
# 3 stocks basket
print(optionvalueavg(key, initial_stocks, numsteps, drift, cov, K, T)) # numsteps here = 50, but T = year = 1

# delta test
gooptionvalue = jax.grad(optionvalue,argnums=1)
gooptionvalue(keys, initial_stocks, numsteps, drift, cov, K, T) # numsteps here = 50, but T = year = 1