In [None]:
# -*-coding:utf-8-*-
'''
@ author: Deepseek
'''

import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider
from MLP import MLP
from tensorflow.keras.datasets import mnist

# 加载数据和模型
(x_train, y_train), (x_test, y_test) = mnist.load_data()
mlp = MLP(784, 10, [128, 64], ['relu', 'relu', 'linear', 'softmax'])
mlp.load('mlp_model.npz')
x_test_flattened = x_test.reshape(-1, 784) / 255.0

# 预计算错误样本
errors = []
for i in range(len(x_test)):
    prob = mlp.forward(x_test_flattened[i:i+1])[0]
    if np.argmax(prob) != y_test[i]:
        errors.append(i)
errors = np.array(errors)

def visualize_digit(error_index):
    actual_index = errors[error_index]
    img = x_test[actual_index]
    y_pred = mlp.forward(x_test_flattened[actual_index:actual_index+1])[0]
    true_label = y_test[actual_index]
    pred_label = np.argmax(y_pred)

    fig, axs = plt.subplots(1, 2, figsize=(14, 5))

    # 图像展示
    axs[0].imshow(img, cmap='gray')
    axs[0].axis('off')
    title = f'Misclassified Digit (True: {true_label}, Pred: {pred_label})'
    axs[0].set_title(title, fontsize=14, color='red')

    # 概率柱状图
    classes = np.arange(10)
    colors = ['#1f77b4' if c != true_label else '#ff7f0e' for c in classes]
    bars = axs[1].bar(classes, y_pred*100, color=colors)
    axs[1].set_xticks(classes)
    axs[1].set_xlabel('Class', fontsize=12)
    axs[1].set_ylabel('Probability (%)', fontsize=12)
    axs[1].set_ylim(0, 100)

    # 添加标注和星号
    for i, (prob, bar) in enumerate(zip(y_pred*100, bars)):
        axs[1].text(i, prob + 2, f'{prob:.1f}%',
                   ha='center', va='bottom', fontsize=10)
        if i == true_label:
            axs[1].text(i, prob + 8, '★',
                       ha='center', va='bottom', color='red', fontsize=20)

    plt.tight_layout()
    plt.show()

# 创建带长滑动条的交互
interact(visualize_digit,
         error_index=IntSlider(min=0,
                               max=len(errors)-1,
                               step=1,
                               value=0,
                               description='Error Index',
                               layout={'width': '1000px'},
                               style={'description_width': 'initial'}))

interactive(children=(IntSlider(value=0, description='Error Index', layout=Layout(width='1000px'), max=260, st…

<function __main__.visualize_digit(error_index)>