# TODO

Ambiguous continents:
  - Timor-Leste (OC, AS)
  - Armenia, Azerbaijan, Georgia (EU, AS)
  - Trinidad and Tobago (NA, SA)

Adding ambiguous continents should not create new cells!

In [None]:
FIELD_SIZE = 3
MIN_CELL_SIZE = 1
MAX_CELL_SIZE = 10

In [None]:
import json
import pandas as pd

# Import country data

In [None]:
data = json.load(open("local/countries_processed.json", encoding="utf-8"))
df = pd.DataFrame(data)

# Filter & rename columns
df.columns = ['iso', 'iso3', 'iso_numeric', 'fips', 'name', 'capital',
              'area_km2', 'population', 'continent', 'tld',
              'currency_code', 'currency_name', 'phone', 'zip_format', 'zip_regex',
              'languages', 'geonameid', 'neighbors', 'eq_fips', 'parent', 'territories', 'neighbors_t']
subset = ['iso', 'name', 'capital', 'continent',
          'area_km2', 'population',
          'currency_code', 'currency_name', 'languages',
          'territories', 'neighbors_t']
df = df[subset]
df.rename(columns={"neighbors_t": "neighbors"}, inplace=True)

# Additional columns & global fixes
df["continent"].fillna("NA", inplace=True)  # North America fix
df["landlocked"] = df["iso"].isin("AF,AD,AM,AT,AZ,BY,BT,BO,BW,BF,BI,CF,TD,CZ,SZ,ET,HU,KZ,XK,LA,LS,LI,LU,MW,ML,MD,MN,NP,NE,MK,PY,RW,SM,RS,SK,SS,CH,TJ,UG,VA,ZA,ZW".split(","))
df["island"] = df["neighbors"].apply(len) == 0

def add_alternative_value(df, col, country, *values):
    cols = list(df.columns)
    if col not in cols:
        return False
    
    # Add alternative column if not exists
    altcol = col + "_alt"
    if altcol not in cols:
        df.insert(cols.index(col) + 1, altcol, df[col].apply(lambda x: []))
    
    # Find country
    if country in df["name"].values:
        index = df.index[df["name"] == country][0]
    elif country in df["iso"].values:
        index = df.index[df["iso"] == country][0]
    else:
        print(f"country {country} not found!")
        return False
    
    # Add values
    values = sum([[x] if not isinstance(x, list) else x for x in values], [])
#     values = [val for val in values if val not in df.loc[index, altcol]]
    
    # Warn if value to-be-added is the actual value. Swap if not first-named
    if df.loc[index, col] in values:
        if values[0] == df.loc[index, col]:
            print(f"{country}/{col}: '{df.loc[index, col]}' is already set as main value - skipping")
        else:
            print(f"{country}/{col}: '{df.loc[index, col]}' is already set as main value - swapping with '{values[0]}'")
            df.loc[index, col] = values[0]
        values = values[1:]

    for val in values:
        if val not in df.loc[index, altcol]:
            df.loc[index, altcol].append(val)
    return True

# Individual fixes
df.loc[df["name"] == "Palau", "capital"] = "Ngerulmud"
df.loc[df["iso"] == "PS", "name"] = "Palestine"

# Alternative values
# Names
add_alternative_value(df, "name", "CI", "Ivory Coast", "Côte d'Ivoire")
add_alternative_value(df, "name", "TR", "Türkiye", "Turkey")
add_alternative_value(df, "name", "VA", "Vatican", "Vatican City")

# Multiple continents (source: https://en.wikipedia.org/wiki/List_of_transcontinental_countries)
add_alternative_value(df, "continent", "Armenia", "AS", "EU")
add_alternative_value(df, "continent", "Georgia", "AS", "EU")
add_alternative_value(df, "continent", "Azerbaijan", "AS", "EU")
add_alternative_value(df, "continent", "Trinidad and Tobago", "SA", "NA")
add_alternative_value(df, "continent", "Panama", "NA", "SA")
add_alternative_value(df, "continent", "Egypt", "AF", "AS")
add_alternative_value(df, "continent", "Russia", "EU", "AS")
add_alternative_value(df, "continent", "TR", "AS", "EU")
add_alternative_value(df, "continent", "Timor Leste", "AS", "OC")

# Multiple/unclear capital (source: https://en.wikipedia.org/wiki/List_of_countries_with_multiple_capitals)
add_alternative_value(df, "capital", "Kazakhstan", "Astana", "Nur-Sultan")
add_alternative_value(df, "capital", "Bolivia", "La Paz", "Sucre")
add_alternative_value(df, "capital", "Burundi", "Gitega", "Bujumbura")
add_alternative_value(df, "capital", "CI", "Yamoussoukro", "Abidjan")
add_alternative_value(df, "capital", "Eswatini", "Mbabane", "Lobamba")
add_alternative_value(df, "capital", "Malaysia", "Kuala Lumpur", "Putrajaya")
add_alternative_value(df, "capital", "Netherlands", "Amsterdam", "The Hague")
add_alternative_value(df, "capital", "Palestine", "Ramallah", "Jerusalem", "East Jerusalem")
add_alternative_value(df, "capital", "South Africa", "Pretoria", "Cape Town", "Bloemfontein")
add_alternative_value(df, "capital", "Sri Lanka", "Colombo", "Sri Jayawardenepura Kotte")

# Capitals with multiple spellings / alternative names
add_alternative_value(df, "capital", "US", "Washington", "Washington, DC")


# Display all changes
altcols = [col for col in df.columns if col.endswith("_alt")]
print("\nAll countries with alternative values:")
df[df[altcols].applymap(len).sum(axis=1) > 0]

# df.reset_index(inplace=True)
# df

In [None]:
df[df["iso"] == "PS"]

# Import flag colors

In [None]:
# Assign colors
colors = pd.read_csv("local/flag-colors.csv", sep=";")
colors.columns = ["country", "color"]
colors["country"].fillna(method='ffill', inplace=True)
colors.dropna(inplace=True)
colors = colors.applymap(lambda x: x.strip())


cmap = {
    "Light Blue": "Blue",
    "Dark Blue": "Blue",
    "Sky Blue": "Blue",
    "Aquamarine Blue": "Blue",
    "Fulvous": "Orange",
    "Crimson": "Red",
    "Saffron Orange": "Orange",
    "Green Or Blue": "Green",
    "Maroon": "Red",  # Qatar, Sri Lanka,
    "Olive Green": "Green",
    "Yellow": "Yellow/Gold",
    "Gold": "Yellow/Gold",
    "Golden": "Yellow/Gold",
}
name_map = {
    'American Samoa': None,
    'Anguilla': None,
    'Antigua And Barbuda': 'Antigua and Barbuda',
    'Aruba': None,
    'Bermuda': None,
    'Bosnia And Herzegovina': 'Bosnia and Herzegovina',
    'Bouvet Island': None,
    'Brunei Darussalam': "Brunei",
    'Cook Islands': None,
    'Curaçao': None,
    "Côte D'Ivoire": 'Ivory Coast',
    'Democratic Republic Of The Congo': 'Democratic Republic of the Congo',
    'French Polynesia': None,
    'Holy See (Vatican City State)': "Vatican",
    'Niue': None,
    'Norfolk Island': None,
#     'Palestine': 'Palestinian Territory',
    'Pitcairn Islands': None,
    'Republic Of The Congo': 'Republic of the Congo',
    'Russian Federation': "Russia",
    'Saint Kitts And Nevis': 'Saint Kitts and Nevis',
    'Saint Vincent And The Grenadines': 'Saint Vincent and the Grenadines',
    'Sao Tome And Principe': 'Sao Tome and Principe',
    'Syrian Arab Republic': "Syria",
    'Tanzania, United Republic Of': "Tanzania",
    'Trinidad And Tobago': 'Trinidad and Tobago',
    'Åland Islands': None,
    'Turkey': 'Türkiye'
}
add = [
    {"country": "Timor Leste", "color": ["Red", "Yellow/Gold", "Black", "White"]},
    {"country": "Kosovo", "color": ["Blue", "Yellow/Gold", "White"]},
    {"country": "Taiwan", "color": ["Red", "Blue", "White"]}
]

colors["color"] = colors["color"].apply(lambda c: cmap.get(c, c))
colors["country"] = colors["country"].apply(lambda x: name_map.get(x, x))
colors1 = colors.groupby(by="country")["color"].agg(list).reset_index()
colors1 = colors1.append(add, ignore_index=True)

colors2 = pd.merge(df[["name"]], colors1, how="left", left_on="name", right_on="country")
colors2_outer = pd.merge(df[["name"]], colors1, how="outer", left_on="name", right_on="country", indicator=True)
df["flag_colors"] = colors2["color"]
df["flag_colors"] = df["flag_colors"].apply(lambda cc: list(set(cc)))

no_flag = df["flag_colors"].isna().sum()
print(f"Assigned colors. {no_flag} countries missing a flag.")

### Apply flag color fixes

In [None]:
import re

cfre = re.compile(r"^(?:(?:\[(?P<main_set>[^\(\)]*)\])|(?:(?P<main_add>[^\(\)]*)))?(?:,?\s*\((?P<optional>[^\(\)]*?)\))?,?\s*(?:\(\((?P<ignore>[^\(\)]*?)\)\))?$")

def parse_color(c):
    cmap = {"Y/G": "Yellow/Gold", "R": "Red", "W": "White", "B": "Blue", "Gr": "Green", "O": "Orange"}
    return cmap.get(c, c)
    
def parse_fixes(specs):
    for line in specs:
        iso = line[:2]
        spec = line[3:]
        match = cfre.match(spec)
        if match:
            for mode, cc in match.groupdict().items():
                if cc:
                    yield (iso, mode, [parse_color(c.strip()) for c in cc.split(",")])
        else:
            print(spec, "no match")

color_fixes = open("local/flag color fixes.txt").read().split("\n")
color_fixes = list(parse_fixes(color_fixes))

# Apply the fixes
for iso, mode, cc in color_fixes:
    if iso not in df["iso"].values:
        print(f"Country {iso} not found.")
        continue
    index = df.index[df["iso"] == iso][0]
    if mode == "main_set":
        df.at[index, "flag_colors"] = cc
    elif mode == "main_add":
        for c in cc:
            df.loc[index, "flag_colors"].append(c)
    elif mode == "optional":
        add_alternative_value(df, "flag_colors", iso, *cc)

df["flag_colors"] = df["flag_colors"].apply(lambda cc: list(sorted(set(cc))))
if "flag_colors_alt" in list(df.columns):
    df["flag_colors_alt"] = df["flag_colors_alt"].apply(lambda cc: list(sorted(set(cc))))

changes = set(iso for iso, _, _ in color_fixes)    

df[df["iso"].isin(changes)][["iso", "name", "flag_colors", "flag_colors_alt"]]

In [None]:
# # from IPython.display import Image, display
# from IPython.display import SVG, display

# for row in list(df.iterrows()):
#     country = row[1].to_dict()
#     print(f'{country["name"]} ({country["iso"]}): {", ".join(country["flag_colors"])}')
#     display(SVG(url=f'https://hatscripts.github.io/circle-flags/flags/{country["iso"].lower()}.svg'))
#     print()
# # display(Image(filename=''))

In [None]:
from collections import Counter
Counter(df["flag_colors"].sum())

# Init Categories

In [None]:
from functools import total_ordering
from typing import Callable

@total_ordering
class Category:
    def __init__(self, key: str, name: str, difficulty: float, values: pd.Series, alt_values: pd.Series):
        self.key = key
        self.alt_key = key + "_alt"
        self.name = name
        self.difficulty = difficulty
        self.values = values
        self.alt_values = alt_values
        self.sets = None
        self.alt_sets = None
        
    def __repr__(self):
        return str(self)
    
    def __lt__(self, other):
        return self.key < other.key
    
    def __eq__(self, other):
        return self.key == other.key
    
    def __hash__(self):
        return hash(self.key)

@total_ordering
class NominalCategory(Category):
    def __init__(self, df: pd.DataFrame, key: str, name: str, difficulty: float, col: str, extractor: Callable = None):
        self.col = col
        self.extractor = extractor
        if extractor is None:
            extractor = lambda x: x
        values = df[col].apply(extractor)
        altcol = col + "_alt"
        self.alt_col = altcol
        alt_values = None
        if self.alt_col in list(df.columns):
            data = pd.concat([df, values], axis=1)
            alt_values = data.apply(lambda row: list(sorted({extractor(x) for x in row[altcol]}.difference(row[col]))), axis=1)
#                 print(alt_values[alt_values.apply(len) > 0])
        super().__init__(key, name, difficulty, values, alt_values)
        
    def __str__(self):
        return f"NominalCategory('{self.key}', {len(self.sets)} values)"
    
    def __lt__(self, other):
        return self.key < other.key
    
    def __eq__(self, other):
        return self.key == other.key
    
    def __hash__(self):
        return hash(self.key)

@total_ordering
class MultiNominalCategory(Category):
    def __init__(self, df: pd.DataFrame, key: str, name: str, difficulty: float, col: str):
        values = df[col].apply(set).apply(sorted).apply(list)
        altcol = col + "_alt"
        alt_values = None
        if altcol in list(df.columns):
            alt_values = df.apply(lambda row: list(sorted(set(row[altcol]).difference(row[col]))), axis=1)
        super().__init__(key, name, difficulty, values, alt_values)
        
    def __str__(self):
        return f"MultiNominalCategory('{self.key}', {len(self.sets)} values)"
    
    def __lt__(self, other):
        return self.key < other.key
    
    def __eq__(self, other):
        return self.key == other.key
    
    def __hash__(self):
        return hash(self.key)
    
        
nominal_categories = [
    NominalCategory(df, key="continent", name="Continent", difficulty=1, col="continent"),
    NominalCategory(df, key="starting_letter", name="Starting letter", difficulty=1, col="name", extractor=lambda x: x[0].upper()),
    NominalCategory(df, key="ending_letter", name="Ending letter", difficulty=2, col="name", extractor=lambda x: x[-1].upper()),
    NominalCategory(df, key="capital_starting_letter", name="Capital starting letter", difficulty=1.5, col="capital", extractor=lambda x: x[0].upper()),
    NominalCategory(df, key="capital_ending_letter", name="Capital ending letter", difficulty=3, col="capital", extractor=lambda x: x[-1].upper()),
    MultiNominalCategory(df, key="flag_colors", name="Flag color", difficulty=1.5, col="flag_colors"),
]
nominal_categories = {cat.key: cat for cat in nominal_categories}

values = pd.concat([
    df[["iso", "name"]],
    pd.DataFrame({cat.key: cat.values for cat in nominal_categories.values()}),
    pd.DataFrame({cat.alt_key: cat.alt_values for cat in nominal_categories.values() if cat.alt_values is not None}),
#     pd.DataFrame(bool_categories)
], axis=1)

values[values[[cat.alt_key for cat in nominal_categories.values()]].applymap(len).sum(axis=1) > 0]

## Category Ideas

- Starting/ending with letter
- Capital starting/ending with letter
- Top/Bottom 20 (area/population)
- (dynamic): Bigger/smaller/More/less populated than X
- Island?
- Landlocked?

In [None]:
# bool_categories = {
#     "Island": df["neighbours"].apply(lambda x: not x),
#     "Landlocked": None,
#     "Top 10 Area": df.ISO.isin(df.nlargest(10, 'Area(in sq km)').ISO),
#     "Bottom 10 Area": df.ISO.isin(df.nsmallest(10, 'Area(in sq km)').ISO),
#     "Top 10 Pop.": df.ISO.isin(df.nlargest(10, 'Population').ISO),
#     "Bottom 10 Pop.": df.ISO.isin(df.nsmallest(10, 'Population').ISO),
# }
# values = pd.concat([df[["ISO", "Country"]], pd.DataFrame(nominal_categories), pd.DataFrame(bool_categories)], axis=1)
# values

In [None]:
for cat in nominal_categories.values():
    if isinstance(cat, NominalCategory):
        cat.sets = values.groupby(by=cat.key)["iso"].agg(sorted)
        cat.alt_sets = values.explode(column=cat.alt_key).groupby(by=cat.alt_key)["iso"].agg(sorted)
    elif isinstance(cat, MultiNominalCategory):
        cat.sets = values.explode(column=cat.key).groupby(by=cat.key)["iso"].agg(sorted)
        cat.alt_sets = values.explode(column=cat.alt_key).groupby(by=cat.alt_key)["iso"].agg(sorted)

while True:
    # Retain only sets with at least 3 (FIELD_SIZE) elements
    num_sets_0 = sum(len(cat.sets) for cat in nominal_categories.values())
    for cat in nominal_categories.values():
        cat.sets = cat.sets[cat.sets.apply(len) >= FIELD_SIZE]
    num_sets_1 = sum(len(cat.sets) for cat in nominal_categories.values())
    if num_sets_0 != num_sets_1:
        print(f"Removed {num_sets_0 - num_sets_1} category sets")
    # Retain only countries contained in sets of at least 2 different categories (-> has matching row+column)
    category_contents = {cat.key: cat.sets.sum() for cat in nominal_categories.values()}
    # {cat: len(cc) for cat, cc in category_contents.items()}
    contents = set().union(*[set(cc) for cc in category_contents.values()])
    print("contents:", len(contents))
    retain = {c for c in contents if len([key for key, cc in category_contents.items() if c in cc]) >= 2}
    print("retain:", len(retain))
    remove = contents.difference(retain)
    
    for cat in nominal_categories.values():
        cat.sets = cat.sets.apply(lambda cc: [c for c in cc if c in retain])
    if not remove:
        break
    print(f"Removed {len(remove)} countries:", remove)
    print("Repeat ...")

list(nominal_categories.values())

In [None]:
def compare_sets(set1, set2):
    if len(set1) < len(set2):
        return -1
    if len(set1) > len(set2):
        return 1
    l1 = list(sorted(set(set1)))
    l2 = list(sorted(set(set2)))
    if l1 < l2:
        return -1
    if l1 > l2:
        return 1
    return 0

def min_set(set1, set2):
    cmp = compare_sets(set1, set2)
    return set2 if cmp > 0 else set1

def max_set(set1, set2):
    cmp = compare_sets(set1, set2)
    return set2 if cmp < 0 else set1

In [None]:
nominal_categories

In [None]:
# dfv = pd.concat([pd.DataFrame({"iso": values["iso"], "category": cat.key, "value": values[cat.key], "alt_values": values[cat.alt_key]}) for cat in nominal_categories.values()], ignore_index=True)
# dfv["category_type"] = dfv["category"].apply(lambda key: nominal_categories[key].__class__.__name__)

# def merge_values(row):
#     if row["category_type"] == "NominalCategory":
#         return [row["value"]] + row["alt_values"]
#     if row["category_type"] == "MultiNominalCategory":
#         return row["value"] + row["alt_values"]
#     return []
# dfv["all_values"] = dfv.apply(merge_values, axis=1)
# dfv

In [None]:
import itertools
import matplotlib.pyplot as plt

setkeys = sum([[(cat.key, value) for value in cat.sets.index] for cat in nominal_categories.values()], [])

def is_cell_allowed(key1, value1, key2, value2):
    if key1 != key2:
        return True
    cat = nominal_categories[key1]
    if value1 == value2:
        return False
    return isinstance(cat, MultiNominalCategory)

def init_cell_contents(key1, value1, key2, value2, alt=False):
    cat1, cat2 = nominal_categories[key1], nominal_categories[key2]
    contents = set(cat1.sets[value1]).intersection(cat2.sets[value2])
    if not alt:
        return sorted(contents)
    
    # Solutions caused by alternative values
    all1 = set(cat1.sets[value1] + cat1.alt_sets.get(value1, []))
    all2 = set(cat2.sets[value2] + cat2.alt_sets.get(value2, []))
    alt_contents = all1.intersection(all2).difference(contents)
    if not alt_contents:
        return []
    
    # Prevent that two different alternative values are used to create a solution
    # (e.g. Capital starting with P and ending with N -> South Africa - because of [P]retoria and Cape Tow[n])
    if isinstance(cat1, NominalCategory) and isinstance(cat2, NominalCategory):
        if cat1.col == cat2.col and cat1.extractor and cat2.extractor:
            col = cat1.col
            altcol = cat1.alt_col
            dfx = df[df["iso"].isin(alt_contents)][["iso", col, altcol]].copy()
            dfx["values"] = dfx.apply(lambda row: [row[col]] + list(row[altcol]), axis=1)
            dfx["src1"] = dfx["values"].apply(lambda xx: [x for x in xx if cat1.extractor(x) == value1])
            dfx["src2"] = dfx["values"].apply(lambda xx: [x for x in xx if cat2.extractor(x) == value2])
            dfx["keep"] = dfx.apply(lambda row: not set(row["src1"]).isdisjoint(row["src2"]), axis=1)
            alt_contents = dfx[dfx["keep"]]["iso"].tolist()
            
#             print(f"{key1}/{value1} - {key2}/{value2}")
#             display(dfx)
#             print(f"keep {alt_contents}")
    
    return sorted(set(alt_contents))


cells = {(min_set(row, col), max_set(row, col)): init_cell_contents(*row, *col)
         for row, col in itertools.combinations(setkeys, 2) if is_cell_allowed(*row, *col)}

print(f"Generated {len(cells)} cells")

# Bring cells to DataFrame to do filtering (cell size etc.)
cell_info = pd.DataFrame([{"row_cat": row[0], "row_val": row[1],
                           "col_cat": col[0], "col_val": col[1],
                           "contents": contents, "alt_contents": init_cell_contents(*row, *col, alt=True)} for (row, col), contents in cells.items()])
cell_info["size"] = cell_info["contents"].apply(len)

# display(cell_info[cell_info["row_cat"] == cell_info["col_cat"]])

# cell_info["size"].value_counts()
cell_info = cell_info[(cell_info["size"] >= MIN_CELL_SIZE) & (cell_info["size"] <= MAX_CELL_SIZE)]
plt.hist(cell_info["size"], rwidth=.9, bins=[x-.5 for x in range(cell_info["size"].min(), cell_info["size"].max() + 1)])
plt.title("Distribution of cell sizes")

# Bring back to dict with tuple access
cell_keys = cell_info.apply(lambda row: ((row["row_cat"], row["row_val"]), (row["col_cat"], row["col_val"])), axis=1)
cells = {key: (contents, alt_contents) for key, contents, alt_contents in zip(cell_keys, cell_info["contents"], cell_info["alt_contents"])}

print(f"Retained {len(cells)} cells (of size {MIN_CELL_SIZE}-{MAX_CELL_SIZE})")
# cell_info[cell_info["alt_contents"].apply(len) > 0].head(50)

In [None]:
# cell_info[cell_info.apply(lambda row: len({row["row_cat"], row["col_cat"]}.intersection(["capital_starting_letter", "capital_ending_letter"])) == 2, axis=1)]

In [None]:
import json


def get_label(cat: Category, value):
    if cat.key == "continent":
        continents = {"AF": "Africa", "EU": "Europe", "AS": "Asia", "NA": "N. America", "SA": "S. America", "OC": "Oceania"}
        return continents[value]
    return f"{cat.name}: {value}"


class Game:
    def __init__(self, values, solutions, alt_solutions, rows, cols):
        self.size = FIELD_SIZE
        self.values = values  # All possible values to be guessed (list of dicts)
        self.solutions = solutions  # 3x3 array containing list of possible solutions
        self.alt_solutions = alt_solutions  # 3x3 array containing list of possible alternative solutions
        self.rows = rows  # rows (tuples of form (Category, value))
        self.cols = cols  # columns (as above)
    
    def to_json(self, include_values=False):
        data = {
            "size": self.size,
            "solutions": [[list(cell) for cell in row] for row in self.solutions],
            "alternativeSolutions": [[list(cell) for cell in row] for row in self.alt_solutions],
            "labels": {
                "rows": [get_label(cat, value) for cat, value in self.rows],
                "cols": [get_label(cat, value) for cat, value in self.cols]
            }
        }
        if include_values:
            data["values"] = self.values
        return data
    
    def to_dataframe(self, solution=False):
        game_df = pd.DataFrame(data=[[",".join(c1) + (",(" + ",".join(c2) + ")" if c2 else "") for c1, c2 in zip(row1, row2)] for row1, row2 in zip(self.solutions, self.alt_solutions)] if solution else None,
                               index=[get_label(cat, value) for cat, value in self.rows],
                               columns=[get_label(cat, value) for cat, value in self.cols])
        game_df.fillna("", inplace=True)
        return game_df


def get_solutions(cells, row, col, alt=False):
    i = 1 if alt else 0
    if (row, col) in cells:
        return cells[(row, col)][i]
    if (col, row) in cells:
        return cells[(col, row)][i]
    return None


# Sample possible game setups

In [None]:
country_values = values[["iso", "name"]]

In [None]:
class Constraint:
    def __init__(self, prop, num, mode):
        # prop: function mapping a set (cat key, value) to some boolean value
        # num: number of categories
        # mode: -1: at most *num* matching categories. 0: exactly *num* matching categories. 1: at least *num* categories
        self.prop = prop
        self.num = num
        self.mode = mode
        
    def match(self, key, value):
        return self.prop(key, value)
        
    def count(self, sets):
        return len([(key, value) for key, value in sets if self.match(key, value)])
        
    def balance(self, sets):
        # ("needs x more", "only x more allowed")
        n = self.count(sets)
        return (self.num - n if self.mode >= 0 else None,
                self.num - n if self.mode <= 0 else None)
        
    def apply(self, sets):
        n = self.count(sets)
        if self.mode == -1:
            return n <= self.num
        if self.mode == 0:
            return n == self.num
        return n >= self.num
    
    def is_once(self):
        return n == 1 and self.mode == 0
    
    def is_never(self):
        return n == 0 and self.mode == 0
    
    @staticmethod
    def category(key, n, mode):
        return Constraint(lambda k, _: k == key, n, mode)
    
    @staticmethod
    def exactly(prop, n):
        return Constraint(prop, n, 0)
    
    @staticmethod
    def at_most(prop, n):
        return Constraint(prop, n, -1)
    
    @staticmethod
    def at_least(prop, n):
        return Constraint(prop, n, 1)
    
    @staticmethod
    def category_exactly(key, n):
        return Constraint.category(key, n, 0)
    
    @staticmethod
    def category_at_most(key, n):
        return Constraint.category(key, n, -1)
    
    @staticmethod
    def category_at_least(key, n):
        return Constraint.category(key, n, 1)
    
    @staticmethod
    def once(prop):
        return Constraint.exactly(prop, 1)
    
    @staticmethod
    def never(prop, n):
        return Constraint.exactly(prop, 0)
    
    @staticmethod
    def at_most_once(prop):
        return Constraint.at_most(prop, 1)
    
    @staticmethod
    def dummy():
        return Constraint(lambda cat: True, 0, 1)
    

In [None]:
import random
from collections import Counter

def _get_allowed_sets(cross_sets, parallel_sets, constraints):
    # Check constraint balances
    balance = [c.balance(cross_sets + parallel_sets) for c in constraints]
    underfed = [c for (a, b), c in zip(balance, constraints) if a is not None and a > 0]
    overfed = [c for (a, b), c in zip(balance, constraints) if b == 0]
#     print(f"{len(cross_sets)} cross, {len(parallel_sets)} parallel, {len(underfed)} underfed, {len(overfed)} overfed")
    
    # underfed: needs more. overfed: maximum is reached.
    choice = setkeys
    if underfed or overfed:
        # Only take those sets that satisfy some underfed constraint
        choice = [(key, value) for key, value in choice
                  if (any(c.match(key, value) for c in underfed) or not underfed)
                  and not any(c.match(key, value) for c in overfed)]
    # Not 2 identical (cat, value) sets in the game
    choice = set(choice).difference(cross_sets).difference(parallel_sets)
    # Not 2 crossing identical categories, except MultiNominal, but then only 1 each
    # Each category only allowed twice
    cross_cats = Counter(cat for cat, value in cross_sets)
    parallel_cats = Counter(cat for cat, value in parallel_sets)
    choice = {(cat, value) for cat, value in choice
              if (cat not in cross_cats and parallel_cats.get(cat, 0) <= 1)
              or (isinstance(nominal_categories[cat], MultiNominalCategory) and cross_cats[cat] == 1 and cat not in parallel_cats)}
    
    return choice


def _sample_fitting_set(cross_sets, parallel_sets, constraints):
    """ Samples a new column (assuming cross_sets are the rows and parallel_sets the previous columns. Or the other way round) """
    choice = list(_get_allowed_sets(cross_sets, parallel_sets, constraints))
    # Iterate all possible sets randomly until a fitting one is hit
    random.shuffle(choice)
    for c in choice:
        if all(get_solutions(cells, c, crossing) for crossing in cross_sets):
            return c
    return None

def _sample_game_setup(constraints):
    rows, cols = [], []
    for _ in range(FIELD_SIZE):
        # Sample a new column, then a new row
        new_col = _sample_fitting_set(cross_sets=rows, parallel_sets=cols, constraints=constraints)
        if new_col is not None:
            cols.append(new_col)
        else:
            return None, None
        new_row = _sample_fitting_set(cross_sets=cols, parallel_sets=rows, constraints=constraints)
        if new_row is not None:
            rows.append(new_row)
        else:
            return None, None
    if len(rows) != FIELD_SIZE or len(cols) != FIELD_SIZE:
        return None
    # Check constraints
    if not all(c.apply(rows + cols) for c in constraints):
        return None, None
    return rows, cols

def sample_game(constraints=[], shuffle=True):
    MAX_TRIES = 100
    rows, cols = None, None
    for i in range(MAX_TRIES):
        rows, cols = _sample_game_setup(constraints)
        if rows is not None and cols is not None:
            break
    
    if rows is None or cols is None:
        print(f"Could not create game setup ({MAX_TRIES} tries)")
        return None
        
    if shuffle:
        random.shuffle(rows)
        random.shuffle(cols)
        if random.random() > .5:
            rows, cols = cols, rows
    
    game = Game(values=country_values.to_dict(orient="records"),
                solutions=[[get_solutions(cells, row, col) for col in cols] for row in rows],
                alt_solutions=[[get_solutions(cells, row, col, alt=True) for col in cols] for row in rows],
                rows=[(nominal_categories[cat], value) for cat, value in rows],
                cols=[(nominal_categories[cat], value) for cat, value in cols])
    game.i = i
    return game

constraints = [
    # We always want a continent
    Constraint.category_at_least("continent", 1),
    
    # Some categories are pretty boring to appear multiple times
    Constraint.category_at_most("capital_ending_letter", 1),
    Constraint.category_at_most("capital_starting_letter", 1),
    Constraint.category_at_most("ending_letter", 1)
]

games = [sample_game(constraints=constraints, shuffle=False) for _ in range(1000)]

for game in games[:10]:
    print(game.i)
    display(game.to_dataframe(solution=True))
# get_allowed_sets([('capital_ending_letter', 'T'), ('flag_colors', 'Red'), ('starting_letter', 'T')], [('capital_starting_letter', 'O')])

In [None]:
constraints = [
    # We always want a continent
    Constraint.category_at_least("continent", 1),
    
    # Some categories are pretty boring to appear multiple times
    Constraint.category_at_most("capital_ending_letter", 1),
    Constraint.category_at_most("capital_starting_letter", 1),
    Constraint.category_at_most("ending_letter", 1)
]

games = [sample_game(constraints=constraints, shuffle=False) for _ in range(1000)]

SAVE = False

if SAVE:
    json.dump([game.to_json() for game in games], open("games.json", mode="w", encoding="utf-8"))

In [None]:
for game in games:
    if sum(sum(game.alt_solutions, []), []):
        
        display(game.to_dataframe(solution=True))