In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math
import time
import pickle
from networks import *
from fbsde import *

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
h = 1
epsilon = 0.01
print(device)
print(torch.__version__)

cuda
2.0.0


In [3]:
print(pickle.format_version)

4.0


In [4]:
sup_per_dim = 5
batch_size = 100
num_iterations = 66000
mse = nn.MSELoss(reduction="sum")
runs = 10
r = 0.05
volatility = 0.4
T = 1
fbsde = BS_Barenblatt(volatility, r, 1)
#optimizer = "LBFGS" # Can use LBFGS or Adam
#learning_rate = 1
optimizer = "Adam"
learning_rate = 1e-3

In [5]:
def loss_diff(pde, u, t, x):
  u_t = torch.autograd.grad(u, t, grad_outputs=torch.ones_like(u), create_graph=True)[0]
  Du = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), create_graph=True)[0]
  #Hessian H[i][j] is derivative with respect to jth variable then with respect to ith variable or is it the other way around
  I_N = torch.eye(x.shape[-1], device=device)
  def get_vjp(v):
    return torch.autograd.grad(Du, x, grad_outputs=v.repeat(x.shape[0], 1), create_graph=True)
  D2u = torch.vmap(get_vjp)(I_N)[0]
  if len(x.shape) > 1:
    D2u = D2u.swapaxes(0, 1)
  A = D2u @ pde.sigma(t,x,u) @ pde.sigma(t, x, u).transpose(-2, -1)
  trace = torch.diagonal(A, dim1=-2, dim2=-1).sum(dim=-1, keepdim=True)
  #trace = torch.vmap(torch.trace)(A)
  # in the code D2u[sample][i][j] is the derivative with respect to ith variable then jth variable
  f = pde.phi(t, x, u, Du) - torch.sum(Du * pde.mu(t, x, u, Du), dim=-1, keepdim=True) - 1/2 * trace
  return mse(u_t, f)

def loss_bc(pde, u, x):
  return mse(pde.g(x), u)

The choice of interior training points is specific to the Black-Scholes-Barenblatt PDE as there is a simple closed form solution for the forward process in the associated PDE and so we can train on points chosen in a similar way to the FBSDE loss case.

In [6]:
#for activation in [torch.sin, F.relu]:
for activation in [torch.sin]:
    activation_str = None
    if activation == torch.sin:
        activation_str = "sin"
    elif activation == F.relu:
        activation_str = "relu"
    for d in [100, 3]:
        losses = []
        boundary_losses = []
        times = []
        zeta = torch.tensor(int(d / 2) * [1., 0.5] + (d % 2) * [1.], device=device)
        for run in range(runs):
            run_losses = []
            run_boundary_losses = []
            run_times = []
            pinn_network = NAIS_Net_Untied(d+1, 256, 4, 1, activation, epsilon, h).to(device)
            pinn_optimizer = torch.optim.LBFGS(pinn_network.parameters(), lr=learning_rate) if optimizer == "LBFGS" else torch.optim.Adam(pinn_network.parameters(), lr=learning_rate)
            for iteration in range(num_iterations):
                start_time = time.time()
                interior_ts = torch.rand((batch_size, 1), requires_grad=True, device=device)
                # This is where the markdown remark comes in
                interior_xs = zeta.detach().clone().repeat(batch_size, 1) * torch.exp(-((fbsde.volatility)**2 / 2) * interior_ts + fbsde.volatility * torch.sqrt(interior_ts) * torch.randn((batch_size, d), device=device))
                #interior_xs = sup_per_dim * torch.rand((batch_size, d), device=device)
                interior_xs.requires_grad_(True)

                boundary_ts = fbsde.T * torch.ones((batch_size, 1), device=device)
                boundary_ts.requires_grad_(True)
                #boundary_xs = sup_per_dim * torch.rand((batch_size, d), device=device)
                boundary_xs = zeta.detach().clone().repeat(batch_size, 1) * torch.exp(-((fbsde.volatility)**2 / 2) * boundary_ts + fbsde.volatility * torch.sqrt(boundary_ts) * torch.randn((batch_size, d), device=device))
                boundary_xs.requires_grad_(True)

                if optimizer == "LBFGS":
                    def closure():
                        pinn_optimizer.zero_grad()
                        xs_i = interior_xs.detach().clone().requires_grad_(True)
                        ts_i = interior_ts.detach().clone().requires_grad_(True)
                        sample_i = torch.cat((ts_i, xs_i), dim=-1)
                        u_i = pinn_network(sample_i)
                        loss_interior = loss_diff(fbsde, u_i, ts_i, xs_i)/batch_size
                        xs_b = boundary_xs.detach().clone().requires_grad_(True)
                        ts_b = boundary_ts.detach().clone().requires_grad_(True)
                        sample_b = torch.cat((ts_b, xs_b), dim=-1)
                        u_b = pinn_network(sample_b)
                        loss_boundary = loss_bc(fbsde, u_b, xs_b)/batch_size
                        loss = loss_interior + loss_boundary
                        loss.backward()
                        print(loss)
                        return loss
                    loss = pinn_optimizer.step(closure)
                elif optimizer == "Adam":
                    pinn_optimizer.zero_grad()
                    interior_sample = torch.cat((interior_ts, interior_xs), dim=-1)
                    u = pinn_network(interior_sample)
                    loss_interior = loss_diff(fbsde, u, interior_ts, interior_xs)/batch_size
                    boundary_sample = torch.cat((boundary_ts, boundary_xs), dim=-1)
                    u = pinn_network(boundary_sample)
                    loss_boundary = loss_bc(fbsde, u, boundary_xs)/batch_size
                    loss = loss_interior + loss_boundary
                    loss.backward()
                    pinn_optimizer.step()
                    run_boundary_losses.append(loss_boundary.item())
                if iteration % 1000 == 0:
                    print("Iteration: %d, Loss: %.3e" % (iteration, loss.item()))
                    torch.save(pinn_network.state_dict(), "PINN_%s/pde_pinn_dimensions_%d_run_%d_iteration_%d.pt" % (activation_str, d, run, iteration))
                run_losses.append(loss.item())
                run_times.append(time.time() - start_time)
            losses.append(run_losses)
            boundary_losses.append(run_boundary_losses)
            times.append(run_times)
            torch.save(pinn_network.state_dict(), "PINN_%s/pde_pinn_dimensions_%d_run_%d_trained.pt" % (activation_str, d, run))
        with open("PINN_%s/losses_dimensions_%d.pkl" % (activation_str, d), "wb") as f:
            pickle.dump(losses, f)
        with open("PINN_%s/terminal_losses_dimensions_%d.pkl" % (activation_str, d), "wb") as f:
            pickle.dump(boundary_losses, f)
        with open("PINN_%s/times_dimensions_%d.pkl" % (activation_str, d), "wb") as f:
            pickle.dump(run_times, f)

Iteration: 0, Loss: 5.548e+03
Iteration: 1000, Loss: 5.736e+01
Iteration: 2000, Loss: 4.259e+01
Iteration: 3000, Loss: 3.520e+01
Iteration: 4000, Loss: 3.900e+01
Iteration: 5000, Loss: 5.461e+01
Iteration: 6000, Loss: 2.995e+01
Iteration: 7000, Loss: 6.977e+01
Iteration: 8000, Loss: 2.105e+01
Iteration: 9000, Loss: 2.933e+01
Iteration: 10000, Loss: 2.945e+01
Iteration: 11000, Loss: 3.522e+01
Iteration: 12000, Loss: 2.073e+01
Iteration: 13000, Loss: 2.818e+01
Iteration: 14000, Loss: 1.924e+01
Iteration: 15000, Loss: 2.404e+01
Iteration: 16000, Loss: 4.114e+01
Iteration: 17000, Loss: 2.426e+01
Iteration: 18000, Loss: 2.158e+01
Iteration: 19000, Loss: 3.909e+01
Iteration: 20000, Loss: 5.680e+01
Iteration: 21000, Loss: 2.824e+01
Iteration: 22000, Loss: 3.592e+01
Iteration: 23000, Loss: 2.802e+01
Iteration: 24000, Loss: 1.922e+01
Iteration: 25000, Loss: 1.850e+01
Iteration: 26000, Loss: 4.137e+01
Iteration: 27000, Loss: 2.572e+01
Iteration: 28000, Loss: 4.837e+01
Iteration: 29000, Loss: 2.2

KeyboardInterrupt: 