In [1]:
μ, d, σ, S, K = 0.05, 0.03, 0.15, 100.0, 120.0

In [4]:
import numpy as np

def create_model():
    return np.array([μ, d, σ, S, K])

In [5]:
default_model = create_model()

## Numba
------------

In [6]:
import numpy as np
from numpy.random import randn
from numba import njit

default_M = 2 ** 25
T = 3.0

In [7]:
@njit
def compute_call_price(model, T = 3.0, M=default_M):
    """
    Generate M observations of S_n and average to estimate
    the option price.
    """

    # Set up
    μ, d, σ, S, K = model
    s = np.full(M, np.log(S))
    Z = np.random.randn(M)
    s = s + (μ - d - 0.5 * σ * σ) * T +  σ * np.sqrt(T) * Z
    expectation = np.mean(np.maximum(np.exp(s) - K, 0))
    
    return np.exp(-μ * T) * expectation

In [8]:
%%time 
compute_call_price(default_model)

CPU times: total: 2.23 s
Wall time: 2.45 s


5.201341236129124

In [9]:
from numba import prange

In [10]:
@njit(parallel=True)
def compute_call_price_parallel(model, T = 3.0, M=default_M):

    μ, d, σ, S, K = model

    current_sum = 0.0
    # For each sample path
    for m in prange(M):
        s = np.log(S) + (μ - d - 0.5 * σ * σ) * T +  σ * np.sqrt(T) * randn()
        current_sum += np.maximum(np.exp(s) - K, 0)
    return np.exp(-μ * T) * current_sum / M

In [11]:
%%time
compute_call_price_parallel(default_model)

CPU times: total: 2.61 s
Wall time: 1.22 s


5.203885237767771

## JAX
----------------

In [12]:
import jax
import jax.numpy as jnp

In [13]:
def compute_call_price_jax(model, T = 3.0, M=default_M,
                           key=jax.random.PRNGKey(1)):
    """
    Estimate the price of the call option using Monte Carlo.
    """
    # Set up
    μ, d, σ, S, K = model
    s = jnp.full(M, jnp.log(S))
    Z = jax.random.normal(key, (1, M))
    s = s + (μ - d - 0.5 * σ * σ) * T +  σ * jnp.sqrt(T) * Z[0, :]
    expectation = jnp.mean(jnp.maximum(jnp.exp(s) - K, 0))
    return jnp.exp(-μ * T) * expectation

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [14]:
%%time 
compute_call_price_jax(default_model).block_until_ready()

CPU times: total: 1.95 s
Wall time: 969 ms


Array(5.200003, dtype=float32)

In [15]:
compute_call_price_jax_jit = jax.jit(compute_call_price_jax)

In [16]:
%%time 
compute_call_price_jax_jit(default_model).block_until_ready()

CPU times: total: 1.25 s
Wall time: 498 ms


Array(5.200003, dtype=float32)

In [17]:
compute_call_grad = jax.grad(compute_call_price_jax)
compute_call_grad_jit = jax.jit(jax.grad(compute_call_price_jax))

In [18]:
%%time 
compute_call_grad(default_model).block_until_ready()

CPU times: total: 2.73 s
Wall time: 1.44 s


Array([  84.89636   , -100.49637   ,   59.568905  ,    0.3349879 ,
         -0.23582327], dtype=float32)

In [19]:
%%time 
compute_call_grad_jit(default_model).block_until_ready()

CPU times: total: 1.89 s
Wall time: 838 ms


Array([  84.89636   , -100.49637   ,   59.568913  ,    0.3349879 ,
         -0.23582327], dtype=float32)