# Config

In [None]:
# REQUIRED
ALGO = "a2c"
ENV = "MiniGrid-Fetch-8x8-N3-v0"

# OPTIONAL
MODEL_NAME = "EXPERIMENTAL"
SEED = 1
LOG_INTERVAL = 1
SAVE_INTERVAL = 1
PROCS = 16
FRAMES = 10**5

# HYPERPARAMETERS
EPOCHS = 4
BATCH_SIZE = 256
FRAMES_PER_PROC = None
DISCOUNT = 0.99
LR = 0.001
GAE_LAMBDA = 0.95
ENTROPY_COEF = 0.01
VALUE_LOSS_COEF = 0.5
MAX_GRAD_NORM = 0.5
OPTIM_EPSILON = 1e-8
OPTIM_ALPHA = 0.99
CLIP_EPSILON = 0.2
RECURRENCE = 1
TEXT = False

# Reward Shaper

In [None]:
def reshape_reward(obs, action, reward, done):
    if done:
        # implement reward shaping here
        pass
    return reward

# Train

In [None]:
import datetime
import tensorboardX
import time
import datetime
import torch_ac

import utils
from utils import device
from model import ACModel

# Set run dir

date = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")
default_model_name = f"{ENV}_{ALGO}_seed{SEED}_{date}"

model_name = MODEL_NAME or default_model_name
model_dir = utils.get_model_dir(model_name)

# Load loggers and Tensorboard writer

txt_logger = utils.get_txt_logger(model_dir)
csv_file, csv_logger = utils.get_csv_logger(model_dir)
tb_writer = tensorboardX.SummaryWriter(model_dir)

# Set seed for all randomness sources

utils.seed(SEED)

# Set device

txt_logger.info(f"Device: {device}\n")

# Load environments

envs = []
for i in range(PROCS):
    envs.append(utils.make_env(ENV, SEED + 10000 * i))
txt_logger.info("Environments loaded\n")

# Load training status

try:
    status = utils.get_status(model_dir)
except OSError:
    status = {"num_frames": 0, "update": 0}
txt_logger.info("Training status loaded\n")

# Load observations preprocessor

obs_space, preprocess_obss = utils.get_obss_preprocessor(envs[0].observation_space)
if "vocab" in status:
    preprocess_obss.vocab.load_vocab(status["vocab"])
txt_logger.info("Observations preprocessor loaded")

# Load model

acmodel = ACModel(obs_space, envs[0].action_space, RECURRENCE > 1, TEXT)
if "model_state" in status:
    acmodel.load_state_dict(status["model_state"])
acmodel.to(device)
txt_logger.info("Model loaded\n")
txt_logger.info("{}\n".format(acmodel))

# Load algo

if ALGO == "a2c":
    algo = torch_ac.A2CAlgo(envs, acmodel, device, FRAMES_PER_PROC, DISCOUNT, LR, GAE_LAMBDA,
                            ENTROPY_COEF, VALUE_LOSS_COEF, MAX_GRAD_NORM, RECURRENCE,
                            OPTIM_ALPHA, OPTIM_EPSILON, preprocess_obss, reshape_reward)
elif ALGO == "ppo":
    algo = torch_ac.PPOAlgo(envs, acmodel, device, FRAMES_PER_PROC, DISCOUNT, LR, GAE_LAMBDA,
                            ENTROPY_COEF, VALUE_LOSS_COEF, MAX_GRAD_NORM, RECURRENCE,
                            OPTIM_EPSILON, CLIP_EPSILON, EPOCHS, BATCH_SIZE, preprocess_obss, reshape_reward)
else:
    raise ValueError("Incorrect algorithm name: {}".format(ALGO))

if "optimizer_state" in status:
    algo.optimizer.load_state_dict(status["optimizer_state"])
txt_logger.info("Optimizer loaded\n")

# Train model

num_frames = status["num_frames"]
update = status["update"]
start_time = time.time()

while num_frames < FRAMES:
    # Update model parameters
    update_start_time = time.time()
    exps, logs1 = algo.collect_experiences()
    logs2 = algo.update_parameters(exps)
    logs = {**logs1, **logs2}
    update_end_time = time.time()

    num_frames += logs["num_frames"]
    update += 1

    # Print logs

    if update % LOG_INTERVAL == 0:
        fps = logs["num_frames"] / (update_end_time - update_start_time)
        duration = int(time.time() - start_time)
        return_per_episode = utils.synthesize(logs["return_per_episode"])
        rreturn_per_episode = utils.synthesize(logs["reshaped_return_per_episode"])
        num_frames_per_episode = utils.synthesize(logs["num_frames_per_episode"])

        header = ["update", "frames", "FPS", "duration"]
        data = [update, num_frames, fps, duration]
        header += ["rreturn_" + key for key in rreturn_per_episode.keys()]
        data += rreturn_per_episode.values()
        header += ["num_frames_" + key for key in num_frames_per_episode.keys()]
        data += num_frames_per_episode.values()
        header += ["entropy", "value", "policy_loss", "value_loss", "grad_norm"]
        data += [logs["entropy"], logs["value"], logs["policy_loss"], logs["value_loss"], logs["grad_norm"]]

        txt_logger.info(
            "U {} | F {:06} | FPS {:04.0f} | D {} | rR:μσmM {:.2f} {:.2f} {:.2f} {:.2f} | F:μσmM {:.1f} {:.1f} {} {} | H {:.3f} | V {:.3f} | pL {:.3f} | vL {:.3f} | ∇ {:.3f}"
            .format(*data))

        header += ["return_" + key for key in return_per_episode.keys()]
        data += return_per_episode.values()

        if status["num_frames"] == 0:
            csv_logger.writerow(header)
        csv_logger.writerow(data)
        csv_file.flush()

        for field, value in zip(header, data):
            tb_writer.add_scalar(field, value, num_frames)

    # Save status

    if SAVE_INTERVAL > 0 and update % SAVE_INTERVAL == 0:
        status = {"num_frames": num_frames, "update": update,
                    "model_state": acmodel.state_dict(), "optimizer_state": algo.optimizer.state_dict()}
        if hasattr(preprocess_obss, "vocab"):
            status["vocab"] = preprocess_obss.vocab.vocab
        utils.save_status(status, model_dir)
        txt_logger.info("Status saved")
