# Proximal Policy Optimization with LSTM

[Reference code](https://github.com/seungeunrho/minimalRL/blob/master/ppo-lstm.py)

## Points
+ Model predicts values and probabilities based on hidden states and environment states

+ The time sequence if LSTM corresponds to a complete interation sequence

+ Actor Network handles 1 timestep, while Critic Network handles B timesteps, where B means not only the batch size, but also the length of rollouts.

+ The initial lstm state in Critic Network is the first output lstm state for 1 rollout. (Not necessarily)

## Puzzles
+ Why does the result become poor if we use empty lstm_size in the training procedure?

## 1. Import packages

In [0]:
import gym
import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Categorical
import torch.nn.functional as F

## 2. Define constants

In [0]:
gamma = 0.98
lmbda = 0.95
num_epochs = 1000
num_rollouts = 20
reward_div = 100
k_epoch = 3
eps = 0.1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## 3. Prepare data

In [0]:
env = gym.make("CartPole-v0")

def get_sample(env, policy):
    lstm_state = (torch.zeros((1, 1, 32), dtype=torch.float).to(device), torch.zeros((1, 1, 32), dtype=torch.float).to(device))
    done = False
    s = env.reset() # (state_size, )
    while not done:
        ss, aa, rr, s_primes, done_masks = list(), list(), list(), list(), list()
        probs = list()
        first_lstm_state = None
        for t in range(num_rollouts):
#             if first_lstm_state is None: # Also work!
#                 first_lstm_state = lstm_state
            a, lstm_state = policy.sample_action(torch.Tensor(s).to(device), lstm_state) # !!! old states should be replaced by new states
            s_prime, r, done, _ = env.step(a) # a is 0 or 1
            ss.append(s)
            aa.append(a)
            rr.append(r)
            s_primes.append(s_prime)
            done_mask = 0.0 if done else 1.0
            done_masks.append(done_mask)
            probs.append(policy.policy(torch.Tensor(s).to(device), lstm_state)[0][0][a]) # policy.policy output: (probs: (1, 2), lstm_states)
            if first_lstm_state is None:
                first_lstm_state = lstm_state
            s = s_prime
            if done:
                break
                
        sample = (torch.Tensor(ss).to(device), torch.LongTensor(aa).to(device), \
                  torch.Tensor(rr).to(device), torch.Tensor(s_primes).to(device), \
                  torch.Tensor(done_masks).to(device), torch.Tensor(probs).to(device), first_lstm_state) # !!! return the first lstm states
        yield sample

## 4. Build model

In [0]:

def GAE(advantages, gamma, lmbda):
    gae_advantages = torch.zeros_like(advantages)
    gae = 0

    for ri in reversed(range(len(advantages))):
        gae = gae * gamma * lmbda + advantages[ri]
        gae_advantages[ri] = gae
    return gae_advantages


class PPO(nn.Module):
    def __init__(self):
        super(PPO, self).__init__()
        
        self.fc = nn.Linear(4, 64)
        self.lstm = nn.LSTM(64, 32, batch_first=True)
        self.fc_pi = nn.Linear(32, 2)
        self.fc_v = nn.Linear(32, 1)
        
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.0005, betas=(0.9, 0.99))

    def policy(self, states, lstm_states, softmax_dim=1):
        '''
            Input:
                states: (B, state_size=4)
                lstm_states: (B, 1, hidden_size=32)
        '''
        net = F.relu(self.fc(states)) # (B, 64)
        net = net.view(1, -1, 64) # (1, timestep=B, 64)
        net, lstm_states = self.lstm(net, lstm_states) # (1, timestep=B, 32)], ((1, 1, hidden_size=32), (1, 1, hidden_size=32))
        net = self.fc_pi(net.squeeze(0)) # (B, 2)
        probs = F.softmax(net, dim=softmax_dim)
        return probs, lstm_states
        
    def sample_action(self, state, lstm_states): # state: (4,) => indicates that the fully-connected layer in PyTorch can receive inputs without batch_size
        probs, lstm_states = self.policy(state, lstm_states) # (1, 2), lstm_states
        prbos = probs.view(2) # (2, )
        m = Categorical(probs)
        a_pred = m.sample().item()
        return a_pred, lstm_states # (predicted action: 0 or 1, log of probability of current action)

    def value(self, states, lstm_states):
        '''
            Input:
                states: (B, state_size=4)
                lstm_states: (B, 1, hidden_size=32)
        '''
        net = F.relu(self.fc(states)) # (B, state_size=4)
        net = net.view(1, -1, 64) # (1, timestep=B, state_size=4)
        net, lstm_states = self.lstm(net, lstm_states) # (B, 1, hidden_size=32), lstm_states
        net = net.view(-1, 1, 32) # debug
        return self.fc_v(net.squeeze(1))
      
    def fit(self, sample): # samples: [(s1, a1, r1), (s2, a2, r2), ...]
        (s, a, r, ns, done_mask, old_probs, (h, c)) = sample
        empty_lstm_state = (torch.zeros((1, 1, 32), dtype=torch.float).to(device), torch.zeros((1, 1, 32), dtype=torch.float).to(device))
        lstm_state = (h.detach(), c.detach())
        rewards = r / reward_div # (B, num_rollouts)
        
        for i in range(k_epoch):
            td_target = (rewards + gamma * self.value(ns, lstm_state).squeeze() * done_mask).unsqueeze(1) # (num_rollouts, 1)
            vs = self.value(s, lstm_state) # (num_rollouts, 1)
            advantages = td_target - vs # (num_rollouts, 1)

            advantages = GAE(advantages, gamma, lmbda).detach() # !!! detach the advantages
            
            
#             probs, _ = self.policy(s, lstm_state, softmax_dim=1) # (num_rollouts, action_size=2) # Also work!
            probs, _ = self.policy(s, empty_lstm_state, softmax_dim=1) # (num_rollouts, action_size=2)
  
            probs = probs.gather(1, a.unsqueeze(1)) # (num_rollouts, 1)
            
            ratio = torch.exp(torch.log(probs) - torch.log(old_probs.unsqueeze(1))) # (num_rollouts, 1) !!! tensor with size of (20 ,1) minus that of (20,) will produce (20, 20) tensor
            surr1 = ratio * advantages
            surr2 = torch.clamp(ratio, 1 - eps, 1 + eps) * advantages
            
            loss = torch.mean(-torch.min(surr1, surr2) +  F.smooth_l1_loss(vs, td_target.detach()))

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
        
ppo = PPO().to(device)

## 5. Train

In [5]:
score = 0.0

for epoch in range(num_epochs):
    sample_iter = get_sample(env, ppo)
    for sample in sample_iter:
        ppo.fit(sample)
        rewards = sample[2]
        score += sum(rewards)
        
    if epoch % 20 == 0:
        print('Epoch %d || Average Score: %.6f'%(epoch, score / (epoch + 1)))

Epoch 0 || Average Score: 15.000000
Epoch 20 || Average Score: 27.904762
Epoch 40 || Average Score: 29.999998
Epoch 60 || Average Score: 38.983604
Epoch 80 || Average Score: 46.135803
Epoch 100 || Average Score: 65.970299
Epoch 120 || Average Score: 72.512390
Epoch 140 || Average Score: 83.588654
Epoch 160 || Average Score: 96.093170
Epoch 180 || Average Score: 98.922653
Epoch 200 || Average Score: 108.064674
Epoch 220 || Average Score: 116.058830
Epoch 240 || Average Score: 121.356850
Epoch 260 || Average Score: 125.390800
Epoch 280 || Average Score: 127.925262
Epoch 300 || Average Score: 129.617935
Epoch 320 || Average Score: 132.392517
Epoch 340 || Average Score: 133.835770
Epoch 360 || Average Score: 134.318558
Epoch 380 || Average Score: 135.763779
Epoch 400 || Average Score: 138.456360
Epoch 420 || Average Score: 136.372925
Epoch 440 || Average Score: 137.934250
Epoch 460 || Average Score: 140.121475
Epoch 480 || Average Score: 142.465698
Epoch 500 || Average Score: 144.762482
Ep