<a href="https://colab.research.google.com/github/xinyanz-erin/Applied-Finance-Project/blob/Lilian/European_Call_jax_v5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!curl https://colab.chainer.org/install |sh -
import cupy

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  1580  100  1580    0     0   7939      0 --:--:-- --:--:-- --:--:--  7939
+ apt -y -q install cuda-libraries-dev-10-0
Reading package lists...
Building dependency tree...
Reading state information...
cuda-libraries-dev-10-0 is already the newest version (10.0.130-1).
0 upgraded, 0 newly installed, 0 to remove and 40 not upgraded.
+ pip install -q cupy-cuda100  chainer 
[K     |████████████████████████████████| 58.9 MB 32 kB/s 
[K     |████████████████████████████████| 1.0 MB 68.0 MB/s 
[?25h  Building wheel for chainer (setup.py) ... [?25l[?25hdone
+ set +ex
Installation succeeded!


In [7]:
import jax
import jax.numpy as jnp
from jax import random
from jax import jit
from jax import grad
import numpy as np
from torch.utils.dlpack import from_dlpack

In [8]:
N_STOCKS = 1
N_BATCH = 2
N_STEPS = 365
N_PATHS = 1000

In [17]:
def Brownian_motion(key, initial_stocks, numsteps, drift, cov, T):
    stocks_init = jnp.zeros((numsteps + 1, initial_stocks.shape[0]))   # initial_stocks.shape[0] <-> Stocks number
    stocks_init = jax.ops.index_update(stocks_init,   # jax.ops.index_update(x, idx, y) <-> Pure equivalent of x[idx] = y
                            jax.ops.index[0],         # initialization of stock prices
                            initial_stocks)
    noise = jax.random.multivariate_normal(key, jnp.array([0]*initial_stocks.shape[0]), cov, (numsteps,)) #shape<->[numsteps,nstocks]
    dt = jnp.array(T[0]/numsteps)
    def time_step(t, val):
        dx =  drift * dt * val[t-1,:] + val[t-1,:] * jnp.sqrt(dt) * noise[t,:] # no need to multiply by sigma here because noise generated by cov not corr
        val = jax.ops.index_update(val,
                            jax.ops.index[t],
                            val[t-1] + dx)
        return val
    return jax.lax.fori_loop(1, numsteps+1, time_step, stocks_init)[1:] # jax.lax.fori_loop(lower, upper, body_fun, init_val)

   
def OptionPrice(seed, K_val, B_val, S0, sigma_val, r_val, T_val = 1):
  # price = []
  # Y = jnp.zeros(N_BATCH)
  # Y = cupy.zeros(N_BATCH, dtype=cupy.float32)
  # X = cupy.zeros((N_BATCH, N_STOCKS * 7), dtype = cupy.float32)
  # for op in range(N_BATCH):    
  key = random.PRNGKey(seed)
  initial_stocks = jnp.array(N_STOCKS * [S0])
          # cov=np.random.random((3,3))
          # cov=np.matmul(cov,cov.T)
  corr = jnp.diag(jnp.array([1]*N_STOCKS)) # assume no correlation between stocks here
  sigma = jnp.array(N_STOCKS * [sigma_val])
  cov = (jnp.diag(sigma)).dot(corr).dot(jnp.diag(sigma))
  r = jnp.array([r_val] * N_STOCKS)
  drift = r # To match BS, use drift = r
  T = jnp.array([T_val] * N_STOCKS)
  K = K_val
  B = jnp.array([K * B_val] * N_STOCKS)

  fast_simple = jax.jit(Brownian_motion, static_argnums=2)
  fast_simple(key, initial_stocks, N_STEPS, drift, cov, T)

  numsamples = N_PATHS # num of paths
  keys = jax.random.split(key, numsamples)
  # print('keys = ', keys)

  batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None, None))
  avg_across_stocks = batch_simple(keys, initial_stocks, N_STEPS, drift, cov, T)

  # avg_across_stocks = np.mean(out, axis=2) # becomes paths * steps
  
  eff_price = (1 - np.any(avg_across_stocks < B[0], axis=1).astype(int))* avg_across_stocks[:,N_STEPS]  
  Barrier_Call_price = jnp.mean(jnp.maximum(eff_price - K, 0) * jnp.exp(-r[0] * T[0]))
  
  # Y = jax.ops.index_update(Y,jax.ops.index[op],Barrier_Call_price)
    #Y[op] = Barrier_Call_price
    
    #paras = (T, jnp.repeat(jnp.array(K), N_STOCKS), B, initial_stocks, sigma, drift, r)
    
    #paras = np.column_stack(paras).reshape(1,-1)[0]
    #paras = jnp.array(paras)
    #X[op,] = cupy.array(paras)
  #price = ([from_dlpack(X.toDlpack()), from_dlpack(Y.toDlpack())])
  price = Barrier_Call_price
  return price

In [18]:
gradient_func = grad(OptionPrice, (3, 5))
# delta, rho = gradient_func(15, K_val = 200, B_val = -1, S0 = 100, sigma_val = 0.4, r_val = 0.1, T_val = 1)
delta, rho = gradient_func(15, 200.0, -1.0, 100.0, 0.4, 0.1, 1.0)

keys =  [[  89299191 2420819866]
 [ 628360391 1163436746]
 [ 893678722 2970809962]
 ...
 [1275522704   96506818]
 [ 314043706 1814581361]
 [1261880031 3640555791]]


In [16]:
delta

DeviceArray(0., dtype=float32)