### 加载数据集

In [9]:
import sys
sys.path.append('../../')


from datasets.datasets import DatasetManager

dataset_manager = DatasetManager(batch_size=128)
train_loader, test_loader = dataset_manager.mnist_dataset()

Using device: cpu
CIFAR-10 path: /home/shiroha/Code/Frontend/KAN/datasets/CIFAR10
MNIST path: /home/shiroha/Code/Frontend/KAN/datasets/mnist


In [10]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

class NaiveFourierKANLayer(nn.Module):
    def __init__(self, inputdim, outdim, initial_gridsize, addbias=True):
        super(NaiveFourierKANLayer, self).__init__()
        self.addbias = addbias
        self.inputdim = inputdim
        self.outdim = outdim

        # Learnable gridsize parameter
        self.gridsize_param = nn.Parameter(torch.tensor(initial_gridsize, dtype=torch.float32))

        # Fourier coefficients as a learnable parameter with Xavier initialization
        self.fouriercoeffs = nn.Parameter(torch.empty(2, outdim, inputdim, initial_gridsize))
        nn.init.xavier_uniform_(self.fouriercoeffs)

        if self.addbias:
            self.bias = nn.Parameter(torch.zeros(1, outdim))

    def forward(self, x):
        gridsize = torch.clamp(self.gridsize_param, min=1).round().int()
        xshp = x.shape
        outshape = xshp[:-1] + (self.outdim,)
        x = torch.reshape(x, (-1, self.inputdim))
        k = torch.reshape(torch.arange(1, gridsize + 1, device=x.device), (1, 1, 1, gridsize))
        xrshp = torch.reshape(x, (x.shape[0], 1, x.shape[1], 1))
        c = torch.cos(k * xrshp)
        s = torch.sin(k * xrshp)
        y = torch.sum(c * self.fouriercoeffs[0:1, :, :, :gridsize], (-2, -1))
        y += torch.sum(s * self.fouriercoeffs[1:2, :, :, :gridsize], (-2, -1))
        if self.addbias:
            y += self.bias
        y = torch.reshape(y, outshape)
        return y

class MNISTFourierKAN(nn.Module):
    def __init__(self):
        super(MNISTFourierKAN, self).__init__()
        self.fourierkan1 = NaiveFourierKANLayer(28*28, 128, initial_gridsize=28)
        self.fourierkan2 = NaiveFourierKANLayer(128, 10, initial_gridsize=4)

    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten the images
        x = self.fourierkan1(x)
        x = self.fourierkan2(x)
        return x

### 加载模型

In [13]:
# from models.models import ModelManager
# from torchinfo import summary

# Initialize the model and optimizer with a lower learning rate
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MNISTFourierKAN().to(device)  # Use 'cuda' for GPU
optimizer = optim.LBFGS(model.parameters(), lr=0.01)  # Reduced learning rate from 0.1 to 0.01

# 使用torchsummary输出模型结构
# summary(model, input_size=(64,))  # 假设输入特征为64维

### 训练(利用预训练模型可只执行第一步然后跳去评估部分)

In [14]:
from weights.weights import WeightManager
weight_manager = WeightManager()
weight_name = 'test'

In [15]:
# Define the training loop
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        def closure():
            optimizer.zero_grad()
            output = model(data)
            loss = nn.CrossEntropyLoss()(output, target)
            loss.backward()
            return loss
        data, target = data.to(device), target.to(device)
        optimizer.step(closure)
        if batch_idx % 10 == 0:
            loss = closure()
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
            weight_manager.save_model(model, optimizer, epoch=epoch, dir_name=weight_name, file_name=f'{weight_name}_{epoch}_checkpoint.pth')

# Train the model for only one epoch as per user request
for epoch in range(1, 2):
    train(model, device, train_loader, optimizer, epoch)

Model saved to /home/shiroha/Code/Frontend/KAN/weights/test/test_1_checkpoint.pth


KeyboardInterrupt: 

### 评估

#### 指定单个文件测试模式

In [None]:
# Evaluate the model
model, optimizer, start_epoch = weight_manager.load_model(model, optimizer, dir_name='test', file_name='test_1_checkpoint.pth', device=device)
def evaluate(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += nn.CrossEntropyLoss()(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n')

# Evaluate the trained model
evaluate(model, device, test_loader)

  checkpoint = torch.load(load_path, map_location=device)


Model loaded from /home/shiroha/Code/Frontend/KAN/weights/test/test_1_checkpoint.pth, starting from epoch 1

Test set: Average loss: 0.0175, Accuracy: 2648/10000 (26%)



#### 指定文件夹全部权重文件测试

In [17]:
pth_files = weight_manager.list_pth_files(dir_name=weight_name)

if pth_files:
    print("Available .pth files:")
    for pth_file in pth_files:
        print(f"- {pth_file}")
    for model in pth_files:
        evaluate(model, device, test_loader)


AttributeError: 'WeightManager' object has no attribute 'list_pth_files'