In [4]:
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import torch.nn.functional as F
from tqdm import tqdm

CACHE_DIR = None



In [9]:
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
train_dataset[0]

{'chosen': [{'content': 'Use the pygame library to write a version of the classic game Snake, with a unique twist',
   'role': 'user'},
  {'content': "Sure, I'd be happy to help you write a version of the classic game Snake using the pygame library! Here's a basic outline of how we can approach this:\n\n1. First, we'll need to set up the game display and create a game object that we can use to handle the game's state.\n2. Next, we'll create the game's grid, which will be used to represent the game board. We'll need to define the size of the grid and the spaces within it.\n3. After that, we'll create the snake object, which will be used to represent the player's movement. We'll need to define the size of the snake and the speed at which it moves.\n4. We'll also need to create a food object, which will be used to represent the food that the player must collect to score points. We'll need to define the location of the food and the speed at which it moves.\n5. Once we have these objects se

In [8]:

dataset = load_dataset("Dahoas/rm-static", cache_dir=CACHE_DIR)
dataset



DatasetDict({
    train: Dataset({
        features: ['prompt', 'response', 'chosen', 'rejected'],
        num_rows: 76256
    })
    test: Dataset({
        features: ['prompt', 'response', 'chosen', 'rejected'],
        num_rows: 5103
    })
})

In [4]:
import math

math.average([1, 2, 3])

AttributeError: module 'math' has no attribute 'average'

In [2]:

MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
BATCH_SIZE = 20
LR = 5e-5
EPOCHS = 1
BETA = 0.1
MAX_PROMPT_LEN = 512
MAX_RESP_LEN = 512
PAD_TO_MULTIPLE_OF = 8  # 便于 tensor core 加速
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR, padding_side="right", truncation=True, max_length=MAX_PROMPT_LEN)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR).to(DEVICE)

# ref_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR).to(DEVICE)
# ref_model.eval()  # 不训练


In [2]:
dataset = load_dataset("Dahoas/rm-static", cache_dir=CACHE_DIR)


In [None]:
dataset

In [4]:
# -----------------------
# DataLoader
# -----------------------
def apply_chat_template(prompt, response, tokenizer):
    messages = [
        {"role": "user", "content": prompt.strip()},
        {"role": "assistant", "content": response.strip()}
    ]
    msg = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)#.strip()  # There's an automatic /n at the end
    return msg


def collate_fn(batch):
    prompt_chosen_texts = []
    prompt_rejected_texts = []
    prompt_only_texts = []
    prompt_offset_lens = []
    for item in batch:
        prompt_chosen_texts.append(apply_chat_template(item['prompt'], item['chosen'], tokenizer))
        prompt_rejected_texts.append(apply_chat_template(item['prompt'], item['rejected'], tokenizer))
        prompt_only = apply_chat_template(item['prompt'], '', tokenizer)
        prompt_only_len = tokenizer(prompt_only, return_tensors="pt", padding=False, truncation=True).input_ids.shape[1]
        prompt_offset_lens.append(prompt_only_len-2)

    prompt_chosen_ids = tokenizer(prompt_chosen_texts, return_tensors="pt", padding=True, truncation=True).input_ids
    prompt_rejected_ids = tokenizer(prompt_rejected_texts, return_tensors="pt", padding=True, truncation=True).input_ids

    return {
        'prompt_chosen_ids': prompt_chosen_ids,
        'prompt_rejected_ids': prompt_rejected_ids,
        'prompt_offset_lens': torch.tensor(prompt_offset_lens, dtype=torch.int32)}


loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)


In [None]:
for batch in loader:
    prompt_chosen_batch = batch['prompt_chosen_ids'].to(DEVICE)
    prompt_rejected_batch = batch['prompt_rejected_ids'].to(DEVICE)
    prompt_offset_lens_batch = batch['prompt_offset_lens'].to(DEVICE)
    
    # print(tokenizer.decode(prompt_chosen_batch[1]))
    # print(tokenizer.decode(prompt_chosen_batch[1][prompt_offset_lens_batch[1]:]))

    # 计算损失
    outputs_chosen = model(prompt_chosen_batch)
    # outputs_rejected = model(prompt_rejected_batch)



    chosen_logprobs = F.log_softmax(outputs_chosen.logits[:, :-1, :], dim=-1)
    chosen_logprobs_selected = torch.gather(chosen_logprobs, dim=-1, index=prompt_chosen_batch[:, 1:].unsqueeze(-1)).squeeze(-1)
    print(chosen_lorgprobs_selected)

    
    assert False

    loss_chosen = outputs_chosen.loss
    loss_rejected = outputs_rejected.loss

    # 计算 DPO 损失
    logits_chosen = outputs_chosen.logits[:, :-1, :].contiguous()
    logits_rejected = outputs_rejected.logits[:, :-1, :].contiguous()

    chosen_log_probs = F.log_softmax(logits_chosen, dim=-1)
    rejected_log_probs = F.log_softmax(logits_rejected, dim=-1)

    chosen_log_probs_selected = torch.gather(chosen_log_probs, -1, prompt_chosen_ids[:, 1:].unsqueeze(-1)).squeeze(-1)
    rejected_log_probs_selected = torch.gather(rejected_log_probs, -1, prompt_rejected_ids[:, 1:].unsqueeze(-1)).squeeze(-1)

    dpo_loss = (chosen_log_probs_selected - rejected_log_probs_selected).mean()

    # 反向传播
    loss = loss_chosen + loss_rejected + BETA * dpo_loss
    loss.backward()

    # 优化器更新等操作...

AssertionError: 

In [12]:
outputs_chosen.logits.shape
chosen_logprobs = F.log_softmax(outputs_chosen.logits[:, :-1, :], dim=-1)

chosen_logprobs_selected = torch.gather(chosen_logprobs, dim=-1, index=prompt_chosen_batch[:, 1:].unsqueeze(-1)).squeeze(-1)


: 

In [9]:
outputs_chosen.logits.shape, prompt_chosen_batch.unsqueeze(-1).shape



(torch.Size([20, 459, 151936]), torch.Size([20, 459, 1]))