In [46]:
import os
from tqdm import tqdm
import torch
import numpy as np
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, AutoModelForSequenceClassification
from datasets import load_dataset
from trl import PPOTrainer, PPOConfig
from transformers import GenerationConfig

In [None]:
# ========= Fixed end-to-end PPO script =========

import os
from tqdm import tqdm
import torch
import numpy as np
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, AutoModelForSequenceClassification
from datasets import load_dataset
from trl import PPOTrainer, PPOConfig


try:
    from trl.models import AutoModelForCausalLMWithValueHead
except Exception:
    try:
        from trl import AutoModelForCausalLMWithValueHead
    except Exception:
        raise ImportError("AutoModelForCausalLMWithValueHead not found. Please upgrade trl (`pip install -U trl`).")

# ----------------- Config -----------------
MODEL_NAME = "distilgpt2"
DATA_JSON = "synthetic_data.jsonl"
MAX_LEN = 128
BATCH_SIZE = 32      
MINI_BATCH = 8
EPOCHS = 2
GEN_MAX_NEW_TOKENS = 48
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ---------- Tokenizer & dataset ----------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


raw_ds = load_dataset("json", data_files=DATA_JSON)["train"]


raw_ds = raw_ds.map(lambda ex: {"prompt": str(ex["prompt"])}, remove_columns=[c for c in raw_ds.column_names if c != "prompt"])


def tokenize_and_index(batch, idxs):
    toks = tokenizer(batch["prompt"], padding="max_length", truncation=True, max_length=MAX_LEN)
    return {"input_ids": toks["input_ids"], "attention_mask": toks["attention_mask"], "idx": idxs}


tokenized = raw_ds.map(lambda examples, idx: tokenize_and_index(examples, idx), with_indices=True, batched=True)


tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "idx"])

print("Columns after tokenization:", tokenized.column_names)
print("Sample tensorized item:", {k: tokenized[0][k] for k in ["input_ids","attention_mask","idx"]})
print("Original prompt:", raw_ds[0]["prompt"])


reward_model_path= "reward_model_checkpoints/checkpoint-150"


# ------------- Load models ---------------
ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained(MODEL_NAME)
value_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(MODEL_NAME)
reward_model = AutoModelForSequenceClassification.from_pretrained(reward_model_path)

reward_tokenizer = AutoTokenizer.from_pretrained(reward_model_path)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    reward_model.config.pad_token_id = tokenizer.pad_token


    
ppo_model.config.pad_token_id = tokenizer.pad_token_id
value_model.config.pad_token_id = tokenizer.pad_token_id
reward_model.config.pad_token_id = tokenizer.pad_token_id


if not hasattr(ppo_model, "is_gradient_checkpointing"):
    ppo_model.is_gradient_checkpointing = False
if not hasattr(ppo_model, "generation_config"):
    ppo_model.generation_config = AutoConfig.from_pretrained(MODEL_NAME)
if not hasattr(value_model, "is_gradient_checkpointing"):
    value_model.is_gradient_checkpointing = False

# Move models to device
ppo_model.to(DEVICE)
value_model.to(DEVICE)
reward_model.to(DEVICE)
ppo_model.eval()
value_model.eval()
reward_model.eval()

# -------- PPO config & trainer -----------
ppo_config = PPOConfig(
    learning_rate=1.41e-5,
    batch_size=BATCH_SIZE,
    mini_batch_size=MINI_BATCH,
    gradient_accumulation_steps=1,
    fp16=False,
    bf16=False,
    logging_dir="./ppo_logs"
)

ppo_trainer = PPOTrainer(
    args=ppo_config,
    processing_class=tokenizer,
    model=ppo_model,
    ref_model=ref_model,
    reward_model=reward_model,
    train_dataset=tokenized,
    value_model=value_model
)

print("PPO Trainer initialized")


def find_policy_model(obj):
   
    if hasattr(obj, "generate"):
        return obj
    
    for attr in ["model", "policy_model", "policy", "module"]:
        maybe = getattr(obj, attr, None)
        if maybe is None:
            continue
        if hasattr(maybe, "generate"):
            return maybe
        # nested search
        for subattr in ["model", "policy_model", "policy", "module", "transformer", "base_model"]:
            sub = getattr(maybe, subattr, None)
            if sub and hasattr(sub, "generate"):
                return sub
    raise RuntimeError("Could not find underlying policy model with .generate()")

def logits_to_rewards(logits):
    
    if isinstance(logits, torch.Tensor):
        arr = logits.detach().cpu().numpy()
    else:
        arr = np.array(logits)
    if arr.ndim == 0:
        return [float(arr)]
    if arr.ndim == 1:
        return arr.tolist()
    if arr.ndim == 2:
        if arr.shape[1] == 1:
            return arr[:,0].tolist()
        if arr.shape[1] == 2:
            return arr[:,1].tolist()
        return arr.mean(axis=1).tolist()
    return arr.reshape(arr.shape[0], -1).mean(axis=1).tolist()

policy_model = find_policy_model(ppo_trainer)
policy_model.to(DEVICE)
policy_model.eval()

# ---------- generation kwargs -------------
generation_kwargs = {
    "max_new_tokens": GEN_MAX_NEW_TOKENS,
    "do_sample": True,
    "top_k": 0,
    "top_p": 1.0,
    "pad_token_id": tokenizer.eos_token_id,
}

# ---------- PPO loop ----------------------
print("Starting PPO loop...")
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    for batch in tqdm(ppo_trainer.dataloader):
       
        device = DEVICE

        input_ids = batch["input_ids"].to(device)         
        attention_mask = batch["attention_mask"].to(device) 
        idxs = batch["idx"].tolist()                       

       
        with torch.no_grad():
            gen_out = policy_model.generate(input_ids, **generation_kwargs)

       
        gen_ids = gen_out.detach().cpu().tolist()
        responses = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)

      
        prompts = [raw_ds[int(i)]["prompt"] for i in idxs]
        texts = [p + r for p, r in zip(prompts, responses)]

     
        reward_inputs = tokenizer(texts, return_tensors="pt", padding=True,padding_side="left", truncation=True, max_length=MAX_LEN)
        reward_inputs = {k: v.to(device) for k, v in reward_inputs.items()}

        with torch.no_grad():
            reward_out = reward_model(**reward_inputs)

        logits = getattr(reward_out, "logits", None)
        if logits is None and hasattr(reward_out, "scores"):
            logits = reward_out.scores
        if logits is None:
            raise RuntimeError("Reward model returned no logits/scores.")

        rewards = logits_to_rewards(logits)
       
        rewards_for_step = [torch.tensor(float(x), device=device) for x in rewards]
        
       
        prompt_tensors = input_ids 
        response_tensors = gen_out[:, input_ids.shape[1]:]  
        reward_list = rewards_for_step 
        

print("PPO run finished.")


Columns after tokenization: ['prompt', 'input_ids', 'attention_mask', 'idx']
Sample tensorized item: {'input_ids': tensor([ 5122,  6060,  9253,  1080,   318, 13456, 10059, 15536,    13,  1867,
          318,   262,  6808,  2728,    30, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256

  0%|          | 0/125 [00:00<?, ?it/s]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
  1%|          | 1/125 [00:04<09:30,  4.60s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
  2%|▏         | 2/125 [00:09<09:11,  4.48s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
  2%|▏         | 3/125 [00:13<09:10,  4.51s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
  3%|▎         | 4/125 [00:17<08:59,  4.46s/it]A decoder-only architecture is being used, but right-padding was detected! For co

Epoch 2/2


  0%|          | 0/125 [00:00<?, ?it/s]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
  1%|          | 1/125 [00:06<13:09,  6.36s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
  2%|▏         | 2/125 [00:12<13:18,  6.50s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
  2%|▏         | 3/125 [00:19<13:15,  6.52s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
  3%|▎         | 4/125 [00:26<13:19,  6.61s/it]A decoder-only architecture is being used, but right-padding was detected! For co

AttributeError: 'PolicyAndValueWrapper' object has no attribute 'save_pretrained'

In [None]:
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification

# ---- Test prompts ----
test_prompts = [
    {"prompt": "Provide a Root Cause Analysis for the recent API latency spike."},
    {"prompt": "Provide a Root Cause Analysis for the recent database connection failure."},
    {"prompt": "Provide a Root Cause Analysis for the recent user login outage."},
]


r_model_path = "reward_model_checkpoints/checkpoint-150"
r_model = AutoModelForSequenceClassification.from_pretrained(r_model_path)
r_tokenizer = AutoTokenizer.from_pretrained(r_model_path)


base_model_path = "distilgpt2"  # Original LM checkpoint before RLHF
orig_model = AutoModelForCausalLM.from_pretrained(base_model_path)
orig_tokenizer = AutoTokenizer.from_pretrained(base_model_path)
orig_model.config.pad_token_id = orig_tokenizer.pad_token_id


# p_model_path = "ppo_trained_rca_model"  
# ppomodel = AutoModelForCausalLM.from_pretrained(p_model_path)
# p_tokenizer = AutoTokenizer.from_pretrained(p_model_path)  





original_results = []
for item in test_prompts:
    input_ids = orig_tokenizer(item["prompt"], return_tensors="pt").input_ids
    generation_output = orig_model.generate(input_ids=input_ids, **generation_kwargs)
    response_text = orig_tokenizer.decode(generation_output[0], skip_special_tokens=True)

   
    score_text = item["prompt"] + response_text
    inputs = r_tokenizer(score_text, return_tensors="pt").to(r_model.device)
    output = r_model(**inputs)
    score = output.logits[0].item()

    original_results.append({"prompt": item["prompt"], "response": response_text, "score": score})


ppo_results = []
for item in test_prompts:
    input_ids = tokenizer(item["prompt"], return_tensors="pt").input_ids
    generation_output = ppo_model.generate(input_ids=input_ids, **generation_kwargs)
    response_text = tokenizer.decode(generation_output[0], skip_special_tokens=True)

  
    score_text = item["prompt"] + response_text
    inputs = r_tokenizer(score_text, return_tensors="pt").to(r_model.device)
    output = r_model(**inputs)
    score = output.logits[0].item()

    ppo_results.append({"prompt": item["prompt"], "response": response_text, "score": score})


df_original = pd.DataFrame(original_results)
df_ppo = pd.DataFrame(ppo_results)

df_comparison = pd.DataFrame({
    'Prompt': df_original['prompt'],
    'Response (Before RLHF)': df_original['response'],
    'Score (Before)': df_original['score'].round(2),
    'Response (After RLHF)': df_ppo['response'],
    'Score (After)': df_ppo['score'].round(2)
})

pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', None)

display(df_comparison)


avg_score_original = df_original['score'].mean()
avg_score_ppo = df_ppo['score'].mean()

print(f"\nAverage Reward Score (Before RLHF): {avg_score_original:.2f}")
print(f"Average Reward Score (After RLHF):  {avg_score_ppo:.2f}")


Unnamed: 0,Prompt,Response (Before RLHF),Score (Before),Response (After RLHF),Score (After)
0,Provide a Root Cause Analysis for the recent API latency spike.,"Provide a Root Cause Analysis for the recent API latency spike. * Normalize clusters by ingesting the whole dataset. * Gradually update data asynchronously. This means constant release times on each cluster, or a small and perhaps constant autogenous cluster. * Drop clusters that introduce too many org security",0.25,Provide a Root Cause Analysis for the recent API latency spike. This report reveals that many possible changes to the NVMe bus would be made even more important if the bus’s current latencies were even higher.\n\n\n\nWe found that up to 6.4GB was the 16GB,0.08
1,Provide a Root Cause Analysis for the recent database connection failure.,"Provide a Root Cause Analysis for the recent database connection failure. The officer in charge of investigating the traffic accident has confirmed it has.\n\n\n\n\nThe error occurred in Lexington. The Glass Home, south of Boston, which is slightly bigger than some commercial production, has been shut down for weeks",-0.19,"Provide a Root Cause Analysis for the recent database connection failure. This database connection failure will result in validation errors of this xmanifest, logging a esp1.2 release based on the old apparently last xmanifest IDAs.4.1 issue string. Importers of the gpp LD",1.49
2,Provide a Root Cause Analysis for the recent user login outage.,Provide a Root Cause Analysis for the recent user login outage. You might voluntarily turn off the service to your vehicle so you can access the right packages to make sure your vehicle never gets to your home.\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n,-0.87,"Provide a Root Cause Analysis for the recent user login outage. The maximum number of problems has been 14,000 users. The root cause has been identified by users gathered by the experts.\n\n\n\nThe root cause has been identified by users gathered by the experts. If you are affected by both",-0.55



Average Reward Score (Before RLHF): -0.27
Average Reward Score (After RLHF):  0.34
