In [1]:
import sys
sys.path.append('../src')
import sys
import os
import argparse
import warnings
import numpy
import torch
import torch.optim as optim
from torchvision import datasets, transforms
from fvcore.common.timer import Timer

from utils import *
from models.kan.LBFGS import *

In [2]:
def train(args, model, device, train_loader, optimizer, epoch, logger, start_index):
    model.train()
    fvctimer = Timer()
    closure_count_list = []
    for batch_idx, (data, target) in enumerate(train_loader, start_index):
        closure_count = 0

        data, target = todevice(data, device), todevice(target, device)

        if fvctimer.is_paused():
            fvctimer.resume()
        else:
            fvctimer.reset()

        if args.optimizer == "adam":
            # print("adam")
            optimizer.zero_grad()
            output = model(data)

            if args.loss == "cross_entropy":
                losses = [F.cross_entropy(output, target)]
            elif args.loss == "mse":
                losses = [F.mse_loss(output, target)]
            else:
                raise NotImplementedError
            
            loss = 0
            for l in losses:
                loss = loss + l
            loss.backward()
            optimizer.step()

        elif args.optimizer == "lbfgs":
            # print("lbfgs")

            def closure():
                nonlocal closure_count
                closure_count += 1
                optimizer.zero_grad()
                output = model(data)
                if args.loss == "cross_entropy":
                    losses = [F.cross_entropy(output, target)]
                elif args.loss == "mse":
                    losses = [F.mse_loss(output, target)]
                else:
                    raise NotImplementedError

                loss = 0
                for l in losses:
                    loss = loss + l

                loss.backward()
                return loss

            optimizer.step(closure)

        fvctimer.pause()
        closure_count_list.append(closure_count)

    total_training_time = fvctimer.seconds()
    average_training_time_per_iteration = fvctimer.avg_seconds()
    total_closure_count = sum(closure_count_list)
    average_closure_count_per_iteration = numpy.mean(closure_count_list)
    print(f"Epoch: {epoch}; total training time: {total_training_time:,} seconds; average training time per iteration: {average_training_time_per_iteration:,} seconds")
    print(f"Epoch: {epoch}; total closure count: {total_closure_count}; average closure count per iteration: {average_closure_count_per_iteration}")
    return model

In [3]:
def main(args):

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    if use_cuda:
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    randomness_control(args.seed)

    train_loader, test_loader, num_classes, input_size = get_loader(args, use_cuda = use_cuda)

    args.output_size = num_classes
    args.input_size = input_size

    args.activation = get_activation(args)
    args.kan_shortcut_function = get_shortcut_function(args)

    model = get_model(args)
    model = model.to(device)
    num_parameters, flops = get_model_complexity(model, None, args)

    if args.optimizer == "adam":
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
    elif args.optimizer == "lbfgs":
        optimizer = LBFGS(
            filter(lambda p: p.requires_grad, model.parameters()), 
            lr=args.lr, 
            history_size=10, 
            line_search_fn="strong_wolfe",
            tolerance_grad=1e-32, 
            tolerance_change=1e-32, 
            tolerance_ys=1e-32)
    else:
        raise NotImplementedError

    # scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)

    print(f"Model: {args.model}, Number of parameters: {num_parameters:,}, FLOPs: {flops:,}")
    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch, logger = None, start_index = (epoch - 1) *len(train_loader))


In [4]:
def forward_once(args, model, device, train_loader, optimizer, epoch, logger, start_index):
    model.train()
    fvctimer = Timer()
    for batch_idx, (data, target) in enumerate(train_loader, start_index):
        data, target = todevice(data, device), todevice(target, device)

        if fvctimer.is_paused():
            fvctimer.resume()
        else:
            fvctimer.reset()

        output = model(data)
        
        fvctimer.pause()

    total_forward_time = fvctimer.seconds()
    average_forward_time_per_iteration = fvctimer.avg_seconds()
    print(f"Epoch: {epoch}; total forward time: {total_forward_time:,} seconds; average forward time per iteration: {average_forward_time_per_iteration:,} seconds")
    return model

In [5]:
def main_forward(args):

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    if use_cuda:
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    randomness_control(args.seed)

    train_loader, test_loader, num_classes, input_size = get_loader(args, use_cuda = use_cuda)

    args.output_size = num_classes
    args.input_size = input_size

    args.activation = get_activation(args)
    args.kan_shortcut_function = get_shortcut_function(args)

    model = get_model(args)
    model = model.to(device)
    num_parameters, flops = get_model_complexity(model, None, args)

    if args.optimizer == "adam":
        optimizer = None
        pass
    elif args.optimizer == "lbfgs":
        optimizer = LBFGS(
            filter(lambda p: p.requires_grad, model.parameters()), 
            lr=args.lr, 
            history_size=10, 
            line_search_fn="strong_wolfe",
            tolerance_grad=1e-32, 
            tolerance_change=1e-32, 
            tolerance_ys=1e-32)
    else:
        raise NotImplementedError

    # scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)

    print(f"Model: {args.model}, Number of parameters: {num_parameters:,}, FLOPs: {flops:,}")
    for epoch in range(1, args.epochs + 1):
        forward_once(args, model, device, train_loader, optimizer, epoch, logger = None, start_index = (epoch - 1) *len(train_loader))


# LBFGS

## KAN
Restart the notebook before evaluation

In [5]:
def get_kan_args():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch Training')

    parser.add_argument('--model', type=str, default="KAN", #required=True,
                        help='network structure')
    parser.add_argument('--layers_width', type=int, default=[8, 8], nargs='+', #required=True,
                        help='the width of each hidden layer')
    parser.add_argument('--batch_norm', action='store_true', default=False,
                        help='whether use batch normalization')
    parser.add_argument('--activation_name', type=str, default="relu", 
                        help='activation function')
    parser.add_argument('--pre_train_ckpt', type=str, default="", 
                        help='path of the pretrained model')

    parser.add_argument('--dataset', type=str, default="MNIST", #required=True,
                        help='dataset')

    parser.add_argument('--batch-size', type=int, default=1024,
                        help='input batch size for training (default: 1024)')
    parser.add_argument('--test-batch-size', type=int, default=128,
                        help='input batch size for testing (default: 128)')
    parser.add_argument('--epochs', type=int, default=3, # 100 MNIST pretrain, 5 Finetune
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=0.01,
                        help='learning rate (default: 0.01)')
    parser.add_argument('--optimizer', type=str, default="lbfgs",
                        help='supported optimizer: adam, lbfgs')
    # parser.add_argument('--gamma', type=float, default=0.7,
    #                     help='Learning rate step gamma (default: 0.7, 1.0 for fewshot)')
    parser.add_argument('--loss', type=str, default="cross_entropy",
                        help='loss function')

    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--dry-run', action='store_true', default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed', type=int, default=1314,
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=100000,
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    parser.add_argument('--save-model-interval', type = int, default=-1, 
                        help='whether save model along training')
    parser.add_argument('--evaluateion_interval', type = int, default=100000, 
                        help='interval between two evaluations')
    ################# Parameters for KAN #################
    parser.add_argument('--kan_bspline_grid', type=int, default=15, 
                        help='the grid size of the bspline in the KAN layer')
    parser.add_argument('--kan_bspline_order', type=int, default=3, 
                        help='the order of the bspline in the KAN layer')
    parser.add_argument('--kan_shortcut_name', type=str, default="silu", 
                        help='the shortcut(base) function in the KAN layer: zero, identity, silu')
    parser.add_argument('--kan_grid_range', type=float, default=[-4, 4], nargs=2,
                        help='the range of the grid in the KAN layer. default is [-1, 1]. but for general normalized data, it can be larger.')
    ################# Parameters for KAN #################
    ################# Parameters for MLP #################
    ## pass ##
    ################# Parameters for MLP #################
    args = parser.parse_args([])

    args.save_model_along = args.save_model_interval > 0

    return args

In [5]:
main(get_kan_args())

seed 1314
Model: KAN, Number of parameters: 134,762, FLOPs: 3,545,632.0
Epoch: 1; total training time: 77.31417033076286 seconds; average training time per iteration: 1.3104096666230993 seconds
Epoch: 1; total closure count: 1535; average closure count per iteration: 26.016949152542374
Epoch: 2; total training time: 75.79276225715876 seconds; average training time per iteration: 1.2846230891043857 seconds
Epoch: 2; total closure count: 1534; average closure count per iteration: 26.0
Epoch: 3; total training time: 75.65926801413298 seconds; average training time per iteration: 1.282360474815813 seconds
Epoch: 3; total closure count: 1530; average closure count per iteration: 25.93220338983051


## MLP
Restart the notebook before evaluation

In [4]:
def get_mlp_args():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch Training')

    parser.add_argument('--model', type=str, default="MLP", #required=True,
                        help='network structure')
    parser.add_argument('--layers_width', type=int, default=[1024, 1024], nargs='+', #required=True,
                        help='the width of each hidden layer')
    parser.add_argument('--batch_norm', action='store_true', default=False,
                        help='whether use batch normalization')
    parser.add_argument('--activation_name', type=str, default="relu", 
                        help='activation function')
    parser.add_argument('--pre_train_ckpt', type=str, default="", 
                        help='path of the pretrained model')

    parser.add_argument('--dataset', type=str, default="MNIST", #required=True,
                        help='dataset')

    parser.add_argument('--batch-size', type=int, default=1024,
                        help='input batch size for training (default: 1024)')
    parser.add_argument('--test-batch-size', type=int, default=128,
                        help='input batch size for testing (default: 128)')
    parser.add_argument('--epochs', type=int, default=3, # 100 MNIST pretrain, 5 Finetune
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=0.01,
                        help='learning rate (default: 0.01)')
    parser.add_argument('--optimizer', type=str, default="lbfgs",
                        help='supported optimizer: adam, lbfgs')
    # parser.add_argument('--gamma', type=float, default=0.7,
    #                     help='Learning rate step gamma (default: 0.7, 1.0 for fewshot)')
    parser.add_argument('--loss', type=str, default="cross_entropy",
                        help='loss function')

    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--dry-run', action='store_true', default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed', type=int, default=1314,
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=100000,
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    parser.add_argument('--save-model-interval', type = int, default=-1, 
                        help='whether save model along training')
    parser.add_argument('--evaluateion_interval', type = int, default=100000, 
                        help='interval between two evaluations')
    ################# Parameters for KAN #################
    parser.add_argument('--kan_bspline_grid', type=int, default=20, 
                        help='the grid size of the bspline in the KAN layer')
    parser.add_argument('--kan_bspline_order', type=int, default=5, 
                        help='the order of the bspline in the KAN layer')
    parser.add_argument('--kan_shortcut_name', type=str, default="silu", 
                        help='the shortcut(base) function in the KAN layer: zero, identity, silu')
    parser.add_argument('--kan_grid_range', type=float, default=[-4, 4], nargs=2,
                        help='the range of the grid in the KAN layer. default is [-1, 1]. but for general normalized data, it can be larger.')
    ################# Parameters for KAN #################
    ################# Parameters for MLP #################
    ## pass ##
    ################# Parameters for MLP #################
    args = parser.parse_args([])

    args.save_model_along = args.save_model_interval > 0

    return args

In [5]:
main(get_mlp_args())

seed 1314
Model: MLP, Number of parameters: 1,863,690, FLOPs: 3,725,322
Epoch: 1; total training time: 5.6864582523703575 seconds; average training time per iteration: 0.09638064834526029 seconds
Epoch: 1; total closure count: 1525; average closure count per iteration: 25.847457627118644
Epoch: 2; total training time: 4.047977037727833 seconds; average training time per iteration: 0.06860978030047174 seconds
Epoch: 2; total closure count: 1522; average closure count per iteration: 25.796610169491526
Epoch: 3; total training time: 4.682358644902706 seconds; average training time per iteration: 0.07936201093055434 seconds
Epoch: 3; total closure count: 1535; average closure count per iteration: 26.016949152542374


# Forward Time Evaluation

## KAN

In [9]:
main_forward(get_kan_args())

seed 1314
Model: KAN, Number of parameters: 134,762, FLOPs: 3,545,632.0
Epoch: 1; total forward time: 4.250635847449303 seconds; average forward time per iteration: 0.07204467538049665 seconds
Epoch: 2; total forward time: 2.2687958478927612 seconds; average forward time per iteration: 0.03845416691343663 seconds
Epoch: 3; total forward time: 2.29159764200449 seconds; average forward time per iteration: 0.0388406380000761 seconds


## MLP

In [7]:
main_forward(get_mlp_args())

seed 1314
Model: MLP, Number of parameters: 1,863,690, FLOPs: 3,725,322
Epoch: 1; total forward time: 1.757744014263153 seconds; average forward time per iteration: 0.029792271428189035 seconds
Epoch: 2; total forward time: 0.024965673685073853 seconds; average forward time per iteration: 0.0004231470116114212 seconds
Epoch: 3; total forward time: 0.02934861183166504 seconds; average forward time per iteration: 0.0004974340988417803 seconds


# Adam

## KAN

In [6]:
kan_args = get_kan_args()
kan_args.optimizer = "adam"
main(kan_args)

seed 1314


Model: KAN, Number of parameters: 134,762, FLOPs: 3,545,632.0
Epoch: 1; total training time: 4.801859557628632 seconds; average training time per iteration: 0.08138745012929884 seconds
Epoch: 1; total closure count: 0; average closure count per iteration: 0.0
Epoch: 2; total training time: 2.7768291011452675 seconds; average training time per iteration: 0.04706490001941131 seconds
Epoch: 2; total closure count: 0; average closure count per iteration: 0.0
Epoch: 3; total training time: 2.7728042006492615 seconds; average training time per iteration: 0.04699668136693663 seconds
Epoch: 3; total closure count: 0; average closure count per iteration: 0.0


## MLP

In [5]:
mlp_args = get_mlp_args()
mlp_args.optimizer = "adam"
main(mlp_args)

seed 1314
Model: MLP, Number of parameters: 1,863,690, FLOPs: 3,725,322
Epoch: 1; total training time: 1.297071099281311 seconds; average training time per iteration: 0.02198425592002222 seconds
Epoch: 1; total closure count: 0; average closure count per iteration: 0.0
Epoch: 2; total training time: 0.15516074001789093 seconds; average training time per iteration: 0.0026298430511506937 seconds
Epoch: 2; total closure count: 0; average closure count per iteration: 0.0
Epoch: 3; total training time: 0.16446852684020996 seconds; average training time per iteration: 0.002787602149834067 seconds
Epoch: 3; total closure count: 0; average closure count per iteration: 0.0
