In [None]:
# default_exp rl.agents.ddpg

# DDPG
> An implementation of DDPG, Deep Deterministic Policy Gradient.

Reference:
1. https://github.com/massquantity/DBRL/blob/master/dbrl/models/ddpg.py
2. https://www.cnblogs.com/massquantity/p/13842139.html

**Deterministic Policy Gradient (DPG)** is a type of Actor-Critic RL algorithm that uses two neural networks: one for estimating the action value function, and the other for estimating the optimal target policy. The **Deep Deterministic Policy Gradient** (**DDPG**) agent builds upon the idea of DPG and is quite efficient compared to vanilla Actor-Critic agents due to the use of deterministic action policies.

DDPG, or Deep Deterministic Policy Gradient, is an actor-critic, model-free algorithm based on the deterministic policy gradient that can operate over continuous action spaces. It combines the actor-critic approach with insights from DQNs: in particular, the insights that 1) the network is trained off-policy with samples from a replay buffer to minimize correlations between samples, and 2) the network is trained with a target Q network to give consistent targets during temporal difference backups. DDPG makes use of the same ideas along with batch normalization.

It combines ideas from DPG (Deterministic Policy Gradient) and DQN (Deep Q-Network). It uses Experience Replay and slow-learning target networks from DQN, and it is based on DPG, which can operate over continuous action spaces.

<img src='https://camo.githubusercontent.com/401c65749e015a97c17d7145daef95c88fd7c3affb829e643b84cd0c865e1b18/68747470733a2f2f6769746875622e636f6d2f5265636f4875742d5374616e7a61732f533735383133392f7261772f6d61696e2f696d616765732f646470675f616c676f2e706e67'>

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
import torch
from torch import nn, optim
import torch.nn.functional as functional

from typing import List
from copy import deepcopy
import numpy as np

from recohut.models.layers.ou_noise import OUNoise
from recohut.models.actor_critic import Actor, Critic
from recohut.models.embedding import GroupEmbedding
from recohut.rl.memory import ReplayMemory

In [None]:
#export
class DDPGAgent(object):
    """
    DDPG (Deep Deterministic Policy Gradient) Agent
    """

    def __init__(self, config, noise: OUNoise, group2members_dict: dict, verbose=False):
        """
        Initialize DDPGAgent
        :param config: configurations
        :param group2members_dict: group members data
        :param verbose: True to print networks
        """
        self.config = config
        self.noise = noise
        self.group2members_dict = group2members_dict
        self.tau = config.tau
        self.gamma = config.gamma
        self.device = config.device

        self.embedding = GroupEmbedding(embedding_size=config.embedding_size,
                                         user_num=config.user_num,
                                         item_num=config.item_num).to(config.device)
        self.actor = Actor(embedded_state_size=config.embedded_state_size,
                                 action_weight_size=config.embedded_action_size,
                                 hidden_sizes=config.actor_hidden_sizes).to(config.device)
        self.actor_target = Actor(embedded_state_size=config.embedded_state_size,
                                        action_weight_size=config.embedded_action_size,
                                        hidden_sizes=config.actor_hidden_sizes).to(config.device)
        self.critic = Critic(embedded_state_size=config.embedded_state_size,
                                   embedded_action_size=config.embedded_action_size,
                                   hidden_sizes=config.critic_hidden_sizes).to(config.device)
        self.critic_target = Critic(embedded_state_size=config.embedded_state_size,
                                          embedded_action_size=config.embedded_action_size,
                                          hidden_sizes=config.critic_hidden_sizes).to(config.device)

        if verbose:
            print(self.embedding)
            print(self.actor)
            print(self.critic)

        self.copy_network(self.actor, self.actor_target)
        self.copy_network(self.critic, self.critic_target)

        self.replay_memory = ReplayMemory(buffer_size=config.buffer_size)
        self.critic_criterion = nn.MSELoss()
        self.embedding_optimizer = optim.Adam(self.embedding.parameters(), lr=config.embedding_learning_rate,
                                              weight_decay=config.embedding_weight_decay)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=config.actor_learning_rate,
                                          weight_decay=config.actor_weight_decay)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=config.critic_learning_rate,
                                           weight_decay=config.critic_weight_decay)

    def copy_network(self, network: nn.Module, network_target: nn.Module):
        """
        Copy one network to its target network
        :param network: the original network to be copied
        :param network_target: the target network
        """
        for parameters, target_parameters in zip(network.parameters(), network_target.parameters()):
            target_parameters.data.copy_(parameters.data)

    def sync_network(self, network: nn.Module, network_target: nn.Module):
        """
        Synchronize one network to its target network
        :param network: the original network to be synchronized
        :param network_target: the target network
        :return:
        """
        for parameters, target_parameters in zip(network.parameters(), network_target.parameters()):
            target_parameters.data.copy_(parameters.data * self.tau + target_parameters.data * (1 - self.tau))

    def get_action(self, state: list, item_candidates: list = None, top_K: int = 1, with_noise=False):
        """
        Get one action
        :param state: one environment state
        :param item_candidates: item candidates
        :param top_K: top K items
        :param with_noise: True to with noise
        :return: action
        """
        with torch.no_grad():
            states = [state]
            embedded_states = self.embed_states(states)
            action_weights = self.actor(embedded_states)
            action_weight = torch.squeeze(action_weights)
            if with_noise:
                action_weight += self.noise.get_ou_noise()

            if item_candidates is None:
                item_embedding_weight = self.embedding.item_embedding.weight.clone()
            else:
                item_candidates = np.array(item_candidates)
                item_candidates_tensor = torch.tensor(item_candidates, dtype=torch.int).to(self.device)
                item_embedding_weight = self.embedding.item_embedding(item_candidates_tensor)

            scores = torch.inner(action_weight, item_embedding_weight).detach().cpu().numpy()
            sorted_score_indices = np.argsort(scores)[:top_K]

            if item_candidates is None:
                action = sorted_score_indices
            else:
                action = item_candidates[sorted_score_indices]
            action = np.squeeze(action)
            if top_K == 1:
                action = action.item()
        return action

    def get_embedded_actions(self, embedded_states: torch.Tensor, target=False):
        """
        Get embedded actions
        :param embedded_states: embedded states
        :param target: True for target network
        :return: embedded_actions (, actions)
        """
        if not target:
            action_weights = self.actor(embedded_states)
        else:
            action_weights = self.actor_target(embedded_states)

        item_embedding_weight = self.embedding.item_embedding.weight.clone()
        scores = torch.inner(action_weights, item_embedding_weight)
        embedded_actions = torch.inner(functional.gumbel_softmax(scores, hard=True), item_embedding_weight.t())
        return embedded_actions

    def embed_state(self, state: list):
        """
        Embed one state
        :param state: state
        :return: embedded_state
        """
        group_id = state[0]
        group_members = torch.tensor(self.group2members_dict[group_id], dtype=torch.int).to(self.device)
        history = torch.tensor(state[1:], dtype=torch.int).to(self.device)
        embedded_state = self.embedding(group_members, history)
        return embedded_state

    def embed_states(self, states: List[list]):
        """
        Embed states
        :param states: states
        :return: embedded_states
        """
        embedded_states = torch.stack([self.embed_state(state) for state in states], dim=0)
        return embedded_states

    def embed_actions(self, actions: list):
        """
        Embed actions
        :param actions: actions
        :return: embedded_actions
        """
        actions = torch.tensor(actions, dtype=torch.int).to(self.device)
        embedded_actions = self.embedding.item_embedding(actions)
        return embedded_actions

    def update(self):
        """
        Update the networks
        :return: actor loss and critic loss
        """
        batch = self.replay_memory.sample(self.config.batch_size)
        states, actions, rewards, next_states = list(zip(*batch))

        self.embedding_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()
        embedded_states = self.embed_states(states)
        embedded_actions = self.embed_actions(actions)
        rewards = torch.unsqueeze(torch.tensor(rewards, dtype=torch.int).to(self.device), dim=-1)
        embedded_next_states = self.embed_states(next_states)
        q_values = self.critic(embedded_states, embedded_actions)

        with torch.no_grad():
            embedded_next_actions = self.get_embedded_actions(embedded_next_states, target=True)
            next_q_values = self.critic_target(embedded_next_states, embedded_next_actions)
            q_values_target = rewards + self.gamma * next_q_values

        critic_loss = self.critic_criterion(q_values, q_values_target)
        critic_loss.backward()
        self.critic_optimizer.step()

        self.actor_optimizer.zero_grad()
        embedded_states = self.embed_states(states)
        actor_loss = -self.critic(embedded_states, self.get_embedded_actions(embedded_states)).mean()
        actor_loss.backward()
        self.actor_optimizer.step()
        self.embedding_optimizer.step()

        self.sync_network(self.actor, self.actor_target)
        self.sync_network(self.critic, self.critic_target)

        return actor_loss.detach().cpu().numpy(), critic_loss.detach().cpu().numpy()

In [None]:
class Config(object):
    tau = 1e-3
    gamma = 0.9
    embedding_size = 32
    item_num = 5
    user_num = 5
    actor_hidden_sizes = (128, 64)
    critic_hidden_sizes = (32, 16)
    batch_size = 64
    embedding_weight_decay = 1e-6
    actor_weight_decay = 1e-6
    critic_weight_decay = 1e-6
    embedding_learning_rate = 1e-4
    actor_learning_rate = 1e-4
    critic_learning_rate = 1e-4
    device = torch.device("cpu")
    history_length = 5
    buffer_size = 100
    state_size = history_length + 1
    action_size = 1
    embedded_state_size = state_size * embedding_size
    embedded_action_size = action_size * embedding_size

In [None]:
config = Config()

In [None]:
noise = OUNoise(embedded_action_size = 32,
                ou_mu = 0.0,
                ou_theta = 0.15,
                ou_sigma = 0.2,
                ou_epsilon = 1.0,
)

group2members_dict = {'0':[1,2,3], '1':[1,4,5]}

agent = DDPGAgent(config=config, noise=noise, group2members_dict=group2members_dict, verbose=True)

GroupEmbedding(
  (user_embedding): Embedding(6, 32)
  (item_embedding): Embedding(6, 32)
  (user_attention): Sequential(
    (0): Linear(in_features=32, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=1, bias=True)
  )
  (user_softmax): Softmax(dim=-1)
)
Actor(
  (net): Sequential(
    (0): Linear(in_features=192, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=32, bias=True)
  )
)
Critic(
  (net): Sequential(
    (0): Linear(in_features=224, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=16, bias=True)
    (3): ReLU()
    (4): Linear(in_features=16, out_features=1, bias=True)
  )
)


In [None]:
#export
class DDPG(nn.Module):
    def __init__(
            self,
            actor,
            actor_optim,
            critic,
            critic_optim,
            tau=0.001,
            gamma=0.99,
            policy_delay=1,
            item_embeds=None,
            device=torch.device("cpu")
    ):
        super(DDPG, self).__init__()
        self.actor = actor
        self.actor_optim = actor_optim
        self.critic = critic
        self.critic_optim = critic_optim
        self.tau = tau
        self.gamma = gamma
        self.step = 1
        self.policy_delay = policy_delay
        self.actor_targ = deepcopy(actor)
        self.critic_targ = deepcopy(critic)
        for p in self.actor_targ.parameters():
            p.requires_grad = False
        for p in self.critic_targ.parameters():
            p.requires_grad = False
    #    item_embeds = torch.as_tensor(item_embeds).to(device)
    #    self.item_embeds = item_embeds / (torch.norm(item_embeds, dim=1, keepdim=True) + 1e-7)
        self.item_embeds = torch.as_tensor(item_embeds).to(device)

    def update(self, data):
        critic_loss, y, q = self._compute_critic_loss(data)
        self.critic_optim.zero_grad()
        critic_loss.backward()
        # torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5, 2)
        self.critic_optim.step()

        if self.policy_delay <= 1 or (
                self.policy_delay > 1 and self.step % self.policy_delay == 0
        ):
            actor_loss, action = self._compute_actor_loss(data)
            self.actor_optim.zero_grad()
            actor_loss.backward()
            self.actor_optim.step()

            with torch.no_grad():
                self.soft_update(self.actor, self.actor_targ)
                self.soft_update(self.critic, self.critic_targ)
        else:
            actor_loss = action = None

        self.step += 1
        info = {
            "actor_loss": (
                actor_loss.cpu().detach().item()
                if actor_loss is not None
                else None
            ),
            "critic_loss": critic_loss.cpu().detach().item(),
            "y": y, "q": q,
            "action": action
        }
        return info

    def compute_loss(self, data):
        actor_loss, action = self._compute_actor_loss(data)
        critic_loss, y, q = self._compute_critic_loss(data)
        info = {
            "actor_loss": (
                actor_loss.cpu().detach().item()
                if actor_loss is not None
                else None
            ),
            "critic_loss": critic_loss.cpu().detach().item(),
            "y": y, "q": q,
            "action": action
        }
        return info

    def _compute_actor_loss(self, data):
        state, action = self.actor(data)
        actor_loss = -self.critic(state, action).mean()
        return actor_loss, action

    def _compute_critic_loss(self, data):
        with torch.no_grad():
            r, done = data["reward"], data["done"]
            next_s = self.actor_targ.get_state(data, next_state=True)
            next_a = self.actor_targ.get_action(next_s)
            q_targ = self.critic_targ(next_s, next_a)
            y = r + self.gamma * (1. - done) * q_targ

        s = self.actor.get_state(data)
        a = self.item_embeds[data["action"]]
        q = self.critic(s, a)
        critic_loss = F.mse_loss(q, y)
        return critic_loss, y, q

    def soft_update(self, net, target_net):
        for targ_param, param in zip(target_net.parameters(), net.parameters()):
            targ_param.data.copy_(
                targ_param.data * (1. - self.tau) + param.data * self.tau
            )

    def select_action(self, data, *args):
        with torch.no_grad():
            _, action = self.actor(data)
        return action

    def forward(self, state):
        action = self.actor.get_action(state)
        action = action / torch.norm(action, dim=1, keepdim=True)
        item_embeds = self.item_embeds / torch.norm(
            self.item_embeds, dim=1, keepdim=True
        )
        scores = torch.matmul(action, item_embeds.T)
        _, rec_idxs = torch.topk(scores, 10, dim=1)
        return rec_idxs

In [None]:
#hide
!pip install -q watermark
%reload_ext watermark
%watermark -a "Sparsh A." -m -iv -u -t -d

Author: Sparsh A.

Last updated: 2021-12-19 10:30:22

Compiler    : GCC 7.5.0
OS          : Linux
Release     : 5.4.104+
Machine     : x86_64
Processor   : x86_64
CPU cores   : 2
Architecture: 64bit

torch  : 1.10.0+cu111
IPython: 5.5.0

