In [1]:
import os
import json
import math
import itertools
from collections import defaultdict
from tqdm import tqdm

import numpy as np
import networkx as nx
from tenacity import retry, stop_after_attempt, wait_random_exponential
import time
from sentence_transformers import SentenceTransformer
from sklearn.cluster import AgglomerativeClustering
from sentence_transformers import CrossEncoder
from openai import OpenAI, BadRequestError
from openai.types.chat import ChatCompletion
import yaml
import dataclasses
class MinimumDelay:
    def __init__(self, delay: float | int):
        self.delay = delay
        self.start = None

    def __enter__(self):
        self.start = time.time()

    def __exit__(self, exc_type, exc_val, exc_tb):
        end = time.time()
        seconds = end - self.start
        if self.delay > seconds:
            time.sleep(self.delay - seconds)

@retry(wait=wait_random_exponential(min=1, max=90), stop=stop_after_attempt(3))
def chat(client: OpenAI, delay: float | int, **kwargs) -> ChatCompletion | None:
    try:
        with MinimumDelay(delay):
            return client.chat.completions.create(**kwargs)
    except BadRequestError as e:
        print(f"Bad Request: {e}")
        if "safety" in e.message:
            return None
        raise e
    except Exception as e:
        print(f"Exception: {e}")
        raise e

  from .autonotebook import tqdm as notebook_tqdm


In [43]:
UNIQUE_FRAMES_OUT = "unique_frames.jsonl"      # intermediate
MERGED_FRAMES_OUT = "merged_frames.jsonl"      # final output
FINAL_FRAMES_IN = "final_frames.jsonl"         # your input file (per example)
EMBED_MODEL_NAME = "all-mpnet-base-v2"         # stronger embedder
CROSS_ENCODER_MODEL = "cross-encoder/stsb-roberta-large"  # optional
EMBED_BATCH = 64
CROSS_ENCODER_AVAILABLE = True
# Clustering params
DISTANCE_THRESHOLD = 0.28   # cosine distance threshold for AgglomerativeClustering
MIN_CLUSTER_SIZE = 2        # ignore clusters of size 1 for pairwise LLM checking

# Cross-encoder prefilter threshold (higher is more strict)
CROSS_ENCODER_THRESHOLD = 0.75   # scores in [0,1] roughly for semantic similarity

# LLM verification: number of pairs to verify per cluster (if cluster big, we will prune)
MAX_PAIRS_TO_VERIFY = 400   # safety cap (choose top by cross-encoder score if available)

In [2]:
# -----------------------------
# Utilities: JSONL read/write
# -----------------------------
def read_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            yield json.loads(line)

def write_jsonl(path, items):
    with open(path, "w", encoding="utf-8") as f:
        for it in items:
            f.write(json.dumps(it, ensure_ascii=False) + "\n")

In [67]:
def build_unique_frames(final_frames_path: str,
                        problems_path: str,
                        out_path: str):
    """
    Create a JSONL file of unique frames with counts, linked targets/intents/actions
    from problems.jsonl, and a list of source_ids (image ids that produced the frame).
    """
    import json

    # --- Load problems.jsonl into a dict keyed by id ---
    id2problems = {}
    with open(problems_path, "r", encoding="utf-8") as f:
        for line in f:
            rec = json.loads(line)
            pid = rec["id"]
            resp = json.loads(rec["response"])
            id2problems[pid] = {
                "targets":  [t["category"] for t in resp["hate_problems"].get("targets", [])],
                "intents":  [i["category"] for i in resp["hate_problems"].get("intents", [])],
                "actions":  [a["category"] for a in resp["hate_problems"].get("actions", [])]
            }

    # --- Collect and count unique frame texts ---
    frame_map = {}  # frame_text -> {frame_id, count, targets, intents, actions, source_ids}
    counter = 1

    with open(final_frames_path, "r", encoding="utf-8") as f:
        for line in f:
            rec = json.loads(line)
            img_id = rec["id"]
            resp = json.loads(rec["response"])
            for fr in resp.get("hate_frames", []):
                text = fr["frame"].strip()
                if text not in frame_map:
                    frame_map[text] = {
                        "frame_id": f"f{counter}",
                        "count": 0,
                        "targets": id2problems.get(img_id, {}).get("targets", []),
                        "intents": id2problems.get(img_id, {}).get("intents", []),
                        "actions": id2problems.get(img_id, {}).get("actions", []),
                        "source_ids": set()      # <-- track contributing image IDs
                    }
                    counter += 1
                frame_map[text]["count"] += 1
                frame_map[text]["source_ids"].add(img_id)

    # --- Write unique frames & return list ---
    unique_list = []
    with open(out_path, "w", encoding="utf-8") as out_f:
        for text, data in frame_map.items():
            rec = {
                "text": text,
                "frame_id": data["frame_id"],
                "count": data["count"],
                "targets": data["targets"],
                "intents": data["intents"],
                "actions": data["actions"],
                "source_ids": sorted(list(data["source_ids"]))  # convert set to sorted list
            }
            out_f.write(json.dumps(rec) + "\n")
            unique_list.append(rec)

    return unique_list


In [68]:
# -----------------------------
# Step 2: Embeddings + clustering grouped by target
# -----------------------------
def embed_texts(model, texts, batch_size=EMBED_BATCH):
    embs = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        e = model.encode(batch, convert_to_numpy=True, show_progress_bar=False)
        embs.append(e)
    return np.vstack(embs)

def cluster_within_target(unique_frames, embed_model_name=EMBED_MODEL_NAME,
                          distance_threshold=DISTANCE_THRESHOLD):
    model = SentenceTransformer(embed_model_name)
    # map target -> list of indices in unique_frames
    target_to_idxs = defaultdict(list)
    for idx, uf in enumerate(unique_frames):
        # If a frame has multiple targets, we duplicate its index into each target group.
        # This means it can be considered for paraphrase merging within each relevant target.
        for t in uf.get("targets", []) or ["__no_target__"]:
            target_to_idxs[t].append(idx)

    # For each target group, embed its frames and cluster
    results = {}  # target -> { 'indices': [...], 'labels': array }
    for target, idxs in tqdm(target_to_idxs.items(), desc="Clustering targets"):
        if len(idxs) == 0:
            continue
        texts = [unique_frames[i]["text"] for i in idxs]
        embeddings = embed_texts(model, texts)
        # normalize for cosine distance
        norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
        norms[norms==0] = 1.0
        embeddings = embeddings / norms

        # If only one element, give it its own label 0
        if len(idxs) == 1:
            labels = np.array([0], dtype=int)
        else:
            clustering = AgglomerativeClustering(
                n_clusters=None,
                distance_threshold=distance_threshold,
                metric='cosine',
                linkage='average'
            )
            labels = clustering.fit_predict(embeddings)
        results[target] = {
            "indices": idxs,
            "labels": labels,
            "embeddings": embeddings,   # parallel to texts
            "texts": texts
        }
    return results

In [69]:
# -----------------------------
# Step 3: Hybrid verification (Cross-encoder prefilter -> LLM verify)
# -----------------------------
def make_pair_list_from_cluster(idxs, labels):
    """
    Given indices (global indices into unique_frames) and cluster labels per element,
    yield lists of clusters: each cluster -> list of global indices.
    """
    clusters = defaultdict(list)
    for i_local, label in enumerate(labels):
        clusters[label].append(idxs[i_local])
    # return only clusters of size >= 2
    return [c for c in clusters.values() if len(c) >= 2]

# CROSS-ENCODER helper (optional)
def init_cross_encoder(model_name=CROSS_ENCODER_MODEL):
    if not CROSS_ENCODER_AVAILABLE:
        print("[init_cross_encoder] cross-encoder not available (pip install cross-encoder). Skipping.")
        return None
    print(f"[init_cross_encoder] loading {model_name}")
    return CrossEncoder(model_name)

def score_pairs_with_crossencoder(cross_encoder, pairs):
    # pairs: list of (text1, text2)
    if cross_encoder is None:
        return [1.0] * len(pairs)  # neutral score (force LLM check if desired)
    texts = pairs
    scores = cross_encoder.predict(texts, show_progress_bar=False)
    # normalize to 0-1 if needed (most models give cosine-like 0-1)
    return scores

In [88]:
import json
import time
from collections import defaultdict, deque

def call_llm_verify_relation(
    frames=None,                # preferred (matches caller)
    frames_list=None,           # back-compat alias
    client=None,
    config=None,
    system_prompt=None,         # optional override; otherwise use config.system_prompt
    user_prompt=None,           # optional extra line(s) before the listing
    timeout=60,
):
    """
    Uses YAML-provided response_format and system_prompt to get pairwise
    paraphrases/contradictions, then converts to grouped form:

    Returns:
        {
          "paraphrase_groups": [[ids...], ...],
          "contradiction_groups": [["idA","idB"], ...]
        }
    """
    # -------- Resolve inputs --------
    frames_data = frames or frames_list
    if not frames_data:
        return {"paraphrase_groups": [], "contradiction_groups": []}

    # Keep a mapping so we can normalize IDs the model returns (string vs int)
    id_map = {}
    for fid, _text in frames_data:
        id_map[str(fid)] = fid
        id_map[fid] = fid

    def _normalize_id(x):
        return id_map.get(x, id_map.get(str(x), x))

    # -------- Build messages --------
    # Use the system_prompt exactly as provided by YAML unless an explicit override is passed in.
    sys_msg = (system_prompt or getattr(config, "system_prompt", "") or "").strip()

    # The YAML system prompt already contains all instructions and JSON schema.
    # So we only pass the optional user_prompt and the list of frames.
    listing = "\n".join(f"{fid}: {text}" for fid, text in frames_data)
    user_msg = ""
    if user_prompt:
        user_msg += user_prompt.strip() + "\n"
    user_msg += listing

    # -------- Call LLM (honor YAML-configured params) --------
    model        = getattr(config, "model",        "gpt-5-mini-2025-08-07")
    temperature  = getattr(config, "temperature",  1.0)
    max_tokens   = getattr(config, "max_tokens",   1024)
    seed         = getattr(config, "seed",         None)
    response_fmt = getattr(config, "response_format", None)
    call_delay   = getattr(config, "delay",        1)

    completion = chat(client,
        model=model,
        messages=[
            {"role": "system", "content": sys_msg},
            {"role": "user", "content": user_msg},
        ],
        delay=call_delay,
        temperature=temperature,
        seed=seed,
        timeout=timeout,
        **({"response_format": response_fmt} if response_fmt else {}),
    )

    # -------- Parse response robustly --------
    # Some SDKs expose structured JSON at choices[0].message.parsed
    msg = completion.choices[0].message
    parsed = getattr(msg, "parsed", None)

    if parsed is None:
        # Fall back to JSON string content
        raw = (msg.content or "").strip()

        def _safe_load(s):
            try:
                return json.loads(s)
            except Exception:
                # try brace slicing
                try:
                    i, j = s.find("{"), s.rfind("}")
                    if i != -1 and j != -1 and j > i:
                        return json.loads(s[i:j+1])
                except Exception:
                    return None
            return None

        parsed = _safe_load(raw)

    if not isinstance(parsed, dict):
        return {"paraphrase_groups": [], "contradiction_groups": []}

    # Expect schema:
    # {
    #   "paraphrases": [["id1","id2"], ...],
    #   "contradictions": [["id1","id2"], ...]
    # }
    paraphrases = parsed.get("paraphrases", []) or []
    contradictions = parsed.get("contradictions", []) or []

    # Normalize pairs: tolerate objects with id1/id2 (in case schema varies)
    def _iter_pairs(seq):
        for item in seq:
            if isinstance(item, (list, tuple)) and len(item) == 2:
                yield _normalize_id(item[0]), _normalize_id(item[1])
            elif isinstance(item, dict) and "id1" in item and "id2" in item:
                yield _normalize_id(item["id1"]), _normalize_id(item["id2"])

    para_pairs = [(a, b) for a, b in _iter_pairs(paraphrases) if a in id_map.values() and b in id_map.values() and a != b]
    contra_pairs = [(a, b) for a, b in _iter_pairs(contradictions) if a in id_map.values() and b in id_map.values() and a != b]

    # -------- Convert to groups --------
    # Paraphrases: merge transitively (connected components)
    from collections import defaultdict, deque
    graph = defaultdict(set)
    for a, b in para_pairs:
        graph[a].add(b); graph[b].add(a)

    seen, para_groups = set(), []
    for node in list(graph.keys()):
        if node in seen:
            continue
        comp, dq = [], deque([node])
        seen.add(node)
        while dq:
            cur = dq.popleft()
            comp.append(cur)
            for nb in graph[cur]:
                if nb not in seen:
                    seen.add(nb)
                    dq.append(nb)
        if len(comp) >= 2:
            comp_sorted = sorted(comp, key=lambda z: str(z))
            para_groups.append(comp_sorted)

    # Contradictions: keep as 2-item groups (no transitive closure)
    # Dedup by sorted tuple
    contra_set = set()
    for a, b in contra_pairs:
        key = tuple(sorted((a, b), key=lambda z: str(z)))
        contra_set.add(key)
    contradiction_groups = [list(p) for p in sorted(contra_set, key=lambda t: (str(t[0]), str(t[1])))]

    # Dedup paraphrase groups by sorted tuple of IDs
    def _dedup_groups(groups):
        seen_keys, out = set(), []
        for g in groups:
            key = tuple(sorted(g, key=lambda z: str(z)))
            if key not in seen_keys:
                seen_keys.add(key)
                out.append(list(key))
        return out

    para_groups = _dedup_groups(para_groups)

    return {
        "paraphrase_groups": para_groups,
        "contradiction_groups": contradiction_groups,
    }



In [89]:
# -----------------------------
# Step 4: Build paraphrase graph and merge
# -----------------------------
from collections import defaultdict
import itertools
import networkx as nx
from tqdm import tqdm


def perform_hybrid_paradetection(
    unique_frames,
    clusters_by_target,
    client=None,
    config=None,
    cross_encoder=None,
    cross_encoder_threshold=0.75,
    batch_size=100,
):
    """
    Hybrid paraphrase/contradiction detection with batched LLM calls.

    * High cross-encoder scores (>= threshold) are accepted immediately as paraphrases.
    * Remaining low-confidence frames are grouped into batches of size `batch_size`
      and sent once per batch to the LLM. The LLM is expected to return *groups* of
      paraphrase IDs and contradiction IDs.

    Returns
    -------
    relations : list of dicts
        Each dict:
        { "type": "paraphrase"/"contradiction",
          "x": global_index,
          "y": global_index,
          "reasoning": str (optional),
          "score": float or None }
    """
    relations = []

    for target, data in tqdm(clusters_by_target.items(),
                             desc="Hybrid detection per target"):
        idxs = data["indices"]
        labels = data["labels"]

        cluster_lists = make_pair_list_from_cluster(idxs, labels)

        for cluster in cluster_lists:
            pairs = list(itertools.combinations(cluster, 2))
            if not pairs:
                continue

            # 1) Cross-encoder filtering
            low_pairs = []
            if cross_encoder is not None:
                pair_texts = [
                    (unique_frames[i]["text"], unique_frames[j]["text"])
                    for i, j in pairs
                ]
                scores = score_pairs_with_crossencoder(cross_encoder, pair_texts)
                for (i, j), s in zip(pairs, scores):
                    if s >= cross_encoder_threshold:
                        relations.append({
                            "type": "paraphrase",
                            "x": i, "y": j,
                            "reasoning": f"auto-accepted (score={s:.3f})",
                            "score": float(s),
                        })
                    else:
                        low_pairs.append((i, j))
            else:
                low_pairs = pairs

            if not low_pairs:
                continue

            # 2) Collect unique frame IDs from remaining pairs
            unique_ids = sorted(set(itertools.chain.from_iterable(low_pairs)))

            # 3) Batch call to LLM, which returns *groups*
            for start in range(0, len(unique_ids), batch_size):
                batch_ids = unique_ids[start:start + batch_size]
                frames_for_batch = [
                    (fid, unique_frames[fid]["text"]) for fid in batch_ids
                ]

                # call_llm_verify_relation must now return a dict with
                #   {"paraphrase_groups": [[ids...], ...],
                #    "contradiction_groups": [[ids...], ...]}
                batch_groups = call_llm_verify_relation(
                    client=client,
                    config=config,
                    frames=frames_for_batch
                )

                # convert groups into pairwise edges
                for group in batch_groups.get("paraphrase_groups", []):
                    for i, j in itertools.combinations(group, 2):
                        relations.append({
                            "type": "paraphrase",
                            "x": i, "y": j,
                            "reasoning": "grouped by LLM",
                            "score": None,
                        })

                for group in batch_groups.get("contradiction_groups", []):
                    for i, j in itertools.combinations(group, 2):
                        relations.append({
                            "type": "contradiction",
                            "x": i, "y": j,
                            "reasoning": "grouped by LLM",
                            "score": None,
                        })

    return relations



def merge_paraphrase_components(unique_frames, relations):
    """
    Merge all frames that are in the same paraphrase *group* (connected component).
    The `relations` list can contain many edges produced either
    by the cross-encoder or by grouped LLM output.

    Returns
    -------
    final_items : list of dict
        Each dict represents a merged cluster with combined metadata.
    """
    g = nx.Graph()
    for idx in range(len(unique_frames)):
        g.add_node(idx)

    # only paraphrase edges define components
    for r in relations:
        if r["type"] == "paraphrase":
            g.add_edge(r["x"], r["y"])

    node_to_rep = {}
    for comp in nx.connected_components(g):
        comp = list(comp)
        if len(comp) == 1:
            n = comp[0]
            node_to_rep[n] = n
            continue
        # Representative = shortest text (break ties by largest count)
        comp_frames = [(unique_frames[n]["text"],
                        unique_frames[n]["count"],
                        n) for n in comp]
        comp_frames.sort(key=lambda x: (len(x[0]), -x[1]))
        rep = comp_frames[0][2]
        for n in comp:
            node_to_rep[n] = rep

    # rep -> list of member nodes
    rep_map = defaultdict(list)
    for n, rep in node_to_rep.items():
        rep_map[rep].append(n)

    final_items = []
    seen_reps = set()
    for rep, members in rep_map.items():
        if rep in seen_reps:
            continue
        seen_reps.add(rep)

        merged_count = sum(unique_frames[m]["count"] for m in members)
        targets, intents, actions = set(), set(), set()
        for m in members:
            targets.update(unique_frames[m].get("targets", []))
            intents.update(unique_frames[m].get("intents", []))
            actions.update(unique_frames[m].get("actions", []))

        final_items.append({
            "frame_id": unique_frames[rep]["frame_id"],
            "text": unique_frames[rep]["text"],
            "merged_ids": [unique_frames[m]["frame_id"] for m in members if m != rep],
            "members": [unique_frames[m]["frame_id"] for m in members],
            "targets": sorted(targets),
            "intents": sorted(intents),
            "actions": sorted(actions),
            "count": merged_count,
        })

    # add singletons not touched by any paraphrase edge
    untouched = set(range(len(unique_frames))) - set(node_to_rep.keys())
    for idx in untouched:
        uf = unique_frames[idx]
        final_items.append({
            "frame_id": uf["frame_id"],
            "text": uf["text"],
            "merged_ids": [],
            "members": [uf["frame_id"]],
            "targets": uf.get("targets", []),
            "intents": uf.get("intents", []),
            "actions": uf.get("actions", []),
            "count": uf.get("count", 0),
        })

    return final_items


In [None]:
@dataclasses.dataclass
class ChatCompletionConfig:
    seed: int
    delay: int
    model: str
    max_tokens: int
    temperature: float
    system_prompt: str
    response_format: dict | None = None

config_file_path = 'prompts/paraphrase.yaml'
with open(config_file_path, 'r') as f:
    config = yaml.safe_load(f)
    config = ChatCompletionConfig(**config)  # Assuming you have this class defined


In [None]:
# Step 1: build unique frames
unique_frames = build_unique_frames(FINAL_FRAMES_IN, 'problems.jsonl', out_path=UNIQUE_FRAMES_OUT)
# Step 2: clustering per target
clusters_by_target = cluster_within_target(unique_frames, embed_model_name=EMBED_MODEL_NAME, distance_threshold=DISTANCE_THRESHOLD)

# Optional init cross-encoder
cross_encoder = None
if CROSS_ENCODER_AVAILABLE:
    cross_encoder = init_cross_encoder(CROSS_ENCODER_MODEL)

In [97]:
import os, json, time, itertools, math, random
from pathlib import Path
from tqdm import tqdm

# --- You already have these helpers ---
# - make_pair_list_from_cluster(idxs, labels) -> list[list[int]]
# - score_pairs_with_crossencoder(cross_encoder, pair_texts) -> list[float]
# - call_llm_verify_relation(client, config, frames=[(fid, text), ...]) -> {
#       "paraphrase_groups":[[ids...],...], "contradiction_groups":[[ids...],...]
#   }

# ----------------- CONFIG -----------------
OUTPUT_DIR = Path("hybrid_ckpt")
REL_FPATH  = OUTPUT_DIR / "relations.jsonl"
FRM_FPATH  = OUTPUT_DIR / "frames_seen.jsonl"
STATE_FPATH= OUTPUT_DIR / "state.json"

CROSS_THRESHOLD = 0.75
BATCH_SIZE = 100
MAX_LLM_RETRIES = 5
MAX_CE_RETRIES  = 3
# -----------------------------------------

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

def load_state():
    if STATE_FPATH.exists():
        with open(STATE_FPATH, "r") as f:
            return json.load(f)
    return {
        "target_pos": 0,
        "cluster_pos": 0,
        "batch_pos": 0,  # index of batch within the unique_ids list
        "timestamp": time.time()
    }

def save_state(state):
    state["timestamp"] = time.time()
    tmp = STATE_FPATH.with_suffix(".json.tmp")
    with open(tmp, "w") as f:
        json.dump(state, f)
    os.replace(tmp, STATE_FPATH)

def append_jsonl(path: Path, records):
    with open(path, "a") as f:
        for r in records:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

def retry(fn, *, tries, base_delay=1.5, max_delay=30, jitter=True):
    for attempt in range(tries):
        try:
            return fn()
        except Exception as e:
            if attempt == tries - 1:
                raise
            sleep_s = min(max_delay, base_delay * (2 ** attempt))
            if jitter:
                sleep_s *= (0.5 + random.random())
            time.sleep(sleep_s)

def list_targets(clusters_by_target):
    # deterministic order for resume
    return list(sorted(clusters_by_target.items(), key=lambda kv: str(kv[0])))

def ensure_frames_logged(frames_for_batch):
    # frames_for_batch: list[(fid, text)]
    append_jsonl(FRM_FPATH, [{"id": fid, "text": text} for fid, text in frames_for_batch])

def process_cluster(unique_frames, cluster, *, cross_encoder, client, config,
                    relations_sink, ce_threshold, batch_size, state):
    pairs = list(itertools.combinations(cluster, 2))
    if not pairs:
        return

    # 1) Cross-encoder filtering (unchanged)
    low_pairs = []
    if cross_encoder is not None:
        def ce_call():
            pair_texts = [(unique_frames[i]["text"], unique_frames[j]["text"]) for i, j in pairs]
            return score_pairs_with_crossencoder(cross_encoder, pair_texts)
        scores = retry(ce_call, tries=MAX_CE_RETRIES)
        to_append = []
        for (i, j), s in zip(pairs, scores):
            if s >= ce_threshold:
                to_append.append({
                    "type": "paraphrase",
                    "x": i, "y": j,
                    "reasoning": f"auto-accepted (score={s:.3f})",
                    "score": float(s),
                })
            else:
                low_pairs.append((i, j))
        if to_append:
            append_jsonl(relations_sink, to_append)
    else:
        low_pairs = pairs

    if not low_pairs:
        return

    # 2) Unique IDs that need LLM grouping
    unique_ids = sorted(set(itertools.chain.from_iterable(low_pairs)))
    total_batches = math.ceil(len(unique_ids) / batch_size)
    start_batch = state.get("batch_pos", 0)

    # ---- NEW: per-cluster progress bar for LLM frames ----
    already_done = min(start_batch * batch_size, len(unique_ids))
    desc = f"Cluster frames → LLM ({len(cluster)} in cluster, {len(unique_ids)} low)"
    with tqdm(total=len(unique_ids), desc=desc, initial=already_done, leave=False) as pbar:
        # 3) Batched LLM grouping
        for b in range(start_batch, total_batches):
            start = b * batch_size
            batch_ids = unique_ids[start:start + batch_size]
            frames_for_batch = [(fid, unique_frames[fid]["text"]) for fid in batch_ids]

            # Log frames we send to LLM
            ensure_frames_logged(frames_for_batch)

            def llm_call():
                return call_llm_verify_relation(client=client, config=config, frames=frames_for_batch)

            groups = retry(llm_call, tries=MAX_LLM_RETRIES)

            out = []
            for group in groups.get("paraphrase_groups", []):
                for i, j in itertools.combinations(group, 2):
                    out.append({
                        "type": "paraphrase",
                        "x": int(i), "y": int(j),
                        "reasoning": "grouped by LLM",
                        "score": None,
                    })
            for group in groups.get("contradiction_groups", []):
                for i, j in itertools.combinations(group, 2):
                    out.append({
                        "type": "contradiction",
                        "x": int(i), "y": int(j),
                        "reasoning": "grouped by LLM",
                        "score": None,
                    })

            if out:
                append_jsonl(relations_sink, out)

            # checkpoint after each batch
            state["batch_pos"] = b + 1
            save_state(state)

            # ---- NEW: advance per-cluster progress ----
            pbar.update(len(batch_ids))

    # reset batch cursor for next cluster
    state["batch_pos"] = 0
    save_state(state)

def run_hybrid_incremental(
    unique_frames,
    clusters_by_target,
    *,
    client=None,
    config=None,
    cross_encoder=None,
    cross_encoder_threshold=CROSS_THRESHOLD,
    batch_size=BATCH_SIZE
):
    state = load_state()
    targets = list_targets(clusters_by_target)

    with tqdm(total=len(targets), desc="Hybrid detection per target", initial=state["target_pos"]) as pbar:
        for tpos, (target, data) in enumerate(targets):
            if tpos < state["target_pos"]:
                continue

            idxs = data["indices"]
            labels = data["labels"]
            cluster_lists = make_pair_list_from_cluster(idxs, labels)

            # store current target in state
            state["target_pos"] = tpos
            state["cluster_pos"] = state.get("cluster_pos", 0)
            save_state(state)

            for cpos, cluster in enumerate(cluster_lists):
                if cpos < state["cluster_pos"]:
                    continue

                try:
                    process_cluster(
                        unique_frames,
                        cluster,
                        cross_encoder=cross_encoder,
                        client=client,
                        config=config,
                        relations_sink=REL_FPATH,
                        ce_threshold=cross_encoder_threshold,
                        batch_size=batch_size,
                        state=state
                    )
                except KeyboardInterrupt:
                    print("\nInterrupted. Progress saved.")
                    return
                except Exception as e:
                    # Save and re-raise so you see the error; files already have progress.
                    save_state(state)
                    raise

                # move to next cluster
                state["cluster_pos"] = cpos + 1
                state["batch_pos"] = 0
                save_state(state)

            # move to next target
            state["target_pos"] = tpos + 1
            state["cluster_pos"] = 0
            state["batch_pos"] = 0
            save_state(state)
            pbar.update(1)

# ----------------- USAGE -----------------
# run_hybrid_incremental(
#     unique_frames=unique_frames,                  # list[{"text": str, ...}]
#     clusters_by_target=clusters_by_target,        # {target: {"indices": [...], "labels": [...]}}
#     client=client,
#     config=config,
#     cross_encoder=cross_encoder,                  # or None
#     cross_encoder_threshold=0.75,
#     batch_size=100,
# )
#
# Outputs:
#   hybrid_ckpt/relations.jsonl     (append-only)
#   hybrid_ckpt/frames_seen.jsonl   (frames sent to LLM)
#   hybrid_ckpt/state.json          (resume cursor)


In [None]:
# ----------------- USAGE -----------------
run_hybrid_incremental(
unique_frames=unique_frames,                  # list[{"text": str, ...}]
     clusters_by_target=clusters_by_target,        # {target: {"indices": [...], "labels": [...]}}
     client=client,
     config=config,
     cross_encoder=cross_encoder,                  # or None
     cross_encoder_threshold=0.75,
     batch_size=100,
 )

Hybrid detection per target: 100%|██████████| 12/12 [4:15:07<00:00, 1275.62s/it] 


In [99]:
REL_FPATH = Path("hybrid_ckpt/relations.jsonl")

def load_paraphrase_relations(rel_path: str | Path):
    """
    Read append-only relations.jsonl, keep unique undirected paraphrase edges.
    If the same pair appears multiple times, we keep the first occurrence
    (or you can change the policy below to keep the highest score).
    """
    seen = set()  # (min(x,y), max(x,y))
    uniq = []
    with open(rel_path, "r") as f:
        for line in f:
            if not line.strip():
                continue
            r = json.loads(line)
            if r.get("type") != "paraphrase":
                continue
            x, y = int(r["x"]), int(r["y"])
            if x == y:
                continue
            key = (min(x, y), max(x, y))
            if key in seen:
                continue
            seen.add(key)
            # keep original structure merge() expects
            uniq.append({
                "type": "paraphrase",
                "x": key[0],
                "y": key[1],
                "reasoning": r.get("reasoning"),
                "score": r.get("score"),
            })
    return uniq

In [100]:
relations = load_paraphrase_relations(REL_FPATH)
final_items = merge_paraphrase_components(unique_frames, relations)

In [101]:
len(final_items)

4502

In [104]:
relevant =0
for i in final_items:
    if i['count']>1:
        relevant+=1
print(relevant)

1085


In [102]:
final_items

[{'frame_id': 'f6197',
  'text': 'Muslims are terrorists — even Muslim children are potential bombers.',
  'merged_ids': ['f1', 'f676', 'f432', 'f2336'],
  'members': ['f1', 'f676', 'f432', 'f6197', 'f2336'],
  'targets': ['Age (Demographics)',
   'Race/Ethnicity (Demographics)',
   'Religion/Belief (Demographics)'],
  'intents': ['Hate (Hostility/Aggression)',
   'Humiliation (Discrimination/Prejudice)',
   'Prejudice (Discrimination/Prejudice)',
   'Stereotyping (Social/Cultural Control)',
   'Violence (Hostility/Aggression)'],
  'actions': ['Insults (Verbal/Written Expressions)',
   'Provocation (Verbal/Written Expressions)',
   'Violence (Physical Action)'],
  'count': 5},
 {'frame_id': 'f2',
  'text': 'Bullying or attacking Muslim kids is acceptable because they’re dangerous.',
  'merged_ids': [],
  'members': ['f2'],
  'targets': ['Age (Demographics)', 'Religion/Belief (Demographics)'],
  'intents': ['Hate (Hostility/Aggression)',
   'Humiliation (Discrimination/Prejudice)',
   '

In [103]:
write_jsonl('merged_frames_hateful.jsonl',final_items)

In [None]:
len(clusters_by_target)

12

In [65]:
total_frames = sum(len(data["indices"]) for data in clusters_by_target.values())
print(total_frames)

11659


MMHS150K

In [3]:
import json
import re

def safe_json_loads(s):
    """Try to safely load JSON even if it has invalid control characters."""
    try:
        return json.loads(s)
    except json.JSONDecodeError:
        # Remove unescaped control characters
        s = re.sub(r"[\x00-\x1f\x7f]", "", s)
        try:
            return json.loads(s)
        except Exception:
            return None  # still invalid

def build_unique_frames(final_frames_path: str,
                        problems_path: str,
                        out_path: str):
  id2problems = {}

  with open(problems_path, "r", encoding="utf-8") as f:
      for line_num, line in enumerate(f, 1):
          rec = safe_json_loads(line)
          if not rec:
              print(f"⚠️ Skipping malformed line {line_num} in problems file.")
              continue

          pid = rec.get("id")
          resp = safe_json_loads(rec.get("response", "{}"))
          if not resp:
              continue

          hate_probs = resp.get("hate_problems", {})
          if isinstance(hate_probs, str):
              hate_probs = safe_json_loads(hate_probs) or {}

          def extract_categories(items):
              result = []
              for x in items or []:
                  if isinstance(x, dict):
                      result.append(x.get("category"))
                  elif isinstance(x, str):
                      result.append(x)
              return [v for v in result if v]  # remove None or empty

          id2problems[pid] = {
              "targets": extract_categories(hate_probs.get("targets")),
              "intents": extract_categories(hate_probs.get("intents")),
              "actions": extract_categories(hate_probs.get("actions")),
          }


  frame_map = {}
  counter = 1

  with open(final_frames_path, "r", encoding="utf-8") as f:
      for line_num, line in enumerate(f, 1):
          rec = safe_json_loads(line)
          if not rec:
              print(f"⚠️ Skipping malformed line {line_num} in frames file.")
              continue

          img_id = rec.get("id")
          img_id = img_id.split(".")[0]
          #print(img_id)
          resp = safe_json_loads(rec.get("response", "{}"))
          if not resp:
              continue

          frames = resp.get("hate_frames", [])
          if len(frames) <= 1:
              continue

          for fr in frames[1:]:
              text = fr.get("frame", "").strip()
              if not text:
                  continue
              if text not in frame_map:
                  frame_map[text] = {
                      "frame_id": f"f{counter}",
                      "count": 0,
                      "targets": id2problems.get(img_id, {}).get("targets", []),
                      "intents": id2problems.get(img_id, {}).get("intents", []),
                      "actions": id2problems.get(img_id, {}).get("actions", []),
                      "source_ids": set(),
                  }
                  counter += 1
              frame_map[text]["count"] += 1
              frame_map[text]["source_ids"].add(img_id)

  unique_list = []
  with open(out_path, "w", encoding="utf-8") as out_f:
      for text, data in frame_map.items():
          rec = {
              "text": text,
              "frame_id": data["frame_id"],
              "count": data["count"],
              "targets": data["targets"],
              "intents": data["intents"],
              "actions": data["actions"],
              "source_ids": sorted(list(data["source_ids"])),
          }
          out_f.write(json.dumps(rec) + "\n")
          unique_list.append(rec)

  return unique_list


In [4]:
FINAL_FRAMES_IN = "final_frames_mmhs.jsonl"         # your input file (per example)
UNIQUE_FRAMES_OUT="unique_frames_mmhs.jsonl"
unique_frames = build_unique_frames(FINAL_FRAMES_IN, 'problems_mmhs.jsonl', out_path=UNIQUE_FRAMES_OUT)

⚠️ Skipping malformed line 170 in frames file.
⚠️ Skipping malformed line 171 in frames file.


In [5]:
import numpy as np
import itertools
from sentence_transformers import SentenceTransformer, util
from tqdm import tqdm
from transformers import pipeline

# -----------------------------
# CONFIG
# -----------------------------
EMBED_MODEL = "intfloat/e5-large-v2"  # strong similarity model
NLI_MODEL = "facebook/bart-large-mnli"  # for contradiction detection
PARAPHRASE_THRESHOLD = 0.70             # lower = more aggressive merging
CONTRADICTION_THRESHOLD = 0.85

# -----------------------------
# STEP 1: EMBEDDING
# -----------------------------
def embed_texts(model, texts, batch_size=128):
    embs = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        e = model.encode(batch, convert_to_numpy=True, show_progress_bar=False, normalize_embeddings=True)
        embs.append(e)
    return np.vstack(embs)

# -----------------------------
# STEP 2: PARAPHRASE DETECTION
# -----------------------------
def detect_paraphrases(unique_frames, model=None, threshold=PARAPHRASE_THRESHOLD):
    if model is None:
        model = SentenceTransformer(EMBED_MODEL)
    texts = [uf["text"] for uf in unique_frames]
    embeddings = embed_texts(model, texts)
    
    # Compute cosine similarities
    sim_matrix = util.cos_sim(embeddings, embeddings).cpu().numpy()

    paraphrases = []
    n = len(unique_frames)
    for i in tqdm(range(n), desc="Detecting paraphrases"):
        for j in range(i+1, n):
            score = sim_matrix[i, j]
            if score >= threshold:
                paraphrases.append({
                    "type": "paraphrase",
                    "x": i,
                    "y": j,
                    "score": float(score),
                    "reasoning": f"Cosine similarity {score:.3f} ≥ {threshold}"
                })
    return paraphrases

# -----------------------------
# STEP 3: CONTRADICTION DETECTION (OPTIONAL)
# -----------------------------
def detect_contradictions(unique_frames, candidate_pairs, nli_model=None, threshold=CONTRADICTION_THRESHOLD):
    if nli_model is None:
        nli_model = pipeline("text-classification", model=NLI_MODEL, truncation=True)
    
    contradictions = []
    texts = [uf["text"] for uf in unique_frames]
    for (i, j) in tqdm(candidate_pairs, desc="Detecting contradictions"):
        premise, hypothesis = texts[i], texts[j]
        result = nli_model(f"{premise} </s> {hypothesis}", return_all_scores=True)[0]
        result_dict = {r["label"]: r["score"] for r in result}
        if result_dict.get("CONTRADICTION", 0) >= threshold:
            contradictions.append({
                "type": "contradiction",
                "x": i,
                "y": j,
                "score": result_dict["CONTRADICTION"],
                "reasoning": f"NLI contradiction score {result_dict['CONTRADICTION']:.3f} ≥ {threshold}"
            })
    return contradictions

# -----------------------------
# STEP 4: COMBINE RESULTS
# -----------------------------
def detect_paraphrases_and_contradictions(unique_frames,
                                          paraphrase_threshold=PARAPHRASE_THRESHOLD,
                                          contradiction_threshold=CONTRADICTION_THRESHOLD):
    model = SentenceTransformer(EMBED_MODEL)
    paraphrases = detect_paraphrases(unique_frames, model, paraphrase_threshold)

    # Only check contradictions among similar pairs to save time
    candidate_pairs = [(p["x"], p["y"]) for p in paraphrases]
    contradictions = detect_contradictions(unique_frames, candidate_pairs, threshold=contradiction_threshold)

    all_relations = paraphrases + contradictions
    print(f"✅ Found {len(paraphrases)} paraphrases and {len(contradictions)} contradictions.")
    return all_relations


In [None]:
relations = detect_paraphrases_and_contradictions(unique_frames,
                                                  paraphrase_threshold=0.68,
                                                  contradiction_threshold=0.85)

In [None]:
#finding number of paraphrase and contradictions
import json

# Path to your JSONL file
file_path = "relations.jsonl"

# Initialize counters
paraphrase_count = 0
contradiction_count = 0

# Read line by line
with open(file_path, "r", encoding="utf-8") as f:
    for line in f:
        if not line.strip():
            continue
        data = json.loads(line)
        if data.get("type") == "paraphrase":
            paraphrase_count += 1
        elif data.get("type") == "contradiction":
            contradiction_count += 1

print(f"Paraphrases: {paraphrase_count}")
print(f"Contradictions: {contradiction_count}")


In [1]:
#finding number of paraphrase and contradictions
import json

# Path to your JSONL file
file_path = "hybrid_ckpt_hateful/relations.jsonl"

# Initialize counters
paraphrase_count = 0
contradiction_count = 0

# Read line by line
with open(file_path, "r", encoding="utf-8") as f:
    for line in f:
        if not line.strip():
            continue
        try:
          data = json.loads(line)
          if data.get("type") == "paraphrase":
              paraphrase_count += 1
          elif data.get("type") == "contradiction":
              contradiction_count += 1
        except json.JSONDecodeError:
            print(f"Skipping malformed line: {line}")

print(f"Paraphrases: {paraphrase_count}")
print(f"Contradictions: {contradiction_count}")


Paraphrases: 16572
Contradictions: 96


In [2]:
import json
from collections import Counter

file_path = "unique_frames_mmhs.jsonl"

target_counter = Counter()

with open(file_path, "r", encoding="utf-8") as f:
    for line in f:
        record = json.loads(line)
        targets = record.get("targets", [])
        target_counter.update(targets)

# Display results sorted by count
for target, count in target_counter.most_common():
    print(f"{target}: {count}")

print("\nTotal unique target categories:", len(target_counter))
print("Total frames:", sum(target_counter.values()))


Race/Ethnicity (Demographics): 5451
Gender/Sex/Sexual Orientation (Demographics): 3463
Disability (Demographics): 2143
Socio-economic Status (Socio-economic role): 1139
Other: 1138
Nationality/Region/Citizenship (Demographics): 1024
Appearance (Demographics): 449
Occupation/Profession (Socio-economic role): 415
Religion/Belief (Demographics): 317
Age (Demographics): 171
Legal Status/Discrimination (Socio-economic role): 157
Family Status (Socio-economic role): 34

Total unique target categories: 12
Total frames: 15901


In [1]:
import json

file_path = "merged_frames_hateful.jsonl"

total_targets = 0
total_intents = 0
total_actions = 0
frames_with_count_2plus = 0

with open(file_path, "r", encoding="utf-8") as f:
    for line in f:
        record = json.loads(line.strip())
        if record.get("count", 0) >= 2:
            frames_with_count_2plus += 1
            total_targets += len(record.get("targets", []))
            total_intents += len(record.get("intents", []))
            total_actions += len(record.get("actions", []))

print(f"Frames with count >= 2: {frames_with_count_2plus}")
print(f"Total targets annotated: {total_targets}")
print(f"Total intents annotated: {total_intents}")
print(f"Total actions annotated: {total_actions}")


Frames with count >= 2: 1085
Total targets annotated: 2117
Total intents annotated: 6709
Total actions annotated: 2726
