In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from Load import load_cifar10, load_model
from Search import hyperparameter_tuning
from Model import ThreeLayerNN 
from Train import train
from Evaluate import evaluate

# 数据加载

In [None]:
np.random.seed(10)
data_dir = './cifar-10-batches-py'  # 文件路径
X_train, y_train, X_val, y_val, X_test, y_test = load_cifar10(data_dir)

# 超参数搜索

In [None]:
best_params = hyperparameter_tuning(X_train, y_train, X_val, y_val)

# 模型训练 + 保存(使用最优超参数)

In [None]:
np.random.seed(10)
final_model = ThreeLayerNN(3072, best_params['hidden_size'], 10, 
                         activation='relu', reg=best_params['reg'])
final_model = train(final_model, X_train, y_train, X_val, y_val,
                  epochs=100, learning_rate=best_params['lr'])
final_model.save('my_model.npz')

In [None]:
# 3. 模型训练 + 保存(使用最优超参数)
np.random.seed(10)
final_model = ThreeLayerNN(3072, 256, 10, 
                         activation='relu', reg=1e-4)
final_model = train(final_model, X_train, y_train, X_val, y_val,
                  epochs=50, learning_rate=0.03)
final_model.save('my_model.npz')

# 测试

In [None]:
# 4. 测试集
my_model = load_model('my_model.npz')
test_acc = evaluate(my_model, X_test, y_test)
print(f'Test Accuracy: {test_acc:.4f}')

# 参数分布

In [None]:
def visualize_parameters(model):
    """可视化神经网络参数"""
    plt.figure(figsize=(15, 10))
    
    # W1 可视化
    plt.subplot(2, 2, 1)
    sns.heatmap(model.params['W1'], cmap='coolwarm', center=0)
    plt.title('W1 Weight Matrix')
    plt.xlabel('Hidden Units')
    plt.ylabel('Input Features')
    
    # W2 可视化
    plt.subplot(2, 2, 2)
    sns.heatmap(model.params['W2'], cmap='coolwarm', center=0)
    plt.title('W2 Weight Matrix')
    plt.xlabel('Output Classes')
    plt.ylabel('Hidden Units')
    
    # b1 可视化
    plt.subplot(2, 2, 3)
    plt.plot(range(len(model.params['b1'])), model.params['b1'])
    plt.title('b1 Bias Vector')
    plt.xlabel('Hidden Units')
    plt.ylabel('Bias Value')
    
    # b2 可视化
    plt.subplot(2, 2, 4)
    plt.plot(range(len(model.params['b2'])), model.params['b2'])
    plt.title('b2 Bias Vector')
    plt.xlabel('Output Classes')
    plt.ylabel('Bias Value')
    
    plt.tight_layout()
    plt.show()

In [None]:
visualize_parameters(my_model)