In [1]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from  torch.optim import lr_scheduler
import numpy as np
import matplotlib.pyplot as plt
from math import *
import time

In [2]:
torch.set_default_tensor_type('torch.DoubleTensor')

In [3]:
# activation function
def activation(x):
    return x * torch.sigmoid(x) 

In [4]:
# build ResNet with one blocks
class Net(nn.Module):
    def __init__(self,input_size,width):
        super(Net,self).__init__()
        self.layer_in = nn.Linear(input_size,width)
        self.layer_1 = nn.Linear(width,width)
        self.layer_2 = nn.Linear(width,width)
        self.layer_out = nn.Linear(width,1)
    def forward(self,x):
        output = self.layer_in(x)
        output = activation(self.layer_2(activation(self.layer_1(output)))) # residual block 1
        output = self.layer_out(output)
        return output

In [5]:
input_size = 1
width = 4
net = Net(input_size,width)

In [6]:
def model(x):
    return x * (x - 1.0) * net(x)

In [7]:
# exact solution
def u_ex(x):  
    return torch.sin(pi*x)

In [8]:
# f(x)
def f(x):
    return pi**2 * torch.sin(pi*x)

In [9]:
grid_num = 100
x = torch.zeros(grid_num + 1, input_size)
for index in range(grid_num + 1):
    x[index] = index * 1 / grid_num

In [10]:
optimizer = optim.Adam(net.parameters(), lr = 0.05)
scheduler = lr_scheduler.StepLR(optimizer, 50, 0.9) # every 100 epoch, learning rate * 0.1

In [11]:
# Xavier normal initialization for weights:
#             mean = 0 std = gain * sqrt(2 / fan_in + fan_out)
# zero initialization for biases
def initialize_weights(self):
    for m in self.modules():
        if isinstance(m,nn.Linear):
            nn.init.xavier_normal(m.weight.data)
            if m.bias is not None:
                m.bias.data.zero_()

In [12]:
# loss function to DRM by auto differential
def loss_function(x):
    h = 1 / grid_num
    sum_0 = 0.0
    sum_1 = 0.0
    sum_2 = 0.0
    sum_a = 0.0
    sum_b = 0.0
    for index in range(grid_num):
        x_temp = x[index] + h / 2 
        x_temp.requires_grad = True
        grad_x_temp = torch.autograd.grad(outputs = model(x_temp), inputs = x_temp, grad_outputs = torch.ones(model(x_temp).shape), create_graph = True)
        sum_1 += (0.5*grad_x_temp[0]**2 - f(x_temp)[0]*model(x_temp)[0])
        
    for index in range(1, grid_num):
        x_temp = x[index]
        x_temp.requires_grad = True
        grad_x_temp = torch.autograd.grad(outputs = model(x_temp), inputs = x_temp, grad_outputs = torch.ones(model(x_temp).shape), create_graph = True)
        sum_2 += (0.5*grad_x_temp[0]**2 - f(x_temp)[0]*model(x_temp)[0])
    
    x_temp = x[0]
    x_temp.requires_grad = True
    grad_x_temp = torch.autograd.grad(outputs = model(x_temp), inputs = x_temp, grad_outputs = torch.ones(model(x_temp).shape), create_graph = True)
    sum_a = 0.5*grad_x_temp[0]**2 - f(x_temp)[0]*model(x_temp)[0]
    
    x_temp = x[grid_num]
    x_temp.requires_grad = True
    grad_x_temp = torch.autograd.grad(outputs = model(x_temp), inputs = x_temp, grad_outputs = torch.ones(model(x_temp).shape), create_graph = True)
    sum_a = 0.5*grad_x_temp[0]**2 - f(x_temp)[0]*model(x_temp)[0]
    
    sum_0 = h / 6 * (sum_a + 4 * sum_1 + 2 * sum_2 + sum_b)
    return sum_0

In [13]:
def error_function(x):
    error = 0.0
    for index in range(len(x)):
        x_temp = x[index]
        error += (model(x_temp)[0] - u_ex(x_temp)[0])**2
    return error / len(x)

In [14]:
print("Total number of paramerters in networks is {}  ".format(sum(x.numel() for x in net.parameters())))
param_num = sum(x.numel() for x in net.parameters())

Total number of paramerters in networks is 53  


In [15]:
epoch = 500
loss_record = np.zeros(epoch)
error_record = np.zeros(epoch)
time_start = time.time()
for i in range(epoch):
    optimizer.zero_grad()
    loss = loss_function(x)
    loss_record[i] = float(loss)
    error = error_function(x)
    error_record[i] = float(error)
    print("current epoch is: ", i)
    print("current loss is: ", loss.detach())
    print("current error is: ", error.detach())
     
    loss.backward()
    optimizer.step() 
    np.save("loss_of_DRM_100.npy", loss_record)
    np.save("error_of_DRM_100.npy", error_record)
    
np.save("loss_of_DRM_100.npy", loss_record)
np.save("error_of_DRM_100.npy", error_record)

time_end = time.time()
print('total time is: ', time_end-time_start, 'seconds')

current epoch is:  0
current loss is:  tensor([0.1107])
current error is:  tensor(0.5172)
current epoch is:  1
current loss is:  tensor([-0.1292])
current error is:  tensor(0.4691)
current epoch is:  2
current loss is:  tensor([-0.3930])
current error is:  tensor(0.4160)
current epoch is:  3
current loss is:  tensor([-0.6895])
current error is:  tensor(0.3562)
current epoch is:  4
current loss is:  tensor([-1.0245])
current error is:  tensor(0.2883)
current epoch is:  5
current loss is:  tensor([-1.3985])
current error is:  tensor(0.2120)
current epoch is:  6
current loss is:  tensor([-1.7958])
current error is:  tensor(0.1295)
current epoch is:  7
current loss is:  tensor([-2.1534])
current error is:  tensor(0.0522)
current epoch is:  8
current loss is:  tensor([-2.3235])
current error is:  tensor(0.0077)
current epoch is:  9
current loss is:  tensor([-2.1420])
current error is:  tensor(0.0310)
current epoch is:  10
current loss is:  tensor([-1.9792])
current error is:  tensor(0.0622)

current epoch is:  89
current loss is:  tensor([-2.4737])
current error is:  tensor(4.4538e-05)
current epoch is:  90
current loss is:  tensor([-2.4738])
current error is:  tensor(3.9802e-05)
current epoch is:  91
current loss is:  tensor([-2.4739])
current error is:  tensor(3.9812e-05)
current epoch is:  92
current loss is:  tensor([-2.4740])
current error is:  tensor(4.1115e-05)
current epoch is:  93
current loss is:  tensor([-2.4741])
current error is:  tensor(4.1811e-05)
current epoch is:  94
current loss is:  tensor([-2.4742])
current error is:  tensor(4.0972e-05)
current epoch is:  95
current loss is:  tensor([-2.4743])
current error is:  tensor(3.8580e-05)
current epoch is:  96
current loss is:  tensor([-2.4744])
current error is:  tensor(3.5609e-05)
current epoch is:  97
current loss is:  tensor([-2.4745])
current error is:  tensor(3.3546e-05)
current epoch is:  98
current loss is:  tensor([-2.4746])
current error is:  tensor(3.3555e-05)
current epoch is:  99
current loss is:  

current epoch is:  174
current loss is:  tensor([-2.4759])
current error is:  tensor(1.6052e-05)
current epoch is:  175
current loss is:  tensor([-2.4759])
current error is:  tensor(1.5756e-05)
current epoch is:  176
current loss is:  tensor([-2.4759])
current error is:  tensor(1.5498e-05)
current epoch is:  177
current loss is:  tensor([-2.4759])
current error is:  tensor(1.5311e-05)
current epoch is:  178
current loss is:  tensor([-2.4759])
current error is:  tensor(1.5197e-05)
current epoch is:  179
current loss is:  tensor([-2.4759])
current error is:  tensor(1.5136e-05)
current epoch is:  180
current loss is:  tensor([-2.4759])
current error is:  tensor(1.5100e-05)
current epoch is:  181
current loss is:  tensor([-2.4759])
current error is:  tensor(1.5064e-05)
current epoch is:  182
current loss is:  tensor([-2.4759])
current error is:  tensor(1.5027e-05)
current epoch is:  183
current loss is:  tensor([-2.4759])
current error is:  tensor(1.4995e-05)
current epoch is:  184
current

current epoch is:  259
current loss is:  tensor([-2.4760])
current error is:  tensor(1.2799e-05)
current epoch is:  260
current loss is:  tensor([-2.4760])
current error is:  tensor(1.2791e-05)
current epoch is:  261
current loss is:  tensor([-2.4760])
current error is:  tensor(1.2782e-05)
current epoch is:  262
current loss is:  tensor([-2.4760])
current error is:  tensor(1.2773e-05)
current epoch is:  263
current loss is:  tensor([-2.4760])
current error is:  tensor(1.2765e-05)
current epoch is:  264
current loss is:  tensor([-2.4760])
current error is:  tensor(1.2756e-05)
current epoch is:  265
current loss is:  tensor([-2.4760])
current error is:  tensor(1.2747e-05)
current epoch is:  266
current loss is:  tensor([-2.4760])
current error is:  tensor(1.2738e-05)
current epoch is:  267
current loss is:  tensor([-2.4760])
current error is:  tensor(1.2729e-05)
current epoch is:  268
current loss is:  tensor([-2.4760])
current error is:  tensor(1.2720e-05)
current epoch is:  269
current

current epoch is:  344
current loss is:  tensor([-2.4760])
current error is:  tensor(1.2300e-05)
current epoch is:  345
current loss is:  tensor([-2.4760])
current error is:  tensor(1.2295e-05)
current epoch is:  346
current loss is:  tensor([-2.4760])
current error is:  tensor(1.2291e-05)
current epoch is:  347
current loss is:  tensor([-2.4760])
current error is:  tensor(1.2286e-05)
current epoch is:  348
current loss is:  tensor([-2.4760])
current error is:  tensor(1.2282e-05)
current epoch is:  349
current loss is:  tensor([-2.4760])
current error is:  tensor(1.2277e-05)
current epoch is:  350
current loss is:  tensor([-2.4760])
current error is:  tensor(1.2273e-05)
current epoch is:  351
current loss is:  tensor([-2.4760])
current error is:  tensor(1.2268e-05)
current epoch is:  352
current loss is:  tensor([-2.4760])
current error is:  tensor(1.2264e-05)
current epoch is:  353
current loss is:  tensor([-2.4760])
current error is:  tensor(1.2259e-05)
current epoch is:  354
current

current epoch is:  429
current loss is:  tensor([-2.4760])
current error is:  tensor(1.1960e-05)
current epoch is:  430
current loss is:  tensor([-2.4760])
current error is:  tensor(1.1957e-05)
current epoch is:  431
current loss is:  tensor([-2.4760])
current error is:  tensor(1.1953e-05)
current epoch is:  432
current loss is:  tensor([-2.4760])
current error is:  tensor(1.1950e-05)
current epoch is:  433
current loss is:  tensor([-2.4760])
current error is:  tensor(1.1946e-05)
current epoch is:  434
current loss is:  tensor([-2.4760])
current error is:  tensor(1.1943e-05)
current epoch is:  435
current loss is:  tensor([-2.4760])
current error is:  tensor(1.1939e-05)
current epoch is:  436
current loss is:  tensor([-2.4760])
current error is:  tensor(1.1936e-05)
current epoch is:  437
current loss is:  tensor([-2.4760])
current error is:  tensor(1.1932e-05)
current epoch is:  438
current loss is:  tensor([-2.4760])
current error is:  tensor(1.1929e-05)
current epoch is:  439
current

In [16]:
torch.save(net.state_dict(), 'net_params_DRM.pkl')