# Gridworld

Script for generating a dataset of pre/post images of actions performed on random gridworld states.

In [None]:
%load_ext autoreload
%autoreload 2

## Configure environment

In [None]:
import os
os.environ["SDL_VIDEODRIVER"] = "dummy"
os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1"
import yaml

import pygame
pygame.init()

import symbolic
from config import EnvironmentPaths

paths = EnvironmentPaths(environment="gridworld")
pddl = symbolic.Pddl(str(paths.domain_pddl), str(paths.problem_pddl))

with open(paths.env / "config.yaml") as f:
    config = yaml.full_load(f)

## Generate dataset

### Define state generation functions

In [None]:
import typing

import numpy as np

from env.gridworld import propositions
from env.gridworld.dataset import LogDatabase
from env.gridworld.propositions import ArgumentTypeError, PropositionValueError
from env.gridworld.world import World

def random_state(pddl, config, state=set(), prob_proposition_on=0.05, log=None):
    """Generate random world state.
    
    Args:
        pddl (symbolic.Pddl): Pddl instance.
        config (dict): World config.
        state (set(str), optional): Initial state (default empty).
        prob_proposition_on (double, optional): Probability of turning a proposition on (default 0.05).
        log (text file, optional): Print debug statements to this file (default None).
    Returns:
        (set(str)): Random symbolic state as a set of propositions.
    """
    # Empty world for prop testing
    world_test = World(pddl, config, state, log=log)
    
    # Iterate over all propositions
    state_index = pddl.state_index
    for i in range(len(state_index)):
        # Select prob_pred_on predicates
        if random.random() > prob_proposition_on:
            continue

        prop = state_index.get_proposition(i)

        # Check prop for consistency
        try:
            world_test.state.add(prop, validate=True)
        except (ArgumentTypeError, PropositionValueError) as e:
            if log is not None:
                log.write(f"! {e}")
                # Try to remove prop if added in the state.
                # Need to check existence before removing because
                # prop may not be constructable with invalid args.
                if prop in world_test.state.stringify():
                    try:
                        world_test.state.remove(prop)
                    except (KeyError, ValueError) as e:
                        log.write(f"! {e}")
            continue

    world_test.place_objects()
    return pddl.derived_state(world_test.state.stringify())

def generate_pre_post(
    pddl: symbolic.Pddl,
    config: typing.Dict,
    action_call: str,
    conj: symbolic.PartialState,
    log: typing.Optional[LogDatabase] = None
) -> [typing.Tuple[typing.Tuple[np.ndarray, np.ndarray, np.ndarray], typing.Tuple[np.ndarray, np.ndarray, np.ndarray]]]:
    """Generate pre/post images of given action performed on a random state.
    
    Args:
        pddl: Pddl instance.
        config: World config.
        action_call: Action call.
        conj: Precondition dnf conjunction.
        log: Print debug statements to this file (default None).
    Returns:
        (image [220, 220, 3], state [N], boxes [O, 4]) ftuple for pre- and post-conditions each.
    """
    # Generate random state
    s_random = random_state(pddl, config, conj.pos, log=log) - conj.neg
    s = conj.pos | s_random
    if log is not None:
        log.write(f"s_random: {s_random}")
        log.write(f"s_combined: {s}")

    # Initialize world
    world = World(pddl, config, s, log=log, validate=True)
    world.place_objects()
    
    # Check that world did not reintroduce negative propositions
    s_neg = set(propositions.alias(prop) for prop in conj.neg)
    s = world.state.stringify()
    if s & s_neg:
        raise PropositionValueError(f"Conflicting propositions:\ns:{s}\ns_neg:{s_neg}")

    # Render pre and post images
    img_pre = world.render()
    s_pre = pddl.state_index.get_indexed_state(world.state.stringify())
    boxes_pre = world.get_bounding_boxes()
    
    world.execute(action_call)
    img_post = world.render()
    s_post = pddl.state_index.get_indexed_state(world.state.stringify())
    boxes_post = world.get_bounding_boxes()
    
    return (img_pre, s_pre, boxes_pre), (img_post, s_post, boxes_post)

### Define IO functions

In [None]:
import random

import matplotlib.pyplot as plt
import numpy as np

class Stdout:
    """Dummy class for logging to stdout."""
    def write(self, message):
        print(message)

def save_images(log, img_pre, img_post):
    """Save pre and post images with the current log key.
    
    Args:
        log (env.gridworld.dataset.LogDatabase): Log database.
        img_pre (np.ndarray): Pre image.
        img_post (np.ndarray): Post image.
    """
    plt.imsave(log.path_images / f"{log.key}_pre.png", img_pre)
    plt.imsave(log.path_images / f"{log.key}_post.png", img_post)
    
def load_images(log, key):
    """Load pre and post images from the given log key.
    
    Args:
        log (env.gridworld.dataset.LogDatabase): Log database.
        key (int): Log to load.
    Returns:
        (np.ndarray, np.ndarray): Pair of pre, post images
    """
    img_pre = plt.imread(log.path_images / f"{key}_pre.png")
    img_post = plt.imread(log.path_images / f"{key}_post.png")
    return (img_pre, img_post)

def render_images(img_pre, img_post):
    """Render pre and post images side-by-side.
    
    Args:
        img_pre (np.ndarray): Pre image.
        img_post (np.ndarray): Post image.
    """
    fig = plt.figure(figsize=(14, 7))
    ax1 = plt.subplot(121)
    ax1.imshow(img_pre, interpolation='none')
    ax2 = plt.subplot(122)
    ax2.imshow(img_post, interpolation='none')
    plt.show()

def save_variables(log, pddl, config, action_call, conj, debug: bool = True):
    """Save the given variables, along with the current random seed, at the current log key.
    
    Args:
        log (env.gridworld.dataset.LogDatabase): Log database.
        pddl (symbolic.Pddl): Pddl instance.
        config (dict): World config.
        action_call (str): Action call.
        conj (symbolic.PartialState): Precondition dnf conjunction.
        debug: Save all variables if true, otherwise save only necessary variables.
    """
    if debug:
        log.save({
            "pddl": pddl,
            "config": config,
            "action_call": action_call,
            "conj": conj,
            "state_random": random.getstate(),
            "state_np_random": np.random.get_state(),
        })
    else:
        log.save({
            "action_call": action_call,
        })

def load_variables(log, key, verbose=True):
    """Load the saved variables, along with the saved random seed, at the given log key.
    
    Optionally print the saved log.
    
    Args:
        log (env.gridworld.dataset.LogDatabase): Log database.
        key (int): Log to load.
        verbose (bool, optional): Whether to print the log (default True).
    Returns:
        (pddl, config, action_call, conj): Tuple of saved variables.
    """
    # Load variables
    variables = log.load(key, verbose=verbose)
    pddl = variables["pddl"]
    config = variables["config"]
    action_call = variables["action_call"]
    conj = variables["conj"]
    
    # Set random state
    random.setstate(variables["state_random"])
    np.random.set_state(variables["state_np_random"])
    
    return pddl, config, action_call, conj

### Generate dataset

Logs, images, and variables are saved in `data/gridworld`.

In [None]:
import traceback

import tqdm.notebook

from gpred import dnf_utils

def generate_dataset(size_dataset: int = 10000, debug: bool = False):
    """Generates dataset by saving png images and logging variables to disk.
    
    Args:
        size_dataset: Minimum number of entries in dataset.
    """
    # Create log database for saving data
    log = LogDatabase(path=paths.data)

    # Set random seed
    random.seed(0)
    np.random.seed(0)

    # Iterate over all actions
    idx = 0
    num_generated = 0
    loop = tqdm.notebook.tqdm(total=size_dataset)
    while num_generated < size_dataset:
        for action in pddl.actions:
            # Iterate over all parameter combinations for action
            for args in action.parameter_generator:
                action_call = action.to_string(args)

                # Get pre/post conditions
                pre_post = dnf_utils.get_normalized_conditions(pddl, action_call, apply_axioms=True)

                if pre_post is None:
                    # No valid normalized conditions (violated axioms)
                    continue

                # Iterate over precondition conjunctions
                pre, post = pre_post
                for conj in pre.conjunctions:
                    # Initialize log
                    log.key = idx
                    save_variables(log, pddl, config, action_call, conj, debug=debug)
                    log.write(f"{action_call}")
                    log.write("=========================")
                    log.write(f"idx: {idx}")
                    log.write(f"s: {conj.pos}")

                    try:
                        # Generate images
                        (img_pre, s_pre, boxes_pre), (img_post, s_post, boxes_post) = generate_pre_post(pddl, config, action_call, conj, log=log)
                    except (ArgumentTypeError, PropositionValueError) as e:
                        with open(paths.data / "warnings.log", "a+") as f:
                            f.write(f"{log.key}: {e.__class__.__name__}: {e}\n")
                            log.write(traceback.format_exc())
                    except Exception as e:
                        with open(paths.data / "exceptions.log", "a+") as f:
                            tb = traceback.format_exc()
                            f.write(f"{log.key}: {e.__class__.__name__}: {e}\n{tb}\n")
                            log.write(tb)
                    else:
                        # Save images
                        save_images(log, img_pre, img_post)

                        # Save state
                        log.save({"s_pre": s_pre, "s_post": s_post, "boxes_pre": boxes_pre, "boxes_post": boxes_post})

                        # Increment index
                        num_generated += 1
                        if idx <= size_dataset:
                            loop.update(1)

                    # Increment log index
                    log.commit()
                    idx += 1

    # Close loop
    print(f"Generated {num_generated} out of {size_dataset}")
    loop.close()

    # Reset log key
    log.key = "stdout"

In [None]:
generate_dataset(100000)

### Save dataset

Take logs, images, and variables from `data/gridworld` and convert them to vector format in `data/gridworld/dataset.h5`.

In [None]:
from env.gridworld.dataset import LogDatabase
import tqdm.notebook

# Create log database in case previous cell was not run
log = LogDatabase(path=paths.data)

log.publish_dataset("dataset.hdf5", tqdm=tqdm.notebook.tqdm)

In [None]:
import collections

import h5py

idx_pre_post = collections.defaultdict(lambda: 0)
with h5py.File(paths.data / "dataset.hdf5", "r") as f:
    D = len(f["actions"])
    with h5py.File(paths.data / "dataset_half.hdf5", "w") as f_out:
        dset = f_out.create_dataset("idx_pre_post", (D,), dtype=int)
        for idx_data, action in enumerate(f["actions"]):
            dset[idx_data] = idx_pre_post[action]
            idx_pre_post[action] = 1 - idx_pre_post[action]

In [None]:
import numpy as np

with h5py.File(paths.data / "dataset_half.hdf5", "r") as f:
    print(np.array(f["idx_pre_post"]).sum())

## Compute image distribution
Comment out the image normalization transform inside the Problem class before running this.

In [None]:
experiment = Experiment(device)
for path_run in experiment:
    problem = experiment.problem
    problem.compute_image_distribution()

# Analyze Dataset Distribution

In [None]:
import math

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import tqdm

from gpred import dnf_utils


def plot_predicate_counts(stats: pd.DataFrame):
    """Plots predicates (x) vs. count (y).
    
    Args:
        stats: Longform dataframe output by `compute_pddl_statistics()`.
    """
    f, ax = plt.subplots(figsize=(20, 10))

    sns.set_style("whitegrid")
    g = sns.countplot(data=stats.sort_values("Predicate"), x="Predicate", hue="Label")
    for item in g.get_xticklabels():
        item.set_rotation(90)

def plot_dnfs(stats: pd.DataFrame):
    """Plots a heatmap of actions vs. propositions specified by their DNFs.
    
    Args:
        stats: Longform table output by compute_pddl_statistics().
    """
    SIZE_SECTION = 10
    CMAP = sns.diverging_palette(10, 130, n=100)
    
    df_action_v_prop = stats.astype({"Label": float}).pivot(index=["Action", "Condition"], columns="Proposition", values="Label")
    num_rows = len(df_action_v_prop)
    num_sections = math.ceil(num_rows / SIZE_SECTION)

    f, axs = plt.subplots(num_sections, 1, figsize=(10, num_sections * 5))

    for i in tqdm.notebook.tqdm(range(num_sections)):
        plt.subplot(num_sections, 1, i + 1)
        g = sns.heatmap(data=df_action_v_prop[i*SIZE_SECTION:min(len(df_action_v_prop), (i+1)*SIZE_SECTION)], square=True, cmap=CMAP, linewidths=0.5, linecolor="#eee", cbar_kws={"shrink": 0.5})

In [None]:
import h5py

with h5py.File(paths.data / "dataset_mini.hdf5", "r") as f:
    actions = [action.decode("utf-8") for action in set(f["actions"])]

pddl = symbolic.Pddl(str(paths.domain_pddl), str(paths.problem_pddl))
stats = dnf_utils.compute_pddl_statistics(pddl, actions=actions)

plot_predicate_counts(stats)

In [None]:
plot_dnfs(stats)

## Run example action skeleton

In [None]:
import matplotlib.pyplot as plt

def render(world):
    """Render world."""
    # pygame.image.save(world.canvas, 'test.png')
    img = world.render()
    fig = plt.figure(figsize=(7, 7))
    plt.imshow(img, interpolation='none')
    plt.show()

In [None]:
import random
import numpy as np

from env.gridworld.world import World

# Initialize random seed
random.seed(0)
np.random.seed(0)

# Create world
world = World(pddl, config)
print(f"Initial state: {world.state}\n")
render(world)

# List valid actions
world.list_valid_actions()

In [None]:
world.execute('goto(door_key)')
render(world)
world.list_valid_actions()

In [None]:
world.execute('pick(door_key, room_a)')
render(world)
world.list_valid_actions()

In [None]:
world.execute('goto(door)')
render(world)
world.list_valid_actions()

In [None]:
world.execute('unlock(door, door_key)')
render(world)
world.list_valid_actions()

In [None]:
world.execute('place(door_key, room_a)')
render(world)
world.list_valid_actions()

In [None]:
world.execute('open(door)')
render(world)
world.list_valid_actions()

In [None]:
world.execute('goto(chest_key)')
render(world)
world.list_valid_actions()

In [None]:
world.execute('pick(chest_key, room_a)')
render(world)
world.list_valid_actions()

In [None]:
world.execute('enter(room_b, door)')
render(world)
world.list_valid_actions()

In [None]:
world.execute('goto(chest)')
render(world)
world.list_valid_actions()

In [None]:
world.execute('unlock(chest, chest_key)')
render(world)
world.list_valid_actions()

In [None]:
world.execute('place(chest_key, room_b)')
render(world)
world.list_valid_actions()

In [None]:
world.execute('open(chest)')
render(world)
world.list_valid_actions()

In [None]:
world.execute('pick(trophy, chest)')
render(world)
world.list_valid_actions()

In [None]:
world.is_goal_satisfied()