<a href="https://colab.research.google.com/github/skywalker0803r/deep-learning-ian-goodfellow/blob/master/reinforce/PPG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [135]:
from collections import namedtuple,deque
from torch import nn
import torch.nn.functional as F
import torch
import gym
from tqdm import tqdm_notebook as tqdm
from torch.distributions import Categorical
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import TensorDataset
Memory = namedtuple('Memory', ['state', 'action', 'action_log_prob', 'reward', 'done', 'value'])
AuxMemory = namedtuple('Memory', ['state', 'target_value', 'old_values'])
memories = deque([])
aux_memories = deque([])

def init_(m):
    if isinstance(m, nn.Linear):
        gain = torch.nn.init.calculate_gain('tanh')
        torch.nn.init.orthogonal_(m.weight, gain)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)

class ExperienceDataset(Dataset):
    def __init__(self, data):
        super().__init__()
        self.data = data

    def __len__(self):
        return len(self.data[0])

    def __getitem__(self, ind):
        return tuple(map(lambda t: t[ind], self.data))

def create_shuffled_dataloader(data, batch_size):
    ds = ExperienceDataset(data)
    return DataLoader(ds, batch_size = batch_size, shuffle = True)

def clipped_value_loss(values, rewards, old_values, clip):
    value_clipped = old_values + (values - old_values).clamp(-clip, clip)
    value_loss_1 = (value_clipped.flatten() - rewards) ** 2
    value_loss_2 = (values.flatten() - rewards) ** 2
    return torch.mean(torch.max(value_loss_1, value_loss_2))

class Actor(nn.Module):
    def __init__(self, state_dim, hidden_dim, num_actions):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh()
        )

        self.action_head = nn.Sequential(
            nn.Linear(hidden_dim, num_actions),
            nn.Softmax(dim=-1)
        )

        self.value_head = nn.Linear(hidden_dim, 1)
        self.apply(init_)

    def forward(self, x):
        hidden = self.net(x)
        return self.action_head(hidden), self.value_head(hidden)

class Critic(nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1),
        )
        self.apply(init_)

    def forward(self, x):
        return self.net(x)
  
def normalize(t, eps = 1e-5):
    return (t - t.mean()) / (t.std() + eps)

def update_network_(loss, optimizer):
    optimizer.zero_grad()
    loss.mean().backward()
    optimizer.step()

In [136]:
class PPG(nn.Module):
  def __init__(self,state_dim,num_actions,actor_hidden_dim,critic_hidden_dim,epochs,
        epochs_aux,minibatch_size,lr,betas,lam,gamma,beta_s,eps_clip,value_clip,device):
    super().__init__()
    self.actor = Actor(state_dim, actor_hidden_dim, num_actions)
    self.critic = Critic(state_dim, critic_hidden_dim)
    self.device = device
    self.gamma = gamma
    self.lam = lam
    self.minibatch_size = minibatch_size
    self.eps_clip = eps_clip
    self.beta_s = beta_s
    self.opt_actor = torch.optim.Adam(self.actor.parameters(),lr=lr)
    self.opt_critic = torch.optim.Adam(self.critic.parameters(),lr=lr)
    self.value_clip = value_clip
    self.epochs = epochs
    self.epochs_aux = epochs_aux

  
  def load(self):
    self.load_state_dict(torch.load('model.pt'))
  
  def save(self):
    torch.save(self.state_dict(),'model.pt')
  
  def learn(self, memories, aux_memories, next_state):
        states = []
        actions = []
        old_log_probs = []
        rewards = []
        masks = []
        values = []

        for mem in memories:
            states.append(mem.state)
            actions.append(torch.tensor(mem.action))
            old_log_probs.append(mem.action_log_prob)
            rewards.append(mem.reward)
            # invert done for GAE calculations
            masks.append(1 - float(mem.done))
            values.append(mem.value)

        # calculate generalized advantage estimate
        next_state = torch.from_numpy(next_state).to(self.device)
        next_value = self.critic(next_state).detach()
        values = values + [next_value]

        returns = []
        gae = 0
        for i in reversed(range(len(rewards))):
            delta = rewards[i] + self.gamma * values[i + 1] * masks[i] - values[i]
            gae = delta + self.gamma * self.lam * masks[i] * gae
            returns.insert(0, gae + values[i])

        # convert values to torch tensors
        to_torch_tensor = lambda t: torch.stack(t).to(self.device).detach()

        states = to_torch_tensor(states)
        actions = to_torch_tensor(actions)
        old_values = to_torch_tensor(values[:-1])
        old_log_probs = to_torch_tensor(old_log_probs)

        rewards = torch.tensor(returns).float().to(self.device)

        # store state and target values to auxiliary memory buffer for later training
        aux_memory = AuxMemory(states, rewards, old_values)
        aux_memories.append(aux_memory)

        # prepare dataloader for policy phase training
        dl = create_shuffled_dataloader([states, actions, old_log_probs, rewards, old_values], self.minibatch_size)

        # policy phase training, similar to original PPO
        for _ in range(self.epochs):
            for states, actions, old_log_probs, rewards, old_values in dl:
                action_probs, _ = self.actor(states)
                values = self.critic(states)
                dist = Categorical(action_probs)
                action_log_probs = dist.log_prob(actions)
                entropy = dist.entropy()

                # calculate clipped surrogate objective, classic PPO loss
                ratios = (action_log_probs - old_log_probs).exp()
                advantages = normalize(rewards - old_values.detach())
                surr1 = ratios * advantages
                surr2 = ratios.clamp(1 - self.eps_clip, 1 + self.eps_clip) * advantages
                policy_loss = - torch.min(surr1, surr2) - self.beta_s * entropy

                update_network_(policy_loss, self.opt_actor)

                # calculate value loss and update value network separate from policy network
                value_loss = clipped_value_loss(values, rewards, old_values, self.value_clip)

                update_network_(value_loss, self.opt_critic)
  def learn_aux(self, aux_memories):
        states = []
        rewards = []
        old_values = []
        for state, reward, old_value in aux_memories:
            states.append(state)
            rewards.append(reward)
            old_values.append(old_value)

        states = torch.cat(states)
        rewards = torch.cat(rewards)
        old_values = torch.cat(old_values)

        # get old action predictions for minimizing kl divergence and clipping respectively
        old_action_probs, _ = self.actor(states)
        old_action_probs.detach_()

        # prepared dataloader for auxiliary phase training
        dl = create_shuffled_dataloader([states, old_action_probs, rewards, old_values], self.minibatch_size)

        # the proposed auxiliary phase training
        # where the value is distilled into the policy network, while making sure the policy network does not change the action predictions (kl div loss)
        for epoch in range(self.epochs_aux):
            for states, old_action_probs, rewards, old_values in dl:
                action_probs, policy_values = self.actor(states)
                action_logprobs = action_probs.log()

                # policy network loss copmoses of both the kl div loss as well as the auxiliary loss
                aux_loss = clipped_value_loss(policy_values, rewards, old_values, self.value_clip)
                loss_kl = F.kl_div(action_logprobs, old_action_probs, reduction='batchmean')
                policy_loss = aux_loss + loss_kl

                update_network_(policy_loss, self.opt_actor)

                # paper says it is important to train the value network extra during the auxiliary phase
                values = self.critic(states)
                value_loss = clipped_value_loss(values, rewards, old_values, self.value_clip)

                update_network_(value_loss, self.opt_critic)        

In [137]:
def main(
    env_name = 'LunarLander-v2',
    num_episodes = 5000,
    max_timesteps = 500,
    actor_hidden_dim = 32,
    critic_hidden_dim = 256,
    minibatch_size = 16,
    lr = 0.0005,
    betas = (0.9, 0.999),
    lam = 0.95,
    gamma = 0.99,
    eps_clip = 0.2,
    value_clip = 0.4,
    beta_s = .01,
    update_timesteps = 64,
    num_policy_updates_per_aux = 32,
    epochs = 1,
    epochs_aux = 6,
    render = False,
    render_every_eps = 250,
    save_every = 1000,
    load = True,
    print_reward_interval = 10,
):
    """
    :param env_name: OpenAI gym environment name
    :param num_episodes: number of episodes to train
    :param max_timesteps: max timesteps per episode
    :param actor_hidden_dim: actor network hidden layer size
    :param critic_hidden_dim: critic network hidden layer size
    :param minibatch_size: minibatch size for training
    :param lr: learning rate for optimizers
    :param betas: betas for Adam Optimizer
    :param lam: GAE lambda (exponential discount)
    :param gamma: GAE gamma (future discount)
    :param eps_clip: PPO policy loss clip coefficient
    :param value clip: value loss clip coefficient
    :param beta_s: entropy loss coefficient
    :param update_timesteps: number of timesteps to run before training
    :param epochs: policy phase epochs
    :param epochs_aux: auxiliary phase epochs
    :param render: toggle render environment
    :param render_every_eps: if render, how often to render
    :param save_every: how often to save networks
    :load: toggle load a previously trained network 
    """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    env = gym.make(env_name)

    state_dim = env.observation_space.shape[0]
    num_actions = env.action_space.n

    memories = deque([])
    aux_memories = deque([])

    agent = PPG(
        state_dim,
        num_actions,
        actor_hidden_dim,
        critic_hidden_dim,
        epochs,
        epochs_aux,
        minibatch_size,
        lr,
        betas,
        lam,
        gamma,
        beta_s,
        eps_clip,
        value_clip,
        device
    )

    if load:
      pass
        #agent.load()

    time = 0
    updated = False
    num_policy_updates = 0

    for eps in range(num_episodes):
        render_eps = render and eps % render_every_eps == 0
        state = env.reset()
        total_reward = 0
        for timestep in range(max_timesteps):
            time += 1

            if updated and render_eps:
                env.render()

            state = torch.from_numpy(state).to(device)
            action_probs, _ = agent.actor(state)
            value = agent.critic(state)

            dist = Categorical(action_probs)
            action = dist.sample()
            action_log_prob = dist.log_prob(action)
            action = action.item()

            next_state, reward, done, _ = env.step(action)
            total_reward += reward

            memory = Memory(state, action, action_log_prob, reward, done, value)
            memories.append(memory)

            state = next_state

            if time % update_timesteps == 0:
                agent.learn(memories, aux_memories, next_state)
                num_policy_updates += 1
                memories.clear()

                if num_policy_updates % num_policy_updates_per_aux == 0:
                    agent.learn_aux(aux_memories)
                    aux_memories.clear()

                updated = True

            if done:
              if eps % print_reward_interval == 0:
                print(f'eps:{eps},total_reward:{total_reward}')
                if render_eps:
                    updated = False
                break

        if render_eps:
            env.close()

        if eps % save_every == 0:
            pass
            #agent.save()

In [138]:
main()

eps:0,total_reward:-457.2370217378628
eps:10,total_reward:-265.042451966741
eps:20,total_reward:-192.21099659424584
eps:30,total_reward:-321.47092543806605
eps:40,total_reward:-548.6910798843332
eps:50,total_reward:-314.5578492951288
eps:60,total_reward:-561.3409298841557
eps:70,total_reward:-96.72455662811723
eps:80,total_reward:-120.39051650666418
eps:90,total_reward:-62.14781111608708
eps:100,total_reward:-435.59147498118887
eps:110,total_reward:-252.59948826203248
eps:120,total_reward:-87.01209656226551
eps:130,total_reward:-176.4662142309512
eps:140,total_reward:-342.8118258754097
eps:150,total_reward:-126.2154492097739
eps:160,total_reward:-336.39268060387974
eps:170,total_reward:-68.57622235573072
eps:180,total_reward:10.616886590529504
eps:190,total_reward:-335.5122101800707
eps:200,total_reward:-452.6678194020224
eps:210,total_reward:-404.0946033801291
eps:220,total_reward:-103.24625469870553
eps:230,total_reward:-89.47769911975224
eps:240,total_reward:-155.5899793045744
eps:2

KeyboardInterrupt: ignored