In [16]:
from collections import namedtuple
import random
from tqdm import tqdm_notebook
import torch

In [2]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [49]:
memory = ReplayMemory(10)

In [56]:
for i in tqdm_notebook(range(20)):
    memory.push(torch.Tensor([1*i]), torch.Tensor([10*i]), torch.Tensor([100*i]), torch.Tensor([1000*i]))
    #'state', 'action', 'next_state', 'reward'

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))




In [57]:
transitions = memory.sample(10)

In [58]:
batch = Transition(*zip(*transitions))

In [59]:
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), dtype=torch.uint8)
non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])

In [60]:
state_batch = torch.cat(batch.state)
action_batch = torch.cat(batch.action)
reward_batch = torch.cat(batch.reward)

In [61]:
non_final_mask

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.uint8)

In [63]:
non_final_next_states

tensor([1400., 1100., 1600., 1800., 1000., 1200., 1300., 1900., 1700., 1500.])

In [66]:
next_state_values = torch.zeros(10)

In [67]:
next_state_values[non_final_mask]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])