In [None]:
# Count truncated initial_response per file using TOKEN-COUNT-ONLY detection
# (No "unfinished ending" checks; max_new_tokens in JSONL is ignored.)
#
# Truncated criterion:
#   tokenize(initial_response) minus tokenize(prefill_text if it is a true prefix)
#   => gen_token_count
#   truncated if gen_token_count is within MARGIN of any value in LIMIT_GRID.

from __future__ import annotations

import json
from pathlib import Path
from typing import Dict, Any, Iterable, Tuple
from functools import lru_cache

from tqdm.auto import tqdm
from transformers import AutoTokenizer


# -----------------------
# CONFIG
# -----------------------
DATA_DIR = Path("rq1_runs")

FILES_TO_PROCESS = [
    "gen_google__gemma-3-4b-it_harmbench.jsonl",
    "gen_google__gemma-3-4b-it.jsonl",
    "gen_google__gemma-3-12b-it_harmbench.jsonl",
    "gen_google__gemma-3-12b-it.jsonl",
    "gen_google__gemma-3-27b-it_harmbench.jsonl",
    "gen_google__gemma-3-27b-it.jsonl",
    "gen_meta-llama__Llama-3.1-8B-Instruct_harmbench.jsonl",
    "gen_meta-llama__Llama-3.1-8B-Instruct.jsonl",
    "gen_meta-llama__llama-3.3-70b-instruct_harmbench.jsonl",
    "gen_meta-llama__llama-3.3-70b-instruct.jsonl",
    "gen_Qwen__Qwen3-8B_harmbench.jsonl",
    "gen_Qwen__Qwen3-8B.jsonl",
    "gen_Qwen__Qwen3-14B_harmbench.jsonl",
    "gen_Qwen__Qwen3-14B.jsonl",
    "gen_qwen__qwen3-32b_harmbench.jsonl",
    "gen_qwen__qwen3-32b.jsonl",
]

# The real caps you believe were enforced by the generator
LIMIT_GRID = [512, 1024, 1280, 1536, 1792]

# Error margin (tokens) around each cap
MARGIN = 4  # tune: 1–2 = stricter, 4–8 = more recall

FIELD = "initial_response"

# Optional: if your JSONL model_name isn't a valid HF tokenizer id, map it here.
# Example (fill only if needed):
MODEL_NAME_OVERRIDES = {
    # "google/gemma-3-12b-it": "google/gemma-3-12b-it",
    # "gen_meta-llama__Llama-3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
}


# -----------------------
# IO
# -----------------------
def iter_jsonl(path: Path) -> Iterable[Dict[str, Any]]:
    with path.open("r", encoding="utf-8") as f:
        for line_no, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            try:
                yield json.loads(line)
            except json.JSONDecodeError:
                # Skip malformed lines; keep going
                continue


# -----------------------
# Tokenizer cache
# -----------------------
@lru_cache(maxsize=64)
def get_tokenizer(model_name: str):
    tok_name = MODEL_NAME_OVERRIDES.get(model_name, model_name)
    return AutoTokenizer.from_pretrained(tok_name, trust_remote_code=True)


# -----------------------
# Token counting (prefill-safe)
# -----------------------
def generated_token_count(rec: Dict[str, Any], tokenizer, field: str = FIELD) -> int:
    """
    Compute the token count generated by the model for `field`.
    If prefill_text is literally included as a prefix, subtract it (safely).
    """
    full_text = rec.get(field) or ""
    if not full_text:
        return 0

    prefill = rec.get("prefill_text") or ""
    full_ids = tokenizer.encode(full_text, add_special_tokens=False)

    if prefill and full_text.startswith(prefill):
        pre_ids = tokenizer.encode(prefill, add_special_tokens=False)

        # safest: subtract only if token-prefix matches
        if full_ids[: len(pre_ids)] == pre_ids:
            return max(0, len(full_ids) - len(pre_ids))

        # fallback: tokenize the string suffix (stable even if token-prefix mismatch)
        suffix = full_text[len(prefill) :]
        return len(tokenizer.encode(suffix, add_special_tokens=False))

    # no usable prefill subtraction
    return len(full_ids)


# -----------------------
# Token-only truncation decision
# -----------------------
def is_truncated_token_only(rec: Dict[str, Any], tokenizer) -> bool:
    gen_ct = generated_token_count(rec, tokenizer, field=FIELD)
    return any(abs(gen_ct - L) <= MARGIN for L in LIMIT_GRID)


# -----------------------
# Per-file counting
# -----------------------
def count_truncations_for_file(path: Path) -> Tuple[int, int]:
    total = 0
    trunc = 0

    # We'll fetch tokenizer per record (cached), in case a file mixes models.
    for rec in iter_jsonl(path):
        total += 1
        model_name = rec.get("model_name")
        if not model_name:
            continue  # skip if malformed
        tok = get_tokenizer(model_name)
        if is_truncated_token_only(rec, tok):
            trunc += 1

    return total, trunc


In [4]:

rows = []
missing = []

for fname in FILES_TO_PROCESS:
    path = DATA_DIR / fname
    if not path.exists():
        missing.append(str(path))
        continue

    total, trunc = count_truncations_for_file(path)
    pct = (100.0 * trunc / total) if total else 0.0
    rows.append((fname, total, trunc, pct))

if missing:
    print("\nMissing files:")
    for m in missing:
        print("  -", m)

if not rows:
    print("\nNo files processed.")

# Sort by truncation rate descending
rows.sort(key=lambda x: x[3], reverse=True)

print("\nToken-only truncation summary (initial_response):")
print(f"{'file':70}  {'total':>8}  {'trunc':>8}  {'%':>7}")
print("-" * 100)
for fname, total, trunc, pct in rows:
    print(f"{fname:70}  {total:8d}  {trunc:8d}  {pct:6.2f}%")

print("\nSettings used:")
print("  LIMIT_GRID =", LIMIT_GRID)
print("  MARGIN     =", MARGIN)
print("  FIELD      =", FIELD)




Token-only truncation summary (initial_response):
file                                                                       total     trunc        %
----------------------------------------------------------------------------------------------------
gen_google__gemma-3-4b-it.jsonl                                             2920      2915   99.83%
gen_google__gemma-3-4b-it_harmbench.jsonl                                   2000      1853   92.65%
gen_google__gemma-3-12b-it.jsonl                                            2920      2647   90.65%
gen_google__gemma-3-27b-it.jsonl                                            2920      2469   84.55%
gen_google__gemma-3-12b-it_harmbench.jsonl                                  2000      1676   83.80%
gen_google__gemma-3-27b-it_harmbench.jsonl                                  2000      1342   67.10%
gen_meta-llama__Llama-3.1-8B-Instruct.jsonl                                 2920      1493   51.13%
gen_qwen__qwen3-32b.jsonl                       

# Fix code

In [1]:
# =========================
# CELL 1: Setup & Config
# =========================
import json
import time
import random
import inspect
from pathlib import Path
from typing import Dict, Any, Optional, Tuple, Iterable
from functools import lru_cache
from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED

import requests
import httpx
from tqdm.auto import tqdm
from openai import OpenAI
from transformers import AutoTokenizer

# -------------------------
# Config
# -------------------------
RUN_DIR = Path("rq1_runs")

# Two files pertaining to the same model to run back-to-back
FILES_TO_PROCESS = [
    RUN_DIR / "gen_Qwen__Qwen3-14B_harmbench.jsonl",
    RUN_DIR / "gen_Qwen__Qwen3-14B.jsonl"
]

# SGLang OpenAI-compatible endpoint
SGLANG_BASE_URL = "http://127.0.0.1:30000/v1"
SGLANG_API_KEY = "EMPTY"

# Global continuation limits
MAX_TOTAL_TOKENS = 2048
FOLLOWUP_MAX_TOKENS = 512

# Truncation detection limits (from your detection script)
LIMIT_GRID = [512, 1024, 1280, 1536, 1792]
MARGIN = 4

# Keep-alive / connection pooling for faster repeated requests
_http_client = httpx.Client(
    limits=httpx.Limits(max_connections=64, max_keepalive_connections=16),
    timeout=httpx.Timeout(60.0),
)
sg = OpenAI(base_url=SGLANG_BASE_URL, api_key=SGLANG_API_KEY, http_client=_http_client)

# Parallelism settings
MAX_WORKERS = 8
MAX_IN_FLIGHT = MAX_WORKERS * 4
DO_MERGE = True

def is_blank(x: Optional[str]) -> bool:
    return x is None or (isinstance(x, str) and x.strip() == "")

In [2]:
# =========================
# CELL 2: Detection Logic
# =========================

@lru_cache(maxsize=8)
def get_tokenizer(model_name: str):
    return AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

def generated_token_count(rec: Dict[str, Any], tokenizer, field: str = "initial_response") -> int:
    """
    Compute the token count generated by the model.
    Subtracts prefill safely if it exists.
    """
    full_text = rec.get(field) or ""
    if not full_text:
        return 0

    prefill = rec.get("prefill_text") or ""
    full_ids = tokenizer.encode(full_text, add_special_tokens=False)

    if prefill and full_text.startswith(prefill):
        pre_ids = tokenizer.encode(prefill, add_special_tokens=False)
        # safest: subtract only if token-prefix matches
        if full_ids[: len(pre_ids)] == pre_ids:
            return max(0, len(full_ids) - len(pre_ids))

        # fallback: tokenize the string suffix
        suffix = full_text[len(prefill) :]
        return len(tokenizer.encode(suffix, add_special_tokens=False))

    return len(full_ids)

def is_truncated_token_only(rec: Dict[str, Any], tokenizer) -> bool:
    """Returns True if the generated token count is within MARGIN of any known limit."""
    if rec.get("condition") == "error":
        return False
    gen_ct = generated_token_count(rec, tokenizer)
    return any(abs(gen_ct - L) <= MARGIN for L in LIMIT_GRID)

In [3]:
# =========================
# CELL 3: SGLang API & Prompting
# =========================
_SERVER_MODEL_ID = None

def get_server_model_id() -> str:
    global _SERVER_MODEL_ID
    if _SERVER_MODEL_ID is not None:
        return _SERVER_MODEL_ID
    models = requests.get(SGLANG_BASE_URL + "/models", timeout=10).json()
    _SERVER_MODEL_ID = models["data"][0]["id"]
    return _SERVER_MODEL_ID

def apply_chat_template_no_think_if_supported(tokenizer, messages):
    try:
        return tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=False,
            enable_thinking=False,
        )
    except TypeError:
        return tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=False,
        )

def build_base_prompt(tokenizer, user_text: str) -> str:
    messages = [{"role": "user", "content": user_text}]
    return apply_chat_template_no_think_if_supported(tokenizer, messages)

def build_followup_prompt(tokenizer, user_text: str, assistant_text: str, followup_text: str) -> str:
    messages = [
        {"role": "user", "content": user_text},
        {"role": "assistant", "content": assistant_text},
        {"role": "user", "content": followup_text},
    ]
    return apply_chat_template_no_think_if_supported(tokenizer, messages)

def sglang_complete(prompt: str, max_new_tokens: int, temperature: float = 0.0, top_p: float = 1.0) -> str:
    last_err = None
    for attempt in range(6):
        try:
            resp = sg.completions.create(
                model=get_server_model_id(),
                prompt=prompt,
                max_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
            )
            return resp.choices[0].text
        except Exception as e:
            last_err = e
            time.sleep(1.0 * (2 ** attempt) + random.random() * 0.25)
    raise RuntimeError(f"SGLang completion failed. last_err={repr(last_err)}")

In [4]:
# =========================
# CELL 4: The Core Worker
# =========================

def process_record(rec: Dict[str, Any], tokenizer) -> Tuple[Dict[str, Any], bool]:
    """
    Checks if truncated. If so, calculates remaining tokens to reach 2048,
    continues the initial_response, and regenerates the followup_response.
    """
    if not is_truncated_token_only(rec, tokenizer):
        return rec, False

    user_text = rec.get("prompt_text") or ""
    followup_text = rec.get("followup_text") or ""
    initial_text = rec.get("initial_response") or ""
    
    if is_blank(user_text) or is_blank(followup_text) or is_blank(initial_text):
        return rec, False

    # 1. Calculate remaining budget
    gen_ct = generated_token_count(rec, tokenizer)
    remaining_tokens = MAX_TOTAL_TOKENS - gen_ct
    
    # If it's already over or exactly at the limit, we skip it
    if remaining_tokens <= 0:
        return rec, False

    base_kwargs = rec.get("gen_kwargs", {})
    temperature = float(base_kwargs.get("temperature", 0.0))
    top_p = float(base_kwargs.get("top_p", 1.0))

    # 2. Continue the Initial Response
    base_prompt = build_base_prompt(tokenizer, user_text)
    
    # By appending the raw initial_response to the base prompt, SGLang 
    # will literally start generating the very next token.
    continuation_prompt = base_prompt + initial_text
    
    continuation_suffix = sglang_complete(
        prompt=continuation_prompt, 
        max_new_tokens=remaining_tokens,
        temperature=temperature,
        top_p=top_p
    )
    new_initial = initial_text + continuation_suffix

    # 3. Regenerate the Followup Response
    followup_prompt = build_followup_prompt(tokenizer, user_text, new_initial, followup_text)
    new_followup = sglang_complete(
        prompt=followup_prompt, 
        max_new_tokens=FOLLOWUP_MAX_TOKENS,
        temperature=temperature,
        top_p=top_p
    )

    # 4. Patch record in-place
    rec["initial_response"] = new_initial
    rec["followup_response"] = new_followup
    
    return rec, True

In [5]:
# =========================
# CELL 5: Execution & Merging
# =========================
def iter_jsonl(path: Path) -> Iterable[Dict[str, Any]]:
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                yield json.loads(line)

def record_key(r: Dict[str, Any]) -> Tuple[Any, ...]:
    return (r.get("model_name"), r.get("prompt_id"), r.get("condition"), r.get("prefill_id"), r.get("followup_id"))

def peek_first_record(path: Path) -> Dict[str, Any]:
    for r in iter_jsonl(path):
        return r
    raise ValueError(f"Empty JSONL: {path}")

# Run back-to-back
for in_path in FILES_TO_PROCESS:
    if not in_path.exists():
        print(f"Skipping {in_path} (File not found)")
        continue

    print(f"\n{'='*50}\nProcessing File: {in_path.name}\n{'='*50}")

    patch_path = Path(str(in_path).replace(".jsonl", "_continued_patch.jsonl"))
    merged_path = Path(str(in_path).replace(".jsonl", "_continued_merged.jsonl"))

    # Validate model matching
    file_model = peek_first_record(in_path).get("model_name")
    server_model = get_server_model_id()
    
    if file_model != server_model:
        print(f"ERROR: Model mismatch for {in_path.name}.")
        print(f"  File model:   {file_model}\n  Server model: {server_model}")
        print("Skipping to next file...")
        continue

    tokenizer = get_tokenizer(file_model)

    # Pre-count total & truncated
    total_records = 0
    trunc_candidates = 0
    for rec in iter_jsonl(in_path):
        total_records += 1
        if is_truncated_token_only(rec, tokenizer):
            trunc_candidates += 1

    print(f"Total records: {total_records}")
    print(f"Truncated candidates to continue: {trunc_candidates}")

    # Process and write
    if patch_path.exists():
        patch_path.unlink()

    patched, errors = 0, 0
    cand_idx, next_flush = 0, 0
    futures_by_idx, results_by_idx = {}, {}

    pbar = tqdm(iter_jsonl(in_path), total=total_records, desc="Processing", unit="rec")
    
    with patch_path.open("a", encoding="utf-8") as fpatch:
        with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
            for rec in pbar:
                # Only submit truncated records
                if is_truncated_token_only(rec, tokenizer):
                    futures_by_idx[cand_idx] = ex.submit(process_record, rec, tokenizer)
                    cand_idx += 1

                # Flush logic
                while len(futures_by_idx) >= MAX_IN_FLIGHT:
                    done_set, _ = wait(list(futures_by_idx.values()), return_when=FIRST_COMPLETED)
                    for i in [idx for idx, fut in futures_by_idx.items() if fut in done_set]:
                        try:
                            results_by_idx[i] = futures_by_idx.pop(i).result()
                        except Exception as e:
                            results_by_idx[i] = e

                    while next_flush in results_by_idx:
                        out = results_by_idx.pop(next_flush)
                        if isinstance(out, Exception):
                            errors += 1
                        else:
                            rec2, did_patch = out
                            if did_patch:
                                patched += 1
                                fpatch.write(json.dumps(rec2, ensure_ascii=False) + "\n")
                        next_flush += 1

                pbar.set_postfix(cands=trunc_candidates, patched=patched, errs=errors)

            # Drain remaining
            while futures_by_idx:
                done_set, _ = wait(list(futures_by_idx.values()), return_when=FIRST_COMPLETED)
                for i in [idx for idx, fut in futures_by_idx.items() if fut in done_set]:
                    try:
                        results_by_idx[i] = futures_by_idx.pop(i).result()
                    except Exception as e:
                        results_by_idx[i] = e

                while next_flush in results_by_idx:
                    out = results_by_idx.pop(next_flush)
                    if isinstance(out, Exception):
                        errors += 1
                    else:
                        rec2, did_patch = out
                        if did_patch:
                            patched += 1
                            fpatch.write(json.dumps(rec2, ensure_ascii=False) + "\n")
                    next_flush += 1
                
                pbar.set_postfix(cands=trunc_candidates, patched=patched, errs=errors)

    print(f"Done processing {in_path.name}.")
    print(f"Patched records written: {patched} | Errors: {errors}")

    if DO_MERGE:
        patch_map = {record_key(r): r for r in iter_jsonl(patch_path)}
        if merged_path.exists():
            merged_path.unlink()

        replaced = 0
        with merged_path.open("a", encoding="utf-8") as fmerged:
            for r in iter_jsonl(in_path):
                k = record_key(r)
                if k in patch_map:
                    r = patch_map[k]
                    replaced += 1
                fmerged.write(json.dumps(r, ensure_ascii=False) + "\n")

        print(f"Merged output saved to: {merged_path}")
        print(f"Total replaced records in merge: {replaced}")


Processing File: gen_Qwen__Qwen3-14B_harmbench.jsonl
Total records: 2000
Truncated candidates to continue: 575


Processing:   0%|          | 0/2000 [00:00<?, ?rec/s]

Done processing gen_Qwen__Qwen3-14B_harmbench.jsonl.
Patched records written: 575 | Errors: 0
Merged output saved to: rq1_runs/gen_Qwen__Qwen3-14B_harmbench_continued_merged.jsonl
Total replaced records in merge: 575

Processing File: gen_Qwen__Qwen3-14B.jsonl
Total records: 2920
Truncated candidates to continue: 1052


Processing:   0%|          | 0/2920 [00:00<?, ?rec/s]

Done processing gen_Qwen__Qwen3-14B.jsonl.
Patched records written: 1052 | Errors: 0
Merged output saved to: rq1_runs/gen_Qwen__Qwen3-14B_continued_merged.jsonl
Total replaced records in merge: 1052
