In [None]:
# !pip install torch transformers datasets peft bitsandbytes accelerate

In [None]:
# !pip install flash-attn --no-build-isolation

In [1]:
import torch
import random
import bitsandbytes
from datasets import load_dataset, load_from_disk
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    BitsAndBytesConfig,
    AutoModelForSequenceClassification,
)
from peft import PeftModel, LoraConfig, get_peft_model, prepare_model_for_kbit_training
from typing import Dict, Any, List, Union
from tqdm import tqdm
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ---------------------------------------------------------------------------
# CustomSymPOTrainer 类 non-unsloth
# ---------------------------------------------------------------------------

class CustomSymPOTrainer(Trainer):
    def __init__(
        self,
        model: Union[PeftModel, torch.nn.Module],
        ref_model: Union[PeftModel, torch.nn.Module],
        reward_model: Union[PeftModel, torch.nn.Module],
        args: TrainingArguments,
        beta_kl: float,
        log_ratio_clip_min: float,
        log_ratio_clip_max: float,
        **kwargs,
    ):
        """
        用于 SymPO 算法的自定义 Trainer。

        Args:
            model (PeftModel): 策略模型 (policy_model)。
            ref_model (PeftModel): 参考模型 (ref_model)。
            reward_model (torch.nn.Module): 奖励模型 (reward_model)。
            args (TrainingArguments): 训练参数。
            beta_kl (float): KL散度惩罚项的系数。
            log_ratio_clip_min (float): 对数概率比的最小裁剪值。
            log_ratio_clip_max (float): 对数概率比的最大裁剪值。
        """
        super().__init__(model=model, args=args, **kwargs)

        if not hasattr(self, 'processor') or self.processor is None:
            self.processor = self.tokenizer

        self.ref_model = ref_model
        self.reward_model = reward_model
        self.beta_kl = beta_kl
        self.log_ratio_clip_min = log_ratio_clip_min
        self.log_ratio_clip_max = log_ratio_clip_max

    def _get_log_probs(self, model: AutoModelForCausalLM, prompts: List[str], responses: List[str]) -> torch.Tensor:
        full_texts = [p + r for p, r in zip(prompts, responses)]

        original_padding_side = self.processor.padding_side
        self.processor.padding_side = 'left'

        prompt_tokens = self.processor(prompts, padding=False, truncation=False)
        prompt_lengths = [len(p) for p in prompt_tokens['input_ids']]

        max_len = model.config.max_position_embeddings

        full_tokens = self.processor(
            full_texts, padding=True, truncation=True,
            max_length=max_len,
            return_tensors="pt"
        )
        self.processor.padding_side = original_padding_side

        input_ids = full_tokens['input_ids'].to(self.args.device)
        attention_mask = full_tokens['attention_mask'].to(self.args.device)

        outputs = model(input_ids, attention_mask=attention_mask, return_dict=True)
        logits = outputs.logits

        shifted_logits = logits[..., :-1, :]
        shifted_labels = input_ids[..., 1:]
        loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
        nll_per_token = loss_fct(shifted_logits.reshape(-1, shifted_logits.size(-1)), shifted_labels.reshape(-1))
        nll_per_token = nll_per_token.view(input_ids.size(0), -1)
        log_probs_per_token = -nll_per_token

        # response_mask = torch.zeros_like(shifted_labels, dtype=torch.bool)
        # for i, prompt_len in enumerate(prompt_lengths):
        #     actual_prompt_len = min(prompt_len, max_len)
        #     response_start_index = actual_prompt_len - 1

        #     actual_seq_len = attention_mask[i].sum()
        #     response_end_index = actual_seq_len -1

        #     if response_start_index < response_end_index:
        #          response_mask[i, response_start_index:response_end_index] = True

        seq_len = shifted_labels.size(1)
        position_ids = torch.arange(seq_len, device=self.args.device).expand_as(shifted_labels)
        prompt_lengths_tensor = torch.tensor(prompt_lengths, device=self.args.device).unsqueeze(1)
        response_start_index = prompt_lengths_tensor - 1
        mask = position_ids >= response_start_index
        attention_mask_shifted = attention_mask[:, 1:].to(torch.bool)
        response_mask = mask & attention_mask_shifted

        masked_log_probs = log_probs_per_token * response_mask
        total_log_probs = masked_log_probs.sum(dim=1)
        del logits, shifted_logits, shifted_labels, nll_per_token, log_probs_per_token, response_mask, masked_log_probs

        return total_log_probs

    def _get_rewards(self, prompts: List[str], responses: List[str]) -> torch.Tensor:
        texts = [p + r for p, r in zip(prompts, responses)]
        original_padding_side = self.processor.padding_side
        self.processor.padding_side = 'right'
        reward_max_len = self.reward_model.config.max_position_embeddings
        inputs = self.processor(
            texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=reward_max_len
        ).to(self.args.device)
        self.processor.padding_side = original_padding_side
        with torch.no_grad():
            outputs = self.reward_model(**inputs)
            return torch.sigmoid(outputs.logits.squeeze(-1))

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        prompts = inputs.pop("prompt")
        chosen_responses = inputs.pop("chosen")
        rejected_responses = inputs.pop("rejected")

        # y1s = chosen_responses
        # y2s = rejected_responses
        y1s, y2s, zs = [], [], []
        for j in range(len(prompts)):
            if random.random() < 0.5:
                y1s.append(chosen_responses[j]); y2s.append(rejected_responses[j]); zs.append(1.0)
            else:
                y1s.append(rejected_responses[j]); y2s.append(chosen_responses[j]); zs.append(0.0)
        zs = torch.tensor(zs).to(self.args.device)

        log_probs_pi_y1 = self._get_log_probs(model, prompts, y1s)
        log_probs_pi_y2 = self._get_log_probs(model, prompts, y2s)

        with torch.no_grad():
            with self.model.disable_adapter():
                log_probs_ref_y1 = self._get_log_probs(self.model, prompts, y1s)
                log_probs_ref_y2 = self._get_log_probs(self.model, prompts, y2s)

        log_ratio_y1 = log_probs_pi_y1 - log_probs_ref_y1
        log_ratio_y2 = log_probs_pi_y2 - log_probs_ref_y2

        with torch.no_grad():
            # f_y1 = self._get_rewards(prompts, y1s)
            # f_y2 = self._get_rewards(prompts, y2s)

            log_ratio_weight_y1 = log_ratio_y1.detach()
            log_ratio_weight_y2 = log_ratio_y2.detach()

            kl_term1 = self.beta_kl * log_ratio_weight_y1
            kl_term2 = self.beta_kl * log_ratio_weight_y2

            # weight1 = 1 - f_y2 - kl_term1
            # weight2 = f_y1 + kl_term2
            weight1 = zs - kl_term1
            weight2 = zs + kl_term2

        clamped_log_ratio_y1 = torch.clamp(log_ratio_y1, min=self.log_ratio_clip_min, max=self.log_ratio_clip_max)
        ratio_y1 = torch.exp(clamped_log_ratio_y1)

        clamped_log_ratio_y2 = torch.clamp(log_ratio_y2, min=self.log_ratio_clip_min, max=self.log_ratio_clip_max)
        ratio_y2 = torch.exp(clamped_log_ratio_y2)

        J_sym_objective = ratio_y1 * weight1 - ratio_y2 * weight2
        total_loss = -J_sym_objective.mean()

        logs = {
            "mean_ratio_chosen": ratio_y1.detach().mean().item(),
            "mean_ratio_rejected": ratio_y2.detach().mean().item(),
            "mean_weight_chosen": weight1.detach().mean().item(),
            "mean_weight_rejected": weight2.detach().mean().item(),
            "mean_logprob_pi_chosen": log_probs_pi_y1.detach().mean().item(),
            "mean_logprob_ref_chosen": log_probs_ref_y1.detach().mean().item(),
            # "f_chosen": f_y1.detach().mean().item(),
            # "f_rejected": f_y2.detach().mean().item(),
            "debug_ratio_chosen_max": ratio_y1.detach().max().item(),
            "debug_ratio_rejected_max": ratio_y2.detach().max().item(),
            "debug_weight_chosen_max": weight1.detach().max().item(),
            "debug_weight_rejected_max": weight2.detach().max().item(),
            "debug_logprob_pi_chosen_max": log_probs_pi_y1.detach().max().item(),
            "debug_logprob_ref_chosen_max": log_probs_ref_y1.detach().max().item(),
            "debug_objective_mean": J_sym_objective.detach().mean().item()
        }
        self._current_logs = logs

        if self.args.device == "cuda":
            torch.cuda.empty_cache()

        return total_loss

    def log(self, logs: dict, *args, **kwargs) -> None:
        if hasattr(self, "_current_logs") and self._current_logs is not None:
            logs.update(self._current_logs)
            self._current_logs = None

        super().log(logs, *args, **kwargs)

    def evaluate(
        self,
        eval_dataset=None,
        ignore_keys=None,
        metric_key_prefix: str = "eval",
    ):
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
        if eval_dataset is None:
            raise ValueError("Evaluation requires an eval_dataset.")

        eval_dataloader = self.get_eval_dataloader(eval_dataset)
        self.model.eval()

        # 【优化】不再使用列表存储所有结果，而是使用累加器
        total_samples = 0
        total_accuracy_count = 0
        total_policy_margins = 0
        total_rewards_chosen = 0
        total_rewards_rejected = 0

        with torch.no_grad():
            for batch in tqdm(eval_dataloader, desc="Evaluating"):
                prompts = batch["prompt"]
                chosen_responses = batch["chosen"]
                rejected_responses = batch["rejected"]

                batch_size = len(prompts)
                total_samples += batch_size

                # 计算对数概率和奖励
                policy_chosen_logps = self._get_log_probs(self.model, prompts, chosen_responses)
                policy_rejected_logps = self._get_log_probs(self.model, prompts, rejected_responses)
                rewards_chosen = self._get_rewards(prompts, chosen_responses)
                rewards_rejected = self._get_rewards(prompts, rejected_responses)

                # 【优化】在每个批次上直接计算指标，然后累加
                # 计算本批次的准确样本数
                correct_predictions = (policy_chosen_logps > policy_rejected_logps).sum().item()
                total_accuracy_count += correct_predictions

                # 累加其他指标的总和
                total_policy_margins += (policy_chosen_logps - policy_rejected_logps).sum().item()
                total_rewards_chosen += rewards_chosen.sum().item()
                total_rewards_rejected += rewards_rejected.sum().item()

        # 【优化】在循环结束后，根据累加结果计算最终的平均指标
        accuracy = total_accuracy_count / total_samples if total_samples > 0 else 0
        policy_margins = total_policy_margins / total_samples if total_samples > 0 else 0
        rewards_chosen_mean = total_rewards_chosen / total_samples if total_samples > 0 else 0
        rewards_rejected_mean = total_rewards_rejected / total_samples if total_samples > 0 else 0
        reward_margins = rewards_chosen_mean - rewards_rejected_mean

        metrics = {
            f"{metric_key_prefix}_accuracy": accuracy,
            f"{metric_key_prefix}_policy_margins": policy_margins,
            f"{metric_key_prefix}_rewards_chosen_mean": rewards_chosen_mean,
            f"{metric_key_prefix}_rewards_rejected_mean": rewards_rejected_mean,
            f"{metric_key_prefix}_rewards_margins": reward_margins,
        }

        self.model.train()
        self.log(metrics)
        return metrics



In [3]:
from dataclasses import dataclass
from typing import Dict, List, Any

@dataclass
class DataCollatorForCustomSymPO:
    """
    一个简单的 Data Collator，它将一批字典（每个包含字符串）
    转换为一个包含字符串列表的字典。
    """
    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        # 'features' 是一个列表，列表中的每个元素都是一个数据集样本，
        # 例如: [{'prompt': 'p1', 'chosen': 'c1', 'rejected': 'r1'}, {'prompt': 'p2', ...}]

        # 初始化一个空字典来存放批次数据
        batch = {}

        # 遍历第一个样本的所有键
        for key in features[0].keys():
            # 为每个键创建一个列表，包含该批次所有样本中该键对应的值
            batch[key] = [feature[key] for feature in features]

        # 返回的结果
        # 例如: {'prompt': ['p1', 'p2'], 'chosen': ['c1', 'c2'], 'rejected': ['r1', 'r2']}
        return batch

In [4]:
from transformers import TrainerCallback

class PrintingCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if state.is_local_process_zero and logs:
            print(logs)

In [5]:
# 1. 定义路径
sft_model_path = "/train/Llama-3-8B-Instruct"
# sft_model_path = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
# reward_model_path = "/content/drive/My Drive/Colab_outputs/f_model_llama3_8b_lr2e-4_ultrafeedback_armorm"
reward_model_path ="/train/f_model"
dataset_train_name = "/train/traindataset"
# dataset_test_name = "/content/drive/MyDrive/Policy_Dataset/Data_Test_generation/datasets/llama3_ultrafeedback_armorm"
output_dir  = "/train/output_model/llama3-8b-sympo-5e-5_0.1"
device_map = "auto"
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
lr = 5e-5

In [6]:
# 4-bit 量化配置
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

# LoRA 配置
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(sft_model_path)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("Loading policy model (for training)...")
policy_model = AutoModelForCausalLM.from_pretrained(
    sft_model_path,
    quantization_config=quantization_config,
    device_map=device_map,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
)
policy_model = prepare_model_for_kbit_training(policy_model)
policy_model = get_peft_model(policy_model, lora_config)

# print("Loading reference model (frozen)...")
# ref_model = AutoModelForCausalLM.from_pretrained(
#     sft_model_path,
#     quantization_config=quantization_config,
#     device_map=device_map,
#     attn_implementation="flash_attention_2",
#     torch_dtype=torch.bfloat16,
# )
print("Loading reward model (frozen)...")
reward_model = AutoModelForSequenceClassification.from_pretrained(
    reward_model_path,
    quantization_config=quantization_config,
    device_map=device_map,
    num_labels=1,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
)
if reward_model.config.pad_token_id is None:
    reward_model.config.pad_token_id = tokenizer.pad_token_id

# reward_model = AutoModelForSequenceClassification.from_pretrained(
#     sft_model_path,
#     num_labels=1,
#     torch_dtype=torch.bfloat16,
#     device_map="auto"
# )
#
# # 设置填充符号
# if reward_model.config.pad_token_id is None:
#     reward_model.config.pad_token_id = tokenizer.eos_token_id

# reward_model = PeftModel.from_pretrained(reward_model, reward_model_path)

# for param in ref_model.parameters():
#     param.requires_grad = False
for param in reward_model.parameters():
    param.requires_grad = False

Loading tokenizer...


`torch_dtype` is deprecated! Use `dtype` instead!


Loading policy model (for training)...


Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.18s/it]


Loading reward model (frozen)...


Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.07s/it]
Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at /train/f_model and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
dataset_train = load_from_disk(dataset_train_name)


In [8]:
def preprocess_ultrafeedback(example: dict) -> dict:
    """
    Args:
        example: 数据集中的一个样本字典。

    Returns:
        一个包含 'prompt', 'chosen', 'rejected' 键的新字典。
    """

    chosen_response = example['chosen'][-1]['content']+ "<|eot_id|>\n"
    rejected_response = example['rejected'][-1]['content']+ "<|eot_id|>\n"
    conversation_history = example['chosen'][:-1]

    full_prompt = tokenizer.apply_chat_template(
        conversation_history,
        tokenize=False,
        add_generation_prompt=True
    )
    if not full_prompt or not chosen_response or not rejected_response:
        return {"prompt": "", "chosen": "", "rejected": ""}

    return {
        "prompt": full_prompt,
        "chosen": chosen_response,
        "rejected": rejected_response
    }


dataset_train = load_from_disk(dataset_train_name)

print("Preprocessing dataset...")
processed_dataset_train = dataset_train.map(
    preprocess_ultrafeedback,
    num_proc=4,
    remove_columns=dataset_train.column_names
)

processed_dataset_train = processed_dataset_train.filter(
    lambda x: x['prompt'] and x['chosen'] and x['rejected']
)

# dataset_test = load_from_disk(dataset_test_name)

# print("Preprocessing dataset...")
# processed_dataset_test = dataset_test.map(
#     preprocess_ultrafeedback,
#     num_proc=4,
#     remove_columns=dataset_test.column_names
# )

# processed_dataset_test = processed_dataset_test.filter(
#     lambda x: x['prompt'] and x['chosen'] and x['rejected']
# )

Preprocessing dataset...


In [9]:
training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    gradient_checkpointing=False,
    # group_by_length=True, # 启用长度分组
    # length_column_name="prompt_length",
    optim="adamw_8bit",
    learning_rate=lr,
    max_grad_norm=1.0,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    logging_strategy="steps",
    logging_steps=10,
    logging_first_step=True,
    bf16=True,
    save_strategy="steps",
    save_steps=300,
    # eval_strategy="no",
    eval_strategy="no",
    remove_unused_columns=False,
    save_total_limit=20,
    # metric_for_best_model="eval_accuracy",
    # greater_is_better=True,
    # remove_unused_columns=False,
)
data_collator = DataCollatorForCustomSymPO()
printing_callback = PrintingCallback()

print("Initializing CustomSymPOTrainer...")
trainer = CustomSymPOTrainer(
    model=policy_model,
    ref_model=policy_model,
    reward_model=reward_model,
    args=training_args,
    tokenizer=tokenizer,
    train_dataset=processed_dataset_train,
    # eval_dataset=processed_dataset_test,
    data_collator=data_collator,
    callbacks=[PrintingCallback()],
    # 传递自定义参数
    beta_kl=0.1,
    log_ratio_clip_min=-2.3,
    log_ratio_clip_max=2.3,
)

  super().__init__(model=model, args=args, **kwargs)


Initializing CustomSymPOTrainer...


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.


In [None]:
print("Starting training...")
# trainer.train(resume_from_checkpoint=True)
trainer.train()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 128009}.


Starting training...


[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mzmisource[0m ([33mzmisource-zhejiang-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  return fn(*args, **kwargs)
Casting fp32 inputs back to torch.bfloat16 for flash-attn compatibility.
Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss
1,0.0


{'loss': 0.0, 'grad_norm': 69.2486572265625, 'learning_rate': 0.0, 'mean_ratio_chosen': 1.0, 'mean_ratio_rejected': 1.0, 'mean_weight_chosen': 0.0, 'mean_weight_rejected': 0.0, 'mean_logprob_pi_chosen': -280.10357666015625, 'mean_logprob_ref_chosen': -280.10357666015625, 'debug_ratio_chosen_max': 1.0, 'debug_ratio_rejected_max': 1.0, 'debug_weight_chosen_max': 0.0, 'debug_weight_rejected_max': 0.0, 'debug_logprob_pi_chosen_max': -280.10357666015625, 'debug_logprob_ref_chosen_max': -280.10357666015625, 'debug_objective_mean': 0.0, 'epoch': 0.0002664978846730404}


In [None]:
# import torch
# from transformers import AutoModelForCausalLM, AutoTokenizer
# from peft import PeftModel
# import os

# # --- 1. 请在这里配置您的路径 ---

# # 您微调时使用的原始基础模型的Hugging Face ID
# # (根据您的目录名称，很可能就是这个)
# base_model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

# # 您的LoRA适配器所在的目录
# adapter_path = "/content/drive/My Drive/Colab_outputs/f_model_llama3_8b_lr2e-4_ultrafeedback_armorm"

# # 您想将合并后的完整模型保存到的【新】目录
# merged_model_path = "/content/drive/My Drive/Colab_outputs/merged_model_llama3_8b_ultrafeedback_new"


In [None]:
# # --- 2. 执行合并的代码 ---

# # 确保目标目录存在
# os.makedirs(merged_model_path, exist_ok=True)

# print(f"正在加载基础模型: {base_model_id}")
# # 加载基础模型，请确保您的Colab有足够的内存和显存
# base_model = AutoModelForCausalLM.from_pretrained(
#     base_model_id,
#     torch_dtype=torch.bfloat16, # 在T4/A100 GPU上建议使用 bfloat16
#     device_map="auto"
# )

# print(f"正在加载分词器: {base_model_id}")
# tokenizer = AutoTokenizer.from_pretrained(base_model_id)

In [None]:



# print(f"正在加载LoRA适配器: {adapter_path}")
# # 将LoRA适配器加载到基础模型上
# merged_model = PeftModel.from_pretrained(base_model, adapter_path)

# print("正在合并权重...")
# # 合并权重，然后卸载适配器，得到一个独立的、完整的模型
# merged_model = merged_model.merge_and_unload()

# print(f"正在将合并后的完整模型保存到: {merged_model_path}")
# # 将这个新模型和分词器一起保存到新目录
# merged_model.save_pretrained(merged_model_path)
# tokenizer.save_pretrained(merged_model_path)

# print("\n模型合并并保存完毕！")
