In [7]:
import torch.optim as optim
from torch.distributions import Categorical

class PPOAgent:
    def __init__(self, obs_shape=71, action_dim=5, lr=3e-4, gamma=0.99, eps_clip=0.2):
        self.policy = WorkerNetwork(obs_shape, action_dim)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
        self.policy_old = WorkerNetwork(obs_shape, action_dim)
        self.policy_old.load_state_dict(self.policy.state_dict())
        
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.mse_loss = nn.MSELoss()

    def select_action(self, state, storage):
        state = torch.FloatTensor(state).unsqueeze(0)
        with torch.no_grad():
            logits, val = self.policy_old(state)
        
        probs = F.softmax(logits, dim=-1)
        dist = Categorical(probs)
        action = dist.sample()
        
        storage.states.append(state)
        storage.actions.append(action)
        storage.logprobs.append(dist.log_prob(action))
        storage.values.append(val)
        
        return action.item()

    def update(self, storage):
        # Convert list to tensors
        old_states = torch.cat(storage.states).detach()
        old_actions = torch.cat(storage.actions).detach()
        old_logprobs = torch.cat(storage.logprobs).detach()
        
        # Calculate Rewards-to-go (Returns) and Advantages
        returns = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(storage.rewards), reversed(storage.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            returns.insert(0, discounted_reward)
            
        returns = torch.tensor(returns, dtype=torch.float32)
        advantages = returns - torch.cat(storage.values).detach().squeeze()

        # PPO Update Loop (typically 10-20 epochs)
        for _ in range(10):
            logits, state_values = self.policy(old_states)
            probs = F.softmax(logits, dim=-1)
            dist = Categorical(probs)
            logprobs = dist.log_prob(old_actions)
            
            # Policy Ratio
            ratios = torch.exp(logprobs - old_logprobs)
            
            # Clipped Objective
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages
            
            loss = -torch.min(surr1, surr2) + 0.5 * self.mse_loss(state_values.squeeze(), returns)
            
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()
            
        self.policy_old.load_state_dict(self.policy.state_dict())
        storage.clear()

class RolloutStorage:
    def __init__(self):
        self.states, self.actions, self.logprobs, self.rewards, self.is_terminals, self.values = [], [], [], [], [], []
    
    def clear(self):
        del self.states[:], self.actions[:], self.logprobs[:], self.rewards[:], self.is_terminals[:], self.values[:]
