In [6]:
import jsonlines

def normalize_text(s):
    """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps."""
    import string, re

    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def compute_f1(prediction, truth):
    pred_tokens = normalize_text(prediction).split()
    truth_tokens = normalize_text(truth).split()
    
    # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
    if len(pred_tokens) == 0 or len(truth_tokens) == 0:
        return int(pred_tokens == truth_tokens)
    
    common_tokens = set(pred_tokens) & set(truth_tokens)
    
    # if there are no common tokens then f1 = 0
    if len(common_tokens) == 0:
        return 0
    
    prec = len(common_tokens) / len(pred_tokens)
    rec = len(common_tokens) / len(truth_tokens)
    
    return 2 * (prec * rec) / (prec + rec)


def format_reward_func(completions, target, **kwargs):
    rewards = []
    for completion, gt in zip(completions, target):

      try:
        # add synthetic <think> as its already part of the prompt and prefilled for the assistant to more easily match the regex
        # completion = "<think>" + completion
        if random.random() < 0.1:  # 1% chance to write samples into a file
          os.makedirs("logs/completion_samples", exist_ok=True)
          log_file = os.path.join("logs/completion_samples", "completion_samples.txt")
          with open(log_file, "a") as f:
            f.write(f"\n\n==============\n")
            f.write(completion)
        
        if completion.strip().lower() == "<decomposition>false</decomposition>":
            rewards.append(1.0)
        elif completion.strip().lower().startswith("<decomposition>true</decomposition>"):
            sub_completion = completion.strip().lower().replace("<decomposition>true</decomposition>", "").strip()
            if sub_completion.startswith("<sub question>") and sub_completion.endswith("</sub question>"):
                rewards.append(1.0)
        else:
            rewards.append(0.0)
      except Exception:
        rewards.append(0.0)
    return rewards

def equation_reward_func(completions, target, nums, **kwargs):
    rewards = []
    for completion, gt, numbers in zip(completions, target, nums):
      try:
        f1 = compute_f1(gt, completion)
        rewards.append(f1)
        if abs(f1) > 0.8:
            if random.random() < 0.10:  # 10% chance to write fully successful samples into a file
                os.makedirs("logs/completion_samples", exist_ok=True)
                log_file = os.path.join("logs/completion_samples", "success_completion_samples.txt")
                with open(log_file, "a") as f:
                    f.write(f"\n\n==============\n")
                    f.write(completion)
      except Exception:
            rewards.append(0.0) 
    return rewards


def generate_r1_prompt(obj):
    prompt = obj["chosen"][0]['content'].replace("Please give your answer: ", "")
    target = obj["chosen"][1]['content']
    r1_prefix = [{
        "role": "system",
        "content": "You are a helpful assistant. You first thinks about the reasoning process in the mind and then provides the user with the answer."
      },
      { 
        "role": "user",
        "content": prompt
      },
      {
        "role": "assistant",
        "content": "Answer: "
      }]
    return {"prompt": tokenizer.apply_chat_template(r1_prefix, tokenize=False, continue_final_message=True), "target": target}

with jsonlines.open("/home/leisong/codes/trl/data/cpo_data.json", "r") as f:
    for obj in f:
        print(obj["chosen"][1]['content'])
        break

<decomposition>true</decomposition> 
    <sub question>grant green featuring performances.</sub question>


In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "microsoft/Phi-4"

# model_name = "Qwen/Qwen2.5-3B-Instruct"

llm_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

print(llm_model)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 6/6 [00:09<00:00,  1.56s/it]


Phi3ForCausalLM(
  (model): Phi3Model(
    (embed_tokens): Embedding(100352, 5120, padding_idx=100257)
    (embed_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-39): 40 x Phi3DecoderLayer(
        (self_attn): Phi3SdpaAttention(
          (o_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (qkv_proj): Linear(in_features=5120, out_features=7680, bias=False)
          (rotary_emb): Phi3RotaryEmbedding()
        )
        (mlp): Phi3MLP(
          (gate_up_proj): Linear(in_features=5120, out_features=35840, bias=False)
          (down_proj): Linear(in_features=17920, out_features=5120, bias=False)
          (activation_fn): SiLU()
        )
        (input_layernorm): Phi3RMSNorm((5120,), eps=1e-05)
        (resid_attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
        (post_attention_layernorm): Phi3RMSNorm((5120,), eps=1e-05)
      )
    )
    (norm): Phi3RMSNorm((5120,), eps=1e-05

In [4]:

def get_response(prompt, llm_model, tokenizer):
    messages = [
        {"role": "system", "content": "You are a helpful AI assistant."},
        {"role": "user", "content": prompt}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(llm_model.device)

    generated_ids = llm_model.generate(
        **model_inputs,
        max_new_tokens=256
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response

import jsonlines
import random


with jsonlines.open("data/cpo_data.json") as reader:
    print(max([len(obj["rejected"][0]["content"])+len(obj["rejected"][1]["content"]) for obj in reader]))
    # for obj in reader:
        # if random.random() <= 0.05:
        #     prompt = obj["chosen"][0]["content"]
        #     resp = get_response(prompt, llm_model, tokenizer)
        #     print(resp, obj["chosen"][1]["content"])

3307
