In [None]:
import os
import sys

import importlib
import numpy as np
import pandas as pd
from pathlib import Path

MODULE_NAME = "infer"
MAIN_PATH = "/home/sequenzia/dev/repos/atari-rl"

WANDB_ON = False
PROJECT = "solen-rl-project-eval-2"

NO_RENDER = True

N_ENVS = 2
N_STEPS = 10000

module_path = f"{MAIN_PATH}/utils/{MODULE_NAME}.py"
agents_path = f"{MAIN_PATH}/agents"
data_path = f"{MAIN_PATH}/data"

spec = importlib.util.spec_from_file_location(MODULE_NAME, module_path)
infer = importlib.util.module_from_spec(spec)
sys.modules[MODULE_NAME] = infer
spec.loader.exec_module(infer)

all_infer_logs = {}
all_infer_data = {}


ALGOS = ["ppo", "a2c"]

GAMES = ["Breakout",
         "Pong",
         "SpaceInvaders"]
        #  "Qbert"]
        #  "Seaquest",
        #  "Centipede",
        #  "MsPacman",
        #  "Asterix",
        #  "Asteroids",
        #  "Assault"]

for algo in ALGOS:

    for game in GAMES:
        
        ENV_ID = f"ALE/{game}-v5"

        RUN_KEY = f"{algo.upper()}_{game}"

        infer_logs = infer.infer(run_key=RUN_KEY,
                                 env_id=ENV_ID,
                                 algo=algo,
                                 game=game,
                                 agents_path=agents_path,
                                 n_envs=N_ENVS,
                                 n_steps=N_STEPS,
                                 no_render=NO_RENDER,
                                 project=PROJECT,
                                 wandb_on=WANDB_ON
                                 debug_on=False)
        
        # all_infer_logs[RUN_KEY] = infer_logs
        
        infer_data_np = np.empty((0,5))

        for idx in range(len(infer_logs)):

            infer_data_np = np.vstack((infer_data_np, 
                                       np.array([infer_logs[idx].scores, 
                                                 infer_logs[idx].times, 
                                                 infer_logs[idx].lengths,
                                                 infer_logs[idx].frame_numbers,
                                                 infer_logs[idx].run_frame_numbers]).T))

        infer_data = pd.DataFrame(infer_data_np, 
                                  columns=["scores", 
                                           "times", 
                                           "lengths", 
                                           "frame_numbers", 
                                           "run_frame_numbers"])

        all_infer_data[RUN_KEY] = infer_data


In [None]:
all_infer_data['A2C_Pong']

In [None]:
losses = [223,2232,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]

{f"losses/loss-{ii}": loss for ii, loss in enumerate(losses)}

In [None]:
import wandb

wandb.init(project=PROJECT,
           name="ppo_breakout_eval_14",
           group="ppo",
           job_type="eval",
           settings=wandb.Settings(disable_job_creation=True))


run_frame_numbers = []

episode_scores = []
episode_times = []
episode_lengths = []
episode_lives = []
episode_frame_numbers = []


for infer_log in all_infer_logs['PPO_Breakout']:

    for episode in infer_log.episode_logs:

        run_frame_numbers.append((episode.run_frame_number, episode.run_frame_number))

        episode_scores.append((episode.run_frame_number, episode.episode_score))
        episode_lengths.append((episode.run_frame_number, episode.episode_length))
        episode_times.append((episode.run_frame_number, episode.episode_time))
        episode_lives.append((episode.run_frame_number, episode.episode_lives))
        episode_frame_numbers.append((episode.run_frame_number, episode.episode_frame_number))


run_frame_numbers.sort(key=lambda x: x[0])

episode_scores.sort(key=lambda x: x[0])
episode_lengths.sort(key=lambda x: x[0])
episode_times.sort(key=lambda x: x[0])
episode_lives.sort(key=lambda x: x[0])
episode_frame_numbers.sort(key=lambda x: x[0])


for idx, run_frame_number in enumerate(run_frame_numbers):

    wandb.log(data={"episode_score": episode_scores[idx][1],
                    "episode_time": episode_times[idx][1],
                    "episode_length": episode_lengths[idx][1],
                    "episode_lives": episode_lives[idx][1],
                    "run_frame_number": run_frame_number[0]},
                    step=run_frame_number[0])

wandb.finish()




In [None]:
episode_frame_numbers

In [None]:
_run_frame_numbers

In [None]:
    # wandb.log(data={f"episode_times/time-{idx}": time for idx, time in enumerate(episode_times)})
    # wandb.log(data={f"episode_lengths/length-{idx}": length for idx, length in enumerate(episode_lengths)})
    # wandb.log(data={f"episode_lives/lives-{idx}": lives for idx, lives in enumerate(episode_lives)})
    # wandb.log(data={f"episode_frame_numbers/frame_number-{idx}": frame_number for idx, frame_number in enumerate(episode_frame_numbers)})
    # wandb.log(data={f"run_frame_numbers/frame_number-{idx}": frame_number for idx, frame_number in enumerate(run_frame_numbers)})


In [None]:
wandb_log

In [None]:
import wandb

wandb.init(project=PROJECT,
           name="ppo_breakout_eval",
           group="ppo",
           job_type="eval")

wandb_tbl = wandb.Table(dataframe=all_infer_data['PPO_Breakout'])

wandb.log({"ppo_breakout_eval": wandb_tbl})

wandb.finish()

In [None]:
all_infer_logs

In [None]:
all_infer_data['PPO_BREAKOUT']

In [None]:
all_infer_data['A2C_BREAKOUT']