Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/guidellm/data/deserializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
InMemoryJsonStrDatasetDeserializer,
)
from .synthetic import (
SyntheticTextDataset,
SyntheticTextDatasetConfig,
SyntheticTextDatasetDeserializer,
SyntheticTextGenerator,
SyntheticTextPrefixBucketConfig,
)

Expand All @@ -44,9 +44,9 @@
"InMemoryJsonStrDatasetDeserializer",
"JSONFileDatasetDeserializer",
"ParquetFileDatasetDeserializer",
"SyntheticTextDataset",
"SyntheticTextDatasetConfig",
"SyntheticTextDatasetDeserializer",
"SyntheticTextGenerator",
"SyntheticTextPrefixBucketConfig",
"TarFileDatasetDeserializer",
"TextFileDatasetDeserializer",
Expand Down
145 changes: 112 additions & 33 deletions src/guidellm/data/deserializers/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from random import Random
from typing import Any

import numpy as np
import yaml
from datasets import Features, IterableDataset, Value
from datasets import DatasetInfo, Features, IterableDataset, Value
from datasets.iterable_dataset import _BaseExamplesIterable
from faker import Faker
from pydantic import ConfigDict, Field, ValidationError, model_validator
from transformers import PreTrainedTokenizerBase
Expand All @@ -21,9 +23,9 @@
from guidellm.utils import IntegerRangeSampler

__all__ = [
"SyntheticTextDataset",
"SyntheticTextDatasetConfig",
"SyntheticTextDatasetDeserializer",
"SyntheticTextGenerator",
"SyntheticTextPrefixBucketConfig",
]

Expand Down Expand Up @@ -121,29 +123,34 @@ def check_prefix_options(self) -> SyntheticTextDatasetConfig:
return self


class SyntheticTextGenerator:
class _SyntheticTextExamplesIterable(_BaseExamplesIterable):
"""Custom examples iterable for synthetic text generation."""

def __init__(
self,
config: SyntheticTextDatasetConfig,
processor: PreTrainedTokenizerBase,
random_seed: int = 42,
random_seed: int,
):
super().__init__()
self.config = config
self.processor = processor
self.random_seed = random_seed
self.iteration_count = 0

def __iter__(self) -> Iterator[dict[str, Any]]:
samples_generated = 0
def __iter__(self) -> Iterator[tuple[int, dict[str, Any]]]:
iter_random_seed = self.random_seed + self.iteration_count
self.iteration_count += 1

faker = Faker()
faker.seed_instance(self.random_seed)
faker.seed_instance(iter_random_seed)
prompt_tokens_sampler = iter(
IntegerRangeSampler(
average=self.config.prompt_tokens,
variance=self.config.prompt_tokens_stdev,
min_value=self.config.prompt_tokens_min,
max_value=self.config.prompt_tokens_max,
random_seed=self.random_seed,
random_seed=iter_random_seed,
)
)
output_tokens_sampler = iter(
Expand All @@ -152,27 +159,77 @@ def __iter__(self) -> Iterator[dict[str, Any]]:
variance=self.config.output_tokens_stdev,
min_value=self.config.output_tokens_min,
max_value=self.config.output_tokens_max,
random_seed=self.random_seed + 1, # ensure diff dist from prompts
random_seed=iter_random_seed + 1, # ensure diff dist from prompts
)
)

# Create a shared prefix if specified
rand = Random(self.random_seed + 3)
rand = Random(iter_random_seed + 3)
prefix_iter = self._create_prefix_iter(faker, rand)
samples_count = 0

while True:
prompt_tokens_count = next(prompt_tokens_sampler)
output_tokens_count = next(output_tokens_sampler)

yield {
"prefix": next(prefix_iter),
"prompt": self._create_prompt(
prompt_tokens_count, faker, f"{samples_generated} "
),
"prompt_tokens_count": prompt_tokens_count,
"output_tokens_count": output_tokens_count,
yield (
samples_count,
{
"prefix": next(prefix_iter),
"prompt": self._create_prompt(
prompt_tokens_count,
faker,
f"{self.iteration_count} {samples_count} ",
),
"prompt_tokens_count": prompt_tokens_count,
"output_tokens_count": output_tokens_count,
},
)
samples_count += 1

@property
def is_typed(self) -> bool:
return True

@property
def features(self) -> Features:
return Features(
{
"prefix": Value("string"),
"prompt": Value("string"),
"prompt_tokens_count": Value("int32"),
"output_tokens_count": Value("int32"),
}
samples_generated += 1
)

@property
def num_shards(self) -> int:
return 1

def shuffle_data_sources(
self,
generator: np.random.Generator, # noqa: ARG002
) -> _SyntheticTextExamplesIterable:
"""Return self since synthetic data doesn't have fixed sources to shuffle."""
return self

def shard_data_sources(
self,
num_shards: int, # noqa: ARG002
index: int, # noqa: ARG002
contiguous: bool = True, # noqa: ARG002
) -> _SyntheticTextExamplesIterable:
"""Return self since synthetic data generation is infinite and stateless."""
return self

def load_state_dict(self, state_dict: dict) -> None:
"""Load the state from a state dict."""
self.iteration_count = state_dict.get("iteration_count", 0)

def _init_state_dict(self) -> dict:
"""Initialize the state dict for the iterable."""
self._state_dict = {"iteration_count": self.iteration_count}
return self._state_dict

def _create_prompt(
self, prompt_tokens_count: int, faker: Faker, unique: str = ""
Expand Down Expand Up @@ -226,6 +283,39 @@ def _create_prefix_iter(self, faker: Faker, rand: Random) -> Iterator[str]:
yield rand.choice(prefixes)


class SyntheticTextDataset(IterableDataset):
def __init__(
self,
config: SyntheticTextDatasetConfig,
processor: PreTrainedTokenizerBase,
random_seed: int = 42,
):
self.config = config
self.processor = processor
self.random_seed = random_seed

# Create the examples iterable
ex_iterable = _SyntheticTextExamplesIterable(
config=config,
processor=processor,
random_seed=random_seed,
)

# Initialize parent with proper ex_iterable
super().__init__(
ex_iterable=ex_iterable,
info=DatasetInfo(
description="Synthetic text dataset generator",
features=ex_iterable.features,
),
)

def set_epoch(self, epoch: int):
"""Set the epoch for the dataset iteration."""
if isinstance(self._ex_iterable, _SyntheticTextExamplesIterable):
self._ex_iterable.iteration_count = epoch


@DatasetDeserializerFactory.register("synthetic_text")
class SyntheticTextDatasetDeserializer(DatasetDeserializer):
def __call__(
Expand Down Expand Up @@ -254,21 +344,10 @@ def __call__(
f"got {data}"
)

return IterableDataset.from_generator(
SyntheticTextGenerator,
gen_kwargs={
"config": data,
"processor": processor_factory(),
"random_seed": random_seed,
},
features=Features(
{
"prefix": Value("string"),
"prompt": Value("string"),
"prompt_tokens_count": Value("int32"),
"output_tokens_count": Value("int32"),
}
),
return SyntheticTextDataset(
config=data,
processor=processor_factory(),
random_seed=random_seed,
)

def _load_config_dict(self, data: Any) -> SyntheticTextDatasetConfig | None:
Expand Down
24 changes: 22 additions & 2 deletions src/guidellm/data/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
self.precache: list[Any] | None = (
list(self.generator(data_samples)) if data_samples else None
)
self.epoch = 0

def __iter__(self) -> Iterator[DataT]:
worker_info = torch.utils.data.get_worker_info()
Expand All @@ -74,18 +75,29 @@ def __iter__(self) -> Iterator[DataT]:
if (index + worker_index) % worker_modulus == 0:
yield item
else:
yield from self.generator(modulus=worker_modulus, offset=worker_index)
yield from self.generator(
modulus=worker_modulus, offset=worker_index, epoch=self.epoch
)

def set_epoch(self, epoch: int):
self.epoch = epoch

def generator(
self,
max_items: int | None = None,
modulus: int | None = None,
offset: int | None = None,
epoch: int = 0,
) -> Iterator[DataT]:
gen_count = 0

with contextlib.suppress(StopIteration):
dataset_iters = [iter(dataset) for dataset in self.datasets]
dataset_iters = []
for dataset in self.datasets:
if hasattr(dataset, "set_epoch"):
with contextlib.suppress(Exception):
dataset.set_epoch(epoch)
dataset_iters.append(iter(dataset))

while max_items is None or gen_count < max_items:
try:
Expand Down Expand Up @@ -152,6 +164,7 @@ def __init__(
"num_workers": num_workers,
"random_seed": random_seed,
}
self.epoch = 0

super().__init__(
dataset=iterator,
Expand All @@ -163,6 +176,13 @@ def __init__(
**kwargs,
)

def __iter__(self):
if isinstance(self.dataset, DatasetsIterator):
self.dataset.set_epoch(self.epoch)
self.epoch += 1

return super().__iter__()

@property
def info(self) -> dict[str, Any]:
return self._info
31 changes: 12 additions & 19 deletions tests/unit/data/deserializers/test_synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

from guidellm.data.deserializers.deserializer import DataNotSupportedError
from guidellm.data.deserializers.synthetic import (
SyntheticTextDataset,
SyntheticTextDatasetConfig,
SyntheticTextDatasetDeserializer,
SyntheticTextGenerator,
SyntheticTextPrefixBucketConfig,
)

Expand Down Expand Up @@ -264,9 +264,7 @@ def test_generator_initialization(self, simple_config, mock_tokenizer):

### WRITTEN BY AI ###
"""
generator = SyntheticTextGenerator(
simple_config, mock_tokenizer, random_seed=42
)
generator = SyntheticTextDataset(simple_config, mock_tokenizer, random_seed=42)

assert generator.config == simple_config
assert generator.processor == mock_tokenizer
Expand All @@ -278,9 +276,7 @@ def test_basic_iteration(self, simple_config, mock_tokenizer):

### WRITTEN BY AI ###
"""
generator = SyntheticTextGenerator(
simple_config, mock_tokenizer, random_seed=42
)
generator = SyntheticTextDataset(simple_config, mock_tokenizer, random_seed=42)

items = []
for i, item in enumerate(generator):
Expand Down Expand Up @@ -310,20 +306,21 @@ def test_create_prompt_method(self, simple_config, mock_tokenizer):
"""
from faker import Faker

generator = SyntheticTextGenerator(
simple_config, mock_tokenizer, random_seed=42
)
generator = SyntheticTextDataset(simple_config, mock_tokenizer, random_seed=42)
faker = Faker()
faker.seed_instance(42)

# Access the _create_prompt method through the examples iterable
ex_iterable = generator._ex_iterable

# Test normal case
result = generator._create_prompt(5, faker, "unique_prefix ")
result = ex_iterable._create_prompt(5, faker, "unique_prefix ")
assert isinstance(result, str)
# The result should be the decoded tokens (token_0 token_1 etc.) due to our mock
assert "token_" in result

# Test zero tokens
result = generator._create_prompt(0, faker)
result = ex_iterable._create_prompt(0, faker)
assert result == ""

@pytest.mark.regression
Expand All @@ -332,7 +329,7 @@ def test_prefix_tokens_integration(self, config_with_prefix, mock_tokenizer):

### WRITTEN BY AI ###
"""
generator = SyntheticTextGenerator(
generator = SyntheticTextDataset(
config_with_prefix, mock_tokenizer, random_seed=42
)

Expand All @@ -353,12 +350,8 @@ def test_random_seeding_consistency(self, simple_config, mock_tokenizer):
### WRITTEN BY AI ###
"""
# Create two generators with same seed
generator1 = SyntheticTextGenerator(
simple_config, mock_tokenizer, random_seed=42
)
generator2 = SyntheticTextGenerator(
simple_config, mock_tokenizer, random_seed=42
)
generator1 = SyntheticTextDataset(simple_config, mock_tokenizer, random_seed=42)
generator2 = SyntheticTextDataset(simple_config, mock_tokenizer, random_seed=42)

items1 = []
items2 = []
Expand Down
Loading