# LLM-based Filtering for Reddit Stories

Uses Groq API with Llama 3.1-8B to classify stories as KEEP (ridiculous, humorous, absurd) or REMOVE (depressing, serious, harmful)

Features:
- Resumable (saves progress every 100 rows)
- Cloud-based inference (won't crash your computer!)
- Progress tracking
- Free tier: 1,000 requests/day, 30 requests/minute
- Estimated time: ~3 days for 2,611 stories

In [None]:
# Install groq (uncomment if needed)
# !pip install groq

In [None]:
# Imports
import pandas as pd
from pathlib import Path
import sys
from tqdm import tqdm
import json
import time
from groq import Groq

# Add parent directory to path
sys.path.append(str(Path.cwd().parent))
from config import RAW_DATA_DIR

print(f"Data directory: {RAW_DATA_DIR}")

# Initialize Groq client
# Get your free API key from: https://console.groq.com/keys
GROQ_API_KEY = input("Enter your Groq API key: ")
client = Groq(api_key=GROQ_API_KEY)

## Setup Groq API

1. Get your free API key: https://console.groq.com/keys
2. Enter it in the cell above
3. Free tier: 1,000 requests/day, 30 requests/minute

In [None]:
# Test Groq connection
try:
    response = client.chat.completions.create(
        model="llama-3.1-8b-instant",
        messages=[{"role": "user", "content": 'Say "OK" if you can read this.'}],
        max_tokens=10
    )
    print("‚úÖ Groq is working!")
    print(f"Response: {response.choices[0].message.content}")
except Exception as e:
    print(f"‚ùå Error: {e}")
    print("Make sure you entered a valid API key")

## Load Data

In [None]:
# Load the confession dataset
confession_file = RAW_DATA_DIR / "reddit_confession.parquet"
    
if not confession_file.exists():
    raise FileNotFoundError(f"Cannot find confession data. Looking for: {confession_file}")

df = pd.read_parquet(confession_file)
print(f"Loaded {len(df)} rows")
print(f"Columns: {list(df.columns)}")
print(f"\nFirst row:")
print(df.iloc[0])

## Define Classification Function

In [None]:
def classify_story(text, title=""):
    """
    Classify a story as KEEP or REMOVE using Groq API.
    
    Args:
        text: Story text
        title: Story title (optional)
    
    Returns:
        'KEEP' or 'REMOVE'
    """
    
    prompt = f"""You are a content filter. Classify this Reddit story as KEEP or REMOVE.

REMOVE stories that contain:
- Death, dying, serious illness, hospitals, medical emergencies
- Violence, abuse, assault, harm to people or animals
- Suicide, self-harm, depression, serious mental health issues
- Drug addiction, alcoholism, substance abuse, DUI, rehab
- Guilt, regret, shame, "still think about it", feeling terrible
- Cheating, divorce, serious relationship problems
- Getting people in trouble, suspended, fired, arrested
- Trauma, PTSD, haunting memories
- Anything genuinely sad, upsetting, or disturbing
- Serious consequences or people getting hurt

KEEP stories that are:
- Genuinely funny, silly, or absurdly ridiculous
- Harmless embarrassing moments
- Lighthearted chaos with no real harm
- Weird situations that are entertaining
- No guilt, no regret, no one gets hurt

Examples:

Story: "I shot up in the hospital while my mom was on life support"
Classification: REMOVE (death, hospital, serious)

Story: "I played pedestrian chicken while driving drugged"
Classification: REMOVE (drugs, DUI, could have killed people)

Story: "I outed a teacher, got someone else suspended, still think about it and feel terrible"
Classification: REMOVE (guilt, regret, got someone suspended, serious consequences)

Story: "I accidentally walked into the wrong apartment and sat on someone's couch before realizing"
Classification: KEEP (harmless funny mistake)

Now classify this story:

Title: {title}
Story: {text}

Think step by step:
1. Does anyone get seriously hurt or could have been hurt?
2. Is there guilt, regret, or serious consequences?
3. Is it genuinely funny and lighthearted, or is it sad/disturbing?

Answer with ONLY the word: KEEP or REMOVE"""
    
    try:
        response = client.chat.completions.create(
            model="llama-3.1-8b-instant",
            messages=[{"role": "user", "content": prompt}],
            temperature=0.0,
            max_tokens=20
        )
        
        result = response.choices[0].message.content.strip().upper()
        
        if 'KEEP' in result:
            return 'KEEP'
        elif 'REMOVE' in result:
            return 'REMOVE'
        else:
            return 'REMOVE'  # Default to REMOVE
            
    except Exception as e:
        print(f"Error classifying story: {e}")
        return 'REMOVE'

## Test Classification on Sample

In [None]:
for idx in range(20, 40):
    row = df.iloc[idx]
    title = row.get('title', '')
    text = row.get('text', '')
    
    result = classify_story(text, title)
    
    print(f"Row {idx}:")
    print(f"Title: {title}")
    print(f"Text: {text}")
    print(f"Result: {result}")
    print("-" * 80 + "\n")

## Run Classification on Full Dataset (Resumable)

In [None]:
# Setup for resumable processing
checkpoint_file = RAW_DATA_DIR / "llm_filter_checkpoint.json"
checkpoint_interval = 100  # Save every 100 rows

# Load checkpoint if exists
if checkpoint_file.exists():
    with open(checkpoint_file, 'r') as f:
        checkpoint_data = json.load(f)
    classifications = checkpoint_data['classifications']
    start_idx = checkpoint_data['last_processed'] + 1
    print(f"Resuming from row {start_idx} (found checkpoint)")
else:
    classifications = {}
    start_idx = 0
    print("Starting from beginning")

# Add classification column if not exists
if 'llm_label' not in df.columns:
    df['llm_label'] = None

# Process all rows with rate limiting
print(f"\nProcessing {len(df) - start_idx} rows...")
print(f"Rate limit: 30 requests/minute (~2 second delay between requests)")
print(f"Estimated time: ~{((len(df) - start_idx) * 2) / 60:.1f} minutes\n")

try:
    for idx in tqdm(range(start_idx, len(df)), desc="Classifying stories"):
        row = df.iloc[idx]
        title = row.get('title', '')
        text = row.get('text', '')
        
        # Classify
        label = classify_story(text, title)
        classifications[idx] = label
        df.at[idx, 'llm_label'] = label
        
        # Rate limiting: 30 req/min = 2 seconds between requests
        time.sleep(2)
        
        # Save checkpoint every N rows
        if (idx + 1) % checkpoint_interval == 0:
            checkpoint_data = {
                'last_processed': idx,
                'classifications': classifications
            }
            with open(checkpoint_file, 'w') as f:
                json.dump(checkpoint_data, f)
            
except KeyboardInterrupt:
    print("\n\nInterrupted! Saving checkpoint...")
    checkpoint_data = {
        'last_processed': idx - 1,
        'classifications': classifications
    }
    with open(checkpoint_file, 'w') as f:
        json.dump(checkpoint_data, f)
    print(f"Checkpoint saved. Processed up to row {idx}")
    
print("\n‚úÖ Classification complete!")

## Show Results

In [None]:
# Count results
keep_count = (df['llm_label'] == 'KEEP').sum()
remove_count = (df['llm_label'] == 'REMOVE').sum()

print(f"\n=== Classification Results ===")
print(f"Total stories: {len(df)}")
print(f"KEEP: {keep_count} ({keep_count/len(df)*100:.1f}%)")
print(f"REMOVE: {remove_count} ({remove_count/len(df)*100:.1f}%)")

# Show some examples of KEEP stories
print("\n=== Sample KEEP Stories ===")
keep_stories = df[df['llm_label'] == 'KEEP'].head(5)
for idx, row in keep_stories.iterrows():
    print(f"\nTitle: {row.get('title', 'N/A')}")
    print(f"Text: {row.get('text', '')[:200]}...")
    print("-" * 80)

## Save Filtered Dataset

In [None]:
# Filter to only KEEP stories
filtered_df = df[df['llm_label'] == 'KEEP'].copy()

# Drop the classification column (optional)
filtered_df = filtered_df.drop(columns=['llm_label'])

# Save
output_file = RAW_DATA_DIR / "reddit_confession_filtered.parquet"
filtered_df.to_parquet(output_file, index=False)

print(f"\n‚úÖ Saved {len(filtered_df)} filtered stories to:")
print(f"   {output_file}")

# Clean up checkpoint
if checkpoint_file.exists():
    checkpoint_file.unlink()
    print("\nüßπ Cleaned up checkpoint file")