In [None]:
import os
import cv2
import time
import pickle
import numpy as np
from torchvision.datasets import MNIST


def get_time():
    return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))


def unpickle(file):
    with open(file, "rb") as fo:
        return pickle.load(fo, encoding="bytes")


trainset = MNIST(root="./data", train=True, download=True)
testset = MNIST(root="./data", train=False, download=True)
print(f"{get_time()} Trainset size: {len(trainset)}")
print(f"{get_time()} Testset size: {len(testset)}")


train_count = {i: 1 for i in range(10)}
test_count = {i: 1 for i in range(10)}

# trainset
for i in range(len(trainset)):
    img, label = trainset[i]
    folder = f"{label:05d}"
    save_dir = os.path.join(
        "./data/mnist/train", folder, f"image{train_count[label]}.jpg"
    )
    os.makedirs(os.path.dirname(save_dir), exist_ok=True)
    cv2.imwrite(save_dir, cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
    train_count[label] += 1

# testset
for i in range(len(testset)):
    img, label = testset[i]
    folder = f"{label:05d}"
    save_dir = os.path.join(
        "./data/mnist/test", folder, f"image{test_count[label]}.jpg"
    )
    os.makedirs(os.path.dirname(save_dir), exist_ok=True)
    cv2.imwrite(save_dir, cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
    test_count[label] += 1

print(f"{get_time()} Done!")