In [1]:
# Implicit Quantile Network Critic

In [8]:
import torch
import torch.nn as nn
import numpy as np

In [9]:
def weight_init(layers):
    for layer in layers:
        torch.nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')

In [10]:
class IQN_SafetyCritic(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_dim):
        super(IQN_SafetyCritic, self).__init__()
        
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.K = 32  # number of samples for policy
        self.N = 8  # number of quantile samples
        self.n_cos = 64
        self.hidden_dim = hidden_dim
        
        # Start from 0 (according to paper)
        self.pis = torch.FloatTensor([np.pi*i for i in range(self.n_cos)]).view(1,1,self.n_cos) 

        self.head = nn.Linear(obs_dim+action_dim, hidden_dim) 
        self.cos_embedding = nn.Linear(self.n_cos, hidden_dim)
        self.lin1 = nn.Linear(hidden_dim, hidden_dim)
        self.lin2 = nn.Linear(hidden_dim, 1)
        #weight_init([self.head_1, self.ff_1])

    def calc_cos(self, batch_size, n_tau=8):
        """
        Calculating the cosinus values depending on the number of tau samples
        """
        taus = torch.rand(batch_size, n_tau).unsqueeze(-1) #(batch_size, n_tau, 1)
        cos = torch.cos(taus*self.pis)
        assert cos.shape == (batch_size,n_tau,self.n_cos), "cos shape is incorrect"
        return cos, taus
    
    def forward(self, obs, action, num_tau=8):
        """
        Quantile Calculation depending on the number of tau
        
        Return:
        quantiles [ shape of (batch_size, num_tau, action_size)]  - actions size is 1 for us
        taus [shape of ((batch_size, num_tau, 1))]
        
        """
        batch_size = obs.shape[0]
        obs_action = torch.cat([obs, action], dim=-1)
        
        # cosine embedding
        cos, taus = self.calc_cos(batch_size, num_tau)  # cos (bs, n_tau, n_cos), tau (bs, n_tau, 1)
        cos = cos.view(batch_size*num_tau, self.n_cos) # cos (bs*n_tau, n_cos)
        cos_x = torch.relu(self.cos_embedding(cos)).view(batch_size, num_tau, self.hidden_dim) # (bs, n_tau, hidden_dim)
        # state-action embedding
        x = torch.relu(self.head(obs_action))  # (bs, hidden_dim)
        # combining embdeddings
        # x has shape (batch, hidden_dim) for multiplication –> reshape to (batch, 1, hidden_dim)
        x = (x.unsqueeze(1)*cos_x).view(batch_size*num_tau, self.hidden_dim)
        x = torch.relu(self.lin1(x))
        out = self.lin2(x)
        
        return out.view(batch_size, num_tau, 1), taus

In [18]:
# test
critic = IQN_SafetyCritic(obs_dim=10, action_dim=10, hidden_dim=64)
s = torch.randn(1, 10)
a = torch.randn(1, 10)
distribution, quantiles = critic(s,a)
print(distribution.shape)
print(quantiles.shape)

torch.Size([1, 8, 1])
torch.Size([1, 8, 1])


In [24]:
# QUANTILE REGRESSION LOSS + HUBER LOSS
def quantile_huber_loss(td_errors, taus, k=1.0, n=8):
        """
        Calculate quantiel huber loss element-wisely depending on kappa k and n (number of quantiles)
        """
        # Huber Loss
        huber_l = torch.where(td_errors.abs() <= k, 0.5 * td_errors.pow(2), k * (td_errors.abs() - 0.5 * k))
        assert huber_l.shape == (td_errors.shape[0], n, n), "huber loss has wrong shape"
        # Quantile Huber Loss
        quantil_l = abs(taus -(td_errors.detach() < 0).float()) * huber_l / 1.0
        return quantil_l

In [25]:
# batch_size, ...
batch_size = 12
obs_dim = 10
action_dim = 10
num_quantiles = 8

# initialize qr-networks
critic = IQN_SafetyCritic(obs_dim=10, action_dim=10, hidden_dim=64)
critic_target = IQN_SafetyCritic(obs_dim=10, action_dim=10, hidden_dim=64)

# assume some batch_sample
states = torch.randn(batch_size, obs_dim)
actions = torch.randn(batch_size, action_dim)
next_states = torch.randn(batch_size, obs_dim)
rewards = torch.randn(batch_size, 1)
dones = torch.zeros(batch_size, 1)
# simulate policy
next_action_new = torch.randn(batch_size, action_dim)

# calculate q-targets
# td-target: r + (gamma)*(1-d)*Q(next_state,a_new')
q_target_next, _ = critic_target(next_states, next_action_new)  # (bs,N,1)
q_target_next = q_target_next.transpose(1,2)  # (bs,1,N)
assert q_target_next.shape == (batch_size,1, num_quantiles)
q_target = rewards.unsqueeze(-1) + 0.99 * (1-dones.unsqueeze(-1)) * q_target_next
q, taus = critic(states, actions)
td = q_target - q
assert td.shape == (batch_size, num_quantiles, num_quantiles)
#print(q_target.shape)
#print(q.shape)
#print(td_error.shape)
loss = quantile_huber_loss(td, taus)
loss = loss.sum(dim=1).mean(dim=1)  # formula 8.5 // not entirely sure if we should sum first and then take mean
loss = loss.mean()  # take mean over batches

In [26]:
loss

tensor(2.1040, grad_fn=<MeanBackward0>)