In [None]:
from model import NeuralNetwork
from utils import load_mnist_data, CrossEntropyLoss
from activation import ReLU,LeakyReLU, Tanh
from optimizer import MiniBatchGradientDescent
from visualize import plot_training_history, plot_confusion_matrix, plot_misclassified_examples
import numpy as np
import matplotlib.pyplot as plt

In [None]:
X_train, X_val, X_test, y_train, y_val, y_test = load_mnist_data()
input_dim = 784
output_dim = 10
baseline = NeuralNetwork(input_dim, output_dim, hidden_dims=[256,128], activation=ReLU(),lammbda=1e-4)
baseline1 = NeuralNetwork(input_dim, output_dim, hidden_dims=[512,256], activation=LeakyReLU(),lammbda=1e-3)
baseline2 = NeuralNetwork(input_dim, output_dim, hidden_dims=[128], activation=Tanh(),lammbda=0)

optimizer = MiniBatchGradientDescent(learning_rate=0.02, batch_size=32)
loss = CrossEntropyLoss()
baseline2.compile(loss, optimizer)

In [None]:
import os
import time
from datetime import datetime

# 训练模型
history = baseline2.fit(X_train, y_train, epochs=100, X_val=X_val, y_val=y_val, verbose=1)

# Create a directory for this run
run_time = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = os.path.join('runs', f'run_{run_time}')
os.makedirs(run_dir, exist_ok=True)

# Plot and save training history
history_path = os.path.join(run_dir, 'training_history.png')
plot_training_history(baseline2.history, save_path=history_path)
print(f"Training history plot saved to {history_path}")

# Evaluate the model on the test set
test_loss, test_acc = baseline2.evaluate(X_test, y_test)
print(f"Test loss: {test_loss:.4f}")
print(f"Test accuracy: {test_acc:.4f}")

# Also save the training metrics as a text file
with open(os.path.join(run_dir, 'training_metrics.txt'), 'w') as f:
    f.write(f"Final training accuracy: {baseline1.history['train_acc'][-1]:.4f}\n")
    f.write(f"Final validation accuracy: {baseline1.history['val_acc'][-1]:.4f}\n")
    f.write(f"Final training loss: {baseline1.history['train_loss'][-1]:.4f}\n")
    f.write(f"Final validation loss: {baseline1.history['val_loss'][-1]:.4f}\n")
    f.write(f"Test loss: {test_loss:.4f}\n")
    f.write(f"Test accuracy: {test_acc:.4f}\n")


y_pred_probs = baseline2.predict(X_test)  # 获取预测概率
y_pred = np.argmax(y_pred_probs, axis=1)  # 将概率转换为类别标签
y_true = np.argmax(y_test, axis=1)  # 将one-hot编码的真实标签转换为类别标签

# 2. 可视化混淆矩阵
plot_confusion_matrix(y_true, y_pred, class_names=list(range(10)), normalize=False)
plot_confusion_matrix(y_true, y_pred, class_names=list(range(10)), normalize=True)

# 3. 可视化错误分类样本
plot_misclassified_examples(X_test, y_true, y_pred, class_names=list(range(10)))