In [None]:
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torchvision import datasets, transforms

# Squares

In [None]:
class SquareDataset(Dataset):
    def __init__(self, size):
        self.size = size
        self.X = torch.randint(255, (size, 9), dtype=torch.float)

        real_w = torch.tensor([[1,1,1,0,0,0,0,0,0],
                               [0,0,0,1,1,1,0,0,0],
                               [0,0,0,0,0,0,1,1,1]], 
                               dtype=torch.float)

        y = torch.argmax(self.X.mm(real_w.t()), 1)
        
        self.Y = torch.zeros(size, 3, dtype=torch.float) \
                      .scatter_(1, y.view(-1, 1), 1)

    def __getitem__(self, index):
        return (self.X[index], self.Y[index])

    def __len__(self):
        return self.size

In [None]:
squares = SquareDataset(256)
print(squares[34])
print(squares[254])
print(squares[25])

In [None]:
dataloader = DataLoader(squares, batch_size=5)

for batch, (X, Y) in enumerate(dataloader):
    print(X, '\n\n', Y)
    break

# Digits

In [None]:
digits = datasets.MNIST('data', train=True, download=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Lambda(lambda x: x.view(28*28))
                    ]),
                    target_transform=transforms.Compose([
                        transforms.Lambda(lambda y: 
                                          torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
                    ])
                 )

In [None]:
dataloader = DataLoader(digits, batch_size=10, shuffle=True)

for batch, (X, Y) in enumerate(dataloader):
    print(X, '\n\n', Y)
    break