In [1]:
from abc import abstractmethod

import numpy as np
from torch.utils.data import DataLoader

In [2]:
class Dataset:
    """ Abstract dataset - Used for both Keras and Pytorch"""
        
    @abstractmethod
    def __getitem__(self, idx):
        """Gets batch at position `index`.
        
        Parameters
        ----------
            idx: index position of the batch in the data.
            
        Returns
        -------
            A batch
        """
        raise NotImplementedError

    @abstractmethod
    def __len__(self):
        """Length of the dataset.
        
        Returns
        -------
            The number of samples in the data.
        """
        raise NotImplementedError

    def on_epoch_end(self):
        """ Keras method called at the end of every epoch. """
        pass

    def __iter__(self):
        """Create a generator that iterates over the data."""
        for item in (self[i] for i in range(len(self))):
            yield item

In [3]:
class RandomData(Dataset):
    
    def __init__(self, num_samples: int, num_classes: int):
        self.data = np.random.randn(num_samples)
        self.label = np.random.randint(num_classes, size=num_samples)
        
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]

In [12]:
data = RandomData(num_samples=10, num_classes=10)

In [13]:
# Using Pytorch DataLoader
dataloader = DataLoader(data, batch_size=1)
next(iter(dataloader))

[tensor([1.1348], dtype=torch.float64), tensor([5])]

In [14]:
# Using the Dataset generator
next(iter(data))

(1.134762453800256, 5)