# **一个Actor-Critic算法的（批处理方式）简单实现**
## **算法概述**
- 采用批输入数据的方式加速训练;
- 批输入的数据包括TD误差以及log概率;
- 采用TD误差进行批处理训练过程及其不稳定，很难训起来。

In [58]:
import copy
import platform

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [59]:
if platform.system() == "Darwin":
    PYTORCH_ENABLE_MPS_FALLBACK=1
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
value_method = "TD_error"
all_value_method = ["Q_value","V_value","TD_error","Advantage"]
assert value_method in all_value_method, "You choose a wrong value_method!"

In [60]:
class ActorCriticTrainer(nn.Module):
    def __init__(self, env):
        super(ActorCriticTrainer, self).__init__()
        self.state_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.n
        self.create_training_network()
        self.create_training_method()
        self.GAMMA = 0.9
        self.to(device)

        self.state_batch = []
        self.action_batch = []
        self.reward_batch = []
        self.next_state_batch = []

    def create_training_network(self):
        self.fc = nn.Linear(self.state_dim, 20)
        self.critic = nn.Sequential(self.fc,nn.ReLU(),nn.Linear(20,1))
        self.actor = nn.Sequential(self.fc,nn.ReLU(),nn.Linear(20, self.action_dim))

    def create_training_method(self):
        self.optim = optim.Adam(self.parameters(),lr=0.001)
        # self.value_loss = nn.MSELoss()
        # self.actor_loss = nn.LogSoftmax(dim=-1)
    
    def choose_action(self, state):
        with torch.no_grad():
            state = torch.tensor(state, device=device)
            action_probs = F.softmax(self.actor(state), dim=-1)
            action = torch.multinomial(action_probs, 1).item()
            return action

    def calculate_batch_td_error(self, state_batch:torch.Tensor, reward_batch:torch.Tensor, done:bool):
        value_batch = self.critic(state_batch)
        values = value_batch[:-1,:].squeeze(-1)
        next_values = value_batch[1:,:].squeeze(-1)
        if done:
            next_values[-1] = 0.
        td_errors = reward_batch + self.GAMMA * next_values - values
        return td_errors

    def calculate_policy_loss(self, state_batch, action_batch, td_errors):
        action_logits_batch = self.actor(state_batch)
        log_probs = torch.log(F.softmax(action_logits_batch, dim=-1))
        action_log_probs = torch.gather(log_probs,1,action_batch.unsqueeze(-1)).squeeze(-1)
        policy_loss = action_log_probs * torch.abs(td_errors)
        return policy_loss
    
    def calculate_batch_advantage(self, state_batch:torch.Tensor, reward_batch:torch.Tensor):
        value_batch = self.critic(state_batch)
        pass

    def perceive(self, state, action, reward, next_state):
        if len(self.state_batch) != 0:
            self.state_batch.pop(-1)
        self.state_batch.append(state)
        self.action_batch.append(action)
        self.reward_batch.append(reward)
        self.state_batch.append(next_state)
        
    def train_loop(self, done):
        state_batch = torch.tensor(self.state_batch, device=device)
        action_batch = torch.tensor(self.action_batch, device=device)
        reward_batch = torch.tensor(self.reward_batch, device=device)

        td_errors = self.calculate_batch_td_error(state_batch, reward_batch, done)
        value_loss = torch.square(td_errors).mean()
        policy_loss = self.calculate_policy_loss(state_batch, action_batch, td_errors.detach()).mean()
        loss = value_loss - policy_loss
        self.optim.zero_grad()
        loss.backward()
        self.optim.step()
        
    def clear_list(self):
        self.state_batch.clear()
        self.action_batch.clear()
        self.reward_batch.clear()
        self.next_state_batch.clear()

In [61]:
import gym
env_name = "CartPole-v1"
env = gym.make(env_name)
agent = ActorCriticTrainer(env)

In [62]:
import time
def main():
    start_time = time.time()
    for episode in range(3000):
        state, _ = env.reset()
        for step in range(300):
            action = agent.choose_action(state)
            next_state, reward, done, _, _ = env.step(action)
            reward = -1 if done else 0.01
            agent.perceive(state, action, reward, next_state)
            state = next_state
            if done:
                break
        agent.train_loop(done)
        agent.clear_list()
        if episode % 100 == 0 and episode != 0:
            total_reward = 0
            for i in range(10):
                state, _ = env.reset()
                for step in range(300):
                    action = agent.choose_action(state)
                    next_state, reward, done, _, _ = env.step(action)
                    total_reward += reward
                    state = next_state
                    if done:
                        break
            print(f"episode {episode} total reward is {total_reward/10}")
    end_time = time.time()
    print(f"total time is {end_time - start_time}")

In [63]:
if __name__ == "__main__":
    main()

episode 100 total reward is 22.0
episode 200 total reward is 26.4
episode 300 total reward is 40.4
episode 400 total reward is 42.7
episode 500 total reward is 54.3
episode 600 total reward is 71.8
episode 700 total reward is 56.1
episode 800 total reward is 61.2
episode 900 total reward is 56.8
episode 1000 total reward is 72.9
episode 1100 total reward is 62.8
episode 1200 total reward is 73.8
episode 1300 total reward is 74.8
episode 1400 total reward is 86.3
episode 1500 total reward is 130.1
episode 1600 total reward is 140.4
episode 1700 total reward is 102.8
episode 1800 total reward is 227.9
episode 1900 total reward is 173.8
episode 2000 total reward is 181.9
episode 2100 total reward is 155.2
episode 2200 total reward is 180.2
episode 2300 total reward is 219.4
episode 2400 total reward is 235.9
episode 2500 total reward is 250.9
episode 2600 total reward is 249.8
episode 2700 total reward is 272.9
episode 2800 total reward is 245.1
episode 2900 total reward is 169.6
total ti

## 实验记录
~~2023-3-6 \
1、将td_errors取了绝对值。 \
结果：效果好了一些，因为td_errors有正有负，如果直接乘log_probs再取绝对值的话loss值会抵消，造成训练缓慢的问题。 \
假设：因为变成batch后，损失是取了平均值，但是损失大小没有改变，所以在这里取batch是无效的。 \
2023-3-7 \
2、在以上基础上将epoch调大了一些。 \
结果：确实能取得更好的效果，但是训练时间也随着变成，感觉使用TD_errors进行批处理效果不太行。 (调学习率有一点点作用)\
假设：在其他部分（模型、参数等）不变的情况下，换一种价值计算方式。~~\
3、以上实验都是将train_loop放在了if done中，相当于agent只学到了以done方式结束的trajectories，因此效果不好。\
结果：将train_loop放在step循坏外，训练效果有了大幅度提升。 \
假设：在其他部分（模型、参数等）不变的情况下，换一种价值计算方式。