## Imports

In [None]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report
from sklearn.model_selection import train_test_split
import re
import warnings
warnings.filterwarnings('ignore')

from transformers import (
    DistilBertTokenizer,
    DistilBertForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding
)
from datasets import Dataset, DatasetDict, load_dataset
from peft import LoraConfig, get_peft_model, TaskType

import random
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

print("All libraries loaded!")

## Load the dataset
Gathered a bunch of datasets because they've been failing. Will select the first one that works

In [None]:
# Try multiple datasets in order of preference
dataset = None
dataset_name = None

datasets_to_try = [
    ("raquiba/Sarcasm_News_Headline", "News Headlines Sarcasm Dataset"),
    ("helinivan/english-sarcasm-detector", "English Sarcasm Detector Dataset"),
    ("siddhant4583agarwal/sarcasm-detection-dataset", "Sarcasm Detection Dataset")
]

for dataset_id, description in datasets_to_try:
    try:
        print(f"Trying to load {description}...")
        dataset = load_dataset(dataset_id)
        dataset_name = description
        print(f"{description} loaded successfully!")
        break
    except Exception as e:
        print(f"Failed to load {description}: {e}")
        continue

# Fallback: make synthetic dataset if nothing works
if dataset is None:
    print("None of the datasets worked. Making a synthetic one.")
    synthetic_data = {
        "headline": [
            "Area Man Constantly Mentioning He Doesn't Own a Television",
            "Local Woman Takes Up Jogging, Tells Everyone",
            "Breaking: Rain Causes Things to Get Wet",
            "Scientists Discover Fire Hot, Water Wet",
            "Man Who Plays Guitar at Parties Loses Guitar Privileges",
            "President Signs Bill Into Law",
            "New Study Shows Exercise Beneficial for Health"
        ],
        "is_sarcastic": [1, 1, 1, 1, 1, 0, 0]
    }
    dataset = {"train": Dataset.from_dict(synthetic_data)}
    dataset_name = "Synthetic Dataset"
    print("Synthetic dataset created.")

## Split the dataset
Some potential datasets have premade train/test splits, so we check for that. If not, we make one ourself.

In [None]:
# Decide which part of the dataset to use
if 'train' in dataset:
    main_data = dataset['train']
elif 'test' in dataset:
    main_data = dataset['test']
else:
    main_data = dataset

print(f"Using: {dataset_name}")
print(f"Dataset size: {len(main_data)}")

# If dataset is really big, just take a subset
if len(main_data) > 50000:
    print("Using a 20k sample for speed.")
    main_data = main_data.shuffle(seed=42).select(range(20000))

# Make train/test split
dataset = main_data.train_test_split(test_size=0.2, seed=42)
print(f"Train size: {len(dataset['train'])}")
print(f"Test size: {len(dataset['test'])}")

## Preprocess the text
Data cleaning/preprocessing

In [None]:
def clean_text(text):
    if pd.isna(text):
        return ""
    text = str(text).lower()
    text = re.sub(r'http\S+|www.\S+', '', text)
    text = ' '.join(text.split())
    return text

def preprocess_dataset(dataset):
    df = dataset.to_pandas()
    text_col = next((c for c in ['text', 'comment', 'headline', 'sentence', 'content'] if c in df.columns), None)
    label_col = next((c for c in ['label', 'labels', 'sarcastic', 'is_sarcastic', 'target'] if c in df.columns), None)

    if text_col is None:
        raise ValueError("Can't find a text column.")

    if label_col is None:
        print("No label col found — generating fake labels.")
        df['labels'] = df[text_col].str.contains(r'\b(great|awesome|fantastic|perfect)\b', case=False).astype(int)
    else:
        df = df.dropna(subset=[text_col, label_col])
        df['labels'] = df[label_col].astype(int)

    df['text'] = df[text_col].apply(clean_text)
    df = df[df['text'].str.len() > 0]
    return df[['text', 'labels']]

train_df = preprocess_dataset(dataset['train'])
test_df = preprocess_dataset(dataset['test'])

## Tokenization and formatting
Tokenization + prepping for PyTorch

In [None]:
from sklearn.model_selection import train_test_split

model_name = "distilbert-base-uncased"
tokenizer = DistilBertTokenizer.from_pretrained(model_name)

def tokenize_function(examples):
    return tokenizer(
        examples['text'],
        truncation=True,
        padding=True,
        max_length=512
    )

train_texts, val_texts, train_labels, val_labels = train_test_split(
    train_df['text'].tolist(),
    train_df['labels'].tolist(),
    test_size=0.2,
    random_state=42,
    stratify=train_df['labels']
)

train_dataset = Dataset.from_dict({'text': train_texts, 'labels': train_labels})
val_dataset = Dataset.from_dict({'text': val_texts, 'labels': val_labels})
test_dataset = Dataset.from_dict({'text': test_df['text'].tolist(), 'labels': test_df['labels'].tolist()})

train_dataset = train_dataset.map(tokenize_function, batched=True)
val_dataset = val_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)

for ds in [train_dataset, val_dataset, test_dataset]:
    ds.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

## Model setup with LoRA
Choosing LoRA for PEFT

In [None]:
model = DistilBertForSequenceClassification.from_pretrained(
    model_name,
    num_labels=2,
    id2label={0: "not_sarcastic", 1: "sarcastic"},
    label2id={"not_sarcastic": 0, "sarcastic": 1}
)

lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_lin", "v_lin"]
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

## Training the model

In [None]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='weighted')
    accuracy = accuracy_score(labels, predictions)
    return {'accuracy': accuracy, 'f1': f1, 'precision': precision, 'recall': recall}

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=100,
    evaluation_strategy="steps",
    eval_steps=500,
    save_strategy="steps",
    save_steps=1000,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    report_to=[],
    remove_unused_columns=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

## Evaluate the model

In [None]:
predictions = trainer.predict(test_dataset)
y_pred = np.argmax(predictions.predictions, axis=1)
y_true = predictions.label_ids

accuracy = accuracy_score(y_true, y_pred)
precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')

print(f"Accuracy: {accuracy:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(classification_report(y_true, y_pred))

## Confusion Matrix

In [None]:
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(6,5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['Not Sarcastic', 'Sarcastic'],
            yticklabels=['Not Sarcastic', 'Sarcastic'])
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.show()

## Try some sample inputs

In [None]:
def predict_sarcasm(text, model, tokenizer):
    text = clean_text(text)
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
        pred = torch.argmax(probs, dim=-1).item()
        conf = probs[0][pred].item()
    return pred, conf

examples = [
    "How about that.",
    "The weather is really nice today.",
    "I love when my phone dies!"
]

for i, ex in enumerate(examples):
    label, conf = predict_sarcasm(ex, model, tokenizer)
    print(f"{i+1}. {ex} --> {'Sarcastic' if label==1 else 'Not Sarcastic'} ({conf:.2f})")

## Save the model

In [None]:
model.save_pretrained("./sarcasm_detector_lora")
tokenizer.save_pretrained("./sarcasm_detector_lora")
print("Saved!")