# A.6 Setting up efficient data loaders (using PyTorch)

In [3]:
pip freeze | grep torch

torch==2.3.1
Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch

In [4]:
X_train = torch.tensor([
    [-1.2, 3.1],
    [-0.9, 2.9],
    [-0.5, 2.6],
    [2.3, -1.1],
    [2.7, -1.5]
])
y_train = torch.tensor([0, 0, 0, 1, 1])
 
X_test = torch.tensor([
    [-0.8, 2.8],
    [2.6, -1.6],
])
y_test = torch.tensor([0, 1])

In [5]:
from torch.utils.data import Dataset

class ToyDataset(Dataset):
    def __init__(self, x, y):
        self.features = x
        self.labels = y

    def __getitem__(self, index):
        one_x = self.features[index]
        one_y = self.labels[index]
        return one_x, one_y
    
    def __len__(self):
        return self.labels.shape[0]
    
train_ds = ToyDataset(X_train, y_train)
test_ds = ToyDataset(X_test, y_test)

In [6]:
print(len(train_ds))

5


In [10]:
from torch.utils.data import DataLoader

torch.manual_seed(123)

train_loader = DataLoader(
    dataset=train_ds,
    batch_size=3,
    shuffle=True,
    num_workers=0
)

test_loader = DataLoader(
    dataset=test_ds,
    batch_size=2,
    shuffle=False,
    num_workers=0
)

In [11]:
for idx, (x, y) in enumerate(train_loader):
    print(f"Batch {idx+1}:", x, y)

Batch 1: tensor([[ 2.3000, -1.1000],
        [-0.9000,  2.9000],
        [-1.2000,  3.1000]]) tensor([1, 0, 0])
Batch 2: tensor([[-0.5000,  2.6000],
        [ 2.7000, -1.5000]]) tensor([0, 1])


In [12]:
for idx, (x, y) in enumerate(test_loader):
    print(f"Batch {idx+1}:", x, y)

Batch 1: tensor([[-0.8000,  2.8000],
        [ 2.6000, -1.6000]]) tensor([0, 1])


In [13]:
train_loader = DataLoader(
    dataset=train_ds,
    batch_size=2,
    shuffle=True,
    num_workers=0,
    drop_last=True
)

In [14]:
for idx, (x, y) in enumerate(train_loader):
    print(f"Batch {idx+1}:", x, y)

Batch 1: tensor([[-0.5000,  2.6000],
        [-0.9000,  2.9000]]) tensor([0, 0])
Batch 2: tensor([[-1.2000,  3.1000],
        [ 2.3000, -1.1000]]) tensor([0, 1])
