# Pytorch custom dataset class

https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

https://www.youtube.com/watch?v=oWq6aVv5mC8&list=PL98nY_tJQXZln8spB5uTZdKN08mYGkOf2&index=3

Why we need one?

- easy to access dataset values
- easy to batch
- easy to apply transformations
- easy to read; etc. etc.

# for tabular data

In [None]:
import torch

In [None]:
class CustomDataset:

    def __init__(self, data, targets):
        self.data = data
        self.targets = targets

    def __len__(self): # returns no. of samples
        return self.data.shape[0]

    def __getitem__(self, idx):
        sample = self.data[idx, :] # 2D array
        target = self.targets[idx]

        return {
            "sample": torch.tensor(sample, dtype=torch.float),
            "target": torch.tensor(target, dtype=torch.long)
        }

In [None]:
# using sklearn.datasets to convert a dataset to custom dataset for pytorch

from sklearn.datasets import make_classification
data, targets = make_classification(n_samples=800)
data.shape, targets.shape

((800, 20), (800,))

In [None]:
my_dataset = CustomDataset(data=data, targets = targets)

In [None]:
# getting length

len(my_dataset)

800

In [None]:
# getting a single sample

my_dataset[0]

{'sample': tensor([ 1.3503,  0.8401,  1.0510, -1.1730, -1.7619,  0.2561,  0.5544, -1.1724,
          1.4327,  0.8183,  0.4204,  0.8143, -0.9895,  0.5636,  1.0187,  1.3005,
         -0.1188, -0.5897,  0.4319, -0.6092]), 'target': tensor(0)}

In [None]:
# for indexing convert to tensor and then index...for speed

In [None]:
for i in range(3):
    print(my_dataset[i])

{'sample': tensor([ 1.3503,  0.8401,  1.0510, -1.1730, -1.7619,  0.2561,  0.5544, -1.1724,
         1.4327,  0.8183,  0.4204,  0.8143, -0.9895,  0.5636,  1.0187,  1.3005,
        -0.1188, -0.5897,  0.4319, -0.6092]), 'target': tensor(0)}
{'sample': tensor([ 0.3977,  1.5995,  0.3835,  1.8373, -0.0302, -0.1476, -1.4509, -0.3353,
        -0.6139,  0.1969,  0.0725, -2.1643, -0.7953, -1.4174, -0.8418,  0.4276,
        -0.8497, -1.0180, -1.0387, -1.8173]), 'target': tensor(1)}
{'sample': tensor([ 0.6842, -0.7820,  0.6804, -2.0236, -0.0048, -0.6655, -0.0208,  0.6825,
         4.0973,  0.0487,  0.2305,  0.5186,  0.4740,  1.1554, -0.2019, -1.2453,
        -0.2971, -2.9534,  1.9299,  0.2426]), 'target': tensor(0)}
