# SFW Classifier - Content Moderation Pipeline

In [1]:
!pip install pymongo

import os
from pymongo import MongoClient
import pandas as pd
from google import genai
from google.genai import types
import json
from datetime import datetime
import time

Collecting pymongo
  Downloading pymongo-4.15.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (22 kB)
Collecting dnspython<3.0.0,>=1.16.0 (from pymongo)
  Downloading dnspython-2.8.0-py3-none-any.whl.metadata (5.7 kB)
Downloading pymongo-4.15.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (1.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m40.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dnspython-2.8.0-py3-none-any.whl (331 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m331.1/331.1 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: dnspython, pymongo
Successfully installed dnspython-2.8.0 pymongo-4.15.2


## Configuration

In [None]:
# MongoDB Configuration
MONGO_URI = os.getenv('MONGO_URI', 'mongodb://localhost:27017/')
DATABASE_NAME = os.getenv('MONGO_DATABASE', 'your_database')
COLLECTION_NAME = os.getenv('MONGO_COLLECTION', 'your_collection')

# Gemini Configuration
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', '')

# Content Moderation System Prompt
SYSTEM_PROMPT = "You are an expert data labeler specializing in content classification. Your task is to analyze conversation snippets and assign soft labels based on the level of sexual or intimate content, providing a probability distribution across a defined set of categories."

# Classification Prompt Template
PROMPT_TEMPLATE = """You will be provided with a conversation history. Your task is to analyze the entire conversation for context but focus primarily on the **last message** to classify its content.

**User Input:**
Here is the conversation history:
{conversation_history}

**Classification Scale:**
1.  **SFW:** No sexual or intimate content.
2.  **Mildly Suggestive:** Light romance or flirtation; non-graphic.
3.  **Moderately Intimate:** Erotic situations without explicit sexual acts.
4.  **Explicit NSFW:** Graphic depiction of sexual acts or anatomy.
5.  **Illegal / Extreme NSFW:** Prohibited content (e.g., involving minors, non-consensual acts).

**Output Instructions:**
1.  Your response must be a single JSON object representing a probability distribution (soft labels) of which category the content belongs to.
2.  The keys of the JSON object must be the following strings: "SFW", "Mildly Suggestive", "Moderately Intimate", "Explicit NSFW", "Illegal / Extreme NSFW".
3.  The value for each key must be a floating-point number between 0.0 and 1.0, representing the probability for that category.
4.  The sum of all probability values in the dictionary must equal 1.0.
5.  Your final output must ONLY be the JSON object. Do not include any other text, explanations, or markdown formatting like ```json.

**Example Output Format:**
```json
{
  "SFW": 0.1,
  "Mildly Suggestive": 0.7,
  "Moderately Intimate": 0.2,
  "Explicit NSFW": 0.0,
  "Illegal / Extreme NSFW": 0.0
}
```"""

# Processing Configuration
NUM_ROWS = 100000  # Number of rows to fetch from MongoDB
OUTPUT_FILE = f'sfw_classification_{datetime.now().strftime("%Y%m%d_%H%M%S")}.parquet'

## 1. Connect to MongoDB and Fetch Data

In [None]:
def fetch_data_from_mongodb(uri, db_name, collection_name, limit=100000):
    """Fetch data from MongoDB collection"""
    client = MongoClient(uri)
    db = client[db_name]
    collection = db[collection_name]

    # Fetch documents
    documents = list(collection.find().limit(limit))

    # Convert to DataFrame
    df = pd.DataFrame(documents)

    client.close()
    print(f"Fetched {len(df)} rows from MongoDB")
    return df

# Fetch data
df = fetch_data_from_mongodb(MONGO_URI, DATABASE_NAME, COLLECTION_NAME, NUM_ROWS)
df.head()

## 2. Prepare Batch Request File (JSONL)
Create a JSONL file with all requests for Gemini Batch API

In [None]:
def format_conversation_history(conversation_array):
    """Format conversation history with role labels"""
    # Take last 6 messages (or fewer if less available)
    last_messages = conversation_array[-6:] if len(conversation_array) >= 6 else conversation_array

    formatted_messages = []
    for msg in last_messages:
        role = msg.get('role', 'unknown')
        content = msg.get('content', '')
        formatted_messages.append(f"{role.upper()}: {content}")

    # Use newline separator for readability
    return "\n".join(formatted_messages)

def create_batch_requests_file(df, conversation_column, prompt_template, system_prompt, output_file='batch_requests.jsonl'):
    """Create JSONL file with batch requests for Gemini API"""
    with open(output_file, 'w') as f:
        for idx, row in df.iterrows():
            conversation_array = row[conversation_column]
            conversation_history = format_conversation_history(conversation_array)
            user_prompt = prompt_template.format(conversation_history=conversation_history)

            request = {
                "key": f"request-{idx}",
                "request": {
                    "system_instruction": {
                        "parts": [{"text": system_prompt}]
                    },
                    "contents": [{
                        "parts": [{"text": user_prompt}],
                        "role": "user"
                    }],
                    "generation_config": {
                        "temperature": 0.0,
                        "max_output_tokens": 200,
                        "response_modalities": ["TEXT"],
                        "response_mime_type": "application/json"
                    },
                    "thinking_config": {
                        "thinking_budget": 0
                    },
                    "safety_settings": [
                        {
                            "category": "HARM_CATEGORY_HATE_SPEECH",
                            "threshold": "OFF"
                        },
                        {
                            "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
                            "threshold": "OFF"
                        },
                        {
                            "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
                            "threshold": "OFF"
                        },
                        {
                            "category": "HARM_CATEGORY_HARASSMENT",
                            "threshold": "OFF"
                        }
                    ]
                },
                "metadata": {
                    "original_id": str(row.get('_id', idx)),
                    "index": idx
                }
            }
            f.write(json.dumps(request) + '\n')

    print(f"Created batch requests file: {output_file} with {len(df)} requests")
    return output_file

# Create batch requests (update 'conversation_column' to your actual column name)
# batch_file = create_batch_requests_file(df, conversation_column='conversation_history', prompt_template=PROMPT_TEMPLATE, system_prompt=SYSTEM_PROMPT)

## 3. Submit Batch Job to Gemini API

In [None]:
# Initialize Gemini client
client = genai.Client(api_key=GEMINI_API_KEY)

def submit_batch_job(client, batch_file, model="gemini-2.5-flash-preview-09-2025"):
    """Upload batch file and create batch job"""

    # Upload the batch requests file
    print(f"Uploading {batch_file}...")
    uploaded_file = client.files.upload(
        file=batch_file,
        config={'display_name': 'sfw-classification-batch', 'mime_type': 'application/jsonl'}
    )
    print(f"Uploaded file: {uploaded_file.name}")

    # Create batch job
    print("Creating batch job...")
    batch_job = client.batches.create(
        model=f"models/{model}",
        src=uploaded_file.name,
        config={
            'display_name': f"sfw-classifier-{datetime.now().strftime('%Y%m%d_%H%M%S')}",
        }
    )

    print(f"Batch job created: {batch_job.name}")
    print(f"Job state: {batch_job.state.name}")
    print(f"\n⚠️ SAVE THIS JOB ID: {batch_job.name}")

    return batch_job

# Submit batch job
# batch_job = submit_batch_job(client, batch_file)
# JOB_ID = batch_job.name  # Save this for later!

## 4. Monitor Batch Job Status

In [None]:
# If you need to reconnect, reinitialize the client
# client = genai.Client(api_key=GEMINI_API_KEY)

def check_batch_status(job_id, api_key=None):
    """Check the status of a batch job using job ID"""
    if api_key is None:
        api_key = GEMINI_API_KEY

    client = genai.Client(api_key=api_key)
    job = client.batches.get(name=job_id)
    print(f"Job: {job.name}")
    print(f"State: {job.state.name}")
    print(f"Create time: {job.create_time}")
    if hasattr(job, 'update_time'):
        print(f"Update time: {job.update_time}")
    return job

def wait_for_batch_completion(job_id, api_key=None, check_interval=60):
    """Wait for batch job to complete (checks every check_interval seconds)"""
    if api_key is None:
        api_key = GEMINI_API_KEY

    client = genai.Client(api_key=api_key)

    while True:
        job = client.batches.get(name=job_id)
        print(f"Current state: {job.state.name}")

        if job.state.name == 'JOB_STATE_SUCCEEDED':
            print("Batch job completed successfully!")
            return job
        elif job.state.name in ['JOB_STATE_FAILED', 'JOB_STATE_CANCELLED']:
            print(f"Batch job {job.state.name}")
            return job

        print(f"Waiting {check_interval} seconds before next check...")
        time.sleep(check_interval)

# Check status using saved job ID
# JOB_ID = "your-job-id-here"  # Paste your job ID here
# check_batch_status(JOB_ID)

# Or wait for completion
# completed_job = wait_for_batch_completion(JOB_ID)

## 5. Retrieve and Process Results

In [None]:
def download_batch_results(job, api_key=None):
    """Download and parse batch job results"""
    if api_key is None:
        api_key = GEMINI_API_KEY

    client = genai.Client(api_key=api_key)

    # If job is a string (job_id), fetch the job object
    if isinstance(job, str):
        job = client.batches.get(name=job)

    if job.state.name != 'JOB_STATE_SUCCEEDED':
        print(f"Job not successful. Current state: {job.state.name}")
        return None

    # Get result file name
    result_file_name = job.output.file_name if hasattr(job, 'output') else job.dest.file_name

    # Download results
    print(f"Downloading results from: {result_file_name}")
    file_content_bytes = client.files.download(file=result_file_name)
    file_content = file_content_bytes.decode('utf-8')

    # Parse JSONL results
    results = []
    for line in file_content.splitlines():
        if line.strip():
            result = json.loads(line)
            results.append(result)

    print(f"Downloaded {len(results)} results")
    return results

def extract_last_messages(conversation_array, num_messages=6):
    """Extract last N messages from conversation history"""
    last_messages = conversation_array[-num_messages:] if len(conversation_array) >= num_messages else conversation_array
    formatted = []
    for msg in last_messages:
        role = msg.get('role', 'unknown')
        content = msg.get('content', '')
        formatted.append(f"{role.upper()}: {content}")
    return "\n---\n".join(formatted)

def parse_results_to_dataframe(results, original_df, conversation_column='conversation_history'):
    """Convert batch results to a clean DataFrame with soft label probabilities and last 6 messages"""
    parsed_results = []

    for result in results:
        request_key = result.get('key', '')
        metadata = result.get('metadata', {})
        idx = metadata.get('index')

        # Extract response text (JSON soft labels)
        soft_labels = None
        raw_response = None
        if 'response' in result:
            candidates = result['response'].get('candidates', [])
            if candidates:
                content = candidates[0].get('content', {})
                parts = content.get('parts', [])
                if parts:
                    raw_response = parts[0].get('text', '').strip()
                    # Try to parse as JSON
                    try:
                        soft_labels = json.loads(raw_response)
                    except json.JSONDecodeError:
                        soft_labels = None

        # Get last 6 messages from original data
        last_6_messages = ""
        if idx is not None and idx < len(original_df):
            conversation_array = original_df.iloc[idx][conversation_column]
            last_6_messages = extract_last_messages(conversation_array, num_messages=6)

        # Extract individual probabilities
        prob_sfw = soft_labels.get('SFW', None) if soft_labels else None
        prob_mildly_suggestive = soft_labels.get('Mildly Suggestive', None) if soft_labels else None
        prob_moderately_intimate = soft_labels.get('Moderately Intimate', None) if soft_labels else None
        prob_explicit_nsfw = soft_labels.get('Explicit NSFW', None) if soft_labels else None
        prob_illegal_extreme = soft_labels.get('Illegal / Extreme NSFW', None) if soft_labels else None

        parsed_results.append({
            'request_key': request_key,
            'original_id': metadata.get('original_id'),
            'index': idx,
            'last_6_messages': last_6_messages,
            'prob_sfw': prob_sfw,
            'prob_mildly_suggestive': prob_mildly_suggestive,
            'prob_moderately_intimate': prob_moderately_intimate,
            'prob_explicit_nsfw': prob_explicit_nsfw,
            'prob_illegal_extreme': prob_illegal_extreme,
            'raw_response': raw_response,
            'status': result.get('status', {})
        })

    results_df = pd.DataFrame(parsed_results)
    results_df = results_df.sort_values('index').reset_index(drop=True)

    return results_df

# Download and parse results using job ID
# JOB_ID = "your-job-id-here"
# results = download_batch_results(JOB_ID)
# results_df = parse_results_to_dataframe(results, df, conversation_column='conversation_history')
# results_df.head()

# View probability distributions
# print(results_df[['prob_sfw', 'prob_mildly_suggestive', 'prob_moderately_intimate', 'prob_explicit_nsfw', 'prob_illegal_extreme']].describe())

## 6. Save Results to Flatfile

In [None]:
# Save to Parquet (recommended - handles text data well)
# columns_to_save = ['last_6_messages', 'prob_sfw', 'prob_mildly_suggestive', 'prob_moderately_intimate', 'prob_explicit_nsfw', 'prob_illegal_extreme']
# results_df[columns_to_save].to_parquet(OUTPUT_FILE, index=False, compression='snappy')
# print(f"Results saved to {OUTPUT_FILE}")

# Alternative: Save to JSONL (preserves structure better)
# results_df[columns_to_save].to_json(OUTPUT_FILE.replace('.parquet', '.jsonl'), orient='records', lines=True)

# View sample output
# print(results_df[columns_to_save].head())