In [1]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import alexnet
from torchvision import datasets
from collections import OrderedDict
from tqdm import tqdm

from fftKAN import NaiveFourierKANLayer

# Import the CIFAR-10 dataset
datasets.CIFAR10(root='./data', train=True, download=True)

# Create the comvolutions and avgpool adapted for a 32*32 input size

def get_model():
    features = nn.Sequential(
        nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(64, 192, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(192, 384, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(384, 256, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(256, 256, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=2, stride=2)
    )
    avgpool = nn.AdaptiveAvgPool2d((4, 4))
    classifier_AlexNet = nn.Sequential(
        nn.Flatten(),
        nn.Dropout(),
        nn.Linear(256 * 4 * 4, 4096),
        nn.ReLU(inplace=True),
        nn.Dropout(),
        nn.Linear(4096, 4096),
        nn.ReLU(inplace=True),
        nn.Linear(4096, 10),
        nn.LogSoftmax(dim=1)
    )
    classifier_KAN = nn.Sequential(
        nn.Flatten(),
        nn.Dropout(),
        NaiveFourierKANLayer(256 * 4 * 4, 100, gridsize=7, smooth_initialization=True),
        nn.Dropout(),
        NaiveFourierKANLayer(100, 10, gridsize=5, smooth_initialization=True),
        nn.LogSoftmax(dim=1)
    )

    AlexNet_model = nn.Sequential(OrderedDict([
        ('features', features),
        ('avgpool', avgpool),
        ('classifier', classifier_AlexNet)
    ]))
    KAN_model = nn.Sequential(OrderedDict([
        ('features', features),
        ('avgpool', avgpool),
        ('classifier', classifier_KAN)
    ]))

    return AlexNet_model, KAN_model

Files already downloaded and verified


In [2]:
import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [3]:
AlexNet_model, KAN_model = get_model()

# Print the number of parameters of the full models
print(sum(p.numel() for p in AlexNet_model.parameters()), "\t AlexNet parameters")
print(sum(p.numel() for p in KAN_model.parameters()), "\t KAN parameters")

# Print the number of parameters of the classifiers
print(sum(p.numel() for p in AlexNet_model.classifier.parameters()), "\t AlexNet classifier parameters")
print(sum(p.numel() for p in KAN_model.classifier.parameters()), "\t KAN classifier parameters")

# Print the number of parameters of the features
print(sum(p.numel() for p in AlexNet_model.features.parameters()), "\t AlexNet features parameters")
print(sum(p.numel() for p in KAN_model.features.parameters()), "\t KAN features parameters")

35855178 	 AlexNet parameters
7996094 	 KAN parameters
33603594 	 AlexNet classifier parameters
5744510 	 KAN classifier parameters
2251584 	 AlexNet features parameters
2251584 	 KAN features parameters


In [43]:
AlexNet_model, _ = get_model()

size = 32
for layer in AlexNet_model.features:
    if isinstance(layer, nn.Conv2d):
        size = (size - layer.kernel_size[0] + 2 * layer.padding[0]) // layer.stride[0] + 1
        print("Conv to:", size)
    if isinstance(layer, nn.MaxPool2d):
        size = (size - layer.kernel_size) // layer.stride + 1
        print("Pool to:", size)
    

Conv to: 32
Pool to: 16
Conv to: 16
Pool to: 8
Conv to: 8
Conv to: 8
Conv to: 8
Pool to: 4


In [None]:
# Train the AlexNet model
model, _ = get_model()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
pbar = tqdm(range(10)) 
for epoch in pbar:
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        pbar.set_description(f'AlexNet Loss: {loss:02f}')
    torch.save(model.state_dict(), "alexnet_cifar10.pth")

torch.save(model.state_dict(), "alexnet_cifar10.pth")

In [21]:
# Load model
model, _ = get_model()
model.load_state_dict(torch.load("alexnet_cifar10.pth"))

# Accuracy of the AlexNet model
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%')

# Accuracy of the AlexNet model per class
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

for i in range(10):
    print(f'Accuracy of {classes[i]}: {100 * class_correct[i] / class_total[i]}%')

Accuracy of the network on the 10000 test images: 79.92%
Accuracy of plane: 82.2%
Accuracy of car: 84.3%
Accuracy of bird: 75.3%
Accuracy of cat: 60.1%
Accuracy of deer: 75.3%
Accuracy of dog: 70.3%
Accuracy of frog: 89.2%
Accuracy of horse: 85.8%
Accuracy of ship: 88.1%
Accuracy of truck: 91.2%


In [None]:
alexnet, model = get_model()

# Use pretrained AlexNet model to initialize the KAN model
alexnet.load_state_dict(torch.load("alexnet_cifar10.pth"))
model.features = alexnet.features
model.avgpool = alexnet.avgpool

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
pbar = tqdm(range(10))
for epoch in pbar:
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        pbar.set_description(f'KAN Loss: {loss:02f}, weight mean: {model.classifier[2].fouriercoeffs.mean():02f}')
    torch.save(model.state_dict(), "kan_cifar10.pth")

torch.save(model.state_dict(), "kan_cifar10.pth")

In [None]:
torch.save(model.state_dict(), "kan_cifar10.pth")

In [None]:
# Load model
_, model = get_model()
model.load_state_dict(torch.load("kan_cifar10.pth"))

# Accuracy of the KAN model
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%')