In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from models.kan.KAN import KAN
from utils.utils import *
import copy



Detected IPython. Loading juliacall extension. See https://juliapy.github.io/PythonCall.jl/stable/compat/#IPython


In [16]:
# Creates a dataset according to the specified function
def create_dataset(fn, device='cuda'):
    x1 = torch.linspace(-1, 1, steps=40)
    x2 = torch.linspace(-1, 1, steps=40)
    x1, x2 = torch.meshgrid(x1, x2, indexing='ij')
    x_eval = torch.stack([x1.flatten(), x2.flatten()], dim=1)
    y_target = fn(x_eval) #+ torch.randn((x_eval.shape[0], 1))*0.1

    # test and train split
    n_training = int(x_eval.shape[0] * 0.8)
    training_idxs = np.random.randint(x_eval.shape[0], size=n_training)
    test_idxs = [i for i in range(x_eval.shape[0]) if i not in training_idxs]

    x_training = x_eval[training_idxs]
    x_test = x_eval[test_idxs]
    y_training = y_target[training_idxs]
    y_test = y_target[test_idxs]

    shuffled_idxs = torch.randperm(x_training.shape[0])
    x_training = x_training[shuffled_idxs]
    y_training = y_training[shuffled_idxs]

    return x_training.to(device), x_test.to(device), y_training.to(device), y_test.to(device)


# Evaluates a model on the validation/test set
def eval_model(model, x_val, y_val, loss):
    model.eval()
    with torch.no_grad():
        val_out = model(x_val)
        val_loss = loss(val_out, y_val)
    return val_loss


# Trains the model
def training(model:KAN, x_training, x_test, y_training, y_test, epochs=1000, patience=50, log=100, lr=0.001, lamb=0.01):
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr = lr)

    best_val_loss = float('inf')
    best_epoch = 0
    best_model_state = None

    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        y_pred = model(x_training)
        train_loss = criterion(y_pred, y_training)
        reg, _, _ = model.regularization_loss()
        loss = train_loss + lamb * reg
        loss.backward()
        optimizer.step()

        val_loss = eval_model(model, x_test, y_test, criterion)

        if epoch % log == 0:
            print(f"Epoch: {epoch} \t Train Loss: {train_loss: 0.5f} \t Val Loss: {val_loss: 0.5f} \t Best Val Loss: {best_val_loss: 0.5f}")

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_epoch = epoch
            best_model_state = copy.deepcopy(model.state_dict())
        elif epoch - best_epoch >= patience:
            print(f"Early stopping at epoch {epoch}")
            break
    return best_model_state

In [17]:
model = KAN(
    [2, 5, 1],
    mu_1=1.,
    mu_2=1.,
    use_orig_reg=True,
    store_act=True
)

fn = lambda x: (x[:, 0]**2).unsqueeze(-1) + (x[:, 1]).unsqueeze(-1)
# fn = lambda x: (x[:, 0]**2 + x[:, 1]**3).unsqueeze(-1)

# fn = lambda x: torch.exp(torch.sin(torch.pi * x[:, :1]) + torch.pow(x[:, 1:] + 0.5, 2))
x_training, x_test, y_training, y_test = create_dataset(fn)

# Training without fixing symbolic functions
best_model = training(model, x_training, x_test, y_training, y_test, epochs=3000, patience=500, lr=0.01, lamb=0.01)

Epoch: 0 	 Train Loss:  0.66161 	 Val Loss:  0.61285 	 Best Val Loss:  inf
Epoch: 100 	 Train Loss:  0.00040 	 Val Loss:  0.00041 	 Best Val Loss:  0.00043
Epoch: 200 	 Train Loss:  0.00016 	 Val Loss:  0.00016 	 Best Val Loss:  0.00014
Epoch: 300 	 Train Loss:  0.00011 	 Val Loss:  0.00010 	 Best Val Loss:  0.00010
Epoch: 400 	 Train Loss:  0.00010 	 Val Loss:  0.00009 	 Best Val Loss:  0.00008
Epoch: 500 	 Train Loss:  0.00010 	 Val Loss:  0.00006 	 Best Val Loss:  0.00004
Epoch: 600 	 Train Loss:  0.00005 	 Val Loss:  0.00007 	 Best Val Loss:  0.00003
Epoch: 700 	 Train Loss:  0.00007 	 Val Loss:  0.00008 	 Best Val Loss:  0.00002
Epoch: 800 	 Train Loss:  0.00006 	 Val Loss:  0.00005 	 Best Val Loss:  0.00002
Epoch: 900 	 Train Loss:  0.00005 	 Val Loss:  0.00006 	 Best Val Loss:  0.00002
Epoch: 1000 	 Train Loss:  0.00003 	 Val Loss:  0.00003 	 Best Val Loss:  0.00002
Epoch: 1100 	 Train Loss:  0.00006 	 Val Loss:  0.00004 	 Best Val Loss:  0.00002
Early stopping at epoch 1196


In [18]:
plot(
    folder_path='./saved_models_optuna/example_KAN/figures',
    layers=model.layers,
    show_plots=False
)

In [19]:
save_acts(
    layers=model.layers,
    folder_path='./saved_models_optuna/example_KAN/cached_acts',
)

In [20]:
cache_acts, cache_preacts = get_kan_arch(n_layers=2, model_path='./saved_models_optuna/example_KAN')
pruned_acts, pruned_preacts = pruning(cache_acts, cache_preacts, theta=0.01)  

Pruning node (0,0)
Pruning node (0,2)
Pruning node (0,3)
Pruning node (0,4)


In [21]:
symb_g = fit_kan(
    pruned_acts,
    pruned_preacts,
    symb_xs=[sp.Symbol('x_1'), sp.Symbol('x_2')],
    model_path='./saved_models_optuna/example_KAN'
)

In [22]:
quantise(symb_g[0])

0.99*x_1**2 + 0.99*x_2