In [108]:
import sys
from zmq import device
print(sys.executable)
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import random
from torch.optim import LBFGS, Adam
from tqdm import tqdm
from scipy.stats import norm
from utils import *
from pinn import PINNs

/data/vinay_2421ma05/VINAY/BS_PF/.conda/bin/python


In [109]:
# BS params
K = 4
sigma = 0.3
r = 0.03
T = 1
L = 10
N_x = 101
N_t = 101
seed = 0
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [110]:
res, b_left, b_right, b_upper, b_lower = get_data([0,10], [0, 1], N_x, N_t)
res_test, _, _, _, _ = get_data([0,10], [0,1], N_x, N_t)
# print(res.shape, b_left.shape, b_right.shape, b_upper.shape, b_lower.shape)
res = torch.tensor(res, dtype=torch.float32, requires_grad=True).to(device)
b_left = torch.tensor(b_left, dtype=torch.float32, requires_grad=True).to(device)
b_right = torch.tensor(b_right, dtype=torch.float32, requires_grad=True).to(device)
b_upper = torch.tensor(b_upper, dtype=torch.float32, requires_grad=True).to(device)
b_lower = torch.tensor(b_lower, dtype=torch.float32, requires_grad=True).to(device)
x_res, t_res = res[:,0:1], res[:,1:2]
x_left, t_left = b_left[:,0:1], b_left[:,1:2]
x_right, t_right = b_right[:,0:1], b_right[:,1:2]
x_upper, t_upper = b_upper[:,0:1], b_upper[:,1:2]
x_lower, t_lower = b_lower[:,0:1], b_lower[:,1:2]

print(x_res.shape, x_left.shape, x_right.shape, x_upper.shape, x_lower.shape)
print(t_res.shape, t_left.shape, t_right.shape, t_upper.shape, t_lower.shape)

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

torch.Size([10000, 1]) torch.Size([101, 1]) torch.Size([101, 1]) torch.Size([101, 1]) torch.Size([101, 1])
torch.Size([10000, 1]) torch.Size([101, 1]) torch.Size([101, 1]) torch.Size([101, 1]) torch.Size([101, 1])


In [111]:
pinn = PINNs(in_dim=2, hidden_dim=32, out_dim=1, num_layer=4).to(device)
pinn.apply(init_weights)
optim = LBFGS(pinn.parameters(), line_search_fn='strong_wolfe')
# optim = Adam(model.parameters(), lr=1e-4)
n_params = get_n_params(pinn)
print(pinn)
print('No of Parameters',n_params)

PINNs(
  (linear): Sequential(
    (0): Linear(in_features=2, out_features=32, bias=True)
    (1): Tanh()
    (2): Linear(in_features=32, out_features=32, bias=True)
    (3): Tanh()
    (4): Linear(in_features=32, out_features=32, bias=True)
    (5): Tanh()
    (6): Linear(in_features=32, out_features=1, bias=True)
  )
)
No of Parameters 2241


  torch.nn.init.xavier_uniform(m.weight)


In [112]:
kernel_size = 200
D1 = kernel_size
D2 = len(x_left)
D3 = len(x_lower)
def compute_ntk(J1, J2):
    Ker = torch.matmul(J1, torch.transpose(J2, 0, 1))
    return Ker

In [113]:
print(D1)
print(D2)
print(D3)

200
101
101


In [115]:
print(J1)

tensor([[ 0.,  0.,  0.,  ...,  1., -1.,  1.],
        [ 0.,  0.,  0.,  ...,  1., -1.,  1.],
        [ 0.,  0.,  0.,  ...,  1., -1.,  1.],
        ...,
        [ 0.,  0.,  0.,  ...,  1., -1.,  1.],
        [ 0.,  0.,  0.,  ...,  1., -1.,  1.],
        [ 0.,  0.,  0.,  ...,  1., -1.,  1.]])


In [114]:
loss_track = []
w1, w2, w3 = 1, 1, 1
n_epochs = 120
for i in tqdm(range(n_epochs)):
    if i % 50 == 0:
        J1 = torch.zeros((D1, n_params))
        J2 = torch.zeros((D2, n_params))
        J3 = torch.zeros((D3, n_params))

        batch_ind = np.random.choice(len(x_res), kernel_size, replace=False)
        x_train, t_train = x_res[batch_ind], t_res[batch_ind]

        pred_res = pinn(x_train, t_train)
        pred_left = pinn(x_left, t_left)
        pred_upper = pinn(x_upper, t_upper)
        pred_lower = pinn(x_lower, t_lower)

        for j in range(len(x_train)):
            pinn.zero_grad()
            pred_res[j].backward(retain_graph=True)
            J1[j, :] = torch.cat([p.grad.view(-1) for p in pinn.parameters()])
        for j in range(len(x_left)):
            pinn.zero_grad()
            pred_left[j].backward(retain_graph=True)
            J2[j, :] = torch.cat([p.grad.view(-1) for p in pinn.parameters()])
        for j in range(len(x_lower)):
            pinn.zero_grad()
            pred_lower[j].backward(retain_graph=True)
            pred_upper[j].backward(retain_graph=True)
            J3[j, :] = torch.cat([p.grad.view(-1) for p in pinn.parameters()])
        K1 = torch.trace(compute_ntk(J1, J1))
        K2 = torch.trace(compute_ntk(J2, J2))
        K3 = torch.trace(compute_ntk(J3, J3))
        eps=1e-10
        K = K1+K2+K3
        w1 = K.item() / (K1.item()+eps)
        w2 = K.item() / (K2.item()+eps)
        w3 = K.item() / (K3.item()+eps)
    def closure():
        pred_res = pinn(x_res, t_res)
        pred_left = pinn(x_left, t_left)
        pred_right = pinn(x_right, t_right)
        pred_upper = pinn(x_upper, t_upper)
        pred_lower = pinn(x_lower, t_lower)

        u_x = torch.autograd.grad(pred_res, x_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True, create_graph=True)[0]
        u_xx = torch.autograd.grad(u_x, x_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True, create_graph=True)[0]
        u_t = torch.autograd.grad(pred_res, t_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True, create_graph=True)[0]

        loss_res = torch.mean((u_t - ((sigma**2 * x_res**2) / 2) * u_xx - (r * x_res) * u_x + (r * pred_res)) ** 2)
        loss_bc = torch.mean((pred_upper) ** 2) + torch.mean((pred_lower - (K * torch.exp(-r * t_lower))) ** 2)
        loss_ic = torch.mean((pred_left[:,0] - torch.max(K - x_left[:,0], torch.zeros(x_left[:,0].shape).to(device))) ** 2)

        loss_track.append([loss_res.item(), loss_ic.item(), loss_bc.item()])
        loss = w1*loss_res + w2*loss_ic + w3*loss_bc
        optim.zero_grad()
        loss.backward()
        return loss

    optim.step(closure)
if i % 100 == 0:
        print(f'{i}/{n_epochs} PDE Loss: {loss_track[-1][0]:.9f}, BVP Loss: {loss_track[-1][1]:.9f}, IC Loss: {loss_track[-1][2]:.9f},')

print('Loss Res: {:9f}, Loss_BC: {:9f}, Loss_IC: {:9f}'.format(loss_track[-1][0], loss_track[-1][1], loss_track[-1][2]))
print('Train Loss: {:4f}'.format(np.sum(loss_track[-1])))

100%|██████████| 120/120 [00:11<00:00, 10.76it/s]

Loss Res: 295742.718750, Loss_BC: 22138758.000000, Loss_IC: 328641984.000000
Train Loss: 351076484.718750





In [105]:
print(J1)

tensor([[ 0.,  0.,  0.,  ...,  1., -1.,  1.],
        [ 0.,  0.,  0.,  ...,  1., -1.,  1.],
        [ 0.,  0.,  0.,  ...,  1., -1.,  1.],
        ...,
        [ 0.,  0.,  0.,  ...,  1., -1.,  1.],
        [ 0.,  0.,  0.,  ...,  1., -1.,  1.],
        [ 0.,  0.,  0.,  ...,  1., -1.,  1.]])


In [95]:
print(K1,K2,K3)
print(pred_res)
print(pred_left)
print(pred_lower)
print(pred_upper)

tensor(6600.) tensor(3333.) tensor(12928.)
tensor([[18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18127.4121],
        [18

In [33]:
print(K)

tensor(nan)


In [106]:
print(w1,w2,w3)

print(loss_track)

3.4637878787878265 6.858985898589653 1.7683323019801844
[[295742.71875, 22138758.0, 328641984.0], [295742.90625, 22138704.0, 328642240.0], [295742.71875, 22138758.0, 328641984.0], [295742.71875, 22138758.0, 328641984.0], [295742.71875, 22138758.0, 328641984.0], [295742.71875, 22138758.0, 328641984.0], [295742.71875, 22138758.0, 328641984.0], [295742.90625, 22138704.0, 328642240.0], [295742.71875, 22138758.0, 328641984.0], [295742.71875, 22138758.0, 328641984.0], [295742.71875, 22138758.0, 328641984.0], [295742.71875, 22138758.0, 328641984.0], [295742.71875, 22138758.0, 328641984.0], [295742.90625, 22138704.0, 328642240.0], [295742.71875, 22138758.0, 328641984.0], [295742.71875, 22138758.0, 328641984.0], [295742.71875, 22138758.0, 328641984.0], [295742.71875, 22138758.0, 328641984.0], [295742.71875, 22138758.0, 328641984.0], [295742.90625, 22138704.0, 328642240.0], [295742.71875, 22138758.0, 328641984.0], [295742.71875, 22138758.0, 328641984.0], [295742.71875, 22138758.0, 328641984.0], 

In [17]:
print('NTK weight: w_res: {:4f} w_ic: {:4f}, w_bc: {:4f}'.format(w1, w2, w3))
print('Train Loss: {:4f}'.format(np.sum(loss_track[-1])))
torch.save(pinn.state_dict(), './1dwave_pinns_ntk.pt')

NTK weight: w_res: 3.525000 w_ic: 6.980198, w_bc: 1.745050
Train Loss: 97207286064.000000


In [18]:
state = {
    'epoch': n_epochs,
    'state_dict': pinn.state_dict(),
    'optimizer': optim.state_dict(),
    'loss_hist': loss_track
}

torch.save(state, './BS_Put_PINNs_101')
# Testing
N_x=101
N_t=101
res_test, _, b_right_test, _, _ = get_test_data([0,10], [0,1], N_x, N_t)
# step_size = 1e-4

N = norm.cdf

# res_test = make_time_sequence(res_test, num_step=5, step=step_size)
res_test = torch.tensor(res_test, dtype=torch.float32, requires_grad=True).to(device)
b_right_test = torch.tensor(b_right_test, dtype=torch.float32, requires_grad=True).to(device)
x_test, t_test = res_test[:,0:1], res_test[:,1:2]
x_right_test, t_right_test = b_right_test[:,0:1], b_right_test[:,1:2]

with torch.no_grad():
    pred = pinn(x_test, t_test)[:,0:1]
    pred_right = pinn(x_right_test, t_right_test)[:,0:1]
    pred = pred.cpu().detach().numpy()
    pred_right = pred_right.cpu().detach().numpy()
pred = pred.reshape(N_x,N_t)

# Closed form solution 
def BS_CALL(S, T):
    d1 = (torch.log(S/K) + (r + sigma**2 / 2)*T) / (sigma*np.sqrt(T))
    d2 = d1 - sigma * torch.sqrt(T)
    return S * N(d1) - K * torch.exp(-r*T)* N(d2)

def BS_PUT(S, T):
    d1 = (np.log(np.where(S/K > 1e-8, S/K, 1e-8)) + (r + sigma**2/2)*T) / (sigma*np.sqrt(np.where(T > 1e-8, T, 1e-8)))
    d2 = d1 - sigma* np.sqrt(T)
    return K * np.exp(-r*T) * (1 - N(d2)) + S * (N(d1) - 1)

res_test, _, b_right_test, _, _ = get_test_data([0,10], [0,1], N_x, N_t)
u = BS_PUT(res_test[:,0], res_test[:,1]).reshape(N_x,N_t)
u_right = BS_PUT(b_right_test[:,0], b_right_test[:,1])

# Relative l1 and l2 errors at full grid
rl1 = np.sum(np.abs(u-pred)) / np.sum(np.abs(u))
rl2 = np.sqrt(np.sum((u-pred)**2) / np.sum(u**2))
print('relative L1 error: {:4f}'.format(rl1))
print('relative L2 error: {:4f}'.format(rl2))

rl1_right = np.sum(np.abs(u_right-pred_right[:,0])) / np.sum(np.abs(u_right))
rl2_right = np.sqrt(np.sum((u_right-pred_right[:,0])**2) / np.sum(u_right**2))
print('relative L1 error (At Final Time) :{:4f}'.format(rl1_right))
print('relative L2 error (At Final Time) :{:4f}'.format(rl2_right))

plt.figure(figsize=(4,3))
plt.imshow(pred, extent=[0,10,1,0], aspect='auto')
plt.xlabel('x')
plt.ylabel('t')
plt.title('Predicted u(x,t)')
plt.colorbar()
plt.tight_layout()
plt.savefig('./1dBS_Put_pinns_pred.png')

plt.figure(figsize=(4,3))
plt.imshow(u, extent=[0,10,1,0], aspect='auto')
plt.xlabel('x')
plt.ylabel('t')
plt.title('Exact u(x,t)')
plt.colorbar()
plt.tight_layout()
plt.savefig('./1dBS_Put_exact.png')

# plt.plot(X[final_index, 0], y_pred[final_index], '--', color="r")
plt.figure()
plt.plot(x_right_test.cpu().detach().numpy(), pred_right, '--', color="r")
plt.xlabel('S')
plt.ylabel('V(S, T)')
plt.title('Predicted u(x,t) (Final Time)')
# set the limits
plt.xlim([0, 10])
plt.ylim([0, 4])
plt.savefig('./1dBS_Put_pinns_pred(Final Time).png')

# Pointwise Error at final time
plt.figure()
plt.plot(x_right_test.cpu().detach().numpy()[:,0], u_right - pred_right[:,0], '--', color="r")
plt.xlabel('S')
plt.ylabel('V(S, T)')
plt.title('Pointwise Error (Final Time)')
# set the limits
plt.xlim([0, 10])
plt.ylim([0, 0.002])
plt.savefig('./1dBS_Put_PINNs_pointwise_error(Final Time).png')
plt.show()

print("Maximum Poinwise error (At Final Fime): {:4f}".format(np.max(np.abs(u_right - pred_right[:,0]))))



  return K * np.exp(-r*T) * (1 - N(d2)) + S * (N(d1) - 1)
  rl1 = np.sum(np.abs(u-pred)) / np.sum(np.abs(u))


TypeError: sum() received an invalid combination of arguments - got (out=NoneType, axis=NoneType, ), but expected one of:
 * (*, torch.dtype dtype = None)
      didn't match because some of the keywords were incorrect: out, axis
 * (tuple of ints dim, bool keepdim = False, *, torch.dtype dtype = None)
 * (tuple of names dim, bool keepdim = False, *, torch.dtype dtype = None)
