<a href="https://colab.research.google.com/github/vincenzodentamaro/keras-FastKAN/blob/main/Fast_Kan_Keras.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
from tensorflow.keras.datasets import mnist
from tensorflow.keras import optimizers, losses

class RadialBasisFunction(layers.Layer):
    def __init__(self, grid_min, grid_max, num_grids, **kwargs):
        super(RadialBasisFunction, self).__init__(**kwargs)

        self.grid = tf.cast(
            tf.linspace(grid_min, grid_max, num_grids),
            dtype=tf.float32,
        )
        self.denominator = tf.cast(
            (grid_max - grid_min) / num_grids,
            dtype=tf.float32,
        )

    def call(self, x):
        return tf.exp(-((x[..., None] - self.grid) / self.denominator) ** 2)

class FastKANLayer(layers.Layer):
    def __init__(self, input_dim, output_dim, grid_min, grid_max, num_grids, use_base_update, base_activation, spline_weight_init_scale):
        super(FastKANLayer, self).__init__()
        self.norm = layers.LayerNormalization(axis=-1)
        self.rbf = RadialBasisFunction(grid_min, grid_max, num_grids)
        self.spline_linear = layers.Dense(output_dim)
        self.use_base_update = use_base_update
        if use_base_update:
            self.base_activation = base_activation
            self.base_linear = layers.Dense(output_dim)
    def call(self, x):
        x_norm = self.norm(x)
        spline_basis = self.rbf(x_norm)
        spline_basis_flat = tf.reshape(spline_basis, [tf.shape(spline_basis)[0], -1])
        ret = self.spline_linear(spline_basis_flat)
        if self.use_base_update:
            base = self.base_linear(self.base_activation(x))
            ret = ret + base
        return ret
class FastKAN(tf.keras.Model):
    def __init__(self, layers_hidden, grid_min=-1, grid_max=1, num_grids=10, use_base_update=False, base_activation='relu', spline_weight_init_scale=1):
        super(FastKAN, self).__init__()
        self.layers_list = []
        for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:]):
            self.layers_list.append(FastKANLayer(in_dim, out_dim, grid_min, grid_max, num_grids, use_base_update, base_activation, spline_weight_init_scale))

    def call(self, inputs):
        x = inputs
        for layer in self.layers_list:
            x = layer(x)
        return x

# Load MNIST
(train_images, train_labels), (val_images, val_labels) = mnist.load_data()
train_images, val_images = train_images / 255.0, val_images / 255.0

# Reshape the images to 1D arrays and convert to float32
train_images = train_images.reshape((-1, 28*28)).astype('float32')
val_images = val_images.reshape((-1, 28*28)).astype('float32')

# Convert the labels to numpy arrays
train_labels = np.array(train_labels)
val_labels = np.array(val_labels)

# Define model
model = FastKAN([28 * 28, 64, 10])

# Define optimizer, loss and accuracy metric
optimizer = optimizers.Adam(learning_rate=1e-3)
loss_fn = losses.SparseCategoricalCrossentropy(from_logits=True)
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy()

# Train and validate for 20 epochs
for epoch in range(20):
    print(f'Start of epoch {epoch+1}')

    # Iterate over the batches of the dataset.
    for step in range(len(train_images)):
        x_batch_train = train_images[step]
        y_batch_train = train_labels[step]
        with tf.GradientTape() as tape:
            logits = model(x_batch_train[None, ...])
            loss_value = loss_fn(y_batch_train[None, ...], logits)
        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

        # Update training metric.
        accuracy_metric.update_state(y_batch_train[None, ...], logits)

        # Log every 200 batches.
        if step % 200 == 0:
            print(f'Training loss (for one batch) at step {step}: {float(loss_value)}')
            print(f'Seen so far: {(step + 1) * 64} samples')

    # Display metrics at the end of each epoch.
    train_acc = accuracy_metric.result()
    print(f'Training acc over epoch: {float(train_acc)}')

    # Reset training metrics at the end of each epoch
    accuracy_metric.reset_states()

    # Run a validation loop at the end of each epoch.
    for step in range(len(val_images)):
        x_batch_val = val_images[step]
        y_batch_val = val_labels[step]
        val_logits = model(x_batch_val[None, ...])
        val_loss_value = loss_fn(y_batch_val[None, ...], val_logits)
        # Update val metrics
        accuracy_metric.update_state(y_batch_val[None, ...], val_logits)
    val_acc = accuracy_metric.result()
    accuracy_metric.reset_states()
    print(f'Validation acc: {float(val_acc)}')


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
Start of epoch 1




Training loss (for one batch) at step 0: 2.5073800086975098
Seen so far: 64 samples
Training loss (for one batch) at step 200: 1.3673475980758667
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.6204124689102173
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 1.4972189664840698
Seen so far: 38464 samples
Training loss (for one batch) at step 800: 0.8662940859794617
Seen so far: 51264 samples
Training loss (for one batch) at step 1000: 0.758224606513977
Seen so far: 64064 samples
Training loss (for one batch) at step 1200: 0.033305682241916656
Seen so far: 76864 samples
Training loss (for one batch) at step 1400: 1.0611724853515625
Seen so far: 89664 samples
Training loss (for one batch) at step 1600: 2.479414463043213
Seen so far: 102464 samples
Training loss (for one batch) at step 1800: 0.510697603225708
Seen so far: 115264 samples
Training loss (for one batch) at step 2000: 3.2079367637634277
Seen so far: 128064 samples
Training loss (for