In [1]:
#Install a pip package in the current Jupyter kernel
import sys
!{sys.executable} -m pip install torch
!{sys.executable} -m pip install gym
!{sys.executable} -m pip install gym[classic_control]



In [60]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np
import gym

class ActorCritic(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=16):
        super(ActorCritic, self).__init__()
        self.actor = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
            nn.Softmax(dim=-1)
        )
        self.critic = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, state):
        action_probs = self.actor(state)
        state_value = self.critic(state)
        return action_probs, state_value

def compute_returns(rewards, gamma):
    R = 0
    returns = []
    for step in reversed(range(len(rewards))):
        R = rewards[step] + gamma * R
        returns.insert(0, R)
    return returns



In [61]:

def rollout(steps_per_batch, model, env, gamma):
    states, actions, log_probs, returns, rewards = [], [], [], [], []
    
    t = 0
    n_eps = 0
    while t < steps_per_batch:
        state, _ = env.reset()
        done = False
        eps_rewards = []

        while not done:
            state = torch.FloatTensor(state).unsqueeze(0)
            action_probs, _ = model(state)
            distribution = Categorical(action_probs)

            action = distribution.sample()
            log_prob = distribution.log_prob(action)

            next_state, reward, done, _, _ = env.step(action.item())

            eps_rewards.append(reward)
            log_probs.append(log_prob)

            states.append(state)
            actions.append(action)
            
            state = next_state

        returns.extend(compute_returns(eps_rewards, gamma))
        rewards.extend(eps_rewards)
        t += len(returns)
        n_eps += 1

    return states, actions, log_probs, returns, rewards, n_eps

"""
env = gym.make("CartPole-v1")
input_dim = env.observation_space.shape[0]
output_dim = env.action_space.n
model = ActorCritic(input_dim, output_dim)
steps_per_batch = 50

states, actions, log_probs, returns, rewards = rollout(steps_per_batch, model, env, gamma=0.99)
"""

'\nenv = gym.make("CartPole-v1")\ninput_dim = env.observation_space.shape[0]\noutput_dim = env.action_space.n\nmodel = ActorCritic(input_dim, output_dim)\nsteps_per_batch = 50\n\nstates, actions, log_probs, returns, rewards = rollout(steps_per_batch, model, env, gamma=0.99)\n'

In [63]:


def train_ppo(
        env_name="CartPole-v1", gamma=0.7, lr=1e-6,
        clip_epsilon=0.2,epochs=400, num_updates=10,
        steps_per_batch=2048):
    env = gym.make(env_name)
    input_dim = env.observation_space.shape[0]
    output_dim = env.action_space.n

    model = ActorCritic(input_dim, output_dim)
    policy_optimizer = optim.Adam(model.actor.parameters(), lr=lr)
    value_optimizer = optim.Adam(model.critic.parameters(), lr=lr)
        
    for epoch in range(epochs):
        states, actions, log_probs, returns, rewards, n_eps = rollout(steps_per_batch, model, env, gamma)
    
        states = torch.stack(states) 
        actions = torch.stack(actions)
        returns = torch.FloatTensor(returns).unsqueeze(0)
        log_probs = torch.FloatTensor(log_probs).unsqueeze(0)
        for _ in range(5):
            values = model.critic(states).squeeze().unsqueeze(0)

            cur_action_probs = model.actor(states)
            dist = Categorical(cur_action_probs)
            cur_log_probs = dist.log_prob(actions)

            advantage = returns - values.detach()
            advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-05)

            ratio = (cur_log_probs - log_probs.detach()).exp()
            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantage
            policy_loss = -torch.min(surr1, surr2).mean()

            policy_optimizer.zero_grad()
            policy_loss.backward(retain_graph=True)
            policy_optimizer.step()

            value_loss = nn.MSELoss(reduction="mean")(returns, values)

            value_optimizer.zero_grad()
            value_loss.backward()
            value_optimizer.step()

        if epoch % num_updates == 0:
            print("Epoch: {}, Eps count: {} Reawrds: {:.4f} Policy Loss: {:.4f}, Value Loss: {:.4f}".format(epoch, n_eps, sum(rewards), policy_loss.item(), value_loss.item()))

    env.close()

train_ppo()

Epoch: 0, Eps count: 13 Reawrds: 297.0000 Policy Loss: -0.0032, Value Loss: 10.5298
Epoch: 10, Eps count: 13 Reawrds: 366.0000 Policy Loss: -0.0034, Value Loss: 10.8348
Epoch: 20, Eps count: 11 Reawrds: 343.0000 Policy Loss: 0.0002, Value Loss: 10.9143
Epoch: 30, Eps count: 15 Reawrds: 318.0000 Policy Loss: -0.0025, Value Loss: 10.1740
Epoch: 40, Eps count: 14 Reawrds: 332.0000 Policy Loss: -0.0049, Value Loss: 10.3904
Epoch: 50, Eps count: 14 Reawrds: 277.0000 Policy Loss: -0.0011, Value Loss: 10.0351
Epoch: 60, Eps count: 14 Reawrds: 294.0000 Policy Loss: -0.0024, Value Loss: 10.0615
Epoch: 70, Eps count: 13 Reawrds: 332.0000 Policy Loss: -0.0013, Value Loss: 10.6247
Epoch: 80, Eps count: 14 Reawrds: 282.0000 Policy Loss: -0.0070, Value Loss: 10.1615
Epoch: 90, Eps count: 14 Reawrds: 302.0000 Policy Loss: 0.0026, Value Loss: 10.2234
Epoch: 100, Eps count: 14 Reawrds: 394.0000 Policy Loss: -0.0057, Value Loss: 10.8770
Epoch: 110, Eps count: 15 Reawrds: 251.0000 Policy Loss: -0.0036, V

In [8]:
import gym
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np

env = gym.make("CartPole-v1", render_mode="rgb_array")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

print(env.observation_space.shape)
print(env.action_space)

(4,)
Discrete(2)


In [10]:
import math
math.log(0.5631)

-0.5742980467215641