## Video 16 - PyTorch Datasets and DataLoaders - Training Set Exploration for Deep Learning and AI

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

In [2]:
train_set = torchvision.datasets.FashionMNIST(
    root='./data/FashionMNIST',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

In [3]:
train_set.__getitem__(3)

(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.1294, 0.3765, 0.6863, 0.6118, 0.2510, 0.0549, 0.2118, 0.5373,
           0.8000, 0.7608, 0.4000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2863, 0.7294,
           0.6941, 0.7176, 0.6863, 0.7373, 0.9098, 1.0000, 0.8745, 0.8588,
           0.7608, 0.7020, 0.7294, 0.8353, 0.5725, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1373, 0.6392, 0.5490,
           0.5882, 0.5961, 0.5882, 0.5725, 0.6863, 0.6863, 0.6784, 0.6706,
           0.6118, 0.5961, 0.5804, 0.5059, 0.6118, 0.5490, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5882, 0.5569, 0.5490,
           0.5961, 0.6275, 0.6118, 0.5725, 0.5569, 0.4980, 0.5294, 0.5216,
           0.5490, 0.5490, 0.5373, 0.5216, 

In [4]:
type(train_set[0])

tuple

In [5]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=10)

In [6]:
type(train_loader)

torch.utils.data.dataloader.DataLoader

### Make our own dataset 

In [62]:
from torch.utils.data import Dataset, TensorDataset

class DatasetAttempt(Dataset):
    def __init__(self, samples, labels): # samples and labels are lists
        self.X = torch.tensor(samples)
        self.y = torch.tensor(labels)
    
    def __getitem__(self, index):
        return (self.X[index], self.y[index])
    
    def __len__(self):
        return len(self.X)

In [63]:
data_samples = [2,3,4,3,4,2,1,5,7]
data_labels = ['two', 'three', 'four', 'three', 'four', 'two', 'one', 'five', 'seven']

In [64]:
from sklearn import preprocessing
import torch

le = preprocessing.LabelEncoder()
targets = le.fit_transform(data_labels)

In [65]:
print(targets, type(targets), len(targets), len(data_labels))

[5 4 1 4 1 5 2 0 3] <class 'numpy.ndarray'> 9 9


In [66]:
t_data = DatasetAttempt(data_samples, targets)

In [67]:
t_data.__getitem__(7)

(tensor(5), tensor(0))

In [68]:
t_data.__getitem__(5)

(tensor(2), tensor(5))

In [70]:
t_data.__getitem__(8)

(tensor(7), tensor(3))

In [71]:
# or 
t_data[2]

(tensor(4), tensor(1))

In [72]:
print(len(t_data)) 
t_data.__len__()

9


9

In [80]:
print(t_data.X, '\n', t_data.y)

tensor([2, 3, 4, 3, 4, 2, 1, 5, 7]) 
 tensor([5, 4, 1, 4, 1, 5, 2, 0, 3])


In [85]:
# useful function 
t_data.y.bincount()

tensor([1, 2, 1, 1, 2, 2])

#### Alternatively, we could have just used the "TensorDataset" class: 

In [91]:
t_data_from_tensor = TensorDataset(torch.tensor(data_samples), torch.tensor(targets))

### DataLoader

In [110]:
t_loader = torch.utils.data.DataLoader(t_data, batch_size=5)

In [111]:
t_loader.__dict__

{'dataset': <__main__.DatasetAttempt at 0x1a35718a50>,
 'num_workers': 0,
 'pin_memory': False,
 'timeout': 0,
 'worker_init_fn': None,
 '_DataLoader__multiprocessing_context': None,
 '_dataset_kind': 0,
 'batch_size': 5,
 'drop_last': False,
 'sampler': <torch.utils.data.sampler.SequentialSampler at 0x1a362dc750>,
 'batch_sampler': <torch.utils.data.sampler.BatchSampler at 0x1a362dc850>,
 'generator': None,
 'collate_fn': <function torch.utils.data._utils.collate.default_collate(batch)>,
 '_DataLoader__initialized': True,
 '_IterableDataset_len_called': None}

In [112]:
len(t_loader)

2

In [113]:
batch = next(iter(t_loader))

In [114]:
type(batch)

list

In [118]:
batch

[tensor([2, 3, 4, 3, 4]), tensor([5, 4, 1, 4, 1])]

More convenienently, use unpacking: 

In [119]:
samples, labels = batch

In [120]:
print(samples, '\n', labels)

tensor([2, 3, 4, 3, 4]) 
 tensor([5, 4, 1, 4, 1])
