In [None]:
import pandas as pd
from datasets import Dataset
from google import genai
from google.genai import types
import time
from dataclasses import dataclass
from typing import List, Dict, Any

In [None]:
# upload the csv of your assigned chunk (you can find all the chunks on the google drive)

In [None]:
import os
KEY = "..."
client = genai.Client(api_key= KEY)

In [None]:
df = pd.read_csv('reddit_posts_chunk_3.csv') #replace with your own chunk

In [None]:
df.shape

(2500, 1)

In [None]:
reddit_posts = df['reddit_posts'].tolist()[:1000]

In [None]:
len(reddit_posts)

1000

In [None]:
categories = ['Security & Fraud', 'Online Banking', 'ATM Service',
       'Account Opening', 'Rewards & Offers', 'Website Experience',
       'Fees & Charges', 'International Transfers', 'Interest Rates',
       'App Performance', 'Customer Service', 'Loan Process',
       'Notifications', 'Credit Card', 'Branch Experience']

In [None]:
@dataclass
class RedditClassification:
    category: str
    post_text: str

def get_reddit_schema() -> Dict[str, Any]:
    """Generates the explicit JSON Schema for a list of RedditClassification objects."""
    return {
        "type": "array",
        "items": {
            "type": "object",
            "properties": {
                "category": {
                    "type": "string",
                    "description": "The classified category name for the Reddit post."
                },
                "post_text": {
                    "type": "string",
                    "description": "The original text of the classified post."
                }
            },
            "required": ["category", "post_text"]
        }
    }

In [None]:
def classify_posts_batch(posts: List[str]) -> List[RedditClassification]:
    """
    Classifies a batch of Reddit posts using a single API call.
    """

    # 1. Define the input data format for the model
    posts_formatted = '\n'.join([f"- Post {i+1}: \"\"\"{post}\"\"\"" for i, post in enumerate(posts)])

    # 2. Update the prompt to ask for a list of classifications
    prompt = f"""
You are an expert Reddit post classifier.

Classify EACH of the following Reddit posts into one of these categories.
Return a JSON list where each object in the list contains the 'category' and the original 'post_text'.

Available Categories:
{categories}

Reddit Posts to Classify:
{posts_formatted}
"""

    # 3. Use the corrected response_schema for a list of results
    try:
        response = client.models.generate_content(
            model="gemini-2.5-flash",
            contents=prompt,
            config={
                "response_mime_type": "application/json",
                # ðŸš¨ Use the explicit dictionary schema here!
                "response_schema": get_reddit_schema()
            },
        )
        # Because we used a manual schema, we MUST manually parse the output
        # The output_parsed will NOT work automatically here.
        import json

        # The response text should be a valid JSON string (a list of dictionaries)
        json_output = json.loads(response.text)

        # Convert the list of dicts back into a list of dataclass objects
        parsed_results = [RedditClassification(**item) for item in json_output]

        return parsed_results

    except Exception as e:
        print(f"Error classifying batch: {e}")
        return []

In [None]:
results: List[RedditClassification] = []
batch_size = 50

print(f"Starting classification of {len(reddit_posts)} posts in batches of {batch_size}...")

for i in range(0, len(reddit_posts), batch_size):
    batch = reddit_posts[i:i+batch_size]
    print(f"Processing batch {i//batch_size + 1}: posts {i+1} to {i + len(batch)}")

    # Process the entire batch in a single API call
    batch_results = classify_posts_batch(batch)
    results.extend(batch_results)

    # A small delay is still good practice for very large workloads
    time.sleep(0.5)

print("Classification complete!")

Starting classification of 1000 posts in batches of 50...
Processing batch 1: posts 1 to 50
Processing batch 2: posts 51 to 100
Processing batch 3: posts 101 to 150
Processing batch 4: posts 151 to 200
Processing batch 5: posts 201 to 250
Processing batch 6: posts 251 to 300
Processing batch 7: posts 301 to 350
Processing batch 8: posts 351 to 400
Processing batch 9: posts 401 to 450
Processing batch 10: posts 451 to 500
Processing batch 11: posts 501 to 550
Processing batch 12: posts 551 to 600
Processing batch 13: posts 601 to 650
Processing batch 14: posts 651 to 700
Processing batch 15: posts 701 to 750
Processing batch 16: posts 751 to 800
Processing batch 17: posts 801 to 850
Processing batch 18: posts 851 to 900
Processing batch 19: posts 901 to 950
Processing batch 20: posts 951 to 1000
Classification complete!


In [None]:
len(results)

1000

In [None]:
df.head()

Unnamed: 0,reddit_posts
0,My bills/expenses amount to around $1400/mo. I...
1,My wife and I have recently transitioned to an...
2,I recently got an alert from my bank warning m...
3,I was talking to someone about this and I'm no...
4,im planning to have a credit card (i have debi...


In [None]:
results2 = [r.category for r in results]
df = pd.DataFrame({
    "post": reddit_posts,
    "category": results2
})

In [None]:
df.shape

(1000, 2)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
df.to_csv('/content/drive/MyDrive/reddit_labeled_chunk3.csv', index=False) # rename the file name to match your own chunk