# Train Model

In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.layers import Dense, Dropout, Flatten, Conv1D, MaxPooling1D, GlobalAveragePooling1D

## Load data

In [None]:
mnist = tf.keras.datasets.mnist

In [None]:
(mnist_x_train, mnist_y_train) , (mnist_x_test, mnist_y_test) = mnist.load_data()
plt.imshow(mnist_x_train[0], cmap = plt.cm.binary)

## Neural Network Model
### Train Neural Network Model

In [None]:
mnist_x_train_norm = tf.keras.utils.normalize(mnist_x_train, axis = 1)
mnist_x_test_norm = tf.keras.utils.normalize(mnist_x_test, axis = 1)

In [None]:
nn_model = tf.keras.models.Sequential()
inputs = tf.keras.Input(shape=mnist_x_train_norm.shape[1:], name="mnist_image")

x = Conv1D(64, 4, activation='relu')(inputs)
x = Conv1D(64, 4, activation='relu')(x)
x = MaxPooling1D(pool_size=2)(x)

x = Conv1D(128, 4, activation='relu')(x)
x = Conv1D(128, 4, activation='relu')(x)

x = GlobalAveragePooling1D()(x)
x = Dropout(0.5)(x)

# Flatten is optional after GAP; kept for consistency
x = Flatten()(x)
outputs = Dense(10, activation='softmax', name="class_probs")(x)

nn_model = tf.keras.Model(inputs=inputs, outputs=outputs)
nn_model.compile(optimizer='adam',
                 loss='sparse_categorical_crossentropy',
                 metrics=['accuracy'])

nn_model.summary()

In [None]:
%%time
nn_model.fit(mnist_x_train_norm, mnist_y_train, epochs = 13)
mnist_val_loss, mnist_val_acc = nn_model.evaluate(mnist_x_test_norm, mnist_y_test)
print(mnist_val_loss, mnist_val_acc)

In [None]:
nn_model.save('mnist_9917.keras')

### Use model

In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import cv2
import math

In [None]:
nn_model = tf.keras.models.load_model('mnist_9917.keras')
nn_model.summary()

In [None]:
def wrong(data, model=None):
    predictions = model.predict(mnist_x_test_norm)
    counter = 0
    for i in range(len(data)):
        if np.argmax(predictions[i]) != data[i]:
            counter += 1
            plt.imshow(mnist_x_test_norm[i], cmap = plt.cm.binary)
            plt.show()
            plt.close()
    print(counter)
    
def predict(filepath, model=None):
    IMG_SIZE = 28
    img_array = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
    new_array = cv2.resize(img_array, (IMG_SIZE,IMG_SIZE))
    last_array = tf.keras.utils.normalize(new_array, axis = 1)
    prediction = model.predict([last_array.reshape(-1, IMG_SIZE,IMG_SIZE)])
    print('The number you wrote is: ', np.argmax(prediction))
    plt.figure(figsize=(15,5))
    plt.subplot(121)
    plt.imshow(last_array, cmap='gray_r')
    plt.subplot(122)
    x = np.linspace(0,9,10)
    plt.xticks(x)
    plt.bar(x,list(prediction)[0])
    
def plot_conv_weights(weights):
    w_min = np.min(weights)
    w_max = np.max(weights)
    num_filters = weights.shape[2]
    num_grids = math.ceil(np.sqrt(num_filters))
    fig, axes = plt.subplots(num_grids, num_grids)

    for i, ax in enumerate(axes.flat):
        if i<num_filters:
            img = weights[:, map([1,1,1,1], zip(*l)), i]
            ax.imshow(img, vmin=w_min, vmax=w_max, interpolation='nearest', cmap='seismic')
        ax.set_xticks([])
        ax.set_yticks([])
    plt.show()

In [None]:
wrong(mnist_y_test, model=nn_model) # Counts and shows all the wrong numbers from mnist-test

In [None]:
predict('number_3.png', model=nn_model) # Predicts the number written in number_3.png

In [21]:
from ipycanvas import Canvas
from ipywidgets import VBox, Button, HBox, Label, HTML
from IPython.display import display
import numpy as np
import cv2
import tensorflow as tf

# 假设你已经在别处训练并载入了模型
# nn_model = tf.keras.models.load_model("your_mnist_model.h5")

# 1. 画布
canvas = Canvas(width=200, height=200, sync_image_data=True)  # 关键：sync_image_data=True
canvas.line_width = 25

# 初始化白色背景
canvas.fill_style = 'white'
canvas.fill_rect(0, 0, 200, 200)
canvas.stroke_style = 'black'

drawing = False

def handle_mouse_down(x, y):
    global drawing
    drawing = True
    canvas.begin_path()
    canvas.move_to(x, y)

def handle_mouse_move(x, y):
    if drawing:
        canvas.line_to(x, y)
        canvas.stroke()

def handle_mouse_up(x, y):
    global drawing
    drawing = False

canvas.on_mouse_down(handle_mouse_down)
canvas.on_mouse_move(handle_mouse_move)
canvas.on_mouse_up(handle_mouse_up)

# 2. 清空按钮
clear_button = Button(description="Clear Canvas")

def clear_canvas(_):
    canvas.fill_style = 'white'
    canvas.fill_rect(0, 0, 200, 200)

clear_button.on_click(clear_canvas)

# 3. 预测按钮 + 结果
result_label = HTML(value="<h1 style='color:blue; text-align:center;'>Prediction: ?</h1>")

def preprocess_and_predict(_):
    # 1) 获取画布像素 (RGBA)
    img = np.asarray(canvas.get_image_data(), dtype=np.uint8)   # (280, 280, 4)

    # 2) 转灰度
    img = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)                # (280, 280)

    # 3) 缩放成 28x28
    img = cv2.resize(img, (28, 28))

    # 4) 反色（看你训练时的数据，如果训练时是白字黑底，可以保留这一行）
    img = 255 - img

    # 5) 跟你训练时一致的归一化
    # 你原来是：tf.keras.utils.normalize(img, axis=1)
    last_array = tf.keras.utils.normalize(img, axis=1)

    # 6) reshape 成模型需要的形状
    # 如果你的模型输入是 (28, 28)：
    X = last_array.reshape(-1, 28, 28)
    # 如果你训练时是 (28, 28, 1)，就改成：
    # X = last_array.reshape(-1, 28, 28, 1)

    prediction = nn_model.predict(X)
    digit = int(np.argmax(prediction))

    print('The number you wrote is: ', digit)
    result_label.value = f"<h1 style='color:blue; text-align:center;'>Prediction: {digit}</h1>"

predict_button = Button(description="Predict")
predict_button.on_click(preprocess_and_predict)

# 4. 显示整个 UI
display(VBox([canvas,
              HBox([predict_button, clear_button]),
              result_label]))



VBox(children=(Canvas(height=200, sync_image_data=True, width=200), HBox(children=(Button(description='Predict…