In [1]:
import gym
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, optim

In [2]:
I = 4  # Input dimensions
H = 64  # Hidden layer dimensions
O = 2  # output dimensions (one-hot encoding)
LEARNING_RATE = 0.00001
MOMENTUM = 0.9
REPLAY_LENGTH = 1000
EPISODE_NUM = 1
EPISODE_LENGTH = 200
EPSILON = 0.1
WARMUP_LENGTH = 10
MINIBATCH_SIZE = 1000
# WARMUP_LENGTH = 1
# MINIBATCH_SIZE = 100
EPOCHS = 1
GAMMA = 1
# dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [3]:
class Replay:
    def __init__(self, n):
        self.n = n
        self.obs_arr = []
        self.action_arr = []
        self.reward_arr = []
        self.obs2_arr = []
        self.arr_arr = [self.obs_arr, self.action_arr, self.reward_arr, self.obs2_arr]

    def add(self, t):
        assert len(t) == 4
        flag = len(self.obs_arr) > self.n
        for i, v in enumerate(t):
            self.arr_arr[i].append(v)
            if flag: self.arr_arr[i].pop(0)

    def __str__(self):
        return f'n: {self.n}\n' \
               f'obs_arr: {self.obs_arr}\n' \
               f'action_arr: {self.action_arr}\n' \
               f'reward_arr: {self.reward_arr}\n' \
               f'obs2_arr: {self.obs2_arr}'

    def sample(self, n):
        idx = torch.randperm(min(len(self.obs_arr), n))
        obs_tsr = torch.cat(self.obs_arr).view((-1, 4))
        action_tsr = torch.LongTensor(self.action_arr)
        reward_tsr = torch.FloatTensor(self.reward_arr)
        obs2_tsr = torch.cat(self.obs2_arr).view((-1, 4))
        return tuple(map(lambda x:x[idx], (obs_tsr, action_tsr, reward_tsr, obs2_tsr)))



In [4]:
# model = DQN()
# optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE)
# optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
replay = Replay(REPLAY_LENGTH)
env = gym.make('CartPole-v0')

In [5]:
for episode in range(EPISODE_NUM):
    env.reset()
    obs, reward, done, info = env.step(env.action_space.sample())
    obs = torch.from_numpy(obs)
    for time in range(EPISODE_LENGTH):
        # env.render()
        next_action = env.action_space.sample()
        obs2, reward, done, info = env.step(next_action)
        obs2 = torch.from_numpy(obs2)
        replay.add((obs, next_action, reward, obs2))
        obs = obs2
print(replay)

n: 1000
obs_arr: [tensor([ 0.0452,  0.2367, -0.0093, -0.2746], dtype=torch.float64), tensor([ 0.0499,  0.4319, -0.0148, -0.5702], dtype=torch.float64), tensor([ 0.0585,  0.2370, -0.0262, -0.2823], dtype=torch.float64), tensor([ 0.0633,  0.4325, -0.0319, -0.5831], dtype=torch.float64), tensor([ 0.0719,  0.6281, -0.0436, -0.8857], dtype=torch.float64), tensor([ 0.0845,  0.4336, -0.0613, -0.6070], dtype=torch.float64), tensor([ 0.0932,  0.2393, -0.0734, -0.3342], dtype=torch.float64), tensor([ 0.0980,  0.4354, -0.0801, -0.6491], dtype=torch.float64), tensor([ 0.1067,  0.6316, -0.0931, -0.9659], dtype=torch.float64), tensor([ 0.1193,  0.8278, -0.1124, -1.2863], dtype=torch.float64), tensor([ 0.1359,  1.0242, -0.1381, -1.6120], dtype=torch.float64), tensor([ 0.1563,  1.2206, -0.1704, -1.9443], dtype=torch.float64), tensor([ 0.1808,  1.0277, -0.2092, -1.7089], dtype=torch.float64), tensor([ 0.2013,  0.8355, -0.2434, -1.4880], dtype=torch.float64), tensor([ 0.2180,  1.0326, -0.2732, -1.8458],



In [6]:
print(replay.sample(10))


(tensor([[ 0.0585,  0.2370, -0.0262, -0.2823],
        [ 0.0452,  0.2367, -0.0093, -0.2746],
        [ 0.1193,  0.8278, -0.1124, -1.2863],
        [ 0.1067,  0.6316, -0.0931, -0.9659],
        [ 0.0499,  0.4319, -0.0148, -0.5702],
        [ 0.0980,  0.4354, -0.0801, -0.6491],
        [ 0.0719,  0.6281, -0.0436, -0.8857],
        [ 0.0633,  0.4325, -0.0319, -0.5831],
        [ 0.0932,  0.2393, -0.0734, -0.3342],
        [ 0.0845,  0.4336, -0.0613, -0.6070]], dtype=torch.float64), tensor([1, 1, 1, 1, 0, 1, 0, 1, 1, 0]), tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), tensor([[ 0.0633,  0.4325, -0.0319, -0.5831],
        [ 0.0499,  0.4319, -0.0148, -0.5702],
        [ 0.1359,  1.0242, -0.1381, -1.6120],
        [ 0.1193,  0.8278, -0.1124, -1.2863],
        [ 0.0585,  0.2370, -0.0262, -0.2823],
        [ 0.1067,  0.6316, -0.0931, -0.9659],
        [ 0.0845,  0.4336, -0.0613, -0.6070],
        [ 0.0719,  0.6281, -0.0436, -0.8857],
        [ 0.0980,  0.4354, -0.0801, -0.6491],
        [ 0.