# Tweet Extractor by Keywords

This notebook extracts tweets from the main dataset for:
1. Keywords from finetuning JSONs (auto-extracted)
2. Custom keywords you specify manually

Features:
- Primary search: exact match on `keyword` column
- Fallback search: text search in tweet content
- No duplicate tweets across all extractions

In [None]:
import pandas as pd
import re
import os
from pathlib import Path
from functools import reduce
import operator

# ============================================================================
# CONFIGURATION - Edit these values as needed
# ============================================================================

# Data paths
CSV_PATH = "/Users/ziv/Desktop/Partisan Discourse Documentation/final_data/tweets_exploded_by_keyword.csv"
JSON_DIR = "/Users/ziv/Desktop/Partisan Discourse Documentation/codes/4_finetuning/data_formatting/jsons"
OUT_DIR  = Path("extracted_by_keyword")

# Default N per class (used for JSON keywords)
DEFAULT_N_PER_CLASS = 600
SEED = 42

# ============================================================================
# CUSTOM KEYWORDS - Add your own keywords here!
# Format: {"keyword": n_per_class} or just "keyword" (uses DEFAULT_N_PER_CLASS)
# Examples:
#   CUSTOM_KEYWORDS = {"demonetization": 500, "article 370": 400, "nrc": 300}
#   CUSTOM_KEYWORDS = ["demonetization", "article 370"]  # uses default count
# Set to empty {} or [] to skip custom keywords
# ============================================================================
CUSTOM_KEYWORDS = {
    # Add your custom keywords here:
    # "your keyword": 500,
    # "another keyword": 300,
}

# Set to True to include keywords from JSON files, False to skip them
USE_JSON_KEYWORDS = True

# ============================================================================
# Column configuration (usually don't need to change)
# ============================================================================
POSSIBLE_TWEET_COLS = ("tweet", "text", "full_text", "content", "body")
KEYWORD_COL = "keyword"
LABEL_COL = "tweet_label"
TARGETS = ["pro ruling", "pro opposition"]

print(f"CSV Path: {CSV_PATH}")
print(f"JSON Dir: {JSON_DIR}")
print(f"Output Dir: {OUT_DIR}")
print(f"Default N per class: {DEFAULT_N_PER_CLASS}")

In [None]:
# Build the final keyword list with counts
# Format: [(keyword, n_per_class), ...]

KEYWORDS_WITH_COUNTS = []

# 1. Add keywords from JSON files if enabled
if USE_JSON_KEYWORDS and os.path.exists(JSON_DIR):
    json_files = [f for f in os.listdir(JSON_DIR) if f.endswith('.json')]
    print(f"Found {len(json_files)} JSON files:")
    
    for jf in sorted(json_files):
        match = re.match(r'kyra_(.+)_stance\.json', jf)
        if match:
            kw = match.group(1).replace('_', ' ')
            KEYWORDS_WITH_COUNTS.append((kw, DEFAULT_N_PER_CLASS))
            print(f"  [JSON] '{kw}' -> {DEFAULT_N_PER_CLASS} per class")
else:
    print("Skipping JSON keywords (USE_JSON_KEYWORDS=False or dir not found)")

# 2. Add custom keywords
if CUSTOM_KEYWORDS:
    print(f"\nAdding {len(CUSTOM_KEYWORDS)} custom keywords:")
    if isinstance(CUSTOM_KEYWORDS, dict):
        for kw, n in CUSTOM_KEYWORDS.items():
            KEYWORDS_WITH_COUNTS.append((kw.lower().strip(), n))
            print(f"  [CUSTOM] '{kw}' -> {n} per class")
    elif isinstance(CUSTOM_KEYWORDS, list):
        for kw in CUSTOM_KEYWORDS:
            KEYWORDS_WITH_COUNTS.append((kw.lower().strip(), DEFAULT_N_PER_CLASS))
            print(f"  [CUSTOM] '{kw}' -> {DEFAULT_N_PER_CLASS} per class (default)")

# Remove duplicate keywords (keep first occurrence with its count)
seen = set()
unique_keywords = []
for kw, n in KEYWORDS_WITH_COUNTS:
    if kw not in seen:
        seen.add(kw)
        unique_keywords.append((kw, n))

KEYWORDS_WITH_COUNTS = unique_keywords

print(f"\n{'='*60}")
print(f"TOTAL KEYWORDS TO EXTRACT: {len(KEYWORDS_WITH_COUNTS)}")
print(f"{'='*60}")
for kw, n in KEYWORDS_WITH_COUNTS:
    print(f"  - '{kw}': {n} per class")

In [None]:
# ---------- Helper functions ----------
def _norm_nospace(x):
    """Lowercase + drop all non-alphanumerics (incl. spaces)."""
    if isinstance(x, pd.Series):
        return (
            x.fillna("")
             .astype(str)
             .str.lower()
             .str.replace(r"[^a-z0-9]+", "", regex=True)
        )
    return re.sub(r"[^a-z0-9]+", "", str(x).lower())

def _phrase_variants(s: str) -> list:
    """
    Support ' or ' and '|' as OR separators inside a keyword/phrase.
    Returns the ORIGINAL (lowercased/trimmed) variants.
    """
    raw = str(s).strip()
    parts = re.split(r"\s+or\s+|\|", raw, flags=re.IGNORECASE)
    parts = [p.strip().lower() for p in parts if p.strip()]
    return parts if parts else [raw.lower().strip()]

def _any_contains_norm(tw_norm_series: pd.Series, raw_phrase: str) -> pd.Series:
    """
    Build a boolean mask: tweet contains ANY normalized variant of raw_phrase.
    This is the FALLBACK search - searches in tweet text.
    """
    variants = _phrase_variants(raw_phrase)
    variants_norm = [_norm_nospace(v) for v in variants]
    masks = [tw_norm_series.str.contains(re.escape(vn), regex=True) for vn in variants_norm]
    return reduce(operator.or_, masks) if masks else pd.Series(False, index=tw_norm_series.index)

In [None]:
# ---------- Load & prep ----------
print("Loading CSV... (this may take a while for large files)")
df = pd.read_csv(CSV_PATH, low_memory=False)
print(f"Loaded {len(df):,} rows")
print(f"Columns: {df.columns.tolist()}")

# choose tweet column
tweet_col = next((c for c in POSSIBLE_TWEET_COLS if c in df.columns), None)
if tweet_col is None:
    raise ValueError(f"Couldn't find a tweet/text column. Tried: {POSSIBLE_TWEET_COLS}.")
print(f"Tweet column: {tweet_col}")

# stable id
id_col = "source_row" if "source_row" in df.columns else None
if id_col is None:
    df["source_row"] = df.index
    id_col = "source_row"

# de-dup by tweet text
before_dedup = len(df)
df = df.drop_duplicates(subset=[tweet_col]).copy()
print(f"After dedup: {len(df):,} rows (removed {before_dedup - len(df):,} duplicates)")

In [None]:
# normalize labels to TARGETS
def normalize_label(x: str) -> str:
    if not isinstance(x, str): return "other"
    s = x.strip().lower()
    if re.search(r"\bpro[-_\s]*rul(?:ing)?\b", s): return "pro ruling"
    if re.search(r"\bpro[-_\s]*(opp|opposition)\b", s): return "pro opposition"
    return "other"

df["_label_norm"] = df[LABEL_COL].apply(normalize_label)
print(f"Label distribution (before filtering):")
print(df["_label_norm"].value_counts())

df = df[df["_label_norm"].isin(TARGETS)].copy()
print(f"\nAfter filtering to TARGETS: {len(df):,} rows")

# lowercase keyword col for primary match
if KEYWORD_COL not in df.columns:
    raise ValueError(f"Column '{KEYWORD_COL}' not found. Available: {list(df.columns)[:25]}")

df["_kw_lc"] = df[KEYWORD_COL].astype(str).str.strip().str.lower()

# normalized tweet text for fallback search
print("Normalizing tweet text for fallback search...")
tw_norm = _norm_nospace(df[tweet_col])
print("Done.")

In [None]:
# Global set to track all used tweet IDs across keywords (no duplicates)
GLOBAL_USED_IDS = set()

def sample_for_keyword(kw_raw: str, n_per_class: int, seed_base: int = 0) -> tuple:
    """
    Extract tweets for a keyword.
    
    Search strategy:
    1. PRIMARY: Exact match on keyword column
    2. FALLBACK: Text search in tweet content
    
    Ensures no duplicate tweets across all extractions via GLOBAL_USED_IDS.
    """
    global GLOBAL_USED_IDS
    
    # variants for this bucket
    kw_variants = _phrase_variants(kw_raw)

    # Exclude already-used tweets globally
    available_df = df[~df[id_col].isin(GLOBAL_USED_IDS)]
    available_tw_norm = tw_norm[~df[id_col].isin(GLOBAL_USED_IDS)]

    # PRIMARY pool = keyword column equals any variant
    pool_kw = available_df[available_df["_kw_lc"].isin(kw_variants)].copy()

    # FALLBACK pool = tweet text contains ANY normalized variant
    contains_any = _any_contains_norm(available_tw_norm, kw_raw)
    pool_fb = available_df[contains_any].copy()

    taken_ids = set()
    parts, stats = [], {}

    for i, label in enumerate(TARGETS):
        # exclude what's already taken in this keyword
        kw_cls = pool_kw[(pool_kw["_label_norm"] == label) & (~pool_kw[id_col].isin(taken_ids))]
        fb_cls = pool_fb[(pool_fb["_label_norm"] == label) & (~pool_fb[id_col].isin(taken_ids | set(kw_cls[id_col])))]

        need = n_per_class
        seed_i = SEED + seed_base + i * 13

        # PRIMARY: take from keyword column match
        got_kw = min(need, len(kw_cls))
        part_kw = kw_cls.sample(n=got_kw, random_state=seed_i, replace=False) if got_kw > 0 else kw_cls.head(0)

        # FALLBACK: take remaining from text search
        need_more = need - got_kw
        got_fb = min(need_more, len(fb_cls))
        part_fb = fb_cls.sample(n=got_fb, random_state=seed_i + 1, replace=False) if got_fb > 0 else fb_cls.head(0)

        pick = pd.concat([part_kw, part_fb], axis=0)

        parts.append(pick)
        taken_ids |= set(pick[id_col])

        stats[label] = {
            "requested": need,
            "picked_total": int(len(pick)),
            "from_keyword_col": int(len(part_kw)),
            "from_fallback_text": int(len(part_fb)),
            "short_by": int(max(0, need - len(pick))),
        }

    out_kw = pd.concat(parts, axis=0)
    if not out_kw.empty:
        out_kw = out_kw.sample(frac=1.0, random_state=SEED + seed_base).reset_index(drop=True)

    # integrity checks
    assert out_kw[id_col].nunique() == len(out_kw), f"[{kw_raw}] duplicate IDs"
    assert out_kw[tweet_col].nunique() == len(out_kw), f"[{kw_raw}] duplicate tweets"

    # Add to global used set
    GLOBAL_USED_IDS |= set(out_kw[id_col])

    # Overwrite keyword column with canonical keyword
    canonical = kw_variants[0] if kw_variants else str(kw_raw).strip().lower()
    out_kw[KEYWORD_COL] = canonical

    return out_kw, stats

In [None]:
# --- Run extraction for all keywords ---
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Reset global tracking
GLOBAL_USED_IDS = set()

combined = []
reports = {}

print("=" * 80)
print("EXTRACTING TWEETS BY KEYWORD")
print("=" * 80)

for idx, (kw, n_per_class) in enumerate(KEYWORDS_WITH_COUNTS):
    out_kw, stat_kw = sample_for_keyword(kw, n_per_class, seed_base=idx * 101)

    combined.append(out_kw)
    reports[kw] = stat_kw

    # write per-keyword files
    cols_out = [id_col, tweet_col, LABEL_COL, "_label_norm", KEYWORD_COL, "subjects_scored"]
    cols_out = [c for c in cols_out if c in out_kw.columns]
    canonical_name = _phrase_variants(kw)[0].replace(" ", "_")
    out_csv = OUT_DIR / f"extracted_{canonical_name}.csv"
    out_ids = OUT_DIR / f"extracted_{canonical_name}_ids.txt"

    out_kw[cols_out].to_csv(out_csv, index=False)
    with open(out_ids, "w", encoding="utf-8") as f:
        for v in out_kw[id_col].tolist():
            f.write(f"{v}\n")

    # Status with breakdown
    picked = {lbl: stat_kw[lbl]["picked_total"] for lbl in TARGETS}
    from_kw = {lbl: stat_kw[lbl]["from_keyword_col"] for lbl in TARGETS}
    from_fb = {lbl: stat_kw[lbl]["from_fallback_text"] for lbl in TARGETS}
    print(f"[OK] '{kw}' (req: {n_per_class}/class)")
    print(f"     -> picked: {picked}, from_keyword_col: {from_kw}, from_fallback: {from_fb}")

print("\n" + "=" * 80)

In [None]:
# Combined outputs
all_out = pd.concat(combined, axis=0).reset_index(drop=True) if combined else pd.DataFrame()
cols_out_all = [id_col, tweet_col, LABEL_COL, "_label_norm", KEYWORD_COL, "subjects_scored"]
cols_out_all = [c for c in cols_out_all if c in all_out.columns]

total_rows = len(all_out)
all_csv = OUT_DIR / f"extracted_ALL_{len(KEYWORDS_WITH_COUNTS)}keywords_{total_rows}rows.csv"
all_ids = OUT_DIR / f"extracted_ALL_{len(KEYWORDS_WITH_COUNTS)}keywords_ids.txt"

all_out[cols_out_all].to_csv(all_csv, index=False)
with open(all_ids, "w", encoding="utf-8") as f:
    for v in all_out[id_col].tolist():
        f.write(f"{v}\n")

print(f"[OK] Combined: {all_csv} (rows={total_rows})")

In [None]:
# Summary report
print("\n" + "=" * 80)
print("SUMMARY")
print("=" * 80)

for kw, stat in reports.items():
    short = {lbl: s["short_by"] for lbl, s in stat.items()}
    picked = {lbl: s["picked_total"] for lbl, s in stat.items()}
    total_picked = sum(picked.values())
    from_kw_total = sum(s["from_keyword_col"] for s in stat.values())
    from_fb_total = sum(s["from_fallback_text"] for s in stat.values())
    print(f"  '{kw}': total={total_picked} (keyword_col={from_kw_total}, fallback={from_fb_total}), short_by={short}")

print(f"\nâœ… All files saved to: {OUT_DIR}/")
print(f"   Total unique tweets extracted: {total_rows:,}")
print(f"   Keywords processed: {len(KEYWORDS_WITH_COUNTS)}")