In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim, autograd
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

num_int_pts = 16
num_bndry_pts = 4
dim_prob = 2


class FcBlock(nn.Module):
    """ 
        dim_in  -- dimension of each block's input sample
        width   -- number of nodes in each block's hidden layer
        dim_out -- dimension of each block's output sample (not the dimension of final output)  
    """
    
    def __init__(self, dim_in, width, dim_out):
        super(FcBlock, self).__init__()

        # set parameters
        self.dim_in = dim_in
        self.width = width
        self.dim_out = dim_out

        # create linear layers
        self.Linear1 = nn.Linear(dim_in, width)
        self.Linear2 = nn.Linear(width, dim_out)        

    def forward(self, x):
        # choose activation function = Tanh
        return torch.tanh(self.Linear2(torch.tanh(self.Linear1(x))))

class FcNet(nn.Module):

    def __init__(self, dim_in, width, dim_out, depth=1):
        super(FcNet, self).__init__()

        # set parameters
        self.dim_in = dim_in
        self.width = width
        self.dim_out = dim_out
        self.depth = depth

        # creat a list for holding all blocks/layers
        self.stack = nn.ModuleList()
                
        # input layer       
        self.stack.append(nn.Linear(dim_in, width))
        # add hidden blocks
        for i in range(depth):
            self.stack.append(FcBlock(width, width, width))        
        # output layer
        self.stack.append(nn.Linear(width, dim_out))  
        
    def forward(self, x):
        
        for layer in self.stack:
            x = layer(x)
#             x_temp = layer(x)
#             x = x_temp + x
        return x


## Data Generation

from DataSets.Square2D import Sample_Points, Exact_Solution


# def get_interior_points(N=128,d=2):
#     """
#     randomly sample N points from interior of [-1,1]^d
#     """
#     return torch.rand(N,d) * 2 - 1

# def get_boundary_points(N=33):
#     index = torch.rand(N, 1)
#     index1 = torch.rand(N,1) * 2 - 1
#     xb1 = torch.cat((index, torch.zeros_like(index)), dim=1)
#     xb2 = torch.cat((index1, torch.ones_like(index1)), dim=1)
#     xb3 = torch.cat((index1, torch.full_like(index1, -1)), dim=1)
#     xb4 = torch.cat((torch.ones_like(index1), index1), dim=1)
#     xb5 = torch.cat((torch.full_like(index1, -1), index1), dim=1)
#     xb = torch.cat((xb1, xb2, xb3, xb4, xb5), dim=0)

#     return xb

def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0.0)

def main():

    epochs = 100

    in_N = 2
    m = 10
    out_N = 1

    print(torch.cuda.is_available())
    device = torch.device('cpu' if torch.cuda.is_available() else 'cpu')
    # Notice that the real code is "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
    # Although my computer supports cuda ,  its running speed is slower tahn 'cpu', ...........
    model = FcNet(in_N, m, out_N).to(device)
    model.apply(weights_init)
    criteon = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=3e-3)
    print(model)

    # x = torch.cat((xr, xb), dim=0)

    # if 2 < m:
    #     y = torch.zeros(x.shape[0], m - 2)
    #     x = torch.cat((x, y), dim=1)
    # # print(x.shape)
    best_loss, best_epoch = 1000, 0
    for epoch in range(epochs+1):

        # generate the data set
        
        xr = Sample_Points.SmpPts_Interior_Square2D(num_int_pts, dim_prob)
        xb = Sample_Points.SmpPts_Boundary_Square2D(num_bndry_pts, dim_prob)
#         xr = get_interior_points()
#         xb = get_boundary_points()

        xr = xr.to(device)
        xb = xb.to(device)

        xr.requires_grad_()
        output_r = model(xr)
        output_b = model(xb)
        grads = autograd.grad(outputs=output_r, inputs=xr,
                              grad_outputs=torch.ones_like(output_r),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]

        loss_r = 0.5 * torch.sum(torch.pow(grads, 2),dim=1)- output_r
        loss_r = torch.mean(loss_r)
        loss_b = torch.mean(torch.pow(output_b,2))
        loss = 4 * loss_r + 9 * 500 * loss_b

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch % 100 == 0:
            print('epoch:', epoch, 'loss:', loss.item(), 'loss_r:', (4 * loss_r).item(), 'loss_b:', (9 *500 * loss_b).item())
            if epoch > int(4 * epochs / 5):
                if torch.abs(loss) < best_loss:
                    best_loss = torch.abs(loss).item()
                    best_epoch = epoch
                    torch.save(model.state_dict(), 'new_best_deep_ritz1.mdl')
    print('best epoch:', best_epoch, 'best loss:', best_loss)

    # plot figure
    model.load_state_dict(torch.load('new_best_deep_ritz1.mdl'))
    print('load from ckpt!')
    with torch.no_grad():
        x1 = torch.linspace(-1, 1, 1001)
        x2 = torch.linspace(-1, 1, 1001)
        X, Y = torch.meshgrid(x1, x2)
        Z = torch.cat((Y.flatten()[:, None], Y.T.flatten()[:, None]), dim=1)
        # if 2 < m:
        #     y = torch.zeros(Z.shape[0], m - 2)
        #     Z = torch.cat((Z, y), dim=1)
        Z = Z.to(device)
        pred = model(Z)

    plt.figure()
    pred = pred.cpu().numpy()
    pred = pred.reshape(1001, 1001)
    ax = plt.subplot(1, 1, 1)
    h = plt.imshow(pred, interpolation='nearest', cmap='rainbow',
                   extent=[-1, 1, -1, 1],
                   origin='lower', aspect='auto')
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(h, cax=cax)
    plt.show()


if __name__ == '__main__':
    main()

ModuleNotFoundError: No module named 'DataSets'