In [1]:
import random
import gym
import torch.nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm

import rl_utils
import numpy as np

In [2]:
class Qnet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(Qnet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)


In [3]:
class DQN:
    def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma, epsilon, target_update, device):
        self.action_dim = action_dim
        self.q_net = Qnet(state_dim, hidden_dim, self.action_dim).to(device)
        self.target_q_net = Qnet(state_dim, hidden_dim, self.action_dim).to(device)
        self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=learning_rate)
        self.gamma = gamma
        self.epsilon = epsilon
        self.targe_update = target_update  # the frequency of the target network update
        self.count = 0
        self.device = device

    def take_action(self, state):  # epsilon strategy
        if np.random.random() < self.epsilon:  # if epsion < threshold -> exploration
            action = np.random.randint(self.action_dim)
        else:  # -> exploitation
            state = torch.tensor([state], dtype=torch.float).to(self.device)
            action = self.q_net(state).argmax().item()
        return action

    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1).to(self.device)

        q_values = self.q_net(states).gather(1, actions)

        max_next_q_values = self.target_q_net(next_states).max(1)[0].view(-1, 1)
        q_targets = rewards + self.gamma * max_next_q_values * (1 - dones)
        dqn_loss = torch.mean(F.mse_loss(q_values, q_targets))
        self.optimizer.zero_grad()
        dqn_loss.backward()
        self.optimizer.step()

        if self.count % self.targe_update == 0:
            self.target_q_net.load_state_dict(self.q_net.state_dict())

        self.count += 1

In [4]:
def show(return_list, env_name):
    episodes_list = list(range(len(return_list)))
    plt.plot(episodes_list, return_list)
    plt.xlabel('Episodes')
    plt.ylabel('Returns')
    plt.title('DQN on {}'.format(env_name))
    plt.show()

    mv_return = rl_utils.moving_average(return_list, 9)
    plt.plot(episodes_list, mv_return)
    plt.xlabel('Episodes')
    plt.ylabel('Returns')
    plt.title('DQN on {}'.format(env_name))
    plt.show()

In [5]:
def train():
    lr = 2e-3
    num_episodes = 500
    hidden_dim = 128
    gamma = 0.99
    epsilon = 0.01
    target_update = 10
    buffer_size = 10000
    minimal_size = 500
    batch_size = 64
    device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
    env_name = 'CartPole-v0'
    env = gym.make(env_name, render_mode="human")
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    replay_buffer = rl_utils.ReplayBuffer(buffer_size)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon, target_update, device)

    return_list = []
    for i in range(10):
        with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:
            for i_episode in range(int(num_episodes / 10)):
                epsilon_return = 0
                state, _ = env.reset()
                done = False
                while not done:
                    env.render()
                    action = agent.take_action(state)
                    next_state, reward, done, _, _ = env.step(action)
                    replay_buffer.add(state, action, reward, next_state, done)
                    state = next_state
                    epsilon_return += reward
                    if replay_buffer.size() > minimal_size:
                        b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)
                        transition_dict = {'states': b_s, 'actions': b_a, 'next_states': b_ns, 'rewards': b_r,
                                           'dones': b_d}
                        agent.update(transition_dict)
                return_list.append(epsilon_return)
                if (i_episode + 1) % 10 == 0:
                    pbar.set_postfix({
                        'episode':
                            '%d' % (num_episodes / 10 * i + i_episode + 1),
                        'return':
                            '%.3f' % np.mean(return_list[-10:])
                    })
                pbar.update(1)
    show(return_list, env_name)
    env.close()

In [6]:
train()

  logger.warn(
  state = torch.tensor([state], dtype=torch.float).to(self.device)
  if not isinstance(terminated, (bool, np.bool8)):
Iteration 0: 100%|██████████| 50/50 [00:21<00:00,  2.36it/s, episode=50, return=9.100]
Iteration 1: 100%|██████████| 50/50 [00:22<00:00,  2.20it/s, episode=100, return=12.700]
Iteration 2: 100%|██████████| 50/50 [01:20<00:00,  1.61s/it, episode=150, return=68.500]
Iteration 3: 100%|██████████| 50/50 [04:57<00:00,  5.96s/it, episode=200, return=177.200]
Iteration 4: 100%|██████████| 50/50 [06:34<00:00,  7.88s/it, episode=250, return=172.400]
Iteration 5: 100%|██████████| 50/50 [09:03<00:00, 10.88s/it, episode=300, return=307.000]
Iteration 6:   8%|▊         | 4/50 [01:33<17:50, 23.27s/it]


KeyboardInterrupt: 