Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 119 additions & 1 deletion datafast/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,130 @@ class JudgeLLMOutput(BaseModel):


class DatasetBase(ABC):
"""Abstract base class for all dataset generators."""
"""
Abstract base class for all dataset generators.

Methods
-------
inspect():
Launch a Gradio app to visually browse the dataset (self.data_rows).
Requires gradio to be installed (pip install gradio).
Provides Next/Previous navigation through dataset examples.
"""

def __init__(self, config):
self.config = config
self.data_rows = []

def inspect(self, random: bool = False) -> None:
"""
Launch an interactive Gradio app to visually inspect the generated dataset.

This method redirects to specialized inspectors in datafast.inspectors module,
which provide tailored visualization for each dataset type.

Args:
random: If True, examples will be shown in random order instead of sequential order.
Default is False (sequential order).

Raises:
ImportError: If gradio is not installed.
ValueError: If the dataset type is not supported by any specialized inspector.
"""
import warnings
from importlib import import_module

try:
# Test if Gradio is installed
import gradio as gr
except ImportError as e:
raise ImportError("Gradio is required for .inspect(). Install with 'pip install gradio'.") from e

if not self.data_rows:
raise ValueError("No data rows to inspect. Generate or load data first.")

try:
# Import inspectors dynamically to prevent import cycles
inspectors = import_module('datafast.inspectors')

# Get the class name without module prefix and convert CamelCase to snake_case
class_name = self.__class__.__name__

# Convert CamelCase to snake_case (e.g., ClassificationDataset -> classification_dataset)
import re
snake_case = re.sub(r'(?<!^)(?=[A-Z])', '_', class_name).lower()
inspector_name = f"inspect_{snake_case}"

if hasattr(inspectors, inspector_name):
# Call the appropriate specialized inspector
inspector_func = getattr(inspectors, inspector_name)
inspector_func(self, random=random)
else:
# Fall back to generic JSON display with a warning
warnings.warn(
f"No specialized inspector found for {class_name}. "
"Using generic inspector. Consider adding a specialized inspector "
"in datafast.inspectors module.",
UserWarning
)
self._generic_inspect(random=random)
except Exception as e:
# If there's any error in the process, fall back to generic inspector
warnings.warn(f"Error using specialized inspector: {e}. Falling back to generic inspector.")
self._generic_inspect(random=random)

def _generic_inspect(self, random: bool = False) -> None:
"""Generic inspector that displays dataset rows as JSON.

Args:
random: If True, examples will be shown in random order instead of sequential. Default is False.
"""
import gradio as gr
import numpy as np

# Convert data rows to dicts for display
examples = [row.model_dump() if hasattr(row, 'model_dump') else row.dict() if hasattr(row, 'dict') else row for row in self.data_rows]
total = len(examples)

# Generate random order indices if random is True
if random and total > 1:
import numpy as np
# Create a permutation of indices
random_indices = np.random.permutation(total)
display_order = list(random_indices)
ordering_label = "(Random Order)"
else:
# Sequential order
display_order = list(range(total))
ordering_label = ""

def show_example(idx: int) -> tuple[str, dict]:
idx = max(0, min(idx, total - 1))
# Get the actual example based on the display order
example_idx = display_order[idx]
return f"Example {idx+1} / {total} {ordering_label}", examples[example_idx]

with gr.Blocks() as demo:
idx_state = gr.State(0)
gr.Markdown("# Dataset Inspector (Generic)")
idx_label = gr.Markdown()
data_view = gr.JSON()
with gr.Row():
prev_btn = gr.Button("Previous")
next_btn = gr.Button("Next")

def update_example(idx):
label, example = show_example(idx)
return label, example, idx

prev_btn.click(lambda idx: max(0, idx-1), idx_state, idx_state).then(update_example, idx_state, [idx_label, data_view, idx_state])
next_btn.click(lambda idx: min(total-1, idx+1), idx_state, idx_state).then(update_example, idx_state, [idx_label, data_view, idx_state])

# Initial display
demo.load(update_example, idx_state, [idx_label, data_view, idx_state])

demo.launch()

@abstractmethod
def generate(self, llms=None):
"""Main method to generate the dataset."""
Expand Down
73 changes: 73 additions & 0 deletions datafast/examples/inspect_dataset_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
Example script showing how to generate a dataset and launch the visual inspector.

Run with:
python -m datafast.examples.inspect_dataset_example

Requires:
- OpenAI API key in secrets.env or environment
- gradio package (pip install gradio)
"""
from datafast.datasets import ClassificationDataset
from datafast.schema.config import ClassificationDatasetConfig, PromptExpansionConfig
from dotenv import load_dotenv

# Load API keys from environment or secrets.env
load_dotenv("secrets.env")

# Configure the dataset generation
config = ClassificationDatasetConfig(
classes=[
{"name": "positive", "description": "Text expressing positive emotions or approval"},
{"name": "negative", "description": "Text expressing negative emotions or criticism"},
],
num_samples_per_prompt=2, # Small number for quick demo
output_file="outdoor_activities_sentiments.jsonl", # Optional, will save generated data
languages={
"en": "English",
},
prompts=[
(
"Generate {num_samples} reviews in {language_name} which are diverse "
"and representative of a '{label_name}' sentiment class. "
"{label_description}. The reviews should be brief and in the "
"context of {{context}}."
)
],
expansion=PromptExpansionConfig(
placeholders={
"context": ["hiking trail review", "kayaking trip review"],
},
combinatorial=True # Will generate combinations of all placeholders
)
)

# Set up LLM providers
from datafast.llms import OpenAIProvider, AnthropicProvider, GeminiProvider

providers = [
OpenAIProvider(model_id="gpt-4.1-nano"),
# Uncomment to use additional providers
# AnthropicProvider(model_id="claude-3-5-haiku-latest"),
# GeminiProvider(model_id="gemini-2.0-flash"),
]

def main():
# Generate the dataset
dataset = ClassificationDataset(config)
num_expected_rows = dataset.get_num_expected_rows(providers)
print(f"Expected number of rows to generate: {num_expected_rows}")

# Generate data (comment out if loading existing data)
print("Generating dataset...")
dataset.generate(providers)
print(f"Generated {len(dataset.data_rows)} examples")

# Launch the interactive inspector
print("\nLaunching dataset inspector...")
print("(Close the browser window or press Ctrl+C to exit)")
print("Showing examples in random order")
dataset.inspect(random=True)

if __name__ == "__main__":
main()
160 changes: 160 additions & 0 deletions datafast/examples/show_dataset_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""
Demo: Show an example for each dataset type using the Gradio inspectors.

Run with:
python show_dataset_examples.py

Requires: gradio
"""
from datafast.inspectors import (
inspect_classification_dataset,
inspect_mcq_dataset,
inspect_preference_dataset,
inspect_raw_dataset,
inspect_ultrachat_dataset,
)
from datafast.schema.data_rows import (
TextClassificationRow,
MCQRow,
PreferenceRow,
TextRow,
ChatRow,
)
from datafast.datasets import (
ClassificationDataset,
MCQDataset,
PreferenceDataset,
RawDataset,
UltrachatDataset,
)
from datafast.schema.config import (
ClassificationDatasetConfig,
MCQDatasetConfig,
PreferenceDatasetConfig,
)

# --- Classification Example ---
classification_row = TextClassificationRow(
text="The trail is blocked by a fallen tree.",
label="trail_obstruction",
model_id="gpt-4.1-nano",
metadata={"language": "en"},
)
classification_dataset = ClassificationDataset(
ClassificationDatasetConfig(classes=[{"name": "trail_obstruction", "description": "Obstruction on the trail."}])
)
classification_row2 = TextClassificationRow(
text="The trail is well maintained and easy to follow.",
label="positive_conditions",
model_id="claude-3-5-haiku-latest",
metadata={"language": "en"},
)
classification_dataset.data_rows = [classification_row, classification_row2]

# --- MCQ Example ---
mcq_row = MCQRow(
source_document="The Eiffel Tower is in Paris.",
question="Where is the Eiffel Tower located?",
correct_answer="Paris",
incorrect_answer_1="London",
incorrect_answer_2="Berlin",
incorrect_answer_3="Rome",
model_id="gemini-2.0-flash",
metadata={"language": "en"},
)
mcq_config = MCQDatasetConfig(
text_column="source_document",
local_file_path="dummy.jsonl", # Required by config, not used in this demo
)
mcq_dataset = MCQDataset(mcq_config)
mcq_row2 = MCQRow(
source_document="The Amazon River is the second longest river in the world.",
question="Which river is the second longest in the world?",
correct_answer="Amazon River",
incorrect_answer_1="Nile River",
incorrect_answer_2="Yangtze River",
incorrect_answer_3="Mississippi River",
model_id="gpt-4.1-nano",
metadata={"language": "en"},
)
mcq_dataset.data_rows = [mcq_row, mcq_row2]

# --- Preference Example ---
preference_row = PreferenceRow(
input_document="Describe a recent Mars mission.",
question="What was the main goal of the Mars 2020 mission?",
chosen_response="To search for signs of ancient life and collect samples.",
rejected_response="To launch a satellite.",
chosen_model_id="claude-3-5-haiku-latest",
rejected_model_id="gpt-4.1-nano",
chosen_response_score=9,
rejected_response_score=3,
chosen_response_assessment="Accurate and detailed.",
rejected_response_assessment="Too generic.",
metadata={"language": "en"},
)
preference_dataset = PreferenceDataset(PreferenceDatasetConfig(input_documents=["Describe a recent Mars mission."]))
preference_row2 = PreferenceRow(
input_document="Describe the Voyager 1 mission.",
question="What is Voyager 1 known for?",
chosen_response="It is the farthest human-made object from Earth, exploring interstellar space.",
rejected_response="It is a Mars rover.",
chosen_model_id="gemini-2.0-flash",
rejected_model_id="gpt-4.1-nano",
chosen_response_score=10,
rejected_response_score=2,
chosen_response_assessment="Factually correct and detailed.",
rejected_response_assessment="Incorrect mission.",
metadata={"language": "en"},
)
preference_dataset.data_rows = [preference_row, preference_row2]

# --- RawDataset Example ---
from datafast.schema.data_rows import TextRow
from datafast.schema.config import RawDatasetConfig, UltrachatDatasetConfig

raw_row1 = TextRow(
text="SpaceX launched a new batch of Starlink satellites.",
text_source="human",
metadata={"date": "2025-06-30", "topic": "space"}
)
raw_row2 = TextRow(
text="The James Webb Space Telescope captured new images of a distant galaxy.",
text_source="synthetic",
metadata={"date": "2025-06-29", "topic": "astronomy"}
)
raw_config = RawDatasetConfig(document_types=["news_article", "science_report"], topics=["space", "astronomy"])
raw_dataset = RawDataset(raw_config)
raw_dataset.data_rows = [raw_row1, raw_row2]

# --- UltrachatDataset Example ---
from datafast.schema.data_rows import ChatRow
ultrachat_row1 = ChatRow(
opening_question="How can we reduce space debris?",
persona="space policy expert",
messages=[{"role": "user", "content": "What are current efforts to clean up space debris?"}, {"role": "assistant", "content": "There are several ongoing projects, such as RemoveDEBRIS and ClearSpace-1."}],
model_id="gpt-4.1-nano",
metadata={"language": "en"}
)
ultrachat_row2 = ChatRow(
opening_question="What is the importance of the Moon missions?",
persona="lunar geologist",
messages=[{"role": "user", "content": "Why do we keep returning to the Moon?"}, {"role": "assistant", "content": "The Moon offers scientific insights and is a stepping stone for Mars exploration."}],
model_id="gemini-2.0-flash",
metadata={"language": "en"}
)
ultrachat_config = UltrachatDatasetConfig()
ultrachat_dataset = UltrachatDataset(ultrachat_config)
ultrachat_dataset.data_rows = [ultrachat_row1, ultrachat_row2]

if __name__ == "__main__":
# print("Showing ClassificationDataset example...")
# inspect_classification_dataset(classification_dataset)
# print("Showing MCQDataset example...")
# inspect_mcq_dataset(mcq_dataset)
# print("Showing PreferenceDataset example...")
# inspect_preference_dataset(preference_dataset)
# print("Showing RawDataset example...")
# inspect_raw_dataset(raw_dataset)
# print("Showing UltrachatDataset example...")
# inspect_ultrachat_dataset(ultrachat_dataset)
Loading
Loading