In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

# ========== PLOTTING ==========
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
axes = axes.flatten()

# 1. Heaviside step (1D)
data = datasets[0]
x_samples = generate_qmc_samples(1, 100, data['domain'])
y_samples = f_heaviside(x_samples[:, 0])
axes[0].scatter(x_samples[:, 0], y_samples, c='b', s=6, alpha=0.8)
axes[0].set_title(data['name'])
axes[0].set_xlabel('x')
axes[0].set_ylabel('f(x)')
axes[0].grid(True, alpha=0.3)

# 2. sqrt(2-x) (1D)
data = datasets[1]
x_samples = generate_qmc_samples(1, 100, data['domain'])
y_samples = f_gamma(x_samples[:, 0])
axes[1].scatter(x_samples[:, 0], y_samples, c='b', s=6, alpha=0.8)
axes[1].set_title(data['name'])
axes[1].set_xlabel('x')
axes[1].set_ylabel('f(x)')
axes[1].grid(True, alpha=0.3)

# 3. 2D Gaussians 
data = datasets[2]  
samples = generate_qmc_samples(2, 10000, data['domain'])
X, Y = samples[:, 0], samples[:, 1]
Z = f_gaussians(X, Y)
scatter = axes[2].scatter(X, Y, c=Z, cmap='viridis', s=5)
axes[2].set_title(data['name'])
axes[2].set_xlabel('x')
axes[2].set_ylabel('y')
plt.colorbar(scatter, ax=axes[2])

# 4. Rosenbrock
data = datasets[3]
samples = generate_qmc_samples(2, 10000, data['domain'])
X, Y = samples[:, 0], samples[:, 1]
Z = f_rosenbrock(X, Y)
Z_log = jnp.log10(Z + 1)
scatter = axes[3].scatter(X, Y, c=Z_log, cmap='plasma', s=5)
axes[3].set_title(data['name'] + ' (log scale)')
axes[3].set_xlabel('x')
axes[3].set_ylabel('y')
plt.colorbar(scatter, ax=axes[3])


plt.tight_layout()
plt.show()

In [None]:
# ========== Main Execution ==========
start_time = time.time() # Start timing

architectures = {
    'shallow': lambda dim: [dim, 40, 1],  # Shallow network with one hidden layer
    'deep': lambda dim: [dim, 20, 20, 20, 1],  # Deep network with three hidden layers
}
activations = {
    'relu': nn.relu,
    'tanh': nn.tanh,
}

# Choose dataset
#selected_dataset = datasets['heaviside']
#selected_dataset = datasets['gamma']
#selected_dataset = datasets['gaussians']
selected_dataset = datasets['rosenbrock']

dim = selected_dataset['dim']
domain = selected_dataset['domain']
func = selected_dataset['func']
name = selected_dataset['name']

data_points=100 if dim == 1 else 10000
test_points=100

x_data = generate_qmc_samples(dim, data_points, domain, seed=0)
x_test = generate_qmc_samples(dim, test_points, domain, seed=1)
x_test_norm = normalize(x_test, domain)

y_data = func(x_data[:, 0]) if dim == 1 else func(*x_data.T[:dim])
y_test = func(x_test[:, 0]) if dim == 1 else func(*x_test.T[:dim])

trained_models = {}
best_err = float('inf')
best_setting = None
for arch_name, arch_func in architectures.items():
    sizes = arch_func(dim)
    for act_name, act_func in activations.items():
        print(f"Training {name} with {arch_name} {act_name}")
        params, loss_history = train_model(x_test_norm, y_test, size, act_name, update_freq=1000)
        y_pred = predict(params, x_test_norm, act_name)

        abs_err = jnp.abs(y_pred - y_test)
        rel_err = abs_err / (jnp.abs(y_test) + 1e-10)
        mean_abs = jnp.mean(abs_err)
        mean_rel = jnp.mean(rel_err)
        print(f"{name} {arch_name} {act_name} - Mean Abs Err: {mean_abs:.6f}, Mean Rel Err: {mean_rel:.6f}")
        if mean_abs < best_err:
            best_err = mean_abs
            best_setting = (arch_name, act_name)
        trained_models[f"{arch_name}_{act_name}"] = {'params': params, 'y_pred': y_pred, 'x_data': x_data, 'y_data': y_data}

# Save trained models and data for visualization
with open(f"{name}_trained_models.pkl", 'wb') as f:
    pickle.dump(trained_models, f)

print(f"Best setting for {name}: {best_setting} with Mean Abs Err: {best_err:.6f}")
print("Comments: Shallow ReLU suits Heaviside's piecewise nature. Deep Tanh excels with smooth functions like Gaussians and Rosenbrock due to better non-linearity capture.")

end_time = time.time()
print(f"Training time: {end_time - start_time} seconds")