In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

class KANLayer(nn.Module):
    def __init__(self, in_features, out_features, hidden_units=128):
        super(KANLayer, self).__init__()
        self.fc1 = nn.Linear(in_features, hidden_units)
        self.fc2 = nn.Linear(hidden_units, out_features)
        self.activation = nn.ReLU()  # You can choose other activation functions as well

    def forward(self, x):
        x = self.activation(self.fc1(x))
        x = self.fc2(x)
        return x
    
class KANNet(nn.Module):
    def __init__(self, input_size, output_size):
        super(KANNet, self).__init__()
        self.kan1 = KANLayer(input_size, 256)
        self.kan2 = KANLayer(256, 128)
        self.fc = nn.Linear(128, output_size)

    def forward(self, x):
        x = self.kan1(x)
        x = self.kan2(x)
        x = self.fc(x)
        return x
    
# Generate synthetic data
x = np.linspace(-2 * np.pi, 2 * np.pi, 1000)
y = np.sin(x)

# Convert to PyTorch tensors
X = torch.tensor(x, dtype=torch.float32).unsqueeze(1)
Y = torch.tensor(y, dtype=torch.float32).unsqueeze(1)

def train(model, optimizer, criterion, X, Y, epochs=1000):
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        output = model(X)
        loss = criterion(output, Y)
        loss.backward()
        optimizer.step()
        if epoch % 100 == 0:
            print(f'Epoch {epoch}, Loss: {loss.item()}')
            
# Initialize model, optimizer, and loss function
model = KANNet(input_size=1, output_size=1)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

# Train the model
train(model, optimizer, criterion, X, Y)

# Evaluate the model
model.eval()
with torch.no_grad():
    predictions = model(X).numpy()

# Plot the results
plt.figure(figsize=(10, 5))
plt.plot(x, y, label='True Function')
plt.plot(x, predictions, label='KAN Approximation')
plt.legend()
plt.show()