In [None]:
import sys 
sys.path.append(r'C:\Users\scsar\OneDrive - UNSW\Jupyter Projects\efficient-kan\src')

from efficient_kan import KAN

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

In [None]:
# EXAMPLE
# Load MNIST
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
trainset = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
valset = torchvision.datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = KAN([28 * 28, 64, 10])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
# Define learning rate scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)

# Define loss
criterion = nn.CrossEntropyLoss()
for epoch in range(10):
    # Train
    model.train()
    with tqdm(trainloader) as pbar:
        for i, (images, labels) in enumerate(pbar):
            images = images.view(-1, 28 * 28).to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels.to(device))
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels.to(device)).float().mean()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item(), lr=optimizer.param_groups[0]['lr'])

    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            output = model(images)
            val_loss += criterion(output, labels.to(device)).item()
            val_accuracy += (
                (output.argmax(dim=1) == labels.to(device)).float().mean().item()
            )
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Update learning rate
    scheduler.step()

    print(
        f"Epoch {epoch + 1}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}"
    )

In [None]:
import torch
import torch.nn as nn

from efficient_kan import KANLinear

n_components = 2

class KAN_AutoEncoder(nn.Module):
    def __init__(self):
        super(KAN_AutoEncoder, self).__init__()
        self.encoder = nn.Sequential(
            # KANLinear(in_features = 784, out_features = 256),
            KANLinear(in_features = 784, out_features = 128),
            nn.ReLU(),
            KANLinear(in_features = 128, out_features = 64),
            nn.ReLU(),
            KANLinear(in_features = 64, out_features = n_components)
         )
        self.decoder = nn.Sequential(
            KANLinear(in_features = n_components, out_features = 64),
            nn.ReLU(),
            KANLinear(in_features = 64, out_features = 128),
            nn.ReLU(),
            KANLinear(in_features = 128, out_features = 784),
            # KANLinear(in_features = 256, out_features = 784),
            nn.Sigmoid()  # Assuming inputs are normalized [0,1]
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded



In [None]:
# Load MNIST
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
trainset = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
valset = torchvision.datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = KAN_AutoEncoder()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
# Define learning rate scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)
    
criterion = nn.MSELoss()

In [None]:
def train(model, dataloader, criterion, optimizer, num_epochs, device):
    for epoch in range(num_epochs):
        # Train
        model.train()
        total_loss = 0
        total_mse = 0
        for images, _ in tqdm(dataloader):
            images = images.view(-1, 784).to(device)
            optimizer.zero_grad()
            _, decoded = model(images)
            loss = criterion(decoded, images)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            mse = ((decoded - images)**2).mean().item()  # Calculate MSE for the batch
            total_mse += mse

        average_loss = total_loss / len(dataloader)
        average_mse = total_mse / len(dataloader)
        print(f'Epoch {epoch+1}, Loss: {average_loss:.4f}, MSE: {average_mse:.4f}')

train(model, trainloader, criterion, optimizer, 10, device)

In [None]:
import matplotlib.pyplot as plt
import torch

def visualize_reconstructions(model, dataloader, device, num_images=10):
    model.eval()
    with torch.no_grad():
        dataiter = iter(dataloader)
        images, _ = next(dataiter)
        images = images.view(-1, 784).to(device)
        _, decoded = model(images)
        
        # Move the tensors back to CPU for visualization
        images = images.cpu().numpy()
        decoded = decoded.cpu().numpy()

        # Set up plot
        fig, axs = plt.subplots(2, num_images, figsize=(10, 2.5))
        for i in range(num_images):
            # Original Images
            ax = axs[0, i]
            ax.imshow(images[i].reshape(28, 28), cmap='gray', interpolation='nearest', vmin=0, vmax=1)
            ax.axis('off')

            # Reconstructed Images
            ax = axs[1, i]
            ax.imshow(decoded[i].reshape(28, 28), cmap='gray', interpolation='nearest', vmin=0, vmax=1)
            ax.axis('off')

        plt.tight_layout()
        plt.show()

# Assuming 'valloader' is your DataLoader for the validation dataset
visualize_reconstructions(model, valloader, device)
