In [None]:
import json
import os
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict

from src.checkpoint import RESULT_DIR, MODEL_DIR, STAT_DIR
from src.task_selectors.factory import get_selector_name

stat_directory = RESULT_DIR + STAT_DIR

In [None]:
class Run():
    def __init__(self, filename: str):
        filename_splits = filename.split("_")
        self.env_name = filename_splits[0]
        self.alg_name = filename_splits[1]
        self.seed = int(filename_splits[2])
        self.selector = get_selector_name(int(filename_splits[3]))
        self.id = filename_splits[4].split(".")[0]
        self.filename = filename
        
        # load and sort stats
        with open(f"{stat_directory}{filename}", "r") as f:
            self.stats: dict = json.load(f)
            self.stats_loc = self.stats.pop("loc")
        self.stats_generic: dict[str, list[float]] = {}
        self.stats_envs: dict[str, dict[str, list[float]]] = {}
        self.stats_envs_loc: dict[str, dict[str, list[int]]] = {}
        for k, v in self.stats.items():
            splits = str(k).split(" ")
            if splits[0] == "Env":
                stat = " ".join(splits[2:])
                env_index = int(splits[1])
                if stat not in self.stats_envs.keys():
                    self.stats_envs[stat] = {}
                    self.stats_envs_loc[stat] = {}
                self.stats_envs[stat][env_index] = v
                self.stats_envs_loc[stat][env_index] = self.stats_loc[k]
            else:
                self.stats_generic[k] = v
        for k, v in self.stats_envs.items():
            self.stats_envs[k] = OrderedDict(sorted(v.items()))
        
    def __str__(self):
        return f"{self.selector} with seed {self.seed} - {self.id}"
    
    def plot_stat(self, values, title, loc=None):
        def plot_line(values, label, loc=None):
            if loc is None:
                plt.plot(values, label=label)
            else:
                plt.plot(loc, values, label=label)
        plt.figure(figsize=(15,4))
        print(type(values))
        if type(values) == dict:
            for k, v in values.items():
                plot_line(v, k, loc[k])
        else:
            plot_line(values, title, loc)
        plt.title(title)
        plt.legend()
        plt.show()
    
    def plot_generic(self):
        for k, v in self.stats_generic.items():
            self.plot_stat(v, k)
            
    def plot_envs(self, envs: list[int] = None, flattenings: int = 0):
        for k, stat in self.stats_envs.items():
            to_plot = {}
            to_plot_loc = {}
            for env, v in stat.items():
                if envs is None or env in envs:
                    for _ in range(flattenings):
                        v =  v[0:1] + [(v[i-1] + v[i] + v[i+1]) / 3 for i in range(1, len(v) - 1)] + v[-2:-1]
                    to_plot[env] = v
                    to_plot_loc[env] = self.stats_envs_loc[k][env]
            self.plot_stat(to_plot, k, to_plot_loc)
            

In [None]:
run_ids = [f for f in os.listdir(stat_directory)]
runs = [Run(name) for name in run_ids]

In [None]:
for r in runs:
    print(r)
    r.plot_envs(flattenings=10)

In [None]:
for r in runs:
    print(r)
    r.plot_envs(envs=[1, 4, 8, 6], flattenings=0)