<a href="https://colab.research.google.com/github/xinyanz-erin/Applied-Finance-Project/blob/Pui/European_Call_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title
import jax.numpy as jnp
from jax import random
from jax import jit
from jax import device_put

def test():
  paths=100000
  steps=252
  stocks=3
  key = random.PRNGKey(10)
  x = random.normal(key, (paths,steps,stocks))
  s0=random.randint(key,(stocks,1),50,100)*1.
  s=s0
  s = device_put(s)
  for i in range(steps):
    for j in range(stocks):
      s+=x[:,i,j]
  return(s0,s)

a,b=test()

print(a,np.mean(b,axis=1))

jtest=jit(test)

c,d=jtest()

print(c,np.mean(d,axis=1))

[[86.]
 [56.]
 [71.]] [85.99678  55.996777 70.99677 ]
[[86.]
 [56.]
 [71.]] [85.99678  55.996777 70.99677 ]


In [None]:
#@title
paths=10
steps=5
stocks=3
key = random.PRNGKey(10)
x = random.normal(key, (paths,steps,stocks))
s0=random.randint(key,(stocks,1),50,100)*1.
s=s0
s = device_put(s)
print(s)
print(x[:,1,1])
print(s+x[:,1,1])

[[86.]
 [56.]
 [71.]]
[-0.32907748 -0.3529078  -0.5357338  -2.402825   -1.3187171  -0.4300647
 -0.42437345  0.22680947  0.3185347  -0.9541187 ]
[[85.67092  85.647095 85.464264 83.597176 84.68128  85.56994  85.57563
  86.22681  86.318535 85.04588 ]
 [55.67092  55.64709  55.464268 53.597176 54.681282 55.569935 55.575626
  56.22681  56.318535 55.045883]
 [70.67092  70.647095 70.464264 68.597176 69.68128  70.56994  70.57563
  71.22681  71.318535 70.04588 ]]


In [None]:
#@title
import jax.numpy as jnp
from jax import random
from jax import jit
import numpy as np

def simple_process(key, initial_values, numsteps, drift, cov):
    stocks_init = jnp.zeros((numsteps + 1,initial_values.shape[0]))
    stocks_init=jax.ops.index_update(stocks_init,  # jax.ops.index_update(x, idx, y) <-> Pure equivalent of x[idx] = y
                            jax.ops.index[0],
                            initial_values)
    noise = jax.random.multivariate_normal(key, drift, cov, (numsteps,))
    #return(noise)
    def time_step(t, val):
        dx =  drift+noise[t,:]
        val = jax.ops.index_update(val,
                            jax.ops.index[t],
                            val[t-1] + dx)
        return val
    return jax.lax.fori_loop(1, numsteps+1, time_step, stocks_init)[1:] # jax.lax.fori_loop(lower, upper, body_fun, init_val)

# def fori_loop(lower, upper, body_fun, init_val): # upper is exclusive
#   val = init_val
#   for i in range(lower, upper):
#     val = body_fun(i, val)
#   return val

numsteps=5
key = random.PRNGKey(10)
drift=jnp.array([0.0]*3)
cov=np.random.random((3,3))
cov=np.matmul(cov,cov.T)
initial_values=jnp.array([100.]*3)
fast_simple = jax.jit(simple_process, static_argnums=2)
init_stocks=jnp.array([100.]*3)
fast_simple(key,init_stocks,numsteps,drift,cov)
# Batch OU sample via vmap
batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None)) 
# An integer or None indicates which array axis to map over for all arguments (with None indicating not to map any axis)

%timeit fast_simple(key, init_stocks, 12, drift, cov)

numsamples=100000 # num of paths
keys = jax.random.split(key, numsamples)
batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None))
%timeit batch_simple(keys, init_stocks, 12, drift, cov)

batch_simple(keys, init_stocks, numsteps, drift, cov).shape

The slowest run took 27769.64 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 16.1 µs per loop
1 loop, best of 5: 177 ms per loop


(100000, 5, 3)

In [None]:
#@title
import jax.numpy as jnp
from jax import random
from jax import jit
import numpy as np

def Brownian_motion(key, initial_stocks, numsteps, drift, cov, sigma, T):
    stocks_init = jnp.zeros((numsteps + 1, initial_stocks.shape[0]))
    stocks_init=jax.ops.index_update(stocks_init,   # jax.ops.index_update(x, idx, y) <-> Pure equivalent of x[idx] = y
                            jax.ops.index[0],       # initialization of stock prices
                            initial_stocks)
    noise = jax.random.multivariate_normal(key, drift, cov, (numsteps,))
    dt = jnp.array(T[0]/numsteps)
    def time_step(t, val):
        dx =  drift * dt * val[t-1,:] + sigma * val[t-1,:] * jnp.sqrt(dt) * noise[t,:]
        val = jax.ops.index_update(val,
                            jax.ops.index[t],
                            val[t-1] + dx)
        return val
    return jax.lax.fori_loop(1, numsteps+1, time_step, stocks_init)[1:] # jax.lax.fori_loop(lower, upper, body_fun, init_val)

# def fori_loop(lower, upper, body_fun, init_val): # upper is exclusive
#   val = init_val
#   for i in range(lower, upper):
#     val = body_fun(i, val)
#   return val

np.random.seed(0)
key = random.PRNGKey(0)
initial_stocks=jnp.array([100.]*3)
numsteps=10
drift=jnp.array([0]*3)
# cov=np.random.random((3,3))
# cov=np.matmul(cov,cov.T)
corr = jnp.diag(jnp.array([1]*3))
sigma = jnp.array([0.3]*3)
cov = (jnp.diag(sigma)).dot(corr).dot(jnp.diag(sigma))
#sigma = jnp.array(np.random.random(3))
#r = jnp.array([np.random.random(1)]*3)
T = jnp.array([1.]*3)

fast_simple = jax.jit(Brownian_motion, static_argnums=2)
fast_simple(key, initial_stocks, numsteps, drift, cov, sigma, T)

# numsamples=100000 # num of paths
# keys = jax.random.split(key, numsamples)

# batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None, None, None))
# %timeit batch_simple(keys, initial_stocks, numsteps, drift, cov, sigma, T)
# batch_simple(keys, initial_stocks, numsteps, drift, cov, sigma, T).shape

DeviceArray([[104.03143 , 103.00068 , 102.59222 ],
             [102.768456, 105.89164 , 100.96432 ],
             [102.81328 ,  99.627396, 102.55864 ],
             [105.48844 , 101.25624 , 104.66836 ],
             [105.8519  , 100.323296, 109.50462 ],
             [106.59    ,  96.38026 , 107.60002 ],
             [107.01583 ,  98.68716 , 104.24871 ],
             [103.734055,  95.48625 , 102.4883  ],
             [103.274185,  94.44483 ,  99.23759 ],
             [102.81635 ,  93.41477 ,  96.08998 ]], dtype=float32)

In [None]:
#@title
key = random.PRNGKey(10)
initial_values=jnp.array([100.]*3)
numsteps=5
drift=jnp.array([0.0]*3)
cov=np.random.random((3,3))
cov=np.matmul(cov, cov.T)

print(key)
print(initial_values)
print(numsteps)
print(drift)
print(cov)

[ 0 10]
[100. 100. 100.]
5
[0. 0. 0.]
[[0.74040995 0.50528954 0.38358708]
 [0.50528954 0.69334581 0.33222558]
 [0.38358708 0.33222558 0.21827158]]


In [None]:
#@title
stocks_init = jnp.zeros((numsteps + 1, initial_values.shape[0]))
print(stocks_init)
stocks_init=jax.ops.index_update(stocks_init,
                        jax.ops.index[0],
                        initial_values)
print(stocks_init)
noise = jax.random.multivariate_normal(key, drift, cov, (numsteps,))
print(noise)

[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
[[100. 100. 100.]
 [  0.   0.   0.]
 [  0.   0.   0.]
 [  0.   0.   0.]
 [  0.   0.   0.]
 [  0.   0.   0.]]
[[-0.05614188 -0.93355596 -0.05524409]
 [ 0.70478475 -0.20070948  0.15282556]
 [ 0.11612918 -0.04694405 -0.03823702]
 [ 0.42249295 -0.36928132 -0.06133213]
 [-0.24204297 -0.2061234  -0.13489908]]


In [1]:
!curl https://colab.chainer.org/install |sh -
import cupy

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  1580  100  1580    0     0   9239      0 --:--:-- --:--:-- --:--:--  9294
+ apt -y -q install cuda-libraries-dev-10-0
Reading package lists...
Building dependency tree...
Reading state information...
cuda-libraries-dev-10-0 is already the newest version (10.0.130-1).
0 upgraded, 0 newly installed, 0 to remove and 40 not upgraded.
+ pip install -q cupy-cuda100  chainer 
+ set +ex
Installation succeeded!


In [None]:
# import jax
# import jax.numpy as jnp
# from jax import random
# from jax import jit
# import numpy as np
# from torch.utils.dlpack import from_dlpack

# def Brownian_motion(key, initial_stocks, numsteps, drift, cov, T):
#     stocks_init = jnp.zeros((numsteps + 1, initial_stocks.shape[0]))
#     stocks_init = jax.ops.index_update(stocks_init,   # jax.ops.index_update(x, idx, y) <-> Pure equivalent of x[idx] = y
#                             jax.ops.index[0],         # initialization of stock prices
#                             initial_stocks)
#     noise = jax.random.multivariate_normal(key, jnp.array([0]*initial_stocks.shape[0]), cov, (numsteps,))
#     dt = jnp.array(T[0]/numsteps)
#     def time_step(t, val):
#         dx =  drift * dt * val[t-1,:] + val[t-1,:] * jnp.sqrt(dt) * noise[t,:] # no need to multiply by sigma here because noise generated by cov not corr
#         val = jax.ops.index_update(val,
#                             jax.ops.index[t],
#                             val[t-1] + dx)
#         return val
#     return jax.lax.fori_loop(1, numsteps+1, time_step, stocks_init)[1:] # jax.lax.fori_loop(lower, upper, body_fun, init_val)

# # def fori_loop(lower, upper, body_fun, init_val): # upper is exclusive
# #   val = init_val
# #   for i in range(lower, upper):
# #     val = body_fun(i, val)
# #   return val

# numstocks = 3

# #np.random.seed(15)
# key = random.PRNGKey(10)
# initial_stocks = jnp.array(np.random.random(numstocks) * 200)
# numsteps = 100
# # drift=jnp.array([0.05] * numstocks) 
# # cov=np.random.random((3,3))
# # cov=np.matmul(cov,cov.T)
# corr = jnp.diag(jnp.array([1]*numstocks)) # assume no correlation between stocks here
# sigma = jnp.array(np.random.random(numstocks) * 0.4)
# cov = (jnp.diag(sigma)).dot(corr).dot(jnp.diag(sigma))
# r = jnp.repeat(jnp.array(np.random.random(1) * 0.1), numstocks)
# drift = r # To match BS, use drift = r
# T = jnp.array([1.] * numstocks)

# fast_simple = jax.jit(Brownian_motion, static_argnums=2)
# fast_simple(key, initial_stocks, numsteps, drift, cov, T)

# numsamples = 100000 # num of paths
# keys = jax.random.split(key, numsamples)

# batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None, None))
# %timeit batch_simple(keys, initial_stocks, numsteps, drift, cov, T)
# out = batch_simple(keys, initial_stocks, numsteps, drift, cov, T)

# K = np.random.random(1) * 200
# European_Call_price = np.mean(np.maximum((np.mean(out[:,numsteps,:], axis=1)) - K, 0) * jnp.exp(-r[0] * T[0]))

# paras = (T, jnp.repeat(jnp.array(K), numstocks), initial_stocks, sigma, drift, r)
# paras = np.column_stack(paras).reshape(1,-1)[0]

# print(from_dlpack(cupy.array(paras).toDlpack()))
# print(European_Call_price)

The slowest run took 187.96 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 2.32 ms per loop
tensor([1.0000e+00, 5.8479e+01, 1.2630e+02, 1.3925e-02, 9.5619e-03, 9.5619e-03,
        1.0000e+00, 5.8479e+01, 1.8647e+02, 3.9301e-01, 9.5619e-03, 9.5619e-03,
        1.0000e+00, 5.8479e+01, 7.6436e+01, 1.8355e-01, 9.5619e-03, 9.5619e-03],
       device='cuda:0')
71.90888


In [7]:
################################# TEST ########################################
%%writefile cupy_dataset.py
import cupy
import jax
import jax.numpy as jnp
from jax import random
from jax import jit
import numpy as np
from torch.utils.dlpack import from_dlpack

def Brownian_motion(key, initial_stocks, numsteps, drift, cov, T):
    stocks_init = jnp.zeros((numsteps + 1, initial_stocks.shape[0]))
    stocks_init = jax.ops.index_update(stocks_init,   # jax.ops.index_update(x, idx, y) <-> Pure equivalent of x[idx] = y
                            jax.ops.index[0],         # initialization of stock prices
                            initial_stocks)
    noise = jax.random.multivariate_normal(key, jnp.array([0]*initial_stocks.shape[0]), cov, (numsteps,))
    dt = jnp.array(T[0]/numsteps)
    def time_step(t, val):
        dx =  drift * dt * val[t-1,:] + val[t-1,:] * jnp.sqrt(dt) * noise[t,:] # no need to multiply by sigma here because noise generated by cov not corr
        val = jax.ops.index_update(val,
                            jax.ops.index[t],
                            val[t-1] + dx)
        return val
    return jax.lax.fori_loop(1, numsteps+1, time_step, stocks_init)[1:] # jax.lax.fori_loop(lower, upper, body_fun, init_val)


class NumbaOptionDataSet(object):
    
    def __init__(self, max_len=10, number_path = 1000, batch=2, seed=15, stocks=3):  # 3 stocks
        self.num = 0
        self.max_length = max_len
        self.N_PATHS = number_path
        self.N_STEPS = 365
        self.N_BATCH = batch
        self.N_STOCKS = stocks
        self.T = 1
        self.seed = seed
        np.random.seed(seed)
        
    def __len__(self):
        return self.max_length
        
    def __iter__(self):
        self.num = 0
        return self
    
    def __next__(self):
        if self.num >= self.max_length:
            raise StopIteration
        
        Y = cupy.zeros((self.N_BATCH, 1 + self.N_STOCKS * 2), dtype=cupy.float32) # output: price, delta1, delta2, delta3, gamma1, gamma2, gamma3
        X = cupy.zeros((self.N_BATCH, self.N_STOCKS * 6), dtype = cupy.float32)

        for op in range(self.N_BATCH):
          
          key = random.PRNGKey(self.seed)
          initial_stocks = jnp.array(np.random.random(self.N_STOCKS) * 200)
          # cov=np.random.random((3,3))
          # cov=np.matmul(cov,cov.T)
          corr = jnp.diag(jnp.array([1]*self.N_STOCKS)) # assume no correlation between stocks here
          sigma = jnp.array(np.random.random(self.N_STOCKS) * 0.4)
          cov = (jnp.diag(sigma)).dot(corr).dot(jnp.diag(sigma))
          r = jnp.repeat(jnp.array(np.random.random(1) * 0.1), self.N_STOCKS)
          drift = r # To match BS, use drift = r
          T = jnp.array([self.T] * self.N_STOCKS)
          K = np.random.random(1) * 200

          fast_simple = jax.jit(Brownian_motion, static_argnums=2)
          fast_simple(key, initial_stocks, self.N_STEPS, drift, cov, T)

          numsamples = self.N_PATHS # num of paths
          keys = jax.random.split(key, numsamples)

          batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None, None))
          out = batch_simple(keys, initial_stocks, self.N_STEPS, drift, cov, T)

          European_Call_price = np.mean(np.maximum((np.mean(out[:,self.N_STEPS,:], axis=1)) - K, 0) * jnp.exp(-r[0] * T[0]))
          Y[op, 0] = European_Call_price
          Y[op, 1:7] = cupy.array(np.arange(0.4, 1.0, 0.1), dtype=cupy.float32) # test: assume some number for greeks

          paras = (T, jnp.repeat(jnp.array(K), self.N_STOCKS), initial_stocks, sigma, drift, r)
          paras = np.column_stack(paras).reshape(1,-1)[0]
          X[op,] = cupy.array(paras)

        self.num += 1
        return (from_dlpack(X.toDlpack()), from_dlpack(Y.toDlpack()))


# ds = NumbaOptionDataSet(2, number_path = 10000, batch = 2, seed = 15, stocks=3)
# for i in ds:
#     print(i)

Overwriting cupy_dataset.py


In [4]:
%%writefile model.py
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np

class Net(nn.Module):

    def __init__(self, hidden=1024):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(6*3, hidden) # remember to change this!
        self.fc2 = nn.Linear(hidden, hidden)
        self.fc3 = nn.Linear(hidden, hidden)
        self.fc4 = nn.Linear(hidden, hidden)
        self.fc5 = nn.Linear(hidden, hidden)
        self.fc6 = nn.Linear(hidden, 7) # 7 outputs: price, delta1, delta2, delta3, gamma1, gamma2, gamma3
        self.register_buffer('norm',
                             torch.tensor([1, 200.0, 200.0, 0.4, 0.1, 0.1]*3)) # don't use numpy here - will give error later

    def forward(self, x):
        # normalize the parameter to range [0-1] 
        x = x / self.norm
        x = F.elu(self.fc1(x))
        x = F.elu(self.fc2(x))
        x = F.elu(self.fc3(x))
        x = F.elu(self.fc4(x))
        x = F.elu(self.fc5(x))
        return self.fc6(x)

Overwriting model.py


In [5]:
!pip install pytorch-ignite



In [11]:
from ignite.engine import Engine, Events
from ignite.handlers import Timer
from torch.nn import MSELoss
from torch.optim import Adam
from ignite.contrib.handlers.param_scheduler import CosineAnnealingScheduler
from ignite.handlers import ModelCheckpoint
from model import Net
from cupy_dataset import NumbaOptionDataSet
timer = Timer(average=True)
model = Net().cuda()
loss_fn = MSELoss()
optimizer = Adam(model.parameters(), lr=1e-3)
dataset = NumbaOptionDataSet(max_len = 100, number_path = 1024, batch = 32, stocks = 3)


def train_update(engine, batch):
    model.train()
    optimizer.zero_grad()
    x = batch[0]
    y = batch[1]
    #print(y)
    y_pred = model(x)
    #print(y_pred)
    loss = loss_fn(y_pred[:,0], y[:,0]) # compute MSE between the 2 arrays
    loss.backward()
    optimizer.step()
    return loss.item()

trainer = Engine(train_update)
log_interval = 20

scheduler = CosineAnnealingScheduler(optimizer, 'lr', 1e-4, 1e-6, len(dataset))
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
timer.attach(trainer,
             start=Events.EPOCH_STARTED,
             resume=Events.ITERATION_STARTED,
             pause=Events.ITERATION_COMPLETED,
             step=Events.ITERATION_COMPLETED)    
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
    iter = (engine.state.iteration - 1) % len(dataset) + 1
    if iter % log_interval == 0:
        print('loss', engine.state.output, 'average time', timer.value(), 'iter num', iter)
        
trainer.run(dataset, max_epochs = 100)

loss 1191.623046875 average time 0.1692088015000081 iter num 20
loss 1496.97216796875 average time 0.08837160692497718 iter num 40


ERROR:ignite.engine.engine.Engine:Engine run is terminating due to exception: 


KeyboardInterrupt: ignored