# Building a Custom Data Provider
Data providers take data and return data in a format that can be used by the model. In this case the input data will be the IAM Dataset that we preprocess in ```/data/IAM```

X-rayLaser gives [instructions](https://github.com/X-rayLaser/pytorch-handwriting-synthesis-toolkit?tab=readme-ov-file#implementing-custom-data-provider). This is my attempt to implement a custom data provider along with my learnings.

I set out to gain a better understanding of:
- Classes in python
- Generators
- Iterators
- Yield keyword
- Bounded Iterator
- Transformations to the data
- Format of the data for the model

# Learnings
## 1. Classes in python
Classes allow creating complex datastructures that can contain both data and functions.
- Contain __init__ function that initializes the class
- ```self``` refers to the instance of the class
- Methods below to the class and require ```self``` as the first argument
- Three primary use-cases of classes:
    - Represent real-world objects that have attributes and behaviours
    - Bundling related functions and data together
    - Want modular reusable code
- Important difference between classes and functions is that classes are "stateful" meaning they can retain state across method calls through instance attributes. Functions are "stateless" and do not retain state across calls.
- Static methods (denoted ```@staticmethod``` above the function) are methods that belong to the class and not the instance. They can be called without creating an instance of the class and do not have access to the instance attributes. ```self``` is not required as an argument. 
- Class can inheret from another class by passing the parent class as an argument in the class definition.
    ```
    class ChildClass(ParentClass):
        # rest of the class
    ```

## 2. Iterators, Generators, and the Yield Keyword
Iterators are objects that can be iterated over that return one element at a time.

They must contain the ```__iter__``` method which returns the iterator object itself.

Must also contain either the ```__next__``` method use the ```yield``` keyword to return the next value. 

The ```yield``` keyword turns a function into a generator. When yield is called, it pauses the function, saving its state, and returns the yielded value. When the generator is iterated over, it resumes execution immediately after the last yield run.

Practically, you should use the ```__next__``` method for more control over the iteration process and ```yield``` for simple iteration.

Generators are useful for large datasets that can't fit into memory. They can be used to load data in chunks and process them one at a time.

```
Example Generator
def my_generator(start, end):
    current = start
    while current < end:
        yield current
        current += 1

# Using a generator
for num in my_generator(1, 5):
    print(num)  # Output: 1 2 3 4
```

# Custom Data Provider

In [1]:
# First an Iterator Object to return the strokes and transcript

class CustomIterator:
    def __init__(self, data_dir):
        # Initializes the CustomIAM with the directory containing data files.
        self.data_dir = data_dir

    def __iter__(self):
        # Yields strokes and transcript from each file in the directory.
        for file in self.file_names(self.data_dir):
            try:
                strokes, transcript = self.get_data(file)
                yield strokes, transcript
            except (FileNotFoundError, pickle.UnpicklingError) as e:
                print(f"Error loading {file}: {e}")

    @staticmethod
    def file_names(path):
        # Generates file names in the given directory.
        try:
            for x in os.listdir(path):
                yield x
        except FileNotFoundError as e:
            print(f"Directory not found: {e}")
            return

    def get_data(self, file):
        # Loads data from a pickle file.
        file_path = os.path.join(self.data_dir, file)
        with open(file_path, 'rb') as f:
            content = pickle.load(f)
        
        transcript = content['transcript']
        strokes = content['strokes']

        return strokes, transcript

In [2]:
# And second a Custom Data Provider that uses the Iterator Object
class CustomDataProvider(DataSplittingProvider):
    # The class inherits from the DataSplittingProvider class.
    # More on the DataSplittingProvider below.
    name = 'custom_iam'
    # The name of the data provider. This is used to select the provider when using the CLI to prepare the data.
    
    def __init__(self, training_data_size, validation_data_size=0, data_dir=None):
        if data_dir is None:
            raise ValueError("No data directory specified. Please provide a valid data directory.")
        
        training_data_size, validation_data_size = self._parse_args(training_data_size,
                                                                    validation_data_size)
     
        # Initializes the CustomIterator with the data directory.
        iterator = self.get_generator(training_data_size, validation_data_size, data_dir)
        
        # Initializes the DataSplittingProvider with the CustomIterator
        super().__init__(iterator, training_data_size, validation_data_size)

    @staticmethod
    def _parse_args(training_data_size, validation_data_size):
        return int(training_data_size), int(validation_data_size)

    # The get_generator method returns a generator that yields strokes and transcript from the IAM dataset.
    def get_generator(self, training_data_size, validation_data_size, data_dir):
        
        # Initializes the CustomIAM with the data directory.
        db = CustomIAM(data_dir)

        # If validation data size is specified, the generator yields a total of training_data_size + validation_data_size examples.
        if validation_data_size:
            num_examples = training_data_size + validation_data_size
            it = bounded_iterator(db, num_examples)
        else:
            it = db.__iter__()

        for strokes, text in it:
            yield strokes, text


NameError: name 'DataSplittingProvider' is not defined