# TODO

Ambiguous continents:
  - Timor-Leste (OC, AS)
  - Armenia, Azerbaijan, Georgia (EU, AS)
  - Trinidad and Tobago (NA, SA)

Adding ambiguous continents should not create new cells!

In [None]:
FIELD_SIZE = 3
MIN_CELL_SIZE = 1
MAX_CELL_SIZE = 10

In [None]:
import json
import pandas as pd

In [None]:
data = json.load(open("data/countries_processed.json", encoding="utf-8"))
df = pd.DataFrame(data)

# Filter & rename columns
df.columns = ['iso', 'iso3', 'iso_numeric', 'fips', 'name', 'capital',
              'area_km2', 'population', 'continent', 'tld',
              'currency_code', 'currency_name', 'phone', 'zip_format', 'zip_regex',
              'languages', 'geonameid', 'neighbors', 'eq_fips', 'parent', 'territories', 'neighbors_t']
subset = ['iso', 'name', 'capital', 'continent',
          'area_km2', 'population',
          'currency_code', 'currency_name', 'languages',
          'territories', 'neighbors_t']
df = df[subset]
df.rename(columns={"neighbors_t": "neighbors"}, inplace=True)

# Additional columns & global fixes
df["continent"].fillna("NA", inplace=True)  # North America fix
df["landlocked"] = df["iso"].isin("AF,AD,AM,AT,AZ,BY,BT,BO,BW,BF,BI,CF,TD,CZ,SZ,ET,HU,KZ,XK,LA,LS,LI,LU,MW,ML,MD,MN,NP,NE,MK,PY,RW,SM,RS,SK,SS,CH,TJ,UG,VA,ZA,ZW".split(","))
df["island"] = df["neighbors"].apply(len) == 0

def add_alternative_value(df, col, country, *values):
    cols = list(df.columns)
    if col not in cols:
        return False
    
    # Add alternative column if not exists
    altcol = col + "_alt"
    if altcol not in cols:
        df.insert(cols.index(col) + 1, altcol, df[col].apply(lambda x: []))
    
    # Find country
    if country in df["name"].values:
        index = df.index[df["name"] == country][0]
    elif country in df["iso"].values:
        index = df.index[df["iso"] == country][0]
    else:
        print(f"country {country} not found!")
        return False
    
    # Add values
    values = sum([[x] if not isinstance(x, list) else x for x in values], [])
#     values = [val for val in values if val not in df.loc[index, altcol]]
    
    # Warn if value to-be-added is the actual value. Swap if not first-named
    if df.loc[index, col] in values:
        if values[0] == df.loc[index, col]:
            print(f"{country}/{col}: '{df.loc[index, col]}' is already set as main value - skipping")
        else:
            print(f"{country}/{col}: '{df.loc[index, col]}' is already set as main value - swapping with '{values[0]}'")
            df.loc[index, col] = values[0]
        values = values[1:]

    for val in values:
        if val not in df.loc[index, altcol]:
            df.loc[index, altcol].append(val)
    return True

# Individual fixes
df.loc[df["name"] == "Palau", "capital"] = "Ngerulmud"

# Alternative values
# Names
add_alternative_value(df, "name", "CI", "Ivory Coast", "Côte d'Ivoire")
add_alternative_value(df, "name", "TR", "Türkiye", "Turkey")
add_alternative_value(df, "name", "VA", "Vatican", "Vatican City")

# Multiple continents (source: https://en.wikipedia.org/wiki/List_of_transcontinental_countries)
add_alternative_value(df, "continent", "Armenia", "AS", "EU")
add_alternative_value(df, "continent", "Georgia", "AS", "EU")
add_alternative_value(df, "continent", "Azerbaijan", "AS", "EU")
add_alternative_value(df, "continent", "Trinidad and Tobago", "SA", "NA")
add_alternative_value(df, "continent", "Panama", "NA", "SA")
add_alternative_value(df, "continent", "Egypt", "AF", "AS")
add_alternative_value(df, "continent", "Russia", "EU", "AS")
add_alternative_value(df, "continent", "TR", "AS", "EU")
add_alternative_value(df, "continent", "Timor Leste", "AS", "OC")

# Multiple capitals (source: https://en.wikipedia.org/wiki/List_of_countries_with_multiple_capitals)
add_alternative_value(df, "capital", "Kazakhstan", "Astana", "Nur-Sultan")
add_alternative_value(df, "capital", "Bolivia", "La Paz", "Sucre")
add_alternative_value(df, "capital", "Burundi", "Gitega", "Bujumbura")
add_alternative_value(df, "capital", "CI", "Yamoussoukro", "Abidjan")
add_alternative_value(df, "capital", "Eswatini", "Mbabane", "Lobamba")
add_alternative_value(df, "capital", "Malaysia", "Kuala Lumpur", "Putrajaya")
add_alternative_value(df, "capital", "Netherlands", "Amsterdam", "The Hague")
add_alternative_value(df, "capital", "South Africa", "Pretoria", "Cape Town", "Bloemfontein")
add_alternative_value(df, "capital", "Sri Lanka", "Colombo", "Sri Jayawardenepura Kotte")

# Display all changes
altcols = [col for col in df.columns if col.endswith("_alt")]
print("\nAll countries with alternative values:")
df[df[altcols].applymap(len).sum(axis=1) > 0]

# df.reset_index(inplace=True)
# df

In [None]:
# Assign colors
colors = pd.read_csv("data/flag-colors.csv", sep=";")
colors.columns = ["country", "color"]
colors["country"].fillna(method='ffill', inplace=True)
colors.dropna(inplace=True)
colors = colors.applymap(lambda x: x.strip())


cmap = {
    "Light Blue": "Blue",
    "Dark Blue": "Blue",
    "Sky Blue": "Blue",
    "Aquamarine Blue": "Blue",
    "Fulvous": "Orange",
    "Crimson": "Red",
    "Saffron Orange": "Orange",
    "Green Or Blue": "Green",
    "Maroon": "Red",  # Qatar, Sri Lanka,
    "Olive Green": "Green",
    "Yellow": "Yellow/Gold",
    "Gold": "Yellow/Gold",
    "Golden": "Yellow/Gold",
}
name_map = {
    'American Samoa': None,
    'Anguilla': None,
    'Antigua And Barbuda': 'Antigua and Barbuda',
    'Aruba': None,
    'Bermuda': None,
    'Bosnia And Herzegovina': 'Bosnia and Herzegovina',
    'Bouvet Island': None,
    'Brunei Darussalam': "Brunei",
    'Cook Islands': None,
    'Curaçao': None,
    "Côte D'Ivoire": 'Ivory Coast',
    'Democratic Republic Of The Congo': 'Democratic Republic of the Congo',
    'French Polynesia': None,
    'Holy See (Vatican City State)': "Vatican",
    'Niue': None,
    'Norfolk Island': None,
    'Palestine': 'Palestinian Territory',
    'Pitcairn Islands': None,
    'Republic Of The Congo': 'Republic of the Congo',
    'Russian Federation': "Russia",
    'Saint Kitts And Nevis': 'Saint Kitts and Nevis',
    'Saint Vincent And The Grenadines': 'Saint Vincent and the Grenadines',
    'Sao Tome And Principe': 'Sao Tome and Principe',
    'Syrian Arab Republic': "Syria",
    'Tanzania, United Republic Of': "Tanzania",
    'Trinidad And Tobago': 'Trinidad and Tobago',
    'Åland Islands': None,
    'Turkey': 'Türkiye'
}
add = [
    {"country": "Timor Leste", "color": ["Red", "Yellow/Gold", "Black", "White"]},
    {"country": "Kosovo", "color": ["Blue", "Yellow/Gold", "White"]},
    {"country": "Taiwan", "color": ["Red", "Blue", "White"]}
]

colors["color"] = colors["color"].apply(lambda c: cmap.get(c, c))
colors["country"] = colors["country"].apply(lambda x: name_map.get(x, x))
colors1 = colors.groupby(by="country")["color"].agg(list).reset_index()
colors1 = colors1.append(add, ignore_index=True)

colors2 = pd.merge(df[["name"]], colors1, how="left", left_on="name", right_on="country")
colors2_outer = pd.merge(df[["name"]], colors1, how="outer", left_on="name", right_on="country", indicator=True)
df["flag_colors"] = colors2["color"]

no_flag = df["flag_colors"].isna().sum()
print(f"Assigned colors. {no_flag} countries missing a flag.")

In [None]:
# from IPython.display import Image, display
from IPython.display import SVG, display

for row in list(df.iterrows()):
    country = row[1].to_dict()
    print(f'{country["name"]} ({country["iso"]}): {", ".join(country["flag_colors"])}')
    display(SVG(url=f'https://hatscripts.github.io/circle-flags/flags/{country["iso"].lower()}.svg'))
    print()
# display(Image(filename=''))

In [None]:
from collections import Counter
Counter(df["flag_colors"].sum())

In [None]:
df[df["flag_colors"].map(lambda x: "Gold" in x)]

In [None]:
df

In [None]:
df.columns

In [None]:
from functools import total_ordering

@total_ordering
class Category:
    def __init__(self, key: str, name: str, difficulty: float, values: pd.Series):
        self.key = key
        self.name = name
        self.difficulty = difficulty
        self.values = values
        
    def __repr__(self):
        return str(self)
    
    def __lt__(self, other):
        return self.key < other.key
    
    def __eq__(self, other):
        return self.key == other.key
    
    def __hash__(self):
        return hash(self.key)

@total_ordering
class NominalCategory(Category):
    def __init__(self, key: str, name: str, difficulty: float, values: pd.Series):
        super().__init__(key, name, difficulty, values)
        self.sets = None
        
    def __str__(self):
        return f"NominalCategory('{self.key}', {len(self.sets)} values)"
    
    def __lt__(self, other):
        return self.key < other.key
    
    def __eq__(self, other):
        return self.key == other.key
    
    def __hash__(self):
        return hash(self.key)

@total_ordering
class MultiNominalCategory(Category):
    def __init__(self, key: str, name: str, difficulty: float, values: pd.Series):
        super().__init__(key, name, difficulty, values)
        
    def __str__(self):
        return f"MultiNominalCategory('{self.key}', {len(self.sets)} values)"
    
    def __lt__(self, other):
        return self.key < other.key
    
    def __eq__(self, other):
        return self.key == other.key
    
    def __hash__(self):
        return hash(self.key)
    
        
nominal_categories = [
    NominalCategory(key="continent", name="Continent", difficulty=1, values=df["Continent"]),
    NominalCategory(key="starting_letter", name="Starting letter", difficulty=1, values=df["Country"].apply(lambda x: x[0].upper())),
    NominalCategory(key="ending_letter", name="Ending letter", difficulty=2, values=df["Country"].apply(lambda x: x[-1].upper())),
    NominalCategory(key="capital_starting_letter", name="Capital starting letter", difficulty=1.5, values=df["Capital"].apply(lambda x: x[0].upper())),
    NominalCategory(key="capital_ending_letter", name="Capital ending letter", difficulty=3, values=df["Capital"].apply(lambda x: x[-1].upper())),
    MultiNominalCategory(key="flag_colors", name="Flag color", difficulty=1.5, values=df["flag_colors"]),
]
nominal_categories = {cat.key: cat for cat in nominal_categories}

values = pd.concat([
    df[["ISO", "Country"]],
    pd.DataFrame({cat.key: cat.values for cat in nominal_categories.values()}),
#     pd.DataFrame(bool_categories)
], axis=1)
values

In [None]:
for cat in nominal_categories.values():
    if isinstance(cat, NominalCategory):
        cat.sets = values.groupby(by=cat.key)["ISO"].agg(sorted)
    elif isinstance(cat, MultiNominalCategory):
        cat.sets = values.explode(column=cat.key).groupby(by=cat.key)["ISO"].agg(sorted)

while True:
    # Retain only sets with at least 3 (FIELD_SIZE) elements
    num_sets_0 = sum(len(cat.sets) for cat in nominal_categories.values())
    for cat in nominal_categories.values():
        cat.sets = cat.sets[cat.sets.apply(len) >= FIELD_SIZE]
    num_sets_1 = sum(len(cat.sets) for cat in nominal_categories.values())
    if num_sets_0 != num_sets_1:
        print(f"Removed {num_sets_0 - num_sets_1} category sets")
    # Retain only countries contained in sets of at least 2 different categories (-> has matching row+column)
    category_contents = {cat.key: cat.sets.sum() for cat in nominal_categories.values()}
    # {cat: len(cc) for cat, cc in category_contents.items()}
    contents = set().union(*[set(cc) for cc in category_contents.values()])
    print("contents:", len(contents))
    retain = {c for c in contents if len([key for key, cc in category_contents.items() if c in cc]) >= 2}
    print("retain:", len(retain))
    remove = contents.difference(retain)
    
    for cat in nominal_categories.values():
        cat.sets = cat.sets.apply(lambda cc: [c for c in cc if c in retain])
    if not remove:
        break
    print(f"Removed {len(remove)} countries:", remove)
    print("Repeat ...")

list(nominal_categories.values())

In [None]:
nominal_categories["starting_letter"]

In [None]:
def compare_sets(set1, set2):
    if len(set1) < len(set2):
        return -1
    if len(set1) > len(set2):
        return 1
    l1 = list(sorted(set(set1)))
    l2 = list(sorted(set(set2)))
    if l1 < l2:
        return -1
    if l1 > l2:
        return 1
    return 0

def min_set(set1, set2):
    cmp = compare_sets(set1, set2)
    return set2 if cmp > 0 else set1

def max_set(set1, set2):
    cmp = compare_sets(set1, set2)
    return set2 if cmp < 0 else set1

In [None]:
nominal_categories

In [None]:
row = ("flag_colors", "Red")
col = ("flag_colors", "White")

print(min_set(row, col), max_set(row, col))
print(set(nominal_categories[row[0]].sets[row[1]]).intersection(nominal_categories[col[0]].sets[col[1]]))

In [None]:
import itertools
import matplotlib.pyplot as plt

setkeys = sum([[(cat.key, value) for value in cat.sets.index] for cat in nominal_categories.values()], [])

print("Number of values per category:")
print({cat.key: len(cat.sets) for cat in nominal_categories.values()})

cells = {(min_set(row, col), max_set(row, col)): set(nominal_categories[row[0]].sets[row[1]]).intersection(nominal_categories[col[0]].sets[col[1]])
         for row, col in itertools.combinations(setkeys, 2) if row[0] != col[0] or (row[1] != col[1] and isinstance(nominal_categories[row[0]], MultiNominalCategory))}

print(f"Generated {len(cells)} cells")

# Bring cells to DataFrame to do filtering (cell size etc.)
cell_info = pd.DataFrame([{"row_cat": row[0], "row_val": row[1], "col_cat": col[0], "col_val": col[1], "contents": contents} for (row, col), contents in cells.items()])
cell_info["size"] = cell_info["contents"].apply(len)

# display(cell_info[cell_info["row_cat"] == cell_info["col_cat"]])

# cell_info["size"].value_counts()
cell_info = cell_info[(cell_info["size"] >= MIN_CELL_SIZE) & (cell_info["size"] <= MAX_CELL_SIZE)]
plt.hist(cell_info["size"], rwidth=.9, bins=[x-.5 for x in range(cell_info["size"].min(), cell_info["size"].max() + 1)])
plt.title("Districution of cell sizes")

# Bring back to dict with tuple access
cell_keys = cell_info.apply(lambda row: ((row["row_cat"], row["row_val"]), (row["col_cat"], row["col_val"])), axis=1)
cells = {key: contents for key, contents in zip(cell_keys, cell_info["contents"])}

In [None]:
[(row, col) for row, col in cells.keys() if row[0] == col[0]]

In [None]:
import networkx as nx

G = nx.Graph()
G.add_nodes_from(setkeys)
G.add_edges_from(cells.keys())
nx.draw(G)

In [None]:
nominal_categories

In [None]:
# Now find 2 distinct node sets of size FIELD_SIZE that are mutually completely connected
# (connections within the set also allowed, that's why it is not necessarily a complete bipartite subgraph)

def subsets_of_sizes(S, sizes):
    return itertools.chain(*[itertools.combinations(S, k) for k in sizes])

def row_col_assignments(categories):
    categories = set(categories.values())
    # TODO incorporate number of category values into iterations

    possible_numbers_of_distinct_row_cats = range(1, min(len(categories) // 2, FIELD_SIZE) + 1)
    # // 2 assuming there are at least as many distinct categories in the columns
    for row_cats in subsets_of_sizes(categories, possible_numbers_of_distinct_row_cats):
        row_cats = set(row_cats)
        possible_col_cats = categories.difference([cat for cat in row_cats if not isinstance(cat, MultiNominalCategory)])
        possible_numbers_of_distinct_col_cats = range(len(row_cats), min(len(possible_col_cats), FIELD_SIZE) + 1)
        for col_cats in subsets_of_sizes(possible_col_cats, possible_numbers_of_distinct_col_cats):
            col_cats = set(col_cats)
            if len(row_cats) == len(col_cats):
                # When equal size: Assume row_cats < col_cats (with lexicographical order)
                if compare_sets(row_cats, col_cats) == 1:
                    continue
            # Sampled too many?
            only_rows = len(row_cats.difference(col_cats))
            only_cols = len(col_cats.difference(row_cats))
            both = len(row_cats.intersection(col_cats))
            if only_rows + only_cols + both > 2 * FIELD_SIZE:
                continue
            
            # Only allow for up to 2 appearances of a category in rows/columns.
            # When having a multi-category in row&column, only allow for one appearance each.
            if 2 * only_rows + both < FIELD_SIZE or 2 * only_cols + both < FIELD_SIZE:
                continue
            
            yield (row_cats, col_cats)

row_col_cats = list(row_col_assignments(nominal_categories))
print(f"Generated {len(row_col_cats)} row-column category assignments.")

In [None]:
nominal_categories

In [None]:
[(rows, cols) for rows, cols in row_col_cats if rows.intersection(cols)]

In [None]:
# set([(len(rows), len(cols)) for rows, cols in row_col_cats])

In [None]:
row_col_info = pd.DataFrame([{"rows": rows, "cols": cols, "difficulty": sum([cat.difficulty for cat in rows.union(cols)])} for rows, cols in row_col_cats])
# row_col_info.sort_values(by="difficulty")

In [None]:
plt.hist(row_col_info["difficulty"], rwidth=.9)
plt.title("Difficulty of generated row-column assignments")
plt.show()

In [None]:
import json


def get_label(cat: Category, value):
    if cat.key == "continent":
        continents = {"AF": "Africa", "EU": "Europe", "AS": "Asia", "NA": "N. America", "SA": "S. America", "OC": "Oceania"}
        return continents[value]
    return f"{cat.name}: {value}"


class Game:
    def __init__(self, values, cells, rows, cols):
        self.size = FIELD_SIZE
        self.values = values  # All possible values to be guessed (list of dicts)
        self.cells = cells  # 3x3 array containing list of possible solutions
        self.rows = rows  # rows (tuples of form (Category, value))
        self.cols = cols  # columns (as above)
    
    def to_json(self):
        return {
            "size": self.size,
            "values": self.values,
            "cells": [[list(cell) for cell in row] for row in self.cells],
            "labels": {
                "rows": [get_label(cat, value) for cat, value in self.rows],
                "cols": [get_label(cat, value) for cat, value in self.cols]
            }
        }
    
    def to_dataframe(self, solution=False):
        game_df = pd.DataFrame(data=self.cells if solution else None,
                               index=[get_label(cat, value) for cat, value in self.rows],
                               columns=[get_label(cat, value) for cat, value in self.cols])
        game_df.fillna("", inplace=True)
        return game_df


def get_cell(cells, row, col):
    if (row, col) in cells:
        return cells[(row, col)]
    if (col, row) in cells:
        return cells[(col, row)]
    return None


# Sample possible game setups

In [None]:
country_values = values[["ISO", "Country"]]
country_values.columns = ["iso", "name"]

In [None]:
import random
from collections import Counter

def get_allowed_sets(cross_sets, parallel_sets):
    # Not 2 identical (cat, value) sets in the game
    choice = set(setkeys).difference(cross_sets).difference(parallel_sets)
    # Not 2 crossing identical categories, except MultiNominal, but then only 1 each
    # Each category only allowed twice
    cross_cats = Counter(cat for cat, value in cross_sets)
    parallel_cats = Counter(cat for cat, value in parallel_sets)
    choice = {(cat, value) for cat, value in choice
              if (cat not in cross_cats and parallel_cats.get(cat, 0) <= 1)
              or (isinstance(nominal_categories[cat], MultiNominalCategory) and cross_cats[cat] == 1 and cat not in parallel_cats)}
    
    return choice


def sample_fitting_set(cross_sets, parallel_sets):
    """ Samples a new column assuming cross_sets are the rows and parallel_sets the previous columns. """
    choice = list(get_allowed_sets(cross_sets, parallel_sets))
    random.shuffle(choice)
    for c in choice:
        if all(get_cell(cells, c, crossing) for crossing in cross_sets):
            return c
    return None

def sample_game():
    MAX_TRIES = 100
    for i in range(MAX_TRIES):
        row0, col0 = random.choice(list(cells.keys()))
        rows, cols = [row0], [col0]
        for _ in range(FIELD_SIZE - 1):
            new_col = sample_fitting_set(cross_sets=rows, parallel_sets=cols)
            if new_col is not None:
                cols.append(new_col)
            else:
                break
            new_row = sample_fitting_set(cross_sets=cols, parallel_sets=rows)
            if new_row is not None:
                rows.append(new_row)
            else:
                break
        if len(rows) == FIELD_SIZE and len(cols) == FIELD_SIZE:
#             print(f"Successfully created game after {i+1} iterations.")
            break

    if random.random() > .5:
        rows, cols = cols, rows
    
    game = Game(values=country_values.to_dict(orient="records"),
                cells=[[get_cell(cells, row, col) for col in cols] for row in rows],
                rows=[(nominal_categories[cat], value) for cat, value in rows],
                cols=[(nominal_categories[cat], value) for cat, value in cols])
    return game

# games = [sample_game() for _ in range(1000)]
# game = sample_game()
# game.to_json()
# game.to_dataframe(solution=True)

# get_allowed_sets([('capital_ending_letter', 'T'), ('flag_colors', 'Red'), ('starting_letter', 'T')], [('capital_starting_letter', 'O')])

In [None]:
games = [sample_game() for _ in range(1000)]

json.dump([game.to_json() for game in games], open("games.json", mode="w", encoding="utf-8"))

In [None]:
# import random

# # NUM_CAT_ORIENTATIONS = 10  # assignments of categories to being row or column
# # NUM_SETUPS_PER_ORIENTATION = 1  # how many fields to generate per orientation
# NUM_GAMES = 10

# setups = []

# setkeys_by_cat = {cat: [(cat, value) for cat1, value in setkeys if cat == cat1] for cat in nominal_categories.keys()}

# def sample_setup(row_col_cats):
#     while True:
#         row_cats, col_cats = random.choice(row_col_cats)
#         if random.random() < .5:
#             row_cats, col_cats = col_cats, row_cats

#     #     print(row_cats, col_cats)

#         # Multi-value category is allowed in 1 row and 1 column at once, but not more than that
#         multi_cats = row_cats.intersection(col_cats)
#         # Check self-intersection of multi-value category
#         multi_setkeys = set().union(*[setkeys_by_cat[cat.key] for cat in multi_cats])

#         row_setkeys = set().union(*[setkeys_by_cat[cat.key] for cat in row_cats])
#         col_setkeys = set().union(*[setkeys_by_cat[cat.key] for cat in col_cats])

#         # Cell subgraph with the selected categories (all values)
#         G1 = nx.subgraph(G, row_setkeys.union(col_setkeys)).copy()
#     #     G1 = nx.subgraph(G, [(cat, value) for cat, value in setkeys if cat in row_cats or cat in col_cats]).copy()

#         # Make "anti-bipartite" (make subgraph of row resp. col categories complete)
#         # (Except for mult-categories appearing in both rows and cols)
#         # (The anti-bipartite property can only be ensured if the same multi-category only appears in ONE row and column)
#         if not multi_cats:
#             G1.add_edges_from(itertools.combinations(row_setkeys.difference(multi_setkeys), 2))
#             G1.add_edges_from(itertools.combinations(col_setkeys.difference(multi_setkeys), 2))
#         else:
#             # TODO add these edges only if the cell exists! (and this should already have happened)
#     #         G1.add_edges_from(itertools.product(row_setkeys.difference(multi_setkeys), multi_setkeys))
#     #         G1.add_edges_from(itertools.product(col_setkeys.difference(multi_setkeys), multi_setkeys))
#     #         G1.add_edges_from(((cat1, value1), (cat2, value2))
#     #                           for (cat1, value1), (cat2, value2)
#     #                           in itertools.combinations(multi_setkeys, 2)
#     #                           if cat1 != cat2)
#             # TODO have to assign multi-cat values (setkeys) to being row/col
#             # Idea: first sample clique treating the multi-cat as row only / col only.
#             # Then merge both results and extract possible row-col assignments of the multi setkeys.
#             # Then do another final clique check.
#             G1r = G1.copy()
#             G1r.add_edges_from(itertools.combinations(row_setkeys, 2))
#             G1r.add_edges_from(itertools.combinations(col_setkeys.difference(multi_setkeys), 2))
#             G1c = G1.copy()
#             G1c.add_edges_from(itertools.combinations(row_setkeys.difference(multi_setkeys), 2))
#             G1c.add_edges_from(itertools.combinations(col_setkeys, 2))
#             cliques_r = [clique for clique in nx.find_cliques(G1r) if len(clique) >= 2 * FIELD_SIZE - len(multi_cats) and len(set(cat for cat, _ in clique)) > 1]
#             cliques_c = [clique for clique in nx.find_cliques(G1c) if len(clique) >= 2 * FIELD_SIZE - len(multi_cats) and len(set(cat for cat, _ in clique)) > 1]
        
#             # merge: intersection of non-multi, union of multi setkeys.
# #             clique_pairs = [(clr, clc) for clr, clc in itertools.product(cliques_r, cliques_c)
# #                             if set(clr).difference(multi_setkeys) == set(clc).difference(multi_setkeys)
# #                             and set(clr).intersection(multi_setkeys).isdisjoint(set(clc).intersection(multi_setkeys))]
#             for clr, clc in itertools.product(cliques_r, cliques_c):
#         # tTODO bää
#                 both = set(clr).intersection(clc)
#                 multi_r = set(clr).intersection(multi_setkeys)
#                 multi_c = set(clc).intersection(multi_setkeys)
#                 single_r = both.intersection(row_setkeys).difference(multi_c)
#                 single_c = both.intersection(col_setkeys).difference(multi_r)
#                 if len(both_r)
                
#                 row_setkeys = set((cat, value) for cat, value in both if nominal_categories[cat] in row_cats)
#                 col_setkeys = set((cat, value) for cat, value in both if nominal_categories[cat] in col_cats)
    
#             return None

#         # Compute cliques
#         cliques = [clique for clique in nx.find_cliques(G1) if len(clique) >= 2 * FIELD_SIZE and len(set(cat for cat, _ in clique)) > 1]
#         del G1
#         random.shuffle(cliques)
#         for i, clique in enumerate(cliques):
#             row_setkeys = set((cat, value) for cat, value in clique if nominal_categories[cat] in row_cats)
#             col_setkeys = set((cat, value) for cat, value in clique if nominal_categories[cat] in col_cats)
# #             col_setkeys = set(clique).difference(row_setkeys)

#             # Is clique large enough?
#             if len(row_setkeys) < FIELD_SIZE or len(col_setkeys) < FIELD_SIZE:
#                 continue

#     #         print(f"#{i}, size {len(row_setkeys)}x{len(col_setkeys)}")#, "rows:", row_setkeys, "cols:", col_setkeys)
#             #set((cat, value) for cat, value in clique if cat in col_cats)

#             # visualize complete bipartite graph
#     #         G2 = nx.Graph()
#     #         G2.add_nodes_from(row_setkeys, bipartite=0)
#     #         G2.add_nodes_from(col_setkeys, bipartite=1)
#     #         G2.add_edges_from([(u, v) for u, v in cells.keys() if u in row_setkeys and v in col_setkeys])
#     #         nx.draw(G2)
#     #         plt.show()
#     #         del G2
        
#             rows = random.sample(row_setkeys, FIELD_SIZE)
#             cols = random.sample(col_setkeys, FIELD_SIZE)
#             return rows, cols

#         if not cliques:
#             print("Found no cliques for category setup", (row_cats, col_cats))
#         del cliques
        
# # for rows, cols in setups:
# #     print(rows, cols)

# setups = []
# while len(setups) < NUM_GAMES:
#     setups.append(sample_setup(row_col_cats))

# games = [Game(values=values["Country"].tolist(),
#               cells=[[get_cell(cells, row, col) for col in cols] for row in rows],
#               rows=[(nominal_categories[cat], value) for cat, value in rows],
#               cols=[(nominal_categories[cat], value) for cat, value in cols]) for rows, cols in setups]

# for game in games:
#     display(game.to_dataframe(solution=True))
# #     print(list(G1.nodes))

In [None]:
# import random

# NUM_CAT_ORIENTATIONS = 10  # assignments of categories to being row or column
# NUM_SETUPS_PER_ORIENTATION = 1  # how many fields to generate per orientation

# setups = []

# setkeys_by_cat = {cat: [(cat, value) for cat1, value in setkeys if cat == cat1] for cat in nominal_categories.keys()}

# for row_cats, col_cats in random.sample(row_col_cats, NUM_CAT_ORIENTATIONS):
#     if random.random() < .5:
#         row_cats, col_cats = col_cats, row_cats
    
# #     print(row_cats, col_cats)
    
#     # Multi-value category is allowed in 1 row and 1 column at once, but not more than that
#     multi_cats = row_cats.intersection(col_cats)
#     # Check self-intersection of multi-value category
#     multi_setkeys = set().union(*[setkeys_by_cat[cat] for cat in multi_cats])
    
#     row_setkeys = set().union(*[setkeys_by_cat[cat] for cat in row_cats])
#     col_setkeys = set().union(*[setkeys_by_cat[cat] for cat in col_cats])
    
#     # Cell subgraph with the selected categories (all values)
#     G1 = nx.subgraph(G, row_setkeys.union(col_setkeys)).copy()
# #     G1 = nx.subgraph(G, [(cat, value) for cat, value in setkeys if cat in row_cats or cat in col_cats]).copy()
    
#     # Make "anti-bipartite" (make subgraph of row resp. col categories complete)
#     # (Except for mult-categories appearing in both rows and cols)
#     # (The anti-bipartite property can only be ensured if the same multi-category only appears in ONE row and column)
#     G1.add_edges_from(itertools.combinations(row_setkeys.difference(multi_setkeys), 2))
#     G1.add_edges_from(itertools.combinations(col_setkeys.difference(multi_setkeys), 2))
#     G1.add_edges_from(itertools.product(row_setkeys, multi_setkeys))
#     G1.add_edges_from(itertools.product(col_setkeys, multi_setkeys))
#     G1.add_edges_from(((cat1, value1), (cat2, value2))
#                       for (cat1, value1), (cat2, value2)
#                       in itertools.combinations(multi_setkeys, 2)
#                       if cat1 != cat2)
#     # Compute cliques
#     cliques = [clique for clique in nx.find_cliques(G1) if len(clique) >= 2 * FIELD_SIZE and len(set(cat for cat, _ in clique)) > 1]
#     del G1
#     random.shuffle(cliques)
#     bipartite_cliques = []
#     for i, clique in enumerate(cliques):
#         row_setkeys = set((cat, value) for cat, value in clique if cat in row_cats)
#         col_setkeys = set(clique).difference(row_setkeys)
        
#         if len(row_setkeys) < FIELD_SIZE or len(col_setkeys) < FIELD_SIZE:
#             continue
        
# #         print(f"#{i}, size {len(row_setkeys)}x{len(col_setkeys)}")#, "rows:", row_setkeys, "cols:", col_setkeys)
#         #set((cat, value) for cat, value in clique if cat in col_cats)
#         bipartite_cliques.append((row_setkeys, col_setkeys))
        
#         # visualize complete bipartite graph
# #         G2 = nx.Graph()
# #         G2.add_nodes_from(row_setkeys, bipartite=0)
# #         G2.add_nodes_from(col_setkeys, bipartite=1)
# #         G2.add_edges_from([(u, v) for u, v in cells.keys() if u in row_setkeys and v in col_setkeys])
# #         nx.draw(G2)
# #         plt.show()
# #         del G2
        
#         if len(bipartite_cliques) >= NUM_SETUPS_PER_ORIENTATION and NUM_SETUPS_PER_ORIENTATION is not None:
#             break
            
#     for row_setkeys, col_setkeys in bipartite_cliques:
#         rows = random.sample(row_setkeys, FIELD_SIZE)
#         cols = random.sample(col_setkeys, FIELD_SIZE)
#         setups.append((rows, cols))
    
#     if not cliques:
#         print("Found no cliques for category setup", (row_cats, col_cats))
#     del cliques
        
# # for rows, cols in setups:
# #     print(rows, cols)


# games = [Game(values=values["Country"].tolist(),
#               cells=[[get_cell(cells, row, col) for col in cols] for row in rows],
#               rows=[(nominal_categories[cat], value) for cat, value in rows],
#               cols=[(nominal_categories[cat], value) for cat, value in cols]) for rows, cols in setups]

# for game in games:
#     display(game.to_dataframe())
# #     print(list(G1.nodes))

In [None]:
# Generate all possible combinations of categories along the rows/columns (FIELD_SIZE)
# Can use same category multiple times, but not in both a row and a column.



In [None]:
df["neighbours"].apply(lambda x: not x)

## Category Ideas

- Starting/ending with letter
- Capital starting/ending with letter
- Top/Bottom 20 (area/population)
- (dynamic): Bigger/smaller/More/less populated than X
- Island?
- Landlocked?

In [None]:
# nominal_categories = {
#     "Continent": df["Continent"],
#     "Starting letter": df["Country"].apply(lambda x: x[0].upper()),
#     "Ending letter": df["Country"].apply(lambda x: x[-1].upper()),
#     "Capital starting letter": df["Capital"].apply(lambda x: x[0].upper()),
#     "Capital ending letter": df["Capital"].apply(lambda x: x[-1].upper()),
# }
# bool_categories = {
#     "Island": df["neighbours"].apply(lambda x: not x),
#     "Landlocked": None,
#     "Top 20 Area": df.ISO.isin(df.nlargest(10, 'Area(in sq km)').ISO),
#     "Bottom 20 Area": df.ISO.isin(df.nsmallest(10, 'Area(in sq km)').ISO),
#     "Top 20 Pop.": df.ISO.isin(df.nlargest(10, 'Population').ISO),
#     "Bottom 20 Pop.": df.ISO.isin(df.nsmallest(10, 'Population').ISO),
# }
# values = pd.concat([df[["ISO", "Country"]], pd.DataFrame(nominal_categories), pd.DataFrame(bool_categories)], axis=1)
# values