<a href="https://colab.research.google.com/github/wqiu96/summer_project/blob/master/DeepBSDE_pytorch/solver_pytorch_v02.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/wqiu96/summer_project.git

Cloning into 'summer_project'...
remote: Enumerating objects: 81, done.[K
remote: Counting objects: 100% (81/81), done.[K
remote: Compressing objects: 100% (81/81), done.[K
remote: Total 549 (delta 35), reused 0 (delta 0), pack-reused 468[K
Receiving objects: 100% (549/549), 2.52 MiB | 7.08 MiB/s, done.
Resolving deltas: 100% (283/283), done.


In [2]:
cd summer_project/DeepBSDE_pytorch/

/content/summer_project/DeepBSDE_pytorch


In [0]:
import logging
import time
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.autograd import Variable
from torch.distributions import uniform
import torchvision
from equation_pytorch import get_equation
from config_pytorch import get_config


MOMENTUM = 0.99
EPSILON = 1e-6
DELTA_CLIP = 50.0


class Net(nn.Module):
    def __init__(self,num_hiddens):
      super(Net, self).__init__()
      self.num_hiddens = num_hiddens
      
      self.fc1 = nn.Linear(num_hiddens[0], num_hiddens[1], bias=False)
      self.fc2 = nn.Linear(num_hiddens[1], num_hiddens[2], bias=False)
      self.fc3 = nn.Linear(num_hiddens[2], num_hiddens[3], bias=False)
    
    def forward(self, x):
      # h1 = relu(xw1)
      x = F.relu(self.fc1(x))
      # h2 = relu(h1w2)
      x = F.relu(self.fc2(x))
      # h3 = h2w3
      x = self.fc3(x)
      #termin time
      return x
      

class DeepNet(nn.Module):
    def __init__(self,num_hiddens,config,bsde):
      super(DeepNet, self).__init__()
      self.num_hiddens = num_hiddens
      self._config = config
      self._bsde = bsde
      
      # make sure consistent with FBSDE equation
      self._dim = bsde.dim
      self._num_time_interval = bsde.num_time_interval
      # ops for statistics update of batch normalization
      self.linears = nn.ModuleList([Net(num_hiddens) for i in range(bsde.num_time_interval - 1)])
    
    def forward(self,x):
      #dw_train= torch.from_numpy(self._bsde.sample()[0])
      dw_train= self._bsde.sample()[0].astype(np.float32)
      time_stamp = np.arange(0, self._bsde.num_time_interval) * self._bsde.delta_t
      m = uniform.Uniform(self._config.y_init_range[0],self._config.y_init_range[1])
      y = Variable(m.sample()) #initial
      z = 2*torch.rand([self._dim,1],dtype=torch.float32) - 1 #same as the original
      for t in range(0,bsde.num_time_interval - 1):
        dw = torch.from_numpy(dw_train[:, t]).view(1,100)
        torch.mm(dw.float(), z.float()) # torch.mm have bug use x.float()
        y = y - self._bsde.delta_t* (self._bsde.f_tf(time_stamp[t], x[:, t], y, z)) + torch.mm(dw, z)
        z = (self.linears[t](x[:,t]) / self._dim).view(100,1)
      #terminal condition
      dw = torch.from_numpy(dw_train[:, -1]).view(1,100)
      y = y - self._bsde.delta_t * (self._bsde.f_tf(time_stamp[-1], x[:, -2], y, z)) + torch.mm(dw, z)
      return y


In [21]:
config = get_config('HJB')
bsde = get_equation('HJB', config.dim, config.total_time, config.num_time_interval)

deepNet = DeepNet(config.num_hiddens,config,bsde)
optimizer = optim.SGD(deepNet.parameters(), lr=0.001, momentum=MOMENTUM) # lr have some different wiht the original
train_loss = []
for epoch in range(config.num_iterations):
  x_ = bsde.sample()[1].astype(np.float32)
  x = torch.from_numpy(x_)
  out = deepNet(x)
  delta = out - bsde.g_tf(bsde.total_time, x[:, -1])
  loss = torch.mean(torch.where(torch.abs(delta) < DELTA_CLIP, torch.pow(delta,2),2 * DELTA_CLIP * torch.abs(delta) - DELTA_CLIP ** 2))
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  train_loss.append(loss.item())
  if epoch % 100 == 0 :
    print(epoch, loss.item(), out)
  if loss <= 0.001:
    print(epoch, loss.item(), out)
    break
    

0 10.143587112426758 tensor([[1.5323]], grad_fn=<AddBackward0>)
100 17.113849639892578 tensor([[0.5224]], grad_fn=<AddBackward0>)
200 1.6890392303466797 tensor([[3.3956]], grad_fn=<AddBackward0>)
300 3.0794756412506104 tensor([[2.7902]], grad_fn=<AddBackward0>)
400 7.320667743682861 tensor([[1.8213]], grad_fn=<AddBackward0>)
494 0.0006086149951443076 tensor([[4.6106]], grad_fn=<AddBackward0>)
