In [1]:
from __future__ import annotations

import gymnasium as gym
from DoubleDQN import DoubleDQNAgent
from DQNAgent import DQNAgent
from DuelingDQNAgent import DuelingDQNAgent
from DuelingDoubleDQN import DuelingDoubleDQNAgent
import matplotlib.pyplot as plt
import numpy as np
import wandb

In [2]:
lr = 1e-3
# n_steps = 10000
running_reward_limit = 400
n_episodes = 1500
eps_decay = 0.99 #Rate at which to reduce chance of random action being taken
gamma = 0.99 #Discount factor
frame_count = 0
update_target_net = 1000
episode_reward_history = []
done = False
episode_count=0

wandb.init(name="DoubleDQNAgent_Seaquest",project="deep_rl")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mrullofederico16[0m ([33mfede-[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
env = gym.make("ALE/Seaquest-v5")
env = gym.wrappers.AtariPreprocessing(env, frame_skip=1) #Frame preprocessing, resize frame to 84x84, following "Play Atari with DeepRL"
env = gym.wrappers.FrameStack(env,4) #Stack frames into groups of 4
env = gym.wrappers.RecordEpisodeStatistics(env, deque_size=5)

In [4]:
agent = DoubleDQNAgent(env.action_space.n,lr,gamma, eps_decay)

In [5]:
for _ in range(n_episodes):
    episode_reward = 0
    obs,info = env.reset()
    obs = np.array(obs)
    
    done = False
    while not done:
        frame_count+=1
        action = agent.action_selection(obs,frame_count)
        next_obs, reward, terminated, truncated, info = env.step(action)
        next_obs = np.array(next_obs)
        done = terminated or truncated
        agent.remember(obs,action, reward,next_obs, done)
        
        episode_reward += reward
        obs = next_obs

        if len(agent.buffer)>agent.batch_size and frame_count%4==0:
            agent.replay()
        
        if frame_count % update_target_net == 0:
            agent.update_target_net()
    

    episode_reward_history.append(episode_reward)

    if len(episode_reward_history)>100:
        del episode_reward_history[:1]

    running_reward = np.mean(episode_reward_history)

    episode_count += 1

    if episode_count % 100 == 0:
        agent.net.save_model("./saved_models/DoubleDQN_seaquest.h5")

    if len(env.return_queue) != 0 and len(env.length_queue) != 0:
        wandb.log({'Episode Rewards:':np.array(env.return_queue).flatten()[-1],'Episode Lengths':np.array(env.length_queue).flatten()[-1],'Episode Running Rewards:':running_reward})
    else:
        wandb.log({'Episode Running Rewards:':running_reward})

    if running_reward > running_reward_limit:
        print('Solved at episode{}'.format(episode_count))
        break
print('Reached Limit of episodes')



KeyboardInterrupt: 