<a href="https://colab.research.google.com/github/xinyanz-erin/Applied-Finance-Project/blob/Judy/Barrier_Call_jax_v2_Vega_Single.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [25]:
!curl https://colab.chainer.org/install |sh -
import cupy

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  1580  100  1580    0     0   6076      0 --:--:-- --:--:-- --:--:--  6053100  1580  100  1580    0     0   6076      0 --:--:-- --:--:-- --:--:--  6053
+ apt -y -q install cuda-libraries-dev-10-0
Reading package lists...
Building dependency tree...
Reading state information...
cuda-libraries-dev-10-0 is already the newest version (10.0.130-1).
0 upgraded, 0 newly installed, 0 to remove and 37 not upgraded.
+ pip install -q cupy-cuda100  chainer 
+ set +ex
Installation succeeded!


# Test (Skip this if not trying to test, to make sure that functions are defined correctly in cells below without running this cell)

In [26]:
# Knock out call

# now change code such that 'numsteps' does not represent year
# make dt = year / numsteps
# Add r, and notice that noise must have mean 0, not drift, or else it'll give large option prices
# (done)
# after making the changes, the values are still correct

import jax
import jax.numpy as jnp
from jax import random
from jax import jit
import numpy as np
from torch.utils.dlpack import from_dlpack

def Brownian_motion(key, initial_stocks, numsteps, drift, cov, T):
    stocks_init = jnp.zeros((numsteps + 1, initial_stocks.shape[0]))
    stocks_init = jax.ops.index_update(stocks_init,   # jax.ops.index_update(x, idx, y) <-> Pure equivalent of x[idx] = y
                            jax.ops.index[0],         # initialization of stock prices
                            initial_stocks)
    noise = jax.random.multivariate_normal(key,  jnp.array([0]*initial_stocks.shape[0]), cov, (numsteps+1,)) # noise must have mean 0
    sigma = jnp.diag(cov) ** 0.5
    dt = T / numsteps
    def time_step(t, val):
        #dx =  drift + noise[t,:] # no need to multiply by sigma here because noise generated by cov not corr
        dx2 = jnp.exp((drift - sigma ** 2. / 2.) * dt + jnp.sqrt(dt) * noise[t,:])
        val = jax.ops.index_update(val,
                            jax.ops.index[t],
                            val[t-1] * dx2)
        return val
    return jax.lax.fori_loop(1, numsteps+1, time_step, stocks_init)[1:] # jax.lax.fori_loop(lower, upper, body_fun, init_val)

def optionvalueavg(key, initial_stocks, numsteps, drift, r, cov, K, B, T): # down-and-out call
    # print(batch_simple(keys, initial_stocks, numsteps, drift, cov, T))
    # print(jnp.mean(batch_simple(keys, initial_stocks, numsteps, drift, cov, T), axis=2))
    # print((1 - jnp.any(jnp.mean(batch_simple(keys, initial_stocks, numsteps, drift, cov, T), axis=2) < B, axis=1).astype(int)))
    return jnp.mean(jnp.maximum((1 - jnp.any(jnp.mean(batch_simple(keys, initial_stocks, numsteps, drift, cov, T), axis=2) < B, axis=1).astype(int))* 
                                (jnp.mean(batch_simple(keys, initial_stocks, numsteps, drift, cov, T), axis=2))[:,-1]-K, 0) *
                    jnp.exp(-r[0] * T))
    # must use '-1' not 'numsteps', or else grad will be 0

numstocks = 3

rng = jax.random.PRNGKey(1)
rng, key = jax.random.split(rng)
numsteps = 50
drift = jnp.array([0.]*numstocks)
r = drift # let r = drift to match B-S

cov = jnp.identity(numstocks)*0.5*0.5
initial_stocks = jnp.array([100.]*numstocks) # must be float

T = 1.0
K = 60.0
B = 78.4 # if B is set to 0, equivalent to European call

fast_simple = jax.jit(Brownian_motion, static_argnums=2)
#fast_simple(key, initial_stocks, numsteps, drift, cov)

keys = jax.random.split(key, 100000)
batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None, None))

# print(initial_stocks) #S
# print(K) #K
# print(B) #B
# print(cov) #sigma
# print(drift) #drift

# option price
# 3 stocks basket
print(optionvalueavg(key, initial_stocks, numsteps, drift, r, cov, K, B, T)) # here numsteps different from T

# delta
# 3 stock basket
goptionvalueavg = jax.grad(optionvalueavg,argnums=1)
print(goptionvalueavg(keys, initial_stocks, numsteps, drift, r, cov, K, B, T)) # here numsteps different from T

32.85719
[0.22481976 0.22449198 0.22490814]


In [27]:
# Knock in call
# B>K
# Once price reaches B, get immediate payoff (S[knock-in]-K)
# change r to an array of length 'numsteps'

import jax
import jax.numpy as jnp
from jax import random
from jax import jit
import numpy as np
from torch.utils.dlpack import from_dlpack

def Brownian_motion(key, initial_stocks, numsteps, drift, cov, T):
    stocks_init = jnp.zeros((numsteps + 1, initial_stocks.shape[0]))
    stocks_init = jax.ops.index_update(stocks_init,   # jax.ops.index_update(x, idx, y) <-> Pure equivalent of x[idx] = y
                            jax.ops.index[0],         # initialization of stock prices
                            initial_stocks)
    noise = jax.random.multivariate_normal(key,  jnp.array([0]*initial_stocks.shape[0]), cov, (numsteps+1,)) # noise must have mean 0
    sigma = jnp.diag(cov) ** 0.5
    dt = T / numsteps
    def time_step(t, val):
        #dx =  drift + noise[t,:] # no need to multiply by sigma here because noise generated by cov not corr
        dx2 = jnp.exp((drift - sigma ** 2. / 2.) * dt + jnp.sqrt(dt) * noise[t,:])
        val = jax.ops.index_update(val,
                            jax.ops.index[t],
                            val[t-1] * dx2)
        return val
    return jax.lax.fori_loop(1, numsteps+1, time_step, stocks_init)[1:] # jax.lax.fori_loop(lower, upper, body_fun, init_val)

def optionvalueavg(key, initial_stocks, numsteps, drift, r, cov, K, B, T, numpaths): # up-and-in call
    out = batch_simple(keys, initial_stocks, numsteps, drift, cov, T)
    knock_out_index = jnp.argmax(jnp.mean(out, axis=2) >= B, axis=1)
    
    # print(knock_out_index)
    # print(r)
    # print(jnp.cumsum(r)) # cumsum of r
    # print(jnp.tile(jnp.cumsum(r), (numpaths, 1))) # repeat cumsum r array 'numpaths' times as a matrix
    # print(jnp.tile(jnp.cumsum(r), (numpaths, 1))[jnp.arange(numpaths), knock_out_index]) # index cumsum matrix using knock out index
    # print(jnp.tile(jnp.cumsum(r), (numpaths, 1))[jnp.arange(numpaths), knock_out_index]/(knock_out_index+1)) # divided by (knock out index + 1) to get mean

    r_mean_array = jnp.tile(jnp.cumsum(r), (numpaths, 1))[jnp.arange(numpaths), knock_out_index]/(knock_out_index+1)

############################# Too many lines
    # count_matrix = np.repeat([np.arange(numsteps)], numpaths, axis=0) # matrix [[0,1,2,...],[0,1,2,...],...]
    # bool_matrix = count_matrix <= knock_out_index[:,None] # matrix: 1: index <= knock_out_index for each row, else 0
    # r_matrix = jnp.tile(r, (numpaths, 1)) * bool_matrix.astype(int) # repeat r array 'numpaths' times as a matrix, then multiply by bool_matrix
    # masked = np.ma.masked_equal(r_matrix, 0) # mask so that 0 becomes invalid (for computing mean later)
    # r_mean_array = masked.mean(axis=1)

    # print(out)
    # print(jnp.mean(out, axis=2))
    # print(jnp.all(jnp.mean(out, axis=2) < B, axis=1))
    # print(knock_out_index)
    # print((1 - jnp.all(jnp.mean(out, axis=2) < B, axis=1).astype(int)) *
    #       (jnp.mean(out, axis=2)[jnp.arange(numpaths), knock_out_index] - K))    # (1 or 0) * (S[knock-in]-K)
    # print(count_matrix)
    # print(bool_matrix)
    # print(r_matrix)
    # print(masked)
    # print(r_mean_array)
#############################

############################# This method is slow
    # print(jnp.nan_to_num(jnp.array([r_matrix[i,:(index+1)].mean() for i,index in enumerate(knock_out_index)])))
    # print(T * (knock_out_index+1) / numsteps)
    # print((1 - jnp.all(jnp.mean(out, axis=2) < B, axis=1).astype(int)) *  # knock in: 1, else: 0
    #                 (jnp.mean(out, axis=2)[jnp.arange(numpaths), knock_out_index] - K) *   # (S[knock-in]-K)
    #                 jnp.exp(-jnp.nan_to_num(jnp.array([r_matrix[i,:(index+1)].mean() for i,index in enumerate(knock_out_index)])) * (T * (knock_out_index+1) / numsteps)))

    # return jnp.mean((1 - jnp.all(jnp.mean(out, axis=2) < B, axis=1).astype(int)) *  # knock in: 1, else: 0
    #                 (jnp.mean(out, axis=2)[jnp.arange(numpaths), knock_out_index] - K) *   # (S[knock-in]-K)
    #                 jnp.exp(-jnp.nan_to_num(jnp.array([r_matrix[i,:(index+1)].mean() for i,index in enumerate(knock_out_index)])) * (T * (knock_out_index+1) / numsteps))) # (exp(-mean(r until payoff) * (t until payoff)))
#############################

    return jnp.mean((1 - jnp.all(jnp.mean(out, axis=2) < B, axis=1).astype(int)) *  # knock in: 1, else: 0
                    (jnp.mean(out, axis=2)[jnp.arange(numpaths), knock_out_index] - K) *   # (S[knock-in]-K)
                    jnp.exp(- r_mean_array * (T * (knock_out_index+1) / numsteps))) # (exp(-mean(r until payoff) * (t until payoff)))

numstocks = 3

np.random.seed(np.random.randint(10000))

rng = jax.random.PRNGKey(np.random.randint(10000))
rng, key = jax.random.split(rng)
numsteps = 500
numpaths = 100000
drift = jnp.array([0.]*numstocks)
r = jnp.array(np.random.random(numsteps) * 0.1) # r is an array now

cov = jnp.identity(numstocks)*0.5*0.5
initial_stocks = jnp.array([100.]*numstocks) # must be float

T = 1.0
K = 100.0
B = 120.0

fast_simple = jax.jit(Brownian_motion, static_argnums=2)
#fast_simple(key, initial_stocks, numsteps, drift, cov)

keys = jax.random.split(key, numpaths)
batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None, None))

# print(initial_stocks) #S
# print(K) #K
# print(B) #B
# print(cov) #sigma
# print(drift) #drift
# print(r) #r
 

# option price
# 3 stocks basket
print(optionvalueavg(key, initial_stocks, numsteps, drift, r, cov, K, B, T, numpaths)) # here numsteps different from T
%timeit optionvalueavg(key, initial_stocks, numsteps, drift, r, cov, K, B, T, numpaths)


# delta
# 3 stock basket
goptionvalueavg = jax.grad(optionvalueavg,argnums=1)
print(goptionvalueavg(keys, initial_stocks, numsteps, drift, r, cov, K, B, T, numpaths)) # here numsteps different from T
%timeit goptionvalueavg(keys, initial_stocks, numsteps, drift, r, cov, K, B, T, numpaths)

9.77392
10 loops, best of 5: 77.8 ms per loop
[0.188328   0.18788245 0.1881196 ]
1 loop, best of 5: 3.43 s per loop


In [None]:
# Knock in call
# B>K
# Once price reaches B, get immediate payoff (S[knock-in]-K)
# change r to an array of length 'numsteps'

import jax
import jax.numpy as jnp
from jax import random
from jax import jit
import numpy as np
from torch.utils.dlpack import from_dlpack

def Brownian_motion_single(key, initial_stocks, numsteps, drift,sigma, T):
    stocks_init = jnp.zeros((numsteps + 1, initial_stocks.shape[0]))
    stocks_init = jax.ops.index_update(stocks_init,   # jax.ops.index_update(x, idx, y) <-> Pure equivalent of x[idx] = y
                            jax.ops.index[0],         # initialization of stock prices
                            initial_stocks)
    noise = jax.random.multivariate_normal(key,  jnp.array([0]*initial_stocks.shape[0]), cov, (numsteps+1,)) # noise must have mean 0
    dt = T / numsteps
    def time_step(t, val):
        #dx =  drift + noise[t,:] # no need to multiply by sigma here because noise generated by cov not corr
        dx2 = jnp.exp((drift - sigma ** 2. / 2.) * dt + jnp.sqrt(dt) * noise[t,:])
        val = jax.ops.index_update(val,
                            jax.ops.index[t],
                            val[t-1] * dx2)
        return val
    return jax.lax.fori_loop(1, numsteps+1, time_step, stocks_init)[1:] # jax.lax.fori_loop(lower, upper, body_fun, init_val)

def optionvalueavg(key, initial_stocks, numsteps, drift, r, cov, K, B, T, numpaths): # up-and-in call
    out = batch_simple(keys, initial_stocks, numsteps, drift, cov, T)
    knock_out_index = jnp.argmax(jnp.mean(out, axis=2) >= B, axis=1)

    r_mean_array = jnp.tile(jnp.cumsum(r), (numpaths, 1))[jnp.arange(numpaths), knock_out_index]/(knock_out_index+1)

    return jnp.mean((1 - jnp.all(jnp.mean(out, axis=2) < B, axis=1).astype(int)) *  # knock in: 1, else: 0
                    (jnp.mean(out, axis=2)[jnp.arange(numpaths), knock_out_index] - K) *   # (S[knock-in]-K)
                    jnp.exp(- r_mean_array * (T * (knock_out_index+1) / numsteps))) # (exp(-mean(r until payoff) * (t until payoff)))

numstocks = 3

np.random.seed(np.random.randint(10000))

rng = jax.random.PRNGKey(np.random.randint(10000))
rng, key = jax.random.split(rng)
numsteps = 500
numpaths = 100000
drift = jnp.array([0.]*numstocks)
r = jnp.array(np.random.random(numsteps) * 0.1) # r is an array now

cov = jnp.identity(numstocks)*0.5*0.5
initial_stocks = jnp.array([100.]*numstocks) # must be float

T = 1.0
K = 100.0
B = 120.0

fast_simple = jax.jit(Brownian_motion, static_argnums=2)
#fast_simple(key, initial_stocks, numsteps, drift, cov)

keys = jax.random.split(key, numpaths)
batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None, None))

 

# option price
# 3 stocks basket
print(optionvalueavg(key, initial_stocks, numsteps, drift, r, cov, K, B, T, numpaths)) # here numsteps different from T


# delta
# 3 stock basket
goptionvalueavg = jax.grad(optionvalueavg,argnums=1)
print(goptionvalueavg(keys, initial_stocks, numsteps, drift, r, cov, K, B, T, numpaths)) # here numsteps different from T

In [4]:
Vega_optionvalueavg = jax.grad(optionvalueavg,argnums=5)
print(Vega_optionvalueavg(keys, initial_stocks, numsteps, drift, r, cov, K, B, T, numpaths))

[[6.821512  5.661367  5.6682487]
 [5.661367  6.8694696 5.7544746]
 [5.6682487 5.7544746 6.9244223]]


In [7]:
Theta_optionvalueavg = jax.grad(optionvalueavg,argnums=8)
print(Theta_optionvalueavg(keys, initial_stocks, numsteps, drift, r, cov, K, B, T, numpaths))

4.955046


In [6]:
std = cov ** 0.5
std

DeviceArray([[0.5, 0. , 0. ],
             [0. , 0.5, 0. ],
             [0. , 0. , 0.5]], dtype=float32)

In [21]:
numstocks = 1

rng = jax.random.PRNGKey(1)
rng, key = jax.random.split(rng)
numsteps = 50
drift = jnp.array([0.]*numstocks)
r = drift # let r = drift to match B-S

cov = jnp.array([0.5*0.5])
initial_stocks = jnp.array([100.]*numstocks) # must be float

T = 1.0
K = 60.0
B = 78.4 # if B is set to 0, equivalent to European call

fast_simple = jax.jit(Brownian_motion, static_argnums=2)
#fast_simple(key, initial_stocks, numsteps, drift, cov)

keys = jax.random.split(key, 100000)
batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None, None))
std = cov**0.5

In [22]:
cov

DeviceArray([0.25], dtype=float32)

In [23]:
def optionvalueavg2(key, initial_stocks, numsteps, drift, r, std, K, B, T, numpaths):
  cov= std**2 # up-and-in call
  out = batch_simple(keys, initial_stocks, numsteps, drift, cov, T)
  knock_out_index = jnp.argmax(jnp.mean(out, axis=2) >= B, axis=1)
    
  r_mean_array = jnp.tile(jnp.cumsum(r), (numpaths, 1))[jnp.arange(numpaths), knock_out_index]/(knock_out_index+1)

  return jnp.mean((1 - jnp.all(jnp.mean(out, axis=2) < B, axis=1).astype(int)) *  # knock in: 1, else: 0
                    (jnp.mean(out, axis=2)[jnp.arange(numpaths), knock_out_index] - K) *   # (S[knock-in]-K)
                    jnp.exp(- r_mean_array * (T * (knock_out_index+1) / numsteps))) # (exp(-mean(r until payoff) * (t until payoff)))

In [24]:
print(optionvalueavg2(key, initial_stocks, numsteps, drift, r, std, K, B, T, numpaths))

ValueError: ignored

In [None]:
print(optionvalueavg2(key, [], numsteps, drift, r, std, K, B, T, numpaths))

In [14]:
Vega_optionvalueavg = jax.grad(optionvalueavg2,argnums=5)
print(Vega_optionvalueavg(keys, initial_stocks, numsteps, drift, r, std, K, B, T, numpaths))

[[6.821512  0.        0.       ]
 [0.        6.8694696 0.       ]
 [0.        0.        6.9244223]]


In [None]:
# test prices SD
numsteps = 500
numpaths = 100000

prices = np.zeros(100)

for i in range(100):
  rng = jax.random.PRNGKey(np.random.randint(10000))
  rng, key = jax.random.split(rng)
  fast_simple = jax.jit(Brownian_motion, static_argnums=2)
  #fast_simple(key, initial_stocks, numsteps, drift, cov)
  keys = jax.random.split(key, numpaths)
  batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None, None))

  prices[i] = optionvalueavg(key, initial_stocks, numsteps, drift, r, cov, K, B, T, numpaths)
  #print(i)

prices.std() 

##################### Case 1
# SD = 0.065 (set to half vega)
# antithetic needs 10000 paths: 0.76s
# no antithetic needs 25000 paths: 0.63s
# no antithetic better in this case

##################### Case 2
# SD = 0.030
# antithetic needs 50000 paths: 1.07s
# no antithetic needs 100000 paths: 1.23s
# antithetic better in this case

0.03209044747132164

# Construct Neural Net

In [None]:
%%writefile cupy_dataset.py
import cupy
import jax
import jax.numpy as jnp
from jax import random
from jax import jit
import numpy as np
from torch.utils.dlpack import from_dlpack

def Brownian_motion(key, initial_stocks, numsteps, drift, cov, T):
    stocks_init = jnp.zeros((numsteps + 1, initial_stocks.shape[0]))
    stocks_init = jax.ops.index_update(stocks_init,   # jax.ops.index_update(x, idx, y) <-> Pure equivalent of x[idx] = y
                            jax.ops.index[0],         # initialization of stock prices
                            initial_stocks)
    noise = jax.random.multivariate_normal(key,  jnp.array([0]*initial_stocks.shape[0]), cov, (numsteps+1,)) # noise must have mean 0
    sigma = jnp.diag(cov) ** 0.5
    dt = T / numsteps
    def time_step(t, val):
        #dx =  drift + noise[t,:] # no need to multiply by sigma here because noise generated by cov not corr
        dx2 = jnp.exp((drift - sigma ** 2. / 2.) * dt + jnp.sqrt(dt) * noise[t,:])
        val = jax.ops.index_update(val,
                            jax.ops.index[t],
                            val[t-1] * dx2)
        return val
    return jax.lax.fori_loop(1, numsteps+1, time_step, stocks_init)[1:] # jax.lax.fori_loop(lower, upper, body_fun, init_val)

def optionvalueavg(key, initial_stocks, numsteps, drift, r, cov, K, B, T, keys): # down-and-out call
    return jnp.mean(jnp.maximum((1 - jnp.any(jnp.mean(batch_simple(keys, initial_stocks, numsteps, drift, cov, T), axis=2) < B, axis=1).astype(int))* 
                                (jnp.mean(batch_simple(keys, initial_stocks, numsteps, drift, cov, T), axis=2))[:,-1]-K, 0) *
                    jnp.exp(-r[0] * T))
    # must use '-1' not 'numsteps', or else grad will be 0

###################################################################################################
# these 2 functions must be defined outside class in order to be used in 'optionvalueavg' function
fast_simple = jax.jit(Brownian_motion, static_argnums=2)
batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None, None))
###################################################################################################

class OptionDataSet(object):
    
    def __init__(self, max_len, number_path, batch, seed, stocks):
        self.num = 0
        self.max_length = max_len
        self.N_PATHS = number_path
        self.N_STEPS = 50
        self.N_BATCH = batch
        self.N_STOCKS = stocks
        self.T = 1.0 # assume T = 1, use float here
        self.seed = seed
        np.random.seed(seed)
        
    def __len__(self):
        return self.max_length
        
    def __iter__(self):
        self.num = 0
        return self
    
    def __next__(self):
        if self.num >= self.max_length:
            raise StopIteration
        
        Y = cupy.zeros((self.N_BATCH, 1 + self.N_STOCKS), dtype=cupy.float32) # output: price, delta1, delta2, delta3
        X = cupy.zeros((self.N_BATCH, self.N_STOCKS * 7), dtype = cupy.float32)

        for op in range(self.N_BATCH):
          
          rng = jax.random.PRNGKey(self.seed)
          rng, key = jax.random.split(rng)

          ################################################################################################### generate random input numbers

          initial_stocks = jnp.array(np.random.random(self.N_STOCKS) * 200.0)

          corr = jnp.diag(jnp.array([1]*self.N_STOCKS)) # assume no correlation between stocks here
          sigma = jnp.array(np.random.random(self.N_STOCKS) * 0.4)
          cov = (jnp.diag(sigma)).dot(corr).dot(jnp.diag(sigma))

          r = jnp.repeat(jnp.array(np.random.random(1) * 0.1), self.N_STOCKS)
          drift = r # To match BS, use drift = r

          T = self.T
          K = np.random.random(1) * 200.0
          B = np.random.random(1) * 200.0 * 0.6 # B can't be too large

          ###################################################################################################
          ################################################################################################### apply functions to compute price and deltas
          
          keys = jax.random.split(key, self.N_PATHS)

          European_Call_price = optionvalueavg(key, initial_stocks, self.N_STEPS, drift, r, cov, K, B, T, keys) # need to pass 'keys'
          gooptionvalue = jax.grad(optionvalueavg, argnums=1)
          Deltas = gooptionvalue(keys, initial_stocks, self.N_STEPS, drift, r, cov, K, B, T, keys) # need to pass 'keys'

          ###################################################################################################
          ################################################################################################### store input and output numbers in X and Y

          Y[op, 0] = European_Call_price
          Y[op, 1:4] = cupy.array(Deltas, dtype=cupy.float32)

          # T, B, K, S, sigma, mu, r
          paras = (jnp.repeat(jnp.array(T), self.N_STOCKS), jnp.repeat(jnp.array(B), self.N_STOCKS), jnp.repeat(jnp.array(K), self.N_STOCKS), initial_stocks, sigma, drift, r)
          paras = np.column_stack(paras).reshape(1,-1)[0]
          X[op,] = cupy.array(paras)

          ###################################################################################################

        self.num += 1
        return (from_dlpack(X.toDlpack()), from_dlpack(Y.toDlpack()))


# ds = OptionDataSet(max_len = 2, number_path = 10000, batch = 2, seed = 15, stocks=3) # for testing purpose, use constant seed. When training, change to random seed
# for i in ds:
#     print(i)

Writing cupy_dataset.py


In [None]:
%%writefile model.py
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np

class Net(nn.Module):

    def __init__(self, hidden=1024):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(7*3, hidden) # remember to change this!
        self.fc2 = nn.Linear(hidden, hidden)
        self.fc3 = nn.Linear(hidden, hidden)
        self.fc4 = nn.Linear(hidden, hidden)
        self.fc5 = nn.Linear(hidden, hidden)
        self.fc6 = nn.Linear(hidden, 4) # 4 outputs: price, delta1, delta2, delta3
        self.register_buffer('norm',
                             torch.tensor([1, 200.0*0.6, 200.0, 200.0, 0.4, 0.1, 0.1]*3)) # don't use numpy here - will give error later
                             # T, B, K, S, sigma, mu, r

    def forward(self, x):
        # normalize the parameter to range [0-1] 
        x = x / self.norm
        x = F.elu(self.fc1(x))
        x = F.elu(self.fc2(x))
        x = F.elu(self.fc3(x))
        x = F.elu(self.fc4(x))
        x = F.elu(self.fc5(x))
        return self.fc6(x)

Writing model.py


# Train Neural Net

In [None]:
!pip install pytorch-ignite

Collecting pytorch-ignite
  Downloading pytorch_ignite-0.4.6-py3-none-any.whl (232 kB)
[?25l[K     |█▍                              | 10 kB 29.2 MB/s eta 0:00:01[K     |██▉                             | 20 kB 33.1 MB/s eta 0:00:01[K     |████▎                           | 30 kB 25.3 MB/s eta 0:00:01[K     |█████▋                          | 40 kB 19.8 MB/s eta 0:00:01[K     |███████                         | 51 kB 10.7 MB/s eta 0:00:01[K     |████████▌                       | 61 kB 11.2 MB/s eta 0:00:01[K     |█████████▉                      | 71 kB 10.5 MB/s eta 0:00:01[K     |███████████▎                    | 81 kB 11.6 MB/s eta 0:00:01[K     |████████████▊                   | 92 kB 9.2 MB/s eta 0:00:01[K     |██████████████                  | 102 kB 9.9 MB/s eta 0:00:01[K     |███████████████▌                | 112 kB 9.9 MB/s eta 0:00:01[K     |█████████████████               | 122 kB 9.9 MB/s eta 0:00:01[K     |██████████████████▎             | 133 kB 9.9 M

In [None]:
# If memory is not enough, try changing parameters and restarting session
# loss will converge

from ignite.engine import Engine, Events
from ignite.handlers import Timer
from torch.nn import MSELoss
from torch.optim import Adam
from ignite.contrib.handlers.param_scheduler import CosineAnnealingScheduler
from ignite.handlers import ModelCheckpoint
from model import Net
from cupy_dataset import OptionDataSet
import numpy as np
timer = Timer(average=True)
model = Net().cuda()
loss_fn = MSELoss()
optimizer = Adam(model.parameters(), lr=1e-3)
dataset = OptionDataSet(max_len = 100, number_path = 1024, batch = 32, seed = np.random.randint(10000), stocks = 3) # must have random seed


def train_update(engine, batch):
    model.train()
    optimizer.zero_grad()
    x = batch[0]
    y = batch[1]
    #print(y)
    y_pred = model(x)
    #print(y_pred)
    loss = loss_fn(y_pred[:,:], y[:,:]) # compute MSE between the 2 arrays
    #print(loss)
    loss.backward()
    optimizer.step()
    return loss.item()

trainer = Engine(train_update)
log_interval = 20

scheduler = CosineAnnealingScheduler(optimizer, 'lr', 1e-4, 1e-6, len(dataset))
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
timer.attach(trainer,
             start=Events.EPOCH_STARTED,
             resume=Events.ITERATION_STARTED,
             pause=Events.ITERATION_COMPLETED,
             step=Events.ITERATION_COMPLETED)    
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
    iter = (engine.state.iteration - 1) % len(dataset) + 1
    if iter % log_interval == 0:
        print('loss', engine.state.output, 'average time', timer.value(), 'iter num', iter)
        
trainer.run(dataset, max_epochs = 100)

 Please refer to the documentation for more details.
  


loss 231.91018676757812 average time 0.6379526790499881 iter num 20
loss 410.88385009765625 average time 0.3209585229249797 iter num 40
loss 325.4310607910156 average time 0.21532507674998744 iter num 60
loss 104.61619567871094 average time 0.16247739054998647 iter num 80
loss 132.3387908935547 average time 0.13080099881999785 iter num 100
loss 66.06425476074219 average time 0.09281976219999706 iter num 20
loss 74.7149887084961 average time 0.04830877532498903 iter num 40
loss 21.255390167236328 average time 0.03353121276665737 iter num 60
loss 57.7420539855957 average time 0.02618046737499924 iter num 80
loss 28.545286178588867 average time 0.021730287400002907 iter num 100
loss 6.927199363708496 average time 0.09218964010002537 iter num 20
loss 49.87449645996094 average time 0.04803577580001388 iter num 40
loss 20.671751022338867 average time 0.03331680518334679 iter num 60
loss 5.022592544555664 average time 0.025946390587520796 iter num 80
loss 9.819489479064941 average time 0.0215

ERROR:ignite.engine.engine.Engine:Engine run is terminating due to exception: 


KeyboardInterrupt: ignored

**Save Model**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
model_save_name = 'jax_barrier_test_1.pth'
path = F"/content/drive/MyDrive/AFP Project/PUI/{model_save_name}" 
torch.save(model.state_dict(), path)

**Load Model**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
model_save_name = 'Sobolev_test_2.pth'
path = F"/content/drive/MyDrive/AFP Project/PUI/{model_save_name}" 
state_dict = torch.load(path)
print(state_dict.keys())

In [None]:
# need to run 'Writing cupy_dataset.py' and 'Writing model.py' above before this
from model import Net
model = Net().cuda()

model.load_state_dict(state_dict)
print(model)

**Continue to train model**

In [None]:
# If memory is not enough, try changing parameters and restarting session
# loss will converge

from ignite.engine import Engine, Events
from ignite.handlers import Timer
from torch.nn import MSELoss
from torch.optim import Adam
from ignite.contrib.handlers.param_scheduler import CosineAnnealingScheduler
from ignite.handlers import ModelCheckpoint
from model import Net
from cupy_dataset import OptionDataSet
import numpy as np
timer = Timer(average=True)
#model = Net().cuda()
loss_fn = MSELoss()
optimizer = Adam(model.parameters(), lr=1e-3)
#dataset = OptionDataSet(max_len = 100, number_path = 1024, batch = 32, seed = np.random.randint(10000), stocks = 3) # must have random seed
dataset = OptionDataSet(max_len = 100, number_path = 10000, batch = 8, seed = np.random.randint(10000), stocks = 3) # must have random seed


def train_update(engine, batch):
    model.train()
    optimizer.zero_grad()
    x = batch[0]
    y = batch[1]
    #print(y)
    y_pred = model(x)
    #print(y_pred)
    loss = loss_fn(y_pred[:,:], y[:,:]) # compute MSE between the 2 arrays
    #print(loss)
    loss.backward()
    optimizer.step()
    return loss.item()

trainer = Engine(train_update)
log_interval = 20

scheduler = CosineAnnealingScheduler(optimizer, 'lr', 1e-4, 1e-6, len(dataset))
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
timer.attach(trainer,
             start=Events.EPOCH_STARTED,
             resume=Events.ITERATION_STARTED,
             pause=Events.ITERATION_COMPLETED,
             step=Events.ITERATION_COMPLETED)    
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
    iter = (engine.state.iteration - 1) % len(dataset) + 1
    if iter % log_interval == 0:
        print('loss', engine.state.output, 'average time', timer.value(), 'iter num', iter)
        
trainer.run(dataset, max_epochs = 10)

model_save_name = 'jax_barrier_test_2.pth'
path = F"/content/drive/MyDrive/AFP Project/PUI/{model_save_name}" 
torch.save(model.state_dict(), path)

#Results

In [None]:
import torch
inputs = torch.tensor([[1, 70, 110.0, 100.0, 0.25, 0., 0.]*3]).cuda()
model(inputs.float())