In [1]:
from Visualization import DataVisual
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd
import traceback
import struct
import matplotlib.colors as mcolors
from kan import KAN

torch.set_default_dtype(torch.float32)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((14, 14), antialias=True),  # Antialiasing to preserve information
    transforms.Normalize((0.5,), (0.5,))
])

full_trainset = torchvision.datasets.MNIST(
    root="./Dataset", train=True, download=True, transform=transform
)
full_valset = torchvision.datasets.MNIST(
    root="./Dataset", train=False, download=True, transform=transform
)

In [3]:
def evaluate(testloader, model, criterion):
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in testloader:
            images = images.view(images.size(0), -1).to(device)
            output = model(images)
            val_loss += criterion(output, labels.to(device)).item()
            val_accuracy += (
                (output.argmax(dim=1) == labels.to(device)).float().mean().item()
            )
    val_loss /= len(testloader)
    val_accuracy /= len(testloader)
    return val_loss, val_accuracy

In [4]:
num_train = len(full_trainset)
indices = np.random.permutation(num_train)[:int(num_train / 100)]
trainset = Subset(full_trainset, indices)

num_val = len(full_valset)
val_indices = np.random.permutation(num_val)[:int(num_val / 100)]
valset = Subset(full_valset, val_indices)

print(f"Using {len(trainset)} training samples out of {num_train} total")
print(f"Using {len(valset)} validation samples out of {num_val} total")

#Variables
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

Using 600 training samples out of 60000 total
Using 100 validation samples out of 10000 total


In [None]:
input_size = 14*14  
model = KAN([input_size, 64, 10], ckpt_path="./efficient_model_checkpoint")
model = model.loadckpt("./efficient_model_checkpoint/2")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)
criterion = nn.CrossEntropyLoss()

checkpoint directory created: ./efficient_model
saving model version 0.0


In [6]:
for epoch in range(10):
    model.train()
    with tqdm(trainloader) as pbar:
        for i, (images, labels) in enumerate(pbar):
            images = images.view(images.size(0), -1).to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels.to(device))
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels.to(device)).float().mean()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item(), lr=optimizer.param_groups[0]['lr'])

    model.eval()
    val_loss, val_accuracy = evaluate(valloader, model, criterion)
    scheduler.step()
    print(f"Epoch {epoch + 1}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")
    

100%|██████████| 10/10 [01:26<00:00,  8.60s/it, accuracy=0.917, loss=0.491, lr=0.001]


Epoch 1, Val Loss: 0.3205646425485611, Val Accuracy: 0.8819444477558136


100%|██████████| 10/10 [01:28<00:00,  8.88s/it, accuracy=0.875, loss=0.314, lr=0.0008]


Epoch 2, Val Loss: 0.3078242093324661, Val Accuracy: 0.8819444477558136


100%|██████████| 10/10 [01:28<00:00,  8.81s/it, accuracy=0.875, loss=0.31, lr=0.00064]


Epoch 3, Val Loss: 0.31274087727069855, Val Accuracy: 0.8819444477558136


100%|██████████| 10/10 [00:58<00:00,  5.85s/it, accuracy=1, loss=0.0936, lr=0.000512]  


Epoch 4, Val Loss: 0.2984686344861984, Val Accuracy: 0.8880208432674408


100%|██████████| 10/10 [00:42<00:00,  4.23s/it, accuracy=0.958, loss=0.149, lr=0.00041]


Epoch 5, Val Loss: 0.2950650379061699, Val Accuracy: 0.8880208432674408


100%|██████████| 10/10 [00:43<00:00,  4.33s/it, accuracy=0.958, loss=0.271, lr=0.000328]


Epoch 6, Val Loss: 0.30411046743392944, Val Accuracy: 0.8958333432674408


100%|██████████| 10/10 [00:42<00:00,  4.26s/it, accuracy=1, loss=0.176, lr=0.000262]   


Epoch 7, Val Loss: 0.3033341020345688, Val Accuracy: 0.9036458432674408


100%|██████████| 10/10 [00:42<00:00,  4.27s/it, accuracy=1, loss=0.104, lr=0.00021]   


Epoch 8, Val Loss: 0.2966225743293762, Val Accuracy: 0.8958333432674408


100%|██████████| 10/10 [00:43<00:00,  4.35s/it, accuracy=0.958, loss=0.21, lr=0.000168]


Epoch 9, Val Loss: 0.2981526404619217, Val Accuracy: 0.8880208432674408


100%|██████████| 10/10 [00:42<00:00,  4.28s/it, accuracy=0.958, loss=0.188, lr=0.000134]


Epoch 10, Val Loss: 0.29689842462539673, Val Accuracy: 0.8880208432674408


In [8]:
model.saveckpt(path='./efficient_model_checkpoint/3')
fullloader = DataLoader(full_valset, batch_size=64, shuffle=False)
total_loss, total_accuracy = evaluate(fullloader, model, criterion)
print(f"Total Loss: {total_loss}, Total Accuracy: {total_accuracy}")

Total Loss: 0.3215303277822247, Total Accuracy: 0.9054538216560509
