# **一个Advantage Actor-Critic算法的简单实现**
## **算法概述**
- 采用批输入数据的方式加速训练;
- 批输入的数据包括TD误差以及log概率;
- 采用TD误差进行批处理训练过程及其不稳定，很难训起来。


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [4]:
import platform
if platform.system == "Darwin":
    PYTORCH_ENABLE_MPS_FALLBACK = 1
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
class A3CTrainer(nn.Module):
    def __init__(self, env):
        super(A3CTrainer, self).__init__()
        self.state_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.n

    def create_training_network(self):
        self.fc = nn.Linear(self.state_dim, 20)
        self.value = nn.Sequential(self.fc, nn.ReLU(), nn.Linear(20,1))
        self.actor = nn.Sequential(self.fc, nn.ReLU(), nn.Linear(20, self.action_dim))

    def create_training_method(self):
        self.optim = optim(self.parameters(), lr=0.001)
    
    def choose_action(self, state):
        with torch.no_grad():
            action_logits = self.actor(state)
            action_probs = F.softmax(action_logits, dim=-1)
            action = torch.multinomial(action_probs,1).item()
        return action, action_probs[action]
    
    def sync_grad(self, policy_agent:nn.Module):
        for behavior_param, target_param in zip(policy_agent.parameters(), self.parameters()):
            if target_param is not None:
                break
            target_param._grad = behavior_param.grad
    def update_model(self):
        
        self.optim.step()
        

In [10]:
class PolicyAgent(nn.Module):
    GAMMA = 0.9
    
    def __init__(self):
        super(policy_agent, self).__init__()
        self.create_policy_network()
        
        self.state_batch = []
        self.action_batch = []
        self.reward_batch = []
        self.next_state_batch = []
        
    def create_policy_network(self):
        self.fc = nn.Linear(self.state_dim, 20)
        self.critic = nn.Sequential(self.fc, nn.ReLU(), nn.Linear(20,1))
        self.actor = nn.Sequential(self.fc, nn.ReLU(), nn.Linear(20, self.action_dim))
    
    def choose_action(self, state):
        with torch.no_grad():
            state = torch.tensor(state, device=device)
            action_logits = self.actor(state)
            action_probs = F.softmax(action_logits,dim=-1)
            action = torch.multinomial(action_probs, 1).item()
            return action

    def calculate_batch_td_error(self, state_batch, reward_batch, next_batch):
        td_target = reward_batch + self.GAMMA * self.critic(next_batch)
        value = self.critic(state_batch)
        td_error = td_target - value
        return td_error
    
    def calculate_policy_loss(self, state_batch, action_batch, td_errors):
        action_logits_batch = self.actor(state_batch)
        log_probs = torch.log(F.softmax(action_logits_batch, dim=-1))
        action_log_probs = torch.gather(log_probs,1,action_batch.unsqueeze(-1)).squeeze(-1)
        policy_loss = action_log_probs * td_errors
        return policy_loss
        
    def perceive(self, state, action, reward, next_state):
        self.state_batch.append(state)
        self.action_batch.append(action)
        self.reward_batch.append(reward)
        self.next_state_batch.append(next_state)
    
    def calculate_grad(self):
        state_batch = torch.tensor(self.state_batch, device=device)
        action_batch = torch.tensor(self.action_batch, device=device)
        reward_batch = torch.tensor(self.reward_batch, device=device)
        next_state_batch = torch.tensor(self.next_state_batch, device=device)
        
        td_error = self.calculate_batch_td_error(state_batch, reward_batch, next_state_batch)
        
        value_loss = torch.square(td_errors).mean()
        policy_loss = self.calculate_policy_loss(state_batch, action_batch, td_errors.detach()).mean()
        loss = value_loss - policy_loss
        loss.backward()
        
    
    def sync_model(self,model_dict):
        self.load_state_dict(model_dict)
        
    def clear_list(self):
        self.state_batch.clear()
        self.action_batch.clear()
        self.reward_batch.clear()
        self.next_state_batch.clear()
        

In [8]:
import gym
env_name = "CartPole-v1"
env = gym.make(env_name)

policy_agent = PolicyAgent(env)
target_agent = A3CTrainer(env)

In [1]:
def main():
    for epoch in range(3000):
        state = env.reset()
        for step in range(300):
            action, action_probs = policy_agent.choose_action(state)
            next_state, reward, done, _, _ = env.step(action)
            target_agent.train_loop(state, action, reward, next_state, action_probs)
        
            
            

In [2]:
if __name__ == "__main__":
    main()

## 实验记录 
2023-3-8
1. 实现了部分A3C算法，还有获取行为策略的梯度部分以及