In [1]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from  torch.optim import lr_scheduler
import numpy as np
import matplotlib.pyplot as plt
from math import *
import time

In [2]:
torch.set_default_tensor_type('torch.DoubleTensor')

In [3]:
# activation function
def activation(x):
    return x * torch.sigmoid(x) 

In [4]:
# build ResNet with one blocks
class Net(nn.Module):
    def __init__(self,input_size,width):
        super(Net,self).__init__()
        self.layer_in = nn.Linear(input_size,width)
        self.layer_1 = nn.Linear(width,width)
        self.layer_2 = nn.Linear(width,width)
        self.layer_out = nn.Linear(width,1)
    def forward(self,x):
        output = self.layer_in(x)
        output = activation(self.layer_2(activation(self.layer_1(output)))) # residual block 1
        output = self.layer_out(output)
        return output

In [5]:
input_size = 1
width = 4
net = Net(input_size,width)

In [6]:
def model(x):
    return x * (x - 1.0) * net(x)

In [7]:
# exact solution
def u_ex(x):  
    return torch.sin(pi*x)

In [8]:
# f(x)
def f(x):
    return pi**2 * torch.sin(pi*x)

In [9]:
grid_num = 100
x = torch.zeros(grid_num + 1, input_size)
for index in range(grid_num + 1):
    x[index] = index * 1 / grid_num

In [10]:
optimizer = optim.Adam(net.parameters(), lr = 0.05)
scheduler = lr_scheduler.StepLR(optimizer, 50, 0.9) # every 100 epoch, learning rate * 0.1

In [11]:
# Xavier normal initialization for weights:
#             mean = 0 std = gain * sqrt(2 / fan_in + fan_out)
# zero initialization for biases
def initialize_weights(self):
    for m in self.modules():
        if isinstance(m,nn.Linear):
            nn.init.xavier_normal(m.weight.data)
            if m.bias is not None:
                m.bias.data.zero_()

In [12]:
# loss function to DGM by auto differential
def loss_function(x):
    h = 1 / grid_num
    sum_0 = 0.0
    sum_1 = 0.0
    sum_2 = 0.0
    sum_a = 0.0
    sum_b = 0.0
    for index in range(grid_num):
        x_temp = x[index] + h / 2 
        x_temp.requires_grad = True
#         grad_x_temp = torch.autograd.grad(model(x_temp), x_temp, create_graph = True)
        grad_x_temp = torch.autograd.grad(outputs = model(x_temp), inputs = x_temp, grad_outputs = torch.ones(model(x_temp).shape), create_graph = True)
        grad_grad_x_temp = torch.autograd.grad(outputs = grad_x_temp[0], inputs = x_temp, grad_outputs = torch.ones(model(x_temp).shape), create_graph = True)
        sum_1 += ((grad_grad_x_temp[0])[0] + f(x_temp)[0])**2
    
    for index in range(1, grid_num):
        x_temp = x[index]
        x_temp.requires_grad = True
#         grad_x_temp = torch.autograd.grad(model(x_temp), x_temp, create_graph = True)
        grad_x_temp = torch.autograd.grad(outputs = model(x_temp), inputs = x_temp, grad_outputs = torch.ones(model(x_temp).shape), create_graph = True)
        grad_grad_x_temp = torch.autograd.grad(outputs = grad_x_temp[0], inputs = x_temp, grad_outputs = torch.ones(model(x_temp).shape), create_graph = True)
        sum_2 += ((grad_grad_x_temp[0])[0] + f(x_temp)[0])**2
    
    x_temp = x[0]
    x_temp.requires_grad = True
#     grad_x_temp = torch.autograd.grad(model(x_temp), x_temp, create_graph = True)
    grad_x_temp = torch.autograd.grad(outputs = model(x_temp), inputs = x_temp, grad_outputs = torch.ones(model(x_temp).shape), create_graph = True)
    grad_grad_x_temp = torch.autograd.grad(outputs = grad_x_temp[0], inputs = x_temp, grad_outputs = torch.ones(model(x_temp).shape), create_graph = True)
    sum_a = ((grad_grad_x_temp[0])[0] + f(x_temp)[0])**2
    
    x_temp = x[grid_num]
    x_temp.requires_grad = True
#     grad_x_temp = torch.autograd.grad(model(x_temp), x_temp, create_graph = True)
    grad_x_temp = torch.autograd.grad(outputs = model(x_temp), inputs = x_temp, grad_outputs = torch.ones(model(x_temp).shape), create_graph = True)
    grad_grad_x_temp = torch.autograd.grad(outputs = grad_x_temp[0], inputs = x_temp, grad_outputs = torch.ones(model(x_temp).shape), create_graph = True)
    sum_b = ((grad_grad_x_temp[0])[0] + f(x_temp)[0])**2
    
    sum_0 = h / 6 * (sum_a + 4 * sum_1 + 2 * sum_2 + sum_b)
    return sum_0

In [13]:
def error_function(x):
    error = 0.0
    for index in range(len(x)):
        x_temp = x[index]
        error += (model(x_temp)[0] - u_ex(x_temp)[0])**2
    return error / len(x)

In [14]:
print("Total number of paramerters in networks is {}  ".format(sum(x.numel() for x in net.parameters())))
param_num = sum(x.numel() for x in net.parameters())

Total number of paramerters in networks is 53  


In [15]:
epoch = 500
loss_record = np.zeros(epoch)
error_record = np.zeros(epoch)
time_start = time.time()
for i in range(epoch):
    optimizer.zero_grad()
    loss = loss_function(x)
    loss_record[i] = float(loss)
    error = error_function(x)
    error_record[i] = float(error)
    print("current epoch is: ", i)
    print("current loss is: ", loss.detach())
    print("current error is: ", error.detach())
     
    loss.backward()
    optimizer.step() 
    np.save("loss_of_DGM_100.npy", loss_record)
    np.save("error_of_DGM_100.npy", error_record)
    
np.save("loss_of_DGM_100.npy", loss_record)
np.save("error_of_DGM_100.npy", error_record)

time_end = time.time()
print('total time is: ', time_end-time_start, 'seconds')

current epoch is:  0
current loss is:  tensor(46.2951)
current error is:  tensor(0.4705)
current epoch is:  1
current loss is:  tensor(42.8174)
current error is:  tensor(0.4348)
current epoch is:  2
current loss is:  tensor(39.7662)
current error is:  tensor(0.4031)
current epoch is:  3
current loss is:  tensor(36.9814)
current error is:  tensor(0.3740)
current epoch is:  4
current loss is:  tensor(34.4455)
current error is:  tensor(0.3472)
current epoch is:  5
current loss is:  tensor(32.1860)
current error is:  tensor(0.3230)
current epoch is:  6
current loss is:  tensor(30.2226)
current error is:  tensor(0.3018)
current epoch is:  7
current loss is:  tensor(28.5502)
current error is:  tensor(0.2836)
current epoch is:  8
current loss is:  tensor(27.1301)
current error is:  tensor(0.2680)
current epoch is:  9
current loss is:  tensor(25.8370)
current error is:  tensor(0.2535)
current epoch is:  10
current loss is:  tensor(24.5282)
current error is:  tensor(0.2388)
current epoch is:  1

current epoch is:  91
current loss is:  tensor(0.0427)
current error is:  tensor(2.7050e-06)
current epoch is:  92
current loss is:  tensor(0.0403)
current error is:  tensor(1.7624e-06)
current epoch is:  93
current loss is:  tensor(0.0384)
current error is:  tensor(1.6362e-06)
current epoch is:  94
current loss is:  tensor(0.0366)
current error is:  tensor(1.1408e-06)
current epoch is:  95
current loss is:  tensor(0.0358)
current error is:  tensor(4.1351e-07)
current epoch is:  96
current loss is:  tensor(0.0350)
current error is:  tensor(3.8486e-07)
current epoch is:  97
current loss is:  tensor(0.0329)
current error is:  tensor(7.7843e-07)
current epoch is:  98
current loss is:  tensor(0.0323)
current error is:  tensor(1.4135e-06)
current epoch is:  99
current loss is:  tensor(0.0319)
current error is:  tensor(1.4714e-06)
current epoch is:  100
current loss is:  tensor(0.0314)
current error is:  tensor(1.4548e-06)
current epoch is:  101
current loss is:  tensor(0.0305)
current error

current epoch is:  179
current loss is:  tensor(0.0127)
current error is:  tensor(2.0894e-07)
current epoch is:  180
current loss is:  tensor(0.0126)
current error is:  tensor(2.0983e-07)
current epoch is:  181
current loss is:  tensor(0.0125)
current error is:  tensor(2.0658e-07)
current epoch is:  182
current loss is:  tensor(0.0124)
current error is:  tensor(2.0701e-07)
current epoch is:  183
current loss is:  tensor(0.0123)
current error is:  tensor(2.1211e-07)
current epoch is:  184
current loss is:  tensor(0.0122)
current error is:  tensor(2.1337e-07)
current epoch is:  185
current loss is:  tensor(0.0121)
current error is:  tensor(2.0838e-07)
current epoch is:  186
current loss is:  tensor(0.0120)
current error is:  tensor(2.0441e-07)
current epoch is:  187
current loss is:  tensor(0.0119)
current error is:  tensor(2.0343e-07)
current epoch is:  188
current loss is:  tensor(0.0118)
current error is:  tensor(1.9891e-07)
current epoch is:  189
current loss is:  tensor(0.0117)
curr

current epoch is:  267
current loss is:  tensor(0.0069)
current error is:  tensor(8.3889e-08)
current epoch is:  268
current loss is:  tensor(0.0069)
current error is:  tensor(8.3155e-08)
current epoch is:  269
current loss is:  tensor(0.0068)
current error is:  tensor(8.2436e-08)
current epoch is:  270
current loss is:  tensor(0.0068)
current error is:  tensor(8.1687e-08)
current epoch is:  271
current loss is:  tensor(0.0067)
current error is:  tensor(8.0985e-08)
current epoch is:  272
current loss is:  tensor(0.0067)
current error is:  tensor(8.0332e-08)
current epoch is:  273
current loss is:  tensor(0.0067)
current error is:  tensor(7.9653e-08)
current epoch is:  274
current loss is:  tensor(0.0066)
current error is:  tensor(7.8974e-08)
current epoch is:  275
current loss is:  tensor(0.0066)
current error is:  tensor(7.8335e-08)
current epoch is:  276
current loss is:  tensor(0.0065)
current error is:  tensor(7.7679e-08)
current epoch is:  277
current loss is:  tensor(0.0065)
curr

current epoch is:  355
current loss is:  tensor(0.0039)
current error is:  tensor(4.1330e-08)
current epoch is:  356
current loss is:  tensor(0.0039)
current error is:  tensor(4.0988e-08)
current epoch is:  357
current loss is:  tensor(0.0038)
current error is:  tensor(4.0646e-08)
current epoch is:  358
current loss is:  tensor(0.0038)
current error is:  tensor(4.0306e-08)
current epoch is:  359
current loss is:  tensor(0.0038)
current error is:  tensor(3.9968e-08)
current epoch is:  360
current loss is:  tensor(0.0038)
current error is:  tensor(3.9630e-08)
current epoch is:  361
current loss is:  tensor(0.0037)
current error is:  tensor(3.9294e-08)
current epoch is:  362
current loss is:  tensor(0.0037)
current error is:  tensor(3.8960e-08)
current epoch is:  363
current loss is:  tensor(0.0037)
current error is:  tensor(3.8627e-08)
current epoch is:  364
current loss is:  tensor(0.0036)
current error is:  tensor(3.8295e-08)
current epoch is:  365
current loss is:  tensor(0.0036)
curr

current epoch is:  443
current loss is:  tensor(0.0018)
current error is:  tensor(1.5786e-08)
current epoch is:  444
current loss is:  tensor(0.0018)
current error is:  tensor(1.5563e-08)
current epoch is:  445
current loss is:  tensor(0.0018)
current error is:  tensor(1.5342e-08)
current epoch is:  446
current loss is:  tensor(0.0018)
current error is:  tensor(1.5123e-08)
current epoch is:  447
current loss is:  tensor(0.0018)
current error is:  tensor(1.4906e-08)
current epoch is:  448
current loss is:  tensor(0.0017)
current error is:  tensor(1.4692e-08)
current epoch is:  449
current loss is:  tensor(0.0017)
current error is:  tensor(1.4481e-08)
current epoch is:  450
current loss is:  tensor(0.0017)
current error is:  tensor(1.4272e-08)
current epoch is:  451
current loss is:  tensor(0.0017)
current error is:  tensor(1.4066e-08)
current epoch is:  452
current loss is:  tensor(0.0017)
current error is:  tensor(1.3862e-08)
current epoch is:  453
current loss is:  tensor(0.0017)
curr

In [16]:
torch.save(net.state_dict(), 'net_params_DGM.pkl')