In [1]:
import datasets
import torch
import transformers
from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoTokenizer, set_seed, AutoModelForCausalLM, HfArgumentParser, AutoModelForSequenceClassification
from dataclasses import dataclass, field
from typing import Optional

import trl
from trl import (
    ModelConfig,
    PPOConfig, 
    PPOTrainer,
    TrlParser,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE

  from .autonotebook import tqdm as notebook_tqdm


[2025-03-21 16:46:53,492] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/data/gwy/miniconda3/envs/openr1/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/data/gwy/miniconda3/envs/openr1/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::runtime_error::~runtime_error()@GLIBCXX_3.4'
/data/gwy/miniconda3/envs/openr1/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__gxx_personality_v0@CXXABI_1.3'
/data/gwy/miniconda3/envs/openr1/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::tellp()@GLIBCXX_3.4'
/data/gwy/miniconda3/envs/openr1/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::substr(unsigned long, unsigned long) const@GLIBCXX_3.4'
/data/gwy/miniconda3/envs/openr1/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_M_replace_aux(unsigned long, unsigned long, unsigned long, char)@GLIBCXX_3.4'
/data/gwy/minic

In [2]:
@dataclass
class ScriptArguments(trl.ScriptArguments):
    """
    args for callbacks, benchmarks etc
    """
    wandb_project: Optional[str] = field(
        default=None,
        metadata={"help": ("The project to store runs under.")},
    )

script_args, training_args, model_args = ScriptArguments(dataset_name='.sft'), PPOConfig(), ModelConfig()

In [3]:
script_args.dataset_name = '/data4/gwy/datasets/tldr-preference-sft-trl-style'
script_args.dataset_train_split = 'validation'
model_args.model_name_or_path = '/data4/gwy/model/EleutherAI_pythia-1b-deduped__sft__tldr'
training_args.sft_model_path = '/data4/gwy/model/EleutherAI_pythia-1b-deduped__sft__tldr'
training_args.reward_model_path = '/data4/gwy/model/EleutherAI_pythia-1b-deduped__reward__tldr'
training_args.learning_rate = 3e-6
training_args.per_device_train_batch_size = 8
training_args.gradient_accumulation_steps = 4
training_args.total_episodes = 1000000 
training_args.local_rollout_forward_batch_size  = 8
training_args.missing_eos_penalty = 1.0
training_args.stop_token = 'eos'

In [4]:
from easyalign.utils.load_model_or_tokenizer import load_model, load_tokenizer
from easyalign.utils.build_datasets import build_sft_dataset
from easyalign.utils.utils import setup_logging, init_wandb_training

################
# Model & Tokenizer
################
tokenizer = load_tokenizer(model_args.model_name_or_path, padding_side='left')


In [5]:
value_model = AutoModelForSequenceClassification.from_pretrained(
    training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1
)
reward_model = AutoModelForSequenceClassification.from_pretrained(
    training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1
)
policy = AutoModelForCausalLM.from_pretrained(
    training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code
)

In [6]:
peft_config = get_peft_config(model_args)
if peft_config is None:
    ref_policy = load_model(tokenizer , model_args, training_args,  AutoModelForCausalLM)
else:
    ref_policy = None


In [7]:
################
# Dataset
################
dataset = load_dataset(script_args.dataset_name)
train_dataset = dataset[script_args.dataset_train_split]
eval_dataset = dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None

def prepare_dataset(dataset, tokenizer):
    """pre-tokenize the dataset before training; only collate during training"""

    def tokenize(element):
        input_ids = tokenizer.apply_chat_template(
            element["messages"][:1],
            padding=False,
            add_generation_prompt=True,
        )
        return {"input_ids": input_ids, "lengths": len(input_ids)}

    return dataset.map(
        tokenize,
        remove_columns=dataset.column_names,
        num_proc=training_args.dataset_num_proc,
    )

# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
    train_dataset = prepare_dataset(train_dataset, tokenizer)
    if eval_dataset is not None:
        eval_dataset = prepare_dataset(eval_dataset, tokenizer)
    # filtering
    train_dataset = train_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=training_args.dataset_num_proc)
    if eval_dataset is not None:
        eval_dataset = eval_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=training_args.dataset_num_proc)

In [8]:
assert train_dataset[0]["input_ids"][-1] != tokenizer.eos_token_id, "The last token should not be an EOS token"

In [9]:
################
# Training
################
trainer = PPOTrainer(
    args=training_args,
    processing_class=tokenizer,
    model=policy,
    ref_model=ref_policy,
    reward_model=reward_model,
    value_model=value_model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config,
)

In [10]:
args = trainer.args
accelerator = trainer.accelerator
optimizer = trainer.optimizer
model = trainer.model
ref_policy = trainer.ref_model
reward_model = trainer.reward_model
processing_class = trainer.processing_class
dataloader = trainer.dataloader
device = accelerator.device

In [11]:
from transformers import GenerationConfig

def repeat_generator():
    while True:
        yield from dataloader

iter_dataloader = iter(repeat_generator())
generation_config = GenerationConfig(
    max_new_tokens=args.response_length,
    temperature=(args.temperature + 1e-7),
    top_k=0.0,
    top_p=1.0,
    do_sample=True,
)


In [12]:
for update in range(1, args.num_total_batches + 1):
    trainer.state.episode += 1 * args.batch_size
    data = next(iter_dataloader)

    break

In [19]:
from trl.models.utils import unwrap_model_for_generation
from trl.trainer.utils import batch_generation

with torch.no_grad():
    queries = data["input_ids"].to(device)
    context_length = queries.shape[1]
    responses = []
    postprocessed_responses = []
    logprobs = []
    ref_logprobs = []
    scores = []
    sequence_lengths = []
    values = []
    with unwrap_model_for_generation(
        trainer.model, trainer.accelerator, gather_deepspeed3_params=trainer.args.ds3_gather_for_generation
    ) as unwrapped_model:
        query_responses, logitss = batch_generation(
            unwrapped_model.policy,
            queries,
            args.local_rollout_forward_batch_size,
            processing_class.pad_token_id,
            generation_config,
        )

Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


In [20]:
print(queries.shape, query_responses.shape, logitss.shape)

torch.Size([32, 508]) torch.Size([32, 561]) torch.Size([32, 53, 50304])


In [21]:
from trl.trainer.utils import selective_log_softmax, forward, truncate_response, first_true_indices, get_reward
trainer.stop_token_id = None

with torch.no_grad():
    for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
        query = queries[i : i + args.local_rollout_forward_batch_size]
        query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
        response = query_response[:, context_length:]
        logits = logitss[i : i + args.local_rollout_forward_batch_size]
        logprob = selective_log_softmax(logits, response)
        del logits
        torch.cuda.empty_cache()

        if ref_policy is None:
            with trainer.null_ref_context():
                ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
        else:
            ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
        ref_logits = ref_output.logits[:, context_length - 1 : -1]
        ref_logits /= args.temperature + 1e-7
        ref_logprob = selective_log_softmax(ref_logits, response)
        del ref_output, ref_logits
        torch.cuda.empty_cache()

        # Response Processing 1. truncate response after the first occurrence of `stop_token_id`
        postprocessed_response = response
        if trainer.stop_token_id is not None:  # handle the edge case when stop_token_id exists but is 0
            postprocessed_response = truncate_response(
                trainer.stop_token_id, processing_class.pad_token_id, response
            )

        # Response Processing 2. run reward model on the truncated responses
        postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
        sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
        unwrapped_value_model = accelerator.unwrap_model(model).value_model
        full_value, _, _ = get_reward(
            unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
        )
        value = full_value[:, context_length - 1 : -1].squeeze(-1)
        _, score, _ = get_reward(
            reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
        )

        responses.append(response)
        postprocessed_responses.append(postprocessed_response)
        logprobs.append(logprob)
        ref_logprobs.append(ref_logprob)
        sequence_lengths.append(sequence_length)
        scores.append(score)
        values.append(value)

In [22]:
print(response.shape, logprob.shape, ref_logprob.shape)

torch.Size([8, 53]) torch.Size([8, 53]) torch.Size([8, 53])


In [23]:
responses = torch.cat(responses, 0)
postprocessed_responses = torch.cat(postprocessed_responses, 0)
logprobs = torch.cat(logprobs, 0)
ref_logprobs = torch.cat(ref_logprobs, 0)
sequence_lengths = torch.cat(sequence_lengths, 0)
scores = torch.cat(scores, 0)
values = torch.cat(values, 0)
del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
torch.cuda.empty_cache()

In [24]:
print(scores.shape, values.shape)

torch.Size([32]) torch.Size([32, 53])


In [27]:
INVALID_LOGPROB = 1.0
# Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
# Completions not passing that filter will receive a lower score.
contain_eos_token = torch.any(postprocessed_responses == trainer.processing_class.eos_token_id, dim=-1)
if trainer.args.missing_eos_penalty is not None:
    scores[~contain_eos_token] -= trainer.args.missing_eos_penalty
# accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")

# be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
sequence_lengths_p1 = sequence_lengths + 1
padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
values = torch.masked_fill(values, padding_mask_p1, 0)

In [32]:
print(logprobs.shape, ref_logprobs.shape , values.shape)

torch.Size([32, 53]) torch.Size([32, 53]) torch.Size([32, 53])


In [34]:
from trl.core import masked_whiten

# 4. compute rewards
kl = logprobs - ref_logprobs
non_score_reward = -args.kl_coef * kl
rewards = non_score_reward.clone()
actual_start = torch.arange(rewards.size(0), device=rewards.device)
actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
rewards[[actual_start, actual_end]] += scores

# 5. whiten rewards
if args.whiten_rewards:
    rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
    rewards = torch.masked_fill(rewards, padding_mask_p1, 0)

# 6. compute advantages and returns
lastgaelam = 0
advantages_reversed = []
gen_length = responses.shape[1]
for t in reversed(range(gen_length)):
    nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
    delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
    lastgaelam = delta + args.gamma * args.lam * lastgaelam
    advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], axis=1)
returns = advantages + values
advantages = masked_whiten(advantages, ~padding_mask)
advantages = torch.masked_fill(advantages, padding_mask, 0)
torch.cuda.empty_cache()

In [1]:
from transformers.trainer import get_scheduler

  from .autonotebook import tqdm as notebook_tqdm


In [38]:
mean_kl = kl.sum(1).mean()

In [39]:
mean_kl

tensor(-1.9521e-05, device='cuda:0')

In [2]:
import subprocess
class MissingCUDAException(Exception):
    pass
class CUDAMismatchException(Exception):
    pass
def installed_cuda_version(name=""):
    import torch.utils.cpp_extension
    cuda_home = torch.utils.cpp_extension.CUDA_HOME
    if cuda_home is None:
        raise MissingCUDAException("CUDA_HOME does not exist, unable to compile CUDA op(s)")
    # Ensure there is not a cuda version mismatch between torch and nvcc compiler
    output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], universal_newlines=True)
    output_split = output.split()
    release_idx = output_split.index("release")
    release = output_split[release_idx + 1].replace(',', '').split(".")
    # Ignore patch versions, only look at major + minor
    cuda_major, cuda_minor = release[:2]
    return int(cuda_major), int(cuda_minor)
print(installed_cuda_version(name=""))

(11, 8)
