In [11]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import plotly.graph_objects as go
import scipy.stats as st

import warnings
warnings.filterwarnings('ignore', category=UserWarning)

In [2]:
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

In [4]:
N_points = 100
strike = 100
S_max = 3 * strike
T = 1
r = 0.05
sigma = 0.2

In [3]:
# Loss components

def pde_dynamic(pinn, sigma, r, S_max, T):

    S = S_max * torch.rand(N_points, 1, requires_grad=True)
    tau = T * torch.rand(N_points, 1, requires_grad=True)
    
    V = pinn(S, tau)
    V_tau = torch.autograd.grad(V.sum(), tau, create_graph=True)[0]
    V_S = torch.autograd.grad(V.sum(), S, create_graph=True)[0]
    V_SS = torch.autograd.grad(V_S.sum(), S, create_graph=True)[0]

    residual = V_tau - (0.5 * (sigma * S) ** 2 * V_SS + r * S * V_S - r * V)
    
    return torch.mean(residual ** 2)


def boundary_condition(pinn, strike, S_max, T, r, call_put = 'Call'):
    
    boundary_S0 = torch.zeros(N_points, 1)
    boundary_Smax = S_max * torch.ones(N_points, 1)
    boundary_tau = T * torch.rand(N_points, 1, requires_grad=True)

    if call_put == 'Call':

        # boudary at S = 0
        boundary_V0 = torch.zeros(N_points, 1)

        # boundary at S \to \infty
        boundary_V1 = S_max - strike * torch.exp(-r * boundary_tau)
    
    elif call_put == 'Put':

        # boudary at S = 0
        boundary_V0 = strike * torch.exp(-r * boundary_tau)

        # boundary at S \to \infty
        boundary_V1 = torch.zeros(N_points, 1)

    else:
        raise ValueError("Unsupported kind of option")

    pred_boundary0 = pinn(boundary_S0, boundary_tau)
    pred_boundary1 = pinn(boundary_Smax, boundary_tau)
    loss_boundary = torch.mean((pred_boundary0 - boundary_V0)**2) + \
                    torch.mean((pred_boundary1 - boundary_V1)**2)
    
    return loss_boundary


def terminal_condition(pinn, strike, S_max, call_put):
    terminal_S = S_max * torch.rand(N_points, 1, requires_grad=True)
    terminal_tau = torch.zeros(N_points, 1)
    
    if call_put == 'Call':
        terminal_V = torch.relu(terminal_S - strike)
    elif call_put == 'Put':
        terminal_V = torch.relu(strike - terminal_S)
    else:
        raise ValueError('Unsupported kind of option')

    pred_terminal = pinn(terminal_S, terminal_tau)
    loss_terminal = torch.mean((pred_terminal - terminal_V)**2)
    return loss_terminal

In [None]:
# PINN architecture

class PINN(nn.Module):
    def __init__(self, S_max, T):

        self.S_max = S_max
        self.T = T

        super(PINN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 20),  # Input: (S, tau)
            nn.Tanh(),
            nn.Linear(20, 20),
            nn.Tanh(),
            nn.Linear(20, 20),
            nn.Tanh(),
            nn.Linear(20, 1),
            nn.Softplus()   # Output: V(S,tau)
        )
        
    def forward(self, S, tau):

        # normalize the inputs
        S_norm = S / self.S_max
        tau_norm = tau / self.T

        Stau = torch.cat([S_norm, tau_norm], dim=1)
        return self.net(Stau)

def train_network(strike = strike, rate = r, volatility = sigma, S_max = S_max, T = T, call_put = 'Call'):
    # ================== Training Setup ==================
    pinn = PINN(S_max=S_max, T=T)
    optimizer = torch.optim.Adam(pinn.parameters(), lr=0.005)


    # ================== Training Loop with EMA Early Stopping ==================
    epochs = 8000
    train_losses = []

    # EMA parameters
    ema_loss = None
    alpha = 0.1       # Smoothing factor
    patience = 200     # Epochs to wait before stopping
    min_delta = 1e-5   # Minimum improvement threshold
    wait = 0           # Epochs since last improvement
    min_epochs = 100   # Minimum training epochs before checking

    for epoch in range(epochs):
        optimizer.zero_grad()
        
        # Compute losses
        loss_initial = terminal_condition(pinn, strike = strike, S_max = S_max, call_put = call_put)
        loss_boundary = boundary_condition(pinn, strike = strike, S_max = S_max, T = T, r = r, call_put = call_put)
        loss_physics = pde_dynamic(pinn, sigma = volatility, r = rate, S_max = S_max, T = T)
        loss = loss_initial + loss_boundary + loss_physics

        train_losses.append(loss.item())
        
        # Backpropagation
        loss.backward()
        optimizer.step()
        
        # EMA-based early stopping
        ema_loss = alpha * loss.item() + (1-alpha)*ema_loss if ema_loss else loss.item()
        
        if epoch > min_epochs:
            if loss.item() < ema_loss - min_delta:
                wait = 0  # Reset counter if improving
            else:
                wait += 1
                if wait >= patience:
                    print(f"\nEarly stopping at epoch {epoch}")
                    print(f"Final loss: {loss.item():.6f} (EMA: {ema_loss:.6f})")
                    break
        
        if epoch % 500 == 0:
            print(f"Epoch {epoch:4d} | Loss: {loss.item():.6f} | EMA: {ema_loss:.6f}")
    
    return pinn

In [13]:
# solution visualization

def visualize(pinn, S_max, T):
    # ================== Visualization ==================
    # Generate test data
    S_test = torch.linspace(0, S_max, N_points).view(-1, 1)
    tau_test = torch.linspace(0, T, N_points).view(-1, 1)
    S_grid, Tau_grid = torch.meshgrid(S_test.squeeze(), tau_test.squeeze())

    # Predict solution
    with torch.no_grad():
        V_pred = pinn(S_grid.reshape(-1, 1), Tau_grid.reshape(-1, 1)).reshape(N_points, N_points).numpy()

    # Create interactive 3D plot with plotly
    fig = go.Figure(data=[go.Surface(
        x=S_grid.numpy(),
        y=Tau_grid.numpy(),
        z=V_pred,
        colorscale='jet',
        opacity=0.8,
    )])

    # Update layout
    fig.update_layout(
        title='Interactive Black-Scholes PDE solution',
        scene=dict(
            xaxis_title='Position (S)',
            yaxis_title='Time (tau)',
            zaxis_title='V(S,tau)',
        ),
        autosize=True,
        width=900,
        height=700,
    )

    # Show the plot
    fig.show()

In [12]:
pinn = train_network(call_put='Call', T = T, strike = strike, rate = r, volatility = sigma, S_max = S_max)
visualize(pinn, S_max, T)

Epoch    0 | Loss: 47964.472656 | EMA: 47964.472656
Epoch  500 | Loss: 25695.238281 | EMA: 26251.563639
Epoch 1000 | Loss: 12756.144531 | EMA: 12915.053721
Epoch 1500 | Loss: 5772.044922 | EMA: 5806.268243
Epoch 2000 | Loss: 2280.480469 | EMA: 2341.711950
Epoch 2500 | Loss: 870.575989 | EMA: 867.562750
Epoch 3000 | Loss: 300.365173 | EMA: 315.848048
Epoch 3500 | Loss: 106.062141 | EMA: 123.049853
Epoch 4000 | Loss: 45.291740 | EMA: 53.385114
Epoch 4500 | Loss: 36.726612 | EMA: 21.177278
Epoch 5000 | Loss: 1.872965 | EMA: 4.260065
Epoch 5500 | Loss: 0.659998 | EMA: 1.604078
Epoch 6000 | Loss: 0.600211 | EMA: 0.813075
Epoch 6500 | Loss: 0.970432 | EMA: 1.082508
Epoch 7000 | Loss: 0.927620 | EMA: 0.788786
Epoch 7500 | Loss: 0.243422 | EMA: 0.336131


In [14]:
# Greeks are key notion in mathematical finance. 
# They are sensetivities of the price to the inputs (in the case of PINNs - partial derivatives obtained via autodiff)

def compute_greeks(pinn, S_range=(50, 150), tau_range=(0.1, 1.0), 
                       grid_points=(50, 50)):
    
    S_min, S_max = S_range
    tau_min, tau_max = tau_range
    S_points, tau_points = grid_points
    
    # Create 2D grid
    S_vals = torch.linspace(S_min, S_max, S_points, requires_grad=True)
    tau_vals = torch.linspace(tau_min, tau_max, tau_points, requires_grad=True)
    
    S_grid, tau_grid = torch.meshgrid(S_vals, tau_vals, indexing='xy')
    
    S_flat = S_grid.reshape(-1, 1)
    tau_flat = tau_grid.reshape(-1, 1)
    
    S_flat.requires_grad_(True)
    tau_flat.requires_grad_(True)
    
    V_base = pinn(S_flat, tau_flat)
    
    delta = torch.autograd.grad(
        outputs=V_base, inputs=S_flat,
        grad_outputs=torch.ones_like(V_base),
        create_graph=True,
        retain_graph=True
    )[0].reshape(grid_points)
    
    gamma = torch.autograd.grad(
        outputs=delta, inputs=S_flat,
        grad_outputs=torch.ones_like(delta),
        create_graph=False,
        retain_graph=True
    )[0].reshape(grid_points)
    
    theta = -torch.autograd.grad(
        outputs=V_base, inputs=tau_flat,
        grad_outputs=torch.ones_like(V_base),
        create_graph=False,
        retain_graph=False
    )[0].reshape(grid_points)

    return S_grid, tau_grid, V_base, delta, gamma, theta

def visualize_greeks(S_grid, tau_grid, greek):

    fig = go.Figure(data=[go.Surface(
        x=S_grid.detach().numpy(),
        y=tau_grid.detach().numpy(),
        z=greek.detach().numpy(),
        colorscale='jet',
        opacity=0.8,
    )])

    # Update layout
    fig.update_layout(
        title='Interactive Black-Scholes PDE solution',
        scene=dict(
            xaxis_title='Position (S)',
            yaxis_title='Time (tau)',
            zaxis_title='V(S,tau)',
        ),
        autosize=True,
        width=900,
        height=700,
    )

    # Show the plot
    fig.show()

In [16]:
# Greeks must generally be smooth!

visualize_greeks(compute_greeks(pinn)[0], compute_greeks(pinn)[1], compute_greeks(pinn)[-1])