In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import gym
import random
from collections import deque
from tqdm import tqdm
import rl_utils

In [None]:
class QNet(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNet, self).__init__()
        self.fc1 = nn.Linear(state_dim, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, action_dim)
        
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        return x

In [None]:
class DQN:
    def __init__(self, state_dim, action_dim, target_update, device, gamma=0.99, lr=1e-3):
        self.device = device
        self.action_dim = action_dim
        self.gamma = gamma
        
        self.q_net = QNet(state_dim, action_dim).to(device)
        self.target_q_net = QNet(state_dim, action_dim).to(device)
        self.target_q_net.load_state_dict(self.q_net.state_dict())
        
        self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)
        
        self.target_update = target_update #Target network update interval. If target_update = t,then when q_net is updated t times, target_q_net is updated once
        self.count = 0 #Counter.  record q_net update times

    def take_action(self, state, epsilon):
        # epsilon-greedy
        ############################
        # YOUR IMPLEMENTATION HERE #
        ############################
        pass    
    
    def update(self, replay_buffer, batch_size):
        states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)
        
        states = torch.tensor(states, dtype=torch.float32).to(self.device)
        actions = torch.tensor(actions, dtype=torch.int64).to(self.device)
        rewards = torch.tensor(rewards, dtype=torch.float32).to(self.device)
        next_states = torch.tensor(next_states, dtype=torch.float32).to(self.device)
        dones = torch.tensor(dones, dtype=torch.float32).to(self.device)
        
        ############################
        # YOUR IMPLEMENTATION HERE #
        ############################
        pass    
            
            
        self.count += 1
        return loss.item()
    

In [None]:
# One-hot
def one_hot(state, state_dim):
    vec = np.zeros(state_dim, dtype=np.float32)
    vec[state] = 1.0
    return vec

In [None]:
env_name = "Taxi-v2"
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
env = gym.make(env_name)
env.seed(0)
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)


obs_space = env.observation_space.n
action_space = env.action_space.n

# parameters
total_episodes = 5000
episodes_per_iteration = 100
iterations = total_episodes // episodes_per_iteration

batch_size = 64
buffer_size = 20000
min_buffer_size = 5000
epsilon_start = 1.0
epsilon_end = 0.1
epsilon_decay = 0.999  
update_freq = 1
target_update = 10

replay_buffer = rl_utils.ReplayBuffer(buffer_size)
agent = DQN(obs_space, action_space, target_update, device, gamma=0.99, lr=0.0001)

return_list = []

epsilon = epsilon_start
total_steps = 0

# tqdm
for i_iter in range(iterations):
    with tqdm(range(episodes_per_iteration), desc=f"Iteration {i_iter}", ncols=100) as pbar:
        for i_episode_in_iter in pbar:
            i_episode = i_iter * episodes_per_iteration + i_episode_in_iter
            state = env.reset()
            state_vec = one_hot(state, obs_space)
            done = False
            episode_reward = 0
            episode_length = 0
            
            while not done:
                total_steps += 1
                episode_length += 1
                action = agent.take_action(state_vec, epsilon)
                next_state, reward, done, info = env.step(action)
                next_state_vec = one_hot(next_state, obs_space)
                
                replay_buffer.add(state_vec, action, reward, next_state_vec, done)
                state_vec = next_state_vec
                episode_reward += reward
                
                if replay_buffer.size() > min_buffer_size and total_steps % update_freq == 0:
                    loss = agent.update(replay_buffer, batch_size)
            
            epsilon = max(epsilon_end, epsilon * epsilon_decay)
            return_list.append(episode_reward)
    
    avg_return = np.mean(return_list[-episodes_per_iteration:])
    print(f"Episode: {(i_iter+1)*episodes_per_iteration}, Average Return: {avg_return:.2f}")


torch.save(agent.q_net.state_dict(), "dqn_taxi.pth")

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(5, 5))

# rewards
mv_return = rl_utils.moving_average(return_list, 9)
axes.plot(mv_return)
axes.set_title("Episode Rewards")
axes.set_xlabel("Episode")
axes.set_ylabel("Reward")
axes.grid(True)

plt.tight_layout()
plt.show()