 # Step 2C: Refine Labels with LLM Classifier



 Validate naive model predictions using GPT-4.1-mini:



 1. Process: Load top 5000 high-confidence examples per class

 2. Classify: Use GPT-4.1-mini to reclassify each tweet (single-word response)

 3. Compare: Identify true/false positives/negatives between naive and LLM labels

 4. Save: Store results with agreement status (TP/TN/FP/FN) to 'data/llm_refined_labels.csv'



 This step helps validate the quality of our naive classifier's predictions.

In [1]:
# Import necessary libraries
import os
import json
import pandas as pd
import openai
from tqdm import tqdm
import time


In [2]:
from google.colab import userdata # Import userdata to access secrets

# Load OpenAI API key from Colab secrets
openai.api_key = userdata.get('OpenAI_API_Key')

# Initialize OpenAI client
client = openai.OpenAI(api_key=openai.api_key)


In [3]:
# Load high-confidence predictions
high_conf = pd.read_csv('high_confidence_real_examples.csv')


In [4]:
# Select top 5000 for each class by probability
literal_top = high_conf[high_conf['pred_label'] == 'literal'].nlargest(50, 'prob_0')
sarcastic_top = high_conf[high_conf['pred_label'] == 'sarcastic'].nlargest(50, 'prob_1')
top_examples = pd.concat([literal_top, sarcastic_top], ignore_index=True)


In [5]:
# Function to classify with LLM (GPT-4.1-mini)
def classify_with_llm(text, max_retries=3, wait_time=0.5):
    prompt = (
        "Classify the following tweet as either 'literal' or 'sarcastic'. "
        "Respond with only one word: literal or sarcastic.\n"
        f"Tweet: {text}\nLabel:"
    )
    for attempt in range(max_retries):
        try:
            response = client.chat.completions.create(
                model="gpt-4.1-mini",
                messages=[
                    {"role": "system", "content": "You are a helpful assistant that classifies tweets as literal or sarcastic. Respond with only one word: literal or sarcastic."},
                    {"role": "user", "content": prompt}
                ],
                temperature=0,
                max_tokens=1
            )
            label = response.choices[0].message.content.strip().lower()
            # Enforce single-word response
            if 'lit' in label:
                return 'literal'
            elif 'sar' in label:
                return 'sarcastic'
            else:
                return label
        except Exception as e:
            if attempt < max_retries - 1:
                print(f"Attempt {attempt + 1} failed: {str(e)}. Retrying in {wait_time} seconds...")
                time.sleep(wait_time)
                wait_time *= 2
            else:
                print(f"All {max_retries} attempts failed. Last error: {str(e)}")
                return None


In [6]:
# Apply LLM classification (this may take a while and cost tokens)
llm_labels = []
for text in tqdm(top_examples['text'], desc='Classifying with LLM'):
    llm_labels.append(classify_with_llm(text))
top_examples['llm_label'] = llm_labels



Classifying with LLM: 100%|██████████| 100/100 [00:51<00:00,  1.94it/s]


In [7]:
# Identify TP, TN, FP, FN, Other
def get_tp_tn(row):
    if row['pred_label'] == 'sarcastic' and row['llm_label'] == 'sarcastic':
        return 'TP'
    elif row['pred_label'] == 'literal' and row['llm_label'] == 'literal':
        return 'TN'
    elif row['pred_label'] == 'sarcastic' and row['llm_label'] == 'literal':
        return 'FP'
    elif row['pred_label'] == 'literal' and row['llm_label'] == 'sarcastic':
        return 'FN'
    else:
        return 'Other'

top_examples['llm_agreement'] = top_examples.apply(get_tp_tn, axis=1)


Uncomment the cell below to save the file

In [8]:
# Save all examples (not just TP/TN)
#output_path = 'llm_refined_labels.csv'
#top_examples.to_csv(output_path, index=False)
#print(f'All LLM-refined labels saved to {output_path}')


All LLM-refined labels saved to llm_refined_labels.csv


In [9]:
# Print summary: true/false sarcastic and literal
num_true_sarcastic = (top_examples['llm_agreement'] == 'TP').sum()
num_false_sarcastic = (top_examples['llm_agreement'] == 'FP').sum()
num_true_literal = (top_examples['llm_agreement'] == 'TN').sum()
num_false_literal = (top_examples['llm_agreement'] == 'FN').sum()
print(f"True sarcastic: {num_true_sarcastic} (false sarcastic: {num_false_sarcastic})")
print(f"True literal: {num_true_literal} (false literal: {num_false_literal})")


True sarcastic: 45 (false sarcastic: 5)
True literal: 40 (false literal: 10)
