In [2]:
import torch
import torch.nn.functional as F
import math

class SplineLinearLayer(torch.nn.Module):
    def __init__(self, input_dim, output_dim, num_knots=5, spline_order=3, 
                 noise_scale=0.1, base_scale=1.0, spline_scale=1.0, 
                 activation=torch.nn.SiLU, grid_epsilon=0.02, grid_range=[-1, 1], 
                 standalone_spline_scaling=True):
        super(SplineLinearLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_knots = num_knots
        self.spline_order = spline_order
        self.grid_epsilon = grid_epsilon
        self.grid_range = grid_range
        self.standalone_spline_scaling = standalone_spline_scaling

        self.knots = self._calculate_knots(grid_range, num_knots, spline_order)
        self.base_weights = torch.nn.Parameter(torch.Tensor(output_dim, input_dim))
        self.spline_weights = torch.nn.Parameter(torch.Tensor(output_dim, input_dim, num_knots + spline_order))
        if standalone_spline_scaling:
            self.spline_scales = torch.nn.Parameter(torch.Tensor(output_dim, input_dim))

        self.noise_scale = noise_scale
        self.base_scale = base_scale
        self.spline_scale = spline_scale
        self.activation = activation()

        self._initialize_parameters()

    def _initialize_parameters(self):
        """
        Initializes the parameters of the layer.
        """
        torch.nn.init.xavier_uniform_(self.base_weights, gain=math.sqrt(2))
        noise = torch.rand(self.num_knots + 1, self.input_dim, self.output_dim) - 0.5
        self.spline_weights.data.copy_(self.spline_scale * self._initialize_spline_weights(noise))
        if self.standalone_spline_scaling:
            torch.nn.init.xavier_uniform_(self.spline_scales, gain=math.sqrt(2))

    def _calculate_knots(self, grid_range, num_knots, spline_order):
        """
        Calculates the knots for the spline.

        Args:
            grid_range (list): Range of the grid.
            num_knots (int): Number of knots for the spline.
            spline_order (int): Order of the spline.

        Returns:
            torch.Tensor: Calculated knots.
        """

        h = (grid_range[1] - grid_range[0]) / num_knots
        knots = torch.arange(-spline_order, num_knots + spline_order + 1) * h + grid_range[0]
        return knots.expand(self.input_dim, -1).contiguous()

    def _initialize_spline_weights(self, noise):
        """
        Initializes spline weights.

        Args:
            noise (torch.Tensor): Noise tensor.

        Returns:
            torch.Tensor: Initialized spline weights.
        """
        return self._fit_curve_to_coefficients(self.knots.T[self.spline_order : -self.spline_order], noise)

    def _compute_b_splines(self, x):
        """
        Computes the B-spline basis functions.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Computed B-spline basis functions.
        """
        x = x.unsqueeze(-1)
        bases = ((x >= self.knots[:, :-1]) & (x < self.knots[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = ((x - self.knots[:, : -(k + 1)]) / (self.knots[:, k:-1] - self.knots[:, : -(k + 1)]) * bases[:, :, :-1] + 
                     (self.knots[:, k + 1 :] - x) / (self.knots[:, k + 1 :] - self.knots[:, 1:(-k)]) * bases[:, :, 1:])
        return bases.contiguous()

    def _fit_curve_to_coefficients(self, x, y):
        A = self._compute_b_splines(x).transpose(0, 1)
        B = y.transpose(0, 1)
        solution = torch.linalg.lstsq(A, B).solution
        return solution.permute(2, 0, 1).contiguous()

    @property
    def _scaled_spline_weights(self):
        return self.spline_weights * (self.spline_scales.unsqueeze(-1) if self.standalone_spline_scaling else 1.0)

    def forward(self, x):
        base_output = F.linear(self.activation(x), self.base_weights)
        spline_output = F.linear(self._compute_b_splines(x).view(x.size(0), -1), 
                                 self._scaled_spline_weights.view(self.output_dim, -1))
        return base_output + spline_output

    @torch.no_grad()
    def _update_knots(self, x, margin=0.01):
        """
        Updates the knots based on the input data.

        Args:
            x (torch.Tensor): Input tensor.
            margin (float): Margin value.

        Returns:
            None
        """
        batch = x.size(0)
        splines = self._compute_b_splines(x).permute(1, 0, 2)
        orig_coeff = self._scaled_spline_weights.permute(1, 2, 0)
        unreduced_spline_output = torch.bmm(splines, orig_coeff).permute(1, 0, 2)

        x_sorted = torch.sort(x, dim=0)[0]
        adaptive_knots = x_sorted[torch.linspace(0, batch - 1, self.num_knots + 1, dtype=torch.int64, device=x.device)]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.num_knots
        uniform_knots = torch.arange(self.num_knots + 1, dtype=torch.float32, device=x.device).unsqueeze(1) * uniform_step + x_sorted[0] - margin

        knots = self.grid_epsilon * uniform_knots + (1 - self.grid_epsilon) * adaptive_knots
        knots = torch.cat([
            knots[:1] - uniform_step * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
            knots,
            knots[-1:] + uniform_step * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
        ], dim=0)

        self.knots.copy_(knots.T)
        self.spline_weights.data.copy_(self._fit_curve_to_coefficients(x, unreduced_spline_output))


class KAN(torch.nn.Module):
    """
    Initializes the KAN.

    Args:
        hidden_layers (list): List of hidden layer dimensions.
        num_knots (int): Number of knots for the spline.
        spline_order (int): Order of the spline.
        noise_scale (float): Scale of the noise.
        base_scale (float): Scale of the base weights.
        spline_scale (float): Scale of the spline weights.
        activation (torch.nn.Module): Activation function to use.
        grid_epsilon (float): Epsilon value for the grid.
        grid_range (list): Range of the grid.
    """
    def __init__(self, hidden_layers, num_knots=5, spline_order=3, 
                 noise_scale=0.1, base_scale=1.0, spline_scale=1.0, 
                 activation=torch.nn.SiLU, grid_epsilon=0.02, grid_range=[-1, 1]):
        super(KAN, self).__init__()
        self.layers = torch.nn.ModuleList()
        for in_dim, out_dim in zip(hidden_layers, hidden_layers[1:]):
            self.layers.append(SplineLinearLayer(in_dim, out_dim, num_knots, spline_order, 
                                                 noise_scale, base_scale, spline_scale, 
                                                 activation, grid_epsilon, grid_range))

    def forward(self, x, update_knots=False):
        """
        Forward pass of the KAN.

        Args:
            x (torch.Tensor): Input tensor.
            update_knots (bool): Whether to update knots during forward pass.

        Returns:
            torch.Tensor: Output tensor.
        """
        for layer in self.layers:
            if update_knots:
                layer._update_knots(x)
            x = layer(x)
        return x

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Computes the regularization loss of the KAN.

        Args:
            regularize_activation (float): Regularization strength for activation.
            regularize_entropy (float): Regularization strength for entropy.

        Returns:
            torch.Tensor: Regularization loss.
        """
        return sum(layer._regularization_loss(regularize_activation, regularize_entropy) for layer in self.layers)


In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
valset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

# Define model
model = KAN([28 * 28, 64, 10])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-3)

# Define loss
criterion = nn.CrossEntropyLoss()

# Define ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)

for epoch in range(10):
    # Train
    model.train()
    total_loss = 0
    total_accuracy = 0
    with tqdm(trainloader) as pbar:
        for images, labels in pbar:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels).float().mean()
            total_loss += loss.item()
            total_accuracy += accuracy.item()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
    total_loss /= len(trainloader)
    total_accuracy /= len(trainloader)

    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)
            output = model(images)
            val_loss += criterion(output, labels).item()
            val_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Train Loss: {total_loss}, Train Accuracy: {total_accuracy}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}")

100%|██████████| 938/938 [00:28<00:00, 32.74it/s, accuracy=0.969, loss=0.187]


Epoch 1, Train Loss: 0.39985033546461224, Train Accuracy: 0.8797641257995735, Val Loss: 0.254256592203335, Val Accuracy: 0.9242635350318471


100%|██████████| 938/938 [00:28<00:00, 32.97it/s, accuracy=0.969, loss=0.0618]


Epoch 2, Train Loss: 0.21326004391683062, Train Accuracy: 0.937450026652452, Val Loss: 0.1866328773907368, Val Accuracy: 0.9432722929936306


100%|██████████| 938/938 [00:28<00:00, 32.61it/s, accuracy=1, loss=0.0466]    


Epoch 3, Train Loss: 0.1537492910165713, Train Accuracy: 0.9541244669509595, Val Loss: 0.1872036379767926, Val Accuracy: 0.9430732484076433


100%|██████████| 938/938 [00:28<00:00, 33.34it/s, accuracy=0.938, loss=0.361] 


Epoch 4, Train Loss: 0.12218124231101195, Train Accuracy: 0.9638359541577826, Val Loss: 0.1298399512117705, Val Accuracy: 0.9590963375796179


100%|██████████| 938/938 [00:28<00:00, 33.18it/s, accuracy=0.906, loss=0.283] 


Epoch 5, Train Loss: 0.1010753822310377, Train Accuracy: 0.9693663379530917, Val Loss: 0.1297462544593794, Val Accuracy: 0.9600915605095541


100%|██████████| 938/938 [00:28<00:00, 32.97it/s, accuracy=1, loss=0.0078]    


Epoch 6, Train Loss: 0.08415830911530345, Train Accuracy: 0.9740471748400853, Val Loss: 0.1182966953511261, Val Accuracy: 0.9639729299363057


100%|██████████| 938/938 [00:28<00:00, 32.45it/s, accuracy=1, loss=0.0133]    


Epoch 7, Train Loss: 0.06898970465495516, Train Accuracy: 0.9791277985074627, Val Loss: 0.11010273604836926, Val Accuracy: 0.9659633757961783


100%|██████████| 938/938 [00:29<00:00, 31.67it/s, accuracy=0.969, loss=0.0766]


Epoch 8, Train Loss: 0.05821317545737007, Train Accuracy: 0.9815598347547975, Val Loss: 0.11406798961729547, Val Accuracy: 0.9656648089171974


100%|██████████| 938/938 [00:28<00:00, 32.68it/s, accuracy=0.969, loss=0.115] 


Epoch 9, Train Loss: 0.0482356712479752, Train Accuracy: 0.9848414179104478, Val Loss: 0.10855315757302159, Val Accuracy: 0.9675557324840764


100%|██████████| 938/938 [00:28<00:00, 32.86it/s, accuracy=1, loss=0.04]      


Epoch 10, Train Loss: 0.04032294703637566, Train Accuracy: 0.9876898987206824, Val Loss: 0.11066264958843103, Val Accuracy: 0.9679538216560509
