In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [1]:
import pandas as pd
import numpy as np
import altair as alt
import altair_saver
import glob
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
alt.data_transformers.disable_max_rows()

DataTransformerRegistry.enable('default')

In [2]:
def personal():
    return {
        'config': {
            'view': {
                'height': 400,
                'width': 600,
            },
            'range': {
                'category': {'scheme': 'set2'},
                'ordinal': {'scheme': 'set2'},
            },
            'legend': {
                'labelLimit': 0,
            },
            'background': 'white',
#             'mark': {
#                 'clip': True,
#             },
        }
    }

def publication():
    colorscheme = 'set2'
    stroke_color = '333'
    title_size = 24
    label_size = 20
    line_width = 5

    return {
        'config': {
            'view': {
                'height': 500,
                'width': 600,
                'strokeWidth': 0,
                'background': 'white',
            },
            'title': {
                'fontSize': title_size,
            },
            'range': {
                'category': {'scheme': colorscheme},
                'ordinal': {'scheme': colorscheme},
            },
            'axis': {
                'titleFontSize': title_size,
                'labelFontSize': label_size,
                'grid': False,
                'domainWidth': 5,
                'domainColor': stroke_color,
                'tickWidth': 3,
                'tickSize': 9,
                'tickCount': 4,
                'tickColor': stroke_color,
                'tickOffset': 0,
            },
            'legend': {
                'titleFontSize': title_size,
                'labelFontSize': label_size,
                'labelLimit': 0,
                'titleLimit': 0,
                'orient': 'top-right',
#                 'padding': 10,
                'titlePadding': 10,
#                 'rowPadding': 5,
                'fillColor': '#ffffff88',
#                 'strokeColor': 'black',
                'cornerRadius': 0,
            },
            'rule': {
                'size': 3,
                'color': '999',
                # 'strokeDash': [4, 4],
            },
            'line': {
                'size': line_width,
#                 'opacity': 0.4
            },
        }
    }

alt.themes.register('personal', personal)
alt.themes.register('publication', publication)
alt.themes.enable('publication')

ThemeRegistry.enable('publication')

In [3]:
import json

import replay_buffer
import utils
import jax_specs
from observation_domains import DOMAINS


def load_replay(name, subdir='exploration'):
    return replay_buffer.load(f'results/{subdir}/{name}/replay.pkl')

def load_args(name, subdir='exploration'):
    with open(f'results/{subdir}/{name}/args.json', 'r') as f: 
        args = json.load(f)
    return args

def count_states(replay, args, bins=None):
    states = replay.s
    if bins is not None:
        states = discretize_states(states, args, bins)
    return len(np.unique(states, axis=0))

def discretize_states(states, args, bins):
    env, task = args['env'], args['task']
    spec = DOMAINS[env][task]
    flat_spec = utils.flatten_observation_spec(spec)
    j_flat_spec = jax_specs.convert_dm_spec(flat_spec)
    discrete_states = utils.discretize_observation(
        states, j_flat_spec, bins=bins, preserve_batch=True)
    return discrete_states

def count_states_incremental(replay, args, bins=None):
    states = replay.s
    if bins is not None:
        states = discretize_states(states, args, bins)

    visited = set()
    visit_counts = []
    for i, state in enumerate(states):
        immutable_state = tuple(state)
        visited.add(immutable_state)
        visit_counts.append(len(visited))
        if i + 1 == replay.length:
            break
    return visit_counts

In [13]:
def max_number_of_states(args, bins=None):
    if args['env'] == 'gridworld':
        return args['env_size'] ** 2
    else:
        env, task = args['env'], args['task']
        if bins is None:
            bins = args['n_state_bins']
        spec = DOMAINS[env][task]
        n_state_dims = utils.flatten_spec_shape(spec)[0]
        return bins ** n_state_dims

def make_visited_df(name, subdir='exploration', bins=None):
    replay = load_replay(name, subdir)
    args = load_args(name, subdir)
    if bins is None:
        if args['env'] != 'gridworld':
            bins = args['n_state_bins']
#     bins = None if args['env'] == 'gridworld' else args['n_state_bins']
    visit_counts = count_states_incremental(replay, args, bins=bins)
    df = pd.DataFrame({'visit_count': [0] + visit_counts, 
                       'timestep': list(range(0, len(visit_counts) + 1))})
    df['name'] = name
    df['visit_fraction'] = df['visit_count'] / max_number_of_states(args, bins)
    df = df[df.timestep % args['max_steps'] == 0]
    df['episode'] = list(range(0, len(df)))
    df['title'] = df['name'].str.replace(r'_seed\d', '')
    return df

def make_visited_dfs(pattern, subdir='exploration', bins=None):
    jobs = glob.glob(f'results/{subdir}/{pattern}')
    results = []
    for job in jobs:
        name = os.path.basename(os.path.normpath(job))
        try:
            df = make_visited_df(name, subdir, bins)
        except Exception as e:
            print(e)
        results.append(df)
    return pd.concat(results, sort=False)

In [14]:
alt.themes.enable('personal')
data = pd.concat([
    make_visited_dfs('arxiv2_pv100_seed*', subdir='exploration'),
    make_visited_dfs('epv100kc*_seed*', subdir='exploration'),
], sort=False).reset_index(drop=True)

chart = alt.Chart(data).mark_circle(size=0).encode(
    x=alt.X('episode:Q', title='Episode'),
    y=alt.Y('mean(visit_fraction):Q', title='Fraction of states visited'),
    color=alt.Color('title'),
)
chart = chart + chart.mark_line(size=3).encode()
# altair_saver.save(chart, 'pv100_visits.pdf', method='node')
chart.transform_filter(alt.datum.episode <= 100)

[Errno 2] No such file or directory: 'results/exploration/epv100kc_scale0.1_seed1/replay.pkl'
[Errno 2] No such file or directory: 'results/exploration/epv100kc_scale0.01_seed2/replay.pkl'
[Errno 2] No such file or directory: 'results/exploration/epv100kc_scale0.001_seed3/replay.pkl'


In [32]:
chart = chart.configure(legend={'orient': 'bottom'})
altair_saver.save(chart, 'gw40_pv100_visits.pdf', method='node')

WARN Continuous axis should not have customized aggregation function mean; errorband already agregates the axis.
WARN Continuous axis should not have customized aggregation function mean; errorband already agregates the axis.


In [46]:
gw_data = pd.concat([
    make_visited_dfs('arxiv2_grid40*noreward*', subdir='intrinsic'),
    make_visited_dfs('arxiv2_grid40*puniform*', subdir='slow'),
    make_visited_dfs('arxiv2_grid40*puniform*', subdir='exploration'),
], sort=False).reset_index(drop=True)

pv_data = pd.concat([
    make_visited_dfs('arxiv2_pv100*noreward*', subdir='intrinsic', bins=40),
    make_visited_dfs('arxiv2_pv100*puniform*', subdir='slow', bins=40),
    make_visited_dfs('arxiv2_pv100*puniform*', subdir='exploration', bins=40),
], sort=False).reset_index(drop=True)

In [53]:
def completion_times(data):
    data = data.copy()
    data['Algorithm'] = 'Ours: IR + FP + FA + Optimism'
    data.loc[data['name'].str.contains('noopt'), 'Algorithm'] = 'IR + FP + Fast adaptation'
    data.loc[data['name'].str.contains('slow'), 'Algorithm'] = 'IR + Factored policies'
    data.loc[data['name'].str.contains('intrinsic'), 'Algorithm'] = 'Intrinsic reward'
    data.loc[data['name'].str.contains('noexplore'), 'Algorithm'] = 'No exploration'
    min_data = data[data.visit_fraction == 1.0].groupby(['name', 'Algorithm']).min('episode')
    min_per_alg = min_data.groupby(['Algorithm']).mean()
    return min_per_alg

In [54]:
completion_times(gw_data)

Unnamed: 0_level_0,visit_count,timestep,visit_fraction,episode
Algorithm,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
IR + FP + Fast adaptation,1600.0,56466.666667,1.0,564.666667
IR + Factored policies,1600.0,65500.0,1.0,655.0
Ours: IR + FP + FA + Optimism,1600.0,13575.0,1.0,135.75


In [55]:
completion_times(pv_data)

Unnamed: 0_level_0,visit_count,timestep,visit_fraction,episode
Algorithm,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
IR + FP + Fast adaptation,1600.0,26675.0,1.0,266.75
IR + Factored policies,1600.0,28400.0,1.0,284.0
Intrinsic reward,1600.0,46875.0,1.0,468.75
Ours: IR + FP + FA + Optimism,1600.0,21825.0,1.0,218.25
