In [None]:
import torch
major_version, minor_version = torch.cuda.get_device_capability()
print(f"Major: {major_version}, Minor: {minor_version}")
from datasets import load_dataset
import datasets
from trl import SFTTrainer
import pandas as pd
import numpy as np
import os
import pandas as pd
import numpy as np
from unsloth import FastLanguageModel
from trl import SFTTrainer, SFTConfig
from transformers import TrainingArguments, Trainer
from typing import Tuple
import warnings
from typing import Any, Dict, List, Union
from transformers import DataCollatorForLanguageModeling
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

NUM_CLASSES = 3 # number of classes in the csv

max_seq_length = 1024*3 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+

model_name = "unsloth/Qwen3-4B-Base"; load_in_4bit = False

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_name,load_in_4bit = load_in_4bit,
    max_seq_length = max_seq_length,
    dtype = dtype,
    gpu_memory_utilization = 0.9
)

In [None]:
number_token_ids = []
for i in range(0, NUM_CLASSES):
    number_token_ids.append(tokenizer.encode(str(i), add_special_tokens=False)[0])
# keep only the number tokens from lm_head
par = torch.nn.Parameter(model.lm_head.weight[number_token_ids, :])

old_shape = model.lm_head.weight.shape
old_size = old_shape[0]
print(par.shape)
print(old_shape)

model.lm_head.weight = par

reverse_map = {value: idx for idx, value in enumerate(number_token_ids)} # will be used later to convert an idx from the old tokenizer to the new lm_head
reverse_map

In [None]:
from peft import LoftQConfig

lora_rank = 16

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank,
    target_modules = [
        "lm_head", "embed_tokens" # can easily be trained because it now has a small size
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",],
    lora_alpha = lora_rank,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
    use_rslora = True,  # We support rank stabilized LoRA
    # init_lora_weights = 'loftq',
    # loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1), # And LoftQ
)
print("trainable parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))

In [None]:
prompt = """### INSTRUCTION ###
You are a fact-checking system.  
Classify the "Response" to the "Question" into one of three classes:  
- class 0: no → Correct answer, consistent with context, no extra info.  
- class 1: intrinsic → Contradicts or distorts context info.  
- class 2: extrinsic → Adds info not in context.  
### INPUT ###
Context: {context}  
Response: {response}  
### OUTPUT ###
The correct answer is: class {label}"""


In [None]:
label_map = {
    'no': '0',
    'intrinsic': '1',
    'extrinsic': '2'
}
label2id = {"no": 0, "intrinsic": 1, "extrinsic": 2}

def formatting_prompts_func(dataset_, prompt_template):
    texts = []
    for index, row in dataset_.iterrows():
        # Lấy dữ liệu và strip() để loại bỏ khoảng trắng thừa
        context_ = str(row.get('context', '')).strip()
        question_ = str(row.get('prompt', '')).strip()
        response_ = str(row.get('response', '')).strip()
        
        # Lấy nhãn chữ và loại bỏ khoảng trắng thừa
        label_text = str(row.get('label', '')).strip()
        
        # *** DÒNG QUAN TRỌNG NHẤT: Ánh xạ nhãn chữ sang nhãn số ***
        label_ = label_map.get(label_text, '') # Dùng .get để tránh lỗi nếu có nhãn lạ
        
        # Điền thông tin vào khuôn mẫu
        text = prompt_template.format(
            context=context_, 
            # question=question_, 
            response=response_, 
            label=label_
        )
        
        texts.append(text)
        
    return texts

def encode_label(example):
    example["label"] = label2id[example["label"]]
    return example

In [None]:
import pandas as pd 
from sklearn.model_selection import train_test_split
train_df_original = pd.read_csv("/home/tamanh/dsc2025/dsc2025/data/vihallu-train.csv")
train_df, eval_df = train_test_split(train_df_original, test_size=0.1, random_state=36)

train_df["text"] = formatting_prompts_func(train_df, prompt)
eval_df["text"] = formatting_prompts_func(eval_df, prompt)
train_dataset = datasets.Dataset.from_pandas(train_df, preserve_index=True)
eval_dataset = datasets.Dataset.from_pandas(eval_df, preserve_index=True)
train_dataset = train_dataset.map(encode_label)
eval_dataset = eval_dataset.map(encode_label)

In [None]:
class DataCollatorForLastTokenLM(DataCollatorForLanguageModeling):
    def __init__(
        self,
        *args,
        mlm: bool = False,
        ignore_index: int = -100,
        **kwargs,
    ):
        super().__init__(*args, mlm=mlm, **kwargs)
        self.ignore_index = ignore_index

    def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
        batch = super().torch_call(examples)

        for i in range(len(examples)):
            # Find the last non-padding token
            last_token_idx = (batch["labels"][i] != self.ignore_index).nonzero()[-1].item()
            # Set all labels to ignore_index except for the last token
            batch["labels"][i, :last_token_idx] = self.ignore_index
            # If the last token in the text is, for example, "2", then this was processed with the old tokenizer into number_token_ids[2]
            # But we don't actually want this because number_token_ids[2] could be something like 27, which is now undefined in the new lm_head. So we map it to the new lm_head index.
            # if this line gives you a keyerror then increase max_seq_length
            batch["labels"][i, last_token_idx] = reverse_map[ batch["labels"][i, last_token_idx].item() ]


        return batch
collator = DataCollatorForLastTokenLM(tokenizer=tokenizer)

In [None]:
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_dataset,
    eval_dataset = eval_dataset,
    max_seq_length = max_seq_length,
    dataset_num_proc = 1,
    packing = False,
    args = SFTConfig(
        per_device_train_batch_size = 8,
        gradient_accumulation_steps = 4,
        warmup_steps = 10,
        learning_rate = 1e-4,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "cosine_with_restarts",
        seed = 0,
        output_dir = "sft-qwen3-4b",
        num_train_epochs = 3,
        report_to = "wandb",
        group_by_length = True,

        logging_strategy = "epoch",
        eval_strategy = "epoch",
        save_strategy = "epoch",

        load_best_model_at_end = True,
        metric_for_best_model = "eval_loss",
    ),
    data_collator = collator,
    dataset_text_field = "text",
)

In [None]:
trainer_stats = trainer.train()

In [None]:
# Save the current (trimmed) lm_head and bias
trimmed_lm_head = model.lm_head.weight.data.clone()
trimmed_lm_head_bias = model.lm_head.bias.data.clone() if hasattr(model.lm_head, "bias") and model.lm_head.bias is not None else torch.zeros(len(number_token_ids), device=trimmed_lm_head.device)

# Create a new lm_head with shape [old_size, hidden_dim]
hidden_dim = trimmed_lm_head.shape[1]
new_lm_head = torch.full((old_size, hidden_dim), 0, dtype=trimmed_lm_head.dtype, device=trimmed_lm_head.device)
new_lm_head_bias = torch.full((old_size,), -1000.0, dtype=trimmed_lm_head_bias.dtype, device=trimmed_lm_head_bias.device)

# Fill in the weights and bias for the allowed tokens (number_token_ids)
for new_idx, orig_token_id in enumerate(number_token_ids):
    new_lm_head[orig_token_id] = trimmed_lm_head[new_idx]
    new_lm_head_bias[orig_token_id] = trimmed_lm_head_bias[new_idx]

# Update the model's lm_head weight and bias
with torch.no_grad():
    new_lm_head_module = torch.nn.Linear(hidden_dim, old_size, bias=True, device=model.device)
    new_lm_head_module.weight.data.copy_(new_lm_head)
    new_lm_head_module.bias.data.copy_(new_lm_head_bias)
    model.lm_head.modules_to_save["default"] = new_lm_head_module

print(f"Remade lm_head: shape = {model.lm_head.weight.shape}. Allowed tokens: {number_token_ids}")


In [None]:
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
import random
import pandas as pd

# --- Step 1: Create mapping dictionaries ---
# Map from string labels (original data) to numeric IDs (for computation)
label_to_id = {
    'no': 0,
    'intrinsic': 1,
    'extrinsic': 2
}
# Reverse mapping from numeric IDs (predictions) to string labels (for display)
id_to_label = {v: k for k, v in label_to_id.items()}


# --- Step 2: Prepare data and prompt ---
# Build inference prompt template, removing the label part
# Use rsplit to safely split from the last occurrence
inference_prompt_template = prompt.rsplit("class {label}", 1)[0] + "class "
# inference_prompt_template = prompt.rsplit("lớp {label}", 1)[0] + "lớp "

# Re-create the "text" column for eval_df if it does not exist
eval_df["text"] = formatting_prompts_func(eval_df, inference_prompt_template)

# Sort validation set by token length for efficient batching
eval_df['token_length'] = eval_df['text'].apply(lambda x: len(tokenizer.encode(x, add_special_tokens=False)))
val_df_sorted = eval_df.sort_values(by='token_length').reset_index(drop=True)


# --- Step 3: Initialize variables for evaluation ---
display = 20
batch_size = 1 # Can increase to 8 or 16 if GPU VRAM allows
device = model.device
correct_predictions = 0
results = []

predicts = []
gt = []

# --- Step 4: Evaluation loop ---
with torch.inference_mode():
    for i in tqdm(range(0, len(val_df_sorted), batch_size), desc="Evaluating"):
        batch_df = val_df_sorted.iloc[i:i+batch_size]
        
        # Build prompts for the current batch
        prompts = []
        for _, row in batch_df.iterrows():
            prompts.append(
                inference_prompt_template.format(
                    context=row["context"],
                    question=row["prompt"],
                    response=row["response"]
                )
            )

        # Tokenize and move inputs to GPU
        inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=max_seq_length).to(device)
        
        # Forward pass: get logits from model
        logits = model(**inputs).logits
        
        # Extract logits of the last token in each sequence
        last_idxs = inputs.attention_mask.sum(1) - 1
        last_logits = logits[torch.arange(len(batch_df)), last_idxs, :]
        
        # Compute probabilities for number tokens ("0", "1", "2")
        probs_all = F.softmax(last_logits, dim=-1)
        probs = probs_all[:, number_token_ids]
        
        # Predicted IDs: argmax over probabilities
        predicted_ids = torch.argmax(probs, dim=-1).cpu().numpy()

        # True labels (convert from string to numeric IDs)
        true_labels_str = batch_df['label'].tolist()
        true_labels_ids = [label_to_id[label] for label in true_labels_str]

        # Count correct predictions
        correct_predictions += sum([p == t for p, t in zip(predicted_ids, true_labels_ids)])

        # Save results for inspection
        for j in range(len(batch_df)):
            pred_id = predicted_ids[j]
            true_id = true_labels_ids[j]

            predicts.append(pred_id)
            gt.append(true_id)
            
            results.append({
                "id": batch_df["id"].iloc[j],
                "text": batch_df['text'].iloc[j],
                "true_label": id_to_label.get(true_id, "N/A"),
                "pred_label": id_to_label.get(pred_id, "N/A"),
                "probs": probs[j].float().cpu().numpy(),
                "is_correct": pred_id == true_id
            })

# --- Step 5: Final evaluation results ---
accuracy = 100 * correct_predictions / len(val_df_sorted)
print(f"\nValidation Accuracy: {accuracy:.2f}% ({correct_predictions}/{len(val_df_sorted)})")

print("\n--- Random Samples ---")
# Prioritize wrong samples for debugging
wrong_samples = [s for s in results if not s['is_correct']]
correct_samples = [s for s in results if s['is_correct']]

samples_to_show = wrong_samples + correct_samples

# Show up to `display` samples
for s in random.sample(samples_to_show, min(display, len(results))):
    status_icon = '✅' if s['is_correct'] else '❌'
    print(f"\nText: {s['text']}")
    print(f"True: {s['true_label']:<10} | Pred: {s['pred_label']:<10} {status_icon}")
    # Show probability distribution for each class
    prob_str = ", ".join([f"{id_to_label[k]}: {v:.3f}" for k, v in enumerate(s['probs'], start=0)])
    print(f"Probs: [{prob_str}]")
    
# Save wrong predictions to file
with open("wrong_predictions.txt", "w", encoding="utf-8") as f:
    for s in wrong_samples:
        f.write("============================================\n")
        f.write(f"Text: {s['text']}\n")
        f.write(f"True: {s['true_label']} | Pred: {s['pred_label']}\n")
        prob_str = ", ".join([f"{id_to_label[k]}: {v:.3f}" for k, v in enumerate(s['probs'], start=0)])
        f.write(f"Probs: [{prob_str}]\n\n")

print(f"\n❌ Saved {len(wrong_samples)} wrong samples into 'wrong_predictions.txt'")


# Save correct predictions to CSV
correct_df = pd.DataFrame(correct_samples)
correct_df.to_csv("correct_predictions.csv", index=False, encoding="utf-8-sig")

print(f"✅ Saved {len(correct_samples)} correct samples into 'correct_predictions.csv'")

# Clean up
if 'token_length' in eval_df:
    del eval_df['token_length']
