# Playing Cartpole with DQN

In [49]:
import os
import gym
import yaml
import torch
import random
import platform
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import itertools
import matplotlib.pyplot as plt
from datetime import datetime, timedelta

from memory import ReplayMemory


In [50]:
if platform.system() == "Darwin" and platform.machine().startswith("arm"):
    DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
else:
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [51]:
class DQN(nn.Module):
    def __init__(self, input_channels, action_count, hidden_size=256):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_channels, hidden_size)
        self.fc2 = nn.Linear(hidden_size, action_count)

    def forward(self, input_tensor):
        output = F.relu(self.fc1(input_tensor))
        return self.fc2(output)

In [52]:
LOSS_FN = nn.MSELoss()
opimizer = None
EPISODES = 20000
MEMORY_SIZE = 100000
RUNS_DIR = 'cartpole/runs'
if not os.path.exists(RUNS_DIR):
    os.makedirs(RUNS_DIR)
GRAPH_FILE = os.path.join(RUNS_DIR, 'cartpole.png')
MODEL_FILE = os.path.join(RUNS_DIR, 'cartpole.pth')
LOG_FILE = os.path.join(RUNS_DIR, 'cartpole.log')
LEARNING_RATE = 0.001
BATCH_SIZE = 32
SYNCH_RATE = 10
EPSILON_INIT = 1
EPSILON_DECAY = 0.9995
EPSILON_MIN = 0.05
STOP_REWARD = 100000
DISCOUNT_FACTOR = 0.99
DATE_FORMAT = "%m-%d %H:%M:%S"

In [53]:
def optimize(batch, policy_net, target_net):
    global opimizer
    states, actions, new_states, rewards, dones = zip(*batch)

    states = torch.stack(states)
    actions = torch.stack(actions)
    new_states = torch.stack(new_states)
    rewards = torch.stack(rewards)
    dones = torch.tensor(dones).float().to(DEVICE)

    with torch.no_grad():
        target_q_values = rewards + (1 - dones) * DISCOUNT_FACTOR * target_net(new_states).max(dim=1)[0]

    current_q_values = policy_net(states).gather(dim=1, index=actions.unsqueeze(dim=1)).squeeze()

    loss = LOSS_FN(current_q_values, target_q_values)
    opimizer.zero_grad()
    loss.backward()
    opimizer.step()

In [54]:
def save_graph(rewards_per_episode, epsilon_history):
    # Save plots
    fig = plt.figure(1)

    # Plot average rewards (Y-axis) vs episodes (X-axis)
    mean_rewards = np.zeros(len(rewards_per_episode))
    for x in range(len(mean_rewards)):
        mean_rewards[x] = np.mean(rewards_per_episode[max(0, x-99):(x+1)])
    plt.subplot(121) # plot on a 1 row x 2 col grid, at cell 1
    # plt.xlabel('Episodes')
    plt.ylabel('Mean Rewards')
    plt.plot(mean_rewards)

    # Plot epsilon decay (Y-axis) vs episodes (X-axis)
    plt.subplot(122) # plot on a 1 row x 2 col grid, at cell 2
    # plt.xlabel('Time Steps')
    plt.ylabel('Epsilon Decay')
    plt.plot(epsilon_history)

    plt.subplots_adjust(wspace=1.0, hspace=1.0)

    # Save plots
    fig.savefig(GRAPH_FILE)
    plt.close(fig)

In [55]:
def train():
    global opimizer
    env = gym.make('CartPole-v1', render_mode = None)

    num_actions = env.action_space.n
    num_states = env.observation_space.shape[0]
    rewards_per_episode = []
    best_reward = -float('inf')
    policy_net = DQN(num_states, num_actions).to(DEVICE)
    if os.path.exists(MODEL_FILE):
        policy_net.load_state_dict(torch.load(MODEL_FILE))
        policy_net.eval()
            
    memory = ReplayMemory(MEMORY_SIZE)
    epsilon_history = []
    epsilon = EPSILON_INIT
    step_count = 0
    
    
    start_time = datetime.now()
    last_graph_update_time = start_time

    taregt_net = DQN(num_states, num_actions).to(DEVICE)
    taregt_net.load_state_dict(policy_net.state_dict())

    opimizer = torch.optim.Adam(policy_net.parameters(), lr=LEARNING_RATE)
    with open(LOG_FILE, 'w') as file:
        file.write(f'Started Training {start_time.strftime(DATE_FORMAT)}' + '\n')
    print(f'Started Training {start_time.strftime(DATE_FORMAT)}' + '\n')
    for episode in itertools.count():
        state, _ = env.reset()
        state = torch.tensor(state, dtype=torch.float32, device=DEVICE)
        terminated = False

        episode_reward = 0.0

        while not terminated and episode_reward < STOP_REWARD:

            if random.random() < epsilon:
                action = env.action_space.sample()
                action = torch.tensor(action, dtype=torch.int64, device=DEVICE)
            else:
                with torch.no_grad():
                    action = policy_net(state.unsqueeze(0)).squeeze().argmax()
            
            new_state, reward, terminated, truncated, info = env.step(action.item())
            
            episode_reward += reward

            new_state = torch.tensor(new_state, dtype=torch.float32, device=DEVICE)
            reward = torch.tensor(reward, dtype=torch.float32, device=DEVICE)

            # Store state, the action taken, which state the action led to, reward received, whether the episode terminated
            # into the replay memory 
            memory.record((state, action, new_state, reward, terminated))
            
            step_count += 1
            
            state = new_state

            rewards_per_episode.append(episode_reward)
    
        if episode_reward >= best_reward:
            torch.save(policy_net.state_dict(), MODEL_FILE)
            log_message = f"{datetime.now().strftime(DATE_FORMAT)}: New best reward {episode_reward:0.1f} ({(episode_reward-best_reward)/best_reward*100:+.1f}%) at episode {episode}, saving model..."
            best_reward = episode_reward
            print(log_message)
            with open(LOG_FILE, 'a') as file:
                file.write(log_message + '\n')

        # Update graph every x seconds
        current_time = datetime.now()
        if current_time - last_graph_update_time > timedelta(seconds=10):
            save_graph(rewards_per_episode, epsilon_history)
            last_graph_update_time = current_time
        
        if len(memory) > BATCH_SIZE:
            batch = memory.sample(BATCH_SIZE)

            optimize(batch, policy_net, taregt_net)
        
            epsilon_history.append(epsilon)
            epsilon = max(EPSILON_MIN, epsilon * EPSILON_DECAY)
            
            if step_count >= SYNCH_RATE:
                taregt_net.load_state_dict(policy_net.state_dict())
                step_count = 0
        if episode_reward > STOP_REWARD:
            break
    env.close()

In [56]:
# train()

In [62]:
def test(episodes=4):
    env = gym.make('CartPole-v1', render_mode='human') #**env_make_params)

    num_actions = env.action_space.n
    num_states = env.observation_space.shape[0]
    rewards_per_episode = []
    
    policy_net = DQN(num_states, num_actions).to(DEVICE)
    if os.path.exists(MODEL_FILE):
        # Load learned policy
        policy_net.load_state_dict(torch.load(MODEL_FILE))

        # switch model to evaluation mode
        policy_net.eval()            

    for episode in range(episodes):
        state, _ = env.reset()
        state = torch.tensor(state, dtype=torch.float32, device=DEVICE)
        terminated = False

        episode_reward = 0.0

        while not terminated and episode_reward < STOP_REWARD:
            with torch.no_grad():
                action = policy_net(state.unsqueeze(0)).squeeze().argmax()
            
            new_state, reward, terminated, truncated, info = env.step(action.item())
            new_state = torch.tensor(new_state, dtype=torch.float32, device=DEVICE)
            
            episode_reward += reward

            state = new_state

            rewards_per_episode.append(episode_reward)
            # print(f'Episode {episode} reward so far {episode_reward}')
        print(f'Done episode {episode} with reward {episode_reward}')
    env.close()

In [None]:
test()

  policy_net.load_state_dict(torch.load(MODEL_FILE))


Done episode 0 with reward 2700.0
Done episode 1 with reward 100000.0
Done episode 2 with reward 100000.0
