In [106]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import alexnet
from torchvision import datasets
from collections import OrderedDict
from tqdm import tqdm

from fftKAN import NaiveFourierKANLayer

# Import the CIFAR-10 dataset
datasets.CIFAR10(root='./data', train=True, download=True)

# Create the AlexNet model
model = alexnet()

# Create the comvolutions and avgpool adapted for a 32*32 input size

features = nn.Sequential(
    nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(64, 192, kernel_size=3, padding=1),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(192, 384, kernel_size=3, padding=1),
    nn.ReLU(inplace=True),
    nn.Conv2d(384, 256, kernel_size=3, padding=1),
    nn.ReLU(inplace=True),
    nn.Conv2d(256, 256, kernel_size=3, padding=1),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=2, stride=2)
)
avgpool = nn.AdaptiveAvgPool2d((4, 4))
classifier_AlexNet = nn.Sequential(
    nn.Flatten(),
    nn.Dropout(),
    nn.Linear(256 * 4 * 4, 4096),
    nn.ReLU(inplace=True),
    nn.Dropout(),
    nn.Linear(4096, 4096),
    nn.ReLU(inplace=True),
    nn.Linear(4096, 10),
    nn.LogSoftmax(dim=1)
)
classifier_KAN = nn.Sequential(
    nn.Flatten(),
    nn.Dropout(),
    NaiveFourierKANLayer(256 * 4 * 4, 100, gridsize=50),
    nn.Dropout(),
    NaiveFourierKANLayer(100, 10, gridsize=100),
    nn.LogSoftmax(dim=1)
)
AlexNet_model = nn.Sequential(OrderedDict([
    ('features', features),
    ('avgpool', avgpool),
    ('classifier', classifier_AlexNet)
]))
KAN_model = nn.Sequential(OrderedDict([
    ('features', features),
    ('avgpool', avgpool),
    ('classifier', classifier_KAN)
]))

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Files already downloaded and verified


In [109]:
import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [121]:
tmp = next(iter(trainloader))[0][0][0]
tmp.shape

torch.Size([32, 32])

In [102]:
size = 32
for layer in AlexNet_model.features:
    if isinstance(layer, nn.Conv2d):
        size = (size - layer.kernel_size[0] + 2 * layer.padding[0]) // layer.stride[0] + 1
        print("Conv to:", size)
    if isinstance(layer, nn.MaxPool2d):
        size = (size - layer.kernel_size) // layer.stride + 1
        print("Pool to:", size)
    

Conv to: 32
Pool to: 16
Conv to: 16
Pool to: 8
Conv to: 8
Conv to: 8
Conv to: 8
Pool to: 4


In [None]:
# Train the AlexNet model
model = AlexNet_model
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
pbar = tqdm(range(10)) 
for epoch in pbar:
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        pbar.set_description(f'AlexNet Loss: {loss:02f}')

In [108]:
model = KAN_model
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
pbar = tqdm(range(10))
for epoch in pbar:
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        pbar.set_description(f'KAN Loss: {loss:02f}')

  0%|          | 0/10 [00:00<?, ?it/s]

tensor([0.9922, 0.9608, 0.9529, 0.9451, 0.9843, 0.7412, 0.5451, 0.9373, 0.9608,
        0.9608, 0.9765, 0.9529, 0.9137, 0.9529, 0.9765, 0.9686, 0.9686, 0.9765,
        0.9765, 0.9765, 0.9843, 0.9843, 0.9843, 0.9843, 0.9843, 0.9843, 0.9843,
        0.9843, 0.9843, 0.9843, 0.9765, 1.0000])


KAN Loss: nan:   0%|          | 0/10 [00:00<?, ?it/s]

tensor([0.3412, 0.3412, 0.3569, 0.3725, 0.3804, 0.4039, 0.4196, 0.4196, 0.4353,
        0.4353, 0.4431, 0.4353, 0.5059, 0.4196, 0.3961, 0.5137, 0.5137, 0.5137,
        0.4980, 0.4902, 0.3961, 0.3490, 0.5529, 0.4980, 0.5451, 0.4039, 0.3255,
        0.4353, 0.4196, 0.4353, 0.4510, 0.4118])


KAN Loss: nan:   0%|          | 0/10 [00:01<?, ?it/s]

tensor([0.3020, 0.2941, 0.3020, 0.3098, 0.3098, 0.3098, 0.3098, 0.3098, 0.3176,
        0.3176, 0.3176, 0.3176, 0.3176, 0.3176, 0.3176, 0.3255, 0.3255, 0.3255,
        0.3255, 0.3255, 0.3255, 0.3255, 0.3333, 0.3333, 0.3333, 0.3255, 0.3176,
        0.3176, 0.3176, 0.3098, 0.3098, 0.3020])


KAN Loss: nan:   0%|          | 0/10 [00:02<?, ?it/s]

tensor([0.7804, 0.7255, 0.7255, 0.7176, 0.7255, 0.7255, 0.7255, 0.7412, 0.7490,
        0.7412, 0.7412, 0.7490, 0.7490, 0.7490, 0.7569, 0.7647, 0.7333, 0.5843,
        0.4353, 0.4353, 0.4353, 0.3569, 0.4039, 0.4431, 0.5686, 0.6157, 0.5843,
        0.5137, 0.4902, 0.4824, 0.4353, 0.3176])


KAN Loss: nan:   0%|          | 0/10 [00:02<?, ?it/s]

tensor([0.2000, 0.1843, 0.1922, 0.1843, 0.1843, 0.1922, 0.1843, 0.1922, 0.2000,
        0.2078, 0.2157, 0.2235, 0.2157, 0.2000, 0.1765, 0.1843, 0.1922, 0.2000,
        0.2078, 0.2078, 0.2235, 0.2235, 0.2235, 0.2235, 0.2078, 0.2078, 0.2235,
        0.2314, 0.2392, 0.2314, 0.2392, 0.1922])


KAN Loss: nan:   0%|          | 0/10 [00:03<?, ?it/s]


KeyboardInterrupt: 