In [2]:
import torch
import torch.nn as nn

In [4]:
class QuantileNet(nn.Module):
    
    def __init__(self, state_dim, action_dim, num_quants):
        super().__init__()
        self.action_dim = action_dim
        self.num_quants = num_quants
        self.all_layers = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim * num_quants)
        )
        
    def forward(self, states):
        out = self.all_layers(states)
        out = out.view(-1, self.action_dim, self.num_quants)
        return out

In [8]:
class ReplayBuffer:
    
    pass

In [26]:
class ParamsPool:
    
    def __init__(
        self, 
        state_dim, 
        action_dim, 
        num_quants,
        gamma=0.99
    ):
        self.quantile_net = QuantileNet(state_dim, action_dim, num_quants)
        self.quantile_net_target = QuantileNet(state_dim, action_dim, num_quants)
        self.gamma = gamma
        self.num_quants = num_quants
        
    def update_networks(self, batch):
        
        # batch.states & batch.next_states: (bs, state_dim)
        # batch.actions: (bs, action_dim)
        # batch.r: (bs, 1)
        # batch.d: (bs, 1)
        
        # ==============================
        # compute prediction
        # ==============================
        
        # part of step 4 in algorithm
        
        # ==============================
        # compute target
        # ==============================
        
        # step 1 in algorithm
        next_quantiles = self.quantile_net_target(batch.next_states) # (bs, action_dim, num_quants)
        next_qvalues = next_quantiles.mean(dim=2)  # average over all quantiles but not actions; (bs, action_dim)
        
        # step 2 in algorithm
        next_greedy_actions = next_qvalues.argmax(dim=1, keepdim=True)  # (bs, 1)
        next_greedy_actions = next_greedy_actions.unsqueeze(-1).repeat(1, 1, self.num_quants) # (bs, 1, num_quants)
        
        # step 3 in algorithm
        next_greedy_quantiles = next_quantiles.gather(next_greedy_actions).squeeze() # (bs, num_quants)
        samples_from_sample_bellman_updates = batch.r + self.gamma * (1 - batch.d) * next_greedy_quantiles # (bs, num_quants)
        
        # ==============================
        # compute loss
        # ==============================
        
        # part of step 4 in algorith
        
        # ==============================
        # do backprop and grad descent
        # ==============================
        
        # ==============================
        # update target networks
        # ==============================

- ReplayBuffer: store data to train on
- ParamsPool: contain the parameters; update and use the parameters; need samples from replay buffer (the update networks step correspond to what happens to one batch of data)
- TrainingLoop: run episodes in environment, store data in replay buffer, sample data from replay buffer to train params pool