In [None]:
%cd ../
## Autorreload all the files
%load_ext autoreload
%autoreload 2

from src.spline import *

print("Device cuda: ", torch.cuda.is_available())

f:\MonoKAN
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Device cuda:  True


In [None]:
import torch
import numpy as np
# import matplotlib
# matplotlib.use("QT5Agg")  # 使用Qt5后端
import matplotlib.pyplot as plt

from kan import KAN


def create_dcm_swissmetro_dataset(
    train_num=1000,
    test_num=1000,
    ranges=[-2, 2],    # 比如特征范围设宽一点
    noise_std=0.1,     # 噪声强度
    normalize_input=False,
    normalize_label=False,
    device='cpu',
    seed=0
):
    '''
    生成SwissMetro DCM三备选 synthetic 数据集，含线性项和交互项

    Args:
    -----
        train_num : int
            训练样本数
        test_num : int
            测试样本数
        ranges : list, (2,) or (n_var, 2)
            特征范围（每个特征的min,max），支持统一/分特征
        noise_std : float
            效用扰动高斯噪声标准差
        normalize_input : bool
            是否归一化输入
        normalize_label : bool
            是否归一化输出
        device : str
            设备
        seed : int
            随机种子

    Returns:
    -------
        dataset : dict
            'train_input':  (train_num, 9)
            'test_input':   (test_num, 9)
            'train_label':  (train_num, 3)
            'test_label':   (test_num, 3)
    '''

    np.random.seed(seed)
    torch.manual_seed(seed)
    n_var = 9

    # 统一范围/分特征范围
    if len(np.array(ranges).shape) == 1:
        ranges = np.array(ranges * n_var).reshape(n_var,2)
    else:
        ranges = np.array(ranges)

    def sample_input(num):
        x = torch.zeros(num, n_var)
        for i in range(n_var):
            x[:, i] = torch.rand(num) * (ranges[i, 1] - ranges[i, 0]) + ranges[i, 0]
        return x

    train_input = sample_input(train_num)
    test_input = sample_input(test_num)

    # 别名: x[0~2] = train, x[3~6] = SM, x[7~8] = car
    def utility_func(x):
        # x: [batch, 9]
        # beta, gamma: 任意设置，可微调
        # train
        u_train = (
            -2.0 * x[:, 0]    # train_tt
            -1.5 * x[:, 1]    # train_co
            -0.5 * x[:, 2]    # train_he
            + 1.0 * x[:, 0] * x[:, 2]  # interaction: train_tt * train_he
        )
        # SM
        u_sm = (
            -2.2 * x[:, 3]    # SM_tt
            -1.4 * x[:, 4]    # SM_co
            -0.8 * x[:, 5]    # SM_he
            + 1.2 * x[:, 3] * x[:, 5]  # interaction: SM_tt * SM_he
            + 0.6 * x[:, 6]   # SM_seats
        )
        # car
        u_car = (
            -1.8 * x[:, 7]    # car_TT
            -2.1 * x[:, 8]    # car_CO
            + 0.7 * x[:, 7] * x[:, 8]  # interaction: car_TT * car_CO
        )
        # 叠加高斯噪声
        batch = x.shape[0]
        noise = noise_std * torch.randn(batch, 3)
        return torch.stack([u_train, u_sm, u_car], dim=1) + noise

    train_label = utility_func(train_input)
    test_label = utility_func(test_input)

    def normalize(data, mean, std):
        return (data - mean) / std

    if normalize_input:
        mean_input = torch.mean(train_input, dim=0, keepdim=True)
        std_input = torch.std(train_input, dim=0, keepdim=True)
        train_input = normalize(train_input, mean_input, std_input)
        test_input = normalize(test_input, mean_input, std_input)
    if normalize_label:
        mean_label = torch.mean(train_label, dim=0, keepdim=True)
        std_label = torch.std(train_label, dim=0, keepdim=True)
        train_label = normalize(train_label, mean_label, std_label)
        test_label = normalize(test_label, mean_label, std_label)

    dataset = dict(
        train_input = train_input.to(device),
        test_input  = test_input.to(device),
        train_label = train_label.to(device),
        test_label  = test_label.to(device),
    )
    return dataset


def test_multikan():

    width = [
        [9, 0],    # 9 input features
        [6, 3],    # 6 sum nodes, 3 mult nodes (for 3 pairwise interactions)
        [3, 0]     # 3 outputs (utility for each alternative)
    ]
    mult_arity = [
        [],        # input layer: no mult node
        [2, 2, 2], # each mult node does two-way interaction
        []         # output layer
    ]
    kan = KAN(width, mult_arity=2, device='cuda')

    dataset = create_dcm_swissmetro_dataset(train_num=1000, test_num=1000, device='cuda')
    print(dataset['train_input'].shape, dataset['train_label'].shape)
    kan(dataset['train_input'])
    kan.plot(in_vars=['train_tt', 'train_co', 'train_he', 'SM_tt', 'SM_co', 'SM_he', 'SM_seats', 'car_TT', 'car_CO'],
             out_vars=['train', 'SM', 'car'],
             title='SwissMetro DCM KAN (untrained)',)
    # plt.show()
    
    # 训练KAN模型
    kan.fit(dataset, opt="LBFGS", steps=20, lamb=0.01)
    kan.plot(in_vars=['train_tt', 'train_co', 'train_he', 'SM_tt', 'SM_co', 'SM_he', 'SM_seats', 'car_TT', 'car_CO'],
             out_vars=['train', 'SM', 'car'],
             title='SwissMetro DCM KAN (trained)',)
    # plt.show()
    return
    
if __name__ == '__main__':
    test_multikan()
    # print("Test completed successfully.")