## Download the base model and test inference

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "Qwen/Qwen2.5-0.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")

In [None]:
prompt = "What is real-time machine learning?"
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=512
)
generated_ids = [
    output_ids[len(input_ids):]
    for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response)

## Load the evaluation dataset

In [None]:
from datasets import load_dataset

wiki_data = load_dataset("pdeziel/wikipedia_summaries")
wiki_data

## Tokenize training and evaluation samples using the model

In [None]:
def tokenize(sample):
    # Truncate for faster training
    summary = sample["summary"][:512]
    user_prompt = (
        "Below is a description of a topic. Summarize the description."
        "\n\n" + summary
    )
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": user_prompt},
    ]
    sample["templated_text"] = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    sample["input_ids"] = tokenizer.encode(sample["templated_text"], return_tensors="pt").to(model.device)
    return sample

train_dataset = wiki_data["train"]
train_dataset = train_dataset.map(tokenize, batched=False, remove_columns=train_dataset.column_names)
train_dataset.set_format(type="torch", columns=["input_ids"], device=model.device)

eval_dataset = wiki_data["eval"]
eval_dataset = eval_dataset.map(tokenize, batched=False, remove_columns=eval_dataset.column_names)
eval_dataset.set_format(type="torch", columns=["input_ids"], device=model.device)

In [None]:
import pandas as pd
from transformers import set_seed

def generate_response(sample):
    set_seed(42)
    generate_kwargs = {
        "min_length": -1,
        "top_p": 1.0,
        "top_k": 0.0,
        "do_sample": True,
        "max_new_tokens": 32,
        "pad_token_id": tokenizer.eos_token_id,
    }
    output_ids = model.generate(sample["input_ids"].to(model.device), **generate_kwargs)[0][len(sample["input_ids"][0]):]
    sample["response"] = tokenizer.decode(output_ids, skip_special_tokens=True)
    return sample

eval_dataset = eval_dataset.map(generate_response, batched=False)
pd.set_option("display.max_colwidth", None)
eval_dataset.to_pandas().head()

## Evaluating readability

In [None]:
import textstat
from matplotlib import pyplot as plt

df = eval_dataset.to_pandas()
plt.hist(df["ref_model_reading_ease"], bins=20)
plt.title("Flesch Reading Ease Distribution")
plt.xlabel("Reading Ease Score")
plt.ylabel("Frequency")
plt.savefig("reading_ease_distribution.svg", format="svg")
plt.savefig("reading_ease_distribution.png", format="png")
#print("Average Reading Ease: {:.2f}".format(df["reading_ease"].mean()))
#print("Median Reading Ease: {:.2f}".format(df["reading_ease"].median()))
#print("Standard Deviation of Reading Ease: {:.2f}".format(df["reading_ease"].std()))

In [None]:
import textstat
from matplotlib import pyplot as plt

def get_reading_ease(sample):
    sample["reading_ease"] = textstat.flesch_reading_ease(sample["response"])
    return sample

eval_dataset = eval_dataset.map(get_reading_ease, batched=False)
df = eval_dataset.to_pandas()
plt.hist(df["reading_ease"], bins=20)
plt.title("Flesch Reading Ease Distribution")
plt.xlabel("Reading Ease Score")
plt.ylabel("Frequency")
plt.savefig("reading_ease_distribution.svg", format="svg")
plt.savefig("reading_ease_distribution.png", format="png")
print("Average Reading Ease: {:.2f}".format(df["reading_ease"].mean()))
print("Median Reading Ease: {:.2f}".format(df["reading_ease"].median()))
print("Standard Deviation of Reading Ease: {:.2f}".format(df["reading_ease"].std()))

## Define a reward function

In [None]:
import math
import numpy as np
from matplotlib import pyplot as plt

def get_reward(reading_ease, target=80, stddev=10):
    return math.exp(-((reading_ease - target) ** 2) / (2 * (stddev ** 2)))

x = np.linspace(20, 120, 100)
rewards_a = [get_reward(i, stddev=5) for i in x]
rewards_b = [get_reward(i, stddev=10) for i in x]
rewards_c = [get_reward(i, stddev=15) for i in x]
plt.plot(x, rewards_a, label="Reward A (stddev=5)")
plt.plot(x, rewards_b, label="Reward B (stddev=10)", linestyle="dashed")
plt.plot(x, rewards_c, label="Reward C (stddev=15)", linestyle="dotted")
plt.title("Readability Reward Functions")
plt.xlabel("Flesch-Kincaid Reading Ease")
plt.ylabel("Reward")
plt.axvline(80, label="Target=80", linestyle="dashed")
plt.legend()
plt.savefig("readability_reward.svg", format="svg")
plt.savefig("readability_reward.png", format="png")

In [None]:
df['reward'] = df['reading_ease'].apply(get_reward)
df[df['reading_ease'] > 40].head()

## Configure the model

In [None]:
from peft import LoraConfig, get_peft_model, TaskType
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead, create_reference_model

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

# Load a pretrained model (e.g., GPT-2)
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
peft_model = get_peft_model(model, lora_config)
ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained(peft_model, device_map="auto")
ref_model = create_reference_model(ppo_model)
peft_model.print_trainable_parameters()

## Train the model

In [None]:
# Setup tensorboard
%pip install tensorboard
%load_ext tensorboard
%tensorboard --logdir logs

In [None]:
import torch
import textstat
from tqdm import tqdm
from transformers import set_seed
from trl import PPOTrainer, PPOConfig

def get_reward(response, target=80, stddev=10):
    ease = textstat.flesch_reading_ease(response)
    return math.exp(-((target - ease) ** 2) / (2 * (stddev ** 2)))

ppo_config = PPOConfig(
    model_name=model_name,
    learning_rate=1.41e-5,
    batch_size=1,
    mini_batch_size=1,
    log_with="tensorboard",
    project_kwargs={"logging_dir": "logs"},
)

ppo_trainer = PPOTrainer(config=ppo_config,
                         model=ppo_model,
                         ref_model=ref_model,
                         tokenizer=tokenizer,
                         dataset=train_dataset)
ppo_trainer.current_device = model.device

generate_kwargs = {
    "min_length": -1,
    "top_p": 1.0,
    "top_k": 0.0,
    "do_sample": True,
    "max_new_tokens": 32,
    "pad_token_id": tokenizer.eos_token_id,
}

set_seed(42)
epochs = 1
for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    for batch in tqdm(ppo_trainer.dataloader):
        query_tensors = [q.squeeze() for q in batch["input_ids"]]

        # Get the responses from the model
        response_tensors = []
        for query in query_tensors:
            query_response = ppo_trainer.generate(query, **generate_kwargs).squeeze()[len(query):]
            response_tensors.append(query_response)
        batch["response"] = [tokenizer.decode(r, skip_special_tokens=True) for r in response_tensors]

        # Compute rewards
        scores = [get_reward(response, target=80, stddev=10) for response in batch["response"]]
        rewards = [torch.tensor(float(score)) for score in scores]
        print("query:", tokenizer.decode(query_tensors[0], skip_special_tokens=True))
        print("response:", batch["response"][0])
        print("reward:", rewards[0].item())
        print("-" * 50)

        # Run PPO step
        stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
        ppo_trainer.log_stats(stats, batch, rewards)
ppo_trainer.save_pretrained("my_ppo_model")

## Plot training progress

Note: This data was obtained via tensorboard and is included here for reference.

In [None]:
import pandas as pd

df = pd.read_csv('results/reward_mean.csv')
df['smoothed'] = df['Value'].rolling(window=10).mean()
plt.title("PPO Training - Mean Reward")
plt.xlabel("Step")
plt.ylabel("Mean Reward")
plt.plot(df['Step'], df['Value'], label='Mean Reward')
plt.plot(df['Step'], df['smoothed'], label='Mean Reward (rolling=10)', linestyle='dashed')
plt.legend()
plt.savefig('ppo_training_mean_reward.svg', format='svg')
plt.savefig('ppo_training_mean_reward.png', format='png')

In [None]:
import pandas as pd

df = pd.read_csv('results/kl.csv')
df['smoothed'] = df['Value'].rolling(window=10).mean()
plt.title("PPO Training - KL Divergence")
plt.xlabel("Step")
plt.ylabel("KL Divergence")
plt.axhline(y=0, color='black', linestyle='--')
plt.plot(df['Step'], df['Value'], label='KL Divergence')
plt.plot(df['Step'], df['smoothed'], label='KL Divergence (rolling=10)', linestyle='dashed')
plt.legend()
plt.savefig("ppo_training_kl_divergence.svg", format="svg")
plt.savefig("ppo_training_kl_divergence.png", format="png")

## Evaluate the trained model against the baseline

In [None]:
# When loading a previously trained model, you can use code similar to this

from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead

model_name = "Qwen/Qwen2.5-0.5B-Instruct"
ppo_config = PPOConfig(
    model_name=model_name,
    learning_rate=1.41e-5,
    batch_size=1,
    mini_batch_size=1,
    log_with="tensorboard",
    project_kwargs={"logging_dir": "logs"},
)
ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained("my_ppo_model_3", device_map="auto")
ppo_trainer = PPOTrainer(config=ppo_config, model=ppo_model, ref_model=model, tokenizer=tokenizer)
ppo_trainer.ref_model = model
ppo_trainer.current_device = model.device

In [None]:
import textstat
from transformers import set_seed

def generate(model, input_ids, device="cpu"):
    set_seed(42)
    generate_kwargs = {
        "min_length": -1,
        "top_p": 1.0,
        "top_k": 0.0,
        "do_sample": True,
        "max_new_tokens": 32,
        "pad_token_id": tokenizer.eos_token_id,
    }
    output_ids = model.generate(input_ids.to(device), **generate_kwargs)[0][len(input_ids[0]):]
    return tokenizer.decode(output_ids, skip_special_tokens=True)

def generate_responses(sample):
    device = ppo_trainer.current_device
    sample["ref_model_response"] = generate(ppo_trainer.ref_model, sample["input_ids"], device=device)
    sample["ref_model_reading_ease"] = textstat.flesch_reading_ease(sample["ref_model_response"])
    sample["model_response"] = generate(ppo_trainer.model, sample["input_ids"], device=device)
    sample["model_reading_ease"] = textstat.flesch_reading_ease(sample["model_response"])
    return sample

eval_dataset = eval_dataset.map(generate_responses, batched=False)
eval_dataset.to_pandas().head()

In [None]:
from matplotlib import pyplot as plt

df = eval_dataset.to_pandas()
hatches = ["//", "\\\\"]
labels = ["Trained Model", "Reference Model"]
for i, col in enumerate(["model_reading_ease", "ref_model_reading_ease"]):
    values = df[col].to_list()
    plt.hist(values, label=labels[i], alpha=0.5, hatch=hatches[i])
plt.legend(loc="best")
plt.title("Reading Ease Distribution")
plt.xlabel("Reading Ease Score")
plt.ylabel("Frequency")
plt.savefig("reading_ease_comparison.svg", format="svg")
plt.savefig("reading_ease_comparison.png", format="png")
plt.show()
print("Mean Trained Model Reading Ease: {:.2f}".format(df["model_reading_ease"].mean()))
print("Mean Reference Model Reading Ease: {:.2f}".format(df["ref_model_reading_ease"].mean()))
print("Median Trained Model Reading Ease: {:.2f}".format(df["model_reading_ease"].median()))
print("Median Reference Model Reading Ease: {:.2f}".format(df["ref_model_reading_ease"].median()))
print("Standard Deviation of Trained Model Reading Ease: {:.2f}".format(df["model_reading_ease"].std()))
print("Standard Deviation of Reference Model Reading Ease: {:.2f}".format(df["ref_model_reading_ease"].std()))