In [1]:
import copy
from typing import List, Tuple

import numpy as np
import torch
import tqdm

from envs import connectFourEnv, fourRoomsEnv, mazeEnv
from networks import regular_net, rot_equi_net
from utils import dqn_utils, plotting, schedules
from utils.schedules import LinearSchedule
from utils import dqn_utils
from utils import plotting
import memory_connectFour

In [2]:
def train_and_eval(params: dict, env_params: dict) -> Tuple[List[List[float]], List[float], List[float], List[int], dict]:
    """Run training and evaluation.

    Args:
        params: dict of training parameters
        env_params: dict of environment params
    Returns:
        evals: list of list of evaluation returns during training, for each evaluation environment. For four rooms and connect
        four, there is only one evaluation environment. For the maze there are two, one a duplicate of the training env and one
        randomly rotated
        returns: list of returns from training
        losses: list of losses from training
        lengths: list of episode lengths from training
        saved_models: dict of saved copies of the model
    """
    # create environments and models
    if env_params["name"] == "maze":
        env = mazeEnv.MazeEnv(env_params["dim"], env_params["seed"])
        eval_envs = [mazeEnv.MazeEnv(env_params["dim"], env_params["eval_seed"], env.grid), # copy of env
                    mazeEnv.MazeEnv(env_params["dim"], env_params["eval_seed"], env.grid, np.random.choice([1,2,3]))] # rotated env
        dqn_model = regular_net.CNN() if params["model"] == "regular" else rot_equi_net.EquiCNN()
        dqn_target = regular_net.CNN() if params["model"] == "regular" else rot_equi_net.EquiCNN()
    
    elif env_params["name"] == "fourrooms":
        env = fourRoomsEnv.FourRoomsEnv(env_params["seed"])
        eval_envs = [fourRoomsEnv.FourRoomsEnv(env_params["eval_seed"])]
        dqn_model = regular_net.CNN(2) if params["model"] == "regular" else rot_equi_net.EquiCNN(2)
        dqn_target = regular_net.CNN(2) if params["model"] == "regular" else rot_equi_net.EquiCNN(2)
    elif env_params["name"] == "connectfour":
        env = connectFourEnv.ConnectFourEnv(env_params["seed"])
        eval_envs = [connectFourEnv.ConnectFourEnv(env_params["eval_seed"])]

    # train on gpu if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
    # create optimizer
    optimizer = torch.optim.Adam(dqn_model.parameters(), lr=1e-3)

    # epsilon function
    exploration = LinearSchedule(1.0, 0.1, params["num_steps"])
    
    # create and prepopulate memory
    memory = memory_connectFour.ReplayMemory(params["replay_size"], env.grid.shape, device)
    memory.populate(env, params["replay_prepopulate_steps"])
    
    # collect results
    rewards = []
    returns = []  # returns from training
    lengths = []
    losses = []
    evals = []  # returns from evaluation runs
    
    # for storing model
    t_saves = np.linspace(0, params["num_steps"], params["num_saves"] - 1, endpoint=False)
    saved_models = {}
    
    i_episode = 0  # index of current episode
    t_episode = 0  # time step

    # reset everything
    obs = env.reset()

    # run training loop
    pbar = tqdm.trange(params["num_steps"])
    for t_total in pbar:
    
        if t_total in t_saves:
            model_name = f'{100 * t_total / params["num_steps"]:04.1f}'.replace('.', '_')
            saved_models[model_name] = copy.deepcopy(dqn_model)
    
        # get action using e-greedy
        eps = exploration.value(t_total)  # get current epsilon value
        if np.random.rand() < eps:
            action = np.random.choice(env.action_space)
        else:
            with torch.no_grad():
                q_values = dqn_model(torch.tensor(obs, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device))  # add dim to observation
                max_q_idx = torch.where(q_values == q_values.max())[0]
                action = np.random.choice(max_q_idx.tolist())
            
        # step forward env
        next_obs, reward, done = env.step(action)
        rewards.append(reward)
    
        # add transition to memory
        memory.add(obs, action, reward, next_obs, done)

        # get batch and train every 4 time steps
        if t_total%4 == 0:
            batch = memory.sample(params["batch_size"])
            loss = dqn_utils.train_dqn_batch(optimizer, batch, dqn_model, dqn_target, params["gamma"])
            losses.append(loss)

        # update target every 10,000 time steps
        if t_total%10000 == 0:
            dqn_target.load_state_dict(dqn_model.state_dict())

        # evaluate model every 1,000 steps
        if t_total%1000 == 0:
            evals.append([dqn_utils.evaluate(dqn_model, eval_env, device) for eval_env in eval_envs])
    
        if done:
            G = 0
            for r in rewards[::-1]:
                G = params["gamma"] * G + r
    
            returns.append(G)
            lengths.append(t_episode)
    
            pbar.set_description(
                f'Episode: {i_episode} | Steps: {t_episode + 1} | Return: {G:5.2f} | Epsilon: {eps:4.2f}'
            )
    
            # reset
            t_episode = 0
            i_episode += 1
            rewards = []
            obs = env.reset()
            
        else:
            obs = np.copy(next_obs)
            t_episode += 1
    
    saved_models['100_0'] = copy.deepcopy(dqn_model)

    return evals, returns, losses, lengths, saved_models

In [6]:
params = {"num_steps": 1_000,
          "num_saves": 5,
          "replay_size": 200_000,
          "replay_prepopulate_steps": 50_000,
          "batch_size": 64,
          "gamma": 0.99,
          "model": "regular"} # equi or regular 

env_params = {"name": "fourrooms", # ["maze", "fourrooms", "connectfour"]
              "dim": 3,  # only needed for maze
              "seed": None,
              "eval_seed": None,
             }
evals, returns, losses, lengths, saved_models = train_and_eval(params, env_params)

Episode: 1 | Steps: 297 | Return: -94.38 | Epsilon: 0.67: 100%|█| 1000/1000 [00:
