In [1]:
import numpy as np
import torch
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import Dataset
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data.sampler import SubsetRandomSampler

In [2]:
class MyCustomPixelDataset(Dataset):
    """
    This dataset is a toy dataset that acts like 2x2 version of MNIST
    with the trace as labels.
    """
    def __init__(self, size=1, transform=None):
        super(MyCustomPixelDataset, self).__init__()
        
        # size determines length of dataset
        self.size = size
        
        # generate random 2x2 matrices with values in [0, 255]
        self.matrices = [np.random.randint(256, size=(2, 2)) for i in range(self.size)]
        
        # our dataset will be a list of (ndarray, label) tuples where the label is the trace
        # reshape our matrices into (2x2x1) ndarrays so that transforms.ToTensor()
        # has the appropriate expected shape
        self.matrices = [(np.reshape(np.array([matrix]), (2,2,1)), np.trace(matrix)) for matrix in self.matrices]
        
        # if we're passed a transform then apply it to the first element of our tuple (the input)
        if transform:
            self.matrices = [(transform(element[0]), element[1]) for element in self.matrices]
        
    # length should just be the size although we don't validate size (i.e. should be an int > 0)
    def __len__(self):
        return self.size
    
    # implement __getitem__ as the indexed tuple 
    def __getitem__(self, index):
        #assert 0 <= index <= self.b - self.a
        
        return self.matrices[index]

In [13]:
my_dataset = MyCustomPixelDataset(size = 12, transform = transforms.ToTensor())

In [14]:
my_dataset.matrices[0]

(tensor([[[ 51,  89],
          [ 30, 206]]]), 257)

In [15]:
my_dataset_loader = torch.utils.data.DataLoader(dataset=my_dataset, 
                                           batch_size=4, 
                                           shuffle=True)

In [19]:
for epoch in range(2):   
    for batch_index, (inputs, labels) in enumerate(my_dataset_loader):
        print(epoch, batch_index, labels)

0 0 tensor([288, 231, 358, 328])
0 1 tensor([161, 336, 242, 257])
0 2 tensor([148, 312, 330, 482])
1 0 tensor([242, 358, 312, 328])
1 1 tensor([231, 288, 257, 336])
1 2 tensor([330, 148, 161, 482])


In [22]:
for i, a in enumerate([1, 2, 3], 4):
    print(i, a)

4 1
5 2
6 3
