## Test on a simple case
Consider the following Possion Equation
$$
\begin{cases}
    \Delta u = 1\qquad &u\in\Omega\\
    u = 0\qquad &u\in\partial\Omega.
\end{cases}$$
Here $\Omega = \{(x, y)|x^2+y^2 < 1\}$

The exact solution to this problem is $$u = \frac{1}{4}(x^2+y^2-1).$$

In [1]:
% matplotlib inline
import torch 
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR, MultiStepLR
import numpy as np
from math import *
import matplotlib.pyplot as plt
import matplotlib.cm as cm

torch.set_default_tensor_type('torch.FloatTensor')

class DeepRitzNet(torch.nn.Module):
    def __init__(self, m):
        super(DeepRitzNet, self).__init__()
        self.linear1 = torch.nn.Linear(m,m)
        self.linear2 = torch.nn.Linear(m,m)
        self.linear3 = torch.nn.Linear(m,m)
        self.linear4 = torch.nn.Linear(m,m)
        self.linear5 = torch.nn.Linear(m,m)
        self.linear6 = torch.nn.Linear(m,m)
        
        self.linear7 = torch.nn.Linear(m,1)
      
    def forward(self, x):
        y = x
        y = y + F.relu(self.linear2(F.relu(self.linear1(y))))
        y = y + F.relu(self.linear4(F.relu(self.linear3(y))))
        y = y + F.relu(self.linear6(F.relu(self.linear5(y))))
        output = F.relu(self.linear7(y))
        return output

In [2]:
def draw_graph(mod, m):
    points = np.arange(-1, 1, 0.01)
    xs, ys = np.meshgrid(points, points)
    xs = torch.tensor(xs)
    ys = torch.tensor(ys)
    xl, yl = xs.size()
    z = np.zeros((xl, yl))
    for i in range(xl):
        for j in range(yl):      
            re = np.zeros(m)
            re[0] = xs[i, j]
            re[1] = ys[i, j]
            re = torch.tensor(re)        
            z[i, j] = mod(re.float()).item() + U_groundtruth(re)
    
    plt.imshow(z, cmap=cm.hot)
    plt.colorbar()
    my_x_ticks = np.arange(-1, 1, 0.2)
    my_y_ticks = np.arange(-1, 1, 0.2)
    ax = plt.gca()
    ax.set_xticks(np.linspace(0,199,9))  
    ax.set_xticklabels(('-1', '-0.75', '-0.5', '-0.25', '0', '0.25', '0.5', '0.75', '1'))  
    ax.set_yticks(np.linspace(0,199,9))  
    ax.set_yticklabels( ('1', '0.75', '0.5', '0.25', '0','-0.25','-0.5','-0.75', '-1'))  
    
    plt.show()

In [3]:
def cal_loss(mod):
    points = np.arange(-1, 1, 0.1)
    xs, ys = np.meshgrid(points, points)
    xs = torch.tensor(xs)
    ys = torch.tensor(ys)
    xl, yl = xs.size()
    z = np.zeros((xl, yl))
    mmm = 0
    t = 0
    for i in range(xl):
        for j in range(yl):      
            re = np.zeros(m)
            re[0] = xs[i, j]
            re[1] = ys[i, j]
            re = torch.tensor(re)        
            z[i, j] = mod(re.float()).item() + U_groundtruth(re)
          
            if re[0] ** 2 + re[1] ** 2 < 1 : 
                mmm += abs(z[i, j])
                t += 1
    return mmm / t

In [4]:
def relative_err(mod):
    points = np.arange(-1, 1, 0.1)
    xs, ys = np.meshgrid(points, points)
    xs = torch.tensor(xs)
    ys = torch.tensor(ys)
    xl, yl = xs.size()
    z = np.zeros((xl, yl))
    w = np.zeros((xl, yl))
    t = 0
    for i in range(xl):
        for j in range(yl):      
            re = np.zeros(m)
            re[0] = xs[i, j]
            re[1] = ys[i, j]
            re = torch.tensor(re)
            if re[0] ** 2 + re[1] ** 2 < 1 :
                z[i, j] = mod(re.float()).item() + U_groundtruth(re)
                w[i, j] = U_groundtruth(re)
                t += 1
    z = z ** 2
    w = w ** 2
    return np.sum(z) / np.sum(w)

In [5]:
def U_groundtruth(t):
    re = (t[0] ** 2 + t[1] ** 2 - 1).item() / 4
    return re

In [6]:
def validate(mod):
    draw_graph(mod)
    print(cal_loss(mod))

In [7]:
m = 10
learning_rate = 0.01
iterations = 400  
print_every_iter = 100
beta = 500 #coefficient for the regularization term in the loss expression
n2 = 100  #number of points on the border of Omega
gamma = 10

In [8]:
"""
Train with the grid
从初始化模型开始训练
"""
model = DeepRitzNet(m)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
in_error_iter = [] 
on_error_iter = [] 
mm = 1
points = np.arange(-1, 1, 0.1)
xs, ys = np.meshgrid(points, points)
xs = torch.tensor(xs)
ys = torch.tensor(ys)
xl, yl = xs.size()
                
for k in range(iterations):
    n1 = 0
    loss = torch.zeros(1)
    for i in range(xl):
        for j in range(yl):        
            x_input = np.zeros(m)
            x_input[0] = xs[i, j]
            x_input[1] = ys[i, j]
            if x_input[0] ** 2 + x_input[1] ** 2 < 1:
                n1 += 1
                x_input = torch.tensor(x_input).float()
                y = model(x_input)
                
                x1 = torch.zeros(m)
                x2 = torch.zeros(m)
                x1[0] = 0.0001
                x2[1] = 0.0001
                x_input_1 = x_input.float() + x1
                x_input_2 = x_input.float() + x2
                x_input_3 = x_input.float() - x1
                x_input_4 = x_input.float() - x2
                x_input_grad_1 = (model(x_input_1) - y) / 0.0001
                x_input_grad_2 = (model(x_input_2) - y) / 0.0001
                #x_input_2_grad_x = (model(x_input_1) + model(x_input_3) - 2 * y) / 0.0001**2
                #x_input_2_grad_y = (model(x_input_2) + model(x_input_4) - 2 * y) / 0.0001**2

                loss += 0.5 * ((x_input_grad_1) ** 2 + (x_input_grad_2) ** 2) - y 
    loss /= n1
    
    regularization = torch.zeros(1)
    for t in range(n2):
        theta = t / n2 * (2 * pi)
        x_input = np.zeros(m)
        x_input[0] = cos(theta)
        x_input[1] = sin(theta)
        x_input = torch.tensor(x_input).float()
        y = model(x_input)
        regularization += y**2 
    regularization *= mm / n2
    if gamma < 500:
        gamma = gamma * 1.01
    if mm < 500:
        mm = mm * 1.01
        
    #print loss
    print(k, " epoch, loss: ", loss.data[0].numpy())
    print(k, " epoch, regularization loss: ", regularization.data[0].numpy())
    print(k, " loss to real solution: ", cal_loss(model))
    if cal_loss(model) < 0.0001:
        break
    
    loss += regularization
    
    optimizer.zero_grad()
    loss.backward()
 
    optimizer.step()
    

(0, ' epoch, loss: ', array(-0.29168242, dtype=float32))
(0, ' epoch, regularization loss: ', array(0., dtype=float32))
(0, ' loss to real solution: ', 0.17055096900923078)
(1, ' epoch, loss: ', array(-0.36862096, dtype=float32))
(1, ' epoch, regularization loss: ', array(0.1791603, dtype=float32))
(1, ' loss to real solution: ', 0.24609436444147598)
(2, ' epoch, loss: ', array(-0.4216238, dtype=float32))
(2, ' epoch, regularization loss: ', array(0.20328452, dtype=float32))
(2, ' loss to real solution: ', 0.2978698494311698)
(3, ' epoch, loss: ', array(-0.46614802, dtype=float32))
(3, ' epoch, regularization loss: ', array(0.21573341, dtype=float32))
(3, ' loss to real solution: ', 0.34259766253627766)
(4, ' epoch, loss: ', array(-0.50315773, dtype=float32))
(4, ' epoch, regularization loss: ', array(0.22134157, dtype=float32))
(4, ' loss to real solution: ', 0.38142558749849)
(5, ' epoch, loss: ', array(-0.5357357, dtype=float32))
(5, ' epoch, regularization loss: ', array(0.22428359

(45, ' loss to real solution: ', 1.9364079941277323)
(46, ' epoch, loss: ', array(-1.2456374, dtype=float32))
(46, ' epoch, regularization loss: ', array(0.3325381, dtype=float32))
(46, ' loss to real solution: ', 1.8790879769877211)
(47, ' epoch, loss: ', array(-1.1813478, dtype=float32))
(47, ' epoch, regularization loss: ', array(0.2549832, dtype=float32))
(47, ' loss to real solution: ', 1.8006193380033844)
(48, ' epoch, loss: ', array(-1.1152852, dtype=float32))
(48, ' epoch, regularization loss: ', array(0.17873995, dtype=float32))
(48, ' loss to real solution: ', 1.7164430592374398)
(49, ' epoch, loss: ', array(-1.0149527, dtype=float32))
(49, ' epoch, regularization loss: ', array(0.1054661, dtype=float32))
(49, ' loss to real solution: ', 1.6163455158721212)
(50, ' epoch, loss: ', array(-0.9368763, dtype=float32))
(50, ' epoch, regularization loss: ', array(0.06610832, dtype=float32))
(50, ' loss to real solution: ', 1.5397809809005527)
(51, ' epoch, loss: ', array(-0.88548344

(91, ' epoch, loss: ', array(-1.0262425, dtype=float32))
(91, ' epoch, regularization loss: ', array(0.16117732, dtype=float32))
(91, ' loss to real solution: ', 1.7458025059860998)
(92, ' epoch, loss: ', array(-1.0351933, dtype=float32))
(92, ' epoch, regularization loss: ', array(0.14201707, dtype=float32))
(92, ' loss to real solution: ', 1.7360921124515043)
(93, ' epoch, loss: ', array(-1.0274839, dtype=float32))
(93, ' epoch, regularization loss: ', array(0.10830415, dtype=float32))
(93, ' loss to real solution: ', 1.7060215590046148)
(94, ' epoch, loss: ', array(-0.9951428, dtype=float32))
(94, ' epoch, regularization loss: ', array(0.07455225, dtype=float32))
(94, ' loss to real solution: ', 1.6613502590564286)
(95, ' epoch, loss: ', array(-0.96307635, dtype=float32))
(95, ' epoch, regularization loss: ', array(0.04323855, dtype=float32))
(95, ' loss to real solution: ', 1.6180594014354839)
(96, ' epoch, loss: ', array(-0.9288843, dtype=float32))
(96, ' epoch, regularization los

(135, ' loss to real solution: ', 1.6781469351570706)
(136, ' epoch, loss: ', array(-0.98478144, dtype=float32))
(136, ' epoch, regularization loss: ', array(0.02591628, dtype=float32))
(136, ' loss to real solution: ', 1.6633178200714074)
(137, ' epoch, loss: ', array(-1.0140753, dtype=float32))
(137, ' epoch, regularization loss: ', array(0.04399391, dtype=float32))
(137, ' loss to real solution: ', 1.6601606478031807)
(138, ' epoch, loss: ', array(-1.0378041, dtype=float32))
(138, ' epoch, regularization loss: ', array(0.0682447, dtype=float32))
(138, ' loss to real solution: ', 1.6607734294527978)
(139, ' epoch, loss: ', array(-1.0493684, dtype=float32))
(139, ' epoch, regularization loss: ', array(0.08181491, dtype=float32))
(139, ' loss to real solution: ', 1.657801106520404)
(140, ' epoch, loss: ', array(-1.0614793, dtype=float32))
(140, ' epoch, regularization loss: ', array(0.08204019, dtype=float32))
(140, ' loss to real solution: ', 1.6538413112339863)
(141, ' epoch, loss: '

(180, ' loss to real solution: ', 1.4775854820568843)
(181, ' epoch, loss: ', array(-0.7422626, dtype=float32))
(181, ' epoch, regularization loss: ', array(0.04801654, dtype=float32))
(181, ' loss to real solution: ', 1.4222683463794243)
(182, ' epoch, loss: ', array(-0.7747926, dtype=float32))
(182, ' epoch, regularization loss: ', array(0.09224334, dtype=float32))
(182, ' loss to real solution: ', 1.3712500398910312)
(183, ' epoch, loss: ', array(-0.7235521, dtype=float32))
(183, ' epoch, regularization loss: ', array(0.06545453, dtype=float32))
(183, ' loss to real solution: ', 1.2711229384673752)
(184, ' epoch, loss: ', array(-0.67779994, dtype=float32))
(184, ' epoch, regularization loss: ', array(0.01862003, dtype=float32))
(184, ' loss to real solution: ', 1.173525283282976)
(185, ' epoch, loss: ', array(-0.6272954, dtype=float32))
(185, ' epoch, regularization loss: ', array(0.00402478, dtype=float32))
(185, ' loss to real solution: ', 1.1008581864067206)
(186, ' epoch, loss: 

(225, ' epoch, loss: ', array(-0.7986981, dtype=float32))
(225, ' epoch, regularization loss: ', array(0.03842597, dtype=float32))
(225, ' loss to real solution: ', 1.3053167407688984)
(226, ' epoch, loss: ', array(-0.79543185, dtype=float32))
(226, ' epoch, regularization loss: ', array(0.03342449, dtype=float32))
(226, ' loss to real solution: ', 1.3155392878845187)
(227, ' epoch, loss: ', array(-0.77524346, dtype=float32))
(227, ' epoch, regularization loss: ', array(0.00930998, dtype=float32))
(227, ' loss to real solution: ', 1.3169264183895377)
(228, ' epoch, loss: ', array(-0.7256546, dtype=float32))
(228, ' epoch, regularization loss: ', array(0.00623149, dtype=float32))
(228, ' loss to real solution: ', 1.3325205779688918)
(229, ' epoch, loss: ', array(-0.7539542, dtype=float32))
(229, ' epoch, regularization loss: ', array(0.03239419, dtype=float32))
(229, ' loss to real solution: ', 1.3672041730060454)
(230, ' epoch, loss: ', array(-0.7527662, dtype=float32))
(230, ' epoch, 

(270, ' epoch, loss: ', array(-0.8366625, dtype=float32))
(270, ' epoch, regularization loss: ', array(0.02304845, dtype=float32))
(270, ' loss to real solution: ', 1.231481258313372)
(271, ' epoch, loss: ', array(-0.80028164, dtype=float32))
(271, ' epoch, regularization loss: ', array(0., dtype=float32))
(271, ' loss to real solution: ', 1.2258986771260043)
(272, ' epoch, loss: ', array(-1.0128347, dtype=float32))
(272, ' epoch, regularization loss: ', array(0.83980733, dtype=float32))
(272, ' loss to real solution: ', 1.3098067758129341)
(273, ' epoch, loss: ', array(-0.9339479, dtype=float32))
(273, ' epoch, regularization loss: ', array(0.47220874, dtype=float32))
(273, ' loss to real solution: ', 1.2459966147405923)
(274, ' epoch, loss: ', array(-0.71811813, dtype=float32))
(274, ' epoch, regularization loss: ', array(0., dtype=float32))
(274, ' loss to real solution: ', 1.1054645206537275)
(275, ' epoch, loss: ', array(-0.56678927, dtype=float32))
(275, ' epoch, regularization l

(316, ' epoch, loss: ', array(-0.7262386, dtype=float32))
(316, ' epoch, regularization loss: ', array(0., dtype=float32))
(316, ' loss to real solution: ', 1.0680515988425037)
(317, ' epoch, loss: ', array(-0.7324031, dtype=float32))
(317, ' epoch, regularization loss: ', array(0., dtype=float32))
(317, ' loss to real solution: ', 1.074569759736873)
(318, ' epoch, loss: ', array(-0.7510849, dtype=float32))
(318, ' epoch, regularization loss: ', array(0., dtype=float32))
(318, ' loss to real solution: ', 1.088378207856052)
(319, ' epoch, loss: ', array(-0.77522993, dtype=float32))
(319, ' epoch, regularization loss: ', array(0.01402318, dtype=float32))
(319, ' loss to real solution: ', 1.1087637730503384)
(320, ' epoch, loss: ', array(-0.8047321, dtype=float32))
(320, ' epoch, regularization loss: ', array(0.05333804, dtype=float32))
(320, ' loss to real solution: ', 1.1254754429614808)
(321, ' epoch, loss: ', array(-0.8085474, dtype=float32))
(321, ' epoch, regularization loss: ', arr

(360, ' loss to real solution: ', 1.4400833195390426)
(361, ' epoch, loss: ', array(-0.8645864, dtype=float32))
(361, ' epoch, regularization loss: ', array(0.01035445, dtype=float32))
(361, ' loss to real solution: ', 1.445635928789519)
(362, ' epoch, loss: ', array(-0.8569436, dtype=float32))
(362, ' epoch, regularization loss: ', array(0.00417053, dtype=float32))
(362, ' loss to real solution: ', 1.4470829373348948)
(363, ' epoch, loss: ', array(-0.85385424, dtype=float32))
(363, ' epoch, regularization loss: ', array(0.00399856, dtype=float32))
(363, ' loss to real solution: ', 1.4501787835148754)
(364, ' epoch, loss: ', array(-0.8525204, dtype=float32))
(364, ' epoch, regularization loss: ', array(0.00744982, dtype=float32))
(364, ' loss to real solution: ', 1.454697026888274)
(365, ' epoch, loss: ', array(-0.8431162, dtype=float32))
(365, ' epoch, regularization loss: ', array(0.00677718, dtype=float32))
(365, ' loss to real solution: ', 1.4566700903127432)
(366, ' epoch, loss: '

In [18]:
scheduler = MultiStepLR(optimizer, milestones=[400], gamma=0.1)
from time import *
start = time()              
for k in range(10):
    n1 = 0
    loss = torch.zeros(1)
    for i in range(xl):
        for j in range(yl):        
            x_input = np.zeros(m)
            x_input[0] = xs[i, j]
            x_input[1] = ys[i, j]
            if x_input[0] ** 2 + x_input[1] ** 2 < 1:
                n1 += 1
                x_input = torch.tensor(x_input).float()
                y = model(x_input)
                
                x1 = torch.zeros(m)
                x2 = torch.zeros(m)
                x1[0] = 0.0001
                x2[1] = 0.0001
                x_input_1 = x_input.float() + x1
                x_input_2 = x_input.float() + x2
                x_input_3 = x_input.float() - x1
                x_input_4 = x_input.float() - x2
                x_input_grad_1 = (model(x_input_1) - y) / 0.0001
                x_input_grad_2 = (model(x_input_2) - y) / 0.0001
                x_input_2_grad_x = (model(x_input_1) + model(x_input_3) - 2 * y) / 0.0001**2
                x_input_2_grad_y = (model(x_input_2) + model(x_input_4) - 2 * y) / 0.0001**2

                loss += 0.5 * ((x_input_grad_1) ** 2 + (x_input_grad_2) ** 2) - y 
    loss /= n1
    
    regularization = torch.zeros(1)
    for t in range(n2):
        theta = t / n2 * (2 * pi)
        x_input = np.zeros(m)
        x_input[0] = cos(theta)
        x_input[1] = sin(theta)
        x_input = torch.tensor(x_input).float()
        y = model(x_input)
        regularization += y**2 
    regularization *= mm / n2
    if gamma < 500:
        gamma = gamma * 1.01
    if mm < 500:
        mm = mm * 1.01
        
    print(k, " epoch, loss: ", loss.data[0].numpy())
    print(k, " epoch, regularization loss: ", regularization.data[0].numpy())
    print(k, " loss to real solution: ", cal_loss(model))
    if cal_loss(model) < 0.0001:
        break
    
    loss += regularization
    
    optimizer.zero_grad()
    loss.backward()
    
    scheduler.step()
    optimizer.step()
stop = time()
print(stop - start)

(0, ' epoch, loss: ', array(-0.7830908, dtype=float32))
(0, ' epoch, regularization loss: ', array(0.00349998, dtype=float32))
(0, ' loss to real solution: ', 1.4330135437223306)
(1, ' epoch, loss: ', array(-0.7879697, dtype=float32))
(1, ' epoch, regularization loss: ', array(0.00710402, dtype=float32))
(1, ' loss to real solution: ', 1.4351244984677367)
(2, ' epoch, loss: ', array(-0.7814055, dtype=float32))
(2, ' epoch, regularization loss: ', array(0.00064956, dtype=float32))
(2, ' loss to real solution: ', 1.4289390586541781)
(3, ' epoch, loss: ', array(-0.7718118, dtype=float32))
(3, ' epoch, regularization loss: ', array(0.00589034, dtype=float32))
(3, ' loss to real solution: ', 1.423790562283188)
(4, ' epoch, loss: ', array(-0.7633854, dtype=float32))
(4, ' epoch, regularization loss: ', array(0.00849598, dtype=float32))
(4, ' loss to real solution: ', 1.417541568585914)
(5, ' epoch, loss: ', array(-0.7479438, dtype=float32))
(5, ' epoch, regularization loss: ', array(0.001158

In [19]:
validate(the_model)

TypeError: draw_graph() takes exactly 2 arguments (1 given)

In [None]:
relative_err(the_model)

In [10]:
PATH = 'test_parameters.pkl'
torch.save(the_model.state_dict(), PATH)

NameError: name 'the_model' is not defined

In [13]:
m = 10
PATH = 'test_parameters.pkl'
the_model = DeepRitzNet(m)
the_model.load_state_dict(torch.load(PATH))

IOError: [Errno 2] No such file or directory: 'test_parameters.pkl'

In [14]:
train(the_model, initial_lr=0.001*learning_rate)

AttributeError: 'builtin_function_or_method' object has no attribute 'time'

In [15]:
def train(mod, initial_lr=learning_rate, milestones=[400], gamma=0.1, iterations=iterations, mm=1):
    optimizer = torch.optim.Adam(mod.parameters(), lr=initial_lr)
    scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=gamma)

    mm = 1
    points = np.arange(-1, 1, 0.1)
    xs, ys = np.meshgrid(points, points)
    xs = torch.tensor(xs)
    ys = torch.tensor(ys)
    xl, yl = xs.size()

    start = time.time()
    for k in range(iterations):
        n1 = 0
        loss = torch.zeros(1)
        for i in range(xl):
            for j in range(yl):        
                x_input = np.zeros(m)
                x_input[0] = xs[i, j]
                x_input[1] = ys[i, j]
                if x_input[0] ** 2 + x_input[1] ** 2 < 1:
                    n1 += 1
                    x_input = torch.tensor(x_input).float()
                    y = mod(x_input)

                    x1 = torch.zeros(m)
                    x2 = torch.zeros(m)
                    x1[0] = 0.0001
                    x2[1] = 0.0001
                    x_input_1 = x_input.float() + x1
                    x_input_2 = x_input.float() + x2
                    x_input_3 = x_input.float() - x1
                    x_input_4 = x_input.float() - x2
                    x_input_grad_1 = (mod(x_input_1) - y) / 0.0001
                    x_input_grad_2 = (mod(x_input_2) - y) / 0.0001
                    x_input_2_grad_x = (mod(x_input_1) + the_model(x_input_3) - 2 * y) / 0.0001**2
                    x_input_2_grad_y = (mod(x_input_2) + the_model(x_input_4) - 2 * y) / 0.0001**2

                    loss += 0.5 * ((x_input_grad_1) ** 2 + (x_input_grad_2) ** 2) + y
        loss /= n1

        regularization = torch.zeros(1)
        for t in range(n2):
            theta = t / n2 * (2 * pi)
            x_input = np.zeros(m)
            x_input[0] = cos(theta)
            x_input[1] = sin(theta)
            x_input = torch.tensor(x_input).float()
            y = mod(x_input)
            regularization += y**2 
        regularization *= mm / n2
        if gamma < 500:
            gamma = gamma * 1.01
        if mm < 500:
            mm = mm * 1.01

        #print loss
        print(k, " epoch, loss: ", loss.data[0].numpy())
        print(k, " epoch, regularization loss: ", regularization.data[0].numpy())
        print(k, " loss to real solution: ", cal_loss(mod))
        if cal_loss(the_model) < 0.0001:
            break

        loss += regularization

        optimizer.zero_grad()
        loss.backward()

        scheduler.step()
        optimizer.step()
    stop = time.time()
    print(stop - start)