In [16]:
from torch import nn
import torch
import torch.nn.functional as F
import gym
import torch.optim as optim
from tensorboardX import SummaryWriter
from lib import common,dqn_model
import ptan
import numpy as np

# n-step
REWARD_STEPS=2

# priority replay
PRIO_REPLAY_ALPHA=0.6
BETA_START=0.4
BETA_FRAMES=100000

# C51
Vmax=10
Vmin=-10
N_ATOMS=51
DELTA_Z=(Vmax-Vmin)/(N_ATOMS-1)

class RainbowDQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(RainbowDQN,self).__init__()
        
        self.conv=nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32,64,kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64,64,kernel_size=3, stride=1),
            nn.ReLU()
        )
        
        conv_out_size=self._get_conv_out(input_shape)
        self.fc_val=nn.Sequential(
            dqn_model.NoisyLinear(conv_out_size, 512),
            nn.ReLU(),
            dqn_model.NoisyLinear(512,N_ATOMS)
        )
        
        self.fc_adv=nn.Sequential(
            dqn_model.NoisyLinear(conv_out_size, 512),
            nn.ReLU(),
            dqn_model.NoisyLinear(512,n_actions*N_ATOMS)
        )
        
        self.register_buffer("supports", 
                             torch.arange(Vmin, Vmax+DELTA_Z,DELTA_Z))
        self.softmax=nn.Softmax(dim=1)
        
    def _get_conv_out(self,shape):
        o=self.conv(torch.zeros(1,*shape))
        return int(np.prod(o.size()))
    
    def forward(self,x):
        batch_size=x.size()[0]
        fx=x.float()/256
        conv_out=self.conv(fx).view(batch_size,-1)
        val_out=self.fc_val(conv_out).view(batch_size,1,N_ATOMS)
        adv_out=self.fc_adv(conv_out).view(batch_size,-1,N_ATOMS)
        adv_mean=adv_out.mean(dim=1,keepdim=True)
        return val_out+adv_out-adv_mean
    
    def both(self,x):
        cat_out=self(x)
        probs=self.apply_softmax(cat_out)
        weights=probs*self.supports
        res=weights.sum(dim=2)
        return cat_out,res
    
    def qvals(self,x):
        return self.both(x)[1]
    
    def apply_softmax(self,t):
        return self.softmax(t.view(-1,N_ATOMS)).view(t.size())
    
def calc_loss(batch, batch_weights, net, tgt_net, gamma, device="cpu"):
    states, actions, rewards,dones, next_states=\
        common.unpack_batch(batch)
    batch_size=len(batch)
    
    states_v=torch.tensor(states).to(device)
    actions_v=torch.tensor(actions).to(device)
    next_states_v=torch.tensor(next_states).to(device)
    batch_weights_v=torch.tensor(batch_weights).to(device)
    
    distr_v,qvals_v=net.both(torch.cat((states_v, next_states_v)))
    next_qvals_v=qvals_v[batch_size:]
    distr_v=distr_v[:batch_size]
    
    next_actions_v=next_qvals_v.max(1)[1]
    next_distr_v=tgt_net(next_states_v)
    next_best_distr_v=next_distr_v[range(batch_size),next_actions_v.data]
    next_best_distr_v=tgt_net.apply_softmax(next_best_distr_v)
    next_best_distr=next_best_distr_v.data.cpu().numpy()
    
    dones=dones.astype(np.bool)
    proj_distr=common.distr_projection(next_best_distr,rewards,
                                      dones,Vmin,Vmax,N_ATOMS,gamma)
    
    state_action_values=distr_v[range(batch_size),actions_v.data]
    state_log_sm_v=F.log_softmax(state_action_values,dim=1)
    proj_distr_v=torch.tensor(proj_distr).to(device)
    
    loss_v=-state_log_sm_v*proj_distr_v
    loss_v=batch_weights_v*loss_v.sum(dim=1)
    return loss_v.mean(), loss_v+1e-5

In [18]:
params=common.HYPERPARAMS['pong']
device=torch.device("cuda")
env=gym.make(params["env_name"])
env=ptan.common.wrappers.wrap_dqn(env)

writer=SummaryWriter(comment="-"+params["run_name"]+"-rainbow")
net=RainbowDQN(env.observation_space.shape, env.action_space.n).to(device)
tgt_net=ptan.agent.TargetNet(net)
agent=ptan.agent.DQNAgent(lambda x:net.qvals(x),
                         ptan.actions.ArgmaxActionSelector(),
                         device=device)

exp_source=\
    ptan.experience.ExperienceSourceFirstLast(env,agent,
                                              gamma=params["gamma"],
                                              steps_count=REWARD_STEPS)
buffer=ptan.experience.PrioritizedReplayBuffer(exp_source,
                                               params["replay_size"],
                                               PRIO_REPLAY_ALPHA)
optimizer=optim.Adam(net.parameters(),
                     lr=params["learning_rate"])


frame_idx=0
beta=BETA_START

with common.RewardTracker(writer,params["stop_reward"]) as reward_tracker:
    while True:
        frame_idx += 1
        buffer.populate(1)
        beta=min(1.0, BETA_START+frame_idx*(1.0-BETA_START)/BETA_FRAMES)
        
        new_rewards=exp_source.pop_total_rewards()
        if new_rewards:
            if reward_tracker.reward(new_rewards[0],frame_idx):
                break
        if len(buffer)<params["replay_initial"]:
            continue
        optimizer.zero_grad()
        batch,batch_indices,batch_weights=\
            buffer.sample(params["batch_size"],beta)
        loss_v,sample_prios_v=calc_loss(batch,batch_weights,net,
                                        tgt_net.target_model,
                                        params["gamma"]**REWARD_STEPS,
                                        device=device)
        loss_v.backward()
        optimizer.step()
        buffer.update_priorities(batch_indices,
                                sample_prios_v.data.cpu().numpy())

        if frame_idx%params["target_net_sync"]==0:
            tgt_net.sync()

820: done 1 games, mean reward -21.000, speed 503.34 f/s
1716: done 2 games, mean reward -20.500, speed 487.10 f/s
2533: done 3 games, mean reward -20.667, speed 502.03 f/s
3356: done 4 games, mean reward -20.750, speed 493.04 f/s
4236: done 5 games, mean reward -20.800, speed 501.95 f/s
5056: done 6 games, mean reward -20.833, speed 508.88 f/s
5834: done 7 games, mean reward -20.857, speed 507.33 f/s
6656: done 8 games, mean reward -20.875, speed 504.67 f/s
7568: done 9 games, mean reward -20.778, speed 511.83 f/s
8389: done 10 games, mean reward -20.800, speed 505.01 f/s
9225: done 11 games, mean reward -20.818, speed 502.87 f/s
10042: done 12 games, mean reward -20.833, speed 378.83 f/s
10864: done 13 games, mean reward -20.846, speed 68.79 f/s
11748: done 14 games, mean reward -20.857, speed 68.61 f/s
12692: done 15 games, mean reward -20.867, speed 68.29 f/s
13650: done 16 games, mean reward -20.812, speed 68.59 f/s
14473: done 17 games, mean reward -20.824, speed 68.56 f/s
15295: