In [2]:
!pip install deep-kan

Collecting deep-kan
  Downloading Deep_KAN-0.0.2-py3-none-any.whl.metadata (3.8 kB)
Downloading Deep_KAN-0.0.2-py3-none-any.whl (4.5 kB)
Installing collected packages: deep-kan
Successfully installed deep-kan-0.0.2


# You can find this package on 
https://pypi.org/project/Deep-KAN/

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from deepkan import SplineLinearLayer

# Define the custom KAN layer
class KANLayer(nn.Module):
    def __init__(self, in_features, out_features, num_knots=5, spline_order=3, noise_scale=0.1, base_scale=1.0, spline_scale=1.0, activation=nn.SiLU, grid_epsilon=0.02, grid_range=[-1, 1]):
        super(KANLayer, self).__init__()
        self.linear = SplineLinearLayer(in_features, out_features, num_knots, spline_order, noise_scale, base_scale, spline_scale, activation, grid_epsilon, grid_range)

    def forward(self, x):
        return self.linear(x)

# Define the combined model
class KANModel(nn.Module):
    def __init__(self):
        super(KANModel, self).__init__()
        self.kan_layer = KANLayer(784, 128)  # KAN layer with input size 784 (28x28) and output size 128
        self.linear_layer = nn.Linear(128, 10)  # Standard linear layer with input size 128 and output size 10 (number of classes)

    def forward(self, x):
        x = x.view(-1, 784)  # Flatten the input tensor
        x = self.kan_layer(x)
        x = F.relu(x)
        x = self.linear_layer(x)
        return x

# Load the MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())

# Create data loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

# Initialize the model
model = KANModel()

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(train_loader, 0):
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

# Evaluate the model on the test dataset
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy on the test set: %.2f %%' % (100 * correct / total))

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 28716751.82it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1060621.42it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 9069676.56it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2548565.72it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






[1,   100] loss: 0.935
[1,   200] loss: 0.408
[1,   300] loss: 0.349
[1,   400] loss: 0.324
[1,   500] loss: 0.290
[1,   600] loss: 0.273
[1,   700] loss: 0.260
[1,   800] loss: 0.238
[1,   900] loss: 0.224
[2,   100] loss: 0.213
[2,   200] loss: 0.191
[2,   300] loss: 0.176
[2,   400] loss: 0.186
[2,   500] loss: 0.162
[2,   600] loss: 0.173
[2,   700] loss: 0.159
[2,   800] loss: 0.151
[2,   900] loss: 0.144
[3,   100] loss: 0.127
[3,   200] loss: 0.121
[3,   300] loss: 0.125
[3,   400] loss: 0.118
[3,   500] loss: 0.101
[3,   600] loss: 0.121
[3,   700] loss: 0.114
[3,   800] loss: 0.110
[3,   900] loss: 0.104
[4,   100] loss: 0.083
[4,   200] loss: 0.077
[4,   300] loss: 0.089
[4,   400] loss: 0.091
[4,   500] loss: 0.085
[4,   600] loss: 0.097
[4,   700] loss: 0.090
[4,   800] loss: 0.079
[4,   900] loss: 0.080
[5,   100] loss: 0.062
[5,   200] loss: 0.075
[5,   300] loss: 0.067
[5,   400] loss: 0.062
[5,   500] loss: 0.058
[5,   600] loss: 0.067
[5,   700] loss: 0.058
[5,   800] 

# The combination of KAN + linear layer works really well. I guess, for problems where accuracy is crucial, the combination of KAN and a linear layer will be a suitable choice.