# **一个Advantage Actor-Critic算法的简单实现**
## **算法概述**
- AC算法的改进版，将TD_error改成了Advantage;

In [24]:
import copy
import platform

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [25]:
if platform.system() == "Darwin":
    PYTORCH_ENABLE_MPS_FALLBACK=1
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
value_method = "Advantage"
all_value_method = ["Q_value","V_value","TD_error","Advantage"]
assert value_method in all_value_method, "You choose a wrong value_method!"

In [26]:
class A2CTrainer(nn.Module):
    def __init__(self, env):
        super(A2CTrainer, self).__init__()
        self.state_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.n
        self.create_training_network()
        self.create_training_method()
        self.GAMMA = 0.9
        self.to(device)

        self.state_batch = []
        self.action_batch = []
        self.reward_batch = []
        self.next_state_batch = []

    def create_training_network(self):
        self.fc = nn.Linear(self.state_dim, 20)
        self.critic = nn.Sequential(self.fc,nn.ReLU(),nn.Linear(20,1))
        self.actor = nn.Sequential(self.fc,nn.ReLU(),nn.Linear(20, self.action_dim))

    def create_training_method(self):
        self.optim = optim.Adam(self.parameters(),lr=0.01)
    
    def choose_action(self, state):
        with torch.no_grad():
            state = torch.tensor(state, device=device)
            action_probs = F.softmax(self.actor(state), dim=-1)
            action = torch.multinomial(action_probs, 1).item()
            return action

    def calculate_batch_advantage(self, state_batch:torch.Tensor, reward_batch:torch.Tensor, done:bool):
        value_batch = self.critic(state_batch)
        values = value_batch[:-1,:].squeeze(-1)
        next_values = value_batch[1:,:].squeeze(-1)
        if done:
            next_values[-1] = 0.

        count = next_values.shape[0]
        q_eval_list = [next_values[-1]]
        for idx in range(count):
            q_eval = reward_batch[-(idx+1)] + self.GAMMA * q_eval_list[idx]
            q_eval_list.append(q_eval)
        q_eval_list.reverse()
        q_evals = torch.tensor(q_eval_list[1:], device=device)
        advantage = q_evals - values
        return advantage

    def calculate_policy_loss(self, state_batch, action_batch, advantages):
        action_logits_batch = self.actor(state_batch)
        log_probs = torch.log(F.softmax(action_logits_batch, dim=-1))
        action_log_probs = torch.gather(log_probs,1,action_batch.unsqueeze(-1)).squeeze(-1)
        policy_loss = action_log_probs * advantages
        return policy_loss

    def perceive(self, state, action, reward, next_state):
        if len(self.state_batch) != 0:
            self.state_batch.pop(-1)
        self.state_batch.append(state)
        self.action_batch.append(action)
        self.reward_batch.append(reward)
        self.state_batch.append(next_state)
        
    def train_loop(self, done):
        state_batch = torch.tensor(self.state_batch, device=device)
        action_batch = torch.tensor(self.action_batch, device=device)
        reward_batch = torch.tensor(self.reward_batch, device=device)

        advantage = self.calculate_batch_advantage(state_batch, reward_batch, done)
        value_loss = torch.square(advantage).mean()
        policy_loss = self.calculate_policy_loss(state_batch, action_batch, advantage.detach()).mean()
        loss = value_loss - policy_loss
        self.optim.zero_grad()
        loss.backward()
        self.optim.step()
        
    def clear_list(self):
        self.state_batch.clear()
        self.action_batch.clear()
        self.reward_batch.clear()
        self.next_state_batch.clear()

In [27]:
import gym
env_name = "CartPole-v1"
env = gym.make(env_name)
agent = A2CTrainer(env)

In [28]:
import time
def main():
    start_time = time.time()
    print("Start training...")
    for episode in range(3000):
        state, _ = env.reset()
        for step in range(300):
            action = agent.choose_action(state)
            next_state, reward, done, _, _ = env.step(action)
            reward = -1 if done else 0.01
            agent.perceive(state, action, reward, next_state)
            state = next_state
            if done:
                break
        agent.train_loop(done)
        agent.clear_list()
        if episode % 100 == 0 and episode != 0:
            total_reward = 0
            for i in range(10):
                state, _ = env.reset()
                for step in range(300):
                    action = agent.choose_action(state)
                    next_state, reward, done, _, _ = env.step(action)
                    total_reward += reward
                    state = next_state
                    if done:
                        break
            print(f"episode {episode} total reward is {total_reward/10}")
    end_time = time.time()
    print(f"total time is {end_time - start_time}")
    print("End training.")

In [29]:
if __name__ == "__main__":
    main()

Start training...
episode 100 total reward is 78.4
episode 200 total reward is 264.4
episode 300 total reward is 241.2
episode 400 total reward is 245.6
episode 500 total reward is 263.5
episode 600 total reward is 242.6
episode 700 total reward is 244.4
episode 800 total reward is 256.9
episode 900 total reward is 263.1
episode 1000 total reward is 239.0
episode 1100 total reward is 226.3
episode 1200 total reward is 226.8
episode 1300 total reward is 243.4
episode 1400 total reward is 300.0
episode 1500 total reward is 290.8
episode 1600 total reward is 283.4
episode 1700 total reward is 291.6
episode 1800 total reward is 280.3
episode 1900 total reward is 292.7
episode 2000 total reward is 300.0
episode 2100 total reward is 298.3
episode 2200 total reward is 281.4
episode 2300 total reward is 295.4
episode 2400 total reward is 300.0
episode 2500 total reward is 300.0
episode 2600 total reward is 256.3
episode 2700 total reward is 192.9
episode 2800 total reward is 278.8
episode 2900

## 实验记录
2023-3-7 \
~~1、将TD_errors改成多步的advantage。 \
结果：效果不太行，训不起来。~~ \
2023-3-8 \
2、上述情况是因为advantage取了绝对值导致的，把绝对值取消。 \
结果：效果有了大幅度提升，但稳定性相对较差。