Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions dspy/datasets/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@


class DataLoader(Dataset):
def __init__(
self,
):
def __init__(self):
pass

def from_huggingface(
Expand Down Expand Up @@ -97,8 +95,7 @@ def from_parquet(self, file_path: str, fields: List[str] = None, input_keys: Tup

return [dspy.Example({field: row[field] for field in fields}).with_inputs(input_keys) for row in dataset]

def from_rm(self, num_samples: int,
fields: List[str], input_keys: List[str]) -> List[dspy.Example]:
def from_rm(self, num_samples: int, fields: List[str], input_keys: List[str]) -> List[dspy.Example]:
try:
rm = dspy.settings.rm
try:
Expand All @@ -107,9 +104,13 @@ def from_rm(self, num_samples: int,
for row in rm.get_objects(num_samples=num_samples, fields=fields)
]
except AttributeError:
raise ValueError("Retrieval module does not support `get_objects`. Please use a different retrieval module.")
raise ValueError(
"Retrieval module does not support `get_objects`. Please use a different retrieval module."
)
except AttributeError:
raise ValueError("Retrieval module not found. Please set a retrieval module using `dspy.settings.configure`.")
raise ValueError(
"Retrieval module not found. Please set a retrieval module using `dspy.settings.configure`."
)

def sample(
self,
Expand All @@ -119,7 +120,9 @@ def sample(
**kwargs,
) -> List[dspy.Example]:
if not isinstance(dataset, list):
raise ValueError(f"Invalid dataset provided of type {type(dataset)}. Please provide a list of examples.")
raise ValueError(
f"Invalid dataset provided of type {type(dataset)}. Please provide a list of `dspy.Example`s."
)

return random.sample(dataset, n, *args, **kwargs)

Expand All @@ -141,17 +144,28 @@ def train_test_split(
elif train_size is not None and isinstance(train_size, int):
train_end = train_size
else:
raise ValueError("Invalid train_size. Please provide a float between 0 and 1 or an int.")
raise ValueError(
"Invalid `train_size`. Please provide a float between 0 and 1 to represent the proportion of the "
"dataset to include in the train split or an int to represent the absolute number of samples to "
f"include in the train split. Received `train_size`: {train_size}."
)

if test_size is not None:
if isinstance(test_size, float) and (0 < test_size < 1):
test_end = int(len(dataset_shuffled) * test_size)
elif isinstance(test_size, int):
test_end = test_size
else:
raise ValueError("Invalid test_size. Please provide a float between 0 and 1 or an int.")
raise ValueError(
"Invalid `test_size`. Please provide a float between 0 and 1 to represent the proportion of the "
"dataset to include in the test split or an int to represent the absolute number of samples to "
f"include in the test split. Received `test_size`: {test_size}."
)
if train_end + test_end > len(dataset_shuffled):
raise ValueError("train_size + test_size cannot exceed the total number of samples.")
raise ValueError(
"`train_size` + `test_size` cannot exceed the total number of samples. Received "
f"`train_size`: {train_end}, `test_size`: {test_end}, and `dataset_size`: {len(dataset_shuffled)}."
)
else:
test_end = len(dataset_shuffled) - train_end

Expand Down
Loading