In [16]:
from collections import namedtuple
import random
from tqdm import tqdm_notebook
import torch

In [2]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [5]:
memory = ReplayMemory(10)

In [20]:
for i in tqdm_notebook(range(20)):
    memory.push(torch.Tensor(i), torch.Tensor(i), torch.Tensor(i), torch.Tensor(i))

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))




In [21]:
transitions = memory.sample(10)

In [22]:
transitions

[Transition(state=tensor([1.3119e+24, 3.0621e-41, 1.2266e+24, 3.0621e-41, 0.0000e+00, 0.0000e+00,
        1.3114e+24, 3.0621e-41, 1.2267e+24, 3.0621e-41, 1.4013e-45, 0.0000e+00,
               nan, 0.0000e+00]), action=tensor([ 1.3122e+24,  3.0621e-41,  0.0000e+00,  1.4013e-45,         nan,
         3.0621e-41,  1.0461e+24,  3.0621e-41,  1.3122e+24,  3.0621e-41,
        -5.7649e-27,  4.5832e-41,  0.0000e+00,  0.0000e+00]), next_state=tensor([-2.9468e+01,  4.5832e-41,  1.2536e+24,  3.0621e-41,  1.2111e+24,
         3.0621e-41,  1.2390e+24,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  1.4013e-45,  0.0000e+00]), reward=tensor([ 0.0000e+00,  0.0000e+00, -2.9467e+01,  4.5832e-41,  6.0547e+23,
         3.0621e-41,  4.6243e-44,  0.0000e+00,  1.2383e+24,  3.0621e-41,
        -2.9467e+01,  4.5832e-41,  8.9683e-44,  0.0000e+00])),
 Transition(state=tensor([-2.9467e+01,  4.5832e-41, -2.9467e+01,  4.5832e-41,  1.0374e-08,
         2.6370e-09,  1.2823e+16,  1.2849e+31, 

In [23]:
batch = Transition(*zip(*transitions))

In [24]:
batch.state

(tensor([1.3119e+24, 3.0621e-41, 1.2266e+24, 3.0621e-41, 0.0000e+00, 0.0000e+00,
         1.3114e+24, 3.0621e-41, 1.2267e+24, 3.0621e-41, 1.4013e-45, 0.0000e+00,
                nan, 0.0000e+00]),
 tensor([-2.9467e+01,  4.5832e-41, -2.9467e+01,  4.5832e-41,  1.0374e-08,
          2.6370e-09,  1.2823e+16,  1.2849e+31,  1.8395e+25,  6.1963e-04]),
 tensor([1.2442e+24, 3.0621e-41, 1.2426e+24, 3.0621e-41, 0.0000e+00, 0.0000e+00,
         1.2939e+24, 3.0621e-41, 1.2939e+24, 3.0621e-41, 1.4013e-45, 0.0000e+00,
         1.6109e-19, 1.8888e+31, 4.1051e-41, 0.0000e+00, 1.1210e-43]),
 tensor([1.6767e+24, 3.0621e-41, 1.8074e+24, 3.0621e-41, 0.0000e+00, 0.0000e+00,
         1.3252e+24, 3.0621e-41, 1.3252e+24, 3.0621e-41, 1.4013e-45, 0.0000e+00,
         8.9683e-44]),
 tensor([ 1.7063e+24,  3.0621e-41,  0.0000e+00,  1.4013e-45,         nan,
          3.0621e-41,  1.0461e+24,  3.0621e-41,  1.7063e+24,  3.0621e-41,
         -5.7649e-27,  4.5832e-41]),
 tensor([ 1.7083e+24,  3.0621e-41,  0.0000e+00,  1

In [25]:
batch.action

(tensor([ 1.3122e+24,  3.0621e-41,  0.0000e+00,  1.4013e-45,         nan,
          3.0621e-41,  1.0461e+24,  3.0621e-41,  1.3122e+24,  3.0621e-41,
         -5.7649e-27,  4.5832e-41,  0.0000e+00,  0.0000e+00]),
 tensor([-2.9467e+01,  4.5832e-41, -2.9467e+01,  4.5832e-41,  1.4013e-45,
          0.0000e+00,  1.2461e+24,  3.0621e-41,  1.2111e+24,  3.0621e-41]),
 tensor([9.5800e+23, 3.0621e-41, 1.2442e+24, 3.0621e-41, 0.0000e+00, 0.0000e+00,
         1.2940e+24, 3.0621e-41, 7.8309e+23, 3.0621e-41, 1.4013e-45, 0.0000e+00,
         1.4013e-45, 0.0000e+00, 1.2439e+24, 3.0621e-41, 1.1210e-43]),
 tensor([1.2489e+24, 3.0621e-41, 1.5999e+24, 3.0621e-41, 0.0000e+00, 0.0000e+00,
         1.3111e+24, 3.0621e-41, 1.3113e+24, 3.0621e-41, 1.4013e-45, 0.0000e+00,
                nan]),
 tensor([1.7160e+24, 3.0621e-41, 1.8043e+24, 3.0621e-41, 0.0000e+00, 0.0000e+00,
         1.7005e+24, 3.0621e-41, 1.6897e+24, 3.0621e-41, 1.4013e-45, 0.0000e+00]),
 tensor([1.7986e+24, 3.0621e-41, 1.2876e+24, 3.0621e-41, 

In [26]:
batch.next_state

(tensor([-2.9468e+01,  4.5832e-41,  1.2536e+24,  3.0621e-41,  1.2111e+24,
          3.0621e-41,  1.2390e+24,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  1.4013e-45,  0.0000e+00]),
 tensor([-2.9468e+01,  4.5832e-41,  1.3112e+24,  3.0621e-41,  1.2111e+24,
          3.0621e-41,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]),
 tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.4013e-45, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 1.0956e+24, 3.0621e-41, 1.1210e-43]),
 tensor([-29.4672,   0.0000, -29.4672,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000]),
 tensor([ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.4013e-45,         nan,
          3.0621e-41,  1.0461e+24,  3.0621e-41,  1.3289e+24,  3.0621e-41,
         -5.7649e-27,  4.5832e-41]),
 tensor([ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.4013e-45,  0.0000e+0

In [28]:
state_batch = torch.cat(batch.state)