In [1]:
import neura
import neura.nn as nn
import neura.optim as optim
import numpy as np
from sklearn.datasets import fetch_openml

ModuleNotFoundError: No module named 'neura'

In [None]:
class MNISTDataset(Dataset):
    def __init__(self, train=True, transform=None):
        """
        Initializes the MNIST dataset.

        Args:
            train (bool): If True, loads the training data, otherwise loads test data.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.transform = transform

        # Fetch the data. It's a large download the first time.
        print("Fetching MNIST dataset...")
        # fetch_openml is a reliable way to get the original MNIST data
        mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
        print("Dataset fetched.")

        # The data is in a dictionary-like object
        # Images are 784-dimensional vectors (28*28)
        # Labels are strings '0', '1', ...
        images = mnist.data
        labels = mnist.target

        # Preprocessing Steps
        # 1. Normalize pixel values from [0, 255] to [0, 1.0]
        images = images / 255.0
        # 2. Convert labels from strings to integers
        labels = labels.astype(int)
        
        # 3. Cast data to a more memory-efficient type if desired
        images = images.astype(np.float32)

        # Split into training and testing sets (standard MNIST split is 60k/10k)
        if train:
            self.images = images[:60000]
            self.labels = labels[:60000]
        else:
            self.images = images[60000:]
            self.labels = labels[60000:]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        """
        Returns a tuple of (image, label) for a given index.
        The image is a flattened vector of 784 pixels.
        """
        image = self.images[index]
        label = self.labels[index]

        # In the future, you could apply transforms here, e.g., for data augmentation
        if self.transform:
            image = self.transform(image)

        return image, label
