In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [1]:
!cp /kaggle/input/overture-qwen-qlora-assets/overture_cleaned_places.csv /kaggle/working/

In [2]:
!cp /kaggle/input/overture-qwen-qlora-assets/make_sft_jsonl.py /kaggle/working/

In [3]:
!python /kaggle/working/make_sft_jsonl.py

✅ Done. Samples: 6000  → train/val/test = 4200/600/1200
Example:
{
  "instruction": "Decide whether the two place records refer to the same real-world place. Answer ONLY 'YES' or 'NO'.",
  "input": "Record A: Condomínio Edifício Gramado | landmark_and_historical_building | Rua Antônio de Barros, 2526 | 551129419896.0\nRecord B: Condomínio Edifício Antúrio | Community and Government > Residential Building > Apartment or Condo | Rua Reboujo, 250 | (11) 2091-3330",
  "output": "NO"
}
{
  "instruction": "Decide whether the two place records refer to the same real-world place. Answer ONLY 'YES' or 'NO'.",
  "input": "Record A: Havanna Inh. Hüseyen Özcelik Cocktailbar | Dining and Drinking > Bar | Holwedestr. 1 | 04953 18868947 | https://www.havannarestaurant.de\nRecord B: Havanna | restaurant | Holwedestraße 1 | 495318868947.0 | http://www.havanna-restaurant.de",
  "output": "YES"
}


In [4]:
pip install bitsandbytes

Collecting bitsandbytes
  Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<3,>=2.3->bitsandbytes)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<3,>=2.3->bitsandbytes)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<3,>=2.3->bitsandbytes)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<3,>=2.3->bitsandbytes)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch<3,>=2.3->bitsandbytes)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-c

In [1]:
import re
import time
from collections import Counter
import transformers
from datasets import load_dataset
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    classification_report,
    confusion_matrix,
    roc_auc_score,
    average_precision_score,
    roc_curve,
    precision_recall_curve,
)
import matplotlib.pyplot as plt

transformers.logging.set_verbosity_error()

# ---------------- Basic configuration ----------------
MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"  # Model

bnb = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True)
tok.pad_token = tok.eos_token
tok.padding_side = "left"

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb,
    device_map="auto",
    trust_remote_code=True
)

model.eval()
torch.set_grad_enabled(False)

# 为了和 7B 脚本格式统一，这里也定义 YES / NO 的 token id
YES_TOK_ID = tok(" YES", add_special_tokens=False)["input_ids"][-1]
NO_TOK_ID  = tok(" NO",  add_special_tokens=False)["input_ids"][-1]


# ---------------- Text generation + parsing (for printing samples) ----------------
def predict_yesno_text(prompt, max_new_tokens=2):
    """
    只负责：
      - 用 generate 生成完整回答
      - 从文本中解析 YES/NO
      - 返回 (hard_label_str, full_text)
    """
    inputs = tok(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            temperature=0.0,
            top_p=1.0
        )
    text = tok.decode(out[0], skip_special_tokens=True)

    post = text.split("Answer:")[-1].strip()
    first = post.split()[0].strip(",.?!:;").upper() if post else ""

    mapping_yes = {"YES", "1", "TRUE"}
    mapping_no  = {"NO", "0", "FALSE"}

    if first in mapping_yes:
        hard = "YES"
    elif first in mapping_no:
        hard = "NO"
    else:
        # 兜底在整个 Answer 部分搜
        m = re.search(r"\b(YES|NO|1|0|TRUE|FALSE)\b", post.upper())
        if m:
            tok_ = m.group(1)
            hard = "YES" if tok_ in mapping_yes else "NO"
        else:
            hard = "NO"

    return hard, text


# ---------------- Probability scoring (for AUC / PR-AUC) ----------------
def score_yes_probability(prompt):
    """
    不用 generate，只 forward 一次，
    从最后一个位置的 logits 中取出 ' YES' / ' NO' 的概率，
    返回 (pred_label_int, yes_prob_float)
    """
    enc = tok(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        out = model(**enc)
        logits = out.logits[:, -1, :]  # [1, vocab]
        probs = torch.softmax(logits, dim=-1)[0]

    p_yes = probs[YES_TOK_ID].item()
    p_no  = probs[NO_TOK_ID].item()
    denom = p_yes + p_no + 1e-9
    p_yes_norm = p_yes / denom

    pred = 1 if p_yes_norm >= 0.5 else 0
    return pred, p_yes_norm


# ---------------- Read the test set ----------------
ds = load_dataset("json", data_files="sft_data/test.jsonl")["train"]
y_true, y_pred, y_score = [], [], []

PRINT_N = 10
printed = 0

start_time = time.perf_counter()

for idx, ex in enumerate(ds):
    prompt = f"{ex['instruction']}\n{ex['input']}\nAnswer:"

    # 概率/硬预测（用于所有指标）
    pred_label_int, p_yes = score_yes_probability(prompt)
    y_pred.append(pred_label_int)
    y_score.append(p_yes)

    gold_label = 1 if ex["output"].strip().upper() == "YES" else 0
    y_true.append(gold_label)

    # 打印前几条样本，方便人工检查（复用文本版解析）
    if printed < PRINT_N:
        hard_text_label, raw = predict_yesno_text(prompt)
        print(f"\n===== SAMPLE #{idx} =====")
        print("PROMPT:\n", prompt)
        print("RAW GENERATION:\n", raw)
        print(
            "PARSED(from text):", hard_text_label,
            "   PROB(YES):", p_yes,
            "   GOLD:", "YES" if gold_label == 1 else "NO"
        )
        printed += 1

end_time = time.perf_counter()
total_time = end_time - start_time
num_samples = len(y_true)
avg_latency = total_time / num_samples
time_per_1000 = avg_latency * 1000.0

# ---------------- Basic classification indicators ----------------
print("\n========== BASIC METRICS ==========")
acc = accuracy_score(y_true, y_pred)
p, r, f1, _ = precision_recall_fscore_support(
    y_true, y_pred, average="binary", zero_division=0
)

print("Label balance (gold):", Counter(y_true))
print("Pred distribution    :", Counter(y_pred))
print("\nConfusion matrix:")
print(confusion_matrix(y_true, y_pred))
print("\nClassification report:")
print(classification_report(y_true, y_pred, digits=4))
print(f"\nAcc={acc:.4f}  P={p:.4f}  R={r:.4f}  F1={f1:.4f}")

# ---------------- AUC / PR-AUC ----------------
print("\n========== AUC / PR-AUC ==========")
try:
    roc_auc = roc_auc_score(y_true, y_score)
    pr_auc = average_precision_score(y_true, y_score)
    print(f"ROC-AUC = {roc_auc:.4f}")
    print(f"PR-AUC  = {pr_auc:.4f}")
except ValueError as e:
    print("Cannot compute AUC/PR-AUC:", e)
    roc_auc = None
    pr_auc = None

# ---------------- Curve drawing ----------------
# ROC 曲线
try:
    fpr, tpr, _ = roc_curve(y_true, y_score)
    plt.figure()
    plt.plot(fpr, tpr, label=f"ROC (AUC = {roc_auc:.4f})")
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Qwen3-4B – ROC Curve")
    plt.legend(loc="lower right")
    plt.tight_layout()
    plt.savefig("qwen3_4b_roc_curve.png")
    plt.close()
    print("Saved ROC curve to qwen3_4b_roc_curve.png")
except Exception as e:
    print("Error plotting ROC curve:", e)

# PR Curve
try:
    precision, recall, _ = precision_recall_curve(y_true, y_score)
    plt.figure()
    plt.plot(recall, precision, label=f"PR (AP = {pr_auc:.4f})")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title("Qwen3-4B – Precision-Recall Curve")
    plt.legend(loc="lower left")
    plt.tight_layout()
    plt.savefig("qwen3_4b_pr_curve.png")
    plt.close()
    print("Saved PR curve to qwen3_4b_pr_curve.png")
except Exception as e:
    print("Error plotting PR curve:", e)

# ---------------- Latency ----------------
print("\n========== LATENCY ==========")
print(f"Total samples               : {num_samples}")
print(f"Total time (s)              : {total_time:.4f}")
print(f"Avg latency (s/sample)      : {avg_latency:.6f}")
print(f"Time per 1000 samples (s)   : {time_per_1000:.4f}")

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/727 [00:00<?, ?B/s]

2025-11-14 07:25:33.716365: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763105133.949377     112 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763105134.010137     112 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/3.96G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/3.99G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/99.6M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/238 [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]


===== SAMPLE #0 =====
PROMPT:
 Decide whether the two place records refer to the same real-world place. Answer ONLY 'YES' or 'NO'.
Record A: The Bristol Hotel | resort | 115 Country Music Way | 2766963535.0 | https://www.bristolhotelva.com/?utm_medium=organic&utm_source=yext&utm_campaign=website
Record B: The Bristol Hotel | resort | 510 Birthplace Of Country Music Way | 2766963535 | https://www.bristolhotelva.com/
Answer:
RAW GENERATION:
 Decide whether the two place records refer to the same real-world place. Answer ONLY 'YES' or 'NO'.
Record A: The Bristol Hotel | resort | 115 Country Music Way | 2766963535.0 | https://www.bristolhotelva.com/?utm_medium=organic&utm_source=yext&utm_campaign=website
Record B: The Bristol Hotel | resort | 510 Birthplace Of Country Music Way | 2766963535 | https://www.bristolhotelva.com/
Answer: NO.
PARSED(from text): NO    PROB(YES): 0.06562847601809066    GOLD: YES

===== SAMPLE #1 =====
PROMPT:
 Decide whether the two place records refer to the same