### 1. 환경

In [1]:
import gym

env_name = "CartPole-v1"
render_mode = None
env = gym.make(env_name, render_mode = render_mode)

In [2]:
observation, info = env.reset()

print(observation)
print(info)

[ 0.02752061  0.0366879  -0.00013642  0.02821334]
{}


In [3]:
env.observation_space

Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)

In [4]:
env.action_space

Discrete(2)

In [5]:
for _ in range(3):
    action_sample = env.action_space.sample()
    print(action_sample)

1
1
1


In [6]:
observation, reward, terminated, truncated, info = env.step(action_sample)

print(f"observation: {observation}")
print(f"reward: {reward}")
print(f"terminated: {terminated}")
print(f"truncated: {truncated}")
print(f"info: {info}")

observation: [ 0.02825437  0.2318118   0.00042785 -0.26451263]
reward: 1.0
terminated: False
truncated: False
info: {}


  if not isinstance(terminated, (bool, np.bool8)):


In [7]:
env.close()

### 2. PPO

In [8]:
import numpy as np
import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal
from torch.distributions import Categorical

class RolloutBuffer:
    def __init__(self):
        self.actions = []
        self.states = []
        self.log_probs = []
        self.rewards = []
        self.state_values = []
        self.is_terminals = []
    

    def clear(self):
        self.actions.clear()
        self.states.clear()
        self.log_probs.clear()
        self.rewards.clear()
        self.state_values.clear()
        self.is_terminals.clear()

class ActorCritic(nn.Module):
    def __init__(
            self, 
            state_dim = 4, 
            action_dim = 2, 
            d_model = 64,
            is_continuous_action_space = False, 
            action_std_init = 0.6,
            device = "cpu"
        ):
        super(ActorCritic, self).__init__()
        self.is_continuous_action_space = is_continuous_action_space
        self.device = device

        if is_continuous_action_space:
            self.action_dim = action_dim
            self.action_var = torch.full(
                size = (action_dim, ),
                fill_value = action_std_init ** 2,
                device = device
            )

        # actor
        self.actor = nn.Sequential(
            nn.Linear(state_dim, d_model),
            nn.Tanh(),
            nn.Linear(d_model, d_model),
            nn.Tanh(),
            nn.Linear(d_model, action_dim)
        ).to(device)

        # critic
        self.critic = nn.Sequential(
            nn.Linear(state_dim, d_model),
            nn.Tanh(),
            nn.Linear(d_model, d_model),
            nn.Tanh(),
            nn.Linear(d_model, 1)
        ).to(device)

    def set_action_std(self, new_action_std):
        if self.is_continuous_action_space:
            self.action_var = torch.full(
                size = (self.action_dim, ), 
                fill_value = new_action_std ** 2,
                device = self.device
            )

    def forward(self, *args, **kwargs):
        raise NotImplementedError
    
    def act(self, state):
        batch_size = len(state)

        if self.is_continuous_action_space:
            dist = MultivariateNormal(
                loc = self.actor(state), 
                covariance_matrix = torch.diag(self.action_var)
            )
        else:
            dist = Categorical(
                logits = self.actor(state)
            )

        action = dist.sample()
        log_prob = dist.log_prob(action)
        state_value = self.critic(state)

        action = action.reshape(batch_size, -1)
        log_prob = log_prob.reshape(batch_size, -1)
        state_value = state_value.reshape(batch_size, -1)
        return action.detach(), log_prob.detach(), state_value.detach()
    
    def evaluate(self, state, action):
        batch_size = len(state)

        if self.is_continuous_action_space:
            dist = MultivariateNormal(
                loc = self.actor(state), 
                covariance_matrix = torch.diag(self.action_var)
            )
            log_prob = dist.log_prob(action)

        else:
            dist = Categorical(
                logits = self.actor(state)
            )
            log_prob = dist.log_prob(action[:, 0])
        
        dist_entropy = dist.entropy()
        state_value = self.critic(state)

        log_prob = log_prob.reshape(batch_size, -1)
        state_value = state_value.reshape(batch_size, -1)
        dist_entropy = dist_entropy.reshape(batch_size, -1)
        return log_prob, state_value, dist_entropy
    
class PPO:
    def __init__(
            self,
            state_dim = 4, 
            action_dim = 2, 
            d_model = 64,
            is_continuous_action_space = False,
            action_std_init = 0.6,
            K_epochs = 80,
            eps_clip = 0.2,
            gamma = 0.99,
            lr_actor = 0.0003,
            lr_critic = 0.001,
            device = "cpu"
    ):
        if is_continuous_action_space:
            self.action_std = action_std_init

        self.is_continuous_action_space = is_continuous_action_space
        self.K_epochs = K_epochs
        self.eps_clip = eps_clip
        self.gamma = gamma
        self.device = device

        self.buffer = RolloutBuffer()

        self.policy = ActorCritic(
            state_dim = state_dim, 
            action_dim = action_dim, 
            d_model = d_model,
            is_continuous_action_space = is_continuous_action_space, 
            action_std_init = action_std_init,
            device = device
        )

        self.policy_old = ActorCritic(
            state_dim = state_dim, 
            action_dim = action_dim, 
            d_model = d_model,
            is_continuous_action_space = is_continuous_action_space, 
            action_std_init = action_std_init,
            device = device
        )
        self.policy_old.load_state_dict(self.policy.state_dict())

        self.optimizer = torch.optim.Adam([
            {"params": self.policy.actor.parameters(), "lr": lr_actor},
            {"params": self.policy.critic.parameters(), "lr": lr_critic}
        ])

    def set_action_std(self, new_action_std):
        if self.is_continuous_action_space:
            self.action_std = new_action_std
            self.policy.set_action_std(new_action_std)
            self.policy_old.set_action_std(new_action_std)

    def decay_action_std(self, action_std_decay_rate, min_action_std):
        if self.is_continuous_action_space:
            new_action_std = self.action_std - action_std_decay_rate
            new_action_std = round(new_action_std, 4)
            new_action_std = min_action_std if new_action_std <= min_action_std else new_action_std
            self.set_action_std(new_action_std)

    @torch.no_grad()
    def select_action(self, state):
        state = torch.FloatTensor(state).to(self.device)
        action, log_prob, state_value = self.policy_old.act(state)

        self.buffer.states.append(state)
        self.buffer.actions.append(action)
        self.buffer.log_probs.append(log_prob)
        self.buffer.state_values.append(state_value)

        return action.detach().cpu().numpy()
    
    def save(self, path):
        torch.save(self.policy_old.state_dict(), path)

    def update(self):
        # monte carlo estimate of returns
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(self.buffer.rewards), reversed(self.buffer.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)
            
        # normalize the rewards
        rewards = torch.tensor(rewards).float()
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)
        rewards = rewards.reshape(len(rewards), -1)
        rewards = rewards.to(self.device)

        # convert list to tensor
        old_states = torch.cat(self.buffer.states)
        old_actions = torch.cat(self.buffer.actions)
        old_log_probs = torch.cat(self.buffer.log_probs)
        old_state_values = torch.cat(self.buffer.state_values)

        # calculate advantages
        advantages = rewards - old_state_values

        total_loss = []
        # optimize policy for K epochs
        for _ in range(self.K_epochs):
            # Evaluating old actions and values
            log_prob, state_value, dist_entropy = self.policy.evaluate(old_states, old_actions)
            
            # Finding the ratio (pi_theta / pi_theta__old)
            ratios = torch.exp(log_prob - old_log_probs)

            # Finding Surrogate Loss   
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages

            # final loss of clipped objective PPO
            loss = -1.0 * torch.min(surr1, surr2) + 0.5 * (state_value - rewards) ** 2 - 0.01 * dist_entropy
            loss = loss.mean()
            total_loss.append(loss.item())

            # take gradient step
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        # Copy new weights into old policy
        self.policy_old.load_state_dict(self.policy.state_dict())

        # clear buffer
        self.buffer.clear()

        # retun
        return np.mean(total_loss)

    def load(self, path):
        self.policy_old.load_state_dict(torch.load(path))
        self.policy.load_state_dict(torch.load(path))

In [9]:
ppo = PPO(is_continuous_action_space = False, device = "cpu")

state = torch.randn(64, 4)
print(f"state: {state.shape}")

action, log_prob, state_value = ppo.policy.act(state)
print(f"action: {action.shape}")
print(f"log_prob: {log_prob.shape}")
print(f"state_value: {state_value.shape}")

log_prob, state_value, dist_entropy = ppo.policy.evaluate(state, action)
print(f"log_prob: {log_prob.shape}")
print(f"state_value: {state_value.shape}")
print(f"dist_entropy: {dist_entropy.shape}")

for i in range(len(state)):
    s = state[i:i+1]
    ppo.select_action(s)
    ppo.buffer.rewards.append(1)
    ppo.buffer.is_terminals.append(0)

state: torch.Size([64, 4])
action: torch.Size([64, 1])
log_prob: torch.Size([64, 1])
state_value: torch.Size([64, 1])
log_prob: torch.Size([64, 1])
state_value: torch.Size([64, 1])
dist_entropy: torch.Size([64, 1])


### Train PPO

In [10]:
env_name = "CartPole-v1"
is_continuous_action_space = False

action_std = None
min_action_std = None
action_std_decay_rate = None
action_std_decay_freq = None

max_ep_len = 400
max_training_timesteps = 40000

update_timestep = max_ep_len * 4
K_epochs = 40
eps_clip = 0.2
gamma = 0.99
lr_actor = 0.0003
lr_critic = 0.001
device = "cpu"

random_seed = 0

env = gym.make(env_name)
state_dim = env.observation_space.shape[0]

if is_continuous_action_space:
    action_dim = env.action_space.shape[0]
else:
    action_dim = env.action_space.n

d_model = 64

ppo = PPO(
    state_dim = state_dim, 
    action_dim = action_dim, 
    d_model = d_model,
    is_continuous_action_space = is_continuous_action_space,
    action_std_init = action_std,
    K_epochs = K_epochs,
    eps_clip = eps_clip,
    gamma = gamma,
    lr_actor = lr_actor,
    lr_critic = lr_critic,
    device = device
)

time_step = 0
print_running_reward = 0
print_running_episodes = 0

while time_step <= max_training_timesteps:    
    state, _ = env.reset()
    current_ep_reward = 0

    for t in range(1, max_ep_len + 1):
        # select action with policy
        action = ppo.select_action(state[None, :])
        state, reward, terminated, truncated, info = env.step(action.item())

        # saving reward and is_terminals
        ppo.buffer.rewards.append(reward)
        ppo.buffer.is_terminals.append(terminated or truncated)
    
        time_step += 1
        current_ep_reward += reward

        # update PPO agent
        if time_step % update_timestep == 0:
            loss = ppo.update()

            print_avg_reward = print_running_reward / print_running_episodes
            print(f"time step: {time_step}, loss: {loss:.3f}, reward: {print_avg_reward:.2f}")

            print_running_reward = 0            
            print_running_episodes = 0

        # if continuous action space; then decay action std of ouput action distribution
        if is_continuous_action_space:
            if time_step % action_std_decay_freq == 0:
                ppo.decay_action_std(action_std_decay_rate, min_action_std)

        # break; if the episode is over
        if terminated or truncated:
            break

    print_running_reward += current_ep_reward
    print_running_episodes += 1

env.close()
ppo.save("model.pt")

  if not isinstance(terminated, (bool, np.bool8)):


time step: 1600, loss: 0.700, reward: 20.57
time step: 3200, loss: 0.501, reward: 28.33
time step: 4800, loss: 0.415, reward: 31.22
time step: 6400, loss: 0.580, reward: 48.68
time step: 8000, loss: 0.400, reward: 61.84
time step: 9600, loss: 0.321, reward: 86.06
time step: 11200, loss: 0.430, reward: 130.69
time step: 12800, loss: 0.463, reward: 167.44
time step: 14400, loss: 0.325, reward: 228.71
time step: 16000, loss: 0.150, reward: 201.88
time step: 17600, loss: 0.017, reward: 180.75
time step: 19200, loss: 0.414, reward: 300.40
time step: 20800, loss: 0.097, reward: 300.33
time step: 22400, loss: -0.940, reward: 348.20
time step: 24000, loss: 0.771, reward: 321.25
time step: 25600, loss: 0.361, reward: 387.25
time step: 27200, loss: 0.729, reward: 400.00
time step: 28800, loss: 0.738, reward: 393.60
time step: 30400, loss: 0.442, reward: 367.75
time step: 32000, loss: 0.174, reward: 385.50
time step: 33600, loss: 0.565, reward: 347.00
time step: 35200, loss: 0.087, reward: 384.75

### Test

In [11]:
env = gym.make(env_name, render_mode = "human")

ppo = PPO(state_dim, action_dim)
ppo.load("model.pt")

state, _ = env.reset()
for _ in range(400):
    action = ppo.select_action(state[None, :])
    state, reward, terminated, truncated, info = env.step(action.item())
    if truncated:
        break

ppo.buffer.clear()
env.close()