<a href="https://colab.research.google.com/github/tnfru/PPO/blob/ppg/PPG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical
from collections import deque
from torch.utils.data import DataLoader
from tqdm.notebook import trange
import wandb
from torch.distributions.kl import kl_divergence

In [2]:
gym.envs.register(
    id='CartPole-v2000',
    entry_point='gym.envs.classic_control:CartPoleEnv',
    max_episode_steps=2000
)

In [3]:
def normalize(x):
    if torch.isnan(x.std()):
        return x - x.mean(0)
    
    return (x - x.mean(0)) / (x.std(0) + 1e-8)

In [4]:
class Trajectory(torch.utils.data.Dataset):
    def __init__(self):
        self.states = []
        self.log_probs = []
        self.actions = []
        self.expected_returns = []
        self.dones = []
        self.advantages = []

    def __len__(self):
        return len(self.states)

    def append_trajectory(self, states, actions, expected_returns, dones, log_probs, advantages):
        self.states.extend(states)
        self.actions.extend(actions)
        self.expected_returns.extend(expected_returns)
        self.dones.extend(dones)
        self.log_probs.extend(log_probs)
        self.advantages.extend(advantages)
    
    def convert_to_advantages(self, rewards, state_vals, discount_factor, gae_lambda=0.95):
        advantages = []
        advantage = 0
        next_state_value = 0

        for reward, state_val in zip(reversed(rewards), reversed(state_vals)):
            td_error = reward + discount_factor * next_state_value - state_val
            advantage = td_error + discount_factor * gae_lambda * advantage
            next_state_value = state_val
            advantages.insert(0, advantage)

        return torch.tensor(advantages, dtype=torch.float)
        
    def fix_datatypes(self):
        self.states = torch.stack(self.states)
        self.actions = torch.tensor(self.actions, dtype=torch.long)
        self.dones = torch.tensor(self.dones, dtype=torch.int)
        self.log_probs = torch.tensor(self.log_probs, dtype=torch.float)
        

    def clear_memory(self):
        self.states = []
        self.log_probs = []
        self.actions = []
        self.rewards = []
        self.dones = []
    
    def __getitem__(self, index):
        state = self.states[index]
        action = self.actions[index]
        expected_return = self.expected_returns[index]
        done = self.dones[index]
        log_prob = self.log_probs[index]
        advantage = self.advantages[index]
        
        return state, action, expected_return, done, log_prob, advantage

In [5]:
a = Categorical(logits=torch.tensor([1,2]).float())
a.probs

tensor([0.2689, 0.7311])

In [37]:
class Agent:
    def __init__(self, env, action_dim, state_dim, batch_size, clip_ratio=0.25, entropy_coeff=0.01,
                 kl_max=0.01):
        self.env = env
        self.actor = PPG(action_dim, state_dim)
        self.actor_old = PPG(action_dim, state_dim)
        self.critic = CriticNet(state_dim)
        self.discount_factor = 0.99
        self.batch_size = batch_size
        self.clip_ratio = clip_ratio
        self.actor_opt = optim.Adam(self.actor.parameters(), lr=3e-4)
        self.critic_opt = optim.Adam(self.critic.parameters(), lr=1e-3)
        self.critic_loss_fun = nn.MSELoss()
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.train_iterations = 1
        self.aux_iterations = 6
        self.aux_freq = 32
        self.steps = 0
        self.rollout_length = 256
        self.entropy_coeff = entropy_coeff
        self.kl_max = kl_max
        
        # initialize logging
        wandb.init(project="cartpole")
        wandb.watch(self.actor, log="all")
        wandb.watch(self.critic, log="all")
        wandb.run.name = 'PPG_' + wandb.run.name 
        
    def get_action(self, obs):
        action_probs, _ = self.actor(obs)

        action_dist = Categorical(logits=action_probs)
        action = action_dist.sample()
        state_val = self.critic(obs)
        
        return action.item(), action_dist.log_prob(action).item(), state_val.squeeze().item()
        
    
    def train(self):
        for i in range(self.train_iterations):
            loader = DataLoader(self.trajectory, batch_size=self.batch_size, shuffle=True, pin_memory=True)
            self.steps += 1
            self.actor_old.load_state_dict(self.actor.state_dict())
            
            for states, actions, expected_returns, dones, old_log_probs, advantages in loader:
                states = states.to(self.device)
                actions = actions.to(self.device)
                expected_returns = expected_returns.unsqueeze(1).to(self.device)
                advantages = advantages.to(self.device)
                old_log_probs = old_log_probs.to(self.device)
                
                action_probs, _ = self.actor(states)
                action_dist = Categorical(logits=action_probs)
                with torch.no_grad():
                    old_action_probs, _ = self.actor_old(states)
                    old_action_dist = Categorical(logits=old_action_probs)
                    kl_div = kl_divergence(old_action_dist, action_dist)

                wandb.log({'kl div': kl_div.mean()})
                if kl_div.mean() < self.kl_max:
                    log_probs = action_dist.log_prob(actions)
                    
                    # log trick for efficient computational graph during backprop
                    ratio = torch.exp(log_probs - old_log_probs)
                    
                    weighted_objective = advantages * ratio
                    clipped_objective = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * advantages
                    entropy_loss = action_dist.entropy().mean() * self.entropy_coeff
                    
                    objective = -torch.min(weighted_objective, clipped_objective).mean() - entropy_loss
                    wandb.log({'entropy': entropy_loss})
                    
                    self.actor_opt.zero_grad()
                    objective.backward(retain_graph=True)
                    self.actor_opt.step()
                
                self.train_critic(states, expected_returns)

        if self.steps == self.aux_freq:
            self.steps = 0
            self.actor_old.load_state_dict(self.actor.state_dict())

            for _ in range(self.aux_iterations):
                loader = DataLoader(self.trajectory, batch_size=self.batch_size, shuffle=True, pin_memory=True)
                for states, actions, expected_returns, _, _, _ in loader:
                    expected_returns = expected_returns.unsqueeze(1).to(self.device)
                    states = states.to(self.device)
                    actions = actions.to(self.device)

                    action_probs_old, values_old = self.actor_old(states)
                    action_dist_old = Categorical(logits=action_probs_old.detach())

                    self.train_aux(states, expected_returns, action_dist_old, values_old)
                    self.train_critic(states, expected_returns)

    def clipped_value_loss(self, values, rewards, old_values, clip=0.4):
        value_clipped = old_values + (values - old_values).clamp(-clip, clip)
        value_loss_1 = (value_clipped.flatten() - rewards.flatten()) ** 2
        value_loss_2 = (values.flatten() - rewards.flatten()) ** 2
        return torch.mean(torch.max(value_loss_1, value_loss_2))


    def train_aux(self, states, expected_returns, action_dist_old, values_old, beta=1, val_clip=1e-5):
        action_probs, state_values = self.actor(states)
        action_dist = Categorical(logits=action_probs)

        kl_div = beta * kl_divergence(action_dist_old, action_dist).mean()
        if kl_div < self.kl_max:
            clipped_val_loss = self.clipped_value_loss(state_values, expected_returns, values_old).mean() * val_clip
            actor_val_loss = clipped_val_loss + kl_div 
    
            self.actor_opt.zero_grad()
            actor_val_loss.backward()
            self.actor_opt.step()
    
            wandb.log({'aux state value': state_values.mean()})
            wandb.log({'aux loss': actor_val_loss.mean()})
            wandb.log({'aux kl_div': kl_div})
            
    
    def train_critic(self, states, expected_returns):
        state_values = self.critic(states)
        critic_loss = self.critic_loss_fun(state_values, expected_returns)

        self.critic_opt.zero_grad()
        critic_loss.backward()
        self.critic_opt.step()
    
        wandb.log({'critic loss': critic_loss.mean()})
        wandb.log({'critic state value': state_values.mean()})
        
    def run_episode(self, trajectory=None, render=False):
        observations = []
        actions = []
        rewards = []
        dones = []
        log_probs = []
        state_vals = []

        done = False
        obs = self.env.reset()
        
        while not done:
            if render:
                self.env.render()

            obs = torch.tensor(obs, dtype=torch.float, device=self.device)
            action, log_prob, state_val = agent.get_action(obs)
            obs = obs.cpu()
            next_obs, reward, done, _ = self.env.step(action)
            
            observations.append(obs)
            state_vals.append(state_val)
            actions.append(action)
            rewards.append(reward)
            dones.append(done)
            log_probs.append(log_prob)

            obs = next_obs
            
        wandb.log({'reward': np.sum(rewards)})
        
        if trajectory is not None:
            advantages = trajectory.convert_to_advantages(rewards, state_vals, self.discount_factor)
            expected_returns = torch.tensor(state_vals, dtype=torch.float) + advantages
            advantages = normalize(advantages)
            trajectory.append_trajectory(observations, actions, expected_returns, dones, log_probs, advantages)
            return trajectory

    def forget(self):
        self.trajectory = Trajectory()

    def run_timesteps(self, num_timesteps):
        timesteps = 0
        self.forget()

        while timesteps < num_timesteps:
            self.trajectory = self.run_episode(self.trajectory)

            if len(self.trajectory) >= self.rollout_length:
                timesteps += len(self.trajectory)

                self.trajectory.fix_datatypes()
                self.train()
                self.forget()

In [38]:
class Policy(nn.Module):
    def __init__(self, action_dim, state_dim):
        super(Policy, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU()
        )
        
        self.fc2 = nn.Sequential(
            nn.Linear(64, 64),
            nn.ReLU()
        )
        
        self.fc3 = nn.Sequential(
            nn.Linear(64, action_dim)
        )
        
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.to(self.device)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        
        return x

In [39]:
class PPG(nn.Module):
    def __init__(self, action_dim, state_dim):
        super(PPG, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU()
        )
        
        self.fc2 = nn.Sequential(
            nn.Linear(64, 64),
            nn.ReLU()
        )
        
        self.action_head = nn.Sequential(
            nn.Linear(64, action_dim)
        )

        self.val_head = nn.Sequential(
            nn.Linear(64, 1)
        )
        
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.to(self.device)
        

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        acts = self.action_head(x)
        vals = self.val_head(x)
        
        return acts, vals

In [40]:
class CriticNet(nn.Module):
    def __init__(self, state_dim):
        super(CriticNet, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU()
        )
        
        self.fc2 = nn.Sequential(
            nn.Linear(64, 64),
            nn.ReLU()
        )
        
        self.fc3 = nn.Sequential(
            nn.Linear(64, 1)
        )       
        
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.to(self.device)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        
        return x

In [None]:
env = gym.make("CartPole-v2000")
agent = Agent(env, 2, 4, 64)

agent.run_timesteps(1e7)
        

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
reward,2000.0
_runtime,316.0
_timestamp,1628595526.0
_step,13623.0
kl div,0.01371
entropy,0.00487
critic loss,1.349
critic state value,92.45853
aux state value,3.05204
aux loss,0.0879


0,1
reward,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂██
_runtime,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇██
_timestamp,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇██
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
kl div,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▃▃▁▃▁▁▁█▁▄▁▁▃▁▁▂▆▇▁▁▁
entropy,████▇▇▇▆▆▅▆▆▅▄▄▅▄▄▄▃▃▃▃▄▂▃▂▂▃▂▃▂▂▃▂▁▂▁▂▁
critic loss,▂▂▂▂▂▃▂▂▂▂▁▁▁▁▁▁▂▂▁▁▂▆▁▁▁▁▁█▁▁▁▁▁▁▄▁▁▁▂▁
critic state value,▁▁▂▂▂▃▃▄▄▅▅▅▆▆▅▆▇▇▇▇▇▇█▇██▇█████████████
aux state value,▁▁▁▁▂▂▂▄▄▄▅▅▆▆▆▇▆▇▆▇▇▇▇▇▇▇▇▇████████████
aux loss,▁▁▁▁▁▂▂▃▃▅▄▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██████████████
