In [1]:
import tensorflow as tf
import numpy as np
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm

class KANLayer(tf.keras.layers.Layer):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        base_activation=tf.keras.activations.silu,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            tf.range(-spline_order, grid_size + spline_order + 1, dtype=tf.float32) * h
            + grid_range[0]
        )
        self.grid = tf.Variable(grid[None, :], trainable=False, dtype=tf.float32)

        self.spline_weight = self.add_weight(
            shape=(in_features, out_features, grid_size + spline_order),
            initializer=tf.keras.initializers.HeUniform(),
            trainable=True,
            name="spline_weight"
        )


        self.base_activation = base_activation
        self.grid_eps = grid_eps

    def b_splines(self, x):
        x = tf.expand_dims(x, axis=-1)
        bases = tf.cast((x >= self.grid[:, :-1]) & (x < self.grid[:, 1:]), x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - self.grid[:, : -(k + 1)])
                / (self.grid[:, k:-1] - self.grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (self.grid[:, k + 1 :] - x)
                / (self.grid[:, k + 1 :] - self.grid[:, 1 : (-k)])
                * bases[:, :, 1:]
            )
        return bases

    def curve2coeff(self, x, y):
        A = tf.transpose(self.b_splines(x), perm=[1, 0, 2])
        B = tf.transpose(y, perm=[1, 0, 2])
        solution = tf.linalg.lstsq(A, B)
        result = tf.transpose(solution, perm=[2, 0, 1])
        return result



    def call(self, x):

      # Calculate B-spline output and ensure shapes match for matmul
      b_splines_output = self.b_splines(x)
      batch_size = tf.shape(x)[0]

      # Flatten b_splines output appropriately and reshape scaled_spline_weight to match
      spline_output = tf.matmul(
          tf.reshape(b_splines_output, (batch_size, -1)),
          tf.reshape(self.spline_weight, (-1, self.out_features))
      )

      return  spline_output



class KAN(tf.keras.Model):
    def __init__(self, layer_sizes, **kwargs):
        super(KAN, self).__init__()
        self.kan_layers = [KANLayer(layer_sizes[i], layer_sizes[i + 1], **kwargs) for i in range(len(layer_sizes) - 1)]

    def call(self, x, update_grid=False):
        for layer in self.kan_layers:
            if update_grid:
                layer.update_grid(x)
            x = layer(x)
        return x

    def train(self,X_train,y_train):

      train_data = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(100).batch(16)


      # Training Loop
      epochs = 10
      for epoch in range(epochs):
          # Training
          val_acc = tf.keras.metrics.SparseCategoricalAccuracy()
          for step, (x_batch_train, y_batch_train) in enumerate(train_data):
              with tf.GradientTape() as tape:
                  logits = self(x_batch_train, training=True)
                  loss_value = loss_fn(y_batch_train, logits)
              grads = tape.gradient(loss_value, self.trainable_weights)
              optimizer.apply_gradients(zip(grads, self.trainable_weights))
              val_acc.update_state(y_batch_train,logits)
          print(f"Epoch {epoch + 1},  Accuracy: {val_acc.result().numpy():.4f}")

    def accuracy(self,x_val,y_val):
      val_data = tf.data.Dataset.from_tensor_slices((X_val, y_val)).batch(16)
      val_acc = tf.keras.metrics.SparseCategoricalAccuracy()
      for step, (x_batch_val, y_batch_val) in enumerate(val_data):
          logits = self(x_batch_val, training=False)
          val_acc.update_state(y_batch_val,logits)
      return val_acc.result().numpy()




In [3]:
import pandas as pd
from sklearn.model_selection import train_test_split
data = pd.read_excel('sheet.xlsx', sheet_name='Toxicity')
data = data.drop(['SMILES'], axis=1)
data.dropna(inplace=True, axis=0)
y = data['CYTOTOXIC_TO_ALL'].values
x = data.drop(['CYTOTOXIC_TO_ALL'], axis=1).values

# Split data into train and validation sets
X_train, X_val, y_train, y_val = train_test_split(x, y, test_size=0.2, random_state=42)
model=KAN([14,4,2])

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.train(X_train,y_train)

y_pred=model.predict(X_val)
y_pred
y=np.argmax(y_pred,axis=1)
print(y)
print(y_val)
acc=model.accuracy(X_val,y_val)
print("Test Accuracy : ",acc)

Epoch 1,  Accuracy: 0.5132
Epoch 2,  Accuracy: 0.8289
Epoch 3,  Accuracy: 0.8618
Epoch 4,  Accuracy: 0.8684
Epoch 5,  Accuracy: 0.8750
Epoch 6,  Accuracy: 0.8882
Epoch 7,  Accuracy: 0.9079
Epoch 8,  Accuracy: 0.9276
Epoch 9,  Accuracy: 0.9276
Epoch 10,  Accuracy: 0.9342
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 319ms/step
[1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0
 0]
[1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 1 0 0 1 0 0 1 0 1 0 0 0 0 0 0 0 0 0
 0]
Test Accuracy :  0.8684211


In [4]:
import pandas as pd

iris = load_iris()
X, y = iris.data, iris.target
X = StandardScaler().fit_transform(X)
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

m = KAN([4, 64, 3])
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

m.train(X_train,y_train)



y_pred=m.predict(X_val)
y_pred
y=np.argmax(y_pred,axis=1)
print(y)
print(y_val)
acc=m.accuracy(X_val,y_val)
print("Test Accuracy : ",acc)

Epoch 1,  Accuracy: 0.2583
Epoch 2,  Accuracy: 0.7500
Epoch 3,  Accuracy: 0.9250
Epoch 4,  Accuracy: 0.9500
Epoch 5,  Accuracy: 0.9583
Epoch 6,  Accuracy: 0.9583
Epoch 7,  Accuracy: 0.9583
Epoch 8,  Accuracy: 0.9667
Epoch 9,  Accuracy: 0.9667
Epoch 10,  Accuracy: 0.9667
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 314ms/step
[1 0 2 1 1 0 1 2 1 1 2 0 0 0 0 1 2 1 1 2 0 2 0 2 2 2 2 2 0 0]
[1 0 2 1 1 0 1 2 1 1 2 0 0 0 0 1 2 1 1 2 0 2 0 2 2 2 2 2 0 0]
Test Accuracy :  1.0
