In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [27]:
import altair as alt
import pandas as pd
import glob
import os

In [3]:
def personal():
    return {
        'config': {
            'view': {
                'height': 300,
                'width': 400,
            },
            'range': {
                'category': {'scheme': 'set2'},
                'ordinal': {'scheme': 'plasma'},
            },
            'legend': {
                'labelLimit': 0,
            },
            'background': 'white',
            'mark': {
                'clip': True,
            },
            'line': {
                'size': 3,
#                 'opacity': 0.4
            },


        }
    }

def publication():
    colorscheme = 'set2'
    stroke_color = '333'
    title_size = 24
    label_size = 20
    line_width = 5

    return {
        'config': {
            'view': {
                'height': 500,
                'width': 600,
                'strokeWidth': 0,
                'background': 'white',
            },
            'title': {
                'fontSize': title_size,
            },
            'range': {
                'category': {'scheme': colorscheme},
                'ordinal': {'scheme': colorscheme},
            },
            'axis': {
                'titleFontSize': title_size,
                'labelFontSize': label_size,
                'grid': False,
                'domainWidth': 5,
                'domainColor': stroke_color,
                'tickWidth': 3,
                'tickSize': 9,
                'tickCount': 4,
                'tickColor': stroke_color,
                'tickOffset': 0,
            },
            'legend': {
                'titleFontSize': title_size,
                'labelFontSize': label_size,
                'labelLimit': 0,
                'titleLimit': 0,
                'orient': 'top-left',
#                 'padding': 10,
                'titlePadding': 10,
#                 'rowPadding': 5,
                'fillColor': '#ffffff88',
#                 'strokeColor': 'black',
                'cornerRadius': 0,
            },
            'rule': {
                'size': 3,
                'color': '999',
                # 'strokeDash': [4, 4],
            },
            'line': {
                'size': line_width,
#                 'opacity': 0.4
            },
        }
    }

alt.themes.register('personal', personal)
alt.themes.register('publication', publication)
alt.themes.enable('personal')

ThemeRegistry.enable('personal')

In [4]:
def load_sac_results(env, task):
    sac_results = pd.read_csv('results/sac.csv')
    sac_results = sac_results[sac_results.env == f'{env}_{task}']
    sac_results['test'] = True
    sac_results['score'] = sac_results['episode_reward']
    sac_results['name'] = 'SAC'
    sac_results['title'] = 'SAC'
    return sac_results

In [24]:
def load_jobs(pattern, title=None):    
    jobs = glob.glob(f'exp/*/*{pattern}')
    results = []
    for job in jobs:
        try:
            name = os.path.basename(os.path.normpath(job))
            train_data = pd.read_csv(job + '/train.csv')
            train_data['test'] = False
            test_data = pd.read_csv(job + '/eval.csv')
            test_data['test'] = True
            data = pd.concat([train_data, test_data], sort=False)
            data['name'] = name
            results.append(data)
        except Exception as e:
            print(e)
    df = pd.concat(results, sort=False)
    if title is None:
        df['title'] = df['name'].str.replace(r'_seed\d', '')
    else:
        df['title'] = title
    return df.reset_index(drop=True)

In [25]:
def plot_with_bars(data, y_col='episode_reward', test=True, extent='ci', 
                   y_args={}, **kwargs):
    base_chart = alt.Chart(data).mark_line().encode(x='episode', **kwargs)
    
    legend_chart = base_chart.mark_line(size=0, opacity=1).encode(
        y=alt.Y(f'mean({y_col}):Q', **y_args),
    ).transform_filter(alt.datum.test == test)
    mean_chart = base_chart.encode(
        y=alt.Y(f'mean({y_col}):Q', **y_args),
    ).transform_filter(alt.datum.test == test)
    err_chart = base_chart.encode(
        y=alt.Y(f'{y_col}:Q', **y_args),
    ).transform_filter(alt.datum.test == test).mark_errorband(extent=extent)
    
    chart = legend_chart + err_chart + mean_chart
    return chart

In [36]:
data = pd.concat([
    load_sac_results('cheetah', 'run'),
    load_jobs('sactrunc_v1*'),
])
data = data[data.episode <= 100]

plot_with_bars(data, y_col='episode_reward', color='title', test=False) |\
plot_with_bars(data, y_col='episode_reward', color='title', test=True)

No columns to parse from file
No columns to parse from file
No columns to parse from file
No columns to parse from file
No columns to parse from file
No columns to parse from file
No columns to parse from file
No columns to parse from file
No columns to parse from file
