In [None]:
import os
import sys
import argparse
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

In [None]:
def draw_reward(paths):
    sns.set()
    plt.figure()
    fig, ax = plt.subplots()
    for p, l in paths:
        df = pd.read_csv(p + '/eval.log', dtype='float')
        ndf = pd.DataFrame(columns=['episode', 'reward'])

        for i in range(len(df) // 100):
            ndf = ndf.append(pd.DataFrame({
                'episode': [(i + 1) * 100],
                'reward': [df[i * 100:(i + 1) * 100].mean()[1]]
            }))
        ndf.plot(ax=ax, x='episode', y='reward', label=l)

    plt.ylim([0, 200])
    plt.show()

In [None]:
paths = [
    ['../outputs/rv_n14_v3_mem_4000_20181206_151528/', 'mem_4000'],
    ['../outputs/rv_n14_v3_mem_40000_20181206_151546/', 'mem_40000'],
    ['../outputs/rv_n14_v3_mem_400000_20181206_151617/', 'mem_400000']
]
draw_reward(paths)

In [None]:
paths = [
    ['../outputs/rv_n14_v3_gamma_90_20181206_151747/', 'gamma_90'],
    ['../outputs/rv_n14_v3_gamma_95_20181206_151412/', 'gamma_95'],
    ['../outputs/rv_n14_v3_gamma_97_20181206_151449/', 'gamma_97']
]
draw_reward(paths)

In [None]:
def get_nrow(path):
    df = pd.read_csv(path + '/eval.log', dtype='float')
    return len(df)

In [None]:
def draw_heatmap(path, nrow, agent_num=6, height=20, width=20, split_ep=10000, per_row=3):
    for i in range(nrow // split_ep):
        print("Episode {:d} ~ {:d}".format(i * split_ep + 1, (i + 1) * split_ep))
        heatmap = np.zeros((agent_num, height, width))
        for j in range(0, split_ep, 50):
            e = i * split_ep + j
            fn = path + "/eval/episode{:06d}/task.log".format(e)
            task = np.loadtxt(fn, delimiter=",", usecols=(1, 2, 3), dtype=np.int16)
            for t in task:
                aid, y, x = t
                if aid == -1:
                    continue
                heatmap[aid, y, x] += 1

        fig, ax = plt.subplots((agent_num + per_row - 1) // per_row, per_row, sharex=True, sharey=True)
        cbar_ax = fig.add_axes([0.91, .3, .03, .4])
        cbar_ax.tick_params(labelsize=8)
        for j, ax in enumerate(ax.flat):
            sns.heatmap(heatmap[j], ax=ax, cmap='Blues', square=True,
                        cbar=j == 0,
                        vmin=0, vmax=50,
                        cbar_ax=None if j else cbar_ax)
            ax.tick_params(labelsize=8)
        plt.tight_layout(rect=[0, 0, .9, 1])
        plt.show()