In [None]:
%%capture
# Installs Unsloth, Xformers, and TRL
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps xformers trl peft accelerate bitsandbytes datasets

In [None]:
# ==========================================
# 0. Environment Setup
# ==========================================
import os
try:
    import unsloth
except ImportError:
    print("Installing Unsloth...")
    os.system("pip install -q 'unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git'")
    os.system("pip install -q --no-deps xformers trl peft accelerate bitsandbytes datasets")

# ==========================================
# 1. Configuration
# ==========================================
import torch
import json
import glob
import gc
import re
from unsloth import FastLanguageModel, PatchDPOTrainer
from trl import SFTTrainer, DPOTrainer, DPOConfig
from transformers import TrainingArguments
from datasets import Dataset
from unsloth.chat_templates import get_chat_template

# Paths
DATA_DIR = "/content/drive/MyDrive/ETSP"
OUTPUT_DIR_SFT = "/content/drive/MyDrive/ETSP/qwen_sft_v2"
OUTPUT_DIR_DPO = "/content/drive/MyDrive/ETSP/qwen_dpo_v2"

# Model Configuration
MODEL_NAME = "unsloth/Qwen2.5-3B-Instruct-bnb-4bit"
MAX_SEQ_LENGTH = 8192  # Extended context for 30 reviews + prompt
DTYPE = None
LOAD_IN_4BIT = True

# Mount Google Drive
from google.colab import drive
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

# ==========================================
# 2. Data Processing
# ==========================================
def clean_text(text: str) -> str:
    """Remove HTML tags and normalize whitespace."""
    if not isinstance(text, str):
        return ""
    if '<' in text and '>' in text:
        text = re.sub(r'<[^>]+>', '', text)
    return re.sub(r"\s+", " ", text).strip()

def format_review_context(reviews_list):
    """
    Reconstruct review context matching generation pipeline:
    - Clean text with clean_text()
    - Filter by length: 20-1500 chars
    - Format: 'Review N: {text}' (no title)
    - Max 30 reviews per product
    """
    texts = []
    count = 0

    for r in reviews_list:
        if count >= 30:
            break

        text = clean_text(r.get('text', ''))

        if len(text) < 20 or len(text) > 1500:
            continue

        texts.append(f'Review {count+1}: {text}')
        count += 1

    return '\n\n'.join(texts)

def load_datasets(data_dir):
    """
    Load *_dpo_v2.jsonl files and construct SFT/DPO datasets.

    SFT: Learns readability control (A1/C1) and format adherence
    DPO: Learns hallucination rejection via preference optimization
    """
    files = glob.glob(os.path.join(data_dir, "**/*_dpo_v2.jsonl"), recursive=True)
    print(f"Found {len(files)} data files.")

    sft_data = []
    dpo_data = []

    for file_path in files:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    item = json.loads(line)
                    reviews = item.get('reviews', [])
                    if not reviews: continue

                    context = format_review_context(reviews)
                    if not context.strip(): continue

                    simple = item.get('summary_simple')
                    complex_ = item.get('summary_complex')
                    hallucinated = item.get('summary_hallucinated')

                    if not (simple and complex_ and hallucinated): continue

                    # A1 (Simple) Prompt - matches generation exactly
                    a1_instruction = """Summarize these reviews for a beginner (CEFR A1).

Requirements:
- Use simple present tense, basic vocabulary.
- Write 1 paragraph of 3-4 short sentences.
- NO bullet points.
- Structure: [Overall] + [Feature] + [Conclusion].
- Reflect what most people say, but mention important issues if some people have them.

Reviews:
{context}

Output ONLY the summary.""".format(context=context)

                    # C1 (Complex) Prompt - matches generation exactly
                    c1_instruction = """Summarize these reviews in a professional, analytical style (CEFR C1).

Requirements:
1. **Style**: Use sophisticated vocabulary and phrasing, identical to a high-quality expert review.

2. **Format**: Use a bulleted list with EXACTLY 3-6 points total.
   - **CRITICAL**: EVERY point MUST start with `(+)`, `(-)`, or `(~)`.
   - Use `(+)` for consensus strengths.
   - Use `(-)` for consensus weaknesses.
   - Use `(~)` for mixed/controversial opinions (CRITICAL).
   - **Order**: List all `(+)` first, then `(-)`, then `(~)`. Do NOT mix them randomly.
   - **Compact output**: No blank lines between points. Each point on a new line immediately after the previous one.

3. **Handling Contradictions**:
   - If User A says "great battery" but User B says "battery died", you MUST report this as an inconsistency.
   - Use phrases like "Polarized feedback regarding...", "Inconsistent reports on...", or "While most praise X, some users note Y...".

4. **Length**: Total summary under 180 words. Each point 18-30 words.

Reviews:
{context}

Output ONLY the structured summary.""".format(context=context)

                    # SFT samples (both readability levels)
                    sft_data.append({
                        "instruction": a1_instruction,
                        "output": simple
                    })
                    sft_data.append({
                        "instruction": c1_instruction,
                        "output": complex_
                    })

                    # DPO samples (C1 only - hallucination detection)
                    dpo_data.append({
                        "question": c1_instruction,
                        "chosen": complex_,
                        "rejected": hallucinated
                    })

                except Exception:
                    continue

    print(f"Loaded: {len(sft_data)} SFT samples | {len(dpo_data)} DPO pairs")

    # 90/10 train/val split
    ds_sft = Dataset.from_list(sft_data).train_test_split(test_size=0.1, seed=42)
    ds_dpo = Dataset.from_list(dpo_data).train_test_split(test_size=0.1, seed=42)

    return ds_sft, ds_dpo

# Load datasets
dataset_sft, dataset_dpo = load_datasets(DATA_DIR)
print(f"SFT Train: {len(dataset_sft['train'])}, Val: {len(dataset_sft['test'])}")
print(f"DPO Train: {len(dataset_dpo['train'])}, Val: {len(dataset_dpo['test'])}")

# ==========================================
# 3. Stage 1: Supervised Fine-Tuning (SFT)
# ==========================================
print("\n🚀 [Stage 1] SFT Training")

# Load base model
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = MODEL_NAME,
    max_seq_length = MAX_SEQ_LENGTH,
    dtype = DTYPE,
    load_in_4bit = LOAD_IN_4BIT,
)

# Add LoRA adapters
model = FastLanguageModel.get_peft_model(
    model,
    r = 16,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha = 16,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
)

# Apply Qwen 2.5 chat template
tokenizer = get_chat_template(
    tokenizer,
    chat_template = "qwen-2.5",
    mapping = {"role": "role", "content": "content", "user": "user", "assistant": "assistant"}
)

def format_sft_func(examples):
    """Apply chat template to SFT samples."""
    texts = []
    for inst, out in zip(examples["instruction"], examples["output"]):
        conv = [
            {"role": "user", "content": inst},
            {"role": "assistant", "content": out}
        ]
        texts.append(tokenizer.apply_chat_template(conv, tokenize=False, add_generation_prompt=False))
    return {"text": texts}

dataset_sft_fmt = dataset_sft.map(format_sft_func, batched=True)

# SFT Trainer
trainer_sft = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset_sft_fmt['train'],
    eval_dataset = dataset_sft_fmt['test'],
    dataset_text_field = "text",
    max_seq_length = MAX_SEQ_LENGTH,
    dataset_num_proc = 2,
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_ratio = 0.1,
        num_train_epochs = 2,
        learning_rate = 5e-5,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 10,
        eval_strategy = "steps",
        eval_steps = 50,
        save_strategy = "steps",
        save_steps = 100,
        save_total_limit = 3,
        load_best_model_at_end = True,
        metric_for_best_model = "eval_loss",
        optim = "adamw_8bit",
        output_dir = OUTPUT_DIR_SFT,
        seed = 3407,
    ),
)

trainer_sft.train()
print("✅ SFT complete")

model.save_pretrained(OUTPUT_DIR_SFT)
tokenizer.save_pretrained(OUTPUT_DIR_SFT)

# ==========================================
# 4. Memory Cleanup
# ==========================================
del model, tokenizer, trainer_sft
gc.collect()
torch.cuda.empty_cache()
print("🧹 Memory cleared")

# ==========================================
# 5. Stage 2: Direct Preference Optimization (DPO)
# ==========================================
print("\n🚀 [Stage 2] DPO Training")

# Reload model with SFT weights
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = OUTPUT_DIR_SFT,
    max_seq_length = MAX_SEQ_LENGTH,
    dtype = DTYPE,
    load_in_4bit = LOAD_IN_4BIT,
)

# Enable DPO training
PatchDPOTrainer()

def format_dpo_func(example):
    """Format DPO data for preference learning."""
    return {
        "prompt": tokenizer.apply_chat_template([{"role": "user", "content": example["question"]}], tokenize=False, add_generation_prompt=True),
        "chosen": example["chosen"],
        "rejected": example["rejected"],
    }

dataset_dpo_fmt = dataset_dpo.map(format_dpo_func)

# DPO Trainer
dpo_trainer = DPOTrainer(
    model = model,
    ref_model = None,  # Unsloth handles reference model internally
    tokenizer = tokenizer,
    train_dataset = dataset_dpo_fmt['train'],
    eval_dataset = dataset_dpo_fmt['test'],
    max_length = MAX_SEQ_LENGTH,
    max_prompt_length = MAX_SEQ_LENGTH - 512,  # Reserve 512 tokens for output
    beta = 0.1,  # Strong preference signal for hallucination rejection
    args = DPOConfig(
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 8,
        warmup_ratio = 0.1,
        num_train_epochs = 1,
        learning_rate = 5e-6,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 10,
        eval_strategy = "steps",
        eval_steps = 20,
        save_strategy = "no",
        optim = "adamw_8bit",
        output_dir = OUTPUT_DIR_DPO,
        seed = 3407,
    ),
)

dpo_trainer.train()
print("✅ DPO complete")

# ==========================================
# 6. Final Export
# ==========================================
print(f"💾 Saving to {OUTPUT_DIR_DPO}")
model.save_pretrained(OUTPUT_DIR_DPO)
tokenizer.save_pretrained(OUTPUT_DIR_DPO)
print("🎉 Training complete!")

Installing Unsloth...
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
Found 9 data files.
Loaded: 3554 SFT samples | 1777 DPO pairs
SFT Train: 3198, Val: 356
DPO Train: 1599, Val: 178

🚀 [Stage 1] SFT Training
==((====))==  Unsloth 2025.12.5: Fast Qwen2 patching. Transformers: 4.57.3.
   \\   /|    NVIDIA L4. Num GPUs = 1. Max memory: 22.161 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu126. CUDA: 8.9. CUDA Toolkit: 12.6. Triton: 3.5.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.33.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/2.05G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/271 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

Unsloth 2025.12.5 patched 36 layers with 36 QKV layers, 36 O layers and 36 MLP layers.


Map:   0%|          | 0/3198 [00:00<?, ? examples/s]

Map:   0%|          | 0/356 [00:00<?, ? examples/s]

Unsloth: Tokenizing ["text"] (num_proc=16):   0%|          | 0/3198 [00:00<?, ? examples/s]

Unsloth: Tokenizing ["text"] (num_proc=16):   0%|          | 0/356 [00:00<?, ? examples/s]

The model is already on multiple devices. Skipping the move to device specified in `args`.
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 3,198 | Num Epochs = 2 | Total steps = 800
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 29,933,568 of 3,115,872,256 (0.96% trained)
  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice:

 3


[34m[1mwandb[0m: You chose "Don't visualize my results"


[34m[1mwandb[0m: Detected [huggingface_hub.inference, openai] in use.
[34m[1mwandb[0m: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
[34m[1mwandb[0m: For more information, check out the docs at: https://weave-docs.wandb.ai/


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,Validation Loss
50,2.5754,2.519796
100,2.026,1.907364
150,1.8008,1.881023
200,1.8816,1.867517
250,1.9119,1.858518
300,1.8324,1.851551
350,1.8278,1.846871
400,1.8051,1.841844
450,1.9178,1.83786
500,1.8241,1.835466


Unsloth: Not an error, but Qwen2ForCausalLM does not accept `num_items_in_batch`.
Using gradient accumulation will be very slightly less accurate.
Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient


0,1
eval/loss,█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
eval/runtime,█▂▂▂▂▂▁▂▂▂▁▂▂▂▂▂
eval/samples_per_second,▁▇▇▇▇▇█▇▇▇█▇▇▇▇▇
eval/steps_per_second,▁▇▇▇▇▇█▇▇▇█▇▇▆▇▇
train/epoch,▁▂▂▂▂▃▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇█████
train/global_step,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇████
train/grad_norm,█▆▅▄▂▁▂▂▁▂▁▂▂▂▂▃▂▃▂▂▂▂▃▃▄▃▃▃▃▄▃▃▃▃▃▃▃▃▃▃
train/learning_rate,▄▄▅█████▇▇▇▇▇▇▆▆▆▆▅▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train/loss,█▇▇▆▅▃▂▂▂▂▂▂▂▁▂▁▁▁▁▁▂▂▂▂▁▁▁▁▂▁▁▁▂▁▁▁▁▁▁▁

0,1
eval/loss,1.8255
eval/runtime,58.7558
eval/samples_per_second,6.059
eval/steps_per_second,3.029
total_flos,9.902415246178714e+16
train/epoch,2
train/global_step,800
train/grad_norm,0.28802
train/learning_rate,0.0
train/loss,1.8445


✅ SFT complete
🧹 Memory cleared

🚀 [Stage 2] DPO Training
==((====))==  Unsloth 2025.12.5: Fast Qwen2 patching. Transformers: 4.57.3.
   \\   /|    NVIDIA L4. Num GPUs = 1. Max memory: 22.161 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu126. CUDA: 8.9. CUDA Toolkit: 12.6. Triton: 3.5.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.33.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Map:   0%|          | 0/1599 [00:00<?, ? examples/s]

Map:   0%|          | 0/178 [00:00<?, ? examples/s]

Extracting prompt in train dataset (num_proc=16):   0%|          | 0/1599 [00:00<?, ? examples/s]

Applying chat template to train dataset (num_proc=16):   0%|          | 0/1599 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=16):   0%|          | 0/1599 [00:00<?, ? examples/s]

Extracting prompt in eval dataset (num_proc=16):   0%|          | 0/178 [00:00<?, ? examples/s]

Applying chat template to eval dataset (num_proc=16):   0%|          | 0/178 [00:00<?, ? examples/s]

Tokenizing eval dataset (num_proc=16):   0%|          | 0/178 [00:00<?, ? examples/s]

The model is already on multiple devices. Skipping the move to device specified in `args`.
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,599 | Num Epochs = 1 | Total steps = 200
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 8 x 1) = 8
 "-____-"     Trainable parameters = 29,933,568 of 3,115,872,256 (0.96% trained)


Step,Training Loss,Validation Loss,rewards / chosen,rewards / rejected,rewards / accuracies,rewards / margins,logps / chosen,logps / rejected,logits / chosen,logits / rejected,eval_logits / chosen,eval_logits / rejected,nll_loss
20,0.8522,0.592342,13.255346,11.899546,0.711111,1.355799,-379.436462,-404.503113,-0.622687,-0.590356,0,0,0
40,0.5304,0.323342,12.162534,9.678,0.844444,2.484532,-390.364594,-426.718628,-0.771685,-0.724624,No Log,No Log,No Log
60,0.2797,0.203208,10.942546,7.57139,0.922222,3.371155,-402.564484,-447.784668,-0.927748,-0.868945,No Log,No Log,No Log
80,0.155,0.137324,9.493854,5.340331,0.938889,4.153523,-417.051392,-470.095276,-1.059625,-0.990322,No Log,No Log,No Log
100,0.2151,0.106404,8.069264,3.231753,0.955556,4.837512,-431.297302,-491.181122,-1.188316,-1.110057,No Log,No Log,No Log
120,0.1757,0.089143,6.999538,1.613979,0.961111,5.38556,-441.994537,-507.358826,-1.277033,-1.191818,No Log,No Log,No Log
140,0.1297,0.078113,6.081824,0.248831,0.966667,5.832993,-451.171661,-521.010254,-1.358885,-1.266958,No Log,No Log,No Log
160,0.0378,0.072495,5.693827,-0.391511,0.966667,6.085338,-455.051697,-527.413696,-1.390653,-1.296496,No Log,No Log,No Log
180,0.0689,0.070677,5.400839,-0.835406,0.966667,6.236245,-457.981567,-531.852661,-1.414287,-1.317963,No Log,No Log,No Log
200,0.1203,0.069946,5.35225,-0.913489,0.966667,6.265739,-458.467438,-532.633484,-1.418352,-1.321566,No Log,No Log,No Log


0,1
eval/logits/chosen,█▇▅▄▃▂▂▁▁▁
eval/logits/rejected,█▇▅▄▃▂▂▁▁▁
eval/logps/chosen,█▇▆▅▃▂▂▁▁▁
eval/logps/rejected,█▇▆▄▃▂▂▁▁▁
eval/loss,█▄▃▂▁▁▁▁▁▁
eval/rewards/accuracies,▁▅▇▇██████
eval/rewards/chosen,█▇▆▅▃▂▂▁▁▁
eval/rewards/margins,▁▃▄▅▆▇▇███
eval/rewards/rejected,█▇▆▄▃▂▂▁▁▁
eval/runtime,▆▁▃▃▂▂▂▃█▅

0,1
eval/logits/chosen,-1.41835
eval/logits/rejected,-1.32157
eval/logps/chosen,-458.46744
eval/logps/rejected,-532.63348
eval/loss,0.06995
eval/rewards/accuracies,0.96667
eval/rewards/chosen,5.35225
eval/rewards/margins,6.26574
eval/rewards/rejected,-0.91349
eval/runtime,93.6106


✅ DPO complete
💾 Saving to /content/drive/MyDrive/ETSP/qwen_dpo_v2
🎉 Training complete!


In [None]:
# ==========================================
# 0. 环境与依赖
# ==========================================
import torch
from unsloth import FastLanguageModel
from transformers import TextStreamer
import json
import re
import random
import os
import glob

# ANSI 颜色代码 (为了好看)
class Colors:
    GREEN = '\033[92m'
    RED = '\033[91m'
    YELLOW = '\033[93m'
    CYAN = '\033[96m'
    BOLD = '\033[1m'
    RESET = '\033[0m'
    GRAY = '\033[90m'

# ==========================================
# 1. 配置项 (Configuration)
# ==========================================
# 🔥 自动指向你刚刚训练好的 DPO 模型目录
MODEL_PATH = "/content/drive/MyDrive/ETSP/qwen_dpo_v2"
# 自动寻找任意一个生成的 v2 数据文件用于测试
DATA_DIR = "/content/drive/MyDrive/ETSP"

# ==========================================
# 2. 模型加载
# ==========================================
print(f"{Colors.CYAN}🔄 正在加载训练好的 DPO 模型: {MODEL_PATH}...{Colors.RESET}")
try:
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = MODEL_PATH,
        max_seq_length = 6144, # 与训练保持一致
        dtype = None,
        load_in_4bit = True,
    )
    FastLanguageModel.for_inference(model)
    print(f"{Colors.GREEN}✅ 模型加载成功!{Colors.RESET}\n")
except Exception as e:
    print(f"{Colors.RED}❌ 模型加载失败 (可能是路径不对): {e}{Colors.RESET}")
    exit()

# ==========================================
# 3. 数据处理函数 (必须与训练完全一致)
# ==========================================
def clean_text(text: str) -> str:
    if not isinstance(text, str): return ""
    if '<' in text and '>' in text:
        text = re.sub(r'<[^>]+>', '', text)
    return re.sub(r"\s+", " ", text).strip()

def format_review_context(reviews_list):
    texts = []
    count = 0
    for r in reviews_list:
        if count >= 30: break
        text = clean_text(r.get('text', ''))
        if len(text) < 20 or len(text) > 1500: continue
        texts.append(f'Review {count+1}: {text}')
        count += 1
    return '\n\n'.join(texts)

# ==========================================
# 4. 随机加载一条测试数据
# ==========================================
def load_random_sample():
    files = glob.glob(os.path.join(DATA_DIR, "**/*_dpo_v2.jsonl"), recursive=True)
    if not files:
        print(f"{Colors.RED}❌ 找不到测试数据文件 (*_dpo_v2.jsonl){Colors.RESET}")
        return None, None

    target_file = random.choice(files)
    print(f"📂 使用数据文件: {target_file}")

    samples = []
    with open(target_file, 'r', encoding='utf-8') as f:
        for _ in range(500): # 读取前500行作为池子
            line = f.readline()
            if not line: break
            samples.append(json.loads(line))

    return random.choice(samples), target_file

sample_data, filename = load_random_sample()
if not sample_data: exit()

context_input = format_review_context(sample_data['reviews'])
asin = sample_data.get('parent_asin', 'Unknown')
category = sample_data.get('category', 'Unknown')

# 打印原始评论供你参考
print(f"{'='*60}")
print(f"📖 产品: {category} | ASIN: {asin}")
print(f"📖 原始评论输入 (共 {len(sample_data['reviews'])} 条, 截取前30条)")
print(f"{'='*60}")
print(f"{Colors.GRAY}{context_input[:1000]}... [由于太长只显示前1000字符]{Colors.RESET}")
print(f"{'='*60}\n")

# ==========================================
# 5. 定义 Prompt (必须与训练代码 100% 一致)
# ==========================================
def get_prompt(style, context):
    if style == "A1":
        return """Summarize these reviews for a beginner (CEFR A1).

Requirements:
- Use simple present tense, basic vocabulary.
- Write 1 paragraph of 3-4 short sentences.
- NO bullet points.
- Structure: [Overall] + [Feature] + [Conclusion].
- Reflect what most people say, but mention important issues if some people have them.

Reviews:
{context}

Output ONLY the summary.""".format(context=context)

    elif style == "C1":
        return """Summarize these reviews in a professional, analytical style (CEFR C1).

Requirements:
1. **Style**: Use sophisticated vocabulary and phrasing, identical to a high-quality expert review.

2. **Format**: Use a bulleted list with EXACTLY 3-6 points total.
   - **CRITICAL**: EVERY point MUST start with `(+)`, `(-)`, or `(~)`.
   - Use `(+)` for consensus strengths.
   - Use `(-)` for consensus weaknesses.
   - Use `(~)` for mixed/controversial opinions (CRITICAL).
   - **Order**: List all `(+)` first, then `(-)`, then `(~)`. Do NOT mix them randomly.
   - **Compact output**: No blank lines between points. Each point on a new line immediately after the previous one.

3. **Handling Contradictions**:
   - If User A says "great battery" but User B says "battery died", you MUST report this as an inconsistency.
   - Use phrases like "Polarized feedback regarding...", "Inconsistent reports on...", or "While most praise X, some users note Y...".

4. **Length**: Total summary under 180 words. Each point 18-30 words.

Reviews:
{context}

Output ONLY the structured summary.""".format(context=context)

# ==========================================
# 6. 运行推理测试
# ==========================================
def run_test(style):
    print(f"\n>>> 🧪 测试模式: {Colors.BOLD}{style} (SFT+DPO 效果验证){Colors.RESET}")

    user_prompt = get_prompt(style, context_input)

    messages = [{"role": "user", "content": user_prompt}]
    inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")

    streamer = TextStreamer(tokenizer, skip_prompt=True)

    print(f"{Colors.CYAN}🤖 模型生成中...{Colors.RESET}")
    print(f"{Colors.YELLOW}{'-'*40}{Colors.RESET}")

    _ = model.generate(
        inputs,
        streamer=streamer,
        max_new_tokens=512,
        temperature=0.1, # 低温采样，测试模型稳定性
        use_cache=True
    )
    print(f"{Colors.YELLOW}{'-'*40}{Colors.RESET}")

# 运行两次对比
run_test("A1") # 应该输出简单段落
run_test("C1") # 应该输出带 (+) (-) 的列表
print(f"\n{Colors.GREEN}🎉 测试完成！请检查 C1 模式是否包含 (+) (-) 符号。{Colors.RESET}")

[96m🔄 正在加载训练好的 DPO 模型: /content/drive/MyDrive/ETSP/qwen_dpo_v2...[0m
==((====))==  Unsloth 2025.12.5: Fast Qwen2 patching. Transformers: 4.57.3.
   \\   /|    NVIDIA L4. Num GPUs = 1. Max memory: 22.161 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu126. CUDA: 8.9. CUDA Toolkit: 12.6. Triton: 3.5.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.33.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


[92m✅ 模型加载成功![0m

📂 使用数据文件: /content/drive/MyDrive/ETSP/Clothing_Shoes_and_Jewelry/Clothing_Shoes_and_Jewelry_dpo_v2.jsonl
📖 产品: Clothing_Shoes_and_Jewelry | ASIN: B07L36NT2C
📖 原始评论输入 (共 11 条, 截取前30条)
[90mReview 1: Very warm in a chilly work environment. Washes well.

Review 2: I love this jacket, but I ordered my usual scrub size and it was too small. It holds up under the harsh washing needed to clean the hospital germs off of it (after the initial shrinkage from washing). Super soft and warm. I will be using this for as long as it holds up. The large fit me best.

Review 3: It's very soft, buy the material was a little thinner than I was expecting. It looks great and is very well constructed, though so I'll keep it. If you work in veterinary medicine just know it attracts hair like crazy!

Review 4: This is a lighter-weight fleece jacket that I have been wearing in the office this spring and summer. Just enough warmth to keep the air conditioned office chill away. The material ha

In [None]:
import os
import json
import random
import re
from tqdm.auto import tqdm
from unsloth import FastLanguageModel
from datasets import Dataset
from openai import OpenAI

# ==========================================
# 1. 配置区域 (Configuration)
# ==========================================
# 🔥 确保这里是你 DPO 训练保存的真实路径
MODEL_PATH = "/content/drive/MyDrive/ETSP/qwen_dpo_v2"
DATA_ROOT = "/content/drive/MyDrive/ETSP"

CATEGORIES = [
    "Electronics", "Books", "Home_and_Kitchen", "Beauty_and_Personal_Care",
    "Clothing_Shoes_and_Jewelry", "Toys_and_Games", "Sports_and_Outdoors",
    "Pet_Supplies", "Automotive", "Office_Products"
]

SAMPLES_PER_CATEGORY = 2  # 每个品类抽 2 个

# 裁判配置 (Qwen-Plus)
API_KEY = "sk-996821dc137f4a0885d3bee6eca2127a"
BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
JUDGE_MODEL = "qwen-plus"

# ==========================================
# 2. 模型加载 (Student)
# ==========================================
print(f"🔄 正在加载微调模型: {MODEL_PATH}...")
try:
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = MODEL_PATH,
        max_seq_length = 6144,
        dtype = None,
        load_in_4bit = True,
    )
    FastLanguageModel.for_inference(model)
    print("✅ Student 模型加载成功！")
except Exception as e:
    print(f"❌ 模型加载失败，请检查路径: {e}")
    exit()

# ==========================================
# 3. 数据提取 (Unseen Validation Set)
# ==========================================
def clean_text(text: str) -> str:
    if not isinstance(text, str): return ""
    if '<' in text and '>' in text: text = re.sub(r'<[^>]+>', '', text)
    return re.sub(r"\s+", " ", text).strip()

def format_context(reviews):
    texts = []
    count = 0
    for r in reviews:
        if count >= 30: break
        text = clean_text(r.get('text', ''))
        if len(text) < 20 or len(text) > 1500: continue
        texts.append(f'Review {count+1}: {text}')
        count += 1
    return '\n\n'.join(texts)

def get_unseen_samples(category, count=2):
    # 优先找 v2 数据，没有则找旧数据
    file_path = os.path.join(DATA_ROOT, category, f"{category}_dpo_v2.jsonl")
    if not os.path.exists(file_path):
        file_path = os.path.join(DATA_ROOT, category, f"{category}_dpo.jsonl")
        if not os.path.exists(file_path):
            print(f"⚠️ 跳过 {category}: 文件不存在")
            return []

    clean_data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                item = json.loads(line)
                safe_item = {
                    "parent_asin": item.get("parent_asin", "unknown"),
                    "reviews": item.get("reviews", []),
                }
                clean_data.append(safe_item)
            except:
                continue

    if len(clean_data) < 10: return []

    try:
        ds = Dataset.from_list(clean_data)
        # 使用 seed=42 切分，确保是未见过的测试集
        ds_split = ds.train_test_split(test_size=0.1, seed=42)
        test_set = ds_split['test']

        if len(test_set) > count:
            indices = random.sample(range(len(test_set)), count)
            samples = [test_set[i] for i in indices]
        else:
            samples = [item for item in test_set]

        results = []
        for item in samples:
            context = format_context(item.get('reviews', []))
            if len(context) > 200:
                item['formatted_context'] = context
                item['category'] = category
                results.append(item)

        return results

    except Exception as e:
        print(f"⚠️ {category} 处理出错: {e}")
        return []

# ==========================================
# 4. 裁判逻辑 (Judge) - 纯净事实核查版
# ==========================================
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)

def evaluate_factuality(original_reviews, generated_summary):
    # 🔥 最终纯净版: 去掉了所有关于训练、DPO的描述，只关注事实本身
    judge_prompt = f"""
你是一位公平、客观的事实核查员。请基于下方的评论原文，检查 AI 生成的摘要内容是否属实。

[输入数据]
评论原文 (Source):
{original_reviews}

AI 生成摘要 (Summary):
{generated_summary}

[判罚标准]

✅ PASS (合格):
1. **有据可依**: 摘要中提到的观点（无论是优点还是缺点），只要在评论原文中真实出现过（哪怕只有少数人提到），就算正确。
2. **合理归纳**: 允许模型对相似的观点进行总结概括。

❌ FAIL (不合格 - 幻觉或错误):
1. **无中生有 (Fabrication)**: 提到了评论中根本不存在的参数、功能、配件或具体场景。
2. **情感反转 (Sentiment Error)**: 评论明明说是“缺点”，摘要却写成了“优点”（或反之）。
3. **掩盖矛盾 (Misrepresentation)**: 如果评论中对某一点存在明显分歧（有人说好，有人说坏），但摘要却只单方面说“大家都喜欢”，忽略了反面声音。

[输出格式 - JSON]
{{
    "score": <0-10>,
    // 10: 完美，完全符合事实。
    // 8-9: 事实准确，措辞可能有轻微偏差。
    // 6-7: 有轻微的夸大，或者忽略了明显的争议点。
    // 0-5: 出现了明显的功能捏造（幻觉）或严重的情感错误。

    "is_factual": <boolean>, // 分数 >= 7 为 true
    "hallucinations": [
        "摘要原句 -> 错误原因 (无中生有 / 情感反转 / 掩盖矛盾)"
    ],
    "reasoning": "简短评价。"
}}
"""
    try:
        response = client.chat.completions.create(
            model=JUDGE_MODEL,
            # Qwen-Plus 上下文足够长，这里传入全量评论，不做截断
            messages=[{"role": "user", "content": judge_prompt}],
            temperature=0.0,
            response_format={"type": "json_object"}
        )
        return json.loads(response.choices[0].message.content)
    except Exception as e:
        print(f"Judge API Error: {e}")
        return {"score": 0, "is_factual": False, "hallucinations": ["API Error"], "reasoning": str(e)}

def generate_student_summary(context):
    # 使用训练时的标准 C1 Prompt
    user_prompt = """Summarize these reviews in a professional, analytical style (CEFR C1).

Requirements:
1. **Style**: Use sophisticated vocabulary and phrasing.
2. **Format**: Use a bulleted list with EXACTLY 3-6 points total.
   - **CRITICAL**: EVERY point MUST start with `(+)`, `(-)`, or `(~)`.
   - Use `(+)` for consensus strengths.
   - Use `(-)` for consensus weaknesses.
   - Use `(~)` for mixed/controversial opinions.
   - **Order**: List all `(+)` first, then `(-)`, then `(~)`.

3. **Handling Contradictions**:
   - Report inconsistencies using phrases like "Polarized feedback...".

4. **Length**: Total summary under 180 words.

Reviews:
{context}

Output ONLY the structured summary.""".format(context=context)

    messages = [{"role": "user", "content": user_prompt}]

    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to("cuda")

    outputs = model.generate(
        inputs,
        max_new_tokens=512,
        temperature=0.1,
        use_cache=True
    )

    text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    if "assistant" in text:
        return text.split("assistant")[-1].strip()
    return text

# ==========================================
# 5. 主程序
# ==========================================
def main():
    all_test_samples = []
    print(f"\n🧩 正在从 {len(CATEGORIES)} 个品类中抽取未见验证集 (seed=42)...")

    for cat in tqdm(CATEGORIES, desc="Sampling"):
        samples = get_unseen_samples(cat, count=SAMPLES_PER_CATEGORY)
        all_test_samples.extend(samples)

    if not all_test_samples:
        print("❌ 未找到任何测试样本，请检查数据路径！")
        return

    print(f"\n🚀 准备就绪！共 {len(all_test_samples)} 个样本。开始全品类真实性测试...\n")

    results = []

    # ANSI 颜色配置
    GREEN = '\033[92m'
    RED = '\033[91m'
    YELLOW = '\033[93m'
    RESET = '\033[0m'

    for i, item in enumerate(all_test_samples):
        cat = item['category']
        asin = item['parent_asin']
        context = item['formatted_context']

        print(f"[{i+1}/{len(all_test_samples)}] 正在测试: {cat} (ASIN: {asin})")

        # 1. 生成 (Student)
        summary = generate_student_summary(context)

        # 2. 判卷 (Judge)
        eval_res = evaluate_factuality(context, summary)

        # 3. 记录
        res_entry = {
            "category": cat,
            "asin": asin,
            "summary": summary,
            "score": eval_res['score'],
            "is_factual": eval_res['is_factual'],
            "hallucinations": eval_res['hallucinations'],
            "reasoning": eval_res.get('reasoning', '')
        }
        results.append(res_entry)

        # 4. 实时控制台输出
        score = eval_res['score']
        color = GREEN if score >= 8 else (YELLOW if score >= 5 else RED)

        # 为了不刷屏，只显示前100个字符的摘要
        print(f"   生成摘要: {summary[:100].replace(chr(10), ' ')}...")
        print(f"   裁判打分: {color}{score}/10{RESET}")

        if eval_res['hallucinations']:
            print(f"   🚨 幻觉/错误: {RED}{eval_res['hallucinations']}{RESET}")
        else:
            print(f"   ✅ 事实核查通过")

        print("-" * 50)

    # 最终统计报告
    if not results: return

    print("\n" + "="*60)
    print("📊 幻觉检测最终报告 (Final Factuality Report)")
    print("="*60)

    cat_stats = {}
    for r in results:
        c = r['category']
        if c not in cat_stats: cat_stats[c] = []
        cat_stats[c].append(r['score'])

    print(f"{'Category':<30} | {'Avg Score':<10}")
    print("-" * 45)
    for cat, scores in cat_stats.items():
        avg = sum(scores)/len(scores)
        print(f"{cat:<30} | {avg:.1f}")

    total_avg = sum(r['score'] for r in results) / len(results)
    pass_count = sum(1 for r in results if r['is_factual'])
    pass_rate = (pass_count / len(results)) * 100

    print("="*60)
    print(f"🏆 总平均分: {total_avg:.2f} / 10")
    print(f"🛡️  无幻觉通过率: {pass_rate:.1f}%")
    print("="*60)

    with open("hallucination_test_results.json", "w", encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    print("💾 详细结果已保存至 hallucination_test_results.json")

if __name__ == "__main__":
    main()

🔄 正在加载微调模型: /content/drive/MyDrive/ETSP/qwen_dpo_v2...
==((====))==  Unsloth 2025.12.5: Fast Qwen2 patching. Transformers: 4.57.3.
   \\   /|    NVIDIA L4. Num GPUs = 1. Max memory: 22.161 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu126. CUDA: 8.9. CUDA Toolkit: 12.6. Triton: 3.5.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.33.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
✅ Student 模型加载成功！

🧩 正在从 10 个品类中抽取未见验证集 (seed=42)...


Sampling:   0%|          | 0/10 [00:00<?, ?it/s]


🚀 准备就绪！共 20 个样本。开始全品类真实性测试...

[1/20] 正在测试: Electronics (ASIN: B01K53FS12)
   生成摘要: (+) Widely praised for effective, sturdy, and well-crafted design, particularly for Samsung monitors...
   裁判打分: [93m7/10[0m
   🚨 幻觉/错误: [91m['“Quality concerns noted, including loose components, poor finish, and inadequate packaging” -> 错误原因: 无中生有 (包装问题未提及)', '“leading to dissatisfaction with build integrity and perceived value, especially for critical applications” -> 错误原因: 掩盖矛盾 / 夸大负面情绪'][0m
--------------------------------------------------
[2/20] 正在测试: Electronics (ASIN: B089Q5MJ2K)
   生成摘要: (+) Exceptional performance, rapid startup, lightweight design, strong value for money, and high rel...
   裁判打分: [92m8/10[0m
   🚨 幻觉/错误: [91m["摘要中提到 'modern features like touchpad' -> 错误原因 (无中生有): 所有笔记本电脑都有触控板，评论原文中并未提及‘触控板’作为现代功能或被用户特别称赞，因此将其列为一项‘现代功能’属于不当添加，易误导为该产品特有或突出设计。"][0m
--------------------------------------------------
[3/20] 正在测试: Books (ASIN: 142152077X)
   生成摘要: (+) Exceptional narrative 