In [240]:
import abc
import pathlib
from typing import List, Sequence, Optional, Set, Tuple
import typing as tp
import attr
import nptyping as npt
import numpy as np
import pandas as pd
import scann
from scipy.special import softmax

In [77]:
if tp.TYPE_CHECKING:
    import os  # noqa


PathLike = tp.Union[str, "os.PathLike[str]"]

In [366]:
!ls ../../

directional.txt			    LICENSE    README.md  wiki-100k.txt
google-10000-english-no-swears.txt  notebooks  src	  wordlist-eng.txt


In [3]:
def regularize(list_of_tokens: List[str]) -> List[str]:
    return [token.strip().upper() for token in list_of_tokens]

In [78]:
class WordList:
    def __init__(
            self,
            wordlist_path: str,
            illegals_paths: Optional[List[str]] = None,
            allowed_paths: Optional[List[str]] = None
    ):
        path = pathlib.Path(wordlist_path)
        with path.open() as f:
            self.words = regularize(f.read().splitlines())
        self.illegals = self.load_texts(illegals_paths) if illegals_paths else set()
        self.allowed = self.load_texts(allowed_paths) if allowed_paths else set()
        # If it is illegal for the board, it will be detected later on
        self.allowed.update(self.words)

    def load_texts(self, paths: tp.List[PathLike]) -> tp.Set[str]:
        texts = set()
        for pth in paths:
            path = pathlib.Path(pth)
            with path.open() as f:
                texts.update(regularize(f.read().splitlines()))
        return texts

In [5]:
pathlib.Path("../../wordlist-eng.txt").exists()

True

In [6]:
wordlist = WordList(
    "../../wordlist-eng.txt",
    ["../../directional.txt"],
    ["../../wiki-100k.txt", "../../google-10000-english-no-swears.txt", "../../custom_whitelist.txt"]
)

In [7]:
wordlist

<__main__.WordList at 0x7fdd4c2521f0>

In [231]:
default_rng = np.random.default_rng()

In [9]:
wordlist.illegals

{'BOTTOM',
 'DOWN',
 'EAST',
 'LEFT',
 'NORTH',
 'RIGHT',
 'SOUTH',
 'TOP',
 'UP',
 'WEST'}

In [153]:
NUM_LEAVES_TO_SEARCH = 300
PRE_REORDER_NUM_NEIGHBOURS = 250


labels = ["BLUE"] * 9 + ["RED"] * 8 + ["BYSTANDER"] * 7 + ["ASSASSIN"]
bot_labels = np.array(["OURS", "THEIRS", "BYSTANDER", "ASSASSIN"])
valid_teams = {"BLUE", "RED"}
unique_labels = np.unique(labels).tolist()

In [11]:
len(labels)

25

In [228]:
def is_superstring_or_substring(word: str, target: str) -> bool:
    return target in word or word in target


class Board:
    def __init__(self, wordlist: WordList, rng: tp.Optional[object] = None) -> None:
        self.wordlist = wordlist
        if rng is None:
            rng = default_rng
        self.rng = np.random.default_rng(rng)
        self.words = self.rng.choice(wordlist.words, 25, replace=False)
        self.word2index = {word: i for i, word in enumerate(self.words)}
        self.labels: tp.List[str] = self.rng.permutation(labels)
        self.reset_game()

    def is_related_word(self, word: str) -> bool:
        word = word.upper()
        return any(is_superstring_or_substring(word, target) for target in self.words)

    def is_illegal(self, word: str) -> bool:
        word = word.upper()
        return (
            self.is_related_word(word)
            or word in self.wordlist.illegals
            or word not in self.wordlist.allowed
        )

    def batch_is_illegal(self, words: npt.NDArray[npt.Shape["*"], npt.typing_.Str]) -> npt.NDArray[npt.Shape["*"], npt.typing_.Bool]:
        return np.array([self.is_illegal(w) for w in words])

    def reset_game(self) -> None:
        self.chosen = np.array([False] * 25)
        self.which_team_guessing = "BLUE"
        # self.hint_history = []
        # self.state_history = None

    def opponent_of(self, team: str):
        assert team in valid_teams
        return list(valid_teams.difference([team]))[0]

    def end_turn(self):
        self.which_team_guessing = self.opponent_of(self.which_team_guessing)

    def is_chosen_with_index(self, word: str) -> tp.Tuple[bool, int]:
        if word.upper() not in self.words:
            raise KeyError(f"Word '{word}' is not on the board.")
        index = self.word2index[word]
        return self.chosen[index], index

    def is_chosen(self, word: str) -> bool:
        self.is_chosen_with_index(word)[0]

    def choose_word(self, word: str) -> str:
        chosen, index = self.is_chosen_with_index(word)
        if chosen:
            raise ValueError(f"Word '{word}' has already been chosen!")
        self.chosen[index] = True
        return self.labels[index]

    def words_that_are_label(self, label):
        return self.words[self.labels == label]

    @property
    def blue_words(self):
        return self.words_that_are_label("BLUE")

    @property
    def red_words(self):
        return self.words_that_are_label("RED")

    @property
    def bystander_words(self):
        return self.words_that_are_label("BYSTANDER")

    @property
    def assassin_words(self):
        """There is only one assassin in a regular game, but for the sake of generality, here we go!"""
        return self.words_that_are_label("ASSASSIN")

    def indices_for_label(self, label):
        return np.where(self.labels == label)[0]

    @property
    def blue_indices(self):
        return self.indices_for_label("BLUE")

    @property
    def red_indices(self):
        return self.indices_for_label("RED")

    @property
    def bystander_indices(self):
        return self.indices_for_label("BYSTANDER")

    @property
    def assassin_indices(self):
        """There is only one assassin in a regular game, but for the sake of generality, here we go!"""
        return self.indices_for_label("ASSASSIN")

    def jump_to_random_state(self) -> None:
        """Jump to a valid random state before the end of the game.

        There must be 1 assassin, 1-9 blue words, 1-8 red words, and 0-7 bystanders.
        Thus, 0-8 blue, 0-7 red and 1-6 bystander words are chosen.
        """
        self.reset_game()
        num_blue = self.rng.integers(0, 9)
        num_red = self.rng.integers(0, 8)
        num_bystanders = self.rng.integers(0, 7)
        chosen_blue = self.rng.choice(self.blue_indices, num_blue, replace=False)
        chosen_red = self.rng.choice(self.red_indices, num_red, replace=False)
        chosen_bystanders = self.rng.choice(
            self.bystander_indices, num_bystanders, replace=False
        )
        chosen_indices = np.concatenate([chosen_blue, chosen_red, chosen_bystanders])
        self.chosen[chosen_indices] = True
        self.which_team_guessing = self.rng.choice(["BLUE", "RED"])

    def bag_state(self) -> tp.Dict[str, tp.Set[str]]:
        return {
            label: set(self.words[(self.labels == label) & ~self.chosen])
            for label in unique_labels
        }

    def remaining_words(self) -> npt.NDArray[npt.Shape["*"], npt.typing_.Str]:
        return self.words[~self.chosen]

    def remaining_words_for_team(self, team: str) -> int:
        return (~self.chosen & (self.labels == team)).sum()

    def orient_label(self, my_team, opponent_team, label):
        if label == my_team:
            return "OURS"
        elif label == opponent_team:
            return "THEIRS"
        return label

    def orient_labels_for_team(self, my_team: str) -> npt.NDArray[npt.Shape["*"], npt.typing_.Str]:
        opponent_team = self.opponent_of(my_team)
        return np.array(
            [self.orient_label(my_team, opponent_team, label) for label in self.labels]
        )

In [18]:
rng.choice([0, 1])

0

In [57]:
class CliView:
    def __init__(self, board: Board):
        self.board = board

    def spymaster_words_to_display(self):
        words = []
        # Arguably more readable than the equivalent list comprehension
        for w, l, c in zip(self.board.words, self.board.labels, self.board.chosen):
            w += f"_{l[0]}"
            if c:
                w = w.lower()
            words.append(w) 
        return words

    def operative_words_to_display(self):
        words = []
        # Arguably more readable than the equivalent list comprehension
        for w, l, c in zip(self.board.words, self.board.labels, self.board.chosen):
            if c:
                w += f"_{l[0]}"
                w = w.lower()
            words.append(w)
        return words

    def generic_view(self, words_to_display):
        words = words_to_display()
        print(np.array(words).reshape(5,5))
        print(f"It is {self.board.which_team_guessing}'s turn.")

    def spymaster_view(self):
        self.generic_view(self.spymaster_words_to_display)

    def operative_view(self):
        self.generic_view(self.operative_words_to_display)

    def bag_words_team_view(self):
        words = self.board.bag_state()
        my_team = self.board.which_team_guessing
        other_team = "RED" if my_team == "BLUE" else "BLUE"
        words_oriented_perspective = {
            "My team": [w.lower() for w in words[my_team]],
            "Enemy team": [w.lower() for w in words[other_team]],
            "Neutral": [w.lower() for w in words["BYSTANDER"]],
            "Death": [w.lower() for w in words["ASSASSIN"]]
        }
        return "\n".join([f"{k}: {', '.join(v)}" for k, v in words_oriented_perspective.items()])    

In [20]:
np.concatenate([np.arange(3), np.arange(2)])

array([0, 1, 2, 0, 1])

In [232]:
board = Board(wordlist)
view = CliView(board)

In [67]:
view.spymaster_view()

[['DICE_R' 'GOLD_R' 'heart_r' 'NOTE_B' 'VET_R']
 ['BOOM_B' 'hospital_b' 'COURT_B' 'SUPERHERO_B' 'WIND_B']
 ['SOUL_B' 'BELT_B' 'SCREEN_B' 'TRIP_R' 'SNOW_R']
 ['RULER_B' 'CLOAK_R' 'water_b' 'AGENT_R' 'washer_b']
 ['snowman_b' 'KANGAROO_B' 'HOLE_A' 'CENTAUR_B' 'PUMPKIN_B']]
It is BLUE's turn.


In [68]:
print(view.bag_words_team_view())

My team: court, screen, boom, ruler, pumpkin, soul, wind, belt
Enemy team: cloak, trip, dice, gold, vet, snow, agent
Neutral: superhero, kangaroo, centaur, note
Death: hole


In [23]:
board.bag_state()

{'ASSASSIN': {'EGYPT'},
 'BLUE': {'EAGLE',
  'FIGHTER',
  'FLUTE',
  'JACK',
  'JET',
  'MAPLE',
  'PASTE',
  'PISTOL',
  'TOOTH'},
 'BYSTANDER': {'DRAFT',
  'MILLIONAIRE',
  'PILOT',
  'POISON',
  'SHOP',
  'SNOWMAN',
  'WEB'},
 'RED': {'CIRCLE',
  'CONCERT',
  'CRASH',
  'ROULETTE',
  'SNOW',
  'STREAM',
  'SUIT',
  'WAKE'}}

In [24]:
board.blue_indices

array([ 3,  5,  6,  7,  8, 14, 22, 23, 24])

In [65]:
board.jump_to_random_state()

In [66]:
view.operative_view()

[['DICE' 'GOLD' 'heart_r' 'NOTE' 'VET']
 ['BOOM' 'hospital_b' 'COURT' 'SUPERHERO' 'WIND']
 ['SOUL' 'BELT' 'SCREEN' 'TRIP' 'SNOW']
 ['RULER' 'CLOAK' 'water_b' 'AGENT' 'washer_b']
 ['snowman_b' 'KANGAROO' 'HOLE' 'CENTAUR' 'PUMPKIN']]
It is BLUE's turn.


In [237]:
@attr.s(auto_attribs=True)
class Hint:
    word: str
    count: tp.Optional[int]
    team: str
    num_guessed: int = attr.ib(default=0)
    num_guessed_correctly: int = attr.ib(default=0)


In [28]:
!ls ../../../codenames/dataset/glove

ls: cannot access '../../../codenames/dataset/glove': No such file or directory


In [29]:
glove_path = pathlib.Path("../../../codenames/dataset/glove.6B.300d.npy")


In [30]:
with glove_path.open("rb") as f:
    glove_vectors = np.load(f)

In [70]:
class TextVectorEngine(metaclass=abc.ABCMeta):
    # TODO: add self.vectors, self.tokens, self.token2id here too

    @abc.abstractmethod
    def is_valid_token(self, token):
        pass

    @abc.abstractmethod
    def tokenize(self, phrase):
        pass

    def vectorize(self, phrase: tp.Union[str, tp.Sequence[str]]) -> npt.NDArray:
        if isinstance(phrase, str):
            return self.vectors[self.tokenize(phrase)]
        tokens = np.array(standardize_length(self.tokenize(phrase)))
        return self.vectors[tokens].mean(axis=1)

In [34]:
def batched_norm(vec: np.ndarray) -> np.ndarray:
    """Normalize a batch of vectors
    
    Args:
        vec: (batch, dim)
    """
    return vec / np.linalg.norm(vec, axis=1)[:, None]

In [283]:
glove.vectors[np.array([x in wordlist.allowed for x in glove.tokens])]

array([[ 0.04656  ,  0.21318  , -0.0074364, ...,  0.0090611, -0.20989  ,
         0.053913 ],
       [-0.076947 , -0.021211 ,  0.21271  , ...,  0.18351  , -0.29183  ,
        -0.046533 ],
       [-0.25756  , -0.057132 , -0.6719   , ..., -0.16043  ,  0.046744 ,
        -0.070621 ],
       ...,
       [-0.026489 , -0.12316  ,  0.37794  , ...,  0.24287  , -0.013818 ,
        -0.26118  ],
       [ 1.1097   , -0.033433 , -0.30873  , ...,  0.18013  ,  0.25812  ,
        -0.35347  ],
       [ 0.13382  , -0.21265  , -0.71983  , ...,  0.46108  ,  0.49777  ,
        -0.5036   ]], dtype=float32)

In [292]:
class NumpyVectorEngine(TextVectorEngine):
    def __init__(
        self,
        vectors: npt.NDArray[npt.Shape["*, *"], npt.typing_.Float],
        tokens: npt.NDArray[npt.Shape["*"], npt.typing_.Str],
        normalized: bool = False,
        use_approximate: bool = True,
    ):
        self.vectors = vectors
        self.tokens = tokens
        self.normalized = normalized
        self.use_approximate = use_approximate
        if self.normalized:
            self.vectors = batched_norm(self.vectors)
            if self.use_approximate:
                self.searcher = (
                    scann.scann_ops_pybind.builder(self.vectors, 20, "dot_product")
                    .tree(
                        num_leaves=2000,
                        num_leaves_to_search=NUM_LEAVES_TO_SEARCH,
                        training_sample_size=250000,
                    )
                    .score_ah(2, anisotropic_quantization_threshold=0.2)
                    .reorder(PRE_REORDER_NUM_NEIGHBOURS)
                    .build()
                )
        self.token2id = {t: i for i, t in enumerate(self.tokens)}

    def is_valid_token(self, token: str) -> bool:
        return token.strip().upper() in self.token2id

    def is_tokenizable(self, phrase: str) -> bool:
        return all(token is not None for token in self.tokenize(phrase))

    def tokenize(self, phrase):
        """Simple one-word tokenization. Ignores punctuation."""
        if isinstance(phrase, str):
            phrase = phrase.strip().upper().split()
            return [
                self.token2id[x] if self.is_valid_token(x) else None for x in phrase
            ]
        else:
            phrase = regularize(phrase)
            return [self.tokenize(token) for token in phrase]

    def calculate_similarity_to_word_vector(
        self, word_vector: npt.NDArray
    ) -> npt.NDArray:
        if self.normalized:
            # assert np.allclose(word_vector.sum(-1), np.ones(word_vector.shape[0]))
            return word_vector @ self.vectors.T
        else:
            return batched_cosine_similarity(word_vector, self.vectors)


class Glove(NumpyVectorEngine):

    def __init__(
        self,
        glove_vector_path: PathLike,
        glove_tokens_path: PathLike,
        normalized: bool = False,
        use_approximate: bool = True,
        wordlist: tp.Optional[WordList] = None,
    ):
        gv_path = pathlib.Path(glove_vector_path)
        gt_path = pathlib.Path(glove_tokens_path)
        assert gv_path.exists()
        assert gt_path.exists()
        with gv_path.open("rb") as f:
            vectors = np.load(gv_path)
        with gt_path.open() as f:
            tokens = f.read().splitlines()
            tokens = np.array(regularize(tokens))

        if wordlist is not None:
            allowed_indices = np.array([x in wordlist.allowed for x in glove.tokens])
            vectors = vectors[allowed_indices]
            tokens = tokens[allowed_indices]

        super().__init__(vectors, tokens, normalized=normalized, use_approximate=use_approximate)


def prefix_to_word_stub(prefix: str):
    word_stub = prefix.split('/c/en/', 1)[1]
    word_stub = " ".join(word_stub.upper().split("_"))
    return word_stub


def is_valid_prefix(prefix: str, wordlist: WordList) -> bool:
    if not prefix.startswith('/c/en/'):
        return False
    if prefix.startswith('/c/en/#'):
        return False
    word_stub = prefix_to_word_stub(prefix)
    return word_stub in wordlist.allowed


class ConceptNet(NumpyVectorEngine):

    def __init__(
        self,
        conceptnet_hd5_path: PathLike,
        wordlist: WordList,
        normalized: bool = False,
        use_approximate: bool = True,
    ):
        cn_path = pathlib.Path(conceptnet_hd5_path)
        assert cn_path.is_file()
        conceptnet = pd.read_hdf(cn_path)

        valid_prefixes = [p for p in conceptnet.index if is_valid_prefix(p, wordlist)]
        vectors = conceptnet.loc[valid_prefixes].values
        tokens = np.array([prefix_to_word_stub(p) for p in valid_prefixes if is_valid_prefix(p, wordlist)])

        super().__init__(vectors, tokens, normalized=normalized, use_approximate=use_approximate)

    def tokenize(self, phrase):
        """Basic phrase tokenization. Assumes everything is part of the same phrase. Ignores punctuation."""
        if isinstance(phrase, str):
            phrase = phrase.strip().upper()
            if self.is_valid_token(phrase):
                return [self.token2id[phrase]]
            phrase_arr = phrase.split()
            return [
                self.token2id[x] if self.is_valid_token(x) else None for x in phrase_arr
            ]
        else:
            phrase = regularize(phrase)
            return [self.tokenize(token) for token in phrase]

In [35]:
def batched_cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
    """Take the batched cosine similarity."""
    a_norm = batched_norm(a)  # (batch1, dim)
    b_norm = batched_norm(b)  # (batch2, dim)
    return a_norm @ b_norm.T  # (batch1, batch2)

In [36]:
?board

[0;31mType:[0m        Board
[0;31mString form:[0m <__main__.Board object at 0x7fdd33cc6be0>
[0;31mDocstring:[0m   <no docstring>


In [242]:
def find_x_in_y(x: npt.NDArray, y: npt.NDArray) -> npt.NDArray:
    # https://stackoverflow.com/a/8251757
    x = np.array(x)  # Allow array-based indexing
    index = np.argsort(x)
    sorted_x = x[index]
    sorted_index = np.searchsorted(sorted_x, y)
    y_index = np.take(index, sorted_index, mode="clip")
    return y_index[x[y_index] == y]

In [132]:
def standardize_length(ragged_matrix: tp.Sequence[tp.Sequence]) -> tp.List[tp.Sequence]:
    """Standardize the length of a ragged matrix.

    Example:
        [[3], [4, 5]] -> [[3, 3], [4, 5]]
        [[3], [4, 5], [6, 7, 8]] -> [[3, 3, 3, 3, 3, 3], [4, 5, 4, 5, 4, 5], [6, 7, 8, 6, 7, 8]]
    """
    lengths = [len(i) for i in ragged_matrix]
    lcm = np.lcm.reduce(lengths)
    duplication_count = lcm // lengths
    return [row * n_rep for row, n_rep in zip(ragged_matrix, duplication_count)]

In [319]:
GuessStrategyLookup = tp.Dict[
    str,
    tp.Callable[[npt.NDArray, npt.NDArray, int], tp.Tuple[npt.NDArray, npt.NDArray]],
]


class GloveGuesser:
    def __init__(self, glove: Glove, board: Board, limit: int = 10, p_threshold: float = 0.05):
        self.glove = glove
        self.board = board
        self.limit = limit
        self.p_threshold = p_threshold
        self.word_suggestion_strategy_lookup = {
            "mean": self.generate_word_suggestions_mean,
            "minimax": self.generate_word_suggestions_minimax,
            "approx_mean": self.generate_word_suggestions_mean_approx,
            "approx_minimax": self.generate_word_suggestions_minimax_approx,
        }
        self.guess_strategy_lookup: GuessStrategyLookup = {
            "greedy": self.guess_greedy,
            "softmax": self.guess_softmax,
        }
        self.board_vectors = self.glove.vectorize(self.board.words)

    def indices_illegal_words(self, chosen_words: npt.NDArray):
        return self.board.batch_is_illegal(chosen_words)

    def generate_word_suggestions_abstract(
        self,
        get_similarity_scores: tp.Callable[[tp.List[str]], npt.NDArray],
        words: tp.List[str],
        limit: int = 10,
    ) -> tp.Tuple[tp.Sequence[str], tp.Sequence[float]]:
        for word in words:
            if not self.glove.is_tokenizable(word):
                raise ValueError(f"Hint {word} is not a valid hint word!")
        similarity_scores = get_similarity_scores(words)
        indices = np.argpartition(-similarity_scores, limit)[:limit]
        chosen_words = self.glove.tokens[indices]
        similarity_scores = similarity_scores[indices]
        return chosen_words, similarity_scores

    def get_similarity_scores_mean(self, words: tp.List[str]) -> npt.NDArray:
        word_vector = self.glove.vectorize(words).mean(0)[None, :]
        similarity_scores = self.glove.calculate_similarity_to_word_vector(word_vector)[
            0
        ]
        return similarity_scores

    def generate_word_suggestions_mean(
        self, words: tp.List[str], limit: int = 10
    ) -> tp.Tuple[tp.Sequence[str], tp.Sequence[float]]:
        return self.generate_word_suggestions_abstract(
            self.get_similarity_scores_mean, words, limit
        )

    def get_similarity_scores_minimax(self, words: tp.List[str]) -> npt.NDArray:
        word_vector = self.glove.vectorize(words)
        similarity_scores = self.glove.calculate_similarity_to_word_vector(
            word_vector
        ).min(axis=0)
        return similarity_scores

    def generate_word_suggestions_minimax(
        self, words: tp.List[str], limit: int = 10
    ) -> tp.Tuple[tp.Sequence[str], tp.Sequence[float]]:
        return self.generate_word_suggestions_abstract(
            self.get_similarity_scores_minimax, words, limit
        )

    def generate_word_suggestions_mean_approx(
        self, words: tp.List[str], limit: int = 10
    ) -> tp.Tuple[tp.Sequence[str], tp.Sequence[float]]:
        assert self.glove.normalized and self.glove.use_approximate
        word_vector = self.glove.vectorize(words).mean(0)
        chosen_words, similarity_scores = self.glove.searcher.search(
            word_vector, final_num_neighbors=limit
        )
        chosen_words = self.glove.tokens[chosen_words]
        return chosen_words, similarity_scores

    def generate_word_suggestions_minimax_approx(
        self, words: tp.List[str], limit: int = 10
    ) -> tp.Tuple[tp.Sequence[str], tp.Sequence[float]]:
        assert self.glove.normalized and self.glove.use_approximate
        word_vectors = self.glove.vectorize(words)
        chosen_words = np.unique(
            self.glove.searcher.search_batched(word_vectors, final_num_neighbors=limit)[
                0
            ]
        )
        similarity_scores = np.min(
            word_vectors @ self.glove.vectors[chosen_words].T, axis=0
        )
        indices = np.argpartition(-similarity_scores, limit)[:limit]
        chosen_words = chosen_words[indices]
        chosen_words = self.glove.tokens[chosen_words]
        similarity_scores = similarity_scores[indices]
        return chosen_words, similarity_scores

    def filter_words(
        self,
        chosen_words: npt.NDArray[npt.Shape["*"], npt.typing_.Str],
        similarity_scores: npt.NDArray[npt.Shape["*"], npt.typing_.Float],
        similarity_threshold=0.0,
    ):
        words_to_filter = self.indices_illegal_words(chosen_words) | (
            similarity_scores < similarity_threshold
        )
        return chosen_words[~words_to_filter], similarity_scores[~words_to_filter]

    def re_rank(
        self,
        chosen_words: npt.NDArray[npt.Shape["*"], npt.typing_.Str],
        similarity_scores: npt.NDArray[npt.Shape["*"], npt.typing_.Float],
        limit: int,
    ):
        indices = np.argsort(-similarity_scores)
        chosen_words = chosen_words[indices][:limit]
        similarity_scores = similarity_scores[indices][:limit]
        return chosen_words, similarity_scores

    def give_hint_candidates(
        self, targets: tp.List[str], similarity_threshold=0.0, strategy: str = "minimax"
    ):
        generate_word_suggestions = self.word_suggestion_strategy_lookup[strategy]
        chosen_words, similarity_scores = generate_word_suggestions(
            targets, self.limit * 2
        )

        chosen_words, similarity_scores = self.filter_words(
            chosen_words, similarity_scores
        )

        return self.re_rank(chosen_words, similarity_scores, self.limit)

    def give_hint(
        self, targets: tp.List[str], similarity_threshold=0.0, strategy: str = "minimax"
    ):
        """Greedily choose the best hint."""
        chosen_words, _ = self.give_hint_candidates(
            targets, similarity_threshold, strategy
        )
        return chosen_words[0]

    def choose_hint_parameters(self, hint: Hint) -> tp.Tuple[str, int]:
        """TODO: Add strategy mixins"""
        num_words_remaining = self.board.remaining_words_for_team(
            self.board.which_team_guessing
        )
        if hint.count is None:
            limit = num_words_remaining
        else:
            limit = min(hint.count - hint.num_guessed_correctly, num_words_remaining)
        return hint.word, limit

    def remaining_word_vectors(self) -> npt.NDArray:
        return self.board_vectors[~self.board.chosen]

    def guess_greedy(
        self, remaining_words: npt.NDArray, similarity_scores: npt.NDArray, limit: int
    ) -> tp.Tuple[npt.NDArray, npt.NDArray]:
        indices = np.argsort(-similarity_scores)
        chosen_words = remaining_words[indices][:limit]
        similarity_scores = similarity_scores[indices][:limit]
        return chosen_words, similarity_scores

    def guess_softmax(
        self,
        remaining_words: npt.NDArray,
        similarity_scores: npt.NDArray,
        limit: int,
        temperature: float = 0.05,
    ) -> tp.Tuple[npt.NDArray, npt.NDArray]:
        chosen_words = np.random.choice(
            remaining_words,
            limit,
            p=softmax(similarity_scores / temperature),
            replace=False,
        )
        chosen_words_indices = find_x_in_y(remaining_words, chosen_words)
        return chosen_words, similarity_scores[chosen_words_indices]

    def guess(self, hint: Hint, strategy: str = "softmax") -> tp.Sequence[str]:
        word, limit = self.choose_hint_parameters(hint)
        if not self.glove.is_valid_token(word):
            raise ValueError(f"Hint {word} is not a valid hint word!")
        word_vector = self.glove.vectorize(word)
        remaining_words = self.board.remaining_words()
        remaining_word_vectors = self.remaining_word_vectors()
        similarity_scores = batched_cosine_similarity(
            word_vector, remaining_word_vectors
        )[0]
        guess_with_strategy = self.guess_strategy_lookup[strategy]

        chosen_words, similarity_scores = guess_with_strategy(remaining_words, similarity_scores, limit)
        indices_above_p_threshold = similarity_scores >= self.p_threshold
        return chosen_words[indices_above_p_threshold], similarity_scores[indices_above_p_threshold]


In [73]:
np.arange(3).min()

0

In [293]:
glove = Glove("../../../codenames/dataset/glove.6B.300d.npy", "../../../codenames/dataset/words", wordlist=wordlist, normalized=True)

2022-07-16 05:43:16.633879: I scann/partitioning/partitioner_factory_base.cc:59] Size of sampled dataset for training partition: 53044
2022-07-16 05:43:18.101845: I ./scann/partitioning/kmeans_tree_partitioner_utils.h:88] PartitionerFactory ran in 1.467920834s.


In [75]:
glove.vectors / np.linalg.norm(glove.vectors, axis=1)[:, None]

array([[ 0.00898599,  0.04114335, -0.00143521, ...,  0.00174878,
        -0.04050838,  0.01040511],
       [-0.05874773, -0.05917099,  0.03029284, ..., -0.05357432,
        -0.02812364,  0.08165886],
       [-0.02817159,  0.0030574 ,  0.0231178 , ..., -0.07676922,
        -0.00502329,  0.03069513],
       ...,
       [ 0.01319627, -0.00705923,  0.03197411, ...,  0.03806217,
         0.05397341,  0.0762725 ],
       [ 0.13020873, -0.05790341,  0.04985439, ...,  0.0120673 ,
         0.04541344, -0.02807007],
       [ 0.12707426, -0.08790484,  0.04444436, ...,  0.08578877,
         0.09657492, -0.01748439]], dtype=float32)

In [87]:
type(glove.tokens)

numpy.ndarray

In [89]:
glove.vectorize("Test").shape

(1, 300)

In [154]:
conceptnet = ConceptNet(
    '../../../conceptnet/mini.h5',
    wordlist,
    normalized=True,
)

2022-07-16 04:18:46.087522: I scann/partitioning/partitioner_factory_base.cc:59] Size of sampled dataset for training partition: 47240
2022-07-16 04:18:47.397724: I ./scann/partitioning/kmeans_tree_partitioner_utils.h:88] PartitionerFactory ran in 1.309583514s.


In [128]:
for w in wordlist.words:
    assert conceptnet.is_valid_token(w)

In [42]:
board.is_illegal("AMERICA")

False

In [320]:
guesser = GloveGuesser(glove, board)

In [321]:
cn_guesser = GloveGuesser(conceptnet, board)

In [295]:
glove.tokenize(board.remaining_words())

[[419],
 [18417],
 [439],
 [3304],
 [2391],
 [1975],
 [14862, 18649],
 [20749],
 [5935],
 [2354],
 [7257],
 [1237],
 [530],
 [2348],
 [2235],
 [493],
 [7756],
 [6560],
 [4054],
 [3835],
 [2126],
 [832],
 [2524],
 [20175],
 [2726]]

In [256]:
board.bag_state()

{'ASSASSIN': {'CLUB'},
 'BLUE': {'CAP',
  'DANCE',
  'ENGLAND',
  'LASER',
  'PILOT',
  'PLATE',
  'ROUND',
  'SCORPION',
  'SUIT'},
 'BYSTANDER': {'AZTEC',
  'KNIFE',
  'LIGHT',
  'LOCH NESS',
  'ORGAN',
  'SCREEN',
  'SHOE'},
 'RED': {'CHECK',
  'CLOAK',
  'COVER',
  'CYCLE',
  'ENGINE',
  'EUROPE',
  'IRON',
  'WAVE'}}

In [426]:
%%time
guesser.give_hint_candidates(["cap", "suit", "pilot"], strategy="mean")

CPU times: user 51 ms, sys: 4.44 ms, total: 55.4 ms
Wall time: 4.9 ms


(array(['WEARING', 'JACKET', 'UNIFORM', 'SHIRT', 'BLUE', 'WEAR', 'WORE',
        'SIMILAR', 'PANTS', 'INSTEAD'], dtype='<U68'),
 array([0.34837234, 0.34546995, 0.33026087, 0.32659402, 0.30743918,
        0.30555645, 0.30005243, 0.29982635, 0.29893532, 0.2983142 ],
       dtype=float32))

In [432]:
%%time
guesser.give_hint_candidates(["cap", "suit", "pilot"], strategy="approx_mean")

CPU times: user 3.43 ms, sys: 0 ns, total: 3.43 ms
Wall time: 2.8 ms


(array(['WEARING', 'JACKET', 'UNIFORM', 'SHIRT', 'BLUE', 'WEAR', 'WORE',
        'SIMILAR', 'PANTS', 'INSTEAD'], dtype='<U68'),
 array([0.34837234, 0.34546995, 0.33026087, 0.326594  , 0.3074392 ,
        0.30555642, 0.30005243, 0.29982635, 0.29893535, 0.2983142 ],
       dtype=float32))

In [431]:
%%time
cn_guesser.give_hint_candidates(["cap", "suit", "pilot"], strategy="approx_mean")

CPU times: user 2.56 ms, sys: 0 ns, total: 2.56 ms
Wall time: 2.21 ms


(array(['DOFFING', 'HAT', 'GLENGARRY', 'AVIATOR', 'HEADGEAR', 'HELMET',
        'VISOR', 'GARB', 'HEADDRESS', 'HATS'], dtype='<U19'),
 array([0.29665267, 0.295963  , 0.27957982, 0.27109832, 0.26435703,
        0.2633745 , 0.25370625, 0.25083083, 0.24353588, 0.24231923],
       dtype=float32))

In [264]:
%%time
cn_guesser.give_hint_candidates(["england", "laser", "pilot"], strategy="approx_mean")

CPU times: user 3.23 ms, sys: 208 µs, total: 3.44 ms
Wall time: 3.01 ms


(array(['MOSELEY', 'GLOUCESTERSHIRE', 'YORKSHIRE', 'WARWICKSHIRE',
        'BRISTOL', 'LEICESTERSHIRE', 'WILTSHIRE', 'MALDON',
        'NORTHAMPTONSHIRE', 'NOTTINGHAM'], dtype='<U19'),
 array([0.24276574, 0.22583926, 0.22396934, 0.21759802, 0.21735172,
        0.21629402, 0.21550824, 0.21187842, 0.20884211, 0.20858294],
       dtype=float32))

In [433]:
guesser.guess(Hint("WEARING", None, "BLUE"))

(array(['SUIT', 'CLOAK', 'CAP', 'COVER', 'SHOE', 'LIGHT', 'KNIFE', 'DANCE',
        'IRON'], dtype='<U11'),
 array([0.5179965 , 0.40340713, 0.41855264, 0.2653496 , 0.35332826,
        0.27642316, 0.25963306, 0.24954134, 0.16813308], dtype=float32))

In [440]:
cn_guesser.guess(Hint("WEARING", None, "BLUE"))

(array(['CAP', 'SUIT', 'CLOAK', 'SHOE', 'PLATE', 'CYCLE', 'SCORPION'],
       dtype='<U11'),
 array([0.29517464, 0.39653951, 0.37184003, 0.29065106, 0.05091697,
        0.07256906, 0.05281432]))

In [245]:
board.is_related_word("moles")

False

In [290]:
board.blue_words

array(['ROUND', 'CAP', 'SUIT', 'SCORPION', 'PILOT', 'LASER', 'ENGLAND',
       'PLATE', 'DANCE'], dtype='<U11')

In [53]:
guesser.strategy_lookup["minimax"]

<bound method GloveGuesser.generate_word_suggestions_minimax of <__main__.GloveGuesser object at 0x7fdd4c289f40>>

In [54]:
batched_cosine_similarity(glove.vectors[:1], glove.vectors[:10]).shape

(1, 10)

In [55]:
glove.tokenize(" ".join(wordlist.words))

[637,
 1967,
 325,
 8334,
 12038,
 8427,
 9214,
 453,
 5239,
 12451,
 3292,
 2647,
 12339,
 603,
 26900,
 137,
 1083,
 775,
 231,
 2069,
 14924,
 4925,
 5747,
 1497,
 3045,
 960,
 3827,
 942,
 2913,
 5077,
 2499,
 11012,
 9307,
 480,
 1963,
 534,
 11869,
 1211,
 1846,
 4707,
 10002,
 6676,
 7331,
 1930,
 1641,
 9248,
 8998,
 4794,
 11035,
 41652,
 6910,
 14105,
 774,
 3539,
 351,
 569,
 1904,
 20786,
 5391,
 1784,
 5450,
 2114,
 46995,
 313,
 3845,
 511,
 1090,
 2375,
 5778,
 17489,
 132,
 6242,
 512,
 4012,
 8525,
 23755,
 449,
 2280,
 1866,
 4249,
 3990,
 3031,
 8475,
 953,
 3387,
 5139,
 5223,
 202,
 1333,
 10216,
 2005,
 2162,
 1007,
 3120,
 4124,
 2312,
 2261,
 1257,
 122,
 336,
 6744,
 1714,
 5188,
 14621,
 13308,
 1289,
 2082,
 2926,
 1737,
 7394,
 4635,
 8416,
 1560,
 7774,
 15389,
 5838,
 1598,
 1847,
 2100,
 563,
 525,
 2090,
 621,
 1791,
 807,
 3267,
 6043,
 307263,
 3510,
 1265,
 2854,
 319,
 484,
 2120,
 16677,
 2361,
 2149,
 352,
 2061,
 11532,
 387,
 186,
 851,
 9780,
 5

In [56]:
glove_vectors

array([[ 0.04656  ,  0.21318  , -0.0074364, ...,  0.0090611, -0.20989  ,
         0.053913 ],
       [-0.25539  , -0.25723  ,  0.13169  , ..., -0.2329   , -0.12226  ,
         0.35499  ],
       [-0.12559  ,  0.01363  ,  0.10306  , ..., -0.34224  , -0.022394 ,
         0.13684  ],
       ...,
       [ 0.075713 , -0.040502 ,  0.18345  , ...,  0.21838  ,  0.30967  ,
         0.43761  ],
       [ 0.81451  , -0.36221  ,  0.31186  , ...,  0.075486 ,  0.28408  ,
        -0.17559  ],
       [ 0.429191 , -0.296897 ,  0.15011  , ...,  0.28975  ,  0.32618  ,
        -0.0590532]], dtype=float32)