In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import gym
from collections import namedtuple
import random

env = gym.make('CartPole-v0')
env.reset()

env._max_episode_steps = 500

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward', 'done'))

class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, **kargs):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(**kargs)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [4]:
class QNet(nn.Module):
    def __init__(self):
        super(QNet, self).__init__()
        hidden = 10
        self.fc1 = nn.Linear(env.observation_space.shape[0], hidden)
        self.fc2 = nn.Linear(hidden, hidden)
        self.fc3 = nn.Linear(hidden, env.action_space.n)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [5]:
num_ep = 500
num_batch = 50
num_train = 50

# reward list
rList = []
dis = 0.9

model = QNet()
target_model = QNet()
target_model.load_state_dict(model.state_dict())
model.train()
target_model.eval()

lr = 1e-3
criterion = nn.MSELoss()
optim = torch.optim.Adam([p for p in model.parameters()], lr=lr)
memory = ReplayMemory(1000000)

#     e = exploration_max
max_step = 0
for i in range(num_ep):
    state = env.reset()
    done = False
    step = 0
    rAll = 0
    e = 1. / (i + 10)
    while not done:
        env.render()
        out = target_model(torch.tensor(state).float())
        if np.random.rand(1) < e:
            action = env.action_space.sample()
        else:
            action = torch.argmax(out).item()

        new_state, reward, done, _ = env.step(action)
        # 'Transition', ('state', 'action', 'next_state', 'reward', 'done')
        t = {
            'state': state,
            'action': action,
            'next_state': new_state,
            'reward': reward if not done else -100,
            'done': done,
        }
        memory.push(**t)
        state = new_state
        rAll += reward
        step += 1

        if step >= 10000:
            break
            
    if step >= max_step:
        best_model = model
        max_step = step
        print(f"ep: {i}, step: {step}")

    rList.append(rAll)
    if (i + 1) % 100 == 0:
        print("lr:{}, ep:{}, max: {}".format(lr, i, np.max(rList)))

    # update every 10 episodes
    if i % 10 == 0:
        if len(memory) < num_batch:
            pass
        else:
            for i in range(num_train):
                # sampling from replay memory
                batches = memory.sample(num_batch)
                
                batch_state = torch.tensor([b.state for b in batches]).float()
                batch_next_state = torch.tensor([b.next_state for b in batches]).float()
                batch_action = torch.tensor([b.action for b in batches]).float()
                batch_reward = torch.tensor([b.reward for b in batches]).float()
                batch_done = torch.tensor([b.done for b in batches]).float()
                
                with torch.no_grad():
                    target = batch_reward + dis * torch.max(target_model(batch_next_state), axis=1)[0]
                target[batch_done == 1] = batch_reward[batch_done == 1]

                pred = model(batch_state)
                pred = pred.gather(1, batch_action.long().reshape(-1, 1))
                loss = criterion(pred.flatten(), target)
                optim.zero_grad()
                loss.backward()
                optim.step()
                
            target_model.load_state_dict(model.state_dict())

ep: 0, step: 10
ep: 2, step: 11
ep: 17, step: 11
ep: 25, step: 11
ep: 61, step: 18
ep: 63, step: 25
ep: 67, step: 27
ep: 71, step: 51
ep: 72, step: 176
ep: 80, step: 223
lr:0.001, ep:99, max: 223.0
ep: 197, step: 492
lr:0.001, ep:199, max: 492.0
ep: 213, step: 500
ep: 249, step: 500
lr:0.001, ep:299, max: 500.0
ep: 314, step: 500
ep: 388, step: 500
lr:0.001, ep:399, max: 500.0
ep: 401, step: 500
ep: 408, step: 500
ep: 413, step: 500
ep: 442, step: 500
ep: 463, step: 500
ep: 496, step: 500
lr:0.001, ep:499, max: 500.0


In [None]:
best_model

In [17]:
# inference
state = env.reset()
reward_sum = 0
done = False
while not done:
    env.render()
    pred = best_model(torch.tensor(state).float())
    action = torch.argmax(pred).item()
    new_state, reward, done, _ = env.step(action)
    
    reward_sum += reward
    state = new_state
    
    if reward_sum % 100 == 0:
        print(reward_sum)
    

100.0
200.0
300.0
400.0
500.0
600.0
700.0
800.0
900.0
1000.0
1100.0
1200.0
1300.0
1400.0
1500.0
1600.0
1700.0
1800.0
1900.0
2000.0
2100.0
2200.0
2300.0
2400.0
2500.0
2600.0
2700.0
2800.0
2900.0
3000.0
3100.0
3200.0
3300.0
3400.0
3500.0
3600.0
3700.0
3800.0
3900.0
4000.0
4100.0
4200.0
4300.0
4400.0
4500.0
4600.0
4700.0
4800.0
4900.0
5000.0
5100.0
5200.0
5300.0
5400.0
5500.0
5600.0
5700.0
5800.0
5900.0
6000.0
6100.0
6200.0
6300.0
6400.0
6500.0
6600.0
6700.0
6800.0
6900.0
7000.0
7100.0
7200.0
7300.0
7400.0
7500.0
7600.0
7700.0
7800.0
7900.0
8000.0
8100.0
8200.0
8300.0
8400.0
8500.0
8600.0
8700.0
8800.0
8900.0
9000.0
9100.0
9200.0
9300.0
9400.0
9500.0
9600.0
9700.0
9800.0
9900.0
10000.0
