<a href="https://colab.research.google.com/github/xinyanz-erin/Applied-Finance-Project/blob/Lilian/Barrier_Call_jax_v3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title
import jax.numpy as jnp
from jax import random
from jax import jit
from jax import device_put

def test():
  paths=100000
  steps=252
  stocks=3
  key = random.PRNGKey(10)
  x = random.normal(key, (paths,steps,stocks))
  s0=random.randint(key,(stocks,1),50,100)*1.
  s=s0
  s = device_put(s)
  for i in range(steps):
    for j in range(stocks):
      s+=x[:,i,j]
  return(s0,s)

a,b=test()

print(a,np.mean(b,axis=1))

jtest=jit(test)

c,d=jtest()

print(c,np.mean(d,axis=1))

[[86.]
 [56.]
 [71.]] [85.99678  55.996777 70.99677 ]
[[86.]
 [56.]
 [71.]] [85.99678  55.996777 70.99677 ]


In [None]:
#@title
paths=10
steps=5
stocks=3
key = random.PRNGKey(10)
x = random.normal(key, (paths,steps,stocks))
s0=random.randint(key,(stocks,1),50,100)*1.
s=s0
s = device_put(s)
print(s)
print(x[:,1,1])
print(s+x[:,1,1])

[[86.]
 [56.]
 [71.]]
[-0.32907748 -0.3529078  -0.5357338  -2.402825   -1.3187171  -0.4300647
 -0.42437345  0.22680947  0.3185347  -0.9541187 ]
[[85.67092  85.647095 85.464264 83.597176 84.68128  85.56994  85.57563
  86.22681  86.318535 85.04588 ]
 [55.67092  55.64709  55.464268 53.597176 54.681282 55.569935 55.575626
  56.22681  56.318535 55.045883]
 [70.67092  70.647095 70.464264 68.597176 69.68128  70.56994  70.57563
  71.22681  71.318535 70.04588 ]]


In [None]:
#@title
import jax.numpy as jnp
from jax import random
from jax import jit
import numpy as np

def simple_process(key, initial_values, numsteps, drift, cov):
    stocks_init = jnp.zeros((numsteps + 1,initial_values.shape[0]))
    stocks_init=jax.ops.index_update(stocks_init,  # jax.ops.index_update(x, idx, y) <-> Pure equivalent of x[idx] = y
                            jax.ops.index[0],
                            initial_values)
    noise = jax.random.multivariate_normal(key, drift, cov, (numsteps,))
    #return(noise)
    def time_step(t, val):
        dx =  drift+noise[t,:]
        val = jax.ops.index_update(val,
                            jax.ops.index[t],
                            val[t-1] + dx)
        return val
    return jax.lax.fori_loop(1, numsteps+1, time_step, stocks_init)[1:] # jax.lax.fori_loop(lower, upper, body_fun, init_val)

# def fori_loop(lower, upper, body_fun, init_val): # upper is exclusive
#   val = init_val
#   for i in range(lower, upper):
#     val = body_fun(i, val)
#   return val

numsteps=5
key = random.PRNGKey(10)
drift=jnp.array([0.0]*3)
cov=np.random.random((3,3))
cov=np.matmul(cov,cov.T)
initial_values=jnp.array([100.]*3)
fast_simple = jax.jit(simple_process, static_argnums=2)
init_stocks=jnp.array([100.]*3)
fast_simple(key,init_stocks,numsteps,drift,cov)
# Batch OU sample via vmap
batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None)) 
# An integer or None indicates which array axis to map over for all arguments (with None indicating not to map any axis)

%timeit fast_simple(key, init_stocks, 12, drift, cov)

numsamples=100000 # num of paths
keys = jax.random.split(key, numsamples)
batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None))
%timeit batch_simple(keys, init_stocks, 12, drift, cov)

batch_simple(keys, init_stocks, numsteps, drift, cov).shape

The slowest run took 27769.64 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 16.1 µs per loop
1 loop, best of 5: 177 ms per loop


(100000, 5, 3)

In [None]:
#@title
import jax.numpy as jnp
from jax import random
from jax import jit
import numpy as np

def Brownian_motion(key, initial_stocks, numsteps, drift, cov, sigma, T):
    stocks_init = jnp.zeros((numsteps + 1, initial_stocks.shape[0]))
    stocks_init=jax.ops.index_update(stocks_init,   # jax.ops.index_update(x, idx, y) <-> Pure equivalent of x[idx] = y
                            jax.ops.index[0],       # initialization of stock prices
                            initial_stocks)
    noise = jax.random.multivariate_normal(key, drift, cov, (numsteps,))
    dt = jnp.array(T[0]/numsteps)
    def time_step(t, val):
        dx =  drift * dt * val[t-1,:] + sigma * val[t-1,:] * jnp.sqrt(dt) * noise[t,:]
        val = jax.ops.index_update(val,
                            jax.ops.index[t],
                            val[t-1] + dx)
        return val
    return jax.lax.fori_loop(1, numsteps+1, time_step, stocks_init)[1:] # jax.lax.fori_loop(lower, upper, body_fun, init_val)

# def fori_loop(lower, upper, body_fun, init_val): # upper is exclusive
#   val = init_val
#   for i in range(lower, upper):
#     val = body_fun(i, val)
#   return val

np.random.seed(0)
key = random.PRNGKey(0)
initial_stocks=jnp.array([100.]*3)
numsteps=10
drift=jnp.array([0]*3)
# cov=np.random.random((3,3))
# cov=np.matmul(cov,cov.T)
corr = jnp.diag(jnp.array([1]*3))
sigma = jnp.array([0.3]*3)
cov = (jnp.diag(sigma)).dot(corr).dot(jnp.diag(sigma))
#sigma = jnp.array(np.random.random(3))
#r = jnp.array([np.random.random(1)]*3)
T = jnp.array([1.]*3)

fast_simple = jax.jit(Brownian_motion, static_argnums=2)
fast_simple(key, initial_stocks, numsteps, drift, cov, sigma, T)

# numsamples=100000 # num of paths
# keys = jax.random.split(key, numsamples)

# batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None, None, None))
# %timeit batch_simple(keys, initial_stocks, numsteps, drift, cov, sigma, T)
# batch_simple(keys, initial_stocks, numsteps, drift, cov, sigma, T).shape

DeviceArray([[104.03143 , 103.00068 , 102.59222 ],
             [102.768456, 105.89164 , 100.96432 ],
             [102.81328 ,  99.627396, 102.55864 ],
             [105.48844 , 101.25624 , 104.66836 ],
             [105.8519  , 100.323296, 109.50462 ],
             [106.59    ,  96.38026 , 107.60002 ],
             [107.01583 ,  98.68716 , 104.24871 ],
             [103.734055,  95.48625 , 102.4883  ],
             [103.274185,  94.44483 ,  99.23759 ],
             [102.81635 ,  93.41477 ,  96.08998 ]], dtype=float32)

In [None]:
#@title
key = random.PRNGKey(10)
initial_values=jnp.array([100.]*3)
numsteps=5
drift=jnp.array([0.0]*3)
cov=np.random.random((3,3))
cov=np.matmul(cov, cov.T)

print(key)
print(initial_values)
print(numsteps)
print(drift)
print(cov)

[ 0 10]
[100. 100. 100.]
5
[0. 0. 0.]
[[0.74040995 0.50528954 0.38358708]
 [0.50528954 0.69334581 0.33222558]
 [0.38358708 0.33222558 0.21827158]]


In [None]:
#@title
stocks_init = jnp.zeros((numsteps + 1, initial_values.shape[0]))
print(stocks_init)
stocks_init=jax.ops.index_update(stocks_init,
                        jax.ops.index[0],
                        initial_values)
print(stocks_init)
noise = jax.random.multivariate_normal(key, drift, cov, (numsteps,))
print(noise)

[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
[[100. 100. 100.]
 [  0.   0.   0.]
 [  0.   0.   0.]
 [  0.   0.   0.]
 [  0.   0.   0.]
 [  0.   0.   0.]]
[[-0.05614188 -0.93355596 -0.05524409]
 [ 0.70478475 -0.20070948  0.15282556]
 [ 0.11612918 -0.04694405 -0.03823702]
 [ 0.42249295 -0.36928132 -0.06133213]
 [-0.24204297 -0.2061234  -0.13489908]]


In [1]:
!curl https://colab.chainer.org/install |sh -
import cupy

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  1580  100  1580    0     0  11791      0 --:--:-- --:--:-- --:--:-- 11791
+ apt -y -q install cuda-libraries-dev-10-0
Reading package lists...
Building dependency tree...
Reading state information...
cuda-libraries-dev-10-0 is already the newest version (10.0.130-1).
The following package was automatically installed and is no longer required:
  libnvidia-common-460
Use 'apt autoremove' to remove it.
0 upgraded, 0 newly installed, 0 to remove and 40 not upgraded.
+ pip install -q cupy-cuda100  chainer 
[K     |████████████████████████████████| 58.9 MB 35 kB/s 
[K     |████████████████████████████████| 1.0 MB 31.4 MB/s 
[?25h  Building wheel for chainer (setup.py) ... [?25l[?25hdone
+ set +ex
Installation succeeded!


# Pui

In [41]:
# import jax
# import jax.numpy as jnp
# from jax import random
# from jax import jit
# import numpy as np
# from torch.utils.dlpack import from_dlpack

# def Brownian_motion(key, initial_stocks, numsteps, drift, cov, T):
#     stocks_init = jnp.zeros((numsteps + 1, initial_stocks.shape[0]))
#     stocks_init = jax.ops.index_update(stocks_init,   # jax.ops.index_update(x, idx, y) <-> Pure equivalent of x[idx] = y
#                             jax.ops.index[0],         # initialization of stock prices
#                             initial_stocks)
#     noise = jax.random.multivariate_normal(key, jnp.array([0]*initial_stocks.shape[0]), cov, (numsteps,))
#     dt = jnp.array(T[0]/numsteps)
#     def time_step(t, val):
#         dx =  drift * dt * val[t-1,:] + val[t-1,:] * jnp.sqrt(dt) * noise[t,:] # no need to multiply by sigma here because noise generated by cov not corr
#         val = jax.ops.index_update(val,
#                             jax.ops.index[t],
#                             val[t-1] + dx)
#         return val
#     return jax.lax.fori_loop(1, numsteps+1, time_step, stocks_init)[1:] # jax.lax.fori_loop(lower, upper, body_fun, init_val)

# # def fori_loop(lower, upper, body_fun, init_val): # upper is exclusive
# #   val = init_val
# #   for i in range(lower, upper):
# #     val = body_fun(i, val)
# #   return val

# numstocks = 3

# np.random.seed(15)
# key = random.PRNGKey(10)
# initial_stocks = jnp.array(np.random.random(numstocks) * 200)
# numsteps = 5
# # drift=jnp.array([0.05] * numstocks) 
# # cov=np.random.random((3,3))
# # cov=np.matmul(cov,cov.T)
# corr = jnp.diag(jnp.array([1]*numstocks)) # assume no correlation between stocks here
# sigma = jnp.array(np.random.random(numstocks) * 0.4)
# cov = (jnp.diag(sigma)).dot(corr).dot(jnp.diag(sigma))
# r = jnp.repeat(jnp.array(np.random.random(1) * 0.1), numstocks)
# drift = r # To match BS, use drift = r
# T = jnp.array([1.] * numstocks)
# K = np.random.random(1) * 200
# #B = jnp.repeat(jnp.array(np.random.random(1) * 200*0.9), numstocks)
# B = jnp.repeat(jnp.array(K * 0.9), numstocks)

# fast_simple = jax.jit(Brownian_motion, static_argnums=2)
# fast_simple(key, initial_stocks, numsteps, drift, cov, T)

# numsamples = 10 # num of paths
# keys = jax.random.split(key, numsamples)

# batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None, None))
# %timeit batch_simple(keys, initial_stocks, numsteps, drift, cov, T)
# out = batch_simple(keys, initial_stocks, numsteps, drift, cov, T)

# print(B)
# avg_across_stocks = np.mean(out, axis=2)
# print(np.mean(out, axis=2))
# print(np.any(avg_across_stocks < 67, axis=1)) # test using Barrier = 67
# print(1 - np.any(avg_across_stocks < 67, axis=1).astype(int))
# print((1 - np.any(avg_across_stocks < 67, axis=1).astype(int))* avg_across_stocks[:,numsteps])

# Payoff = (1 - np.any(avg_across_stocks < B[0], axis=1).astype(int))* avg_across_stocks[:,numsteps]
# Barrier_Call_price = np.mean(np.maximum(Payoff - K, 0) * jnp.exp(-r[0] * T[0]))

# paras = (T, jnp.repeat(jnp.array(K), numstocks), B, initial_stocks, sigma, drift, r)
# paras = np.column_stack(paras).reshape(1,-1)[0]

# print(from_dlpack(cupy.array(paras).toDlpack()))
# print(Barrier_Call_price)

In [44]:
################################# TEST ########################################
#%%writefile cupy_dataset.py
import cupy
import jax
import jax.numpy as jnp
from jax import random
from jax import jit
import numpy as np
from torch.utils.dlpack import from_dlpack

def Brownian_motion(key, initial_stocks, numsteps, drift, cov, T):
    stocks_init = jnp.zeros((numsteps + 1, initial_stocks.shape[0]))   # initial_stocks.shape[0] <-> Stocks number
    stocks_init = jax.ops.index_update(stocks_init,   # jax.ops.index_update(x, idx, y) <-> Pure equivalent of x[idx] = y
                            jax.ops.index[0],         # initialization of stock prices
                            initial_stocks)
    noise = jax.random.multivariate_normal(key, jnp.array([0]*initial_stocks.shape[0]), cov, (numsteps,)) #shape<->[numsteps,nstocks]
    dt = jnp.array(T[0]/numsteps)
    def time_step(t, val):
        dx =  drift * dt * val[t-1,:] + val[t-1,:] * jnp.sqrt(dt) * noise[t,:] # no need to multiply by sigma here because noise generated by cov not corr
        val = jax.ops.index_update(val,
                            jax.ops.index[t],
                            val[t-1] + dx)
        return val
    return jax.lax.fori_loop(1, numsteps+1, time_step, stocks_init)[1:] # jax.lax.fori_loop(lower, upper, body_fun, init_val)


class NumbaOptionDataSet(object):
    
    def __init__(self, max_len=10, number_path = 1000, batch=2, seed=15, stocks=3):  # 3 stocks
        self.num = 0
        self.max_length = max_len
        self.N_PATHS = number_path
        self.N_STEPS = 365
        self.N_BATCH = batch
        self.N_STOCKS = stocks
        self.T = 1
        self.seed = seed
        np.random.seed(seed)
        
    def __len__(self):
        return self.max_length
        
    def __iter__(self):
        self.num = 0
        return self

    def __next__(self):
        if self.num >= self.max_length:
            raise StopIteration
        
        Y = cupy.zeros(self.N_BATCH, dtype=cupy.float32)
        X = cupy.zeros((self.N_BATCH, self.N_STOCKS * 7), dtype = cupy.float32)

        for op in range(self.N_BATCH):
          
          key = random.PRNGKey(self.seed)
          initial_stocks = jnp.array(np.random.random(self.N_STOCKS) * 200)
          # cov=np.random.random((3,3))
          # cov=np.matmul(cov,cov.T)
          corr = jnp.diag(jnp.array([1]*self.N_STOCKS)) # assume no correlation between stocks here
          sigma = jnp.array(np.random.random(self.N_STOCKS) * 0.4)
          cov = (jnp.diag(sigma)).dot(corr).dot(jnp.diag(sigma))
          r = jnp.repeat(jnp.array(np.random.random(1) * 0.1), self.N_STOCKS)
          drift = r # To match BS, use drift = r
          T = jnp.array([self.T] * self.N_STOCKS)
          K = np.random.random(1) * 200
          B = jnp.repeat(jnp.array(K * 0.9), self.N_STOCKS)

          fast_simple = jax.jit(Brownian_motion, static_argnums=2)
          fast_simple(key, initial_stocks, self.N_STEPS, drift, cov, T)

          numsamples = self.N_PATHS # num of paths
          keys = jax.random.split(key, numsamples)

          batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None, None))
          out = batch_simple(keys, initial_stocks, self.N_STEPS, drift, cov, T)

          avg_across_stocks = np.mean(out, axis=2) # becomes paths * steps
          Payoff = (1 - np.any(avg_across_stocks < B[0], axis=1).astype(int))* avg_across_stocks[:,self.N_STEPS]
          Barrier_Call_price = np.mean(np.maximum(Payoff - K, 0) * jnp.exp(-r[0] * T[0]))
          Y[op] = Barrier_Call_price

          paras = (T, jnp.repeat(jnp.array(K), self.N_STOCKS), B, initial_stocks, sigma, drift, r)
          paras = np.column_stack(paras).reshape(1,-1)[0]
          X[op,] = cupy.array(paras)

        self.num += 1
        return (from_dlpack(X.toDlpack()), from_dlpack(Y.toDlpack()))
    #self._net_grads = jax.grad(self._net_apply)


ds = NumbaOptionDataSet(1, number_path = 100000, batch = 2, seed = 15, stocks=3)
for i in ds:
    print(i)

(tensor([[1.0000e+00, 6.0895e+01, 5.4805e+01, 1.6976e+02, 1.4462e-01, 3.0592e-02,
         3.0592e-02, 1.0000e+00, 6.0895e+01, 5.4805e+01, 3.5779e+01, 1.1016e-01,
         3.0592e-02, 3.0592e-02, 1.0000e+00, 6.0895e+01, 5.4805e+01, 1.0873e+01,
         2.1200e-01, 3.0592e-02, 3.0592e-02],
        [1.0000e+00, 4.2110e+01, 3.7899e+01, 2.2348e+01, 1.0566e-01, 8.0708e-02,
         8.0708e-02, 1.0000e+00, 4.2110e+01, 3.7899e+01, 4.9980e+01, 2.8711e-01,
         8.0708e-02, 8.0708e-02, 1.0000e+00, 4.2110e+01, 3.7899e+01, 1.8353e+02,
         3.4629e-01, 8.0708e-02, 8.0708e-02]], device='cuda:0'), tensor([13.2313, 46.5335], device='cuda:0'))


# Lilian

In [55]:
################################# TEST ########################################
#%%writefile cupy_dataset.py
import cupy
import jax
import jax.numpy as jnp
from jax import random
from jax import jit
import numpy as np
from torch.utils.dlpack import from_dlpack

In [4]:
N_STOCKS = 3
N_BATCH = 2
N_STEPS = 365
N_PATHS = 1000
S0 = 200
K_val = 200
B_val = 0.9
r_val = 0.1
sigma_val = 0.4

In [99]:
def Brownian_motion(key, initial_stocks, numsteps, drift, cov, T):
    stocks_init = jnp.zeros((numsteps + 1, initial_stocks.shape[0]))   # initial_stocks.shape[0] <-> Stocks number
    stocks_init = jax.ops.index_update(stocks_init,   # jax.ops.index_update(x, idx, y) <-> Pure equivalent of x[idx] = y
                            jax.ops.index[0],         # initialization of stock prices
                            initial_stocks)
    noise = jax.random.multivariate_normal(key, jnp.array([0]*initial_stocks.shape[0]), cov, (numsteps,)) #shape<->[numsteps,nstocks]
    dt = jnp.array(T[0]/numsteps)
    def time_step(t, val):
        dx =  drift * dt * val[t-1,:] + val[t-1,:] * jnp.sqrt(dt) * noise[t,:] # no need to multiply by sigma here because noise generated by cov not corr
        val = jax.ops.index_update(val,
                            jax.ops.index[t],
                            val[t-1] + dx)
        return val
    return jax.lax.fori_loop(1, numsteps+1, time_step, stocks_init)[1:] # jax.lax.fori_loop(lower, upper, body_fun, init_val)

   
def OptionPrice(seed, K_val, B_val, S0, sigma_val, r_val, T_val = 1):
  price = []
  Y = jnp.zeros(N_BATCH)
  #Y = cupy.zeros(N_BATCH, dtype=cupy.float32)
  #X = cupy.zeros((N_BATCH, N_STOCKS * 7), dtype = cupy.float32)
  for op in range(N_BATCH):    
    key = random.PRNGKey(seed)
    initial_stocks = jnp.array(np.random.random(N_STOCKS) * S0)
          # cov=np.random.random((3,3))
          # cov=np.matmul(cov,cov.T)
    corr = jnp.diag(jnp.array([1]*N_STOCKS)) # assume no correlation between stocks here
    sigma = jnp.array(np.random.random(N_STOCKS) * sigma_val)
    cov = (jnp.diag(sigma)).dot(corr).dot(jnp.diag(sigma))
    r = jnp.repeat(jnp.array(np.random.random(1) * r_val), N_STOCKS)
    drift = r # To match BS, use drift = r
    T = jnp.array([T_val] * N_STOCKS)
    K = np.random.random(1) * K_val
    B = jnp.repeat(jnp.array(K * B_val), N_STOCKS)

    fast_simple = jax.jit(Brownian_motion, static_argnums=2)
    fast_simple(key, initial_stocks, N_STEPS, drift, cov, T)

    numsamples = N_PATHS # num of paths
    keys = jax.random.split(key, numsamples)

    batch_simple = jax.vmap(fast_simple, in_axes=(0, None, None, None, None, None))
    out = batch_simple(keys, initial_stocks, N_STEPS, drift, cov, T)

    avg_across_stocks = np.mean(out, axis=2) # becomes paths * steps
  
    Payoff = (1 - np.any(avg_across_stocks < B[0], axis=1).astype(int))* avg_across_stocks[:,N_STEPS]
    
    Barrier_Call_price = np.mean(np.maximum(Payoff - K, 0) * jnp.exp(-r[0] * T[0]))
    print('printout',type(Barrier_Call_price))
    Y = jax.ops.index_update(Y,jax.ops.index[op],Barrier_Call_price)
    #Y[op] = Barrier_Call_price
    
    paras = (T, jnp.repeat(jnp.array(K), N_STOCKS), B, initial_stocks, sigma, drift, r)
    
    #paras = np.column_stack(paras).reshape(1,-1)[0]
    #paras = jnp.array(paras)
    #X[op,] = cupy.array(paras)
  #price = ([from_dlpack(X.toDlpack()), from_dlpack(Y.toDlpack())])
  price = Y
  return price

In [93]:
Option = OptionPrice(15, K_val = 200, B_val = 0.9, S0 = 200, sigma_val = 0.4, r_val = 0.1, T_val = 1)

printout <class 'jaxlib.xla_extension.DeviceArray'>
printout <class 'jaxlib.xla_extension.DeviceArray'>


In [76]:
Option

DeviceArray([44.05111 , 18.307188], dtype=float32)

In [101]:

jax.grad(OptionPrice, argnums=1)(15, 200.0, 0.9, 200.0, 0.4, 0.1, 1.0)

TracerArrayConversionError: ignored

In [None]:
class NumbaOptionDataSet(object):
    
    def __init__(self, max_len=10, number_path = 1000, batch=2, seed=15, stocks=3):  # 3 stocks
        self.num = 0
        self.max_length = max_len
        self.N_PATHS = number_path
        self.N_STEPS = 365
        self.N_BATCH = batch
        self.N_STOCKS = stocks
        self.T = 1
        self.seed = seed
        np.random.seed(seed)
        
    def __len__(self):
        return self.max_length

    def OptionGrad(self):
      Class = NumbaOptionDataSet()
      # grad(J, argnums=2)(X, w, b, y)
      self.grad = jax.grad(Class.OptionPrice,argnums=0)(S0)
      return self.grad

    def __iter__(self):
        self.num = 0
        return self

    def __next__(self):
        if self.num >= self.max_length:
            raise StopIteration
        Class = NumbaOptionDataSet()
        val = Class.OptionPrice(200)
        grad = Class.OptionGrad()
        self.num += 1
        #return (val,'here is grad',grad)
        return ('here is grad',grad)

ds = NumbaOptionDataSet(3, number_path = 100000, batch = 2, seed = 15, stocks=3)
for i in ds:
    print(i)

# Jax Example 1

In [None]:
loss_grad = jax.grad(loss)

# Stochastic gradient descent learning rate
learning_rate = 1.
# All possible inputs
inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])

# Initialize parameters randomly
params = initial_params()

for n in itertools.count():
    # Grab a single random input
    x = inputs[onp.random.choice(inputs.shape[0])]
    # Compute the target output
    y = onp.bitwise_xor(*x)
    # Get the gradient of the loss for this input/output pair
    grads = loss_grad(params, x, y)
    # Update parameters via gradient descent
    params = [param - learning_rate * grad
              for param, grad in zip(params, grads)]
    # Every 100 iterations, check whether we've solved XOR
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, params):
            break

Hello my name is John


# Jax Example 2

In [47]:
import jax.numpy as np
from jax import grad, jit

In [48]:
def J(X, w, b, y):
    """Cost function for a linear regression. A forward pass of our model.

    Args:
        X: a features matrix.
        w: weights (a column vector).
        b: a bias.
        y: a target vector.

    Returns:
        scalar: a cost of this solution.    
    """
    y_hat = X.dot(w) + b # Predict values.
    return ((y_hat - y)**2).mean() # Return cost.

In [49]:
# A features matrix.
X = np.array([
                 [4., 7.],
                 [1., 8.],
                 [-5., -6.],
                 [3., -1.],
                 [0., 9.]
             ])

# A target column vector.
y = np.array([
                 [37.],
                 [24.],
                 [-34.], 
                 [16.],
                 [21.]
             ])

learning_rate = 0.01

Train without JIT

In [53]:
w = np.zeros((2, 1))
b = 0.

print(grad(J, argnums=0)(X, w, b , y)) #argnums (Union[int, Sequence[int]]) – Optional, 
                    #integer or sequence of integers. Specifies which positional 
                    #argument(s) to differentiate with respect to (default 0).

# %timeit grad(J, argnums=2)(X, w, b, y)

# for i in range(100):
#     w -= learning_rate * grad(J, argnums=1)(X, w, b, y)
#     b -= learning_rate * grad(J, argnums=2)(X, w, b, y)
    
#     if i % 10 == 0:
#         print(J(X, w, b, y))

SyntaxError: ignored

Train with JIT

In [None]:


w = np.zeros((2, 1))
b = 0.

grad_X = jit(grad(J, argnums=1))
grad_b = jit(grad(J, argnums=2))

# Run once to trigger JIT compilation.
grad_X(X, w, b, y)
grad_b(X, w, b, y)


%timeit grad_X(X, w, b, y)

%timeit grad_b(X, w, b, y)

for i in range(100):
    w -= learning_rate * grad_X(X, w, b, y)
    b -= learning_rate * grad_b(X, w, b, y)
    
    if i % 10 == 0:
        print(J(X, w, b, y))