In [None]:
from collections import defaultdict
import json

def load_jsonl(path):
    import json
    with open(path, encoding="utf-8") as f:
        return [json.loads(line) for line in f]

eval_data = load_jsonl("/home/user/llm/dataset_xsshield/eval.jsonl")
pred_rows = load_jsonl("/home/user/llm/artifacts_gemma3_7/predictions_eval.jsonl")

idx2pred = {row["_idx"]: row["pred"].strip().lower() for row in pred_rows}

def get_group_from_idx(idx):
    if idx < 6907:
        return "beef"
    elif idx < 20089:
        return "beef_obf"
    else:
        return "benign"

stats = {
    "overall": {"tp":0,"fp":0,"tn":0,"fn":0},
    "beef": {"tp":0,"fp":0,"tn":0,"fn":0},
    "beef_obf": {"tp":0,"fp":0,"tn":0,"fn":0},
    "benign": {"tp":0,"fp":0,"tn":0,"fn":0},
}

file_level_data = defaultdict(lambda: {"preds": [], "golds": []})

for idx, ex in enumerate(eval_data):
    base_id = ex["id"].split('-p')[0]
    
    grp = get_group_from_idx(idx)
    
    gold = "yes" if ex["label"].lower() == "yes" else "no"
    pred = idx2pred.get(idx, "no")

    file_key = (base_id, grp) 
    
    file_level_data[file_key]["preds"].append(pred)
    file_level_data[file_key]["golds"].append(gold)

false_positives = []
false_negatives = []
for file_key, data in file_level_data.items():
    base_id, grp = file_key 

    if grp == "beef_obf":
        continue
    
    file_gold = data["golds"][0] 
    
    total_slices = len(data["preds"])
    yes_count = data["preds"].count("yes")
    file_pred = "yes" if yes_count >= (total_slices / 2) else "no"

    details = {
        "base_id": base_id,
        "group": grp,
        "gold": file_gold,
        "pred": file_pred,
        "yes_count": yes_count,
        "total_slices": total_slices
    }
    
    for key in ("overall", grp):
        cm = stats[key]
        if file_gold == "yes" and file_pred == "yes":
            cm["tp"] += 1
        elif file_gold == "yes" and file_pred == "no":
            cm["fn"] += 1
            if key == "overall": 
                false_negatives.append(details)
        elif file_gold == "no" and file_pred == "yes":
            cm["fp"] += 1
            if key == "overall":
                false_positives.append(details)
        else:
            cm["tn"] += 1

def metrics(cm):
    tp, fp, tn, fn = cm["tp"], cm["fp"], cm["tn"], cm["fn"]
    total = tp + fp + tn + fn
    acc   = (tp + tn) / total if total > 0 else 0.0
    prec  = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec   = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    
    f1    = 2 * (prec * rec) / (prec + rec) if (prec + rec) > 0 else 0.0
    
    denom = ((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)) ** 0.5
    mcc   = (tp*tn - fp*fn) / denom if denom > 0 else 0.0
    
    return {"acc":acc, "precision":prec, "recall":rec, "f1":f1, "mcc":mcc, 
            "tp":tp,"fp":fp,"tn":tn,"fn":fn,"total":total}

results = []
for name, cm in stats.items():
    m = metrics(cm)
    m['group'] = name
    results.append(m)

print(f"{'group':<10} | {'acc':>9} | {'precision':>9} | {'recall':>9} | {'f1':>9} | {'mcc':>9} | {'tp':>6} | {'fp':>6} | {'tn':>6} | {'fn':>6} | {'total':>8}")
print("-" * 112)

for m in results:
    print(f"{m['group']:<10} | {m['acc']:>9.4f} | {m['precision']:>9.4f} | {m['recall']:>9.4f} | {m['f1']:>9.4f} | {m['mcc']:>9.4f} | "
          f"{m['tp']:>6} | {m['fp']:>6} | {m['tn']:>6} | {m['fn']:>6} | {m['total']:>8}")


In [None]:
print("\n" + "="*50)
print("False Positives (FP) [Gold: no, Pred: yes]")
print("="*50)
if false_positives:
    for item in false_positives:
        print(f"  [Group: {item['group']:<8}] [File: {item['base_id']:<40}] "
              f"(Votes: {item['yes_count']}/{item['total_slices']} 'yes')")
else:
    print("  None.")

print("\n" + "="*50)
print("False Negatives (FN) [Gold: yes, Pred: no]")
print("="*50)
if false_negatives:
    for item in false_negatives:
        print(f"  [Group: {item['group']:<8}] [File: {item['base_id']:<40}] "
              f"(Votes: {item['yes_count']}/{item['total_slices']} 'yes')")
else:
    print("  None.")
print("\n")