In [176]:
with open('words/game_wordpool.txt') as f:
    word_pool = f.read().splitlines()

import random

board_words = random.sample(word_pool, k=25)

In [158]:
board_words = [
    'RIVER',
    'DEPOSIT',
    'PARACHUTE',
    'BUFFALO',
    'STRING',
    'MISSILE',
    'LUCK',
    'PORT',
    'CENTER',
    'GAS',
    'GOLD',
    'NUT',
    'PAPER',
    'BOX',
    'SPOT',
    'SCALE',
    'CONCERT',
    'BERMUDA',
    'INDIA',
    'SLUG',
    'ROULETTE',
    'MEXICO',
    'POST',
    'CAR',
    'KID'
]

In [177]:
from sentence_transformers import SentenceTransformer, util

model = SentenceTransformer('all-MiniLM-L12-v2')

In [178]:
board_words_embeddings = list(map(lambda x: model.encode(x), board_words))

In [179]:
from pymilvus import (
    connections,
    utility,
    FieldSchema,
    CollectionSchema,
    DataType,
    Collection,
)
connections.connect("default", host="localhost", port="19530")

In [180]:
fields = [
    FieldSchema(
        name="id",
        dtype=DataType.INT64,
        is_primary=True,
        auto_id=False),
    FieldSchema(
        name="word",
        dtype=DataType.VARCHAR,
        max_length=32,
    ),
    FieldSchema(
        name="embeddings",
        dtype=DataType.FLOAT_VECTOR,
        dim=384,
    )
]
schema = CollectionSchema(fields, "Embeddings of the Codenames word pool")
board_db = Collection("sbert_board_embeddings", schema)    

In [181]:
entries = [
    [i for i in range(len(board_words))],
    board_words,
    board_words_embeddings,
]

In [182]:
board_db.delete("id > -1")
board_db.flush()

In [183]:
board_db.insert(entries)
board_db.flush()  

In [184]:
board_db.num_entities

50

In [185]:
index = {
    "index_type": "IVF_FLAT",
    "metric_type": "COSINE",
    "params": {"nlist": 128},
}
board_db.create_index("embeddings", index)

Status(code=0, message=)

In [186]:
board_db.load()

In [187]:
board_words

['SCHOOL',
 'FLUTE',
 'FOOT',
 'TAIL',
 'DATE',
 'MOSCOW',
 'CYCLE',
 'KETCHUP',
 'POLE',
 'FALL',
 'BOX',
 'KID',
 'TRAIN',
 'MERCURY',
 'AFRICA',
 'ATLANTIS',
 'POINT',
 'GREEN',
 'BERMUDA',
 'TIE',
 'FLY',
 'SMUGGLER',
 'CHEST',
 'HOOD',
 'EGYPT']

In [197]:
clue = "poland"

clue_embedding = model.encode(clue)

In [198]:
vectors_to_search = [clue_embedding]
search_params = {
    "metric_type": "COSINE",
    "params": {"nprobe": 10},
}
result = board_db.search(vectors_to_search, "embeddings", search_params, limit=5, output_fields=["id", "word"])

In [199]:
for hits in result:
    print("====")
    for hit in hits:
        print (hit.entity)

====
id: 14, distance: 0.5735673904418945, entity: {'id': 14, 'word': 'AFRICA'}
id: 5, distance: 0.5577887296676636, entity: {'id': 5, 'word': 'MOSCOW'}
id: 24, distance: 0.5297727584838867, entity: {'id': 24, 'word': 'EGYPT'}
id: 8, distance: 0.40159177780151367, entity: {'id': 8, 'word': 'POLE'}
id: 23, distance: 0.3251984119415283, entity: {'id': 23, 'word': 'HOOD'}
