In [None]:
from torch.utils.data import Dataset, DataLoader
import numpy as np

n_data = 10
raw_data = np.random.rand(n_data, 3)  # Example raw data
labels = np.random.randint(0, 2, size=(n_data,))  # Example labels

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = self._group_items(data, n_items=2, drop_last=True)
        self.labels = self._group_items(labels, n_items=2, drop_last=True)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return {"x": self.data[idx], 
                "y": self.labels[idx]}
    
    def _group_items(self, data, n_items, drop_last):
        """
        Groups a list of items into sublists of size n_items.
        
        Args:
            data (List[Any]): The list to group.
            n_items (int): Size of each group.
            drop_last (bool): If True, discard the final group if it's smaller than n_items.
                            If False, include the final smaller group.
        
        Returns:
            List[List[Any]]: A list of grouped sublists.
        """
        grouped = [data[i:i + n_items] for i in range(0, len(data), n_items)]
        
        if drop_last and grouped and len(grouped[-1]) < n_items:
            grouped.pop()  # Remove the last incomplete group

        return grouped

ds = CustomDataset(data=raw_data, labels=labels)

print(ds[0])

dl = DataLoader(ds, batch_size=5, shuffle=True)

for batch in dl:
    x = batch['x']
    y = batch['y']
    print(x.shape)
    print(y.shape)
    break

{'x': array([[0.10844381, 0.51910783, 0.06583172],
       [0.93830873, 0.34260713, 0.44766689]]), 'y': array([0, 0])}
torch.Size([5, 2, 3])
torch.Size([5, 2])
