<a href="https://colab.research.google.com/github/varunraom91/stock-application/blob/main/simple_grpo_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import numpy as np

# =============================================
# Environment: Math Equation Solver
# =============================================
class MathEnv:
    def __init__(self):
        self.equation = "2x + 3 = 7"
        self.solution = {"steps": ["2x = 4", "x = 2"], "reward": 1.0}  # Ground truth

    def get_reward(self, generated_steps):
        # Reward = 1 if steps match solution, 0.5 for partial correctness, 0 otherwise
        if generated_steps == self.solution["steps"]:
            return 1.0
        elif any(step in generated_steps for step in self.solution["steps"]):
            return 0.5
        else:
            return 0.0

# =============================================
# Neural Networks
# =============================================
class PolicyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(128, 64),  # Simplified "understanding" of equation
            nn.ReLU(),
            nn.Linear(64, 4)     # 4 possible actions: ["add", "subtract", "multiply", "divide"]
        )

    def forward(self, state):
        return self.fc(state)

class ValueModel(nn.Module):  # Used only by PPO
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, state):
        return self.fc(state)

# =============================================
# PPO Training (with Value Model)
# =============================================
def run_ppo(env, num_iterations=3):
    policy = PolicyModel()
    value_model = ValueModel()
    optimizer = torch.optim.Adam(policy.parameters(), lr=0.001)

    # Dummy "state" representing equation (random for illustration)
    state = torch.randn(128)

    for _ in range(num_iterations):
        # Generate solution steps
        action_logits = policy(state)
        action_probs = torch.softmax(action_logits, dim=-1)

        # Simulate generating steps (e.g., ["subtract 3", "divide by 2"])
        generated_steps = ["2x = 4", "x = 2"]  # Assume perfect generation for demo

        # Get reward (1.0 in this case)
        reward = env.get_reward(generated_steps)

        # PPO Advantage Calculation
        value_pred = value_model(state)  # Value model estimates reward
        advantage = reward - value_pred.detach()

        # PPO Loss (simplified)
        old_probs = torch.tensor([0.2, 0.3, 0.1, 0.4])  # Placeholder
        new_probs = action_probs
        ratio = new_probs / old_probs
        clipped_ratio = torch.clamp(ratio, 0.8, 1.2)
        ppo_loss = -torch.min(ratio * advantage, clipped_ratio * advantage).mean()

        print(f"PPO Advantage: {advantage.item():.2f}, Loss: {ppo_loss.item():.2f}")

# =============================================
# GRPO Training (Group-Based)
# =============================================
def run_grpo(env, group_size=3, num_iterations=3):
    policy = PolicyModel()
    optimizer = torch.optim.Adam(policy.parameters(), lr=0.001)

    state = torch.randn(128)  # Same "equation" representation

    # Initialize old_probs before the loop
    old_probs = torch.softmax(policy(state), dim=-1)

    for _ in range(num_iterations):
        # Generate group of solutions
        group_rewards = []
        for _ in range(group_size):
            action_logits = policy(state)
            action_probs = torch.softmax(action_logits, dim=-1)

            # Simulate different solutions (2 correct, 1 incorrect)
            generated_steps = [
                ["2x = 4", "x = 2"],  # Correct
                ["2x = 4", "x = 2"],  # Correct
                ["x = 7"]             # Incorrect
            ][np.random.choice([0,1,2])]

            reward = env.get_reward(generated_steps)
            group_rewards.append(reward)

        # GRPO Advantage Calculation
        group_mean = np.mean(group_rewards)
        group_std = np.std(group_rewards) + 1e-8
        normalized_rewards = [(r - group_mean)/group_std for r in group_rewards]

        # GRPO Loss (with KL penalty)
        new_probs = torch.softmax(policy(state), dim=-1)
        kl_penalty = 0.1 * torch.sum(new_probs * torch.log(new_probs/old_probs))
        grpo_loss = -torch.mean(torch.tensor(normalized_rewards)) + kl_penalty

        print(f"GRPO Group Rewards: {group_rewards}")
        print(f"GRPO Normalized: {[round(r,2) for r in normalized_rewards]}")
        print(f"GRPO Loss: {grpo_loss.item():.2f}\n")

        # Update old_probs for the next iteration
        old_probs = new_probs.detach()

# =============================================
# Run Both Algorithms
# =============================================
if __name__ == "__main__":
    env = MathEnv()
    print("===== PPO Training =====")
    run_ppo(env)

    print("\n===== GRPO Training =====")
    run_grpo(env)

===== PPO Training =====
PPO Advantage: 1.40, Loss: -1.41
PPO Advantage: 1.40, Loss: -1.41
PPO Advantage: 1.40, Loss: -1.41

===== GRPO Training =====
GRPO Group Rewards: [0.0, 1.0, 1.0]
GRPO Normalized: [-1.41, 0.71, 0.71]
GRPO Loss: -0.00

GRPO Group Rewards: [1.0, 0.0, 1.0]
GRPO Normalized: [0.71, -1.41, 0.71]
GRPO Loss: -0.00

GRPO Group Rewards: [0.0, 1.0, 1.0]
GRPO Normalized: [-1.41, 0.71, 0.71]
GRPO Loss: -0.00

