In [None]:
import numpy as np
import torch
from kan import KAN, create_dataset
import matplotlib.pyplot as plt


# Define the target function
def target_function(x, y):
    #navier-stocks simplified
    return np.sin(x) * np.cos(y) + np.log(1 + x**2 + y**2)

# Generate training data
x = np.linspace(-3, 3, 100)
y = np.linspace(-3, 3, 100)
X, Y = np.meshgrid(x, y)
Z = target_function(X, Y)

# Flatten the data for training
X_train = np.vstack([X.ravel(), Y.ravel()]).T
Y_train = Z.ravel()

# Convert data to PyTorch tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
Y_train_tensor = torch.tensor(Y_train, dtype=torch.float32).unsqueeze(1)

# Create dataset for KAN
dataset = create_dataset(lambda x: target_function(x[:, 0], x[:, 1]), n_var=2,    train_num=10000,
    test_num=1000,
    ranges=[(-3, 3), (-3, 3)])

# Initialize the KAN
kan = KAN(width=[2, 20, 40, 10, 1],  grid=5, k=3, noise_scale=0.1, seed=0)

# Train the KAN
kan.train(dataset, opt='LBFGS', steps=100, lamb=0.01)

# Evaluate the KAN
#kan.eval()
with torch.no_grad():
    predictions = kan(X_train_tensor).numpy()

# Reshape predictions for visualization
Z_pred = predictions.reshape(X.shape)

# Plot the results
fig = plt.figure(figsize=(12, 6))
ax1 = fig.add_subplot(121, projection='3d')
ax1.plot_surface(X, Y, Z, cmap='viridis')
ax1.set_title('Original Function')

ax2 = fig.add_subplot(122, projection='3d')
ax2.plot_surface(X, Y, Z_pred, cmap='viridis')
ax2.set_title('KAN Approximation')

plt.show()

train loss: 8.35e-01 | test loss: 8.31e-01 | reg: 4.93e+01 :   2%| | 2/100 [04:08<3:22:24, 123.93s/i