In [1]:
import jax
import jax.numpy as jnp
from dal_utils import print_result
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)

In [2]:
jax.devices()

W1211 02:25:04.170344   28608 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W1211 02:25:04.172702   28513 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.


[CpuDevice(id=0)]

In [3]:
import numpy as np
import numpy as np
from numpy.random import randn

rate, div, vol, spot, strike = 0.05, 0.03, 0.15, 100.0, 120.0

def create_model():
    return np.array([rate, div, vol, spot, strike])

n_paths = 2 ** 20
T = 3.0

In [4]:
default_model = create_model()

## 0. DAL
------------------

In [5]:
# !pip install dal-python -U

from dal import *
import dal
dal.__version__

starting DAL with: 32 threads.
use AAD framework: AADET
starting initialization global data ...
stating initialization global tape ...
finished initialization all the global information.


'2025.12.7'

In [6]:
today = Date_(2022, 9, 15)
EvaluationDate_Set(today)

maturity = today.AddDays(int(365 * T))

In [7]:
%%time
event_dates = ["STRIKE", maturity]
events = [f"{strike}", f"call pays MAX(spot() - STRIKE, 0.0)"]

product = Product_New(event_dates, events)
model = BSModelData_New(spot, vol, rate, div)

# only price
res = MonteCarlo_Value(product, model, n_paths, "sobol", False, False)
dict(res)

CPU times: user 391 ms, sys: 5.07 ms, total: 396 ms
Wall time: 18.3 ms


{'PV': 5.2017688332108225}

In [8]:
%%time

# price with derivatives
res = MonteCarlo_Value(product, model, n_paths, "sobol", False, True)
dict(res)

CPU times: user 246 ms, sys: 14.2 ms, total: 260 ms
Wall time: 29.9 ms


{'PV': 5.201768833617432,
 'd_STRIKE': -0.23584480086364598,
 'd_div': -100.50943481175742,
 'd_rate': 84.9041283109051,
 'd_spot': 0.33503144937252505,
 'd_vol': 59.585032759891895}

## 1. DAL.JAX
----------------

In [9]:
def compute_call_price_jax(model, T = 3.0, M=n_paths, key=jax.random.PRNGKey(1)):
    """
    Estimate the price of the call option using Monte Carlo.
    """
    # Set up
    μ, d, σ, S, K = model
    s = jnp.full(M, jnp.log(S))
    Z = jax.random.normal(key, (1, M))
    s = s + (μ - d - 0.5 * σ * σ) * T +  σ * jnp.sqrt(T) * Z[0, :]
    expectation = jnp.mean(jnp.maximum(jnp.exp(s) - K, 0))
    return jnp.exp(-μ * T) * expectation

compute_call_price_jax_jit = jax.jit(compute_call_price_jax)
_ = compute_call_price_jax(default_model).block_until_ready()
_ = compute_call_price_jax_jit(default_model).block_until_ready()

In [10]:
%%time

# only price without jit
res = compute_call_price_jax(default_model).block_until_ready()
print_result(res)

CPU times: user 60.1 ms, sys: 14.7 ms, total: 74.8 ms
Wall time: 14.5 ms


{'PV': 5.209482799644583}

In [11]:
%%time

# only price with jit
compute_call_price_jax_jit(default_model).block_until_ready()

CPU times: user 35.5 ms, sys: 23.2 ms, total: 58.7 ms
Wall time: 7.23 ms


Array(5.2094828, dtype=float64)

In [12]:
compute_call_value_and_grad = jax.value_and_grad(compute_call_price_jax)
compute_call_value_and_grad_jit = jax.value_and_grad(jax.jit(compute_call_price_jax))

_ = compute_call_value_and_grad(default_model)[0].block_until_ready()
_ = compute_call_value_and_grad_jit(default_model)[0].block_until_ready()

In [13]:
%%time

# price and derivatives without jit
result = compute_call_value_and_grad(default_model)
result[0].block_until_ready()
result[1].block_until_ready()
print_result(result)

CPU times: user 130 ms, sys: 5.96 ms, total: 136 ms
Wall time: 47.5 ms


{'PV': 5.209482799644583,
 'd_rate': 84.85389320333334,
 'd_div': -100.4823416022671,
 'd_vol': 59.64136171711852,
 'd_spot': 0.33494113867422365,
 'd_STRIKE': -0.23570525889814817}

In [14]:
%%time

# price and derivatives with jit
result = compute_call_value_and_grad_jit(default_model)
result[0].block_until_ready()
result[1].block_until_ready()
print_result(result)

CPU times: user 74.4 ms, sys: 34 ms, total: 108 ms
Wall time: 16.7 ms


{'PV': 5.209482799644583,
 'd_rate': 84.85389320333334,
 'd_div': -100.4823416022671,
 'd_vol': 59.64136171711852,
 'd_spot': 0.33494113867422365,
 'd_STRIKE': -0.23570525889814817}