# Compute stats for paper (Table 1).

In [1]:
import sys
sys.path.append("..")  # HACK: needed to import generic.py.

In [2]:
import glob
import itertools
from pprint import pprint

import textworld
from textworld import Game

from generic import process_fully_obs_facts, serialize_facts

In [4]:

def filter_commands(commands):
    """ Filter out commands not needed for this project. """
    filtered_commands = []
    for cmd in commands:
        verb = cmd.split()[0]
        if cmd == "examine cookbook" or verb not in ["examine", "look"]:
            filtered_commands.append(cmd)

    return filtered_commands


def get_stats(gamefile):
    """ Compute different statistics about the games for each difficulty.

    Here are the statistics that get computes:
        - Average graph densities.abs
        - Average nodes in the graph
        - Average number of action candidates at each step along the walkthrough.
    
    Notes
    -----
        The formula used to compute the graph density is
        D = |E| / ( |R| * |V| * (|V|-1) )
    """
    env_infos = textworld.EnvInfos(facts=True, game=True, max_score=True, 
                                   admissible_commands=True, extras=["walkthrough"])
    env = textworld.start(gamefile, env_infos)
    infos = env.reset()

    facts = process_fully_obs_facts(infos["game"], infos["facts"])
    triplets = serialize_facts(facts)

    nodes = sorted(set(e for t in triplets for e in t[:2]))
    relations = sorted(set(t[-1] for t in triplets))
    density = len(triplets)/(len(relations)*len(nodes)*(len(nodes)-1))
    
    # Average actions candidates.
    candidates = [len(filter_commands(infos["admissible_commands"]))]
    for cmd in infos["extra.walkthrough"]:
        infos, score, done = env.step(cmd)
        candidates.append(len(filter_commands(infos["admissible_commands"])))
    
    assert score == infos["max_score"]
    assert done
    
    return density, len(nodes), candidates

def get_densities_and_nb_nodes(gamefiles):
    densities, nb_nodes, nb_candidates = zip(*[get_stats(f) for f in gamefiles])
    nb_candidates = list(itertools.chain(*nb_candidates))
    return densities, nb_nodes, nb_candidates


for i in range(1, 10 + 1):
    densities, nb_nodes, nb_candidates = [], [], []
    for subset in ["train_100", "valid", "test"]:
        gamefiles = glob.glob("../rl.0.2/{}/difficulty_level_{}/*z8".format(subset, i))
        densities_, nb_nodes_, nb_candidates_ = get_densities_and_nb_nodes(gamefiles)
        densities += densities_
        nb_nodes += nb_nodes_
        nb_candidates += nb_candidates_
    
    avg_density = sum(densities) / len(densities)
    avg_nb_nodes = sum(nb_nodes) / len(nb_nodes)
    avg_nb_candidates = sum(nb_candidates) / len(nb_candidates)
    print("{}.\tGraphs Density: {:.1%}\tNodes/Objects: {:.1f}\tAction Candidates: {:.1f}".format(i, avg_density, avg_nb_nodes, avg_nb_candidates))

1.	Graphs Density: 1.5%	Nodes/Objects: 15.9	Action Candidates: 8.0
2.	Graphs Density: 1.4%	Nodes/Objects: 16.7	Action Candidates: 8.9
3.	Graphs Density: 1.3%	Nodes/Objects: 17.1	Action Candidates: 11.5
4.	Graphs Density: 0.6%	Nodes/Objects: 24.8	Action Candidates: 7.6
5.	Graphs Density: 0.4%	Nodes/Objects: 34.1	Action Candidates: 7.2
6.	Graphs Density: 0.4%	Nodes/Objects: 40.4	Action Candidates: 6.6
7.	Graphs Density: 1.3%	Nodes/Objects: 17.5	Action Candidates: 11.8
8.	Graphs Density: 0.5%	Nodes/Objects: 31.0	Action Candidates: 13.8
9.	Graphs Density: 0.5%	Nodes/Objects: 33.4	Action Candidates: 28.4
10.	Graphs Density: 0.3%	Nodes/Objects: 49.5	Action Candidates: 20.2
