In [None]:
import torch
import random
from torch import optim, nn
import math
import matplotlib.pyplot as plt
from torchviz import make_dot

In [None]:
from res.plot_lib import plot_data, plot_model, set_default

In [None]:
set_default()

We generate sample paths of the following SDE:

$$
\text{d}X_t = r X_t \text{d}t + \sigma X_t \text{d}W_t, \quad X_0 = x_0 \in \mathbb{R}.
$$

We know that $$X_t = x_0\exp((r - \frac{1}{2}\sigma^2) t + \sigma W_t).$$

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [151]:
# black-scholes inputs
S0 = torch.tensor(100.0, requires_grad = False, dtype=torch.float, device=device)
K = torch.tensor(100.0, requires_grad = False, dtype=torch.float, device=device)
r = torch.tensor(0.01, requires_grad = False, dtype=torch.float, device=device)
sigma = torch.tensor(0.4, requires_grad = True, dtype=torch.float, device=device)
N = 1000 # number of MC paths
M = 1000 # number of time steps
T = 0.2
dt = T / M
tspace = torch.linspace(0, T, M + 1)
torch.manual_seed(42)

S = S0 * torch.ones(N, M + 1).to(device)
tspace_mat = torch.ones(N, M + 1) * tspace
dW = sigma * dt**0.5 * torch.randn(size=(N, M))
Wt = torch.zeros(N, M + 1)
Wt[:,1:] = torch.cumsum(dW, dim = 1)
S = S0 * torch.exp((r - sigma**2 / 2) * tspace_mat + Wt)

In [None]:
# for i in range(N):
#     plt.plot(tspace, S[i,:].detach().numpy())

In [None]:
# utility functions
cdf = torch.distributions.Normal(0,1).cdf
pdf = lambda x: torch.distributions.Normal(0,1).log_prob(x).exp()

d1 = (torch.log(S0/K) + (r + sigma**2 / 2) * T) / (sigma * T**0.5)
d2 = (torch.log(S0/K) + (r - sigma**2 / 2) * T) / (sigma * T**0.5)
callprice = S0* cdf(d1) - K*torch.exp(-r * T)*cdf(d2)

vega = S0 * pdf(d1) * T**0.5

callprice.backward() # backward pass computes the gradients

print("Call price: %.6f" % callprice)
print("Vega from grad: %.6f" % sigma.grad)
print("Vega from formula: %.6f" % vega)

### Newton's method for estimating implied volatility

In [149]:
vol = torch.tensor(1.0, requires_grad=True, dtype=torch.float, device=device)
callprice_true = torch.tensor(7.220219)
n_iteration = 20
t = 0

d1 = (torch.log(S0/K) + (r + vol**2 / 2) * T) / (vol * T**0.5)
d2 = (torch.log(S0/K) + (r - vol**2 / 2) * T) / (vol * T**0.5)
callprice_pred = S0* cdf(d1) - K*torch.exp(-r * T)*cdf(d2)

error = torch.abs(callprice_pred - callprice_true).detach().numpy()

print("Error in call price: %.6f" % error)

while t < n_iteration and error > 0.00001 :
    d1 = (torch.log(S0/K) + (r + vol**2 / 2) * T) / (vol * T**0.5)
    d2 = (torch.log(S0/K) + (r - vol**2 / 2) * T) / (vol * T**0.5)
    callprice_pred = S0* cdf(d1) - K*torch.exp(-r * T)*cdf(d2)
    callprice_pred.backward() # backward pass computes the gradients
    
#     price_error = callprice_pred - callprice_true
#     price_loss = price_error ** 2
    
#     price_loss.backward()
    
    with torch.no_grad():
        vol -= (callprice_pred - callprice_true) / vol.grad
    
#     with torch.no_grad():
#         vol -= 0.001 * vol.grad

    vol.grad.zero_()
    t += 1
    error = torch.abs(callprice_pred - callprice_true).detach().numpy()
    print("No. of iteration: %i" % t)
    print("Estimated call price: %.6f" % callprice_pred)
    print("Error in call price: %.6f" % error)


Error in call price: 10.555855
No. of iteration: 1
Estimated call price: 17.776073
Error in call price: 10.555855
No. of iteration: 2
Estimated call price: 7.091591
Error in call price: 0.128628
No. of iteration: 3
Estimated call price: 7.220211
Error in call price: 0.000008


### Implied volatility using SGD

In [150]:
vol = torch.tensor(1.0, requires_grad=True, dtype=torch.float, device=device)
callprice_true = torch.tensor(7.220219)
n_iteration = 20
lr = 0.001
d1 = (torch.log(S0/K) + (r + vol**2 / 2) * T) / (vol * T**0.5)
d2 = (torch.log(S0/K) + (r - vol**2 / 2) * T) / (vol * T**0.5)
callprice_pred = S0* cdf(d1) - K*torch.exp(-r * T)*cdf(d2)

error = torch.abs(callprice_pred - callprice_true).detach().numpy()
print("Error in call price: %.6f" % error)

loss_fn = torch.nn.MSELoss(reduction='mean')
optimizer = torch.optim.SGD([vol],lr=lr)

t = 0
while t < n_iteration and error > 0.00001 :
    d1 = (torch.log(S0/K) + (r + vol**2 / 2) * T) / (vol * T**0.5)
    d2 = (torch.log(S0/K) + (r - vol**2 / 2) * T) / (vol * T**0.5)
    callprice_pred = S0* cdf(d1) - K*torch.exp(-r * T)*cdf(d2)
    
#     price_error = callprice_pred - callprice_pred
#     price_loss = price_error ** 2   
#     price_loss.backward()
    price_loss = loss_fn(callprice_pred, callprice_true)
  
    price_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    error = torch.abs(callprice_pred - callprice_true).detach().numpy()
    print("No. of iteration: %i" % t)
    print("Estimated call price: %.6f" % callprice_pred)
    print("Error in call price: %.6f" % error)
    t += 1
    
print("Implied vol: %.6f" % vol.detach().numpy())


Error in call price: 10.555855
No. of iteration: 0
Estimated call price: 17.776073
Error in call price: 10.555855
No. of iteration: 1
Estimated call price: 11.345043
Error in call price: 4.124824
No. of iteration: 2
Estimated call price: 8.771000
Error in call price: 1.550781
No. of iteration: 3
Estimated call price: 7.796814
Error in call price: 0.576595
No. of iteration: 4
Estimated call price: 7.433849
Error in call price: 0.213630
No. of iteration: 5
Estimated call price: 7.299274
Error in call price: 0.079055
No. of iteration: 6
Estimated call price: 7.249458
Error in call price: 0.029239
No. of iteration: 7
Estimated call price: 7.231037
Error in call price: 0.010818
No. of iteration: 8
Estimated call price: 7.224216
Error in call price: 0.003997
No. of iteration: 9
Estimated call price: 7.221699
Error in call price: 0.001480
No. of iteration: 10
Estimated call price: 7.220764
Error in call price: 0.000545
No. of iteration: 11
Estimated call price: 7.220421
Error in call price: 0

### Implied volatility SGD with model framework

In [154]:
class BScallprice(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # wrap vol with nn.Parameter
        self.vol = nn.Parameter(torch.tensor(1, requires_grad=True, dtype=torch.float))
       
    def forward(self, S0, K, r, T):
        d1 = (torch.log(S0/K) + (r + self.vol**2 / 2) * T) / (self.vol * T**0.5)
        d2 = (torch.log(S0/K) + (r - self.vol**2 / 2) * T) / (self.vol * T**0.5)
        callprice = S0* cdf(d1) - K*torch.exp(-r * T)*cdf(d2)
        return callprice


torch.manual_seed(42)
callprice_true = torch.tensor(7.220219)
n_iteration = 20
lr = 0.001

model = BScallprice().to(device)
print(model.state_dict())

loss_fn = torch.nn.MSELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(),lr=lr)

t = 0

while t < n_iteration:
    
    model.train()
    callprice_pred = model(S0, K, r, T)
    price_loss = loss_fn(callprice_pred, callprice_true)
  
    price_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    error = torch.abs(callprice_pred - callprice_true).detach().numpy()
    print("No. of iteration: %i" % t)
    print("Estimated call price: %.6f" % callprice_pred)
    print("Error in call price: %.6f" % error)
    t += 1
    
print(model.state_dict())

OrderedDict([('vol', tensor(1.))])
No. of iteration: 0
Estimated call price: 17.776073
Error in call price: 10.555855
No. of iteration: 1
Estimated call price: 11.345043
Error in call price: 4.124824
No. of iteration: 2
Estimated call price: 8.771000
Error in call price: 1.550781
No. of iteration: 3
Estimated call price: 7.796814
Error in call price: 0.576595
No. of iteration: 4
Estimated call price: 7.433849
Error in call price: 0.213630
No. of iteration: 5
Estimated call price: 7.299274
Error in call price: 0.079055
No. of iteration: 6
Estimated call price: 7.249458
Error in call price: 0.029239
No. of iteration: 7
Estimated call price: 7.231037
Error in call price: 0.010818
No. of iteration: 8
Estimated call price: 7.224216
Error in call price: 0.003997
No. of iteration: 9
Estimated call price: 7.221699
Error in call price: 0.001480
No. of iteration: 10
Estimated call price: 7.220764
Error in call price: 0.000545
No. of iteration: 11
Estimated call price: 7.220421
Error in call pric