# Implementing Dueling Deep and Double Deep Q-Learning
See paper details in [Arxiv](https://arxiv.org/abs/1511.06581).


In [2]:
import torch
import torch.nn.functional as F
import torch.optim as optim
import os
from memory_utils import ReplayBuffer

## Dueling Deep Q-learning 

Same experience replay, memory buffer as seen in normal Deep Q-Learning

![img](./images/dueling_qlearning.png)

In [10]:

DEVICE="cuda:0" if torch.cuda.is_available() else "cpu"

class DuelingDeepQNetwork(torch.nn.Module):
    def __init__(self,lr,n_actions,input_dims,name,chkpt_dir):
        super(DuelingDeepQNetwork,self).__init__()
        self.checkpoint_dir=chkpt_dir
        self.checkpoint_file=os.path.join(self.checkpoint_dir,name)
        self.cnn1=torch.nn.Conv2d(input_dims[0],32,8,stride=4)
        self.cnn2=torch.nn.Conv2d(32,64,4,stride=2)
        self.cnn3=torch.nn.Conv2d(64,64,3,stride=1)
        fc_input_dims=self.calculate_conv_output_dims(input_dims)
        print(f"CNN outputs after flattened: {fc_input_dims}")
        self.fc1=torch.nn.Linear(fc_input_dims,512)
        self.V=torch.nn.Linear(512,1)
        self.A=torch.nn.Linear(512,n_actions)
        self.optimizer=optim.RMSprop(self.parameters(),lr=lr)
        self.criterion=torch.nn.MSELoss()

    def calculate_conv_output_dims(self,input_dims):
        state=torch.zeros(1,*input_dims)
        dims=self.cnn1(state)
        dims=self.cnn2(dims)
        dims=self.cnn3(dims)
        return int(np.prod(dims.size()))

    def forward(self,state):
        out=F.relu(self.cnn1(state))
        out=F.relu(self.cnn2(out))
        out=F.relu(self.cnn3(out)) 
#        out=out.view(out.size()[0],-1)
        out=torch.flatten(out,1)
        out=F.relu(self.fc1(out))
        V=self.V(out)
        A=self.A(out)
        #actions=self.fc2(out)
        return V,A

    def save_checkpoint(self):
        print("...saving checkpoint...")
        torch.save(self.state_dict(),self.checkpoint_file)

    def load_checkpoint(self):
        print("...loading checkpoint...")
        self.load_state_dict(torch.load(self.checkpoint_file))


class DuelingDQNAgent():
    def __init__(self,gamma,epsilon,lr,n_actions,input_dims,mem_size,batch_size,eps_min=0.01,eps_dec=5e-7,replace=1000,algo=None,env_name=None,chkpt_dir='tmp/dqn'):
        self.lr=lr
        self.input_dims=input_dims
        self.n_actions=n_actions
        self.gamma=gamma
        self.batch_size=batch_size
        self.epsilon=epsilon
        self.eps_dec=eps_dec
        self.eps_min=eps_min
        self.replace_target_cnt=replace
        self.env_name=env_name
        self.algo=algo
        self.chkpt_dir=chkpt_dir
        self.action_space=[i for i in range(self.n_actions)]
        self.learn_step_counter=0
        self.q_eval=DuelingDeepQNetwork(self.lr,self.n_actions,self.input_dims,name=f"{self.env_name}_{self.algo}_q_eval",chkpt_dir=self.chkpt_dir)
        self.q_next=DuelingDeepQNetwork(self.lr,self.n_actions,self.input_dims,name=f"{self.env_name}_{self.algo}_q_next",chkpt_dir=self.chkpt_dir)
        self.memory=ReplayBuffer(mem_size,input_dims,n_actions)
        
    def choose_action(self,observation):
        if np.random.random() > self.epsilon:
            state=torch.tensor([observation],dtype=torch.float).to(DEVICE)
            _,advantage=self.q_eval(state)
            actions=self.q_eval(state)
            action=torch.argmax(advantage).item()
        else:
            action=np.random.choice(self.action_space)
        return action


    def store_transition(self,state,action,reward,state_,done):
        self.memory.store_transition(state,action,reward,state_,done)

    def sample_memory(self):
        state,action, reward,new_state,done= self.memory.sample_buffer(self.batch_size)
        states=torch.tensor(state).to(DEVICE)
        rewards=torch.tensor(reward).to(DEVICE)
        dones=torch.tensor(done).to(DEVICE)
        actions=torch.tensor(action).to(DEVICE)
        states_=torch.tensor(new_state).to(DEVICE)
        return states,actions,rewards,states_,dones
    
    def replace_target_network(self):
        if self.learn_step_counter % self.replace_target_cnt==0:
            self.q_next.load_state_dict(self.q_eval.state_dict())

    def save_models(self):
        self.q_eval.save_checkpoint()
        self.q_next.save_checkpoint()

    def load_models(self):
        self.q_eval.load_checkpoint()
        self.q_next.load_checkpoint()

    def decrease_epsilon(self):
        self.epsilon=self.epsilon-self.eps_dec if self.epsilon > self.eps_min else self.eps_min
        
    def learn(self,state,action,reward,new_state):

        if self.memory.mem_cntr < self.batch_size:
            return

        self.q_eval.optimizer.zero_grad()
        self.replace_target_network()

        states,actions,rewards,states_,dones=self.sample_memory()
        #indices=np.arange(self.batch_size)
        
        V,A=self.q_eval(states)
        V_s,A_s=self.q_next(states_)
        
        q_pred=torch.add(V,(A[:,actions]-A.mean(axis=1,keepdim=True)))[:,actions]
        q_next=torch.add(V_s,(A_s[:,actions]-A_s.mean(axis=1,keepdim=True))).max(dim=1)[0]
        
#        max_actions=torch.argmax(q_eval,dim=1)
        
        q_next[dones]=0.0 ## to account for games where a terminal state was reached
        q_target=rewards+ self.gamma*q_next

        loss= self.q_eval.criterion(q_target,q_pred).to(DEVICE)
        loss.backward()
        self.q_eval.optimizer.step()
        self.learn_step_counter += 1
        self.decrease_epsilon()

In [4]:
from utils import make_env
import numpy as np
import os

env=make_env("PongNoFrameskip-v4")
best_score=-np.inf
load_checkpoint=False
n_games=100
CHECK_DIR="./models/"

if not os.path.isdir(CHECK_DIR):
    os.makedirs(CHECK_DIR)



In [None]:
## Learn Dueling DQN LEARNING

agent=DuelingDQNAgent(gamma=0.99,epsilon=1.0,lr=1e-04,input_dims=(env.observation_space.shape),n_actions=env.action_space.n,
                mem_size=20000,eps_min=0.1,batch_size=32,replace=1000,eps_dec=1e-05,chkpt_dir=CHECK_DIR,algo="DuelingDQNAgent",env_name="PongNoFrameskip-v4")

agent.q_eval.to(DEVICE)
agent.q_next.to(DEVICE)

if load_checkpoint:
    agent.load_models()
    
scores,eps_history,steps_array=[],[],[]
n_steps=0

for i in range(n_games):
    done=False
    score=0
    observation=env.reset()
    
    
    while not done:
        action=agent.choose_action(observation)
        observation_,reward,done,info=env.step(action)
        score+=reward
        
        if not load_checkpoint:
            agent.store_transition(observation,action,reward,observation_,int(done))
            agent.learn(observation,action,reward,observation_)
        observation=observation_
        n_steps+=1
    scores.append(score)
    steps_array.append(n_steps)
    avg_score=np.mean(scores[-100:])
    if (i+1) % 10==0:
        print(f'Episode {i+1} score: {score:.1f} | avg score {avg_score:.1f} | epsilon {agent.epsilon:.3f}')
    
    if avg_score > best_score:
        if not load_checkpoint:
            agent.save_models()
        best_score=avg_score

    eps_history.append(agent.epsilon)

CNN outputs after flattened: 3136
CNN outputs after flattened: 3136


  q_next[dones]=0.0 ## to account for games where a terminal state was reached
  return F.mse_loss(input, target, reduction=self.reduction)


...saving checkpoint...
...saving checkpoint...


## Dueling Double Deep Q- Learning

In [None]:

DEVICE="cuda:0" if torch.cuda.is_available() else "cpu"

class DuelingDoubleDeepQNetwork(torch.nn.Module):
    def __init__(self,lr,n_actions,input_dims,name,chkpt_dir):
        super(DuelingDoubleDeepQNetwork,self).__init__()
        self.checkpoint_dir=chkpt_dir
        self.checkpoint_file=os.path.join(self.checkpoint_dir,name)
        self.cnn1=torch.nn.Conv2d(input_dims[0],32,8,stride=4)
        self.cnn2=torch.nn.Conv2d(32,64,4,stride=2)
        self.cnn3=torch.nn.Conv2d(64,64,3,stride=1)
        fc_input_dims=self.calculate_conv_output_dims(input_dims)
        print(f"CNN outputs after flattened: {fc_input_dims}")
        self.fc1=torch.nn.Linear(fc_input_dims,512)
        self.V=torch.nn.Linear(512,1)
        self.A=torch.nn.Linear(512,n_actions)
        self.optimizer=optim.RMSprop(self.parameters(),lr=lr)
        self.criterion=torch.nn.MSELoss()

    def calculate_conv_output_dims(self,input_dims):
        state=torch.zeros(1,*input_dims)
        dims=self.cnn1(state)
        dims=self.cnn2(dims)
        dims=self.cnn3(dims)
        return int(np.prod(dims.size()))

    def forward(self,state):
        out=F.relu(self.cnn1(state))
        out=F.relu(self.cnn2(out))
        out=F.relu(self.cnn3(out)) 
#        out=out.view(out.size()[0],-1)
        out=torch.flatten(out,1)
        out=F.relu(self.fc1(out))
        V=self.V(out)
        A=self.A(out)
        #actions=self.fc2(out)
        return V,A

    def save_checkpoint(self):
        print("...saving checkpoint...")
        torch.save(self.state_dict(),self.checkpoint_file)

    def load_checkpoint(self):
        print("...loading checkpoint...")
        self.load_state_dict(torch.load(self.checkpoint_file))


class DuelingDDQNAgent():
    def __init__(self,gamma,epsilon,lr,n_actions,input_dims,mem_size,batch_size,eps_min=0.01,eps_dec=5e-7,replace=1000,algo=None,env_name=None,chkpt_dir='tmp/dqn'):
        self.lr=lr
        self.input_dims=input_dims
        self.n_actions=n_actions
        self.gamma=gamma
        self.batch_size=batch_size
        self.epsilon=epsilon
        self.eps_dec=eps_dec
        self.eps_min=eps_min
        self.replace_target_cnt=replace
        self.env_name=env_name
        self.algo=algo
        self.chkpt_dir=chkpt_dir
        self.action_space=[i for i in range(self.n_actions)]
        self.learn_step_counter=0
        self.q_eval=DuelingDoubleDeepQNetwork(self.lr,self.n_actions,self.input_dims,name=f"{self.env_name}_{self.algo}_q_eval",chkpt_dir=self.chkpt_dir)
        self.q_next=DuelingDoubleDeepQNetwork(self.lr,self.n_actions,self.input_dims,name=f"{self.env_name}_{self.algo}_q_next",chkpt_dir=self.chkpt_dir)
        self.memory=ReplayBuffer(mem_size,input_dims,n_actions)
        
    def choose_action(self,observation):
        if np.random.random() > self.epsilon:
            state=torch.tensor([observation],dtype=torch.float).to(DEVICE)
            _,advantage=self.q_eval(state)
            actions=self.q_eval(state)
            action=torch.argmax(advantage).item()
        else:
            action=np.random.choice(self.action_space)
        return action


    def store_transition(self,state,action,reward,state_,done):
        self.memory.store_transition(state,action,reward,state_,done)

    def sample_memory(self):
        state,action, reward,new_state,done= self.memory.sample_buffer(self.batch_size)
        states=torch.tensor(state).to(DEVICE)
        rewards=torch.tensor(reward).to(DEVICE)
        dones=torch.tensor(done).to(DEVICE)
        actions=torch.tensor(action).to(DEVICE)
        states_=torch.tensor(new_state).to(DEVICE)
        return states,actions,rewards,states_,dones
    
    def replace_target_network(self):
        if self.learn_step_counter % self.replace_target_cnt==0:
            self.q_next.load_state_dict(self.q_eval.state_dict())

    def save_models(self):
        self.q_eval.save_checkpoint()
        self.q_next.save_checkpoint()

    def load_models(self):
        self.q_eval.load_checkpoint()
        self.q_next.load_checkpoint()

    def decrease_epsilon(self):
        self.epsilon=self.epsilon-self.eps_dec if self.epsilon > self.eps_min else self.eps_min
        
    def learn(self,state,action,reward,new_state):

        if self.memory.mem_cntr < self.batch_size:
            return

        self.q_eval.optimizer.zero_grad()
        self.replace_target_network()

        states,actions,rewards,states_,dones=self.sample_memory()
        #indices=np.arange(self.batch_size)
        
        V,A=self.q_eval(states)
        V_s,A_s=self.q_next(states_)
        
        q_pred=torch.add(V,(A[:,actions]-A.mean(axis=0,keepdim=True)))[:,actions]
        q_next=torch.add(V_s,(A_s[:,actions]-A_s.mean(axis=0,keepdim=True)))
        
        
        V_p,A_p=self.q_eval(states_)
        q_eval=torch.add(V_p+(A_p[:,actions]-A_p.mean(axis=0,keepdim=True)))
        
        
        max_actions=torch.argmax(q_eval,dim=1)
        
        q_next[dones]=0.0 ## to account for games where a terminal state was reached
        q_target=rewards+ self.gamma*q_next[:,max_actions]

        loss= self.q_eval.criterion(q_target,q_pred).to(DEVICE)
        loss.backward()
        self.q_eval.optimizer.step()
        self.learn_step_counter += 1
        self.decrease_epsilon()



In [None]:
## Learn

agent=DuelingDDQNAgent(gamma=0.99,epsilon=1.0,lr=1e-04,input_dims=(env.observation_space.shape),n_actions=env.action_space.n,
                mem_size=20000,eps_min=0.1,batch_size=32,replace=1000,eps_dec=1e-05,chkpt_dir=CHECK_DIR,algo="DuelingDDQNAgent",env_name="PongNoFrameskip-v4")

agent.q_eval.to(DEVICE)
agent.q_next.to(DEVICE)

if load_checkpoint:
    agent.load_models()
    
scores,eps_history,steps_array=[],[],[]
n_steps=0

for i in range(n_games):
    done=False
    score=0
    observation=env.reset()
    
    
    while not done:
        action=agent.choose_action(observation)
        observation_,reward,done,info=env.step(action)
        score+=reward
        
        if not load_checkpoint:
            agent.store_transition(observation,action,reward,observation_,int(done))
            agent.learn(observation,action,reward,observation_)
        observation=observation_
        n_steps+=1
    scores.append(score)
    steps_array.append(n_steps)
    avg_score=np.mean(scores[-100:])
    if (i+1) % 10==0:
        print(f'Episode {i+1} score: {score:.1f} | avg score {avg_score:.1f} | epsilon {agent.epsilon:.3f}')
    
    if avg_score > best_score:
        if not load_checkpoint:
            agent.save_models()
        best_score=avg_score

    eps_history.append(agent.epsilon)
