In [20]:
from efficient_kan import KAN

# Train on MNIST
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [21]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
trainset = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
valset = torchvision.datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)


In [22]:
teacher_model = KAN([28 * 28, 64, 10])
teacher_model.load_state_dict(torch.load('teacher.pth'))
teacher_model.to(device)
teacher_model.eval()

KAN(
  (layers): ModuleList(
    (0-1): 2 x KANLinear(
      (base_activation): SiLU()
    )
  )
)

In [58]:
student_model = KAN([28 * 28, 10])  # Smaller intermediate layer
student_model.to(device)

KAN(
  (layers): ModuleList(
    (0): KANLinear(
      (base_activation): SiLU()
    )
  )
)

In [59]:
def distillation_loss(y_student, y_teacher, y_true, T, alpha):
    soft_loss = nn.KLDivLoss()(F.log_softmax(y_student/T, dim=1),
                                F.softmax(y_teacher/T, dim=1))
    hard_loss = nn.CrossEntropyLoss()(y_student, y_true)
    return alpha * soft_loss * T * T + (1 - alpha) * hard_loss

In [60]:
optimizer = optim.AdamW(student_model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
temperature = 1.5
alpha = 0.5
max_val_acc = 0

In [61]:
for epoch in range(10):
    student_model.train()
    with tqdm(trainloader) as pbar:
        for i, (images, labels) in enumerate(pbar):
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()

            # Get teacher and student outputs
            with torch.no_grad():
                teacher_outputs = teacher_model(images)
            student_outputs = student_model(images)

            # Calculate combined loss
            loss = distillation_loss(student_outputs, teacher_outputs, labels, temperature, alpha)
            loss.backward()
            optimizer.step()

            pbar.set_postfix(loss=loss.item())

    # Update the learning rate
    student_model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            student_outputs = student_model(images)
            with torch.no_grad():
                teacher_outputs = teacher_model(images)
            val_loss += distillation_loss(student_outputs, teacher_outputs, labels, temperature, alpha)
            val_accuracy += (
                (student_outputs.argmax(dim=1) == labels.to(device)).float().mean().item()
            )
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Update learning rate
    scheduler.step()
    if val_accuracy > max_val_acc:
        torch.save(student_model.state_dict(), 'student_model.pth')
        max_val_acc = val_accuracy

    print(
        f"Epoch {epoch + 1}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}"
    )
    


100%|██████████| 938/938 [00:16<00:00, 56.46it/s, loss=0.133] 


Epoch 1, Val Loss: 0.16550178825855255, Val Accuracy: 0.9187898089171974


100%|██████████| 938/938 [00:16<00:00, 56.57it/s, loss=0.0927]


Epoch 2, Val Loss: 0.14655467867851257, Val Accuracy: 0.9290406050955414


100%|██████████| 938/938 [00:16<00:00, 56.54it/s, loss=0.0621]


Epoch 3, Val Loss: 0.13744647800922394, Val Accuracy: 0.9322253184713376


100%|██████████| 938/938 [00:16<00:00, 56.93it/s, loss=0.0623]


Epoch 4, Val Loss: 0.1336313784122467, Val Accuracy: 0.9324243630573248


100%|██████████| 938/938 [00:16<00:00, 56.98it/s, loss=0.259] 


Epoch 5, Val Loss: 0.13036027550697327, Val Accuracy: 0.9344148089171974


100%|██████████| 938/938 [00:16<00:00, 57.08it/s, loss=0.0737]


Epoch 6, Val Loss: 0.1313268542289734, Val Accuracy: 0.9342157643312102


100%|██████████| 938/938 [00:16<00:00, 56.30it/s, loss=0.145] 


Epoch 7, Val Loss: 0.12771864235401154, Val Accuracy: 0.9345143312101911


100%|██████████| 938/938 [00:16<00:00, 58.17it/s, loss=0.185] 


Epoch 8, Val Loss: 0.1260286271572113, Val Accuracy: 0.9375


100%|██████████| 938/938 [00:16<00:00, 58.22it/s, loss=0.104] 


Epoch 9, Val Loss: 0.1250699907541275, Val Accuracy: 0.9358081210191083


100%|██████████| 938/938 [00:16<00:00, 58.27it/s, loss=0.139] 


Epoch 10, Val Loss: 0.12472017854452133, Val Accuracy: 0.9366042993630573


In [62]:
def testing(model, valloader):
    model.eval()

    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            outputs = model(images)
            test_loss += nn.CrossEntropyLoss()(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return 100 * correct / total


In [63]:
# Load and evaluate the model

student_model.load_state_dict(torch.load('student_not_distilled.pth'))
acc = testing(student_model, valloader)
print(f'Accuracy of the NOT distilled student model on the test set: {acc}%')

student_model.load_state_dict(torch.load('student_model.pth'))
acc = testing(student_model, valloader)
print(f'Accuracy of the DISTILLED student model on the test set: {acc}%')

teacher_model.load_state_dict(torch.load('teacher.pth'))
acc = testing(teacher_model, valloader)
print(f'Accuracy of the teacher model on the test set: {acc}%')

Accuracy of the NOT distilled student model on the test set: 93.65%
Accuracy of the DISTILLED student model on the test set: 93.75%
Accuracy of the teacher model on the test set: 97.23%
