In [None]:
import torch as torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from kan import KAN, create_dataset

In [None]:
#literally copied the code below from the documentation.  Need to figure out what does what

# Set the default data type to double
torch.set_default_dtype(torch.float64)
#set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input_size = 20*20 # from 28*28

'''
initialize model
- refer to MultKAN.py for more information
    width: number of neurons in each layer, in order from input to output
    k: order of the spline
    seed: random seed
    grid: grid intervals/grid points (affects the accuracy of the splines/learnable functions)
'''
model = KAN(width=[2, 5, 1], grid=5, k=3, seed=0, device=device)

In [None]:
import torchvision.transforms as transforms

def read_idx_images(file_path):
    """ Reads an IDX image file and returns a tensor of shape (N, 28, 28) """
    with open(file_path, 'rb') as f:
        f.read(4)  # Skip magic number
        num_images = int.from_bytes(f.read(4), 'big')
        rows = int.from_bytes(f.read(4), 'big')
        cols = int.from_bytes(f.read(4), 'big')
        data = np.frombuffer(f.read(), dtype=np.uint8).reshape(num_images, rows, cols)
        images = torch.tensor(data, dtype=torch.float32)  # Convert to float

        transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(20), #change the size of the image to 20x20
        transforms.ToTensor()
        ])
        resized_images = torch.stack([transform(image) for image in images])
        return resized_images

def read_idx_labels(file_path):
    """ Reads an IDX label file and returns a tensor of shape (N,) """
    with open(file_path, 'rb') as f:
        f.read(4)  # Skip magic number
        num_labels = int.from_bytes(f.read(4), 'big')
        data = np.frombuffer(f.read(), dtype=np.uint8)
        return torch.tensor(data, dtype=torch.long)  # Convert to long tensor

class OurData:
    def __init__(self):
        self.ourdataset = {}
        self.ourdataset['train_input'] = read_idx_images('/workspaces/KAN-Network/Dataset/t10k-images.idx3-ubyte') #contains the training data, each data is the binary representation of an image as per the MNIST dataset
        self.ourdataset['test_input'] = read_idx_images('/workspaces/KAN-Network/Dataset/t10k-images.idx3-ubyte') #contains the testing data, same format as train_input
        self.ourdataset['train_label'] = read_idx_labels('/workspaces/KAN-Network/Dataset/train-labels.idx1-ubyte') #contains the labels for the training data
        self.ourdataset['test_label'] = read_idx_labels('/workspaces/KAN-Network/Dataset/t10k-labels.idx1-ubyte') #contains the labels for the testing data
    def __getitem__(self, key):
        return self.ourdataset[key]

In [None]:
from kan.utils import create_dataset
# create dataset f(x,y) = exp(sin(pi*x)+y^2)
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2, device=device)

In [None]:
#Our dataset
data = OurData()
ourdata = {}
ourdata['train_input'] = data['train_input'].view(-1, input_size)
ourdata['train_label'] = data['train_label']
ourdata['test_input'] = data['test_input'].view(-1, input_size)
ourdata['test_label'] = data['test_label']

model(ourdata['train_input']) #forward pass of the model
model.plot() #plots the model

In [None]:
#code to train the model
'''
Training the model off the dataset
- opt: optimization method (LBFGS)
- steps: training steps
- lamb: penalty parameter
other parameters: lr = learning rate = 1, loss_fn = loss function = None
'''
#fits the model to the dataset
'''
model.fit(ourdata, opt="LBFGS", steps=50, lamb=0.001) #values from the basic example in the documentation
model.plot() #plots the model

model = model.prune()
model.plot()
'''

In [None]:
print(ourdata.shape())

In [None]:
print(dataset.shape())