<a href="https://colab.research.google.com/github/shatakshidata/optionpricingjax/blob/main/LSMC_with_JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Due date: October 31 2024

# Description
  In this problem we will use apply the LSMC method to price American put options. Specifically, we will replicate the result in the first row, 6th column of Table 1 in [Longstaff and Schwartz 2001](https://www.anderson.ucla.edu/documents/areas/fac/finance/least_squares.pdf)




# Part 1
The code below simulates the evolution of a stock price that follows a geometric brownian motion using JAX


Solution

In [None]:
import jax
import jax.numpy as jnp
from jax import jit, grad
import optax
from jax import random

# Data
σ = 0.04
r = 0.01
K = 35

# Design choice
dt = 0.01
m = 100

def simulate():
  jnp.random.seed(0)

  def step(S):
    dZ = jnp.random.normal(size=S.size) * jnp.sqrt(dt)
    dS = r * S  * dt + σ  * S  * dZ
    S = S + dS
    return S

  S0 = jnp.ones(20000)
  S = S0
  S_list = []

  for t in range(m):
    S = step(S)
    S_list.append(S)

  S_array = jnp.stack(S_list)
  return S_array

# Part 2
Jit Compiled version of code


Solution

In [None]:
import jax
import jax.numpy as jnp
from jax import jit, grad
import optax
from jax import random

# Data
σ = 0.04
r = 0.01
K = 35

# Design choice
dt = 0.01
m = 100

@jax.jit
def simulate():
  key = random.PRNGKey(0)

  def step(S, key):
    dZ = random.normal(key, shape=S.shape) * jnp.sqrt(dt)
    dS = r * S  * dt + σ  * S  * dZ
    S = S + dS
    return S, key

  S0 = jnp.ones(20000)
  S = S0
  S_list = []

  def body_fun(S, key):
    S, key = step(S, key)
    return (S, key)

  S_final, _ = lax.scan(body_fn, S, random.split(key, m))

  S_array = simulate()
  return S_array


# Part 3
The code below computes the price of an American Put option using Least Squares Monte Carlo (LSMC) with JAX library

Solution

In [None]:
import jax.numpy as jnp
from jax import random, jit

Spot = 36   # stock price
σ = 0.2     # stock volatility
K = 40      # strike price
r = 0.06    # risk free rate
n = 100000  # Number of simualted paths
m = 50      # number of exercise dates
T = 1       # maturity
order = 12   # Polynmial order
Δt = T / m  # interval between two exercise dates


# Construct polynomial features of order up to k using the
# recursive formulation


def chebyshev_basis(x, K):
    B = [jnp.ones(len(x)), x]
    for n in range(2, K+1):
        Bn = 2 * x * B[n - 1] - B[n - 2]
        B.append(Bn)

    return jnp.column_stack(B)


# scales x to be in the interval(-1, 1)
def scale(x):
    xmin = x.min()
    xmax = x.max()
    a = 2 / (xmax - xmin)
    b = 1 - a * xmax
    return a * x + b


# simulates one step of the stock price evolution
def step(S, key):
    dB = jnp.sqrt(Δt) * random.normal(key, shape=S.shape)
    S_tp1 = S + r * S * Δt + σ * S * dB
    return S_tp1


def payoff_put(S):
    return jnp.maximum(K - S, 0.)


# LSMC algorithm
@jit
def compute_price():
    key = random.PRNGKey(0)
    S0 = Spot * jnp.ones(n)
    S = [S0]

    keys = random.split(key, m)

    for t in range(m):
        S_tp1 = step(S[t], keys[t])
        S.append(S_tp1)

    discount = jnp.exp(-r * Δt)

    # Very last date
    value_if_exercise = payoff_put(S[-1])
    discounted_future_cashflows = value_if_exercise * discount

    # Proceed recursively
    for i in range(m - 1):
        X = chebyshev_basis(scale(S[-2 - i]), order)
        Y = discounted_future_cashflows

        Θ = jnp.linalg.solve(X.T @ X, X.T @ Y)
        value_if_wait = X @ Θ
        value_if_exercise = payoff_put(S[-2 - i])
        exercise = value_if_exercise >= value_if_wait
        discounted_future_cashflows = discount * jnp.where(
            exercise,
            value_if_exercise,
            discounted_future_cashflows)

    return discounted_future_cashflows.mean()

# Running the simulation
price = compute_price()
print(price)

#test = compute_price(order, Spot, σ, K, r)


4.4808407
