In this notebook we will implement a DQN agent that uses the DQN network defined in [DuellingDoubleDQN.py](DuellingDoubleDQN.py) and the replay buffer defined in [prioritized replay buffer](PrioritizedReplayBuffer.py). Notice that we also need to implement Importance Sampling in order to properly use the prioritized replay buffer.

In [1]:
import torch 
import torch.nn as nn
from torch.nn import functional as F
import numpy as np

from DuellingDoubleDQN import DuellingDQN
from PrioritizedReplayBuffer import PrioritizedReplay

In [4]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

cuda


In [None]:
class DuellingDoubleDQNAgent():
    """
    device = cpu or gpu
    num_agents: number of agents 
    im_height: height of input image
    im_width: width of input image
    obs_in_channels: number of input channels of the grid image
    conv_dim: number of channels after passing through 1st conv. layer of the DQN 
    kernel_size: kernel size of the conv. layers in the DQN
    n_action: number of discrete actions
    buffer_size: replay buffer size
    roll_out: length of roll out for n-step bootstrap
    replay_batch_size: batch_size of replay 
    epsilon: exploration rate
    epsilon_decay_rate: rate by which to scale down epsilon after every few steps
    tau: parameter for soft update of the target network
    gamma: discount factor for discouted rewards
    update_interval: interval after which to update the network with new parameters
    """
    
    def __init__(self, device = device, num_agents=1, im_height = 464, im_width = 464, obs_in_channels=3, conv_dim = 32,  
                 kernel_size = 6, n_actions = 5, buffer_size = 10**6, roll_out = 4, replay_batch_size = 32,
                 lr = 1e-4, epsilon = 0.3, epsilon_decay_rate = 0.999, tau = 1e-3, gamma = 1, update_interval = 4):
        super().__init__()
        self.device = device
        self.num_agents = num_agents
        self.im_height = im_height
        self.im_width = im_width
        self.in_channels = obs_in_channels
        self.conv_dim = conv_dim
        self.kernel_size = kernel_size
        self.n_actions = n_actions
        self.buffer_size = buffer_size
        self.roll_out = roll_out
        self.replay_batch_size = replay_batch_size
        self.lr = lr
        self.epsilon = epsilon
        self.epsilon_decay_rate = epsilon_decay_rate
        self.tau = tau
        self.gamma = gamma # we want the train to find the shortest possible path to its destination
                           # for every time step it gets a reward of -1
                           # it makes sense to keep gamma = 1 
        self.update_every = update_interval        
                
        self.local_net = [DuellingDoubleDQN(obs_in_channels, conv_dim, kernel_size, n_actions) for _ in range(num_agents)]
        self.target_net = []
        self.optimizer = []
        
        for agent in range(num_agents):
            local = self.local_net[agent]
            target = DuellingDoubleDQN(obs_in_channels, conv_dim, kernel_size, n_actions)
            
            # copy the local networks parameters to the target network
            for local_param, target_param in zip(local.parameters(), target.parameters):
                target_param.data.copy_(local_param.data)
            
            self.target_net.append(target)
            
            local = local.to(device)
            target = target.to(device)
            
            # set the optimizer for the local network
            optim = torch.optim.Adam(local.parameters(), lr = lr)
        
        # loss function to compare the Q-value of the local and the target network
        self.criterion = nn.MSELoss()
        
        # need to fix this to store images as memories. 
        # for the time being using a dummy value for n_states
        n_states = 264
        self.memory = PrioritizedReplay(buffer_size, n_states, n_actions, roll_out, num_agents)
        
        def act(self, observations):
            # function to produce an action from the DQN
            
                
        
                
                
                
        