In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import gym
import random
from collections import deque
from tqdm import tqdm
import rl_utils

In [2]:
class QNet(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNet, self).__init__()
        self.fc1 = nn.Linear(state_dim, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, action_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        return x

In [3]:
class DQN:
    def __init__(self, state_dim, action_dim, target_update, device, gamma=0.99, lr=1e-3):
        self.device = device
        self.action_dim = action_dim
        self.gamma = gamma

        self.q_net = QNet(state_dim, action_dim).to(device)
        self.target_q_net = QNet(state_dim, action_dim).to(device)
        self.target_q_net.load_state_dict(self.q_net.state_dict())

        self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)

        self.target_update = target_update # Target network update interval. If target_update = t,then when q_net is updated t times, target_q_net is updated once
        self.count = 0 # Counter. record q_net update times


    def take_action(self, state, epsilon):
        # epsilon-greedy
        ############################
        # YOUR IMPLEMENTATION HERE #
        ############################
        # pass
        
        # epsilon-greedy policy
        if np.random.random() < epsilon:
            action = np.random.randint(self.action_dim)
        else:
            # state = torch.tensor([state], dtype=torch.float).to(self.device)
            state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)
            
            action = self.q_net(state).argmax().item()
        return action







    def update(self, replay_buffer, batch_size):
        states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)

        states = torch.tensor(states, dtype=torch.float32).to(self.device)
        actions = torch.tensor(actions, dtype=torch.int64).to(self.device)
        rewards = torch.tensor(rewards, dtype=torch.float32).to(self.device)
        next_states = torch.tensor(next_states, dtype=torch.float32).to(self.device)
        dones = torch.tensor(dones, dtype=torch.float32).to(self.device)

        ############################
        # YOUR IMPLEMENTATION HERE #
        ############################
        # pass
        actions = actions.view(-1, 1)
        rewards = actions.view(-1, 1)
        dones = actions.view(-1, 1)
        
        q_values = self.q_net(states).gather(1, actions)  # Q值
        # 下个状态的最大Q值
        max_next_q_values = self.target_q_net(next_states).max(1)[0].view(
            -1, 1)
        q_targets = rewards + self.gamma * max_next_q_values * (1 - dones
                                                                )  # TD误差目标
        import torch.nn.functional as F
        loss = torch.mean(F.mse_loss(q_values, q_targets))  # 均方误差损失函数
        self.optimizer.zero_grad()  # PyTorch中默认梯度会累积,这里需要显式将梯度置为0
        loss.backward()  # 反向传播更新参数
        self.optimizer.step()

        if self.count % self.target_update == 0:
            self.target_q_net.load_state_dict(
                self.q_net.state_dict())  # 更新目标网络








        self.count += 1
        return loss.item()

In [4]:
# One-hot
def one_hot(state, state_dim):
    vec = np.zeros(state_dim, dtype=np.float32)
    vec[state] = 1.0
    return vec

In [5]:
env_name = "Taxi-v2"
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
env = gym.make(env_name)
env.seed(0)
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)


obs_space = env.observation_space.n
action_space = env.action_space.n

# parameters
total_episodes = 5000
episodes_per_iteration = 100
iterations = total_episodes // episodes_per_iteration

batch_size = 64
buffer_size = 20000
min_buffer_size = 5000
epsilon_start = 1.0
epsilon_end = 0.1
epsilon_decay = 0.999
update_freq = 1
target_update = 10

replay_buffer = rl_utils.ReplayBuffer(buffer_size)
agent = DQN(obs_space, action_space, target_update, device, gamma=0.99, lr=0.0001)

return_list = []

epsilon = epsilon_start
total_steps = 0

# tqdm
for i_iter in range(iterations):
    with tqdm(range(episodes_per_iteration), desc=f"Iteration {i_iter}", ncols=100) as pbar:
        for i_episode_in_iter in pbar:
            i_episode = i_iter * episodes_per_iteration + i_episode_in_iter
            state = env.reset()
            state_vec = one_hot(state, obs_space)
            done = False
            episode_reward = 0
            episode_length = 0

            while not done:
                total_steps += 1
                episode_length += 1
                action = agent.take_action(state_vec, epsilon)
                next_state, reward, done, info = env.step(action)
                next_state_vec = one_hot(next_state, obs_space)

                replay_buffer.add(state_vec, action, reward, next_state_vec, done)
                state_vec = next_state_vec
                episode_reward += reward

                if replay_buffer.size() > min_buffer_size and total_steps % update_freq == 0:
                    loss = agent.update(replay_buffer, batch_size)

            epsilon = max(epsilon_end, epsilon * epsilon_decay)
            return_list.append(episode_reward)

    avg_return = np.mean(return_list[-episodes_per_iteration:])
    print(f"Episode: {(i_iter + 1) * episodes_per_iteration}, Average Return: {avg_return:.2f}")


torch.save(agent.q_net.state_dict(), "dqn_taxi.pth")

  result = entry_point.load(False)
Iteration 0: 100%|████████████████████████████████████████████████| 100/100 [00:53<00:00,  1.87it/s]


Episode: 100, Average Return: -794.63


Iteration 1: 100%|████████████████████████████████████████████████| 100/100 [01:52<00:00,  1.12s/it]


Episode: 200, Average Return: -812.69


Iteration 2: 100%|████████████████████████████████████████████████| 100/100 [01:45<00:00,  1.05s/it]


Episode: 300, Average Return: -901.93


Iteration 3: 100%|████████████████████████████████████████████████| 100/100 [01:54<00:00,  1.14s/it]


Episode: 400, Average Return: -1004.21


Iteration 4: 100%|████████████████████████████████████████████████| 100/100 [01:55<00:00,  1.16s/it]


Episode: 500, Average Return: -1096.78


Iteration 5: 100%|████████████████████████████████████████████████| 100/100 [01:44<00:00,  1.04s/it]


Episode: 600, Average Return: -1202.25


Iteration 6: 100%|████████████████████████████████████████████████| 100/100 [02:04<00:00,  1.25s/it]


Episode: 700, Average Return: -1299.98


Iteration 7: 100%|████████████████████████████████████████████████| 100/100 [02:00<00:00,  1.20s/it]


Episode: 800, Average Return: -1369.78


Iteration 8: 100%|████████████████████████████████████████████████| 100/100 [01:42<00:00,  1.03s/it]


Episode: 900, Average Return: -1448.30


Iteration 9: 100%|████████████████████████████████████████████████| 100/100 [01:32<00:00,  1.08it/s]


Episode: 1000, Average Return: -1487.38


Iteration 10: 100%|███████████████████████████████████████████████| 100/100 [01:34<00:00,  1.06it/s]


Episode: 1100, Average Return: -1541.12


Iteration 11: 100%|███████████████████████████████████████████████| 100/100 [01:39<00:00,  1.00it/s]


Episode: 1200, Average Return: -1583.57


Iteration 12: 100%|███████████████████████████████████████████████| 100/100 [01:36<00:00,  1.03it/s]


Episode: 1300, Average Return: -1623.08


Iteration 13: 100%|███████████████████████████████████████████████| 100/100 [01:49<00:00,  1.10s/it]


Episode: 1400, Average Return: -1667.47


Iteration 14: 100%|███████████████████████████████████████████████| 100/100 [01:58<00:00,  1.19s/it]


Episode: 1500, Average Return: -1702.37


Iteration 15: 100%|███████████████████████████████████████████████| 100/100 [01:52<00:00,  1.13s/it]


Episode: 1600, Average Return: -1727.21


Iteration 16: 100%|███████████████████████████████████████████████| 100/100 [02:02<00:00,  1.22s/it]


Episode: 1700, Average Return: -1757.54


Iteration 17: 100%|███████████████████████████████████████████████| 100/100 [01:50<00:00,  1.11s/it]


Episode: 1800, Average Return: -1785.26


Iteration 18: 100%|███████████████████████████████████████████████| 100/100 [02:02<00:00,  1.22s/it]


Episode: 1900, Average Return: -1791.38


Iteration 19: 100%|███████████████████████████████████████████████| 100/100 [02:17<00:00,  1.37s/it]


Episode: 2000, Average Return: -1814.78


Iteration 20: 100%|███████████████████████████████████████████████| 100/100 [02:20<00:00,  1.41s/it]


Episode: 2100, Average Return: -1823.60


Iteration 21: 100%|███████████████████████████████████████████████| 100/100 [02:24<00:00,  1.44s/it]


Episode: 2200, Average Return: -1833.41


Iteration 22:  40%|███████████████████▏                            | 40/100 [01:20<02:00,  2.00s/it]


KeyboardInterrupt: 

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(5, 5))

# rewards
mv_return = rl_utils.moving_average(return_list, 9)
axes.plot(mv_return)
axes.set_title("Episode Rewards")
axes.set_xlabel("Episode")
axes.set_ylabel("Reward")
axes.grid(True)

plt.tight_layout()
plt.show()