In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
import gymnasium as gym
import random
import matplotlib.pyplot as plt
from collections import deque

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Define Environment and Collect Initial Data

In [None]:
# Initialize CartPole environment
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]  # 4 for CartPole
action_dim = env.action_space.n  # 2 for CartPole

# Define a simple policy network for initial data collection
class InitialPolicy(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, action_dim)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)
    
    def act(self, state):
        state = torch.FloatTensor(state).unsqueeze(0)
        logits = self.forward(state)
        dist = Categorical(logits=logits)
        action = dist.sample()
        return action.item(), dist.log_prob(action)

# Function to collect trajectories
def collect_trajectories(policy, env, num_trajectories=10, max_steps=500):
    trajectories = []
    
    for _ in range(num_trajectories):
        states, actions, rewards, dones, s_next, log_probs = [], [], [], [], [], []
        state, _ = env.reset()
        done = False
        step = 0
        
        while not done and step < max_steps:
            action, log_prob = policy.act(state)
            dones.append(done)

            next_state, reward, done, _, _ = env.step(action)
            
            states.append(state)
            actions.append(action)
            rewards.append(reward)
            s_next.append(next_state)
            log_probs.append(log_prob)
            
            state = next_state
            step += 1
        
        # Store the trajectory
        trajectories.append({
            'states': np.array(states),
            'actions': np.array(actions),
            'rewards': np.array(rewards),
            "dones": np.array(dones),
            "next_states": np.array(s_next),
            "log_probs": np.array(log_probs),
            'length': len(states),
            'return': sum(rewards)
        })
    
    return trajectories

# Initialize and train a basic policy to collect diverse trajectories
initial_policy = InitialPolicy(state_dim, action_dim)
optimizer = optim.Adam(initial_policy.parameters(), lr=0.01)
gamma = 0.99
# Train the initial policy for a few episodes to get somewhat reasonable behavior
for episode in range(500):
    state, _ = env.reset()
    done = False
    episode_reward = 0
    actions = []
    states = []
    rewards = []
    while not done:
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        logits = initial_policy(state_tensor)
        dist = Categorical(logits=logits)
        action = dist.sample()
        log_prob = dist.log_prob(action)

        next_state, reward, done, truncated , _= env.step(action.item())
        episode_reward += reward
        states.append(state_tensor)
        actions.append(log_prob)
        rewards.append(reward)
        
        state = next_state
    returns = []
    R = 0
    for r in reversed(rewards):
        R = r + gamma * R
        returns.insert(0, R)
    returns = torch.tensor(returns)

    # Normalize returns for stability
    returns = (returns - returns.mean()) / (returns.std() + 1e-8)

    # Update policy once per episode
    policy_loss = []
    for log_prob, R in zip(actions, returns):
        policy_loss.append(-log_prob * R)
        
    policy_loss = torch.cat(policy_loss).sum()
    optimizer.zero_grad()
    policy_loss.backward()
    optimizer.step()
    
    if (episode + 1) % 5 == 0:
        print(f"Episode {episode+1}, Reward: {episode_reward}")
    if episode_reward >= 1000: 
        break

Episode 5, Reward: 9.0
Episode 10, Reward: 30.0
Episode 15, Reward: 31.0
Episode 20, Reward: 32.0
Episode 25, Reward: 33.0
Episode 30, Reward: 39.0
Episode 35, Reward: 22.0
Episode 40, Reward: 25.0
Episode 45, Reward: 80.0
Episode 50, Reward: 85.0
Episode 55, Reward: 59.0
Episode 60, Reward: 46.0
Episode 65, Reward: 91.0
Episode 70, Reward: 345.0
Episode 75, Reward: 281.0
Episode 80, Reward: 258.0
Episode 85, Reward: 261.0
Episode 90, Reward: 354.0
Episode 95, Reward: 117.0
Episode 100, Reward: 179.0
Episode 105, Reward: 59.0
Episode 110, Reward: 721.0


In [14]:
trajectories = collect_trajectories(initial_policy, env, num_trajectories=100)
print(f"Collected {len(trajectories)} trajectories")

Collected 100 trajectories


### Generate Human Preference Labels

In [4]:
# Simulate human preferences based on various criteria
def generate_human_preferences(trajectories, num_pairs=200):
    """
    Generate synthetic human preferences based on multiple criteria:
    1. Higher return (total reward)
    2. Less extreme cart positions (closer to center)
    3. Less extreme pole angles
    4. Smoother actions (less switching)
    """
    pairs = []
    n = len(trajectories)
    
    for _ in range(num_pairs):
        # Randomly select two different trajectories
        i, j = random.sample(range(n), 2)
        traj1, traj2 = trajectories[i], trajectories[j]
        
        # Calculate preference metrics
        # 1. Return (higher is better)
        return1, return2 = traj1['return'], traj2['return']
        
        # 2. Cart position stability (lower absolute position is better)
        pos_stability1 = np.mean(np.abs([s[0] for s in traj1['states']]))
        pos_stability2 = np.mean(np.abs([s[0] for s in traj2['states']]))
        
        # 3. Pole angle stability (lower absolute angle is better)
        angle_stability1 = np.mean(np.abs([s[2] for s in traj1['states']]))
        angle_stability2 = np.mean(np.abs([s[2] for s in traj2['states']]))
        
        # 4. Action smoothness (fewer changes is better)
        action_changes1 = np.sum(np.abs(np.diff(traj1['actions'])))
        action_changes2 = np.sum(np.abs(np.diff(traj2['actions'])))
        
        # Normalize by trajectory length for fair comparison
        action_changes1 /= max(1, len(traj1['actions']) - 1)
        action_changes2 /= max(1, len(traj2['actions']) - 1)
        
        # Combine metrics with weights
        score1 = (0.5 * return1 - 
                  0.2 * pos_stability1 - 
                  0.2 * angle_stability1 - 
                  0.1 * action_changes1)
        
        score2 = (0.5 * return2 - 
                  0.2 * pos_stability2 - 
                  0.2 * angle_stability2 - 
                  0.1 * action_changes2)
        
        # Determine preference (1 means traj1 is preferred, 0 means traj2 is preferred)
        preference = 1 if score1 > score2 else 0
        
        # Store the pair with preference
        pairs.append((traj1, traj2, preference))
    
    return pairs

# Generate preference pairs
preference_pairs = generate_human_preferences(trajectories)
print(f"Generated {len(preference_pairs)} preference pairs")

Generated 200 preference pairs


### Train the Reward Model

In [5]:
# Define the reward model
class RewardModel(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.state_encoder = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 32)
        )
        
        self.action_encoder = nn.Embedding(action_dim, 8)
        
        self.combined = nn.Sequential(
            nn.Linear(32 + 8, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )
        
    def forward(self, states, actions):
        # states: [batch_size, state_dim]
        # actions: [batch_size]
        state_features = self.state_encoder(states)
        action_features = self.action_encoder(actions)
        combined = torch.cat([state_features, action_features], dim=1)
        return self.combined(combined).squeeze(-1)
    
    def get_reward(self, state, action):
        # For single state-action pair
        state = torch.FloatTensor(state).unsqueeze(0)
        action = torch.LongTensor([action])
        return self.forward(state, action).item()

# Function to compute trajectory reward
def compute_trajectory_reward(reward_model, trajectory):
    states = torch.FloatTensor(trajectory['states'])
    actions = torch.LongTensor(trajectory['actions'])
    return reward_model(states, actions).sum().item()

# Train the reward model on preference pairs
def train_reward_model(reward_model, preference_pairs, num_epochs=100, batch_size=32):
    optimizer = optim.Adam(reward_model.parameters(), lr=0.001)
    bce_loss = nn.BCEWithLogitsLoss()

    for epoch in range(num_epochs):
        random.shuffle(preference_pairs)
        epoch_loss = 0.0
        
        for i in range(0, len(preference_pairs), batch_size):
            batch = preference_pairs[i:i+batch_size]
            batch_loss = 0.0
            
            for traj1, traj2, preference in batch:
                # Convert trajectories to tensors
                states1 = torch.FloatTensor(traj1['states'])
                actions1 = torch.LongTensor(traj1['actions'])
                states2 = torch.FloatTensor(traj2['states'])
                actions2 = torch.LongTensor(traj2['actions'])
                
                # Get rewards for each trajectory
                rewards1 = reward_model(states1, actions1).mean()
                rewards2 = reward_model(states2, actions2).mean()
                
                # Compute the Bradley-Terry preference loss
                logits = rewards1 - rewards2
                
                # Use binary cross-entropy loss with logits
                loss = bce_loss(logits, torch.tensor(float(preference)) )
                
                batch_loss += loss
            
            # Update the reward model
            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()
            
            epoch_loss += batch_loss.item()
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}")
    
    return reward_model

# Initialize and train the reward model
reward_model = RewardModel(state_dim, action_dim)
reward_model = train_reward_model(reward_model, preference_pairs)
print("Reward model training complete.")

Epoch 10, Loss: 77.3828
Epoch 20, Loss: 44.2642
Epoch 30, Loss: 34.1429
Epoch 40, Loss: 30.7154
Epoch 50, Loss: 29.3170
Epoch 60, Loss: 27.6882
Epoch 70, Loss: 26.8126
Epoch 80, Loss: 24.9261
Epoch 90, Loss: 23.4280
Epoch 100, Loss: 21.8602
Reward model training complete.


### Implement PPO and Fine-tune Policy

In [6]:
class ActorCriticNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super(ActorCriticNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc_policy = nn.Linear(hidden_dim, action_dim)
        self.fc_value = nn.Linear(hidden_dim, 1)
    
    def forward(self, x):
        x = torch.torch.nn.ReLU(self.fc1(x))
        policy_logits = self.fc_policy(x)
        policy = torch.softmax(policy_logits, dim=-1)
        value = self.fc_value(x)
        return policy, value

class PPOAgent:
    def __init__(self, state_dim, action_dim, hidden_dim=128, lr=3e-4, gamma=0.99, clip_epsilon=0.2, update_epochs=4):
        self.ac_net = ActorCriticNetwork(state_dim, action_dim, hidden_dim).to(device)
        self.optimizer = optim.Adam(self.ac_net.parameters(), lr=lr)
        self.gamma = gamma
        self.clip_epsilon = clip_epsilon
        self.update_epochs = update_epochs

    def select_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        policy, value = self.ac_net(state)
        m = Categorical(policy)
        action = m.sample()
        return action.item(), m.log_prob(action), value

    def compute_returns_and_advantages(self, trajectories):
        all_returns = []
        all_advantages = []
        for traj in trajectories:
            returns = []
            G = 0
            for (_, _, _, r, _, _, _) in reversed(traj):
                G = r + self.gamma * G
                returns.insert(0, G)
            returns = torch.tensor(returns, dtype=torch.float).to(device)
            returns = (returns - returns.mean()) / (returns.std() + 1e-8)
            advantages = []
            for (i, (_, _, _, _, value, _, _)) in enumerate(traj):
                advantages.append(returns[i] - value.item())
            all_returns.append(returns)
            all_advantages.append(torch.tensor(advantages, dtype=torch.float).to(device))
        return all_returns, all_advantages

    def update_td(self, trajectories):
        # Extract all data from transitions into tensors
        states = trajectories["states"]
        actions = trajectories["actions"]
        rewards = trajectories["rewards"]
        dones = trajectories["dones"]
        next_states = trajectories["next_states"]

        old_log_probs = torch.stack([t[2] for t in trajectories]).to(device)
        rewards = torch.FloatTensor([t[3] for t in trajectories]).to(device)
        values = torch.stack([t[4] for t in trajectories]).to(device).squeeze()
        
        # Calculate TD targets and advantages
        with torch.no_grad():
            _, next_values = self.ac_net(next_states)
            next_values = next_values.squeeze()
            next_values = next_values * (1 - dones)
            td_targets = rewards + self.gamma * next_values
            advantages = td_targets - values.detach()
            # Optional: Normalize advantages
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        # Perform multiple epochs of updates (typical for PPO)
        for _ in range(self.update_epochs):
            # Get current policy and values
            policy, current_values = self.ac_net(states)
            m = Categorical(policy)
            new_log_probs = m.log_prob(actions)
            
            # Calculate ratios and PPO clipped objective
            ratios = torch.exp(new_log_probs - old_log_probs)
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages
            policy_loss = -torch.min(surr1, surr2).mean()
            
            # Value loss using TD targets
            value_loss = nn.functional.mse_loss(current_values.squeeze(), td_targets)
            
            # Optional: Add entropy bonus for exploration
            entropy_loss = -m.entropy().mean() * 0.01
            
            # Total loss
            loss = policy_loss + 0.5 * value_loss + entropy_loss
            
            # Update network
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

In [None]:
# Plot the learning curve
plt.figure(figsize=(10, 6))
plt.plot(rewards_history)
plt.title('PPO Training with Learned Reward Model')
plt.xlabel('Episode')
plt.ylabel('Episode Reward')
plt.grid(True)
plt.show()

# Evaluate the final policy
def evaluate_policy(env, policy, num_episodes=10):
    total_rewards = []
    
    for _ in range(num_episodes):
        state = env.reset()
        done = False
        episode_reward = 0
        
        while not done:
            action, _, _, _ = policy.get_action(state)
            next_state, reward, done, _ = env.step(action)
            episode_reward += reward
            state = next_state
        
        total_rewards.append(episode_reward)
    
    return np.mean(total_rewards)

avg_reward = evaluate_policy(env, ppo_policy)
print(f"Average reward over 10 episodes: {avg_reward:.2f}")

### Visualize and Compare Results

In [None]:
# Function to visualize a policy in action
def visualize_policy(env, policy, reward_model=None, num_steps=500):
    state = env.reset()
    done = False
    step = 0
    total_reward = 0
    model_rewards = 0
    
    while not done and step < num_steps:
        env.render()
        
        # Get action from policy
        if isinstance(policy, PPOPolicy):
            action, _, _, _ = policy.get_action(state)
        else:
            action = policy.act(state)
        
        # Take action
        next_state, env_reward, done, _ = env.step(action)
        total_reward += env_reward
        
        # Get reward from model if available
        if reward_model:
            model_reward = reward_model.get_reward(state, action)
            model_rewards += model_reward
            print(f"Step {step}, Env Reward: {env_reward}, Model Reward: {model_reward:.4f}")
        else:
            print(f"Step {step}, Env Reward: {env_reward}")
        
        state = next_state
        step += 1
    
    env.close()
    print(f"Total environment reward: {total_reward}")
    if reward_model:
        print(f"Total model reward: {model_rewards:.4f}")

# Compare initial policy vs RLHF-trained policy
print("Visualizing initial policy:")
visualize_policy(env, initial_policy)

print("\nVisualizing RLHF-trained policy:")
visualize_policy(env, ppo_policy, reward_model)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Set random seeds for reproducibility.
torch.manual_seed(42)

# Suppose we have a toy vocabulary of size 10, and we model sequences of length 5.
vocab_size = 10
embed_dim = 8
hidden_dim = 16
seq_len = 5
batch_size = 2

# Define a simple language model using an LSTM.
class ToyLanguageModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super(ToyLanguageModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, x):
        # x has shape: (batch_size, seq_len)
        emb = self.embedding(x)  # shape: (batch_size, seq_len, embed_dim)
        lstm_out, _ = self.lstm(emb)  # shape: (batch_size, seq_len, hidden_dim)
        logits = self.fc(lstm_out)  # shape: (batch_size, seq_len, vocab_size)
        return logits

# Initialize the model, loss function, and optimizer.
model = ToyLanguageModel(vocab_size, embed_dim, hidden_dim)
criterion = nn.CrossEntropyLoss()  # expects inputs of shape (N, vocab_size) and targets of shape (N,)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Generate a toy dataset.
# In SFT, x is the prompt and y is the target output sequence.
# For simplicity, we randomly generate x and y (in practice, these come from human demonstrations).
x = torch.randint(0, vocab_size, (batch_size, seq_len))  # shape: (batch_size, seq_len)
y = torch.randint(0, vocab_size, (batch_size, seq_len))  # target sequence

# Forward pass: obtain logits for each token position.
logits = model(x)  # shape: (batch_size, seq_len, vocab_size)
