In [5]:
import numpy as np
import numpy as np
from numpy.random import randn

μ, d, σ, S, K = 0.05, 0.03, 0.15, 100.0, 120.0

def create_model():
    return np.array([μ, d, σ, S, K])

default_M = 2 ** 25
T = 3.0

In [6]:
default_model = create_model()

## python
-----------

In [7]:
from math import log, sqrt, exp

def compute_call_price_python(model, T = 3.0, M=default_M):

    μ, d, σ, S, K = model

    current_sum = 0.0
    Z = np.random.randn(M)
    for m in range(M):
        s = log(S) + (μ - d - 0.5 * σ * σ) * T +  σ * sqrt(T) * Z[m]
        current_sum += max(np.exp(s) - K, 0)
    return exp(-μ * T) * current_sum / M

In [8]:
%%time 
compute_call_price_python(default_model)

CPU times: total: 1min 7s
Wall time: 1min 7s


5.202395417058214

## numpy
----------

In [9]:
def compute_call_price_numpy(model, T = 3.0, M=default_M):
    """
    Generate M observations of S_n and average to estimate
    the option price.
    """

    # Set up
    μ, d, σ, S, K = model
    s = np.full(M, np.log(S))
    Z = np.random.randn(M)
    s = s + (μ - d - 0.5 * σ * σ) * T +  σ * np.sqrt(T) * Z
    expectation = np.mean(np.maximum(np.exp(s) - K, 0))
    
    return np.exp(-μ * T) * expectation

In [11]:
%%time 
compute_call_price_numpy(default_model)

CPU times: total: 1.5 s
Wall time: 1.5 s


5.201171659167529

## Numba
------------

In [12]:
from numba import njit
from numba import prange

In [13]:
compute_call_price_numba = njit()(compute_call_price_numpy)

In [15]:
%%time 
compute_call_price_numba(default_model)

CPU times: total: 1.02 s
Wall time: 1.02 s


5.201710288210103

In [16]:
from numba import prange

In [19]:
@njit(parallel=True)
def compute_call_price_numba_parallel(model, T = 3.0, M=default_M):

    μ, d, σ, S, K = model

    current_sum = 0.0
    # For each sample path
    for m in prange(M):
        s = np.log(S) + (μ - d - 0.5 * σ * σ) * T +  σ * np.sqrt(T) * randn()
        current_sum += np.maximum(np.exp(s) - K, 0)
    return np.exp(-μ * T) * current_sum / M

In [21]:
%%time
compute_call_price_numba_parallel(default_model)

CPU times: total: 1.77 s
Wall time: 258 ms


5.1984890614983845

## JAX
----------------

In [22]:
import jax
import jax.numpy as jnp

In [23]:
def compute_call_price_jax(model, T = 3.0, M=default_M,
                           key=jax.random.PRNGKey(1)):
    """
    Estimate the price of the call option using Monte Carlo.
    """
    # Set up
    μ, d, σ, S, K = model
    s = jnp.full(M, jnp.log(S))
    Z = jax.random.normal(key, (1, M))
    s = s + (μ - d - 0.5 * σ * σ) * T +  σ * jnp.sqrt(T) * Z[0, :]
    expectation = jnp.mean(jnp.maximum(jnp.exp(s) - K, 0))
    return jnp.exp(-μ * T) * expectation

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [26]:
%%time 
compute_call_price_jax(default_model).block_until_ready()

CPU times: total: 1.59 s
Wall time: 522 ms


Array(5.200003, dtype=float32)

In [27]:
compute_call_price_jax_jit = jax.jit(compute_call_price_jax)

In [29]:
%%time 
compute_call_price_jax_jit(default_model).block_until_ready()

CPU times: total: 1.27 s
Wall time: 259 ms


Array(5.200003, dtype=float32)

In [39]:
compute_call_grad = jax.grad(compute_call_price_jax)
compute_call_grad_jit = jax.jit(jax.grad(compute_call_price_jax))

In [40]:
%%time 
compute_call_grad(default_model).block_until_ready()

CPU times: total: 500 ms
Wall time: 491 ms


Array([ 83.81817   , -98.55397   ,  56.65969   ,   0.32851323,
        -0.23282823], dtype=float32)

In [41]:
%%time 
compute_call_grad_jit(default_model).block_until_ready()

CPU times: total: 234 ms
Wall time: 235 ms


Array([ 83.81816   , -98.55397   ,  56.65969   ,   0.32851323,
        -0.23282823], dtype=float32)