In [None]:
# %%

import torch
import torch.nn as nn
from torch.func import jacrev, functional_call, vmap
import numpy as np
from pyDOE import lhs
from lm_train.network import DNN
from lm_train.training_module import training_LM
import matplotlib.pyplot as plt

In [None]:
# %%

class Net(nn.Module):

    def __init__(self, input_size, hidden_size, output_size):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.act = nn.Sigmoid()
        self.fc2 = nn.Linear(hidden_size, output_size, bias=False)

    def forward(self, x):
        out = self.fc1(x)
        out = self.act(out)
        out = self.fc2(out)
        return out


def model_u(data, params):
    return functional_call(model, params, (data, ))


def loss_target(params, *args, **kwargs):
    "General target loss"
    data, target, = args
    output = model_u(data, params)
    assert output.shape == target.shape, 'The shape of output and target should match'
    loss_b = output - target
    return loss_b


def exact(x):
    return torch.sin(2 * np.pi * x)

In [None]:
# %%

torch.set_default_dtype(torch.float64)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_points = 500
n_epoch = 1000

In [None]:
# %%

x = lhs(1, n_points)
x[0] = 0
x[-1] = 1
x = torch.tensor(x, dtype=torch.float64).to(device)
y = exact(x).to(device)
model = Net(1, 100, 1).to(device)
params = dict(model.named_parameters())
losses = [loss_target]
inputs = [[x, y]]
kwargs = [{} for _ in range(len(losses))]
args = tuple(zip(losses, inputs, kwargs))

In [None]:
# %%

params, lossval_all, loss_running, lossval_test = training_LM(
    params,
    device,
    args,
    steps=n_epoch,
)

Step: 100. loss: 2.0473e-07. mu: 1.6678e-15.
Step: 200. loss: 8.2370e-11. mu: 5.6936e-15.
Step: 300. loss: 1.5429e-11. mu: 3.2396e-15.
Step: 400. loss: 9.1208e-13. mu: 1.8433e-15.
Step: 500. loss: 6.3698e-15. mu: 6.2928e-15.
Step: 600. loss: 3.6242e-16. mu: 3.5805e-15.
Step: 700. loss: 1.4452e-16. mu: 1.2224e-14.
Step: 800. loss: 7.3198e-17. mu: 6.9550e-15.
Step: 900. loss: 4.3752e-17. mu: 3.9573e-15.
Step: 1000. loss: 3.0130e-17. mu: 2.2516e-15.
training time: 3.1312501430511475 (s).


In [None]:
# %%

# calculate the L_inf error
data_test = torch.linspace(0, 1, 10000).reshape(-1, 1).to(device)
output = model_u(data_test, params)
target = exact(data_test)
error = torch.linalg.norm(output - target, float('inf'))
print(f'The L_inf error is: {error:.4e}')

The L_inf error is: 3.0643e-08
