<a href="https://colab.research.google.com/github/xinyanz-erin/Applied-Finance-Project/blob/Pui/American_Call_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!curl https://colab.chainer.org/install |sh -
import cupy

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  1580  100  1580    0     0  13057      0 --:--:-- --:--:-- --:--:-- 13057
+ apt -y -q install cuda-libraries-dev-10-0
Reading package lists...
Building dependency tree...
Reading state information...
cuda-libraries-dev-10-0 is already the newest version (10.0.130-1).
0 upgraded, 0 newly installed, 0 to remove and 37 not upgraded.
+ pip install -q cupy-cuda100  chainer 
[K     |████████████████████████████████| 58.9 MB 31 kB/s 
[K     |████████████████████████████████| 1.0 MB 31.4 MB/s 
[?25h  Building wheel for chainer (setup.py) ... [?25l[?25hdone
+ set +ex
Installation succeeded!


# Test (Skip this if not trying to test, to make sure that functions are defined correctly in cells below without running this cell)

In [None]:
# now change code such that 'numsteps' does not represent year
# make dt = year / numsteps
# Add r, and notice that noise must have mean 0, not drift, or else it'll give large option prices
# (done)
# after making the changes, the values are still correct

import jax
import jax.numpy as jnp
from jax import random
from jax import jit
import numpy as np
from torch.utils.dlpack import from_dlpack

def Brownian_motion(key, initial_stocks, numsteps, drift, cov, T):
    stocks_init = jnp.zeros((numsteps + 1, initial_stocks.shape[0]))
    stocks_init = jax.ops.index_update(stocks_init,   # jax.ops.index_update(x, idx, y) <-> Pure equivalent of x[idx] = y
                            jax.ops.index[0],         # initialization of stock prices
                            initial_stocks)
    noise = jax.random.multivariate_normal(key,  jnp.array([0]*initial_stocks.shape[0]), cov, (numsteps+1,)) # noise must have mean 0
    sigma = jnp.diag(cov) ** 0.5
    dt = T / numsteps
    def time_step(t, val):
        #dx =  drift + noise[t,:] # no need to multiply by sigma here because noise generated by cov not corr
        dx2 = jnp.exp((drift - sigma ** 2. / 2.) * dt + jnp.sqrt(dt) * noise[t,:])
        val = jax.ops.index_update(val,
                            jax.ops.index[t],
                            val[t-1] * dx2)
        return val
    return jax.lax.fori_loop(1, numsteps+1, time_step, stocks_init)[1:] # jax.lax.fori_loop(lower, upper, body_fun, init_val)

def optionvalue(key, initial_stocks, numsteps, drift, r, cov, K, T):
  return jnp.mean((jnp.maximum(batch_simple(keys, initial_stocks, numsteps, drift, cov, T)[:,-1,:]-K,0)) * jnp.exp(-r[0] * T)) # this is assuming 1 stock for testing price (didn't take avg)
  # must use '-1' not 'numsteps', or else grad will be 0

def optionvalueavg(key, initial_stocks, numsteps, drift, r, cov, K, T):
  return jnp.mean((jnp.maximum(jnp.mean(batch_simple(keys, initial_stocks, numsteps, drift, cov, T)[:,-1,:], axis=1)-K,0)) * jnp.exp(-r[0] * T)) # this is assuming 3 stocks in basket
  # must use '-1' not 'numsteps', or else grad will be 0

numstocks = 3

rng = jax.random.PRNGKey(1)
rng, key = jax.random.split(rng)
numsteps = 50
drift = jnp.array([0.]*numstocks)
r = drift # let r = drift to match B-S

cov = jnp.identity(numstocks)*0.25*0.25
initial_stocks = jnp.array([100.]*numstocks) # must be float

T = 1.0
K = 110.0

fast_simple = jax.jit(Brownian_motion, static_argnums=2)
#fast_simple(key, initial_stocks, numsteps, drift, cov)

keys = jax.random.split(key, 1000000)
batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None, None))

print(initial_stocks) #S
print(K) #K
print(cov) #sigma
print(drift) #drift

#################################################################################### values for checking
#S, K, r = drift, sigma, T
# 100, 110, 0, 0.25, 1
# 1 stock price should be around 6.1904
# 3 stock price should be around 2.3767
# 1 stock case (not basket) delta should be around (0.39888 / numstocks)

#S, K, r = drift, sigma, T
# 200, 180, 0.1, 0.4, 2
# 1 stock price should be around 70.3005
# 3 stock price should be around 58.7488
# 1 stock case (not basket) delta should be around (0.7946 / numstocks)
####################################################################################

# option price
# 1 stock
print(optionvalue(key, initial_stocks, numsteps, drift, r, cov, K, T)) # here numsteps different from T
# 3 stocks basket
print(optionvalueavg(key, initial_stocks, numsteps, drift, r, cov, K, T)) # here numsteps different from T

# delta test
# 1 stock
goptionvalue = jax.grad(optionvalue,argnums=1)
print(goptionvalue(keys, initial_stocks, numsteps, drift, r, cov, K, T)) # here numsteps different from T

# 3 stock basket
goptionvalueavg = jax.grad(optionvalueavg,argnums=1)
print(goptionvalueavg(keys, initial_stocks, numsteps, drift, r, cov, K, T)) # here numsteps different from T

[100. 100. 100.]
110.0
[[0.0625 0.     0.    ]
 [0.     0.0625 0.    ]
 [0.     0.     0.0625]]
[0. 0. 0.]
6.197539
2.365369
[0.13310367 0.13305798 0.13296412]
[0.09370443 0.09377381 0.09371734]


# Construct Neural Net

In [None]:
%%writefile cupy_dataset.py
import cupy
import jax
import jax.numpy as jnp
from jax import random
from jax import jit
import numpy as np
from torch.utils.dlpack import from_dlpack

def Brownian_motion(key, initial_stocks, numsteps, drift, cov, T):
    stocks_init = jnp.zeros((numsteps + 1, initial_stocks.shape[0]))
    stocks_init = jax.ops.index_update(stocks_init,   # jax.ops.index_update(x, idx, y) <-> Pure equivalent of x[idx] = y
                            jax.ops.index[0],         # initialization of stock prices
                            initial_stocks)
    noise = jax.random.multivariate_normal(key,  jnp.array([0]*initial_stocks.shape[0]), cov, (numsteps+1,)) # noise must have mean 0
    sigma = jnp.diag(cov) ** 0.5
    dt = T / numsteps
    def time_step(t, val):
        #dx =  drift + noise[t,:] # no need to multiply by sigma here because noise generated by cov not corr
        dx2 = jnp.exp((drift - sigma ** 2. / 2.) * dt + jnp.sqrt(dt) * noise[t,:])
        val = jax.ops.index_update(val,
                            jax.ops.index[t],
                            val[t-1] * dx2)
        return val
    return jax.lax.fori_loop(1, numsteps+1, time_step, stocks_init)[1:] # jax.lax.fori_loop(lower, upper, body_fun, init_val)

def optionvalueavg(key, initial_stocks, numsteps, drift, r, cov, K, T, keys): # need to pass 'keys'
    return jnp.mean((jnp.maximum(jnp.mean(batch_simple(keys, initial_stocks, numsteps, drift, cov, T)[:,-1,:], axis=1)-K,0)) * jnp.exp(-r[0] * T))
    # must use '-1' not 'numsteps', or else grad will be 0

###################################################################################################
# these 2 functions must be defined outside class in order to be used in 'optionvalueavg' function
fast_simple = jax.jit(Brownian_motion, static_argnums=2)
batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None, None))
###################################################################################################

class OptionDataSet(object):
    
    def __init__(self, max_len, number_path, batch, seed, stocks):
        self.num = 0
        self.max_length = max_len
        self.N_PATHS = number_path
        self.N_STEPS = 50
        self.N_BATCH = batch
        self.N_STOCKS = stocks
        self.T = 1.0 # assume T = 1, use float here
        self.seed = seed
        np.random.seed(seed)
        
    def __len__(self):
        return self.max_length
        
    def __iter__(self):
        self.num = 0
        return self
    
    def __next__(self):
        if self.num >= self.max_length:
            raise StopIteration
        
        Y = cupy.zeros((self.N_BATCH, 1 + self.N_STOCKS), dtype=cupy.float32) # output: price, delta1, delta2, delta3
        X = cupy.zeros((self.N_BATCH, self.N_STOCKS * 6), dtype = cupy.float32)

        for op in range(self.N_BATCH):
          
          rng = jax.random.PRNGKey(self.seed)
          rng, key = jax.random.split(rng)

          ################################################################################################### generate random input numbers

          initial_stocks = jnp.array(np.random.random(self.N_STOCKS) * 200.0)

          corr = jnp.diag(jnp.array([1]*self.N_STOCKS)) # assume no correlation between stocks here
          sigma = jnp.array(np.random.random(self.N_STOCKS) * 0.4)
          cov = (jnp.diag(sigma)).dot(corr).dot(jnp.diag(sigma))

          r = jnp.repeat(jnp.array(np.random.random(1) * 0.1), self.N_STOCKS)
          drift = r # To match BS, use drift = r

          T = self.T
          K = np.random.random(1) * 200.0

          ###################################################################################################
          ################################################################################################### apply functions to compute price and deltas
          
          keys = jax.random.split(key, self.N_PATHS)

          European_Call_price = optionvalueavg(key, initial_stocks, self.N_STEPS, drift, r, cov, K, T, keys) # need to pass 'keys'
          gooptionvalue = jax.grad(optionvalueavg, argnums=1)
          Deltas = gooptionvalue(keys, initial_stocks, self.N_STEPS, drift, r, cov, K, T, keys) # need to pass 'keys'

          ###################################################################################################
          ################################################################################################### store input and output numbers in X and Y

          Y[op, 0] = European_Call_price
          Y[op, 1:4] = cupy.array(Deltas, dtype=cupy.float32)

          paras = (jnp.repeat(jnp.array(T), self.N_STOCKS), jnp.repeat(jnp.array(K), self.N_STOCKS), initial_stocks, sigma, drift, r)
          paras = np.column_stack(paras).reshape(1,-1)[0]
          X[op,] = cupy.array(paras)

          ###################################################################################################

        self.num += 1
        return (from_dlpack(X.toDlpack()), from_dlpack(Y.toDlpack()))


# ds = OptionDataSet(max_len = 2, number_path = 10000, batch = 2, seed = 15, stocks=3) # for testing purpose, use constant seed. When training, change to random seed
# for i in ds:
#     print(i)

Writing cupy_dataset.py


In [None]:
%%writefile model.py
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np

class Net(nn.Module):

    def __init__(self, hidden=1024):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(6*3, hidden) # remember to change this!
        self.fc2 = nn.Linear(hidden, hidden)
        self.fc3 = nn.Linear(hidden, hidden)
        self.fc4 = nn.Linear(hidden, hidden)
        self.fc5 = nn.Linear(hidden, hidden)
        self.fc6 = nn.Linear(hidden, 4) # 4 outputs: price, delta1, delta2, delta3
        self.register_buffer('norm',
                             torch.tensor([1, 200.0, 200.0, 0.4, 0.1, 0.1]*3)) # don't use numpy here - will give error later

    def forward(self, x):
        # normalize the parameter to range [0-1] 
        x = x / self.norm
        x = F.elu(self.fc1(x))
        x = F.elu(self.fc2(x))
        x = F.elu(self.fc3(x))
        x = F.elu(self.fc4(x))
        x = F.elu(self.fc5(x))
        return self.fc6(x)

Writing model.py


# Train Neural Net

In [None]:
!pip install pytorch-ignite

Collecting pytorch-ignite
  Downloading pytorch_ignite-0.4.6-py3-none-any.whl (232 kB)
[?25l[K     |█▍                              | 10 kB 26.2 MB/s eta 0:00:01[K     |██▉                             | 20 kB 28.3 MB/s eta 0:00:01[K     |████▎                           | 30 kB 19.1 MB/s eta 0:00:01[K     |█████▋                          | 40 kB 16.0 MB/s eta 0:00:01[K     |███████                         | 51 kB 10.6 MB/s eta 0:00:01[K     |████████▌                       | 61 kB 9.8 MB/s eta 0:00:01[K     |█████████▉                      | 71 kB 10.1 MB/s eta 0:00:01[K     |███████████▎                    | 81 kB 11.2 MB/s eta 0:00:01[K     |████████████▊                   | 92 kB 9.2 MB/s eta 0:00:01[K     |██████████████                  | 102 kB 10.0 MB/s eta 0:00:01[K     |███████████████▌                | 112 kB 10.0 MB/s eta 0:00:01[K     |█████████████████               | 122 kB 10.0 MB/s eta 0:00:01[K     |██████████████████▎             | 133 kB 10.

In [None]:
# If memory is not enough, try changing parameters and restarting session
# loss will converge

from ignite.engine import Engine, Events
from ignite.handlers import Timer
from torch.nn import MSELoss
from torch.optim import Adam
from ignite.contrib.handlers.param_scheduler import CosineAnnealingScheduler
from ignite.handlers import ModelCheckpoint
from model import Net
from cupy_dataset import OptionDataSet
import numpy as np
timer = Timer(average=True)
model = Net().cuda()
loss_fn = MSELoss()
optimizer = Adam(model.parameters(), lr=1e-3)
dataset = OptionDataSet(max_len = 100, number_path = 1024, batch = 32, seed = np.random.randint(10000), stocks = 3) # must have random seed


def train_update(engine, batch):
    model.train()
    optimizer.zero_grad()
    x = batch[0]
    y = batch[1]
    #print(y)
    y_pred = model(x)
    #print(y_pred)
    loss = loss_fn(y_pred[:,:], y[:,:]) # compute MSE between the 2 arrays
    #print(loss)
    loss.backward()
    optimizer.step()
    return loss.item()

trainer = Engine(train_update)
log_interval = 20

scheduler = CosineAnnealingScheduler(optimizer, 'lr', 1e-4, 1e-6, len(dataset))
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
timer.attach(trainer,
             start=Events.EPOCH_STARTED,
             resume=Events.ITERATION_STARTED,
             pause=Events.ITERATION_COMPLETED,
             step=Events.ITERATION_COMPLETED)    
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
    iter = (engine.state.iteration - 1) % len(dataset) + 1
    if iter % log_interval == 0:
        print('loss', engine.state.output, 'average time', timer.value(), 'iter num', iter)
        
trainer.run(dataset, max_epochs = 100)

 Please refer to the documentation for more details.
  """


loss 308.3535461425781 average time 0.5890158134499984 iter num 20
loss 309.395263671875 average time 0.29650749102499674 iter num 40
loss 319.1919250488281 average time 0.1990294852333278 iter num 60
loss 232.41702270507812 average time 0.15041632617499587 iter num 80
loss 169.25584411621094 average time 0.12112373191999723 iter num 100
loss 41.82872009277344 average time 0.0727944900500006 iter num 20
loss 8.219586372375488 average time 0.038342809124999634 iter num 40
loss 2.8115668296813965 average time 0.026894211649999042 iter num 60
loss 2.778562545776367 average time 0.021186888737497613 iter num 80
loss 2.3859236240386963 average time 0.017774565589996884 iter num 100
loss 1.3037325143814087 average time 0.07461170490000768 iter num 20
loss 1.4311182498931885 average time 0.039529715224995245 iter num 40
loss 0.608491837978363 average time 0.027656891716664706 iter num 60
loss 0.5415693521499634 average time 0.02177082573749658 iter num 80
loss 0.3868729770183563 average time 

ERROR:ignite.engine.engine.Engine:Engine run is terminating due to exception: 


KeyboardInterrupt: ignored

**Save Model**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
model_save_name = 'jax_european_test_2.pth'
path = F"/content/drive/MyDrive/AFP Project/PUI/{model_save_name}" 
torch.save(model.state_dict(), path)

**Load Model**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import torch
model_save_name = 'Sobolev_test_2.pth'
path = F"/content/drive/MyDrive/AFP Project/PUI/{model_save_name}" 
state_dict = torch.load(path)
print(state_dict.keys())

odict_keys(['norm', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias', 'fc4.weight', 'fc4.bias', 'fc5.weight', 'fc5.bias', 'fc6.weight', 'fc6.bias'])


In [None]:
# need to run 'Writing cupy_dataset.py' and 'Writing model.py' above before this
from model import Net
model = Net().cuda()

model.load_state_dict(state_dict)
print(model)

Net(
  (fc1): Linear(in_features=18, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=1024, bias=True)
  (fc3): Linear(in_features=1024, out_features=1024, bias=True)
  (fc4): Linear(in_features=1024, out_features=1024, bias=True)
  (fc5): Linear(in_features=1024, out_features=1024, bias=True)
  (fc6): Linear(in_features=1024, out_features=7, bias=True)
)


**Continue to train model**

In [None]:
# If memory is not enough, try changing parameters and restarting session
# loss will converge

from ignite.engine import Engine, Events
from ignite.handlers import Timer
from torch.nn import MSELoss
from torch.optim import Adam
from ignite.contrib.handlers.param_scheduler import CosineAnnealingScheduler
from ignite.handlers import ModelCheckpoint
from model import Net
from cupy_dataset import OptionDataSet
import numpy as np
timer = Timer(average=True)
#model = Net().cuda()
loss_fn = MSELoss()
optimizer = Adam(model.parameters(), lr=1e-3)
#dataset = OptionDataSet(max_len = 100, number_path = 1024, batch = 32, seed = np.random.randint(10000), stocks = 3) # must have random seed
dataset = OptionDataSet(max_len = 100, number_path = 10000, batch = 8, seed = np.random.randint(10000), stocks = 3) # must have random seed


def train_update(engine, batch):
    model.train()
    optimizer.zero_grad()
    x = batch[0]
    y = batch[1]
    #print(y)
    y_pred = model(x)
    #print(y_pred)
    loss = loss_fn(y_pred[:,:], y[:,:]) # compute MSE between the 2 arrays
    #print(loss)
    loss.backward()
    optimizer.step()
    return loss.item()

trainer = Engine(train_update)
log_interval = 20

scheduler = CosineAnnealingScheduler(optimizer, 'lr', 1e-4, 1e-6, len(dataset))
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
timer.attach(trainer,
             start=Events.EPOCH_STARTED,
             resume=Events.ITERATION_STARTED,
             pause=Events.ITERATION_COMPLETED,
             step=Events.ITERATION_COMPLETED)    
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
    iter = (engine.state.iteration - 1) % len(dataset) + 1
    if iter % log_interval == 0:
        print('loss', engine.state.output, 'average time', timer.value(), 'iter num', iter)
        
trainer.run(dataset, max_epochs = 10)

model_save_name = 'jax_european_test_3.pth'
path = F"/content/drive/MyDrive/AFP Project/PUI/{model_save_name}" 
torch.save(model.state_dict(), path)

loss 3.847933769226074 average time 0.024525942649995615 iter num 20
loss 0.11467143893241882 average time 0.014268842749987697 iter num 40
loss 0.7498358488082886 average time 0.010797686849988017 iter num 60
loss 0.5012774467468262 average time 0.00905447164998634 iter num 80
loss 0.08808335661888123 average time 0.008048051719988507 iter num 100
loss 0.5424646139144897 average time 0.023962988050016065 iter num 20
loss 0.09199616312980652 average time 0.013907441924999375 iter num 40
loss 0.41971147060394287 average time 0.010556156083328005 iter num 60
loss 0.9499955773353577 average time 0.008943276199994443 iter num 80
loss 0.039512861520051956 average time 0.007942020249997768 iter num 100
loss 0.038578908890485764 average time 0.023953661050018126 iter num 20
loss 0.7367726564407349 average time 0.013887499250017754 iter num 40
loss 0.3312565088272095 average time 0.010584145183342268 iter num 60
loss 0.16353926062583923 average time 0.008966675425003245 iter num 80
loss 0.0240

#Results

In [None]:
import torch
inputs = torch.tensor([[1, 110.0, 100.0, 0.25, 0., 0.]*3]).cuda()
model(inputs.float())

# price, delta1, delta2, delta3
# should be around (2.3654, 0.0937, 0.0937, 0.0937)

tensor([[2.3494, 0.0840, 0.0863, 0.0899]], device='cuda:0',
       grad_fn=<AddmmBackward>)