diff --git a/src/guidellm/data/deserializers/__init__.py b/src/guidellm/data/deserializers/__init__.py index 1062f2b75..e24bfd5d7 100644 --- a/src/guidellm/data/deserializers/__init__.py +++ b/src/guidellm/data/deserializers/__init__.py @@ -22,9 +22,9 @@ InMemoryJsonStrDatasetDeserializer, ) from .synthetic import ( + SyntheticTextDataset, SyntheticTextDatasetConfig, SyntheticTextDatasetDeserializer, - SyntheticTextGenerator, SyntheticTextPrefixBucketConfig, ) @@ -44,9 +44,9 @@ "InMemoryJsonStrDatasetDeserializer", "JSONFileDatasetDeserializer", "ParquetFileDatasetDeserializer", + "SyntheticTextDataset", "SyntheticTextDatasetConfig", "SyntheticTextDatasetDeserializer", - "SyntheticTextGenerator", "SyntheticTextPrefixBucketConfig", "TarFileDatasetDeserializer", "TextFileDatasetDeserializer", diff --git a/src/guidellm/data/deserializers/synthetic.py b/src/guidellm/data/deserializers/synthetic.py index 6e0984620..c8ec2831c 100644 --- a/src/guidellm/data/deserializers/synthetic.py +++ b/src/guidellm/data/deserializers/synthetic.py @@ -6,8 +6,10 @@ from random import Random from typing import Any +import numpy as np import yaml -from datasets import Features, IterableDataset, Value +from datasets import DatasetInfo, Features, IterableDataset, Value +from datasets.iterable_dataset import _BaseExamplesIterable from faker import Faker from pydantic import ConfigDict, Field, ValidationError, model_validator from transformers import PreTrainedTokenizerBase @@ -21,9 +23,9 @@ from guidellm.utils import IntegerRangeSampler __all__ = [ + "SyntheticTextDataset", "SyntheticTextDatasetConfig", "SyntheticTextDatasetDeserializer", - "SyntheticTextGenerator", "SyntheticTextPrefixBucketConfig", ] @@ -121,29 +123,34 @@ def check_prefix_options(self) -> SyntheticTextDatasetConfig: return self -class SyntheticTextGenerator: +class _SyntheticTextExamplesIterable(_BaseExamplesIterable): + """Custom examples iterable for synthetic text generation.""" + def __init__( self, config: SyntheticTextDatasetConfig, processor: PreTrainedTokenizerBase, - random_seed: int = 42, + random_seed: int, ): + super().__init__() self.config = config self.processor = processor self.random_seed = random_seed + self.iteration_count = 0 - def __iter__(self) -> Iterator[dict[str, Any]]: - samples_generated = 0 + def __iter__(self) -> Iterator[tuple[int, dict[str, Any]]]: + iter_random_seed = self.random_seed + self.iteration_count + self.iteration_count += 1 faker = Faker() - faker.seed_instance(self.random_seed) + faker.seed_instance(iter_random_seed) prompt_tokens_sampler = iter( IntegerRangeSampler( average=self.config.prompt_tokens, variance=self.config.prompt_tokens_stdev, min_value=self.config.prompt_tokens_min, max_value=self.config.prompt_tokens_max, - random_seed=self.random_seed, + random_seed=iter_random_seed, ) ) output_tokens_sampler = iter( @@ -152,27 +159,77 @@ def __iter__(self) -> Iterator[dict[str, Any]]: variance=self.config.output_tokens_stdev, min_value=self.config.output_tokens_min, max_value=self.config.output_tokens_max, - random_seed=self.random_seed + 1, # ensure diff dist from prompts + random_seed=iter_random_seed + 1, # ensure diff dist from prompts ) ) # Create a shared prefix if specified - rand = Random(self.random_seed + 3) + rand = Random(iter_random_seed + 3) prefix_iter = self._create_prefix_iter(faker, rand) + samples_count = 0 while True: prompt_tokens_count = next(prompt_tokens_sampler) output_tokens_count = next(output_tokens_sampler) - yield { - "prefix": next(prefix_iter), - "prompt": self._create_prompt( - prompt_tokens_count, faker, f"{samples_generated} " - ), - "prompt_tokens_count": prompt_tokens_count, - "output_tokens_count": output_tokens_count, + yield ( + samples_count, + { + "prefix": next(prefix_iter), + "prompt": self._create_prompt( + prompt_tokens_count, + faker, + f"{self.iteration_count} {samples_count} ", + ), + "prompt_tokens_count": prompt_tokens_count, + "output_tokens_count": output_tokens_count, + }, + ) + samples_count += 1 + + @property + def is_typed(self) -> bool: + return True + + @property + def features(self) -> Features: + return Features( + { + "prefix": Value("string"), + "prompt": Value("string"), + "prompt_tokens_count": Value("int32"), + "output_tokens_count": Value("int32"), } - samples_generated += 1 + ) + + @property + def num_shards(self) -> int: + return 1 + + def shuffle_data_sources( + self, + generator: np.random.Generator, # noqa: ARG002 + ) -> _SyntheticTextExamplesIterable: + """Return self since synthetic data doesn't have fixed sources to shuffle.""" + return self + + def shard_data_sources( + self, + num_shards: int, # noqa: ARG002 + index: int, # noqa: ARG002 + contiguous: bool = True, # noqa: ARG002 + ) -> _SyntheticTextExamplesIterable: + """Return self since synthetic data generation is infinite and stateless.""" + return self + + def load_state_dict(self, state_dict: dict) -> None: + """Load the state from a state dict.""" + self.iteration_count = state_dict.get("iteration_count", 0) + + def _init_state_dict(self) -> dict: + """Initialize the state dict for the iterable.""" + self._state_dict = {"iteration_count": self.iteration_count} + return self._state_dict def _create_prompt( self, prompt_tokens_count: int, faker: Faker, unique: str = "" @@ -226,6 +283,39 @@ def _create_prefix_iter(self, faker: Faker, rand: Random) -> Iterator[str]: yield rand.choice(prefixes) +class SyntheticTextDataset(IterableDataset): + def __init__( + self, + config: SyntheticTextDatasetConfig, + processor: PreTrainedTokenizerBase, + random_seed: int = 42, + ): + self.config = config + self.processor = processor + self.random_seed = random_seed + + # Create the examples iterable + ex_iterable = _SyntheticTextExamplesIterable( + config=config, + processor=processor, + random_seed=random_seed, + ) + + # Initialize parent with proper ex_iterable + super().__init__( + ex_iterable=ex_iterable, + info=DatasetInfo( + description="Synthetic text dataset generator", + features=ex_iterable.features, + ), + ) + + def set_epoch(self, epoch: int): + """Set the epoch for the dataset iteration.""" + if isinstance(self._ex_iterable, _SyntheticTextExamplesIterable): + self._ex_iterable.iteration_count = epoch + + @DatasetDeserializerFactory.register("synthetic_text") class SyntheticTextDatasetDeserializer(DatasetDeserializer): def __call__( @@ -254,21 +344,10 @@ def __call__( f"got {data}" ) - return IterableDataset.from_generator( - SyntheticTextGenerator, - gen_kwargs={ - "config": data, - "processor": processor_factory(), - "random_seed": random_seed, - }, - features=Features( - { - "prefix": Value("string"), - "prompt": Value("string"), - "prompt_tokens_count": Value("int32"), - "output_tokens_count": Value("int32"), - } - ), + return SyntheticTextDataset( + config=data, + processor=processor_factory(), + random_seed=random_seed, ) def _load_config_dict(self, data: Any) -> SyntheticTextDatasetConfig | None: diff --git a/src/guidellm/data/loaders.py b/src/guidellm/data/loaders.py index 5a7c9d4d6..fbeaf0583 100644 --- a/src/guidellm/data/loaders.py +++ b/src/guidellm/data/loaders.py @@ -63,6 +63,7 @@ def __init__( self.precache: list[Any] | None = ( list(self.generator(data_samples)) if data_samples else None ) + self.epoch = 0 def __iter__(self) -> Iterator[DataT]: worker_info = torch.utils.data.get_worker_info() @@ -74,18 +75,29 @@ def __iter__(self) -> Iterator[DataT]: if (index + worker_index) % worker_modulus == 0: yield item else: - yield from self.generator(modulus=worker_modulus, offset=worker_index) + yield from self.generator( + modulus=worker_modulus, offset=worker_index, epoch=self.epoch + ) + + def set_epoch(self, epoch: int): + self.epoch = epoch def generator( self, max_items: int | None = None, modulus: int | None = None, offset: int | None = None, + epoch: int = 0, ) -> Iterator[DataT]: gen_count = 0 with contextlib.suppress(StopIteration): - dataset_iters = [iter(dataset) for dataset in self.datasets] + dataset_iters = [] + for dataset in self.datasets: + if hasattr(dataset, "set_epoch"): + with contextlib.suppress(Exception): + dataset.set_epoch(epoch) + dataset_iters.append(iter(dataset)) while max_items is None or gen_count < max_items: try: @@ -152,6 +164,7 @@ def __init__( "num_workers": num_workers, "random_seed": random_seed, } + self.epoch = 0 super().__init__( dataset=iterator, @@ -163,6 +176,13 @@ def __init__( **kwargs, ) + def __iter__(self): + if isinstance(self.dataset, DatasetsIterator): + self.dataset.set_epoch(self.epoch) + self.epoch += 1 + + return super().__iter__() + @property def info(self) -> dict[str, Any]: return self._info diff --git a/tests/unit/data/deserializers/test_synthetic.py b/tests/unit/data/deserializers/test_synthetic.py index de95227a9..468c4c8e9 100644 --- a/tests/unit/data/deserializers/test_synthetic.py +++ b/tests/unit/data/deserializers/test_synthetic.py @@ -13,9 +13,9 @@ from guidellm.data.deserializers.deserializer import DataNotSupportedError from guidellm.data.deserializers.synthetic import ( + SyntheticTextDataset, SyntheticTextDatasetConfig, SyntheticTextDatasetDeserializer, - SyntheticTextGenerator, SyntheticTextPrefixBucketConfig, ) @@ -264,9 +264,7 @@ def test_generator_initialization(self, simple_config, mock_tokenizer): ### WRITTEN BY AI ### """ - generator = SyntheticTextGenerator( - simple_config, mock_tokenizer, random_seed=42 - ) + generator = SyntheticTextDataset(simple_config, mock_tokenizer, random_seed=42) assert generator.config == simple_config assert generator.processor == mock_tokenizer @@ -278,9 +276,7 @@ def test_basic_iteration(self, simple_config, mock_tokenizer): ### WRITTEN BY AI ### """ - generator = SyntheticTextGenerator( - simple_config, mock_tokenizer, random_seed=42 - ) + generator = SyntheticTextDataset(simple_config, mock_tokenizer, random_seed=42) items = [] for i, item in enumerate(generator): @@ -310,20 +306,21 @@ def test_create_prompt_method(self, simple_config, mock_tokenizer): """ from faker import Faker - generator = SyntheticTextGenerator( - simple_config, mock_tokenizer, random_seed=42 - ) + generator = SyntheticTextDataset(simple_config, mock_tokenizer, random_seed=42) faker = Faker() faker.seed_instance(42) + # Access the _create_prompt method through the examples iterable + ex_iterable = generator._ex_iterable + # Test normal case - result = generator._create_prompt(5, faker, "unique_prefix ") + result = ex_iterable._create_prompt(5, faker, "unique_prefix ") assert isinstance(result, str) # The result should be the decoded tokens (token_0 token_1 etc.) due to our mock assert "token_" in result # Test zero tokens - result = generator._create_prompt(0, faker) + result = ex_iterable._create_prompt(0, faker) assert result == "" @pytest.mark.regression @@ -332,7 +329,7 @@ def test_prefix_tokens_integration(self, config_with_prefix, mock_tokenizer): ### WRITTEN BY AI ### """ - generator = SyntheticTextGenerator( + generator = SyntheticTextDataset( config_with_prefix, mock_tokenizer, random_seed=42 ) @@ -353,12 +350,8 @@ def test_random_seeding_consistency(self, simple_config, mock_tokenizer): ### WRITTEN BY AI ### """ # Create two generators with same seed - generator1 = SyntheticTextGenerator( - simple_config, mock_tokenizer, random_seed=42 - ) - generator2 = SyntheticTextGenerator( - simple_config, mock_tokenizer, random_seed=42 - ) + generator1 = SyntheticTextDataset(simple_config, mock_tokenizer, random_seed=42) + generator2 = SyntheticTextDataset(simple_config, mock_tokenizer, random_seed=42) items1 = [] items2 = []