# PyTorch : Reading data

In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
import pandas as pd

We are reading in the data as a pandas DataFrame in this scenario. But as explained below, it is not limited to just a dataframe.

In [2]:
csv_file = '../data/mnist.csv'

In [3]:
class CustomData(Dataset):
    def __init__(self, file_name):
        ## file_name is something that I added.
        self.df = pd.read_csv(file_name)
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        val = self.df.loc[index]
        
        # Reshaping the series.
        val = val.values.reshape(val.shape[0], 1)
        label = val[0]
        features = val[1:]
        sample = {
            'label': label,
            'features': features
        }
        return sample

## \_\_len\_\_(self)
This method must return the size of the dataset. 

## \_\_getitem\_\_(self, index):
Index which is passed, is usually the index of the data object that is needed by the larger program.
Better to return a list,dict, tuple or tensor of tensors. Need to check if images can be returned? Possible to read them in, in each case, and parse them into tensors. 

In [4]:
dataset = CustomData(csv_file)

In [5]:
dataloader = DataLoader(dataset=dataset,batch_size=4)

In [6]:
for obj in dataloader:
    print(type(obj))
    print(obj.keys())
    print(len(obj))
    print('Labels : ' + str(obj['label'].shape))
    print('Features : ' + str(obj['features'].shape))
    break

<class 'dict'>
dict_keys(['label', 'features'])
2
Labels : torch.Size([4, 1])
Features : torch.Size([4, 784, 1])


## Interpreting 'output' of dataloader
1. Each obj that is returned from the dataloader is a batch of size 4.
1. The obj is a dict in this case, with the keys as given in the initial __getitem__() method.
1. Since these are batches, the dict values are tensors themselves.
1. The 1st dimension of the tensor gives the position of each data object within the batch. Example: A value of 3 in the 1st dimension indexes the 3rd object in the given batch.