In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os

In [None]:
basedir = "../input/chinese-mnist"
image_path = os.path.join(basedir, "data")
image_files = os.path.join(image_path, "data", "*")

In [None]:
ds_files = tf.data.Dataset.list_files(image_files, shuffle=False)
ds_files = ds_files.shuffle(len(ds_files), reshuffle_each_iteration=False)

In [None]:
for i in ds_files.take(5):
    print(i)

In [None]:
def get_image_label(file_path):
    label_str = tf.strings.split(tf.strings.split(file_path, "_")[3], ".")[0]
    label_number = tf.strings.to_number(label_str, out_type=tf.dtypes.int32) - 1
    return label_number

def get_image(file_path):
    image = tf.io.read_file(file_path)
    image = tf.image.decode_jpeg(image, channels=1)
    image = tf.cast(image, tf.float32) / 255.0
    return image

def ds_map_fn(ds):
    file_path = ds
    return get_image(file_path), get_image_label(file_path)

ds_all = ds_files.map(ds_map_fn, num_parallel_calls=tf.data.AUTOTUNE)

In [None]:
cardinality_all = ds_all.cardinality()
ds_train = ds_all.skip(cardinality_all // 5)
ds_test_all = ds_all.take(cardinality_all // 5)
cardinality_test = ds_test_all.cardinality()
ds_val = ds_test_all.take(cardinality_test // 2)
ds_test = ds_test_all.skip(cardinality_test // 2)
print(cardinality_all)
print(cardinality_test)

In [None]:
ds_train = ds_train.cache().shuffle(10000).batch(32).prefetch(tf.data.AUTOTUNE)
ds_test = ds_test.batch(32).cache().prefetch(tf.data.AUTOTUNE)
ds_val = ds_val.batch(32).cache().prefetch(tf.data.AUTOTUNE)

In [None]:
plt.figure(figsize=(7, 7))
for image, label in ds_train.take(1):
    for i in range(9):
        plt.subplot(3, 3, i + 1)
        plt.axis("off")
        plt.title(label[i].numpy())
        plt.imshow(image[i], cmap="gray")    
plt.show()

In [None]:
model = tf.keras.models.Sequential(
    [
        tf.keras.layers.Conv2D(
            32, (3, 3), activation="relu", input_shape=(64, 64, 1)
        ),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(64, (3, 3), activation="relu"),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(64, (3, 3), activation="relu"),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation="relu"),
        tf.keras.layers.Dropout(0.3),
        tf.keras.layers.Dense(15),
    ]
)

model.compile(
    optimizer=tf.keras.optimizers.Adam(3e-4),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.summary()

In [None]:
epochs = 7
history = model.fit(ds_train, epochs=epochs, validation_data=ds_val)

In [None]:
acc = history.history["sparse_categorical_accuracy"]
val_acc = history.history["val_sparse_categorical_accuracy"]

loss = history.history["loss"]
val_loss = history.history["val_loss"]

epochs_range = range(epochs)

plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label="Training Accuracy")
plt.plot(epochs_range, val_acc, label="Validation Accuracy")
plt.legend(loc="lower right")
plt.title("Training and Validation Accuracy")

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label="Training Loss")
plt.plot(epochs_range, val_loss, label="Validation Loss")
plt.legend(loc="upper right")
plt.title("Training and Validation Loss")
plt.show()

In [None]:
model.evaluate(ds_test)