In [39]:
import os
import torch
import pandas as pd
from skimage import io, transform
from PIL import Image

import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
import torchvision

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

seed = 69

torch.manual_seed(seed)
np.random.seed(seed)

In [50]:
class MnistDataset(Dataset):
    def __init__(self, csv_file, transform=None, train=True):
        '''
        Args:
            csv_file (string)
            transform (callable, optional)
        '''
        
        self.mnist = pd.read_csv(csv_file)
        self.train = train
        self.transform = transform
        
    def __len__(self):
        return len(self.mnist)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        if self.train:
            label = self.mnist.iloc[idx][0]
            label = np.array(label)

            img = self.mnist.iloc[idx].values[1:]
            img = np.reshape(img, (28,28))
            img = Image.fromarray(img)
            if self.transform:
                img = self.transform(img)
            
            sample = {"label": torch.from_numpy(label), "image":img}
                
        else:
            img = self.mnist.iloc[idx].values
            img = np.reshape(img, (28,28))
            sample = {"image": torch.from_numpy(img)}
            

            
        return sample

class ToTensor(object):
    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        return {'image': torch.from_numpy(image),
                'label': torch.from_numpy(label)}

In [56]:
training_split = 0.8
batch_size = 32

mnist_train_dataset = MnistDataset("mnist/train.csv", transform=transforms.Compose([
                                                                                    transforms.RandomRotation(2.8),
    transforms.Normalize((0.1307,), (0.3081,))
                                                                                    transforms.ToTensor(),
                                                                                     ]))
mnist_test_dataset = MnistDataset("mnist/test.csv", train=False)

train_length = int(training_split * len(mnist_train_dataset))
validation_length = len(mnist_train_dataset) - train_length

train_dataset, validation_dataset = torch.utils.data.random_split(mnist_train_dataset, (train_length, validation_length))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test_dataset, batch_size=1, shuffle=False)

In [57]:
mnist_train_dataset[0]["image"].shape

TypeError: Cannot handle this data type: (1, 1), <i8

In [None]:
imgs_ = 6
fig = plt.figure()
plt.figure(figsize=(15,imgs_))
for i in range(imgs_):
    ax = plt.subplot(1, imgs_, i+1)
    ax.set_title('sample #{}'.format(i))
    plt.imshow(np.reshape(mnist_train_dataset[i]["image"], (28,28)), cmap='gray')
    
plt.show()