Work in progress
Fast Kolmogorov-Arnold Network in JAX based on fast-kan
using equinox
.
The original implementation of KAN is pykan
.
pip install .
pip install -r requirements.txt
KANX comes with an example on MNIST:
python examples/train_mnist.py
We tested the implementation on MNIST and report the following wall-time for 3000 epochs:
Architecture | Wall time (sec) |
---|---|
CPU (i5-1135G7) | 130.51 |
CPU (i9-12900K) | 67.85 |
GPU (RTX 3070 Ti) | 13.55 |
Plots from the GPU experiment:
More experiments to come...