# Prioritized Experience Replay

Baseado no paper [Prioritized Experience Replay, Schaul et al, 2015.](https://deepmind.com/research/publications/prioritized-experience-replay)

In [None]:
import random
import numpy as np
import matplotlib.pyplot as plt
import gym

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
class Memory:
    """Class to help the deep q network to replay
    the experiences it has had. It is implemented
    with priority sampling."""
    def __init__(self, sdim, maxlen, alpha, offset):
        self.device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._maxlen = maxlen
        self._alpha  = alpha
        self._offset = offset
        self._memory = {
            "states":   torch.zeros((maxlen,*sdim), dtype=torch.float64).to(self.device),
            "actions":  torch.zeros(maxlen,         dtype=torch.int64  ).to(self.device),
            "rewards":  torch.zeros(maxlen,         dtype=torch.float64).to(self.device),
            "nstates":  torch.zeros((maxlen,*sdim), dtype=torch.float64).to(self.device),
            "dones":    torch.zeros(maxlen,         dtype=torch.float64).to(self.device),
            "priority": torch.zeros(maxlen,         dtype=torch.float64).to(self.device)
        }
        self._index = 0
        self._len = 0

    def _iindex(self):
        """Increases the index according to the
        maximum lenght. If the index is greater 
        then the maximum lenght, it starts again."""
        self._index = (self._index+1) % self._maxlen

    def push(self, s, a, r, s2, d):
        if len(self) > 0: 
            mx = torch.max(self._memory["priority"][:len(self)])
        else:
            mx = 1
        if len(self) < self._maxlen:
            self._len += 1
        self._memory["states"][self._index]  = torch.tensor(s, dtype=torch.float64)
        self._memory["actions"][self._index] = a
        self._memory["rewards"][self._index] = r
        self._memory["nstates"][self._index] = torch.tensor(s2, dtype=torch.float64)
        self._memory["dones"][self._index]   = d
        self._memory["priority"][self._index] = mx
        self._iindex()

    def sample(self, batch):
        n = len(self)
        indexes = random.choices(range(n),
                                 weights = self._memory["priority"].tolist()[:n],
                                 k = batch)
        to_return = (
            self._memory["states"  ][indexes],
            self._memory["actions" ][indexes],
            self._memory["rewards" ][indexes],
            self._memory["nstates" ][indexes],
            self._memory["dones"   ][indexes],
            self._memory["priority"][indexes]/self._memory["priority"].sum(),
            indexes)
        return to_return

    def update_priority(self, index, newp):
        q = (newp.type(torch.float64) + self._offset)**self._alpha
        self._memory["priority"][index] = q

    def __len__(self):
        return self._len

In [None]:
class Agent:
    def __init__(self,
                 env,
                 build_net,
                 hyperparameters,
                 load = False, file_path = "dqnsavep.pth"):

        self._alpha       = hyperparameters["a"]
        self._loss_param  = hyperparameters["b"]
        self._lp_increase = hyperparameters["beta_decay"]
        self._lr          = hyperparameters["lr"]
        self._batch       = hyperparameters["batch"]
        self._epsilon     = hyperparameters["e0"]
        self._epsilon_min = hyperparameters["ef"]
        self._gamma       = hyperparameters["gamma"]
        self._tau         = hyperparameters["tau"]
        self._steps       = 0
        
        self.device        = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._memory       = Memory(env.observation_space.shape, 100000, self._alpha, 0.0001)
        self._action_space = env.action_space

        self._decay = self._epsilon_min**(1/hyperparameters["decay"])
        self._file = file_path

        self.dqn        = build_net().to(self.device)
        self.target_dqn = build_net().to(self.device)
        self._opt       = optim.Adam(self.dqn.parameters(), lr=self._lr)

        if load:
            self.dqn.load_state_dict(torch.load(load))
            self.target_dqn.load_state_dict(torch.load(load))

    def apply_decay(self):
        self._epsilon = max(self._epsilon*self._decay, self._epsilon_min)

    def __call__(self, state):
        if random.random() < self._epsilon:
            return 0, self._action_space.sample()
        else:
            state = torch.tensor(state).unsqueeze(0).to(self.device, dtype=torch.float64)
            pred = self.dqn(state)[0]
            value, index = pred.max(0)
            return value.item(), index.item()

    def save_on_memory(self, s, a, r, s2, d):
        self._memory.push(s, a, r, s2, d)

    def train(self):
        self._steps += 1
        if len(self._memory) < self._batch:
            return -float("inf")

        s, a, r, s2, d, p, i = self._memory.sample(self._batch)

        q_eval = self.dqn(s).gather(1, a.unsqueeze(1)).squeeze(1)
        q_next = self.target_dqn(s2).max(1)[0].detach()
        target = r + self._gamma*q_next*(1 - d)
        N      = len(self._memory)
        w      = (N * p)**(-self._loss_param)
        w      = w/w.max()
        self._loss_param *= 1 + self._lp_increase if self._loss_param < 1 else 1

        loss = F.smooth_l1_loss(q_eval, target, reduction="none")
        self._memory.update_priority(i, torch.abs(loss))

        weighted_loss = loss * w.detach()
        final_loss    = torch.mean(weighted_loss)
        self._opt.zero_grad()
        final_loss.backward()
        for param in self.dqn.parameters():
            param.grad.data.clamp_(-100, 100)
        self._opt.step()

        for target_param, param in zip(self.target_dqn.parameters(), self.dqn.parameters()):
            target_param.data.copy_(self._tau * param + (1 - self._tau) * target_param)

        return final_loss

    @property
    def epsilon(self):
        return self._epsilon

    @property
    def beta(self):
        return self._loss_param

    @property
    def memory(self):
        return self._memory


In [None]:
def build_net():
    net = nn.Sequential(nn.Linear(4,32),
                        nn.ReLU(),
                        nn.Linear(32,32),
                        nn.ReLU(),
                        nn.Linear(32,2)).type(torch.float64)
    return net

In [None]:
hparams = {
    "a":0.6,
    "b":0.4,
    "beta_decay": 2.5e-5,
    "lr":1e-3,
    "gamma":0.999,
    "e0":1,
    "ef":0.01,
    "decay":150,
    "batch":64,
    "tau": 0.01,
}

env = gym.make("CartPole-v1")
agent = Agent(env, build_net, hparams)
print(agent.dqn)
print(hparams)

episodes = 400
history = []
tts = 0
for i in range(1, episodes + 1):
    steps = 0
    rewards = 0
    total_loss = 0
    done = False
    state = env.reset()
    while not done:
        value, action = agent(state)
        sstate, reward, done, _ = env.step(action)
        agent.save_on_memory(state, action, reward, sstate, done)
        loss = agent.train()
        state = sstate
        rewards += reward
        steps += 1
        total_loss += loss
    # if i%100 == 0:
    #     save(agent.dqn.state_dict(), f"{str(i)}.pth")
    print(f"episode [{i:04d}] - "
          f"rewards [{rewards}] - "
          f"duration [{steps:04d}] - "
          f"epsilon [{100*agent.epsilon:.2f}%] - "
          f"beta [{agent.beta:.2f}] - "
          f"len [{len(agent.memory):06d}] - "
          f"loss [{total_loss:.2f}]")
    history.append(rewards)
    agent.apply_decay()

In [None]:
means = [0]*episodes
high = [0]*episodes
low  = [0]*episodes
for i in range(episodes):
    start = max(0, i - 20)
    end = min(episodes, i + 21)
    slice = history[ start : end ]
    mean = sum(slice)/(end - start)
    std = np.std(slice)
    means[i] = mean
    high[i] = mean + std
    low[i] = mean - std

plt.figure(figsize=(12,7))
with plt.style.context("seaborn-pastel"):
    x = [k for k in range(episodes)]
    plt.fill_between(x, low, high, alpha=0.2)
    plt.plot(x, means, linewidth=4)
    plt.plot(x, history)
    plt.show()