From 9bbe3e65c3773f9e160e9bfc25879a747e529b13 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Thu, 21 Nov 2024 17:51:07 -0800 Subject: [PATCH] Style fix for Dataloader --- dspy/datasets/dataloader.py | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/dspy/datasets/dataloader.py b/dspy/datasets/dataloader.py index 87d534f9d5..351b844189 100644 --- a/dspy/datasets/dataloader.py +++ b/dspy/datasets/dataloader.py @@ -10,9 +10,7 @@ class DataLoader(Dataset): - def __init__( - self, - ): + def __init__(self): pass def from_huggingface( @@ -97,8 +95,7 @@ def from_parquet(self, file_path: str, fields: List[str] = None, input_keys: Tup return [dspy.Example({field: row[field] for field in fields}).with_inputs(input_keys) for row in dataset] - def from_rm(self, num_samples: int, - fields: List[str], input_keys: List[str]) -> List[dspy.Example]: + def from_rm(self, num_samples: int, fields: List[str], input_keys: List[str]) -> List[dspy.Example]: try: rm = dspy.settings.rm try: @@ -107,9 +104,13 @@ def from_rm(self, num_samples: int, for row in rm.get_objects(num_samples=num_samples, fields=fields) ] except AttributeError: - raise ValueError("Retrieval module does not support `get_objects`. Please use a different retrieval module.") + raise ValueError( + "Retrieval module does not support `get_objects`. Please use a different retrieval module." + ) except AttributeError: - raise ValueError("Retrieval module not found. Please set a retrieval module using `dspy.settings.configure`.") + raise ValueError( + "Retrieval module not found. Please set a retrieval module using `dspy.settings.configure`." + ) def sample( self, @@ -119,7 +120,9 @@ def sample( **kwargs, ) -> List[dspy.Example]: if not isinstance(dataset, list): - raise ValueError(f"Invalid dataset provided of type {type(dataset)}. Please provide a list of examples.") + raise ValueError( + f"Invalid dataset provided of type {type(dataset)}. Please provide a list of `dspy.Example`s." + ) return random.sample(dataset, n, *args, **kwargs) @@ -141,7 +144,11 @@ def train_test_split( elif train_size is not None and isinstance(train_size, int): train_end = train_size else: - raise ValueError("Invalid train_size. Please provide a float between 0 and 1 or an int.") + raise ValueError( + "Invalid `train_size`. Please provide a float between 0 and 1 to represent the proportion of the " + "dataset to include in the train split or an int to represent the absolute number of samples to " + f"include in the train split. Received `train_size`: {train_size}." + ) if test_size is not None: if isinstance(test_size, float) and (0 < test_size < 1): @@ -149,9 +156,16 @@ def train_test_split( elif isinstance(test_size, int): test_end = test_size else: - raise ValueError("Invalid test_size. Please provide a float between 0 and 1 or an int.") + raise ValueError( + "Invalid `test_size`. Please provide a float between 0 and 1 to represent the proportion of the " + "dataset to include in the test split or an int to represent the absolute number of samples to " + f"include in the test split. Received `test_size`: {test_size}." + ) if train_end + test_end > len(dataset_shuffled): - raise ValueError("train_size + test_size cannot exceed the total number of samples.") + raise ValueError( + "`train_size` + `test_size` cannot exceed the total number of samples. Received " + f"`train_size`: {train_end}, `test_size`: {test_end}, and `dataset_size`: {len(dataset_shuffled)}." + ) else: test_end = len(dataset_shuffled) - train_end