Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 21 additions & 16 deletions src/llmcompressor/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import multiprocessing
import re
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable

import torch
from datasets import Dataset
Expand All @@ -24,10 +24,10 @@

def get_processed_dataset(
dataset_args: DatasetArguments,
processor: Optional[Processor] = None,
processor: Processor | None = None,
do_oneshot: bool = False,
do_train: bool = True,
) -> Optional[Dict[str, Dataset]]:
) -> dict[str, Dataset] | None:
"""
Loads datasets for each flow based on dataset_args, stores a Dataset for each
enabled flow in datasets
Expand All @@ -50,17 +50,22 @@ def get_processed_dataset(

def _get_split_name(inp_str):
# strip out split name, for ex train[60%:] -> train
match = re.match(r"(\w*)\[.*\]", inp_str)
if match is not None:
return match.group(1)
split_name_match = re.match(r"(\w*)\[.*\]", inp_str)
if split_name_match is not None:
return split_name_match.group(1)
return inp_str

if splits is None:
splits = {"all": None}
elif isinstance(splits, str):
splits = {_get_split_name(splits): splits}
elif isinstance(splits, List):
splits = {_get_split_name(s): s for s in splits}
match splits:
case None:
splits = {"all": None}
case str():
splits = {_get_split_name(splits): splits}
case list():
splits = {_get_split_name(s): s for s in splits}
case dict():
pass
case _:
raise ValueError(f"Invalid splits type: {type(splits)}")

# default to custom dataset if dataset provided isn't a string
registry_id = (
Expand Down Expand Up @@ -121,10 +126,10 @@ def get_calibration_dataloader(

def format_calibration_data(
tokenized_dataset: Dataset,
num_calibration_samples: Optional[int] = None,
num_calibration_samples: int | None = None,
do_shuffle: bool = True,
collate_fn: Callable = default_data_collator,
) -> List[torch.Tensor]:
) -> list[torch.Tensor]:
"""
Creates a dataloader out of the calibration dataset split, trimming it to
the desired number of calibration samples
Expand Down Expand Up @@ -172,10 +177,10 @@ def format_calibration_data(


def make_dataset_splits(
tokenized_datasets: Dict[str, Any],
tokenized_datasets: dict[str, Any],
do_oneshot: bool = True,
do_train: bool = False,
) -> Dict[str, Dataset]:
) -> dict[str, Dataset]:
"""
Restructures the datasets dictionary based on what tasks will be run
train
Expand Down