In [1]:
# Quantile Regression Critic
import torch
import torch.nn as nn

In [2]:
def weight_init(layers):
    for layer in layers:
        torch.nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')

In [3]:
class QR_SafetyCritic(nn.Module):
    """Quantile Regression Safety Critic
        Args:
            obs_dim: state dimension
            action_dim: action dimensions (countinous action space)
            hidden_dim: -
            num_qunatiles: number of quantiles to approximate quantile distribution (32 in paper)
            risk_level: if risk_level is given, cvar is directly appoximated by a single quantile
                        sampled from U(1-risk_level,1)
    """
    def __init__(self, obs_dim, action_dim, hidden_dim, num_quantiles, risk_level=None):
        super(QR_SafetyCritic, self).__init__()

        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.num_q = num_quantiles

        if risk_level:
            # approximate cvar quantile
            self.risk_level = 1-risk_level  # ValueAtRisk
            self.num_q = 1

        self.head = nn.Linear(self.obs_dim + self.action_dim, hidden_dim)
        self.lin1 = nn.Linear(hidden_dim, hidden_dim)
        self.lin2 = nn.Linear(hidden_dim, self.num_q)
        weight_init([self.head, self.lin1])
    
    def forward(self, s, a):
        """Forward fct of QR-safety critic
        Args:
            input: (state, action)
        Returns:
            Quantile distribution of cost / approximated cvar
        """
        obs_action = torch.cat([s, a], dim=-1)
        x = torch.relu(self.head(obs_action))
        x = torch.relu(self.lin1(x))
        out = self.lin2(x)
        return out.view(obs_action.shape[0], self.num_q, 1)

In [4]:
# test
critic = QR_SafetyCritic(obs_dim=10, action_dim=10, hidden_dim=64, num_quantiles=32)
s = torch.randn(1, 10)
a = torch.randn(1, 10)

In [5]:
critic(s,a).shape  # do we want to keep this shape?

torch.Size([1, 32, 1])

In [11]:
# QUANTILE REGRESSION LOSS + HUBER LOSS
def calculate_huber_loss_and_quantile_loss(td_errors, k=1.0, n=32):
    """
    Calculate huber loss element-wisely depending on kappa k.
    """
    huber_l = torch.where(td_errors.abs() <= k, 0.5 * td_errors.pow(2), k * (td_errors.abs() - 0.5 * k))
    #print(huber_l.shape)
    assert huber_l.shape == (td_errors.shape[0], n, n), "huber loss has wrong shape"
    
    quantile_tau = torch.FloatTensor([i/n for i in range(1,n+1)])
    #print(quantile_tau)
    quantil_l = abs(quantile_tau -(td_errors.detach() < 0).float()) * huber_l / 1.0
    #print(quantil_l.shape)
    return quantil_l

In [12]:
# batch_size, ...
batch_size = 12
obs_dim = 10
action_dim = 10
num_quantiles = 32

# initialize qr-networks
critic = QR_SafetyCritic(obs_dim=10, action_dim=10, hidden_dim=64, num_quantiles=num_quantiles)
critic_target = QR_SafetyCritic(obs_dim=10, action_dim=10, hidden_dim=64, num_quantiles=num_quantiles)

# assume some batch_sample
states = torch.randn(batch_size, obs_dim)
actions = torch.randn(batch_size, action_dim)
next_states = torch.randn(batch_size, obs_dim)
rewards = torch.randn(batch_size, 1)
dones = torch.zeros(batch_size, 1)
# simulate policy
next_action_new = torch.randn(batch_size, action_dim)

# calculate q-targets
# td-target: r + (gamma)*(1-d)*Q(next_state,a_new')
q_target_next = critic_target(next_states, next_action_new).transpose(1,2)  # (bs,N,1) -> (bs,1,N)
assert q_target_next.shape == (batch_size,1, num_quantiles)
q_target = rewards.unsqueeze(-1) + 0.99 * (1-dones.unsqueeze(-1)) * q_target_next
q = critic(states, actions)
td = q_target - q
assert td.shape == (batch_size, num_quantiles, num_quantiles)
#print(q_target.shape)
#print(q.shape)
#print(td_error.shape)
loss = calculate_huber_loss_and_quantile_loss(td)
loss = loss.sum(dim=1).mean(dim=1)  # formula 8.5 // not entirely sure if we should sum first and then take mean
loss = loss.mean()  # take mean over batches

In [26]:
x = torch.randn(2,5,1)
print(x.squeeze(-1))
print(x.squeeze(-1)[:, -2:])
print(x.squeeze(-1)[:, -2:].mean(-1, keepdim=True))

tensor([[ 1.6506, -0.3148,  1.2832,  1.9226, -0.0061],
        [-2.9070,  1.6208, -2.3225,  1.3767, -0.1619]])
tensor([[ 1.9226, -0.0061],
        [ 1.3767, -0.1619]])
tensor([[0.9582],
        [0.6074]])
