# Config

In [None]:
# Required
ENV = "MiniGrid-Empty-Random-6x6-v0"
MODEL = "EXP"

# Optional
EPISODES = 100
SEED = 0
PROCS = 16
ARGMAX = False
WORST_EPS_TO_SHOW = 10
MEMORY = False
TEXT = False

# Eval

In [None]:
import time
import torch
from torch_ac.utils.penv import ParallelEnv

import utils
from utils import device

# Set seed for all randomness sources

utils.seed(SEED)

# Set device

print(f"Device: {device}\n")

# Load environments

envs = []
for i in range(PROCS):
    env = utils.make_env(ENV, SEED + 10000 * i)
    envs.append(env)
env = ParallelEnv(envs)
print("Environments loaded\n")

# Load agent

model_dir = utils.get_model_dir(MODEL)
agent = utils.Agent(env.observation_space, env.action_space, model_dir,
                    argmax=ARGMAX, num_envs=PROCS,
                    use_memory=MEMORY, use_text=TEXT)
print("Agent loaded\n")

# Initialize logs

logs = {"num_frames_per_episode": [], "return_per_episode": []}

# Run agent

start_time = time.time()

obss = env.reset()

log_done_counter = 0
log_episode_return = torch.zeros(PROCS, device=device)
log_episode_num_frames = torch.zeros(PROCS, device=device)

while log_done_counter < EPISODES:
    actions = agent.get_actions(obss)
    obss, rewards, terminateds, truncateds, _ = env.step(actions)
    dones = tuple(a | b for a, b in zip(terminateds, truncateds))
    agent.analyze_feedbacks(rewards, dones)

    log_episode_return += torch.tensor(rewards, device=device, dtype=torch.float)
    log_episode_num_frames += torch.ones(PROCS, device=device)

    for i, done in enumerate(dones):
        if done:
            log_done_counter += 1
            logs["return_per_episode"].append(log_episode_return[i].item())
            logs["num_frames_per_episode"].append(log_episode_num_frames[i].item())

    mask = 1 - torch.tensor(dones, device=device, dtype=torch.float)
    log_episode_return *= mask
    log_episode_num_frames *= mask

end_time = time.time()

# Print logs

num_frames = sum(logs["num_frames_per_episode"])
fps = num_frames / (end_time - start_time)
duration = int(end_time - start_time)
return_per_episode = utils.synthesize(logs["return_per_episode"])
num_frames_per_episode = utils.synthesize(logs["num_frames_per_episode"])

print("F {} | FPS {:.0f} | D {} | R:μσmM {:.2f} {:.2f} {:.2f} {:.2f} | F:μσmM {:.1f} {:.1f} {} {}"
        .format(num_frames, fps, duration,
                *return_per_episode.values(),
                *num_frames_per_episode.values()))

# Print worst episodes

n = WORST_EPS_TO_SHOW
if n > 0:
    print("\n{} worst episodes:".format(n))

    indexes = sorted(range(len(logs["return_per_episode"])), key=lambda k: logs["return_per_episode"][k])
    for i in indexes[:n]:
        print("- episode {}: R={}, F={}".format(i, logs["return_per_episode"][i], logs["num_frames_per_episode"][i]))