In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import KAN

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = KAN([28 * 28, 64, 10])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(10):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)

    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|██████████| 938/938 [00:28<00:00, 32.74it/s, accuracy=0.969, loss=0.187]


Epoch 1, Train Loss: 0.39985033546461224, Train Accuracy: 0.8797641257995735, Val Loss: 0.254256592203335, Val Accuracy: 0.9242635350318471


100%|██████████| 938/938 [00:28<00:00, 32.97it/s, accuracy=0.969, loss=0.0618]


Epoch 2, Train Loss: 0.21326004391683062, Train Accuracy: 0.937450026652452, Val Loss: 0.1866328773907368, Val Accuracy: 0.9432722929936306


100%|██████████| 938/938 [00:28<00:00, 32.61it/s, accuracy=1, loss=0.0466]    


Epoch 3, Train Loss: 0.1537492910165713, Train Accuracy: 0.9541244669509595, Val Loss: 0.1872036379767926, Val Accuracy: 0.9430732484076433


100%|██████████| 938/938 [00:28<00:00, 33.34it/s, accuracy=0.938, loss=0.361] 


Epoch 4, Train Loss: 0.12218124231101195, Train Accuracy: 0.9638359541577826, Val Loss: 0.1298399512117705, Val Accuracy: 0.9590963375796179


100%|██████████| 938/938 [00:28<00:00, 33.18it/s, accuracy=0.906, loss=0.283] 


Epoch 5, Train Loss: 0.1010753822310377, Train Accuracy: 0.9693663379530917, Val Loss: 0.1297462544593794, Val Accuracy: 0.9600915605095541


100%|██████████| 938/938 [00:28<00:00, 32.97it/s, accuracy=1, loss=0.0078]    


Epoch 6, Train Loss: 0.08415830911530345, Train Accuracy: 0.9740471748400853, Val Loss: 0.1182966953511261, Val Accuracy: 0.9639729299363057


100%|██████████| 938/938 [00:28<00:00, 32.45it/s, accuracy=1, loss=0.0133]    


Epoch 7, Train Loss: 0.06898970465495516, Train Accuracy: 0.9791277985074627, Val Loss: 0.11010273604836926, Val Accuracy: 0.9659633757961783


100%|██████████| 938/938 [00:29<00:00, 31.67it/s, accuracy=0.969, loss=0.0766]


Epoch 8, Train Loss: 0.05821317545737007, Train Accuracy: 0.9815598347547975, Val Loss: 0.11406798961729547, Val Accuracy: 0.9656648089171974


100%|██████████| 938/938 [00:28<00:00, 32.68it/s, accuracy=0.969, loss=0.115] 


Epoch 9, Train Loss: 0.0482356712479752, Train Accuracy: 0.9848414179104478, Val Loss: 0.10855315757302159, Val Accuracy: 0.9675557324840764


100%|██████████| 938/938 [00:28<00:00, 32.86it/s, accuracy=1, loss=0.04]      


Epoch 10, Train Loss: 0.04032294703637566, Train Accuracy: 0.9876898987206824, Val Loss: 0.11066264958843103, Val Accuracy: 0.9679538216560509
