# 一个A3C算法的简单实现

In [1]:
import copy
import platform

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
if platform.system() == "Darwin":
    PYTORCH_ENABLE_MPS_FALLBACK=1
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
class ActorCriticTrainer(nn.Module):
    def __init__(self, env):
        super(ActorCriticTrainer, self).__init__()
        self.state_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.n
        self.create_training_network()
        self.create_training_method()
        self.GAMMA = 0.9
        self.to(device)

        self.state_batch = []
        self.action_batch = []
        self.reward_batch = []
        self.next_state_batch = []

    def create_training_network(self):
        self.fc = nn.Linear(self.state_dim, 20)
        self.critic = nn.Sequential(self.fc,nn.ReLU(),nn.Linear(20,1))
        self.actor = nn.Sequential(self.fc,nn.ReLU(),nn.Linear(20, self.action_dim))

    def create_training_method(self):
        self.optim = optim.Adam(self.parameters(),lr=0.001)
        self.value_loss = nn.MSELoss()
        self.actor_loss = nn.LogSoftmax(dim=-1)
    
    def choose_action(self, state):
        with torch.no_grad():
            state = torch.tensor(state, device=device)
            action_probs = F.softmax(self.actor(state), dim=-1)
            action = torch.multinomial(action_probs, 1).item()
            return action

    def calculate_batch_td_error(self, state_batch, reward_batch):
        value_batch = self.critic(state_batch)
        values = copy.copy(value_batch).squeeze(-1)
        next_values = torch.cat((copy.copy(value_batch[1:]).squeeze(-1), torch.tensor([0.], device=device)))
        td_errors = reward_batch + self.GAMMA * next_values - values
        return td_errors

    def calculate_policy_loss(self, state_batch, action_batch, td_errors):
        action_logits_batch = self.actor(state_batch)
        log_probs = torch.log(F.softmax(action_logits_batch, dim=-1))
        action_log_probs = torch.gather(log_probs,1,action_batch.unsqueeze(-1)).squeeze(-1)
        return action_log_probs * td_errors

    def perceive(self, state, action, reward, next_state):
        self.state_batch.append(state)
        self.action_batch.append(action)
        self.reward_batch.append(reward)
        self.next_state_batch.append(next_state)

    def train_loop(self):
        state_batch = torch.tensor(self.state_batch, device=device)
        action_batch = torch.tensor(self.action_batch, device=device)
        reward_batch = torch.tensor(self.reward_batch, device=device)

        td_errors = self.calculate_batch_td_error(state_batch, reward_batch)
        value_loss = torch.square(td_errors).mean()
        policy_loss = self.calculate_policy_loss(state_batch, action_batch, td_errors.detach()).mean()
        loss = value_loss - policy_loss
        self.optim.zero_grad()
        loss.backward()
        self.optim.step()
        
    def clear_list(self):
        self.state_batch.clear()
        self.action_batch.clear()
        self.reward_batch.clear()
        self.next_state_batch.clear()
        


In [4]:
import gym
env_name = "CartPole-v1"
env = gym.make(env_name)
agent = ActorCriticTrainer(env)

In [5]:
import time
def main():
    start_time = time.time()
    for episode in range(3000):
        state, _ = env.reset()
        for step in range(300):
            action = agent.choose_action(state)
            next_state, reward, done, _, _ = env.step(action)
            reward = -1 if done else 0.01
            # agent.train_loop(state, action, reward, next_state)
            agent.perceive(state,action,reward,next_state)
            state = next_state
            if done:
                agent.train_loop()
                agent.clear_list()
                break
        if episode % 100 == 0:
            total_reward = 0
            for i in range(10):
                state, _ = env.reset()
                for step in range(300):
                    action = agent.choose_action(state)
                    next_state, reward, done, _, _ = env.step(action)
                    total_reward += reward
                    state = next_state
                    if done:
                        break
            print(f"episode {episode} total reward is {total_reward/10}")
    end_time = time.time()
    print(f"total time is {end_time - start_time}")

In [6]:
if __name__ == "__main__":
    main()

  state_batch = torch.tensor(self.state_batch, device=device)


episode 0 total reward is 22.7
episode 100 total reward is 23.6
episode 200 total reward is 17.9
episode 300 total reward is 14.1
episode 400 total reward is 12.9
episode 500 total reward is 11.4
episode 600 total reward is 11.2
episode 700 total reward is 10.1
episode 800 total reward is 10.2
episode 900 total reward is 9.7
episode 1000 total reward is 9.7
episode 1100 total reward is 9.3
episode 1200 total reward is 9.7
episode 1300 total reward is 9.2
episode 1400 total reward is 9.7
episode 1500 total reward is 9.2
episode 1600 total reward is 9.4
episode 1700 total reward is 9.4
episode 1800 total reward is 9.4
episode 1900 total reward is 9.4
episode 2000 total reward is 9.4
episode 2100 total reward is 9.1
episode 2200 total reward is 9.3
episode 2300 total reward is 9.8
episode 2400 total reward is 10.1
episode 2500 total reward is 8.8
episode 2600 total reward is 9.2
episode 2700 total reward is 9.2
episode 2800 total reward is 9.2
episode 2900 total reward is 9.2
total time i