In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [1]:
import pandas as pd
import numpy as np
import altair as alt
import altair_saver
import glob
import os
alt.data_transformers.disable_max_rows()

DataTransformerRegistry.enable('default')

In [12]:
def personal():
    return {
        'config': {
            'view': {
                'height': 400,
                'width': 600,
            },
            'range': {
                'category': {'scheme': 'set2'},
                'ordinal': {'scheme': 'set2'},
            },
            'legend': {
                'labelLimit': 0,
            },
            'background': 'white',
#             'mark': {
#                 'clip': True,
#             },
        }
    }

def publication():
    colorscheme = 'set2'
    stroke_color = '333'
    title_size = 24
    label_size = 20
    line_width = 5

    return {
        'config': {
            'view': {
                'height': 500,
                'width': 600,
                'strokeWidth': 0,
                'background': 'white',
            },
            'title': {
                'fontSize': title_size,
            },
            'range': {
                'category': {'scheme': colorscheme},
                'ordinal': {'scheme': colorscheme},
            },
            'axis': {
                'titleFontSize': title_size,
                'labelFontSize': label_size,
                'grid': False,
                'domainWidth': 5,
                'domainColor': stroke_color,
                'tickWidth': 3,
                'tickSize': 9,
                'tickCount': 4,
                'tickColor': stroke_color,
                'tickOffset': 0,
            },
            'legend': {
                'titleFontSize': title_size,
                'labelFontSize': label_size,
                'labelLimit': 0,
                'titleLimit': 0,
                'orient': 'top-left',
#                 'padding': 10,
                'titlePadding': 10,
#                 'rowPadding': 5,
                'fillColor': '#ffffff88',
#                 'strokeColor': 'black',
                'cornerRadius': 0,
            },
            'rule': {
                'size': 3,
                'color': '999',
                # 'strokeDash': [4, 4],
            },
            'line': {
                'size': line_width,
#                 'opacity': 0.4
            },
        }
    }

alt.themes.register('personal', personal)
alt.themes.register('publication', publication)
alt.themes.enable('publication')

ThemeRegistry.enable('publication')

In [3]:
import json

import replay_buffer
import utils
import jax_specs
from observation_domains import DOMAINS


def load_replay(name, subdir='exploration'):
    return replay_buffer.load(f'results/{subdir}/{name}/replay.pkl')

def load_args(name, subdir='exploration'):
    with open(f'results/{subdir}/{name}/args.json', 'r') as f: 
        args = json.load(f)
    return args

def count_states(replay, args, bins=None):
    states = replay.s
    if bins is not None:
        states = discretize_states(states, args, bins)
    return len(np.unique(states, axis=0))

def discretize_states(states, args, bins):
    env, task = args['env'], args['task']
    spec = DOMAINS[env][task]
    flat_spec = utils.flatten_observation_spec(spec)
    j_flat_spec = jax_specs.convert_dm_spec(flat_spec)
    discrete_states = utils.discretize_observation(
        states, j_flat_spec, bins=bins, preserve_batch=True)
    return discrete_states

def count_states_incremental(replay, args, bins=None):
    states = replay.s
    if bins is not None:
        states = discretize_states(states, args, bins)

    visited = set()
    visit_counts = []
    for i, state in enumerate(states):
        immutable_state = tuple(state)
        visited.add(immutable_state)
        visit_counts.append(len(visited))
        if i + 1 == replay.length:
            break
    return visit_counts

In [4]:
def max_number_of_states(args):
    if args['env'] == 'gridworld':
        return args['env_size'] ** 2
    else:
        env, task = args['env'], args['task']
        bins = args['n_state_bins']
        spec = DOMAINS[env][task]
        n_state_dims = utils.flatten_spec_shape(spec)[0]
        return bins ** n_state_dims

def make_visited_df(name, subdir='exploration'):
    replay = load_replay(name, subdir)
    args = load_args(name, subdir)
    bins = None if args['env'] == 'gridworld' else args['n_state_bins']
    visit_counts = count_states_incremental(replay, args, bins=bins)
    df = pd.DataFrame({'visit_count': [0] + visit_counts, 
                       'timestep': list(range(0, len(visit_counts) + 1))})
    df['name'] = name
    df['visit_fraction'] = df['visit_count'] / max_number_of_states(args)
    df = df[df.timestep % args['max_steps'] == 0]
    df['episode'] = list(range(0, len(df)))
    return df

def make_visited_dfs(pattern, subdir='exploration'):
    jobs = glob.glob(f'results/{subdir}/{pattern}')
    results = []
    for job in jobs:
        name = os.path.basename(os.path.normpath(job))
        df = make_visited_df(name, subdir)
        results.append(df)
    return pd.concat(results, sort=False)

In [17]:
replay = load_replay('arxiv_grid20_puniform_seed0', subdir='exploration')
args = load_args('arxiv_grid20_puniform_seed0', subdir='exploration')

In [45]:
make_visited_df('arxiv_grid20_puniform_seed0', subdir='exploration')

Unnamed: 0,visit_count,timestep,name,visit_fraction,episode
0,0,0,arxiv_grid20_puniform_seed0,0.0000,0
100,57,100,arxiv_grid20_puniform_seed0,0.1425,1
200,91,200,arxiv_grid20_puniform_seed0,0.2275,2
300,114,300,arxiv_grid20_puniform_seed0,0.2850,3
400,138,400,arxiv_grid20_puniform_seed0,0.3450,4
...,...,...,...,...,...
19600,400,19600,arxiv_grid20_puniform_seed0,1.0000,196
19700,400,19700,arxiv_grid20_puniform_seed0,1.0000,197
19800,400,19800,arxiv_grid20_puniform_seed0,1.0000,198
19900,400,19900,arxiv_grid20_puniform_seed0,1.0000,199


In [14]:
data = pd.concat([
    make_visited_dfs('arxiv_grid40*', subdir='intrinsic'),
    make_visited_dfs('arxiv_grid40*', subdir='slow'),
    make_visited_dfs('arxiv_grid40*', subdir='exploration'),
], sort=False).reset_index(drop=True)

data['Algorithm'] = 'Ours: IR + FP + FA + Optimism'
data.loc[data['name'].str.contains('noopt'), 'Algorithm'] = 'IR + FP + Fast adaptation'
data.loc[data['name'].str.contains('slow'), 'Algorithm'] = 'IR + Factored policies'
data.loc[data['name'].str.contains('intrinsic'), 'Algorithm'] = 'Intrinsic reward'
data.loc[data['name'].str.contains('noexplore'), 'Algorithm'] = 'No exploration'
algorithms = [
    'No exploration',
    'Intrinsic reward',
    'IR + Factored policies',
    'IR + FP + Fast adaptation',
    'Ours: IR + FP + FA + Optimism'
]


chart = alt.Chart(data).mark_circle(size=0).encode(
    x=alt.X('episode:Q', title='Episode'),
    y=alt.Y('visit_fraction:Q', title='Fraction of states visited'),
    color=alt.Color('Algorithm', scale=alt.Scale(domain=algorithms)),
)
chart = chart + chart.mark_line(size=5).encode()
altair_saver.save(chart, 'grid40_visits.pdf', method='node')
chart