In [5]:
import gym
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

In [6]:
I = 4  # Input dimensions
H = 64  # Hidden layer dimensions
O = 2  # output dimensions (one-hot encoding)
LEARNING_RATE = 0.00001
REPLAY_LENGTH = 1000
EPISODE_NUM = 1000
EPISODE_LENGTH = 200
WARMUP_LENGTH = 10
EPSILON = 0.1
MINIBATCH_SIZE = 1000
# BATCH_SIZE =
EPOCHS = 1
GAMMA = 1
dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [7]:
class DQN(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin1 = nn.Linear(I,H)
        self.lin2 = nn.Linear(H,H)
        self.lin3 = nn.Linear(H,O)

    def forward(self, xb):
        xb = F.relu(self.lin1(xb))
        xb = F.relu(self.lin2(xb))
        return self.lin3(xb)

def fit(epochs, batch_size, lr, model, loss_func, x, y):
    for epoch in range(epochs):
        model.train()
        batch_num = x.size()[0] // batch_size
        for i in range(batch_num):
            xb = x[i*batch_size:(i+1)*batch_size]
            yb = y[i*batch_size:(i+1)*batch_size]
            loss = loss_func(model(xb), yb)
            loss.backward()
            with torch.no_grad():
                for p in model.parameters():
                    p -= p.grad * lr
                model.zero_grad()
        print(f'Epoch {epoch}: Loss: {loss}')

def loss_func(y1, y2):
    return torch.sum(torch.square(y1-y2))

class Replay:
    def __init__(self, n):
        self.rng = np.random.default_rng()
        self.n = n
        self.arr = []

    def debug(self, n):
        print(np.array(self.arr, dtype='object'))
        print(self.rng.choice(np.array(self.arr, dtype='object'), (n,)))

    def sample(self, n):
        minibatch = self.rng.choice(np.array(self.arr, dtype='object'), (n,))
        obs2_arr = torch.from_numpy(np.vstack(minibatch[:, 3]).astype(np.float))
        obs_arr = torch.from_numpy(np.vstack(minibatch[:, 0]).astype(np.float))
        action_arr = torch.from_numpy(np.vstack(minibatch[:, 1]).astype(np.int).flatten())
        reward_arr = torch.from_numpy(np.vstack(minibatch[:, 2]).astype(np.float).flatten())
        return obs_arr, action_arr, reward_arr, obs2_arr

    def add_one(self, v):
        self.arr.append(v)
        if len(self.arr) > self.n:
            self.arr.pop(0)

In [8]:
model = DQN().to(dev)
replay = Replay(REPLAY_LENGTH)
env = gym.make('CartPole-v0')

In [9]:
for episode in range(EPISODE_NUM):
    env.reset()
    obs, reward, done, info = env.step(env.action_space.sample())
    for time in range(EPISODE_LENGTH):
        obs = torch.from_numpy(obs)
        env.render()
        if np.random.rand() < EPSILON or episode < WARMUP_LENGTH:
            next_action = env.action_space.sample()
        else:
            action_values = model(torch.unsqueeze(obs, 0).type(torch.FloatTensor).cuda())[0]
            next_action = torch.argmax(action_values).cpu().numpy()
        obs2, reward, done, info = env.step(next_action)
        replay.add_one((obs, next_action, reward, obs2))
        obs = obs2
    if episode >= WARMUP_LENGTH:
        obs_arr, action_arr, reward_arr, obs2_arr = replay.sample(MINIBATCH_SIZE)
        idx = torch.tensor(range(MINIBATCH_SIZE), dtype=torch.long)
        x = obs_arr.type(torch.FloatTensor).cuda()
        y = model(x)
        y2 = model(obs2_arr.type(torch.FloatTensor).cuda())
        action_arr = action_arr.to(dtype=torch.long)
        y[idx, action_arr] = GAMMA * torch.max(y2, dim=1)[0] + reward_arr.type(torch.FloatTensor).cuda()
        fit(EPOCHS, MINIBATCH_SIZE, LEARNING_RATE, model, loss_func, x, y)
torch.save(model, 'deep_gym_model.pt')



Epoch 0: Loss: 570.0047607421875
Epoch 0: Loss: 177.67340087890625
Epoch 0: Loss: 350.0753173828125
Epoch 0: Loss: 315.1231994628906
Epoch 0: Loss: 241.1131591796875
Epoch 0: Loss: 101.14593505859375
Epoch 0: Loss: 264.685302734375
Epoch 0: Loss: 1692.475341796875
Epoch 0: Loss: 57394.3828125
Epoch 0: Loss: 37974.67578125
Epoch 0: Loss: 111.63449096679688
Epoch 0: Loss: 111.12545776367188
Epoch 0: Loss: 111.80415344238281
Epoch 0: Loss: 136.90884399414062
Epoch 0: Loss: 166.64620971679688
Epoch 0: Loss: 150.34725952148438
Epoch 0: Loss: 168.26280212402344
Epoch 0: Loss: 192.6641845703125
Epoch 0: Loss: 198.60406494140625
Epoch 0: Loss: 168.82664489746094
Epoch 0: Loss: 137.48690795898438
Epoch 0: Loss: 130.38595581054688
Epoch 0: Loss: 116.85819244384766
Epoch 0: Loss: 126.87364196777344
Epoch 0: Loss: 151.2644500732422
Epoch 0: Loss: 158.30499267578125
Epoch 0: Loss: 161.955078125
Epoch 0: Loss: 137.8616943359375
Epoch 0: Loss: 147.82733154296875
Epoch 0: Loss: 167.44342041015625
Epoc

KeyboardInterrupt: 