In [124]:
import gym
import ptan
import numpy as np
import argparse
from tensorboardX import SummaryWriter
import torch
import torch.nn as nn
import torch.nn.utils as nn_utils
import torch.nn.functional as F
import torch.optim as optim
from lib import common

In [139]:
GAMMA=0.99
LEARNING_RATE=0.001
ENTROPY_BETA=0.01
BATCH_SIZE=128
NUM_ENVS=2
REWARD_STEPS=4
CLIP_GRAD=0.1

In [135]:
class AtariA2C(nn.Module):
    def __init__(self,input_shape,n_actions):
        super(AtariA2C,self).__init__()

        self.conv=nn.Sequential(
            nn.Conv2d(input_shape[0],32,kernel_size=8,stride=4),
            nn.ReLU(),
            nn.Conv2d(32,64,kernel_size=4,stride=2),
            nn.ReLU(),
            nn.Conv2d(64,64,kernel_size=3,stride=1),
            nn.ReLU()
        )

        conv_out_size=self._get_conv_out(input_shape)
        self.policy=nn.Sequential(
            nn.Linear(conv_out_size,512),
            nn.ReLU(),
            nn.Linear(512,n_actions)
        )

        self.value=nn.Sequential(
            nn.Linear(conv_out_size,512),
            nn.ReLU(),
            nn.Linear(512,1)
        )
    
    def _get_conv_out(self,shape):
        o=self.conv(torch.zeros(1,*shape))
        return int(np.prod(o.size()))
    
    def forward(self,x):
        fx=x.float()/256
        conv_out=self.conv(fx).view(fx.size()[0],-1)
        return self.policy(conv_out),self.value(conv_out)

In [142]:
def unpack_batch(batch,net,device='cpu'):
    states=[]
    actions=[]
    rewards=[]
    not_done_idx=[]
    last_states=[]
    for idx,exp in enumerate(batch):
        states.append(np.array(exp.state,copy=False))
        actions.append(int(exp.action))
        rewards.append(exp.reward)
        if exp.last_state is not None:
            not_done_idx.append(idx)
            last_states.append(np.array(exp.last_state,copy=False))
    
    states_v=torch.FloatTensor(np.array(states,copy=False)).to(device)
    actions_t=torch.LongTensor(actions).to(device)

    rewards_np=np.array(rewards,dtype=np.float32)
    if not_done_idx:
        last_states_v=torch.FloatTensor(np.array(last_states,copy=False)).to(device)
        last_vals_v=net(last_states_v)[1]
        last_vals_np=last_vals_v.data.cpu().numpy()[:,0]
        last_vals_np*=GAMMA**REWARD_STEPS
        rewards_np[not_done_idx]+=last_vals_np
    
    ref_vals_v=torch.FloatTensor(rewards_np).to(device)
    return states_v,actions_t,ref_vals_v

In [143]:
device=torch.device('cuda')
make_env=lambda:ptan.common.wrappers.wrap_dqn(gym.make('PongNoFrameskip-v4'))
envs=[make_env() for _ in range(NUM_ENVS)]
writer=SummaryWriter(comment='-pong-a2c')
net=AtariA2C(envs[0].observation_space.shape,envs[0].action_space.n).to(device)
print(net)
agent=ptan.agent.PolicyAgent(lambda x:net(x)[0],apply_softmax=True,device=device)
exp_source=ptan.experience.ExperienceSourceFirstLast(
    envs,agent,gamma=GAMMA,steps_count=REWARD_STEPS)
optimizer=optim.Adam(net.parameters(),lr=LEARNING_RATE,eps=1e-3)
batch=[]

with common.RewardTracker(writer,stop_reward=18) as tracker:
    with ptan.common.utils.TBMeanTracker(writer,batch_size=10) as tb_tracker:
        for step_idx,exp in enumerate(exp_source):
            batch.append(exp)
            new_rewards=exp_source.pop_total_rewards()
            if new_rewards:
                if tracker.reward(new_rewards[0],step_idx):
                    break
            
            if len(batch)<BATCH_SIZE:
                continue
            
            states_v,actions_t,vals_ref_v=unpack_batch(batch,net,device=device)
            batch.clear()

            optimizer.zero_grad()
            logits_v,value_v=net(states_v)
            loss_value_v=F.mse_loss(value_v.squeeze(-1),vals_ref_v)
            
            log_prob_v=F.log_softmax(logits_v,dim=1)
            adv_v=vals_ref_v-value_v.detach()
            log_prob_actions_v=adv_v*log_prob_v[range(BATCH_SIZE),actions_t]
            loss_policy_v=-log_prob_actions_v.mean()

            prob_v=F.softmax(logits_v,dim=1)
            entropy_loss_v=ENTROPY_BETA*(prob_v*log_prob_v).sum(dim=1).mean()

            loss_policy_v.backward(retain_graph=True)
            grads=np.concatenate([
                p.grad.data.cpu().numpy().flatten()
                for p in net.parameters()
                if p.grad is not None
                ])
            
            loss_v=entropy_loss_v+loss_value_v
            loss_v.backward()
            nn_utils.clip_grad_norm_(net.parameters(),CLIP_GRAD)
            optimizer.step()
            loss_v+=loss_policy_v

            tb_tracker.track("advantage",       adv_v, step_idx)
            tb_tracker.track("values",          value_v, step_idx)
            tb_tracker.track("batch_rewards",   vals_ref_v, step_idx)
            tb_tracker.track("loss_entropy",    entropy_loss_v, step_idx)
            tb_tracker.track("loss_policy",     loss_policy_v, step_idx)
            tb_tracker.track("loss_value",      loss_value_v, step_idx)
            tb_tracker.track("loss_total",      loss_v, step_idx)
            tb_tracker.track("grad_l2",         np.sqrt(np.mean(np.square(grads))), step_idx)
            tb_tracker.track("grad_max",        np.max(np.abs(grads)), step_idx)
            tb_tracker.track("grad_var",        np.var(grads), step_idx)


AtariA2C(
  (conv): Sequential(
    (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (3): ReLU()
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (5): ReLU()
  )
  (policy): Sequential(
    (0): Linear(in_features=3136, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=6, bias=True)
  )
  (value): Sequential(
    (0): Linear(in_features=3136, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=1, bias=True)
  )
)
2059: done 1 games, mean reward -21.000, speed 32.57 f/s
2396: done 2 games, mean reward -20.000, speed 36.87 f/s
4003: done 3 games, mean reward -19.667, speed 31.84 f/s


KeyboardInterrupt: 

In [146]:
np.concatenate([
                p.grad.cpu().numpy().flatten()
                for p in net.parameters()
                #if p.grad is not None
                ])

array([-1.20283005e-06, -1.18845026e-06, -1.63693016e-06, ...,
        2.66320328e-03,  1.16737334e-04,  6.51394716e-03], dtype=float32)