In [1]:
from glob import glob
import textworld

In [2]:
filenames_train = glob("./games/train/*.json")
filenames_valid = glob("./games/valid/*.json")
filenames_test = glob("./games/test/*.json")

In [3]:
# filter games (keep only one game per skillset)
names = {}
for filename in filenames_valid:
    prefix, seed = filename[:-5].rsplit("-", 1)
    names[prefix] = seed

filenames_valid = [name + "-" + seed + ".json" for name, seed in names.items()]
len(filenames_valid)

222

In [18]:
for filename in filenames_valid:
    print(filename)

./games/valid/tw-cooking-recipe3+take3+open+drop-M3OYhOPxSWyJHyY1.json
./games/valid/tw-cooking-recipe3+cook+open+go12-7WDjCN08tjabF5Qg.json
./games/valid/tw-cooking-recipe1+cook+cut+open+drop+go6-rXmRCdv0Ia1YfpYG.json
./games/valid/tw-cooking-recipe1+take1+cook+cut+open+drop-p35nFk2NSJVqiV5L.json
./games/valid/tw-cooking-recipe3+take3+cut+go6-m67BIYnEUElZI3Ey.json
./games/valid/tw-cooking-recipe3+take3+cook+cut-bRoRtWeZho95hq8o.json
./games/valid/tw-cooking-recipe2+cut-pY7eFyRjs3lPhZb3.json
./games/valid/tw-cooking-recipe2+go12-v3gjIWOOUr9Rc5eE.json
./games/valid/tw-cooking-recipe1+take1+cut+open-2jOyfByjFOVnhZ89.json
./games/valid/tw-cooking-recipe2+cook+cut-P0J2c1D5UXeQF80W.json
./games/valid/tw-cooking-recipe2+cut+open+drop+go9-mNyji3a8tPLgFnaa.json
./games/valid/tw-cooking-recipe1+take1+open+drop-N0rRC7LgTmELf39g.json
./games/valid/tw-cooking-recipe3+take3+cut+open+drop-E6KPFabXUr3pU0va.json
./games/valid/tw-cooking-recipe3+go12-QP7RT6gDFWO7iO8m.json
./games/valid/tw-cooking-recip

In [4]:
def load_games(filenames):
    return [textworld.Game.load(filename) for filename in filenames]

games_train = load_games(filenames_train)
games_valid = load_games(filenames_valid)
games_test = load_games(filenames_test)

In [5]:
def get_ingredients(games):
    return [tuple(ingredient) for game in games for ingredient in game.metadata["ingredients"]]

ingredients_train = get_ingredients(games_train)
ingredients_valid = get_ingredients(games_valid)
ingredients_test = get_ingredients(games_test)

In [None]:
from textworld.challenges.cooking import FOODS_SPLITS

def check_all_food_split(dataset, split):
    to_find = list(FOODS_SPLITS[split])
    print(to_find)
    for game in dataset:
        for food, cook, cut in game.metadata["ingredients"]:
            if food in to_find:
                to_find.remove(food)
                
    print(to_find)
    assert len(to_find) == 0
    
check_all_food_split(games_valid, "valid")
check_all_food_split(games_valid, "train")

In [15]:
from collections import defaultdict
from textworld.challenges.cooking import FOOD_PREPARATIONS_SPLITS

def count_food_preparation_in_split(dataset, split):
    to_find = [(food,) + preparation
               for food, preparations in FOOD_PREPARATIONS_SPLITS[split].items()
               for preparation in preparations]
    
    total = len(to_find)
    for game in dataset:
        for food, cook, cut in game.metadata["ingredients"]:
            preparation = (food, cook, cut)
            if preparation in to_find:
                to_find.remove(preparation)
    
    found = total - len(to_find)
    print("{} out of {} ({:.1%})".format(found, total, found/total))
    #print(to_find)

count_food_preparation_in_split(games_valid, "valid")
count_food_preparation_in_split(games_valid, "train")

        

56 out of 57 (98.2%)
5 out of 120 (4.2%)


In [None]:
import textwrap
import numpy as np
import matplotlib.pyplot as plt
from matplotlib_venn import venn3

def view_dataset(train, valid, test, titles=["train", "valid", "test"]):
    
    def build_formatter_fct():
        cpt = [0]
        def _fct(arg=None):
            if cpt[0] == 0:
                words = set(train) - set(valid) - set(test)
            elif cpt[0] == 1:
                words = set(valid) - set(train) - set(test)
            elif cpt[0] == 2:
                words = (set(train) & set(valid)) - set(test)
            elif cpt[0] == 3:
                words = set(test) - set(valid) - set(train)
            elif cpt[0] == 4:
                words = (set(test) & set(train)) - set(valid)
            elif cpt[0] == 5:
                words = (set(test) & set(valid)) - set(train)
            elif cpt[0] == 6:
                words = set(train) & set(valid) & set(test)
            else:
                return None

            cpt[0] += 1
            return "\n".join(sorted(words, key=lambda s: s.split()[::-1]))
            
        return _fct

    fct = build_formatter_fct()
    subsets = [fct() for _ in range(7)]
    sizes = [len(subset)**0.5 for subset in subsets]

    plt.figure(figsize=(16*1.5,9*1.5))
    venn = venn3(sizes, titles, subset_label_formatter=build_formatter_fct())
    plt.show()
    

In [None]:
# Check food items splits of the datasets.
view_dataset(set(food for food, _, _ in ingredients_train),
             set(food for food, _, _ in ingredients_valid),
             set(food for food, _, _ in ingredients_test))

In [None]:
# Check for missing food items in all datasets.
from textworld.challenges.cooking import FOODS
print(set(FOODS)
      - set(food for food, _, _ in ingredients_train)
      - set(food for food, _, _ in ingredients_valid)
      - set(food for food, _, _ in ingredients_test))

In [None]:
titles = ("sliced", "diced", "chopped")
view_dataset([ingredient[0] for ingredient in ingredients_train if ingredient[2] == "sliced"],
             [ingredient[0] for ingredient in ingredients_train if ingredient[2] == "diced"],
             [ingredient[0] for ingredient in ingredients_train if ingredient[2] == "chopped"],
             titles)

view_dataset([ingredient[0] for ingredient in ingredients_valid if ingredient[2] == "sliced"],
             [ingredient[0] for ingredient in ingredients_valid if ingredient[2] == "diced"],
             [ingredient[0] for ingredient in ingredients_valid if ingredient[2] == "chopped"],
             titles)


view_dataset([ingredient[0] for ingredient in ingredients_test if ingredient[2] == "sliced"],
             [ingredient[0] for ingredient in ingredients_test if ingredient[2] == "diced"],
             [ingredient[0] for ingredient in ingredients_test if ingredient[2] == "chopped"],
             titles)

In [None]:
import textwrap
import numpy as np
import matplotlib.pyplot as plt
from matplotlib_venn import venn3

def view_dataset(train, valid, test=[]):
    
    def build_formatter_fct():
        cpt = [0]
        def _fct(arg=None):
            if cpt[0] == 0:
                words = set(train) - set(valid) - set(test)
            elif cpt[0] == 1:
                words = set(valid) - set(train) - set(test)
            elif cpt[0] == 2:
                words = (set(train) & set(valid)) - set(test)
            elif cpt[0] == 3:
                words = set(test) - set(valid) - set(train)
            elif cpt[0] == 4:
                words = (set(test) & set(train)) - set(valid)
            elif cpt[0] == 5:
                words = (set(test) & set(valid)) - set(train)
            elif cpt[0] == 6:
                words = set(train) & set(valid) & set(test)
            else:
                return None

            cpt[0] += 1
            return "\n".join(sorted(words, key=lambda s: s.split()[::-1]))
            
        return _fct

    fct = build_formatter_fct()
    subsets = [fct() for _ in range(7)]
    sizes = [len(subset)**0.5 for subset in subsets]

    plt.figure(figsize=(16*1.5,9*1.5))
    venn = venn3(sizes, ("train", "valid", "test"), subset_label_formatter=build_formatter_fct())
    plt.show()

foods_train = set(food for food, _, _ in ingredients_train)
foods_valid = set(food for food, _, _ in ingredients_valid)
common_food = foods_train & foods_valid

for food in common_food:
    view_dataset([ingredient for ingredient in ingredients_train if ingredient[0] == food],
                 [ingredient for ingredient in ingredients_valid if ingredient[0] == food])