### Import setting

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

### Define Runge

Try to get 
$$ \frac{1}{1+25x^2} $$

### Data preparation

In [None]:
def runge(x):
    return 1 / (1 + 25 * x**2)

# Training data
N_train = 10000
x_train = torch.linspace(-1, 1, N_train).unsqueeze(1)
y_train = runge(x_train)

# Validation data
N_val = 1500
x_val = torch.linspace(-1, 1, N_val).unsqueeze(1)
y_val = runge(x_val)

### Network prepartion

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(1, 6),
            nn.Tanh(),
            nn.Linear(6, 6),
            nn.Tanh(),
            nn.Linear(6, 1)
        )
        
    def forward(self, x):
        return self.layers(x)

model = Net()


### Training

In [None]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

train_losses, val_losses = [], []
epochs = 7000

for epoch in range(epochs):

    optimizer.zero_grad()
    pred_train = model(x_train)
    loss_train = criterion(pred_train, y_train)
    loss_train.backward()
    optimizer.step()
    
    pred_val = model(x_val)
    loss_val = criterion(pred_val, y_val)
    
    train_losses.append(loss_train.item())
    val_losses.append(loss_val.item())

    if epoch % 500 == 0:
        print(f"Epoch {epoch:4d} | Train Loss: {loss_train.item():.6f} | Val Loss: {loss_val.item():.6f}")

print(f"Epoch 7000 | Train Loss: {loss_train.item():.6f} | Val Loss: {loss_val.item():.6f}")

### Loss

In [None]:
plt.figure(figsize=(12,8))
plt.plot(val_losses, label="Validation Loss")
plt.plot(train_losses, label="Training Loss", linestyle="--")

plt.yscale("log")
plt.legend()
plt.title("Loss Curves")
plt.show()

### Result

In [None]:
y_pred = model(x_val).detach()

plt.figure(figsize=(12,8))
plt.plot(x_val, y_val, label="True function")
plt.plot(x_val, y_pred, label="Approximation", linestyle="--")
plt.legend()
plt.title("Runge Function Approximation")
plt.show()

mse = criterion(y_pred, y_val).item()
max_err = torch.max(torch.abs(y_pred - y_val)).item()

print(f"MSE: {mse:.6f}, Max Error: {max_err:.6f}")