# Generate Safety Classification Dataset from Web Sources

**Tags:** `dataset` `finetuning` | `intermediate` | `safety`

## Prerequisites

- Prem API key exported as `API_KEY`
- Basic understanding of content moderation and safety classification


In [None]:
import os
import json
import time
import requests

API_KEY = os.getenv("API_KEY")

# Load templates from JSON file
# Adjust path based on where notebook is executed from
TEMPLATES_PATH = os.path.join(os.path.dirname(os.path.abspath("")), "resources", "qa_templates.json")
if not os.path.exists(TEMPLATES_PATH):
    # Try alternative path
    TEMPLATES_PATH = "resources/qa_templates.json"

with open(TEMPLATES_PATH, "r") as f:
    TEMPLATES = json.load(f)

# Define web sources (10 pages from Anthropic HH-RLHF dataset)
WEB_URLS = [
    f"https://huggingface.co/datasets/Anthropic/hh-rlhf/viewer/default/train?p={i}"
    for i in range(1, 11)
]

if not API_KEY:
    raise ValueError("API_KEY environment variable is required")


def api(endpoint: str, method: str = "GET", **kwargs):
    response = requests.request(
        method=method,
        url=f"https://studio.premai.io{endpoint}",
        headers={"Authorization": f"Bearer {API_KEY}", **kwargs.pop("headers", {})},
        **kwargs
    )
    if not response.ok:
        err = response.json() if response.content else {}
        error_msg = err.get("error", str(err)) if isinstance(err, dict) else str(err)
        raise Exception(f"{response.status_code}: {error_msg}")
    return response.json()


## Step 1: Create Project


In [None]:
result = api(
    "/api/v1/public/projects/create",
    method="POST",
    headers={"Content-Type": "application/json"},
    json={
        "name": "Safety Classification Project",
        "goal": "Generate safety classification datasets for content moderation"
    }
)
project_id = result["project_id"]
print(f"✓ Project created: {project_id}")


## Step 2: Generate Response Safety Dataset


In [None]:
# Load response safety template
template = TEMPLATES["response_safety"]
question_format = template["question_format"]
answer_format = template["answer_format"]

form_data = {
    "project_id": project_id,
    "name": "Response Safety Classification Dataset",
    "pairs_to_generate": "50",
    "pair_type": "qa",
    "temperature": "0.3",
    "rules[]": template["rules"],
    "question_format": question_format,
    "answer_format": answer_format
}

# Add web URLs
for i, url in enumerate(WEB_URLS):
    form_data[f"web_urls[{i}]"] = url

result = api("/api/v1/public/datasets/create-synthetic", method="POST", data=form_data)
response_dataset_id = result["dataset_id"]
print(f"✓ Dataset created: {response_dataset_id}")


## Step 3: Wait for Dataset Generation


In [None]:
checks = 0
while True:
    time.sleep(5)
    dataset = api(f"/api/v1/public/datasets/{response_dataset_id}")
    if checks % 6 == 0:
        print(f"Status: {dataset['status']}, {dataset['datapoints_count']} datapoints")
    checks += 1
    if dataset["status"] != "processing":
        break

print(f"✓ Ready: {dataset['datapoints_count']} datapoints")


## Step 4: Generate User Prompt Safety Dataset


In [None]:
# Load user prompt safety template
template = TEMPLATES["user_prompt_safety"]
question_format = template["question_format"]
answer_format = template["answer_format"]

form_data = {
    "project_id": project_id,
    "name": "User Prompt Safety Classification Dataset",
    "pairs_to_generate": "50",
    "pair_type": "qa",
    "temperature": "0.3",
    "rules[]": template["rules"],
    "question_format": question_format,
    "answer_format": answer_format
}

# Add web URLs
for i, url in enumerate(WEB_URLS):
    form_data[f"web_urls[{i}]"] = url

result = api("/api/v1/public/datasets/create-synthetic", method="POST", data=form_data)
user_dataset_id = result["dataset_id"]
print(f"✓ Dataset created: {user_dataset_id}")


## Step 5: Create Snapshots and Fine-tune

Create snapshots and proceed with fine-tuning (see script.py or script.ts for complete implementation).
