In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import gym
import time

In [3]:
# helper function to convert numpy arrays to tensors
def t(x): 
    return torch.from_numpy(x).float()

In [4]:
# Actor module, categorical actions only
class Actor(nn.Module):
    def __init__(self, state_dim, n_actions):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 32),
            nn.Tanh(),
            nn.Linear(32, n_actions),
            nn.Softmax()
        )
    
    def forward(self, X):
        return self.model(X)

    
    
# Critic module
class Critic(nn.Module):
    def __init__(self, state_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )
    
    def forward(self, X):
        return self.model(X)

In [5]:
env = gym.make("CartPole-v1")
# env = gym.make("LunarLander-v2")

In [6]:
# config
state_dim = env.observation_space.shape[0]
print('state space dim: {}'.format(state_dim))
n_actions = env.action_space.n
print('action space dim: {}'.format(n_actions))
actor = Actor(state_dim, n_actions)
critic = Critic(state_dim)
adam_actor = torch.optim.Adam(actor.parameters(), lr=1e-3)
adam_critic = torch.optim.Adam(critic.parameters(), lr=1e-3)
gamma = 0.99

# actor.load_state_dict(torch.load('./actor_{}.pth'.format('LunarLander'), map_location=torch.device('cpu')))
# critic.load_state_dict(torch.load('./critic_{}.pth'.format('LunarLander'), map_location=torch.device('cpu')))

state space dim: 4
action space dim: 2


We can use the vanilla policy gradient method:

$$ \nabla_\theta J(\theta) = \frac{1}{N} \sum^{N}_{i=1} \sum^{T}_{t=1} \nabla_\theta\log\pi_\theta(a_{i,t}|s_{i,t})\Bigg(\sum^{T}_{t^\prime=1}r(a_{i,t}|s_{i,t})\Bigg) $$


by definition, $$ Q_w(s_t, a_t) = r_{t+1} + V(s_{t+1})$$
is the total reward expected if action $a$ is taken in state $s$ at time $t$. One term of this equation is already "known", because given a state and an action at time $t$, the reward is determined by the environment.

$V(s_{t+1})$ is pretty uncertain (we don't know the actions yet, hence don't know the rewards). In other words, we can _estimate_ $V(s)$ because the actions are unknown yet.




In [7]:
episode_rewards, TAL, TCL = [], [], []

for i in range(500):
    done = False
    total_reward, total_critic_loss, total_actor_loss = 0, 0, 0
    
    state = env.reset()

    tt = time.time()
    while not done:
        probs = actor(t(state))
        dist = torch.distributions.Categorical(probs=probs)
        action = dist.sample()
        
        next_state, reward, done, info = env.step(action.detach().data.numpy())
        
        # see the cell above for explanation
        advantage = reward + (1-done)*gamma*critic(t(next_state)) - critic(t(state))
        
        total_reward += reward
        state = next_state

        critic_loss = advantage.pow(2).mean()
        total_critic_loss += critic_loss.item()
        adam_critic.zero_grad()
        critic_loss.backward()
        adam_critic.step()

        actor_loss = -dist.log_prob(action)*advantage.detach()
        total_actor_loss += actor_loss.item()
        adam_actor.zero_grad()
        actor_loss.backward()
        adam_actor.step()
    
    TAL.append(total_actor_loss)
    TCL.append(total_critic_loss)
    print('Episode: {}\t RWD: {:.2f}\t TCL: {:.2f}\t TAL: {:.2f}\t elapsed {:.2f} s.'.format(i,
                                                                         total_reward,
                                                                         total_critic_loss, 
                                                                         total_actor_loss,
                                                                         time.time()-tt))
    episode_rewards.append(total_reward)

Episode: 0	 RWD: 21.00	 TCL: 20.91	 TAL: 14.74	 elapsed 0.06 s.
Episode: 1	 RWD: 16.00	 TCL: 15.72	 TAL: 10.95	 elapsed 0.03 s.
Episode: 2	 RWD: 20.00	 TCL: 19.70	 TAL: 13.21	 elapsed 0.03 s.
Episode: 3	 RWD: 21.00	 TCL: 20.61	 TAL: 15.39	 elapsed 0.03 s.
Episode: 4	 RWD: 43.00	 TCL: 42.44	 TAL: 30.87	 elapsed 0.07 s.
Episode: 5	 RWD: 16.00	 TCL: 15.46	 TAL: 10.77	 elapsed 0.02 s.
Episode: 6	 RWD: 12.00	 TCL: 11.43	 TAL: 8.16	 elapsed 0.02 s.
Episode: 7	 RWD: 13.00	 TCL: 12.34	 TAL: 8.03	 elapsed 0.03 s.
Episode: 8	 RWD: 20.00	 TCL: 19.15	 TAL: 13.24	 elapsed 0.04 s.
Episode: 9	 RWD: 14.00	 TCL: 13.11	 TAL: 9.04	 elapsed 0.02 s.
Episode: 10	 RWD: 16.00	 TCL: 14.97	 TAL: 10.24	 elapsed 0.03 s.
Episode: 11	 RWD: 15.00	 TCL: 13.89	 TAL: 9.25	 elapsed 0.02 s.
Episode: 12	 RWD: 21.00	 TCL: 19.66	 TAL: 13.31	 elapsed 0.03 s.
Episode: 13	 RWD: 27.00	 TCL: 25.42	 TAL: 17.51	 elapsed 0.04 s.
Episode: 14	 RWD: 22.00	 TCL: 20.48	 TAL: 13.39	 elapsed 0.04 s.
Episode: 15	 RWD: 13.00	 TCL: 11.29	 TA

KeyboardInterrupt: 

In [None]:
plt.figure(figsize=(15,5))
plt.subplot(1,3,1)
plt.scatter(np.arange(len(episode_rewards)), episode_rewards, s=2)
plt.title("Total reward per episode (online)")
plt.ylabel("reward")
plt.xlabel("episode")

plt.subplot(1,3,2)
plt.plot(TAL)
plt.title('Total Actor Loss')

plt.subplot(1,3,3)
plt.plot(TCL)
plt.title('Total Critic Loss')
torch.save(actor.state_dict(), './actor_{}.pth'.format('LunarLander'))
torch.save(critic.state_dict(), './critic_{}.pth'.format('LunarLander'))