In [1]:
import torch

In [2]:
from fastai.vision import *

In [26]:
import struct

class MnistDataset(torch.utils.data.Dataset):
    """Mnist dataset."""

    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return self.images[idx], self.labels[idx]
    
    @classmethod
    def load(cls, images_path, labels_path, valid_percent=0.25, transform=None):
        images = None
        labels = None
        with open(images_path, 'rb') as f:
            fb_data = f.read()

            offset = 0
            fmt_header = '>iiii'    # 以大端法读取4个 unsinged int32
            magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, fb_data, offset)
            offset += struct.calcsize(fmt_header)
            fmt_image = '>' + str(num_rows * num_cols) + 'B'
            images = np.empty((num_images, 1, num_rows, num_cols)) # 补一个channel=1
            for i in range(num_images):
                im = struct.unpack_from(fmt_image, fb_data, offset)
                images[i] = np.array(im).reshape((num_rows, num_cols))
                offset += struct.calcsize(fmt_image)
            
        with open(labels_path, 'rb') as f:
            fb_data = f.read()

            offset = 0
            fmt_header = '>ii'    # 以大端法读取2个 unsinged int32
            magic_number, num_labels = struct.unpack_from(fmt_header, fb_data, offset)
            offset += struct.calcsize(fmt_header)
            fmt_image = '>' + str(num_labels) + 'B'
            images = np.empty((num_images, 1, num_rows, num_cols)) # 补一个channel=1
            labels = struct.unpack_from(fmt_image, fb_data, offset)
        
        num_of_train = round(len(images)*(1-valid_percent))
        train_ds = cls(images[:num_of_train], labels[:num_of_train], transform)
        valid_ds = cls(images[num_of_train:], labels[num_of_train:], transform)
        return train_ds, valid_ds

In [27]:
path = Path('../../data/mnist-raw/')

In [28]:
path.ls()

[PosixPath('../../data/mnist-raw/train-images-idx3-ubyte'),
 PosixPath('../../data/mnist-raw/train-images-idx3-ubyte.gz'),
 PosixPath('../../data/mnist-raw/train-labels-idx1-ubyte.gz'),
 PosixPath('../../data/mnist-raw/train-labels-idx1-ubyte')]

In [29]:
dataset = MnistDataset.load(path/'train-images-idx3-ubyte', path/'train-labels-idx1-ubyte')

In [30]:
dataset

(<__main__.MnistDataset at 0x7febf0f77fd0>,
 <__main__.MnistDataset at 0x7febf0f771d0>)

In [31]:
dataloaders = {x: torch.utils.data.DataLoader(dataset[idx], batch_size=4,
                                             shuffle=True, num_workers=0)
              for idx, x in enumerate(['train', 'val'])}

In [32]:
inputs, classes = next(iter(dataloaders['train']))

In [33]:
inputs.shape

torch.Size([4, 1, 28, 28])

In [34]:
classes.shape

torch.Size([4])

# 总结

custom transform提上日程