<a href="https://colab.research.google.com/github/xinyanz-erin/Applied-Finance-Project/blob/Judy/For%20Lilian.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
nstock = 3

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import cupy
import jax
import jax.numpy as jnp
from jax import random
from jax import jit
import numpy as np
import pandas as pd
from torch.utils.dlpack import from_dlpack

def Brownian_motion(key, initial_stocks, numsteps, drift, cov, T):
    stocks_init = jnp.zeros((numsteps + 1, initial_stocks.shape[0]))
    stocks_init = jax.ops.index_update(stocks_init,   # jax.ops.index_update(x, idx, y) <-> Pure equivalent of x[idx] = y
                            jax.ops.index[0],         # initialization of stock prices
                            initial_stocks)
    noise = jax.random.multivariate_normal(key,  jnp.array([0]*initial_stocks.shape[0]), cov, (numsteps+1,)) # noise must have mean 0
    sigma = jnp.diag(cov) ** 0.5
    dt = T / numsteps
    def time_step(t, val):
        dx = jnp.exp((drift - sigma ** 2. / 2.) * dt + jnp.sqrt(dt) * noise[t,:])
        val = jax.ops.index_update(val,
                            jax.ops.index[t],
                            val[t-1] * dx)
        return val
    return jax.lax.fori_loop(1, numsteps+1, time_step, stocks_init)[1:] # jax.lax.fori_loop(lower, upper, body_fun, init_val)

def optionvalueavg(key, initial_stocks, numsteps, drift, r, cov, K, B, T): # down-and-out call
    return jnp.mean(jnp.maximum((1 - jnp.any(jnp.mean(batch_simple(keys, initial_stocks, numsteps, drift, cov, T), axis=2) < B, axis=1).astype(int))* 
                                (jnp.mean(batch_simple(keys, initial_stocks, numsteps, drift, cov, T), axis=2))[:,-1]-K, 0) *
                    jnp.exp(-r[0] * T))

goptionvalueavg = jax.grad(optionvalueavg, argnums=1)

#################################################################### Adjust all parameters here (not inside class)
numstocks = nstock
numsteps = 50
numpaths = 2000000

rng = jax.random.PRNGKey(np.random.randint(10000))
rng, key = jax.random.split(rng)
keys = jax.random.split(key, numpaths)

S1_range = jnp.linspace(0.75, 1.25, 6)[4:6]
S2_range = jnp.linspace(0.75, 1.25, 6)
S3_range = jnp.linspace(0.75, 1.25, 6)
K_range = jnp.linspace(0.75, 1.25, 5)
B_range = jnp.linspace(0.5, 1.0, 6)
sigma_range = jnp.linspace(0.15, 0.45, 3)
r_range = jnp.linspace(0.01, 0.04, 3)
T = 1.0

fast_simple = jax.jit(Brownian_motion, static_argnums=2)
batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None, None))
####################################################################

call = []
count = 0

for S1 in S1_range:
  for S2 in S2_range:
    for S3 in S3_range:
      for K in K_range:
        for B in B_range:
          for r in r_range:
            for sigma in sigma_range:
              
              initial_stocks = jnp.array([S1, S2, S3]) # must be float
              r_tmp = jnp.array([r]*numstocks)
              drift = r_tmp
              cov = jnp.identity(numstocks)*sigma*sigma
              
              Knock_Out_Call_price = optionvalueavg(key, initial_stocks, numsteps, drift, r_tmp, cov, K, B, T)
              Deltas = goptionvalueavg(keys, initial_stocks, numsteps, drift, r_tmp, cov, K, B, T)
              call.append([T, K, B, S1, sigma, r, r,
                          T, K, B, S2, sigma, r, r,
                          T, K, B, S3, sigma, r, r, Knock_Out_Call_price] + list(Deltas)) #T, K, S, sigma, mu, r, price, delta
                          
              count += 1
              print(count)

Thedataset_3stock = pd.DataFrame(call)
Thedataset_3stock.to_csv(f'/content/drive/MyDrive/AFP Project/Lilian/Knock_Out_Call_3stock_MC_Datset.csv', index=False, header=False)

1
2
3
4
5
6
7
8
9
10
11


KeyboardInterrupt: ignored