In [3]:
import matplotlib.pyplot as plt

from mnist import MODEL_STATE_PATH, Model

In [4]:
def prompt_yes_no(message: str, default: bool = False) -> bool:
    choice = input(message).strip().lower()
    if not choice:
        return default
    return choice in ("y", "yes")


def prompt_sample_index(max_index: int, default: int = 0) -> int:
    while True:
        choice = input(
            f"Enter test sample index (0-{max_index}, default {default}): "
        ).strip()

        if not choice:
            return default

        try:
            value = int(choice)
            return max(0, min(value, max_index))
        except ValueError:
            print("Invalid input, please enter an integer.")


def show_sample_image(image) -> None:
    plt.imshow(image, cmap="gray")
    plt.show()


def maybe_retrain_model(epochs: int = 10) -> Model:
    if MODEL_STATE_PATH.exists():
        choice = (
            input(
                f"Existing weights found at {MODEL_STATE_PATH}. "
                "Press Enter to use them, or type 'r' to retrain: "
            )
            .strip()
            .lower()
        )
        if choice == "r":
            model = Model()
            model.train(epochs)
        else:
            model = Model.load(MODEL_STATE_PATH)
    else:
        print("No trained weights found, starting a new training run.")
        model = Model()
        model.train(epochs)

    return model


def inference_loop(model: Model) -> None:
    max_index = model.get_max_test_index()
    default_index = 0

    while True:
        sample_index = prompt_sample_index(max_index, default_index)
        result = model.classify(sample_index)
        print(f"Prediction for sample {result['index']}: {result['prediction']}")

        if prompt_yes_no("Display the digit with matplotlib? [y/N]: "):
            show_sample_image(result["image"])

        if not prompt_yes_no("Classify another test sample? [y/N]: "):
            break


def main():
    model = maybe_retrain_model()
    inference_loop(model)


if __name__ == "__main__":
    main()

Existing weights found at mnist_cnn.pt. Press Enter to use them, or type 'r' to retrain:  


Loaded trained weights from mnist_cnn.pt


Enter test sample index (0-9999, default 0):  50


Prediction for sample 50: 6


Display the digit with matplotlib? [y/N]:  n
Classify another test sample? [y/N]:  n
