# **一个Asynchronous Advantage Actor-Critic算法的简单实现**
## **算法概述**
- 基础AC算法的分布式版本
- 分为行为策略和目标策略，行为策略用于与环境进行交互进行环境探索，目标策略用于根据行为策略的探索结果进行模型学习迭代;
- 一个比较典型的A3C范式是，单个目标策略+多个行为策略;
- 行为策略会根据收集到的数据进行loss和梯度的计算，不同行为策略计算得到的梯度会异步的传输给目标策略;
- 目标策略利用得到的梯度进行单步梯度下降，训练后的目标策略模型会定期同步给行为策略，进行新的探索。

论文链接：*https://docs.popo.netease.com/docs/26e9fa439e2b46dcad6374a043725c4a*# (AC算法的理论基础)

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [2]:
import platform
if platform.system == "Darwin":
    PYTORCH_ENABLE_MPS_FALLBACK = 1
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
class A3CTrainer(nn.Module):
    def __init__(self, env):
        super(A3CTrainer, self).__init__()
        self.state_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.n
        
        self.create_training_network()
        self.create_training_method()
        
        self.to(device)

    def create_training_network(self):
        self.fc = nn.Linear(self.state_dim, 20)
        self.value = nn.Sequential(self.fc, nn.ReLU(), nn.Linear(20,1))
        self.actor = nn.Sequential(self.fc, nn.ReLU(), nn.Linear(20, self.action_dim))

    def create_training_method(self):
        self.optim = optim.Adam(self.parameters(), lr=0.001)
    
    def choose_action(self, state):
        with torch.no_grad():
            action_logits = self.actor(state)
            action_probs = F.softmax(action_logits, dim=-1)
            action = torch.multinomial(action_probs,1).item()
            return action, action_probs[action]
    
    def sync_grad(self, policy_agent:nn.Module):
        for behavior_param, target_param in zip(policy_agent.parameters(), self.parameters()):
            if target_param.grad is not None:
                break
            target_param.grad = behavior_param.grad
    
    def update_model(self):
        self.optim.step()
        self.optim.zero_grad()
        

In [4]:
class PolicyAgent(nn.Module):
    GAMMA = 0.9
    
    def __init__(self, env):
        super(PolicyAgent, self).__init__()

        self.state_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.n
        self.create_policy_network()
        self.create_policy_method()
        
        self.state_batch = []
        self.action_batch = []
        self.reward_batch = []
        
        self.to(device)
        
    def create_policy_network(self):
        self.fc = nn.Linear(self.state_dim, 20)
        self.critic = nn.Sequential(self.fc, nn.ReLU(), nn.Linear(20,1))
        self.actor = nn.Sequential(self.fc, nn.ReLU(), nn.Linear(20, self.action_dim))
        
    def create_policy_method(self):
        self.optim = optim.Adam(self.parameters(), lr = 0.001)
    
    def choose_action(self, state):
        with torch.no_grad():
            state = torch.tensor(state, device=device)
            action_logits = self.actor(state)
            action_probs = F.softmax(action_logits,dim=-1)
            action = torch.multinomial(action_probs, 1)
            return action.item()

    def calculate_batch_td_error(self, state_batch:torch.Tensor, reward_batch:torch.Tensor, done:bool):
        value_batch = self.critic(state_batch)
        values = value_batch[:-1,:].squeeze(-1)
        next_values = value_batch[1:,:].squeeze(-1)
        if done:
            next_values[-1] = 0.
        td_errors = reward_batch + self.GAMMA * next_values - values
        return td_errors
    
    def calculate_policy_loss(self, state_batch, action_batch, td_errors):
        action_logits_batch = self.actor(state_batch)
        log_probs = torch.log(F.softmax(action_logits_batch, dim=-1))
        action_log_probs = torch.gather(log_probs,1,action_batch.unsqueeze(-1)).squeeze(-1)
        policy_loss = action_log_probs * td_errors
        return policy_loss
        
    def perceive(self, state, action, reward, next_state):
        if len(self.state_batch) != 0:
            self.state_batch.pop(-1)
        self.state_batch.append(state)
        self.action_batch.append(action)
        self.reward_batch.append(reward)
        self.state_batch.append(next_state)
    
    def calculate_grad(self, done):
        self.optim.zero_grad()
        state_batch = torch.tensor(self.state_batch, device=device)
        action_batch = torch.tensor(self.action_batch, device=device)
        reward_batch = torch.tensor(self.reward_batch, device=device)
        
        td_errors = self.calculate_batch_td_error(state_batch, reward_batch, done)
        
        value_loss = torch.square(td_errors).mean()
        policy_loss = self.calculate_policy_loss(state_batch, action_batch, td_errors.detach()).mean()
        loss = value_loss - policy_loss
        loss.backward()
        
    
    def sync_model(self,model_dict):
        self.load_state_dict(model_dict, strict=False)
        
    def clear_list(self):
        self.state_batch.clear()
        self.action_batch.clear()
        self.reward_batch.clear()
        

In [5]:
import gym
env_name = "CartPole-v1"
env = gym.make(env_name)

policy_agent = PolicyAgent(env)
target_agent = A3CTrainer(env)

In [6]:
def main():
    for episode in range(3000*2):
        state, _ = env.reset()
        for step in range(300):
            action = policy_agent.choose_action(state)
            next_state, reward, done, _, _ = env.step(action)
            reward = -1 if done else 0.01
            policy_agent.perceive(state, action, reward, next_state)
            state = next_state
            if done:
                break
        policy_agent.calculate_grad(done)
        target_agent.sync_grad(policy_agent)
        target_agent.update_model()
        policy_agent.sync_model(target_agent.state_dict())
        policy_agent.clear_list()
        if episode % 100 == 0 and episode !=0:
            total_reward = 0
            for i in range(10):
                state, _ = env.reset()
                for step in range(300):
                    action = policy_agent.choose_action(state)
                    next_state, reward, done, _, _ = env.step(action)
                    total_reward += reward
                    state = next_state
                    if done:
                        break
            print(f"episode {episode} total reward is {total_reward/10}")

In [7]:
if __name__ == "__main__":
    main()

  if not isinstance(terminated, (bool, np.bool8)):
  state_batch = torch.tensor(self.state_batch, device=device)


episode 100 total reward is 23.0
episode 200 total reward is 23.4
episode 300 total reward is 27.3
episode 400 total reward is 33.1
episode 500 total reward is 27.0
episode 600 total reward is 33.7
episode 700 total reward is 31.5
episode 800 total reward is 43.5
episode 900 total reward is 50.8
episode 1000 total reward is 45.6
episode 1100 total reward is 52.5
episode 1200 total reward is 57.6
episode 1300 total reward is 69.5
episode 1400 total reward is 67.0
episode 1500 total reward is 61.2
episode 1600 total reward is 69.7
episode 1700 total reward is 66.3
episode 1800 total reward is 76.1
episode 1900 total reward is 64.8
episode 2000 total reward is 83.0
episode 2100 total reward is 81.0
episode 2200 total reward is 104.3
episode 2300 total reward is 79.9
episode 2400 total reward is 97.3
episode 2500 total reward is 98.2
episode 2600 total reward is 91.9
episode 2700 total reward is 115.5
episode 2800 total reward is 124.6
episode 2900 total reward is 97.9
episode 3000 total r

## 实验记录 
2023-3-8
1. 实现了部分A3C算法，还有获取行为策略的梯度部分以及梯度同步、模型同步部分。 \

2023-3-9
1. 实现了剩余的部分。 \
结果：一个行为策略效果还不错，但是发现m1芯片在小模型训练上表现奇差。在mac上要换成cpu训练。 