# 一个A3C算法的简单实现

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

## Creating ActorCritic model

In [None]:
class ActorCritic(nn.Module):
    
    def __init__(self, num_inputs, action_space):
        super(ActorCritic, self).__init__()
        self.conv1 = nn.Conv2d(num_inputs, 32, 3 ,stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        
        self.lstm = nn.LSTMCell(32 * 3 * 3, 256)
        
        num_outputs = action_space.n
        self.critical_linear = nn.Linear(256, 1)
        self.actor_linear = nn.Linear(256, num_outputs)
        
        self.train()
        
    def forward(self, inputs):
        inputs, (hx, cx) = inputs
        x = F.elu(self.conv1(inputs))
        x = F.elu(self.conv2(x))
        x = F.elu(self.conv3(x))
        x = F.elu(self.conv4(x))
        
        x = x.view(-1, 32*3*3)
        hx, cx = self.lstm(x, (hx, cx))
        x = hx
        return self.critical_linear(x), self.actor_linear(x)

## Creating atari envs

In [None]:
import cv2
import gym
import numpy as np
from gym.spaces.box import Box

def create_atari_env(env_id):
    env = gym.make(env_id)

    
def _process_frame42(frame):
    frame = frame[34:34 + 160, :160]
    frame = cv2.resize(frame, (80, 80))
    frame = cv2.resize(frame, (42, 42))
    frame = frame.mean(2, keepdims=True)
    frame = frame.astype(np.float32)
    frame *= (1.0 / 255.0)
    frame = np.moveaxis(frame, -1, 0)
    return frame


class AtariRescale42x42(gym.ObservationWrapper):
    def __init__(self, env=None):
        super(AtariRescale42x42, self).__init__(env)
        self.observation_space = Box(0.0, 1.0, [1, 42, 42])
        
    def _observation(self, observation):
        return _pro
    

class NormalizedEnv(gym.ObservationWrapper):
    def __init__(self, env=None):
        super(NormalizedEnv, self).__init__(env)
        self.state_mean = 0
        self.state_std = 0
        self.alpha = 0.9999
        self.num_steps = 0
        
    def _observation(self, observation):
        self.num_steps += 1
        self.state_mean = self.state_mean * self.alpha + observation.mean() * (1 - self.alpha)
        self.state_std = self.state_std * self.alpha + observation.std() * (1 - self.alpha)
        
        unbiased_mean = self.state_mean / (1 - pow(self.alpha, self.num_steps))
        unbiased_std = self.state_std / (1 - pow(self.alpha, self.num_steps))

        return (observation - unbiased_mean) / (unbiased_std + 1e-8)

## Train

In [None]:
def ensure_shared_grads(model, shared_model):
    for param, shared_param in zip(model.parameters(), shared_model.parameters()):
        if shared_model.grad is not None:
            return
        shared_param._grad = param.grad
        
def train(rank, args, shared_model:nn.Module, counter, lock, optimizer=None):
    torch.manual_seed(args.seed + rank)
    
    env = create_atari_env(args.env_name)
    env.seed(args.seed + rank)
    
    model = ActorCritic(env.observation_space.shape[0], env.action_space)
    if optimizer is None:
        optimizer = optim.Adam(shared_model.parameters(), lr=args.lr)
    
    state = env.reset()
    state = torch.from_numpy(state)
    done = True
    
    episode_length = 0
    while True:
        model.load_state_dict(shared_model.state_dict())
        if done:
            cx = torch.zeros(1, 256)
            hx = torch.zeros(1, 256)
        else:
            cx = cx.detach()
            hx = hx.detach()
        
        values = []
        log_probs = []
        rewards = []
        entropies = []
        
        for step in range(args.num_steps):
            episode_length += 1
            value, logit, (hx, cx) = model((state.unsqueeze(0),(hx, cx)))
            prob = F.softmax(logit, dim=-1)
            log_prob = F.log_softmax(logit, dim=-1)
            entropy = -(log_prob * prob).sum(1, keepdim=True)
            entropies.append(entropy)
            
            action = prob.multionmial(num_samples=1).detach()
            log_prob = log_prob.gather(1, action)
            
            state, reward, done, _ = env.step(action.numpy())
            done = done or episode_length >= args.max_episode_length
            reward = max(min(reward, 1), -1) # 把reward限制在[-1,1]内
            
            with lock:
                counter.value += 1
            
            if done:
                episode_length = 0
                state = env.reset()
            
            state = torch.from_numpy(state)
            values.append(value)
            log_probs.append(log_prob)
            rewards.append(reward)
            
            if done:
                break
                
        R = torch.zeros(1, 1)
        if not done:
            value, _, _ = model((state.unsqueeze(0), (hx, cx)))
            R = value.detach()
        
        values.append(R)
        policy_loss = 0
        value_loss = 0
        gae = torch.zeros(1, 1)
        for i in reversed(range(rewards)):
            R = args.gamma * R + rewards[i]
            advantage = R - values[i]
            value_loss = value_loss + 0.5 * advantage.pow(2)
            
            delta_t = rewards[i] + args.gamma * values[i+1]
        