In [1]:
import jax
import jax.numpy as jnp
from dal_utils import print_result
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)

In [2]:
jax.devices()

W1214 16:03:49.344173    5644 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W1214 16:03:49.346749    5510 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.


[CpuDevice(id=0)]

In [3]:
import numpy as np

rate, div, vol, spot, strike, barrier = 0.05, 0.03, 0.15, 100.0, 120.0, 150

def create_model():
    return np.array([rate, div, vol, spot, strike, barrier])

n_paths = 2 ** 20
T = 3.0

In [4]:
default_model = create_model()

## 0. DAL
------------------

In [5]:
# !pip install dal-python -U

from dal import *
import dal
dal.__version__

starting DAL with: 32 threads.
use AAD framework: AADET
starting initialization global data ...
stating initialization global tape ...
finished initialization all the global information.


'2025.12.7'

In [6]:
today = Date_(2022, 9, 15)
EvaluationDate_Set(today)

maturity = today.AddDays(int(365 * T))
maturity
freq = "1M"

In [7]:
%%time

event_dates = ["STRIKE", "BARRIER", today]
events = [f"{strike:.2f}", f"{barrier:.2f}", "alive = 1"]
event_dates.append(f"START: {today} END: {maturity} FREQ: {freq}")
events.append("if spot() >= BARRIER:0.1 then alive = 0 end")
event_dates.append(maturity)
events.append(f"if spot() >= BARRIER:0.1 then alive = 0 end\ncall pays alive * MAX(spot() - STRIKE, 0.0)")

CPU times: user 52 μs, sys: 11 μs, total: 63 μs
Wall time: 67.9 μs


In [8]:
%%time

product = Product_New(event_dates, events)
model = BSModelData_New(spot, vol, rate, div)

# only price
res = MonteCarlo_Value(product, model, n_paths, "sobol", False, False)
dict(res)

CPU times: user 2.76 s, sys: 22.5 ms, total: 2.78 s
Wall time: 108 ms


{'PV': 1.5057628229021944}

In [9]:
%%time

# price with derivatives
res = MonteCarlo_Value(product, model, n_paths, "sobol", False, True)
dict(res)

CPU times: user 7.53 s, sys: 77.7 ms, total: 7.61 s
Wall time: 313 ms


{'PV': 1.5052428979007992,
 'd_BARRIER': 0.08925165480978217,
 'd_STRIKE': -0.14047464486290206,
 'd_div': -19.381229382029755,
 'd_rate': 14.865500688327364,
 'd_spot': 0.049749022166538585,
 'd_vol': -7.227015534889115}

## 1. DAL.JAX
----------------

In [10]:
if freq == "1M":
    dt, NT = 1.0 / 12, 3 * 12
elif freq == "1W":
    dt, NT = 1.0 / 51, 3 * 51

In [11]:
@jax.jit
def update_state(Z, s, alive, not_alive, M, μ, d, σ, delta_t, log_b):
    s = s + (μ - d - 0.5 * σ * σ) * delta_t +  σ * jnp.sqrt(delta_t) * Z[0, :]
    alive = jnp.where(s >= log_b, not_alive, alive)
    return alive, s


def compute_barrier_price_jax(model, delta_t=dt, NT = NT, M=n_paths, key=jax.random.PRNGKey(1)):
    """
    Estimate the price of the up barrieroption using Monte Carlo.
    """
    # Set up
    μ, d, σ, S, K, b = model
    log_b = jnp.log(b)
    s = jnp.full(M, jnp.log(S))
    alive = jnp.full(M, 1.0)
    not_alive = jnp.full(M, 0.0)
    subkey = key
    for _ in range(NT):
        key, subkey = jax.random.split(subkey)
        Z = jax.random.normal(key, (1, M))
        alive, s = update_state(Z, s, alive, not_alive, M, μ, d, σ, delta_t, log_b)
    expectation = jnp.mean(alive * jnp.maximum(jnp.exp(s) - K, 0))
    return jnp.exp(-μ * delta_t * NT) * expectation

_ = compute_barrier_price_jax(default_model).block_until_ready()

In [12]:
%%time

# only price without jit
res = compute_barrier_price_jax(default_model).block_until_ready()
print_result(res)

CPU times: user 1.51 s, sys: 267 ms, total: 1.78 s
Wall time: 218 ms


{'PV': 1.4977221797797577}

In [13]:
compute_barrier_value_and_grad = jax.value_and_grad(compute_barrier_price_jax)
_ = compute_barrier_value_and_grad(default_model)[0].block_until_ready()

In [14]:
%%time

# price and derivatives without jit
# TODO: the output numbers of greeks are not correct
result = compute_barrier_value_and_grad(default_model, 1.0/ 12, 36, n_paths)
result[0].block_until_ready()
result[1].block_until_ready()
print_result(result, names=("rate", "div", "vol", "spot", "STRIKE", "barrier"))

CPU times: user 1.6 s, sys: 707 ms, total: 2.31 s
Wall time: 336 ms


{'PV': 1.4977221797797577,
 'd_rate': 50.275886659128524,
 'd_div': -54.7690531984678,
 'd_vol': 21.403390906231706,
 'd_spot': 0.18256351066155926,
 'd_STRIKE': -0.1396552407198014,
 'd_barrier': 0.0}