In [1]:
from torchvision.datasets import mnist
from pathlib import Path
import torch

In [2]:
datadir = "../../data/external/"

In [3]:
MNIST_train = mnist.FashionMNIST(root=datadir, download=True, train=True)
MNIST_test = mnist.FashionMNIST(root=datadir, download=True, train=False)

In [4]:
filepath = Path(datadir) / "fashionmnist.pt"
dataset = {
    "traindata" : MNIST_train.data, 
    "trainlabels" : MNIST_train.targets,
    "testdata" : MNIST_test.data, 
    "testlabels" : MNIST_test.targets,
    }
torch.save(dataset, filepath)

In [21]:
import hashlib

def calculate_md5(file_path, block_size=2**16):
    md5 = hashlib.md5()
    with open(file_path, 'rb') as f:
        for block in iter(lambda: f.read(block_size), b''):
            md5.update(block)
    return md5.hexdigest()

In [47]:
filepath = Path(datadir) / "fashionmnist.pt"
digest = calculate_md5(filepath)
digest

'c4f1c3f76673fe3802f579773267163a'

In [50]:
filepath = Path(datadir) / "mnist.pt"
digest = calculate_md5(filepath)
digest

'52cae56d5bc7c427e8a2393dd1f8df29'

In [5]:
from pathlib import Path
datadir = "../../data/external/"
filepath = Path(datadir) / "fashionmnist.pt"

In [8]:
filepath.resolve()

PosixPath('/Users/rgrouls/code/ML22/data/external/fashionmnist.pt')

In [6]:
import torch
dataset = torch.load(filepath)

In [15]:
class MNISTDataset:
    """ MNIST dataset 
    Args:   
        data (torch.Tensor): images
        labels (torch.Tensor): labels   
        transform (callable, optional): Optional transform to be applied
            on a sample.

    Returns:        
        torch.Tensor: image
        torch.Tensor: label 
    """
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx): 
        # add single batch dimension
        image = self.data[idx].unsqueeze(0)
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label 

In [16]:
traindata = MNISTDataset(dataset["traindata"], dataset["trainlabels"])

In [17]:
x, y = traindata[0]

In [18]:
x.shape, y

(torch.Size([1, 28, 28]), tensor(9))