In [1]:
import torch
import numpy as np

from data import Interval, Function, Data
from model import FCNet
from solver import FunctionApproximator
from plotting import BasePlotter

device = torch.device('cpu')

Function to approximate

$y(x) = \text{sin}(x) + \text{cos}(x^2)$

In [2]:
def f(x):
    return torch.sin(x) + torch.cos(x**2)

func = Function(function=f)

In [None]:
features_in = 1
features_out = 1
hidden = [100, 100, 100, 100, 100]

model = FCNet(features_in, hidden, features_out, activation='cas', init_weights=False)

domain = Interval(-2*np.pi, 2*np.pi, steps=1000)
data = Data(domain, solution=func)

solver = FunctionApproximator(model, data, device=device)
solver.compile('adam', lr=1e-3)
# solver.compile('adam', 'onecycle', sch={'max_lr': 5e-4, 'total_steps': 10000})

losses = solver.train(num_epochs=10000, atol=1e-5, save_gif=True)
solver.create_gif(gif_save_path='gif_test4')

In [None]:
plot = BasePlotter(solver)
plot.losses(losses)
plot.numerical_solution()

In [None]:
solver.evaluate(value=np.pi/2)