<a href="https://colab.research.google.com/github/wqiu96/summer_project/blob/master/DeepBSDE_pytorch/solver_pytorch_v02.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/wqiu96/summer_project.git

Cloning into 'summer_project'...
remote: Enumerating objects: 49, done.[K
remote: Counting objects: 100% (49/49), done.[K
remote: Compressing objects: 100% (49/49), done.[K
remote: Total 517 (delta 20), reused 0 (delta 0), pack-reused 468[K
Receiving objects: 100% (517/517), 2.51 MiB | 23.79 MiB/s, done.
Resolving deltas: 100% (268/268), done.


In [2]:
cd summer_project/DeepBSDE_pytorch/

/content/summer_project/DeepBSDE_pytorch


In [0]:
import logging
import time
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.autograd import Variable
from torch.distributions import uniform
import torchvision
from equation_pytorch import get_equation
from config_pytorch import get_config


MOMENTUM = 0.99
EPSILON = 1e-6
DELTA_CLIP = 50.0


class Net(nn.Module):
    def __init__(self,num_hiddens):
      super(Net, self).__init__()
      self.num_hiddens = num_hiddens
      
      self.fc1 = nn.Linear(num_hiddens[0], num_hiddens[1])
      self.fc2 = nn.Linear(num_hiddens[1], num_hiddens[2])
      self.fc3 = nn.Linear(num_hiddens[2], num_hiddens[3])
    
    def forward(self, x):
      # h1 = relu(xw1)
      x = F.relu(self.fc1(x))
      # h2 = relu(h1w2)
      x = F.relu(self.fc2(x))
      # h3 = h2w3
      x = self.fc3(x)
      #termin time
      return x
      

class DeepNet(nn.Module):
    def __init__(self,num_hiddens,config,bsde):
      super(DeepNet, self).__init__()
      self.num_hiddens = num_hiddens
      self._config = config
      self._bsde = bsde
      
      # make sure consistent with FBSDE equation
      self._dim = bsde.dim
      self._num_time_interval = bsde.num_time_interval
      # ops for statistics update of batch normalization
      self.linears = nn.ModuleList([Net(num_hiddens) for i in range(bsde.num_time_interval - 1)])
    
    def forward(self,x):
      #dw_train= torch.from_numpy(self._bsde.sample()[0])
      dw_train= self._bsde.sample()[0]
      time_stamp = np.arange(0, self._bsde.num_time_interval) * self._bsde.delta_t
      m = uniform.Uniform(self._config.y_init_range[0],self._config.y_init_range[1])
      y = Variable(m.sample()) #initial
      z = torch.rand([self._dim,1],dtype=torch.float64)
      for t in range(0,bsde.num_time_interval - 1):
        dw = torch.tensor([dw_train[:, t]]) 
        self._bsde.delta_t* (self._bsde.f_tf(time_stamp[t], x[:, t], y, z))
        y = y - self._bsde.delta_t* (self._bsde.f_tf(time_stamp[t], x[:, t], y, z)) + torch.mm(dw, z)
        z = (self.linears[t + 1](x[:,t + 1]) / self._dim)
        #print(dw)
        print(z)
      #terminal condition
      dw = torch.tensor([dw_train[:, -1]])
      y = y - self._bsde.delta_t * (self._bsde.f_tf(time_stamp[-1], x[:, -2], y, z)) + torch.mm(dw, z)
      return y


In [5]:
config = get_config('HJB')
bsde = get_equation('HJB', config.dim, config.total_time, config.num_time_interval)

deepNet = DeepNet(config.num_hiddens,config,bsde)
optimizer = optim.SGD(deepNet.parameters(), lr=config.num_hiddens[0], momentum=MOMENTUM) # lr have some different wiht the original

for epoch in range(config.num_iterations):
  x_ = bsde.sample()[1].astype(np.float32)
  x = torch.from_numpy(x_)
  out = deepNet(x)
  delta = out - bsde.g_tf(bsde.total_time, x[:, -1])
  loss = torch.mean(torch.where(torch.abs(delta) < DELTA_CLIP, torch.pow(delta,2),2 * DELTA_CLIP * torch.abs(delta) - DELTA_CLIP ** 2))
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  train_loss.append(loss.item())
  if epoch % 100 == 0 :
    print(epoch, loss.item(), out)

tensor([ 5.0821e-05,  6.8210e-06, -1.2216e-04,  3.7109e-04, -1.0829e-03,
         5.9184e-04,  9.8436e-04, -1.1474e-03, -1.0682e-04,  9.9749e-04,
         8.1411e-04, -1.0944e-03, -5.0104e-04, -3.1971e-04, -5.6897e-04,
        -3.6075e-05,  7.4160e-04,  1.2134e-04, -1.6969e-04, -4.7370e-04,
        -1.5632e-04,  9.1288e-05,  3.7102e-04, -6.4551e-04, -5.8356e-04,
        -6.2464e-04,  6.8228e-04,  2.6126e-04,  1.8897e-04, -7.7021e-04,
         5.2659e-04, -2.4590e-04, -1.4023e-04,  1.0969e-03, -6.9383e-06,
         1.8265e-04,  8.9250e-04, -3.6833e-04, -5.2728e-04, -1.0659e-03,
        -2.1065e-04, -6.7078e-05, -7.3900e-04,  9.9913e-04,  4.3640e-04,
         9.0035e-04, -9.1536e-04,  6.9352e-04,  7.7909e-05, -8.0191e-04,
         4.2797e-05,  7.2556e-04, -1.1524e-03,  4.8242e-04, -1.7515e-05,
        -6.4546e-04, -5.4911e-04, -5.1688e-04, -2.6023e-04,  9.0084e-04,
         5.1417e-04,  4.5579e-04, -9.9694e-04,  7.0968e-05,  7.3168e-04,
         4.5954e-04,  6.4145e-05,  1.0809e-03,  1.1

IndexError: ignored