## This notebook is to prepare the dataset.
The data comes from:
https://www.kaggle.com/datasets/jidhumohan/mnist-png/

In the demo that I gave, I have already prepared the data and it is stored in the same repo as these notebooks. So **you should not have to run this notebook**. However, I include it for completeness.

In [17]:
import pandas as pd
from pathlib import Path
from shutil import copy


Note: the data below is not included in the zip file, so these commands won't run.

If you really want to run it, you'll need to download the data from the link above.

In [3]:
data_root = Path('/data/mnist/archive/mnist_png/')
data_root.exists()

True

In [4]:
def get_data_set(data_root: Path, subdir: str) -> pd.DataFrame:
    """
    Given a root directory and a subdirectory, returns a Pandas DataFrame containing information about PNG files in the subdirectory.

    Args:
        data_root (Path): The root directory containing the subdirectory.
        subdir (str): The name of the subdirectory containing the PNG files.

    Returns:
        pd.DataFrame: A DataFrame with columns for the PNG file names, labels, and numerical names, sorted by numerical name.
    """
    # Get a list of all PNG files in the subdirectory
    pngs = list(data_root.joinpath(subdir).glob('**/*.png'))

    # Extract the file names, numerical names, and labels from the list of PNG files
    png_files = [f.name for f in pngs]
    numerical_names = [int(f.name[:-4]) for f in pngs]
    labels = [int(f.parent.name) for f in pngs]

    # Check that there are no duplicate file names
    assert(len(png_files) == len(set(png_files)))

    # Create a DataFrame with the extracted information, sort it by numerical name, and drop the numerical column
    data_set = pd.DataFrame({'png': png_files, 'label': labels, 'numerical': numerical_names})
    data_set = data_set.sort_values('numerical').reset_index(drop=True).drop(columns='numerical')

    return data_set

In [5]:
test = get_data_set(data_root, 'testing')
test.tail()

Unnamed: 0,png,label
9995,9995.png,2
9996,9996.png,3
9997,9997.png,4
9998,9998.png,5
9999,9999.png,6


In [6]:
train = get_data_set(data_root, 'training')
train.tail()

Unnamed: 0,png,label
59995,59995.png,8
59996,59996.png,3
59997,59997.png,5
59998,59998.png,6
59999,59999.png,8


In [22]:
def copy_pngs(data_root: Path, subdir: str, output_dir: Path):
    """
    Copies all PNG files from a subdirectory to a new directory.

    Args:
        data_root (Path): The root directory containing the subdirectory.
        subdir (str): The name of the subdirectory containing the PNG files.
            Should be 'training' or 'testing'.
        output_dir (Path): The directory to copy the PNG files to.
    """

    output_dir = (Path(output_dir) / subdir[:-3])
    output_dir.mkdir(parents=False, exist_ok=True)

    # Get a list of all PNG files in the subdirectory
    pngs = list((data_root/ subdir).glob('**/*.png'))

    # Copy each PNG file to the output directory
    for index, png in enumerate(pngs):
        output_file = output_dir / png.name
        if not output_file.exists():
            copy(png, output_file)
        if (index + 1) % 100 == 0:
            print(f'Copied {index + 1} files', end='\r')

    return output_dir

In [28]:
def create_dataset_dir(data_root: Path, subdir: str, output_dir: Path, csv_df: pd.DataFrame):
    """
    Copies all PNG files from a subdirectory to a new directory and creates a CSV file containing information about the files.

    Args:
        data_root (Path): The root directory containing the subdirectory.
        subdir (str): The name of the subdirectory containing the PNG files.
            Should be 'training' or 'testing'.
        output_dir (Path): The directory to copy the PNG files to.
        csv_df (pd.DataFrame): A DataFrame containing information about the PNG files.
    """

    output_dir = copy_pngs(data_root, subdir, output_dir)
    csv_name = subdir[:-3] + '.csv'
    csv_df.to_csv(output_dir / csv_name, index=False)
    print(f'Created {str(output_dir / csv_name)}')

In [27]:
create_dataset_dir(data_root, 'testing', '/data/mnist/', test)
create_dataset_dir(data_root, 'training', '/data/mnist/', train)

Created \data\mnist\test\test.csv file
Created \data\mnist\train\train.csv file


WindowsPath('/data/mnist/train')