In [1]:
import math
import random
import time
import torch
import torch.optim as optim


from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.base_env import ActionTuple

import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from models import DQN, DuelingDQN, ReplayMemory, optimize_model

def select_action(state_in):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            # t.max(1) will return the largest column value of each row.
            # second column on max result is index of where max element was
            # found, so we pick action with the larger expected reward.
            action_out = policy_net(state_in).max(1).indices.view(1, 1)
            return action_out
    else:
        return torch.tensor(spec.action_spec.random_action(1).discrete, device=device, dtype=torch.long)


In [3]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

timer_start = time.perf_counter()
# BATCH_SIZE is the number of transitions sampled from the replay buffer
# GAMMA is the discount factor as mentioned in the previous section
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
# TAU is the update rate of the target network
# LR is the learning rate of the ``AdamW`` optimizer
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.1
EPS_DECAY = 50_000
TAU = 0.005
LR = 1e-4

SAVE_WEIGHTS = True
LOAD_WEIGHTS = False
steps_done = 0
STEPS = 500
DOUBLE = True
DUELING = False
GRAPHICS = False

if torch.cuda.is_available():
    num_episodes = 5000
else:
    num_episodes = 50


In [3]:

env = UnityEnvironment(file_name="unity_builds/snake", seed=13, side_channels=[], no_graphics=not GRAPHICS)
env.reset()

behaviour_name = list(env.behavior_specs)[0]
spec = env.behavior_specs[behaviour_name]

n_actions = spec.action_spec.discrete_branches[0]
state, _ = env.get_steps(behaviour_name)
state = state.obs[0]
# n_observations = len(state)
n_observations = spec.observation_specs[0].shape[0]

if DUELING:
    policy_net = DuelingDQN(n_observations, n_actions).to(device)
    target_net = DuelingDQN(n_observations, n_actions).to(device)
else:
    policy_net = DQN(n_observations, n_actions).to(device)
    target_net = DQN(n_observations, n_actions).to(device)

if LOAD_WEIGHTS:
    policy_net.load_state_dict(torch.load('weights/policy_net.pth'))
    print("Loaded weights from file")

target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(1000)
print(f"Initalized DQN with {n_observations} observations and {n_actions} actions")
rewards = []
losses = []

Loaded weights from file
Initalized DQN with 10 observations and 4 actions


In [4]:

pbar = tqdm(range(num_episodes))
for i_episode in pbar:
    if i_episode % 100 == 0 and i_episode != 0:
        print(
            f"Episode {i_episode}, avg reward: {np.mean(rewards[-100:]):.2f}, "
            f"epsilon: {EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY):.2f}")
        torch.save(policy_net.state_dict(), 'weights/policy_net.pth')
        print("Checkpoint: Saved weights to file")
    step_rewards = []
    step_losses = []
    # Initialize the environment and get its state
    env.reset()
    decision_steps, terminal_steps = env.get_steps(behaviour_name)
    state = decision_steps.obs[0]
    state = torch.tensor(state, dtype=torch.float32, device=device)
    for t in range(STEPS):
        # time.sleep(1)
        action = select_action(state)
        action_tuple = ActionTuple()
        action_tuple.add_discrete(action.cpu().numpy())
        env.set_actions(behaviour_name, action_tuple)
        env.step()

        assert action.shape[0] == state.shape[0]

        decision_steps, terminal_steps = env.get_steps(behaviour_name)
        observation = decision_steps.obs[0]
        reward = np.zeros(state.shape[0])
        if len(decision_steps.reward) > 0:
            reward += decision_steps.reward
        if len(terminal_steps.reward) > 0:
            reward += terminal_steps.reward
        done = len(decision_steps) == 0
        terminated = len(terminal_steps) > 0
        # print(observation, reward, done, terminated, action)
        assert len(reward) == state.shape[0] == action.shape[0]

        # if t % 50 == 0:
        #     print(f"step: {t}, reward: {reward}, state: {state}, action: {action}")
        reward = torch.tensor(reward, device=device)
        step_rewards.append(reward.item())


        if done or terminated:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32, device=device)

        # Store the transition in memory
        memory.push(state, action, next_state, reward)

        # Move to the next state
        state = next_state

        # Perform one step of the optimization (on the policy network)
        loss = optimize_model(memory, policy_net, target_net, optimizer, device, double=DOUBLE, BATCH_SIZE=BATCH_SIZE, GAMMA=GAMMA)
        if loss is not None:
            step_losses.append(loss)

        # Soft update of the target network's weights
        # θ′ ← τ θ + (1 −τ )θ′
        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key] * TAU + target_net_state_dict[key] * (1 - TAU)
        target_net.load_state_dict(target_net_state_dict)

        if terminated or done:
            break
    
    if len(step_losses) == 0:
        step_losses.append(0)
    losses.append(np.mean(step_losses))
    ep_rewards = sum(step_rewards)
    pbar.set_description(f"E {i_episode} done after {t + 1} t, with r: {ep_rewards:.2f} and l: {np.mean(step_losses):.2f}")
    rewards.append(ep_rewards)

env.close()
print(f"Finished training in {(time.perf_counter() - timer_start)/60 :.3} minutes")


E 100 done after 7 t, with r: -1.00 and l: 0.18:  10%|█         | 101/1000 [02:11<16:13,  1.08s/it]

Episode 100, avg reward: 0.10, epsilon: 0.63
Checkpoint: Saved weights to file


E 199 done after 35 t, with r: -1.00 and l: 0.15:  20%|██        | 200/1000 [04:27<20:03,  1.50s/it] 

Episode 200, avg reward: -0.58, epsilon: 0.56
Checkpoint: Saved weights to file


E 299 done after 1 t, with r: -1.00 and l: 0.16:  30%|██▉       | 299/1000 [06:22<15:35,  1.33s/it]  

Episode 300, avg reward: -0.59, epsilon: 0.51
Checkpoint: Saved weights to file


E 399 done after 60 t, with r: 0.00 and l: 0.11:  40%|████      | 400/1000 [08:55<14:43,  1.47s/it]  

Episode 400, avg reward: -0.56, epsilon: 0.45
Checkpoint: Saved weights to file


E 499 done after 94 t, with r: 0.00 and l: 0.11:  50%|█████     | 500/1000 [10:49<09:54,  1.19s/it]  

Episode 500, avg reward: -0.64, epsilon: 0.42
Checkpoint: Saved weights to file


E 599 done after 3 t, with r: -1.00 and l: 0.14:  60%|█████▉    | 599/1000 [12:50<07:08,  1.07s/it]  

Episode 600, avg reward: -0.52, epsilon: 0.38
Checkpoint: Saved weights to file


E 699 done after 3 t, with r: -1.00 and l: 0.03:  70%|██████▉   | 699/1000 [15:32<11:29,  2.29s/it]  

Episode 700, avg reward: -0.41, epsilon: 0.34
Checkpoint: Saved weights to file


E 799 done after 23 t, with r: -1.00 and l: 0.04:  80%|████████  | 800/1000 [18:48<07:55,  2.38s/it] 

Episode 800, avg reward: -0.32, epsilon: 0.30
Checkpoint: Saved weights to file


E 899 done after 4 t, with r: -1.00 and l: 0.06:  90%|█████████ | 900/1000 [20:50<01:11,  1.40it/s]  

Episode 900, avg reward: -0.69, epsilon: 0.27
Checkpoint: Saved weights to file


E 999 done after 147 t, with r: 0.00 and l: 0.05: 100%|██████████| 1000/1000 [23:05<00:00,  1.39s/it]

Finished training in 23.2 minutes





In [None]:

if SAVE_WEIGHTS:
    torch.save(policy_net.state_dict(), 'weights/policy_net.pth')
    rewards_df = pd.DataFrame(rewards, columns=['reward'])
    rewards_df.to_csv('rewards.csv', index=False)
    losses_df = pd.DataFrame(losses, columns=['loss'])
    losses_df.to_csv('losses.csv', index=False)



In [None]:

plt.figure(figsize=(16, 5))
plt.plot(rewards)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('Rewards over Episodes')
plt.show()

plt.figure(figsize=(16, 5))
plt.plot(losses)
plt.xlabel('Episode')
plt.ylabel('Loss')
plt.title('Loss over Episodes')
plt.show()


Inference

In [6]:
env = UnityEnvironment(file_name="unity_builds/snake", seed=1, side_channels=[], no_graphics=False)
env.reset()

behaviour_name = list(env.behavior_specs)[0]
spec = env.behavior_specs[behaviour_name]

n_actions = spec.action_spec.discrete_branches[0]
state, _ = env.get_steps(behaviour_name)
state = state.obs[0]
# n_observations = len(state)
n_observations = spec.observation_specs[0].shape[0]

# policy_net = DQN(n_observations, n_actions).to(device)
# policy_net = DuelingDQN(n_observations, n_actions).to(device)
# policy_net.load_state_dict(torch.load('weights/large_observations/4k/policy_net.pth'))
# policy_net.load_state_dict(torch.load('weights/policy_net.pth'))
for t in range(5):
    env.reset()
    decision_steps, terminal_steps = env.get_steps(behaviour_name)
    state = decision_steps.obs[0]
    state = torch.tensor(state, dtype=torch.float32, device=device)
    for _ in range(250):
        with torch.no_grad():
            action = policy_net(state).max(1).indices.view(1, 1)
    
        action_tuple = ActionTuple()
        action_tuple.add_discrete(action.cpu().numpy())
        env.set_actions(behaviour_name, action_tuple)
        env.step()

        decision_steps, terminal_steps = env.get_steps(behaviour_name)
        observation = decision_steps.obs[0]
        done = len(decision_steps) == 0
        terminated = len(terminal_steps) > 0

        if done or terminated:
            break
        state = torch.tensor(observation, dtype=torch.float32, device=device)
env.close()

KeyboardInterrupt: 