# 一个A3C算法的简单实现

In [7]:
import platform

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [8]:
if platform.system() == "Darwin":
    PYTORCH_ENABLE_MPS_FALLBACK=1
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

In [9]:
class ActorCriticTrainer(nn.Module):
    def __init__(self, env):
        super(ActorCriticTrainer, self).__init__()
        self.state_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.n
        self.create_training_network()
        self.create_training_method()
        self.GAMMA = 0.9
        self.to(device)

    def create_training_network(self):
        self.fc = nn.Linear(self.state_dim, 20)
        self.critic = nn.Sequential(self.fc,nn.ReLU(),nn.Linear(20,1))
        self.actor = nn.Sequential(self.fc,nn.ReLU(),nn.Linear(20, self.action_dim))

    def create_training_method(self):
        self.optim = optim.Adam(self.parameters(),lr=0.001)
        self.value_loss = nn.MSELoss()
        self.actor_loss = nn.LogSoftmax(dim=-1)
    
    def choose_action(self, state):
        with torch.no_grad():
            state = torch.tensor(state, device=device)
            action_probs = F.softmax(self.actor(state), dim=-1)
            action = torch.multinomial(action_probs, 1).item()
            return action

    def calculate_td_error(self, state, reward, next_state):
        next_value = self.critic(next_state)
        value = self.critic(state)
        td_error = reward + self.GAMMA * next_value - value
        return td_error

    def calculate_policy_loss(self, state, action, td_error):
        action_logits = self.actor(state)
        log_probs = torch.log(F.softmax(action_logits, dim=-1))
        action_log_probs = torch.gather(log_probs,0,action)
        return action_log_probs * td_error

    def train_loop(self, state, action, reward, state_next):
        state = torch.tensor(state, device=device)
        action = torch.tensor(action, device=device)
        reward = torch.tensor(reward, device=device)
        next_state = torch.tensor(state_next, device=device)
        td_error = self.calculate_td_error(state, reward, next_state)
        value_loss = torch.square(td_error)
        policy_loss = self.calculate_policy_loss(state, action, td_error.item())
        loss = value_loss - policy_loss
        self.optim.zero_grad()
        loss.backward()
        self.optim.step()


In [10]:
import gym
env_name = "CartPole-v1"
env = gym.make(env_name)
agent = ActorCriticTrainer(env)

In [11]:
def main():
    for episode in range(3000):
        state, _ = env.reset()
        for step in range(300):
            action = agent.choose_action(state)
            next_state, reward, done, _, _ = env.step(action)
            reward = -1 if done else 0.01
            agent.train_loop(state, action, reward, next_state)
            state = next_state
            if done:
                break
        if episode % 100 == 0:
            total_reward = 0
            for i in range(10):
                state, _ = env.reset()
                for step in range(300):
                    action = agent.choose_action(state)
                    next_state, reward, done, _, _ = env.step(action)
                    total_reward += reward
                    state = next_state
                    if done:
                        break
            print(f"episode {episode} total reward is {total_reward/10}")

    

In [12]:
if __name__ == "__main__":
    main()

episode 0 total reward is 20.0
episode 100 total reward is 26.9
episode 200 total reward is 16.5
episode 300 total reward is 36.0
episode 400 total reward is 39.4
episode 500 total reward is 39.0
episode 600 total reward is 48.7
episode 700 total reward is 34.9
episode 800 total reward is 79.7


KeyboardInterrupt: 