In [None]:
# !pip install pykan



In [3]:
import torch
import matplotlib.pyplot as plt
from kan import KAN, create_dataset
from kan.MLP import MLP as kan_MLP
from kan.feynman import get_feynman_dataset

In [5]:
class MLP(kan_MLP):

    def __init__(self, width, act='silu', save_act=True, seed=0, device='cpu'):
        super(MLP, self).__init__(width, act, save_act, seed, device)
    
        if act == 'silu':
            self.act_fun = torch.nn.SiLU()
        elif act == 'relu':
            self.act_fun = torch.nn.ReLU()
        elif act == 'tanh':
            self.act_fun = torch.nn.Tanh()

## Experiment Functions

In [6]:
def run_kan_feynman(dataset: str,
                          shape: list[int],
                          samples: int, # samples for both train and test (separated) sets
                          start_grid: int = 3,
                          k: int = 3,
                          device='cuda' if torch.cuda.is_available() else 'cpu',
                          seed:int=42,
                          grids=[3, 5, 10, 20, 50, 100, 200],
                          steps_per_grid=200,
                          plot_model=False,
                          prune=False,
                          prune_threshold=1e-2,
                          steps_after_prune=200,
                          lamb=0.00,
                          opt="LBFGS"):
    
    assert device in ['cpu', 'cuda']
    try:
        torch.manual_seed(seed)
    except Exception as e:
        print('Warning: could not set torch.manual_seed')
    
    try:
        symbol, expr, f, ranges = get_feynman_dataset(dataset)
    except ValueError:
        raise ValueError("Invalid dataset name")

    dataset = create_dataset(f, n_var=len(ranges), ranges=ranges, train_num=samples, test_num=samples, seed=seed)


    dataset['train_input'] = dataset['train_input'].to(device)
    dataset['train_label'] = dataset['train_label'].to(device)
    dataset['test_input'] = dataset['test_input'].to(device)
    dataset['test_label'] = dataset['test_label'].to(device)

    # assert shape[0] == len(ranges), "Input dimension does not match number of variables in dataset"
    # print(shape)
    if not shape[-1] == 1:
        print(f"Warning: Output dimension {shape[-1]} is not 1. Forcing condition.")
        shape[-1] = 1
    if not shape[0] == len(ranges):
        print(f"Warning: Input dimension {shape[0]} does not match number of variables in dataset {len(ranges)}. Forcing condition.")
        shape[0] = len(ranges)

    # Create a copy of shape before passing to KAN to avoid mutation
    # KAN's __init__ converts [2,5,1] to [[2,0],[5,0],[1,0]] in-place
    model = KAN(width=shape.copy(), grid=start_grid, k=k, seed=seed).to(device)

    per_grid_results = {}

    for i, g in enumerate(grids):
        if i > 0: #skip first iteration
            model = model.refine(g)
        
        print(f"Training grid: {g}")
        results = model.fit(dataset, opt=opt, steps=steps_per_grid, lamb=lamb) #, verbose=verbose)
        per_grid_results[g] = results

    if prune:
        print("Pruning model...")
        model = model.prune(prune_threshold)
        # Disable grid updates after pruning to avoid numerical instability with smaller network
        results = model.fit(dataset, opt=opt, steps=steps_after_prune, lamb=0.0, update_grid=False) #, verbose=verbose)
        per_grid_results['prunning'] = results
        

    if plot_model:
        model.plot()
        plt.show()

    best_test_rmse = min(per_grid_results.values(), key=lambda x: x['test_loss'][-1])['test_loss'][-1]
    
    return model, per_grid_results, best_test_rmse

In [None]:
def sweep_kan_feynman(dataset:str, width: int=5, depths=[2,3,4,5,6], seeds: list[int] = [42, 171, 3], lambs: list[float] = [0.0], **kwargs):

    all_results = {}
    best_test_rmse = float('inf')
    
    # Get dataset info to determine input dimension
    try:
        symbol, expr, f, ranges = get_feynman_dataset(dataset)
        n_inputs = len(ranges)
    except Exception as e:
        raise ValueError(f"Could not load dataset {dataset}")

    for depth in depths:
        all_results[depth] = {}
        for lamb in lambs:
            all_results[depth][lamb] = {}
            for seed in seeds:
                all_results[depth][lamb][seed] = {}
                # Build shape with correct input dimension and output dimension of 1
                # depth includes input and output layers, so we need depth-2 hidden layers
                if depth == 1:
                    shape = [n_inputs, 1]
                else:
                    shape = [n_inputs] + [width for _ in range(depth - 1)] + [1]


                model, per_grid_results, test_rmse = run_kan_feynman(dataset=dataset, shape=shape, lamb=lamb, seed=seed, **kwargs)
                    
                all_results[depth][lamb][seed] = {
                    'model': model,
                    'per_grid_results': per_grid_results,
                    'test_rmse': test_rmse
                }

                if test_rmse < best_test_rmse:
                    best_test_rmse = test_rmse

    return all_results, best_test_rmse

In [8]:
def feynman_human_kan_experiment(dataset: str, shape: list[int], seeds: int| list[int] = [42, 171, 3], samples: int = 1000, **kwargs):
    """
    Run feynman dataset experiment with human-constructed KAN
    """
    all_results = []
    for seed in seeds:
        results = run_kan_feynman(dataset=dataset,
                                seed=seed,
                                samples=samples,
                                shape=shape,
                                **kwargs)
        all_results.append(results)
    
    best_test_rmse = min([res[2] for res in all_results])

    print('==='*20)
    print(f'Best Test RMSE for human-constructed KAN on dataset {dataset}: {best_test_rmse:.4e}')

    return best_test_rmse, all_results
    

def feynman_not_pruned_kan_experiment(dataset: str, seeds: int | list[int], samples=1000, **kwargs):

    results = sweep_kan_feynman(dataset=dataset,
                                seeds=seeds,
                                samples=samples,
                                **kwargs)

    # print('==='*20)
    # print(f'Best Test RMSE for not pruned KAN on dataset {dataset}: {results[1]:.4e}')
    best_test_rmse = results[1]

    return best_test_rmse, results


def feynman_pruned_kan_experiment(dataset: str, seeds: int|list[int] = [42, 171, 3], samples: int = 1000, steps_after_prune: int = 100,  lambs=[1e-2, 1e-3], **kwargs):
    """
    Run feynman dataset experiment with pruned KAN
    """
    results, best_test_rmse = sweep_kan_feynman(dataset=dataset,
                                seeds=seeds,
                                samples=samples,
                                prune=True,
                                prune_threshold=0.01,
                                steps_after_prune=steps_after_prune,
                                lambs=lambs,
                                **kwargs)
    

    # get smallest shape with loss < 1e-2
    # get shape of best loss

    smallest_shape = None
    smallest_params = float('inf')
    shape_best_loss = None
    smallest_loss = float('inf')

    for depth, v in results.items():
        for lamb, vv in v.items():
            for seed, res in vv.items():
                test_rmse = res['test_rmse']
                model = res['model']
                shape = [p[0] if isinstance(p, list) else p for p in model.width]
                num_params = sum(p.numel() for p in model.parameters())


                if test_rmse < 1e-2:
                    if smallest_shape is None or sum(shape) < sum(smallest_shape):
                        if num_params > smallest_params and smallest_shape is not None:
                            print(f"Warning: Found smaller shape {shape} with more parameters {num_params} > {smallest_params}. Still taking smallest shape though.")
                        smallest_shape = shape
                        smallest_params = num_params

                if test_rmse < smallest_loss:
                    smallest_loss = test_rmse
                    shape_best_loss = shape


    
    # If no model achieved RMSE < 1e-2, print warning
    if smallest_shape is None:

        print(f"⚠️ WARNING: No model achieved RMSE < 1e-2 for dataset {dataset}")


    return smallest_shape, shape_best_loss, best_test_rmse, results

In [9]:
def feynman_mlp_experiment(dataset: str, seeds: int|list[int] = [42, 171, 3], samples: int = 1000, device='cuda' if torch.cuda.is_available() else 'cpu',
                           activations=['silu', 'relu', 'tanh'], depths = [2,3,4,5,6], steps=1400, lr=1, **kwargs):
    """
    Run feynman dataset experiment with MLP
    """

    assert device in ['cpu', 'cuda']

    all_results = {}
    best_test_rmse = float('inf')

    try:
        symbol, expr, f, ranges = get_feynman_dataset(dataset)
    except ValueError:
        raise ValueError("Invalid dataset name")
    
    for seed in seeds:
        try:
            torch.manual_seed(seed)
        except Exception as e:
            print(f"Warning: could not set seed due to error: {e}")

        dataset = create_dataset(f, n_var=len(ranges), ranges=ranges, train_num=samples, test_num=samples, seed=seed)
        dataset['train_input'] = dataset['train_input'].to(device)
        dataset['train_label'] = dataset['train_label'].to(device)
        dataset['test_input'] = dataset['test_input'].to(device)
        dataset['test_label'] = dataset['test_label'].to(device)

        all_results[seed] = {}
        
        for depth in depths:
            shape = [len(ranges)] + [5 for _ in range(depth-1)] + [1]

            all_results[seed][depth] = {}

            for act in activations:
                model = MLP(width=shape, act=act, seed=seed, device=device)

                results = model.fit(dataset=dataset, steps=steps, opt="LBFGS", lr=lr)
                all_results[seed][depth][act] = results

                test_rmse = min(results['test_loss'])

                if test_rmse < best_test_rmse:
                    best_test_rmse = test_rmse
                
    return all_results, best_test_rmse

In [10]:
def full_feynman_experiment(dataset: str, shape_human: list[int], seeds: list[int] = [42, 171, 3], samples: int = 1000, skip = [], **kwargs):
    """
    Run full feynman dataset experiment with human-constructed KAN, not pruned KAN, pruned KAN, and MLP
    """

    pruned_smallest_shape, pruned_shape_best_loss, pruned_best_rmse = None, None, None
    h_best_test_rmse, not_pruned_best_rmse, mlp_best_rmse = None, None, None

    if 'human' not in skip:
        print('Running Human-constructed KAN Experiment...')
        h_best_test_rmse, h_results = feynman_human_kan_experiment(dataset=dataset, shape=shape_human, seeds=seeds, samples=samples, **kwargs)
        print(f'Best Test RMSE for human-constructed KAN on dataset {dataset}: {h_best_test_rmse:.4e}')

    if 'unpruned' not in skip:
        print('--------------------------------') 
        print('Running Not Pruned KAN Experiment...')
        not_pruned_best_rmse, not_pruned_results = feynman_not_pruned_kan_experiment(dataset=dataset, seeds=seeds, samples=samples, **kwargs)
        print(f'Best Test RMSE for not pruned KAN on dataset {dataset}: {not_pruned_best_rmse:.4e}')

    if 'pruned' not in skip:
        print('--------------------------------') 
        print('Running Pruned KAN Experiment...')
        pruned_smallest_shape, pruned_shape_best_loss, pruned_best_rmse, pruned_results = feynman_pruned_kan_experiment(dataset=dataset, seeds=seeds, samples=samples, **kwargs)
        print(f'Pruned KAN Smallest Shape: {pruned_smallest_shape}, ')
        print(f'Pruned KAN lowest loss Shape: {pruned_shape_best_loss}, ')
        print(f'Pruned KAN lowest loss: {pruned_best_rmse:.4e}')

    if 'mlp' not in skip:
        print('--------------------------------') 
        print('Running MLP Experiment...')
        mlp_results, mlp_best_rmse = feynman_mlp_experiment(dataset=dataset, seeds=seeds, samples=samples, **kwargs)
        print(f'MLP lowest loss: {mlp_best_rmse:.4e}')

    print('--------------------------------') 
    print('==='*40)
    print('==='*40)
    print(f'Final Results for dataset {dataset}:')

    if 'pruned' not in skip:
        print(f'Pruned KAN Smallest Shape: {pruned_smallest_shape}, ')
        print(f'Pruned KAN lowest loss Shape: {pruned_shape_best_loss}, ')
        print(f'Pruned KAN lowest loss: {pruned_best_rmse:.4e}')
    if 'human' not in skip:
        print(f'Human-constructed KAN loss: {h_best_test_rmse:.4e}')
    if 'unpruned' not in skip:
        print(f'UnPruned KAN lowest loss: {not_pruned_best_rmse:.4e}')
    if 'mlp' not in skip:
        print(f'MLP lowest loss: {mlp_best_rmse:.4e}')
    print('==='*40)
    print('==='*40)

    # save results
    full_results = {
        'human': {
            'best_rmse': h_best_test_rmse,
            'results': h_results
        } if 'human' not in skip else None,
        'not_pruned': {
            'best_rmse': not_pruned_best_rmse,
            'results': not_pruned_results
        } if 'unpruned' not in skip else None,
        'pruned': {
            'smallest_shape': pruned_smallest_shape,
            'shape_best_loss': pruned_shape_best_loss,
            'best_rmse': pruned_best_rmse,
            'results': pruned_results
        } if 'pruned' not in skip else None,
        'mlp': {
            'best_rmse': mlp_best_rmse,
            'results': mlp_results
        } if 'mlp' not in skip else None    
    }

    return full_results

## Experiments

In [11]:
results_I_6_20 = full_feynman_experiment(dataset='I.6.20', shape_human=[6, 4, 2, 1, 1], seeds=[171], skip=['unpruned','human','mlp'], depths=[2,3])

--------------------------------
Running Pruned KAN Experiment...
checkpoint directory created: ./model
saving model version 0.0
Training grid: 3


| train_loss: 2.92e-04 | test_loss: 3.25e-04 | reg: 1.43e+01 | : 100%|█| 200/200 [01:05<00:00,  3.03


saving model version 0.1
saving model version 0.2
Training grid: 5


| train_loss: 1.04e-04 | test_loss: 1.30e-04 | reg: 1.38e+01 | : 100%|█| 200/200 [01:06<00:00,  3.01


saving model version 0.3
saving model version 0.4
Training grid: 10


| train_loss: 2.74e-05 | test_loss: 4.56e-05 | reg: 1.38e+01 | : 100%|█| 200/200 [00:41<00:00,  4.84


saving model version 0.5
saving model version 0.6
Training grid: 20


| train_loss: 1.03e-05 | test_loss: 3.45e-05 | reg: 1.38e+01 | : 100%|█| 200/200 [00:42<00:00,  4.66


saving model version 0.7
saving model version 0.8
Training grid: 50


| train_loss: 4.87e-06 | test_loss: 4.11e-05 | reg: 1.38e+01 | : 100%|█| 200/200 [00:52<00:00,  3.81


saving model version 0.9
saving model version 0.10
Training grid: 100


| train_loss: 2.15e-06 | test_loss: 2.11e-04 | reg: 1.38e+01 | : 100%|█| 200/200 [01:40<00:00,  1.99


saving model version 0.11
saving model version 0.12
Training grid: 200


| train_loss: 1.49e-06 | test_loss: 1.21e-02 | reg: 1.25e+01 | : 100%|█| 200/200 [03:34<00:00,  1.07


saving model version 0.13
Pruning model...
saving model version 0.14


| train_loss: 1.49e-06 | test_loss: 1.21e-02 | reg: 1.25e+01 | : 100%|█| 100/100 [00:57<00:00,  1.74


saving model version 0.15
checkpoint directory created: ./model
saving model version 0.0
Training grid: 3


| train_loss: 3.25e-04 | test_loss: 3.91e-04 | reg: 1.40e+01 | : 100%|█| 200/200 [01:03<00:00,  3.16


saving model version 0.1
saving model version 0.2
Training grid: 5


| train_loss: 1.16e-04 | test_loss: 1.78e-04 | reg: 1.35e+01 | : 100%|█| 200/200 [00:57<00:00,  3.46


saving model version 0.3
saving model version 0.4
Training grid: 10


| train_loss: 3.55e-05 | test_loss: 7.23e-05 | reg: 1.35e+01 | : 100%|█| 200/200 [00:46<00:00,  4.26


saving model version 0.5
saving model version 0.6
Training grid: 20


| train_loss: 1.24e-05 | test_loss: 4.78e-05 | reg: 1.35e+01 | : 100%|█| 200/200 [00:43<00:00,  4.65


saving model version 0.7
saving model version 0.8
Training grid: 50


| train_loss: 5.25e-06 | test_loss: 3.64e-05 | reg: 1.35e+01 | : 100%|█| 200/200 [00:57<00:00,  3.46


saving model version 0.9
saving model version 0.10
Training grid: 100


| train_loss: 1.40e-06 | test_loss: 1.01e-04 | reg: 1.35e+01 | : 100%|█| 200/200 [01:37<00:00,  2.05


saving model version 0.11
saving model version 0.12
Training grid: 200


| train_loss: 9.95e-07 | test_loss: 1.55e-02 | reg: 1.23e+01 | : 100%|█| 200/200 [02:55<00:00,  1.14


saving model version 0.13
Pruning model...
saving model version 0.14


| train_loss: 9.94e-07 | test_loss: 1.55e-02 | reg: 1.23e+01 | : 100%|█| 100/100 [01:09<00:00,  1.44


saving model version 0.15
checkpoint directory created: ./model
saving model version 0.0
Training grid: 3


| train_loss: 8.11e-05 | test_loss: 9.94e-05 | reg: 1.93e+01 | : 100%|█| 200/200 [01:29<00:00,  2.24


saving model version 0.1
saving model version 0.2
Training grid: 5


| train_loss: 2.68e-05 | test_loss: 3.31e-05 | reg: 1.93e+01 | : 100%|█| 200/200 [01:12<00:00,  2.76


saving model version 0.3
saving model version 0.4
Training grid: 10


| train_loss: 9.20e-06 | test_loss: 1.75e-05 | reg: 1.93e+01 | : 100%|█| 200/200 [01:07<00:00,  2.97


saving model version 0.5
saving model version 0.6
Training grid: 20


| train_loss: 2.99e-06 | test_loss: 3.17e-05 | reg: 1.93e+01 | : 100%|█| 200/200 [01:06<00:00,  3.01


saving model version 0.7
saving model version 0.8
Training grid: 50


| train_loss: 1.46e-06 | test_loss: 3.82e-04 | reg: 1.93e+01 | : 100%|█| 200/200 [01:38<00:00,  2.03


saving model version 0.9
saving model version 0.10
Training grid: 100


| train_loss: 3.64e-07 | test_loss: 4.33e-03 | reg: 1.93e+01 | : 100%|█| 200/200 [02:47<00:00,  1.19


saving model version 0.11
saving model version 0.12
Training grid: 200


| train_loss: 2.92e-06 | test_loss: 1.17e-02 | reg: 1.83e+01 | : 100%|█| 200/200 [06:54<00:00,  2.07


saving model version 0.13
Pruning model...
saving model version 0.14


| train_loss: 2.28e-06 | test_loss: 1.43e-02 | reg: 1.64e+01 | : 100%|█| 100/100 [04:50<00:00,  2.90


saving model version 0.15
checkpoint directory created: ./model
saving model version 0.0
Training grid: 3


| train_loss: 6.83e-05 | test_loss: 8.78e-05 | reg: 1.94e+01 | : 100%|█| 200/200 [01:41<00:00,  1.96


saving model version 0.1
saving model version 0.2
Training grid: 5


| train_loss: 2.43e-05 | test_loss: 3.19e-05 | reg: 1.94e+01 | : 100%|█| 200/200 [01:16<00:00,  2.60


saving model version 0.3
saving model version 0.4
Training grid: 10


| train_loss: 7.74e-06 | test_loss: 1.51e-05 | reg: 1.94e+01 | : 100%|█| 200/200 [00:52<00:00,  3.83


saving model version 0.5
saving model version 0.6
Training grid: 20


| train_loss: 2.39e-06 | test_loss: 2.08e-05 | reg: 1.94e+01 | : 100%|█| 200/200 [01:01<00:00,  3.25


saving model version 0.7
saving model version 0.8
Training grid: 50


| train_loss: 1.58e-06 | test_loss: 2.30e-04 | reg: 1.94e+01 | : 100%|█| 200/200 [01:57<00:00,  1.70


saving model version 0.9
saving model version 0.10
Training grid: 100


| train_loss: 4.99e-07 | test_loss: 3.85e-03 | reg: 1.94e+01 | : 100%|█| 200/200 [03:13<00:00,  1.03


saving model version 0.11
saving model version 0.12
Training grid: 200


| train_loss: 1.30e-06 | test_loss: 1.36e-02 | reg: 1.85e+01 | : 100%|█| 200/200 [05:56<00:00,  1.78


saving model version 0.13
Pruning model...
saving model version 0.14


| train_loss: 3.18e-07 | test_loss: 7.32e-03 | reg: 1.70e+01 | : 100%|█| 100/100 [04:41<00:00,  2.82

saving model version 0.15
Pruned KAN Smallest Shape: [2, 5, 1], 
Pruned KAN lowest loss Shape: [2, 5, 5, 1], 
Pruned KAN lowest loss: 1.5057e-05
--------------------------------
Final Results for dataset I.6.20:
Pruned KAN Smallest Shape: [2, 5, 1], 
Pruned KAN lowest loss Shape: [2, 5, 5, 1], 
Pruned KAN lowest loss: 1.5057e-05





In [12]:
results_I_6_20b = full_feynman_experiment(dataset='I.6.20b', shape_human=[3,2,2,1,1], seeds=[171], skip=['unpruned','human','mlp'], depths=[2,3,4])

--------------------------------
Running Pruned KAN Experiment...
checkpoint directory created: ./model
saving model version 0.0
Training grid: 3


description:   0%|                                                          | 0/200 [00:00<?, ?it/s]

| train_loss: 9.01e-03 | test_loss: 1.20e-02 | reg: 1.88e+01 | : 100%|█| 200/200 [01:07<00:00,  2.95


saving model version 0.1
saving model version 0.2
Training grid: 5


| train_loss: 5.83e-03 | test_loss: 9.86e-03 | reg: 1.93e+01 | : 100%|█| 200/200 [01:09<00:00,  2.86


saving model version 0.3
saving model version 0.4
Training grid: 10


| train_loss: 3.14e-03 | test_loss: 6.98e-03 | reg: 1.87e+01 | : 100%|█| 200/200 [01:22<00:00,  2.44


saving model version 0.5
saving model version 0.6
Training grid: 20


| train_loss: 1.84e-03 | test_loss: 6.92e-03 | reg: 1.80e+01 | : 100%|█| 200/200 [01:39<00:00,  2.01


saving model version 0.7
saving model version 0.8
Training grid: 50


| train_loss: 1.95e-03 | test_loss: 1.50e-02 | reg: 1.63e+01 | : 100%|█| 200/200 [02:41<00:00,  1.24


saving model version 0.9
saving model version 0.10
Training grid: 100


| train_loss: 3.33e-05 | test_loss: 1.67e-02 | reg: 1.63e+01 | : 100%|█| 200/200 [04:35<00:00,  1.38


saving model version 0.11
saving model version 0.12
Training grid: 200


| train_loss: 3.80e-06 | test_loss: 1.69e-02 | reg: 1.63e+01 | : 100%|█| 200/200 [05:43<00:00,  1.72


saving model version 0.13
Pruning model...
saving model version 0.14


| train_loss: 3.80e-06 | test_loss: 1.69e-02 | reg: 1.63e+01 | : 100%|█| 100/100 [00:50<00:00,  1.97


saving model version 0.15
checkpoint directory created: ./model
saving model version 0.0
Training grid: 3


| train_loss: 1.04e-02 | test_loss: 1.04e-02 | reg: 2.28e+01 | : 100%|█| 200/200 [01:10<00:00,  2.82


saving model version 0.1
saving model version 0.2
Training grid: 5


| train_loss: 7.02e-03 | test_loss: 1.17e-02 | reg: 2.28e+01 | : 100%|█| 200/200 [00:57<00:00,  3.47


saving model version 0.3
saving model version 0.4
Training grid: 10


| train_loss: 3.49e-03 | test_loss: 8.78e-03 | reg: 2.30e+01 | : 100%|█| 200/200 [01:07<00:00,  2.97


saving model version 0.5
saving model version 0.6
Training grid: 20


| train_loss: 2.65e-03 | test_loss: 8.74e-03 | reg: 2.27e+01 | : 100%|█| 200/200 [01:10<00:00,  2.82


saving model version 0.7
saving model version 0.8
Training grid: 50


| train_loss: 1.41e-03 | test_loss: 9.08e-03 | reg: 2.27e+01 | : 100%|█| 200/200 [01:49<00:00,  1.83


saving model version 0.9
saving model version 0.10
Training grid: 100


| train_loss: 2.52e-04 | test_loss: 1.41e-02 | reg: 2.26e+01 | : 100%|█| 200/200 [04:39<00:00,  1.40


saving model version 0.11
saving model version 0.12
Training grid: 200


| train_loss: 1.05e-05 | test_loss: 4.63e-02 | reg: 1.74e+01 | : 100%|█| 200/200 [08:29<00:00,  2.55


saving model version 0.13
Pruning model...
saving model version 0.14


| train_loss: 4.60e-06 | test_loss: 4.67e-02 | reg: 1.67e+01 | : 100%|█| 100/100 [04:14<00:00,  2.54


saving model version 0.15
checkpoint directory created: ./model
saving model version 0.0
Training grid: 3


| train_loss: 8.86e-04 | test_loss: 1.67e-03 | reg: 2.27e+01 | : 100%|█| 200/200 [01:45<00:00,  1.89


saving model version 0.1
saving model version 0.2
Training grid: 5


| train_loss: 2.21e-04 | test_loss: 7.02e-04 | reg: 2.25e+01 | : 100%|█| 200/200 [01:49<00:00,  1.82


saving model version 0.3
saving model version 0.4
Training grid: 10


| train_loss: 8.81e-05 | test_loss: 2.65e-03 | reg: 2.25e+01 | : 100%|█| 200/200 [01:10<00:00,  2.83


saving model version 0.5
saving model version 0.6
Training grid: 20


| train_loss: 3.40e-05 | test_loss: 7.62e-03 | reg: 2.25e+01 | : 100%|█| 200/200 [01:34<00:00,  2.12


saving model version 0.7
saving model version 0.8
Training grid: 50


| train_loss: 9.63e-06 | test_loss: 1.00e-02 | reg: 2.25e+01 | : 100%|█| 200/200 [02:33<00:00,  1.31


saving model version 0.9
saving model version 0.10
Training grid: 100


| train_loss: 2.52e-06 | test_loss: 1.49e-02 | reg: 2.24e+01 | : 100%|█| 200/200 [03:11<00:00,  1.04


saving model version 0.11
saving model version 0.12
Training grid: 200


| train_loss: 5.20e-07 | test_loss: 2.20e-02 | reg: 1.92e+01 | : 100%|█| 200/200 [06:33<00:00,  1.97


saving model version 0.13
Pruning model...
saving model version 0.14


| train_loss: 8.79e-07 | test_loss: 2.36e-02 | reg: 1.85e+01 | : 100%|█| 100/100 [03:50<00:00,  2.31


saving model version 0.15
checkpoint directory created: ./model
saving model version 0.0
Training grid: 3


| train_loss: 7.56e-04 | test_loss: 2.69e-03 | reg: 2.14e+01 | : 100%|█| 200/200 [01:46<00:00,  1.88


saving model version 0.1
saving model version 0.2
Training grid: 5


| train_loss: 2.28e-04 | test_loss: 1.16e-03 | reg: 2.10e+01 | : 100%|█| 200/200 [01:51<00:00,  1.79


saving model version 0.3
saving model version 0.4
Training grid: 10


| train_loss: 5.61e-05 | test_loss: 1.48e-03 | reg: 2.10e+01 | : 100%|█| 200/200 [02:02<00:00,  1.64


saving model version 0.5
saving model version 0.6
Training grid: 20


| train_loss: 3.15e-05 | test_loss: 2.01e-03 | reg: 2.10e+01 | : 100%|█| 200/200 [01:34<00:00,  2.11


saving model version 0.7
saving model version 0.8
Training grid: 50


| train_loss: 3.95e-06 | test_loss: 1.00e-02 | reg: 2.08e+01 | : 100%|█| 200/200 [02:47<00:00,  1.19


saving model version 0.9
saving model version 0.10
Training grid: 100


| train_loss: 7.49e-07 | test_loss: 1.40e-02 | reg: 2.08e+01 | : 100%|█| 200/200 [03:04<00:00,  1.09


saving model version 0.11
saving model version 0.12
Training grid: 200


| train_loss: 1.54e-05 | test_loss: 1.50e-02 | reg: 1.98e+01 | : 100%|█| 200/200 [06:00<00:00,  1.80


saving model version 0.13
Pruning model...
saving model version 0.14


| train_loss: 1.43e-05 | test_loss: 1.45e-02 | reg: 1.89e+01 | : 100%|█| 100/100 [04:34<00:00,  2.74


saving model version 0.15
checkpoint directory created: ./model
saving model version 0.0
Training grid: 3


| train_loss: 4.46e-04 | test_loss: 8.48e-04 | reg: 2.36e+01 | : 100%|█| 200/200 [02:24<00:00,  1.39


saving model version 0.1
saving model version 0.2
Training grid: 5


| train_loss: 1.33e-04 | test_loss: 4.10e-04 | reg: 2.35e+01 | : 100%|█| 200/200 [02:21<00:00,  1.41


saving model version 0.3
saving model version 0.4
Training grid: 10


| train_loss: 3.73e-05 | test_loss: 3.76e-04 | reg: 2.34e+01 | : 100%|█| 200/200 [02:11<00:00,  1.53


saving model version 0.5
saving model version 0.6
Training grid: 20


| train_loss: 1.11e-05 | test_loss: 8.64e-04 | reg: 2.34e+01 | : 100%|█| 200/200 [01:39<00:00,  2.01


saving model version 0.7
saving model version 0.8
Training grid: 50


| train_loss: 1.50e-06 | test_loss: 1.72e-02 | reg: 2.32e+01 | : 100%|█| 200/200 [03:47<00:00,  1.14


saving model version 0.9
saving model version 0.10
Training grid: 100


| train_loss: 6.94e-06 | test_loss: 3.58e-02 | reg: 2.29e+01 | : 100%|█| 200/200 [07:19<00:00,  2.20


saving model version 0.11
saving model version 0.12
Training grid: 200


| train_loss: 4.06e-02 | test_loss: 9.22e-02 | reg: 2.03e+01 | : 100%|█| 200/200 [21:21<00:00,  6.41


saving model version 0.13
Pruning model...
saving model version 0.14


| train_loss: 5.09e-02 | test_loss: 1.00e-01 | reg: 1.85e+01 | : 100%|█| 100/100 [10:25<00:00,  6.26


saving model version 0.15
checkpoint directory created: ./model
saving model version 0.0
Training grid: 3


| train_loss: 4.32e-04 | test_loss: 7.87e-04 | reg: 2.37e+01 | : 100%|█| 200/200 [02:24<00:00,  1.39


saving model version 0.1
saving model version 0.2
Training grid: 5


| train_loss: 1.07e-04 | test_loss: 3.95e-04 | reg: 2.31e+01 | : 100%|█| 200/200 [02:31<00:00,  1.32


saving model version 0.3
saving model version 0.4
Training grid: 10


| train_loss: 3.88e-05 | test_loss: 3.01e-04 | reg: 2.31e+01 | : 100%|█| 200/200 [01:47<00:00,  1.87


saving model version 0.5
saving model version 0.6
Training grid: 20


| train_loss: 8.93e-06 | test_loss: 5.22e-04 | reg: 2.31e+01 | : 100%|█| 200/200 [02:04<00:00,  1.61


saving model version 0.7
saving model version 0.8
Training grid: 50


| train_loss: 7.83e-07 | test_loss: 2.59e-03 | reg: 2.31e+01 | : 100%|█| 200/200 [02:46<00:00,  1.20


saving model version 0.9
saving model version 0.10
Training grid: 100


| train_loss: 1.01e-06 | test_loss: 1.19e-02 | reg: 2.26e+01 | : 100%|█| 200/200 [04:24<00:00,  1.32


saving model version 0.11
saving model version 0.12
Training grid: 200


| train_loss: 1.18e-02 | test_loss: 4.80e-02 | reg: 2.25e+01 | : 100%|█| 200/200 [22:47<00:00,  6.84


saving model version 0.13
Pruning model...
saving model version 0.14


| train_loss: 1.41e-02 | test_loss: 5.23e-02 | reg: 2.14e+01 | : 100%|█| 100/100 [10:50<00:00,  6.50

saving model version 0.15
Pruned KAN Smallest Shape: [3, 5, 1], 
Pruned KAN lowest loss Shape: [3, 5, 5, 4, 1], 
Pruned KAN lowest loss: 3.0089e-04
--------------------------------
Final Results for dataset I.6.20b:
Pruned KAN Smallest Shape: [3, 5, 1], 
Pruned KAN lowest loss Shape: [3, 5, 5, 4, 1], 
Pruned KAN lowest loss: 3.0089e-04





In [None]:
results_I_9_18 = full_feynman_experiment(dataset='I.9.18', shape_human=[6,4,2,1,1], seeds=[3, 42, 171])


Running Pruned KAN Experiment...
checkpoint directory created: ./model
saving model version 0.0
Training grid: 3


| train_loss: 2.83e-02 | test_loss: 2.89e-02 | reg: 9.50e-01 | : 100%|█| 200/200 [00:06<00:00, 28.76


saving model version 0.1
saving model version 0.2
Training grid: 5


| train_loss: 2.81e-02 | test_loss: 2.91e-02 | reg: 1.07e+00 | : 100%|█| 200/200 [00:07<00:00, 26.08


saving model version 0.3
saving model version 0.4
Training grid: 10


| train_loss: 2.75e-02 | test_loss: 2.95e-02 | reg: 1.31e+00 | : 100%|█| 200/200 [00:09<00:00, 20.58


saving model version 0.5
saving model version 0.6
Training grid: 20


| train_loss: 2.59e-02 | test_loss: 3.22e-02 | reg: 1.76e+00 | : 100%|█| 200/200 [00:17<00:00, 11.55


saving model version 0.7
saving model version 0.8
Training grid: 50


| train_loss: 2.06e-02 | test_loss: 4.15e-02 | reg: 2.66e+00 | : 100%|█| 200/200 [00:17<00:00, 11.38


saving model version 0.9
saving model version 0.10
Training grid: 100


| train_loss: nan | test_loss: nan | reg: nan | : 100%|███████████| 200/200 [00:24<00:00,  8.09it/s]


saving model version 0.11
saving model version 0.12
Training grid: 200


| train_loss: nan | test_loss: nan | reg: nan | : 100%|███████████| 200/200 [00:25<00:00,  7.72it/s]


saving model version 0.13
Pruning model...
saving model version 0.14


| train_loss: nan | test_loss: nan | reg: nan | : 100%|███████████| 100/100 [00:12<00:00,  7.77it/s]


saving model version 0.15
checkpoint directory created: ./model
saving model version 0.0
Training grid: 3


| train_loss: 2.83e-02 | test_loss: 2.89e-02 | reg: 9.49e-01 | : 100%|█| 200/200 [00:08<00:00, 22.51


saving model version 0.1
saving model version 0.2
Training grid: 5


| train_loss: 2.81e-02 | test_loss: 2.91e-02 | reg: 1.07e+00 | : 100%|█| 200/200 [00:08<00:00, 24.09


saving model version 0.3
saving model version 0.4
Training grid: 10


| train_loss: 2.75e-02 | test_loss: 2.95e-02 | reg: 1.31e+00 | : 100%|█| 200/200 [00:09<00:00, 20.26


saving model version 0.5
saving model version 0.6
Training grid: 20


| train_loss: 2.59e-02 | test_loss: 3.22e-02 | reg: 1.76e+00 | : 100%|█| 200/200 [00:17<00:00, 11.44


saving model version 0.7
saving model version 0.8
Training grid: 50


| train_loss: 2.06e-02 | test_loss: 4.15e-02 | reg: 2.66e+00 | : 100%|█| 200/200 [00:17<00:00, 11.71


saving model version 0.9
saving model version 0.10
Training grid: 100


| train_loss: nan | test_loss: nan | reg: nan | : 100%|███████████| 200/200 [00:24<00:00,  8.01it/s]


saving model version 0.11
saving model version 0.12
Training grid: 200


| train_loss: nan | test_loss: nan | reg: nan | : 100%|███████████| 200/200 [00:25<00:00,  7.69it/s]


saving model version 0.13
Pruning model...
saving model version 0.14


| train_loss: nan | test_loss: nan | reg: nan | : 100%|███████████| 100/100 [00:13<00:00,  7.55it/s]


saving model version 0.15
checkpoint directory created: ./model
saving model version 0.0
Training grid: 3


| train_loss: 2.83e-02 | test_loss: 2.89e-02 | reg: 9.50e-01 | : 100%|█| 200/200 [00:08<00:00, 24.23


saving model version 0.1
saving model version 0.2
Training grid: 5


| train_loss: 2.81e-02 | test_loss: 2.91e-02 | reg: 1.07e+00 | : 100%|█| 200/200 [00:08<00:00, 24.50


saving model version 0.3
saving model version 0.4
Training grid: 10


| train_loss: 2.75e-02 | test_loss: 2.95e-02 | reg: 1.31e+00 | : 100%|█| 200/200 [00:09<00:00, 20.98


saving model version 0.5
saving model version 0.6
Training grid: 20


| train_loss: 2.59e-02 | test_loss: 3.22e-02 | reg: 1.76e+00 | : 100%|█| 200/200 [00:15<00:00, 13.01


saving model version 0.7
saving model version 0.8
Training grid: 50


| train_loss: 2.06e-02 | test_loss: 4.15e-02 | reg: 2.66e+00 | : 100%|█| 200/200 [00:16<00:00, 11.82


saving model version 0.9
saving model version 0.10
Training grid: 100


| train_loss: nan | test_loss: nan | reg: nan | : 100%|███████████| 200/200 [00:25<00:00,  7.96it/s]


saving model version 0.11
saving model version 0.12
Training grid: 200


| train_loss: nan | test_loss: nan | reg: nan | : 100%|███████████| 200/200 [00:25<00:00,  7.72it/s]


saving model version 0.13
Pruning model...
saving model version 0.14


| train_loss: nan | test_loss: nan | reg: nan | : 100%|███████████| 100/100 [00:12<00:00,  7.78it/s]

saving model version 0.15

Final Results for dataset I.9.18:
Pruned KAN Smallest Shape: [5, 5, 5, 5, 5, 5, 5, 5, 5, 5], 
Pruned KAN lowest loss Shape: [9, 1], 
Pruned KAN lowest loss: 2.8923e-02



