In [None]:
import struct
import numpy as np
from google.colab import drive

drive.mount('/content/drive', force_remount=True)

class MNISTDataset:
    def __init__(self, image_path, label_path):
        self.image_path = image_path
        self.label_path = label_path
        self.images = None
        self.labels = None

    def read_idx(self, filename):
        with open(filename, 'rb') as f:
            # Read magic number and check validity
            magic_number = struct.unpack('>I', f.read(4))[0]
            if magic_number != 2051 and magic_number != 2049:
                raise ValueError(f'Invalid magic number {magic_number} in IDX file')

            # Read dimensions
            dims = struct.unpack('>B', f.read(1))[0]
            shape = tuple(struct.unpack('>I', f.read(4))[0] for _ in range(dims))

            # Read data
            data = np.frombuffer(f.read(), dtype=np.uint8)

            # Check if the shape matches expected size
            assert data.size == np.prod(shape), f"Expected shape {shape} does not match data size {data.size}"

            data = data.reshape(shape)

            # For labels, ensure they are flattened to 1D array
            if dims == 1:
                data = data.ravel()

            return data

    def load_dataset(self):
        self.images = self.read_idx(self.image_path)
        self.labels = self.read_idx(self.label_path)

    def get_train_data(self):
        return self.images, self.labels

    def get_test_data(self):
        test_images = self.read_idx(self.image_path.replace('train', 't10k'))
        test_labels = self.read_idx(self.label_path.replace('train', 't10k'))
        return test_images, test_labels

# Update paths for MNIST dataset in your Google Drive
image_path_train = '/content/drive/MyDrive/MNIST/train-images.idx3-ubyte'
label_path_train = '/content/drive/MyDrive/MNIST/train-labels.idx1-ubyte'
image_path_test = '/content/drive/MyDrive/MNIST/t10k-images.idx3-ubyte'
label_path_test = '/content/drive/MyDrive/MNIST/t10k-labels.idx1-ubyte'

# Initialize the dataset
mnist_train = MNISTDataset(image_path_train, label_path_train)
mnist_test = MNISTDataset(image_path_test, label_path_test)

# Load the dataset
mnist_train.load_dataset()
train_images, train_labels = mnist_train.get_train_data()

test_images, test_labels = mnist_test.get_test_data()


Mounted at /content/drive


AssertionError: Expected shape () does not match data size 47040011