In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
import torch.optim as optim
from sklearn.model_selection import train_test_split

In [None]:
class PostActivationSet(nn.Module):
    def __init__(self, input_size, grid_size, k, b, grid):
        super().__init__()
        self.grid_size = grid_size
        self.k = k
        self.b = b()
        self.c = nn.Parameter(torch.normal(0, 0.1, size=(input_size, grid_size + k)))
        self.w_b = nn.Parameter(torch.ones(input_size))
        self.w_s = nn.Parameter(torch.rand(input_size))
        self.grid = grid

    def b_spline(self, t, i, p=0):
        if p == 0:
            return ((self.grid[i] <= t) & (t < self.grid[i + 1])).int()
        return ((t - self.grid[i]) / (self.grid[i + p] - self.grid[i])) * self.b_spline(t, i, p - 1) + ((self.grid[i + p + 1] - t) / (self.grid[i + p + 1] - self.grid[i + 1])) * self.b_spline(t, i + 1, p - 1)

    def b_splines(self, x):
        bases = torch.tensor([self.b_spline(x[0].item(), i) for i in range(self.grid_size + self.k)])
        for t in x[1:]:
            bases = torch.cat((bases, torch.tensor([self.b_spline(t.item(), i) for i in range(self.grid_size + self.k)])))
        return bases.view(x.shape[0], -1)

    def spline(self, x):
        return torch.sum(self.c * self.b_splines(x), dim=1).view(1, -1)

    def forward(self, x):
        return torch.sum(self.w_b * self.b(x) + self.w_s * self.spline(x)).item()

In [None]:
class Layer(nn.Module):
    def __init__(self, num_input, num_output, grid_size, k=3, b=nn.SiLU, grid_range=(-1, 1)):
        super().__init__()
        self.num_input = num_input
        self.num_output = num_output
        grid_interval = (grid_range[1] - grid_range[0]) / grid_size + grid_range[0]
        grid = torch.arange(-k, grid_size + k + 1) * grid_interval + grid_range[0]
        self.post_activation_sets = nn.ModuleList()
        for i in range(num_output):
            self.post_activation_sets.append(PostActivationSet(num_input, grid_size, k, b, grid))

    def forward(self, x):
        out = torch.zeros(self.num_output)
        for i in range(self.num_output):
            out[i] += self.post_activation_sets[i](x)
        return out

In [None]:
class KAN(nn.Module):
    def __init__(self, layer_sizes, grid_size, k=3, b=nn.SiLU, grid_range=(-1, 1)):
        super().__init__()
        self.layers = nn.ModuleList()
        for num_input, num_output in zip(layer_sizes, layer_sizes[1:]):
            self.layers.append(Layer(num_input, num_output, grid_size, k, b, grid_range))

    def forward(self, X):
        for layer in self.layers:
            unbinded_X = torch.unbind(X, dim=0)
            X = torch.stack([layer(x) for i, x in enumerate(X)], dim=0)
        return torch.sigmoid(X)

In [None]:
!wget https://archive.ics.uci.edu/static/public/151/connectionist+bench+sonar+mines+vs+rocks.zip
!unzip /content/connectionist+bench+sonar+mines+vs+rocks.zip

--2024-06-19 20:16:42--  https://archive.ics.uci.edu/static/public/151/connectionist+bench+sonar+mines+vs+rocks.zip
Resolving archive.ics.uci.edu (archive.ics.uci.edu)... 128.195.10.252
Connecting to archive.ics.uci.edu (archive.ics.uci.edu)|128.195.10.252|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified
Saving to: ‘connectionist+bench+sonar+mines+vs+rocks.zip.1’

          connectio     [<=>                 ]       0  --.-KB/s               connectionist+bench     [ <=>                ]  63.88K  --.-KB/s    in 0.02s   

2024-06-19 20:16:42 (3.10 MB/s) - ‘connectionist+bench+sonar+mines+vs+rocks.zip.1’ saved [65413]

Archive:  /content/connectionist+bench+sonar+mines+vs+rocks.zip
replace sonar.all-data? [y]es, [n]o, [A]ll, [N]one, [r]ename: n
replace sonar.mines? [y]es, [n]o, [A]ll, [N]one, [r]ename: n
replace sonar.rocks? [y]es, [n]o, [A]ll, [N]one, [r]ename: n
replace Index? [y]es, [n]o, [A]ll, [N]one, [r]ename: n
replace sonar.names? [y]es, [n]

In [None]:
data = pd.read_csv("/content/sonar.all-data", header=None)
X = data.iloc[:, 0:60].values
y = data.iloc[:, 60].values

encoder = LabelEncoder()
encoder.fit(y)
y = encoder.transform(y)

X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32).reshape(-1, 1)

loader = DataLoader(list(zip(X,y)), shuffle=True, batch_size=16)
for X_batch, y_batch in loader:
    print(X_batch, y_batch)
    break

tensor([[1.3500e-02, 4.5000e-03, 5.1000e-03, 2.8900e-02, 5.6100e-02, 9.2900e-02,
         1.0310e-01, 8.8300e-02, 1.5960e-01, 1.9080e-01, 1.5760e-01, 1.1120e-01,
         1.1970e-01, 1.1740e-01, 1.4150e-01, 2.2150e-01, 2.6580e-01, 2.7130e-01,
         3.8620e-01, 5.7170e-01, 6.7970e-01, 8.7470e-01, 1.0000e+00, 8.9480e-01,
         8.4200e-01, 9.1740e-01, 9.3070e-01, 9.0500e-01, 8.2280e-01, 6.9860e-01,
         5.8310e-01, 4.9240e-01, 4.5630e-01, 5.1590e-01, 5.6700e-01, 5.2840e-01,
         5.1440e-01, 3.7420e-01, 2.2820e-01, 1.1930e-01, 1.0880e-01, 4.3100e-02,
         1.0700e-01, 5.8300e-02, 4.6000e-03, 4.7300e-02, 4.0800e-02, 2.9000e-02,
         1.9200e-02, 9.4000e-03, 2.5000e-03, 3.7000e-03, 8.4000e-03, 1.0200e-02,
         9.6000e-03, 2.4000e-03, 3.7000e-03, 2.8000e-03, 3.0000e-03, 3.0000e-03],
        [2.6500e-02, 4.4000e-02, 1.3700e-02, 8.4000e-03, 3.0500e-02, 4.3800e-02,
         3.4100e-02, 7.8000e-02, 8.4400e-02, 7.7900e-02, 3.2700e-02, 2.0600e-01,
         1.9080e-01, 1.0650

In [None]:
model = KAN([60, 60, 30, 1], 5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

KAN(
  (layers): ModuleList(
    (0): Layer(
      (post_activation_sets): ModuleList(
        (0-59): 60 x PostActivationSet(
          (b): SiLU()
        )
      )
    )
    (1): Layer(
      (post_activation_sets): ModuleList(
        (0-29): 30 x PostActivationSet(
          (b): SiLU()
        )
      )
    )
    (2): Layer(
      (post_activation_sets): ModuleList(
        (0): PostActivationSet(
          (b): SiLU()
        )
      )
    )
  )
)

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7, shuffle=True)

loader = DataLoader(list(zip(X_train, y_train)), shuffle=True, batch_size=16)

n_epochs = 10
loss_fn = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.0001)
model.train()
for epoch in range(n_epochs):
    for X_batch, y_batch in loader:
        y_pred = model(X_batch)
        loss = loss_fn(y_pred, y_batch)
        optimizer.zero_grad()
        loss.requires_grad = True
        loss.backward()
        optimizer.step()

# evaluate accuracy after training
model.eval()
y_pred = model(X_test)
acc = (y_pred.round() == y_test).float().mean()
acc = float(acc)
print("Model accuracy: %.2f%%" % (acc*100))

Model accuracy: 49.21%


In [None]:
class PostActivationSet(nn.Module):
    def __init__(self, input_size, grid_size, k, b, grid_range):
        super().__init__()
        self.input_size = input_size
        self.grid_size = grid_size
        self.k = k
        self.b = b()
        self.c = nn.Parameter(torch.normal(0, 0.1, size=(input_size, grid_size + k)))
        self.w_b = nn.Parameter(torch.ones(input_size))
        self.w_s = nn.Parameter(torch.rand(input_size))
        self.grid_range = grid_range
        self.create_grid()

    def create_grid(self):
        grid_interval = (self.grid_range[1] - self.grid_range[0]) / self.grid_size + self.grid_range[0]
        self.grid = (torch.arange(-self.k, self.grid_size + self.k + 1) * grid_interval + self.grid_range[0]).expand(self.input_size, -1).contiguous()

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, input_size).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, input_size, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.input_size

        grid: torch.Tensor = (
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.k + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.input_size,
            self.grid_size + self.k,
        )
        return bases.contiguous()

    def spline(self, X):
        c = self.c.reshape(1, self.input_size, self.grid_size + self.k).expand(X.shape[0], self.input_size, -1)
        return torch.sum(c * self.b_splines(X), dim=2).reshape(X.shape[0], -1)

    def forward(self, X, grid_extension):
        if grid_extension:
            previous_spline = self.spline(X)
            self.grid_size *= 2
            self.create_grid()
            b_splines_ = self.b_splines(X)
            c = (torch.linalg.lstsq(b_splines_[:, 0, :], previous_spline[:, 0])).solution.reshape(1, -1)
            for i in range(1, self.input_size):
                c = torch.cat((c, (torch.linalg.lstsq(b_splines_[:, 0, :], previous_spline[:, 0])).solution.reshape(1, -1)))
            self.c = nn.Parameter(c)
        return torch.sum(self.w_b * self.b(X) + self.w_s * self.spline(X)).item()

In [None]:
class Layer(nn.Module):
    def __init__(self, num_input, num_output, grid_size, k=3, b=nn.SiLU, grid_range=(-1, 1)):
        super().__init__()
        self.num_input = num_input
        self.num_output = num_output
        grid_interval = (grid_range[1] - grid_range[0]) / grid_size + grid_range[0]
        grid = torch.arange(-k, grid_size + k + 1) * grid_interval + grid_range[0]
        self.post_activation_sets = nn.ModuleList()
        for i in range(num_output):
            self.post_activation_sets.append(PostActivationSet(num_input, grid_size, k, b, grid))

    def forward(self, X, grid_extension):
        out = torch.zeros(X.shape[0], self.num_output)
        for i in range(self.num_output):
            out[:, i] += self.post_activation_sets[i](X, grid_extension)
        return out

In [None]:
class KAN(nn.Module):
    def __init__(self, layer_sizes, grid_size, k=3, b=nn.SiLU, grid_range=(-1, 1)):
        super().__init__()
        self.layers = nn.ModuleList()
        for num_input, num_output in zip(layer_sizes, layer_sizes[1:]):
            self.layers.append(Layer(num_input, num_output, grid_size, k, b, grid_range))

    def forward(self, X, grid_extension=False):
        for layer in self.layers:
            X = layer(X, grid_extension)
        return torch.sigmoid(X)

In [None]:
model = KAN([60, 60, 30, 1], 5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

KAN(
  (layers): ModuleList(
    (0): Layer(
      (post_activation_sets): ModuleList(
        (0-59): 60 x PostActivationSet(
          (b): SiLU()
        )
      )
    )
    (1): Layer(
      (post_activation_sets): ModuleList(
        (0-29): 30 x PostActivationSet(
          (b): SiLU()
        )
      )
    )
    (2): Layer(
      (post_activation_sets): ModuleList(
        (0): PostActivationSet(
          (b): SiLU()
        )
      )
    )
  )
)

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7, shuffle=True)

loader = DataLoader(list(zip(X_train, y_train)), shuffle=True, batch_size=16)

n_epochs = 10
loss_fn = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.0001)
model.train()
for epoch in range(n_epochs):
    for X_batch, y_batch in loader:
        if epoch % 4 == 0:
            y_pred = model(X_batch, True)
        else:
            y_pred = model(X_batch)
        loss = loss_fn(y_pred, y_batch)
        optimizer.zero_grad()
        loss.requires_grad = True
        loss.backward()
        optimizer.step()

model.eval()
y_pred = model(X_test)
acc = (y_pred.round() == y_test).float().mean()
acc = float(acc)
print("Model accuracy: %.2f%%" % (acc*100))