# Replay Buffer

> Fill in a module description here

In [None]:
#| default_exp replay_buffer

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
from typing import Tuple
import torch

In [None]:
#| export
class ReplayBuffer:
    def __init__(self, mem_size: int, input_shape, n_actions: int):
        self.mem_size = mem_size
        self.mem_counter = 0
        
        # memory counter
        self.state_memory = torch.zeros((self.mem_size, *input_shape))
        self.new_state_memory = torch.zeros((self.mem_size, *input_shape))
        self.action_memory = torch.zeros((self.mem_size, n_actions))
        self.reward_memory = torch.zeros(self.mem_size)
        self.terminal_memory = torch.zeros(self.mem_size, dtype=torch.bool)
        
    def store_transition(self, state, action, new_state, reward, done):
        index = self.mem_counter % self.mem_size
        
        self.state_memory[index] = state
        self.new_state_memory[index] = new_state
        self.action_memory[index] = action
        self.reward_memory[index] = reward
        self.terminal_memory[index] = done
        
        self.mem_counter += 1
    
    def sample_buffer(self, batch_size: int):
        current_max_mem = min(self.mem_counter, self.mem_size)
        
        batch = torch.randperm(current_max_mem)[:batch_size]
        
        states = self.state_memory[batch]
        actions = self.action_memory[batch]
        new_states = self.new_state_memory[batch]
        rewards = self.reward_memory[batch]
        dones = self.terminal_memory[batch]
        
        return states, actions, new_states, rewards, dones