## D4PG-QR

## RAYLIB

In [None]:
pip install "ray[rllib]" torch torchvision


In [None]:
from ray.rllib.agents.ddpg import DDPGTrainer

config = {
    "env": "Pendulum-v1",  # Replace with hedging environment
    "num_workers": 4,  # Parallel training
    "framework": "torch",
    "exploration_config": {"type": "OrnsteinUhlenbeckNoise"},  # Better exploration
    "model": {
        "custom_model": "quantile_critic_model",  # Specify custom critic model
        "custom_model_config": {"num_quantiles": 32},  # Number of quantiles
    },
    "critic_loss_fn": "quantile_huber_loss",  # Use Quantile Regression Loss
}

trainer = DDPGTrainer(config=config)

for _ in range(100):
    result = trainer.train()
    print(result)


In [None]:
# making critic outputs quantiles instead of a single Q-value
import torch
import torch.nn as nn
import torch.nn.functional as F
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork
from ray.rllib.utils.torch_ops import huber_loss
from ray.rllib.utils.annotations import override

class QuantileCritic(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        self.num_quantiles = model_config.get("custom_model_config", {}).get("num_quantiles", 32)
        self.base_model = FullyConnectedNetwork(
            obs_space, action_space, self.num_quantiles, model_config, name + "_base"
        )

    def forward(self, input_dict, state, seq_lens):
        # Predict quantiles instead of single Q-value
        features = self.base_model(input_dict, state, seq_lens)
        return features, state

    def get_q_values(self, obs):
        # Return quantile estimates
        return self.base_model(obs)


In [None]:
def quantile_huber_loss(predictions, targets, taus):
    """Quantile Huber loss for QR-DQN / D4PG-QR."""
    delta = targets - predictions
    loss = torch.where(delta > 0, taus * delta, (taus - 1) * delta)
    return huber_loss(loss, delta)




In [None]:

from ray.rllib.models import ModelCatalog

ModelCatalog.register_custom_model("quantile_critic_model", QuantileCritic)