<a href="https://colab.research.google.com/github/sngo/llms-practice/blob/main/synthectic_data/synthetic_data_generator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q gradio>=4.0.0 anthropic>=0.18.0 pandas>=2.0.0

In [7]:
import os
import json
import csv
import time
import tempfile
import gradio as gr
import anthropic
import pandas as pd
import sqlite3
from openai import OpenAI
from typing import Dict, List, Any, Optional, Union
from google.colab import userdata

In [8]:
api_key = userdata.get('OPENAI_API_KEY')

In [18]:
class SyntheticDataGenerator:
    def __init__(self, api_key: Optional[str] = None):
        """Initialize the SyntheticDataGenerator with API key."""
        self.api_key = api_key or userdata.get('OPENAI_API_KEY')
        if not self.api_key:
            raise ValueError("OPENAI API key is required")

        self.client = OpenAI(api_key=self.api_key)
        self.model = "gpt-4o-mini"  # Default model

    def set_model(self, model_name: str) -> None:
        """Set the model to use."""
        self.model = model_name
        return f"Model set to: {self.model}"

    def validate_schema(self, schema: Dict[str, Any]) -> bool:
        """Validate that the schema is properly formatted."""
        required_keys = ["name", "fields"]
        if not all(key in schema for key in required_keys):
            return False, f"Schema missing required keys: {required_keys}"

        if not isinstance(schema["fields"], list) or len(schema["fields"]) == 0:
            return False, "Schema must contain a non-empty 'fields' list"

        for field in schema["fields"]:
            if "name" not in field or "type" not in field:
                return False, f"Each field must have 'name' and 'type' attributes: {field}"

        return True, "Schema is valid"

    def generate_prompt(self, schema: Dict[str, Any], num_records: int, special_instructions: Optional[str] = None) -> str:
        """Create a prompt for model to generate synthetic data based on schema."""
        fields_info = "\n".join([f"- {field['name']} ({field['type']}): {field.get('description', '')}"
                              for field in schema["fields"]])

        constraints = schema.get("constraints", [])
        constraints_text = ""
        if constraints:
            constraints_text = "Apply these constraints:\n" + "\n".join([f"- {c}" for c in constraints])

        additional_instructions = special_instructions or schema.get("additional_instructions", "")

        prompt = f"""
        I need you to generate {num_records} records of synthetic test data according to this schema:

        Dataset: {schema["name"]}

        Fields:
        {fields_info}

        {constraints_text}

        {additional_instructions}

        Important guidelines:
        1. Generate realistic but completely fictional data
        2. Ensure internal consistency across fields when logical relationships exist
        3. Provide appropriate variety and distribution of values
        4. Format the output as a JSON array of objects with each field name as the key
        5. Only return the raw JSON data with no additional explanations

        Output the synthetic data as valid JSON that I can parse:
        """

        return prompt

    def generate_data(self, schema: Dict[str, Any], num_records: int = 10,
                     special_instructions: Optional[str] = None, max_retries: int = 3) -> List[Dict[str, Any]]:
        """Generate synthetic data using Claude API."""
        prompt = self.generate_prompt(schema, num_records, special_instructions)

        progress_text = f"Generating {num_records} synthetic records with {self.model}..."

        for attempt in range(max_retries):
            try:
                response = self.client.chat.completions.create(
                    model=self.model,
                    max_tokens=4000,
                    messages=[
                        {"role": "user", "content": prompt}
                    ],
                    temperature=0.7,
                )

                # Extract JSON content
                content = response.choices[0].message.content
                # Look for JSON in the content
                json_start = content.find('[')
                json_end = content.rfind(']') + 1

                if json_start >= 0 and json_end > json_start:
                    json_str = content[json_start:json_end]
                    data = json.loads(json_str)
                    return data, f"Successfully generated {len(data)} records"
                else:
                    error_msg = "No valid JSON array found in response. Retrying..."
                    if attempt == max_retries - 1:
                        return None, f"Failed: {error_msg} Response content: {content[:100]}..."
            except json.JSONDecodeError:
                error_msg = f"Failed to parse JSON (Attempt {attempt+1}/{max_retries})"
                if attempt == max_retries - 1:
                    return None, f"Failed: {error_msg} - Raw response: {content[:500]}..."
            except Exception as e:
                error_msg = f"API error (Attempt {attempt+1}/{max_retries}): {str(e)}"
                if attempt == max_retries - 1:
                    return None, f"Failed: {error_msg}"
                time.sleep(2)  # Wait before retrying

        return None, f"Failed to generate valid data after {max_retries} attempts"

    def save_data(self, data: List[Dict[str, Any]], output_format: str, file_path: str = None, table_name: str = "synthetic_data"):
        """Save the generated data in the specified format."""
        if not data:
            return None, "No data to save"

        if output_format == "json":
            if not file_path:
                # Create a temporary file for download
                with tempfile.NamedTemporaryFile(delete=False, suffix='.json') as tmp:
                    file_path = tmp.name
                    json.dump(data, tmp, indent=2)
            else:
                with open(file_path, 'w') as f:
                    json.dump(data, f, indent=2)
            return file_path, f"Saved {len(data)} records as JSON"

        elif output_format == "csv":
            if not file_path:
                with tempfile.NamedTemporaryFile(delete=False, suffix='.csv') as tmp:
                    file_path = tmp.name

            keys = data[0].keys()
            with open(file_path, 'w', newline='') as output_file:
                dict_writer = csv.DictWriter(output_file, fieldnames=keys)
                dict_writer.writeheader()
                dict_writer.writerows(data)
            return file_path, f"Saved {len(data)} records as CSV"

        elif output_format == "sqlite":
            if not file_path:
                with tempfile.NamedTemporaryFile(delete=False, suffix='.db') as tmp:
                    file_path = tmp.name

            df = pd.DataFrame(data)
            conn = sqlite3.connect(file_path)
            df.to_sql(table_name, conn, if_exists='replace', index=False)
            conn.close()
            return file_path, f"Saved {len(data)} records to SQLite database, table {table_name}"

        elif output_format == "dataframe":
            # Return as pandas DataFrame for display in Gradio
            df = pd.DataFrame(data)
            return df, f"Generated {len(data)} records"

        else:
            return None, f"Unsupported format: {output_format}"





In [26]:
# Sample schema for the UI
SAMPLE_SCHEMA = """{
  "name": "Users Test Dataset",
  "description": "Synthetic user data for application testing",
  "fields": [
    {
      "name": "id",
      "type": "uuid",
      "description": "Unique identifier for each user"
    },
    {
      "name": "first_name",
      "type": "string",
      "description": "User's first name"
    },
    {
      "name": "last_name",
      "type": "string",
      "description": "User's last name"
    },
    {
      "name": "email",
      "type": "email",
      "description": "User's email address"
    },
    {
      "name": "age",
      "type": "integer",
      "description": "User's age in years",
      "min": 18,
      "max": 85
    },
    {
      "name": "subscription_tier",
      "type": "enum",
      "description": "User's subscription level",
      "values": ["free", "basic", "premium", "enterprise"]
    },
    {
      "name": "is_active",
      "type": "boolean",
      "description": "Whether the user account is active"
    }
  ],
  "constraints": [
    "If age is under 25, subscription_tier should be more likely to be 'free' or 'basic'",
    "80% of users should have is_active set to true"
  ],
  "additional_instructions": "Create diverse users with varied names and realistic data patterns."
}"""

In [20]:
generator = SyntheticDataGenerator()

In [21]:
generated_data = generator.generate_data(json.loads(SAMPLE_SCHEMA), num_records=10)

In [22]:
print(generated_data)

([{'id': 'd9e0b1f9-efb3-4c8e-8cfd-8b3a2e1f1b9a', 'first_name': 'Liam', 'last_name': 'Smith', 'email': 'liam.smith@example.com', 'age': 22, 'subscription_tier': 'free', 'is_active': True}, {'id': '5c5c3b3e-5da2-4f49-9f38-08de5a9b2b59', 'first_name': 'Emma', 'last_name': 'Johnson', 'email': 'emma.johnson@example.com', 'age': 28, 'subscription_tier': 'premium', 'is_active': True}, {'id': '1c4c4c3e-18db-4b58-a4e5-2d3c9f3e7a2d', 'first_name': 'Olivia', 'last_name': 'Williams', 'email': 'olivia.williams@example.com', 'age': 20, 'subscription_tier': 'basic', 'is_active': True}, {'id': 'b8c1f3a1-5d4f-4b6b-b3a8-e3d1a1f3b9b2', 'first_name': 'Noah', 'last_name': 'Jones', 'email': 'noah.jones@example.com', 'age': 35, 'subscription_tier': 'standard', 'is_active': False}, {'id': 'a5f1b2d9-1d39-4b4e-b4a9-6d5c5a7c2c3d', 'first_name': 'Ava', 'last_name': 'Brown', 'email': 'ava.brown@example.com', 'age': 24, 'subscription_tier': 'free', 'is_active': True}, {'id': '1a4e1b0b-5c3f-4a8a-b9c9-8b0c8d1e4a6f', 

In [23]:
def validate_schema_input(schema_text):
    """Validate the schema JSON."""
    try:
        schema = json.loads(schema_text)
        is_valid, message = generator.validate_schema(schema)
        return message
    except json.JSONDecodeError as e:
        return f"Invalid JSON: {str(e)}"
    except Exception as e:
        return f"Error: {str(e)}"

In [25]:
validate_schema_input(SAMPLE_SCHEMA)

"Each field must have 'name' and 'type' attributes: {'type': 'boolean', 'description': 'Whether the user account is active'}"

In [27]:
def generate_synthetic_data(schema_text, num_records, special_instructions):
    """Generate synthetic data based on the inputs."""
    global last_generated_data

    try:

        # Parse the schema
        schema = json.loads(schema_text)
        is_valid, message = generator.validate_schema(schema)
        if not is_valid:
            return None, message

        # Generate the data
        data, message = generator.generate_data(
            schema,
            num_records=int(num_records),
            special_instructions=special_instructions
        )

        if data:
            last_generated_data = data
            # Convert to DataFrame for display
            df = pd.DataFrame(data)
            return df, message
        else:
            return None, message

    except json.JSONDecodeError as e:
        return None, f"Invalid schema JSON: {str(e)}"
    except Exception as e:
        return None, f"Error generating data: {str(e)}"

In [28]:
def save_and_download_data(output_format, table_name):
    """Save the generated data in the selected format and prepare for download."""
    global last_generated_data

    if last_generated_data is None:
        return None, "No data has been generated yet"

    try:
        file_path, message = generator.save_data(
            last_generated_data,
            output_format,
            table_name=table_name
        )
        return file_path, message
    except Exception as e:
        return None, f"Error saving data: {str(e)}"

In [29]:
# Define the Gradio interface
with gr.Blocks(title="Synthetic Data Generator") as app:
    gr.Markdown("# Synthetic Testing Data Generator")
    gr.Markdown("Generate realistic synthetic data using OpenAI model")

    with gr.Tab("Generate Data"):
        with gr.Row():
            with gr.Column(scale=1):
                num_records_slider = gr.Slider(
                    minimum=1,
                    maximum=100,
                    value=10,
                    step=1,
                    label="Number of Records"
                )

                schema_validate_btn = gr.Button("Validate Schema")
                generate_btn = gr.Button("Generate Data", variant="primary")

            with gr.Column(scale=2):
                schema_input = gr.Textbox(
                    label="Data Schema (JSON)",
                    placeholder="Paste your schema here...",
                    value=SAMPLE_SCHEMA,
                    lines=20
                )

                special_instructions = gr.Textbox(
                    label="Additional Instructions (Optional)",
                    placeholder="Any special requirements for data generation",
                    lines=2
                )

        with gr.Row():
            status_output = gr.Textbox(label="Status", interactive=False)

        with gr.Row():
            data_output = gr.DataFrame(label="Generated Data")

    with gr.Tab("Download Data"):
        with gr.Row():
            with gr.Column():
                output_format = gr.Radio(
                    choices=["json", "csv", "sqlite"],
                    label="Output Format",
                    value="json"
                )

                table_name = gr.Textbox(
                    label="Table Name (for SQLite)",
                    value="synthetic_data"
                )

                download_btn = gr.Button("Prepare for Download", variant="primary")

            with gr.Column():
                download_status = gr.Textbox(label="Download Status", interactive=False)
                download_file = gr.File(label="Download File")

    # Wire up the components
    schema_validate_btn.click(validate_schema_input, inputs=[schema_input], outputs=[status_output])

    generate_btn.click(
        generate_synthetic_data,
        inputs=[schema_input, num_records_slider, special_instructions],
        outputs=[data_output, status_output]
    )

    download_btn.click(
        save_and_download_data,
        inputs=[output_format, table_name],
        outputs=[download_file, download_status]
    )

In [30]:
view = app.launch(share=True, debug=True)

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://9009225dd55f7eb32c.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://9009225dd55f7eb32c.gradio.live
