# Mount the drive and install gym-0.21.0 and highway environment

In [5]:
import sys
import os
try:
  from google.colab import drive
  drive.mount('/content/gdrive/')
  project_path = 'ENPM690/Project/DQN_highway_env'
  sys.path.append(os.path.join('/content/gdrive/MyDrive', project_path))

  !pip uninstall gym
  !pip install gym==0.21.0
  !pip install highway-env
  
except:
  print("Run only for google colab")

Run only for google colab


# Imports

In [6]:
from common_utils import *
from models.dqn_conv_v1 import DQN as DQN
from train import *

In [7]:
def main():
    opt = parse_opts()
    print(opt)

    if not os.path.exists(opt.save_folder):
        os.mkdir(opt.save_folder)
    if not os.path.exists(os.path.join(opt.save_folder, opt.env)):
        os.mkdir(os.path.join(opt.save_folder, opt.env))

    timestamp = time.strftime('%b-%d-%Y_%H%M', time.localtime())
    csv_file_name  = os.path.join(opt.save_folder, f'{timestamp}_stats.csv')

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    em = HighwayEnvManager(device)
    strategy = EpsilonGreedyStrategy(opt.eps_start, opt.eps_end, opt.eps_decay)
    agent = Agent(strategy, em.num_actions_available(), device)
    memory = ReplayMemory(opt.memory_size)
    
    policy_net = DQN(em.get_screen_height(), em.get_screen_width(), em.get_screen_stack(), em.num_actions_available()).to(device)
    target_net = DQN(em.get_screen_height(), em.get_screen_width(), em.get_screen_stack(), em.num_actions_available()).to(device)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

    optimizer = optim.Adam(params=policy_net.parameters(), lr=opt.lr)
    criterion = nn.MSELoss()

    episode_durations = []
    episode_rewards = []
    best_reward = 0
    best_state = None

    for episode in range(opt.num_episodes):
        duration, reward, loss_epoch = train_epoch(opt, em, agent, policy_net, target_net, memory, device, optimizer, criterion)
        episode_durations.append(duration)
        episode_rewards.append(reward)

        write2csv(filename = csv_file_name, duration = duration, reward = reward, loss = loss_epoch)
        moving_avg_period = 50
        avg_reward = get_moving_average(moving_avg_period, episode_rewards)
        print("Episode", episode, "\n",
        moving_avg_period, "episode average reward: ", "{:.4f}".format(avg_reward[-1]), " | currect episode reward: ", "{:.4f}".format(reward), "| duration :", duration)

        state = {'epoch': episode, 'state_dict': policy_net.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}

        if avg_reward[-1] > best_reward:
            best_reward = avg_reward[-1]
            best_state = state

        if episode % opt.target_update == 0:
            target_net.load_state_dict(policy_net.state_dict())
        
        if episode % opt.save_interval == opt.save_interval - 1:
            timestamp = time.strftime('%b-%d-%Y_%H%M', time.localtime())
            torch.save(best_state, os.path.join(os.path.join(opt.save_folder, opt.env),
                                          f'{opt.env}-Episode-{episode}-Reward-{ avg_reward[-1]}_{timestamp}.pth'))
            print("Model saved with average reward ",best_reward)
            best_reward = 0

In [8]:
main()

usage: ipykernel_launcher [-h] [-f F] [--env ENV] [--batch_size BATCH_SIZE]
                          [--gamma GAMMA] [--eps_start EPS_START]
                          [--eps_end EPS_END] [--eps_decay EPS_DECAY]
                          [--target_update TARGET_UPDATE]
                          [--memory_size MEMORY_SIZE] [--lr LR]
                          [--num_episodes NUM_EPISODES]
                          [--save_interval SAVE_INTERVAL]
                          [--save_folder SAVE_FOLDER]
                          [--model_name MODEL_NAME]
ipykernel_launcher: error: unrecognized arguments: --ip=127.0.0.1 --stdin=9003 --control=9001 --hb=9000 --Session.signature_scheme="hmac-sha256" --Session.key=b"0e4b4648-b306-496c-a95f-43771be84c78" --shell=9002 --transport="tcp" --iopub=9004 --f=/tmp/tmp-3369u4OQOOFBVICr.json


SystemExit: 2