In [None]:
!pip install deep-kan

In [None]:
from deepkan import RBFKAN

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = RBFKAN(layers_hidden=[28 * 28, 64, 10])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(20):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)

    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|██████████| 938/938 [00:16<00:00, 57.15it/s, accuracy=0.906, loss=0.149] 


Epoch 1, Train Loss: 0.4723579772213883, Train Accuracy: 0.8596748400852878, Val Loss: 0.28278024633456567, Val Accuracy: 0.9156050955414012


100%|██████████| 938/938 [00:15<00:00, 58.91it/s, accuracy=0.969, loss=0.096] 


Epoch 2, Train Loss: 0.27193666877808853, Train Accuracy: 0.9211587153518124, Val Loss: 0.23299388688318667, Val Accuracy: 0.9289410828025477


100%|██████████| 938/938 [00:15<00:00, 59.63it/s, accuracy=0.906, loss=0.405] 


Epoch 3, Train Loss: 0.23325185359382172, Train Accuracy: 0.9313033049040512, Val Loss: 0.2762829776805867, Val Accuracy: 0.9084394904458599


100%|██████████| 938/938 [00:15<00:00, 58.76it/s, accuracy=0.938, loss=0.246] 


Epoch 4, Train Loss: 0.2116192927174985, Train Accuracy: 0.937450026652452, Val Loss: 0.19961319333476246, Val Accuracy: 0.9429737261146497


100%|██████████| 938/938 [00:15<00:00, 59.43it/s, accuracy=0.938, loss=0.189] 


Epoch 5, Train Loss: 0.18794896839849795, Train Accuracy: 0.9441131396588486, Val Loss: 0.19532995339435566, Val Accuracy: 0.9402866242038217


100%|██████████| 938/938 [00:15<00:00, 59.33it/s, accuracy=0.969, loss=0.114] 


Epoch 6, Train Loss: 0.16509816502489005, Train Accuracy: 0.9504097814498934, Val Loss: 0.19193759916801076, Val Accuracy: 0.9436703821656051


100%|██████████| 938/938 [00:15<00:00, 59.84it/s, accuracy=1, loss=0.0301]    


Epoch 7, Train Loss: 0.1511251903019909, Train Accuracy: 0.9539578891257996, Val Loss: 0.17961289003123618, Val Accuracy: 0.945859872611465


100%|██████████| 938/938 [00:15<00:00, 59.08it/s, accuracy=0.938, loss=0.281] 


Epoch 8, Train Loss: 0.14235043051495735, Train Accuracy: 0.9563732675906184, Val Loss: 0.1634675961100514, Val Accuracy: 0.9515326433121019


100%|██████████| 938/938 [00:15<00:00, 59.06it/s, accuracy=0.938, loss=0.249] 


Epoch 9, Train Loss: 0.12932713463433834, Train Accuracy: 0.9604044509594882, Val Loss: 0.15618378133809632, Val Accuracy: 0.9536226114649682


100%|██████████| 938/938 [00:15<00:00, 58.73it/s, accuracy=1, loss=0.00812]   


Epoch 10, Train Loss: 0.11820037258483136, Train Accuracy: 0.9632029584221748, Val Loss: 0.17273120564307756, Val Accuracy: 0.9499402866242038


100%|██████████| 938/938 [00:15<00:00, 59.18it/s, accuracy=0.969, loss=0.238] 


Epoch 11, Train Loss: 0.10913117716351012, Train Accuracy: 0.9660014658848614, Val Loss: 0.14859289467182, Val Accuracy: 0.9562101910828026


100%|██████████| 938/938 [00:15<00:00, 58.93it/s, accuracy=1, loss=0.0388]    


Epoch 12, Train Loss: 0.10387628554586949, Train Accuracy: 0.9674007196162047, Val Loss: 0.15212755397653244, Val Accuracy: 0.9561106687898089


100%|██████████| 938/938 [00:16<00:00, 58.62it/s, accuracy=0.938, loss=0.275] 


Epoch 13, Train Loss: 0.09536701055573248, Train Accuracy: 0.9703491471215352, Val Loss: 0.1631168847470574, Val Accuracy: 0.9522292993630573


100%|██████████| 938/938 [00:15<00:00, 58.92it/s, accuracy=0.938, loss=0.346] 


Epoch 14, Train Loss: 0.09163396457718538, Train Accuracy: 0.9711820362473348, Val Loss: 0.14687175662434737, Val Accuracy: 0.9588972929936306


100%|██████████| 938/938 [00:15<00:00, 59.28it/s, accuracy=0.969, loss=0.204] 


Epoch 15, Train Loss: 0.08519867767563967, Train Accuracy: 0.9723647388059702, Val Loss: 0.15936849887815013, Val Accuracy: 0.95203025477707


100%|██████████| 938/938 [00:16<00:00, 57.16it/s, accuracy=1, loss=0.0205]    


Epoch 16, Train Loss: 0.07962533615247559, Train Accuracy: 0.974680170575693, Val Loss: 0.14413821905892887, Val Accuracy: 0.959593949044586


100%|██████████| 938/938 [00:17<00:00, 55.13it/s, accuracy=0.969, loss=0.0896]


Epoch 17, Train Loss: 0.07356030021065366, Train Accuracy: 0.9765625, Val Loss: 0.16096238490523307, Val Accuracy: 0.9554140127388535


100%|██████████| 938/938 [00:17<00:00, 53.41it/s, accuracy=0.969, loss=0.191] 


Epoch 18, Train Loss: 0.069909372185764, Train Accuracy: 0.9778784648187633, Val Loss: 0.1489242538342632, Val Accuracy: 0.9583996815286624


100%|██████████| 938/938 [00:17<00:00, 53.90it/s, accuracy=1, loss=0.0293]    


Epoch 19, Train Loss: 0.06813243783931973, Train Accuracy: 0.9772454690831557, Val Loss: 0.14044109393900356, Val Accuracy: 0.9610867834394905


100%|██████████| 938/938 [00:16<00:00, 57.02it/s, accuracy=0.938, loss=0.121] 


Epoch 20, Train Loss: 0.06659237002563685, Train Accuracy: 0.9781783049040512, Val Loss: 0.15983872259499884, Val Accuracy: 0.9579020700636943


# It is faster than normal KANs with a speed of 60 it/s