### Setup

In [None]:
import json
import faiss
import numpy as np
import chess
import chess.svg
import tensorflow as tf

from chesspos.convert import bitboard_to_board
from chesspos.binary_index import board_to_bitboard
from chesspos.utils import files_from_directory
import chesspos.embedding_index as iemb

### Download some embeddings

These files were obtained by extracting bitboards, training an embedding model and embedding the bitboards with that model

In [None]:
# download embeddings
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1fiPUEBTnxzbnFvSSKspmGA-1OPRx1tnn' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1fiPUEBTnxzbnFvSSKspmGA-1OPRx1tnn" -O ../data/embeddings_d64.tar.bz2 && rm -rf /tmp/cookies.txt
# unpack
!tar -xjf ../data/embeddings_d64.tar.bz2 -C ../data/
# clean up
!mv ../data/embeddings ../data/embeddings_d64
!rm ../data/embeddings_d64.tar.bz2
# download the belonging model
!curl -L -o '../data/deep64.tar.bz2' 'https://docs.google.com/uc?export=download&id=1MHBTMx7yCJTL_l-BD72Nr3EEcwLa1myq'
# clean up
!tar -xjf ../data/deep64.tar.bz2 -C ../data/
!mv ../data/deep64 ../data/model_deep64
!rm deep64.tar.bz2

### Create and populate an index (the easy way)

In [None]:
# for help run
# !python ../tools/index_from_embedding.py -h
!python ../tools/index_from_embedding.py\
PCA16,SQ4\
../data/embeddings_d64\
--save_path ../data/embeddings_d64\
--table_id test_embedding\
--train_frac 0.001\
--chunks 10000

### Create and populate index (the hard way)

In [None]:
embedding_path = "../data/embeddings_d64"
save_path = "../data/embeddings_d64"
factory_string = "PCA16,SQ4"
table_id = "test_embedding"
queries = [
    "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
    "8/1R6/4p1k1/1p6/p1b2K2/P1Br4/1P6/8 b - - 8 49"
]
num_results = 10
embedding_dimension = 64
chunks = int(1e4)
train_frac = 1e-3

In [None]:
# create index
index = faiss.index_factory(embedding_dimension, factory_string)

In [None]:
%%time
# train faiss index
train_file_list = files_from_directory(embedding_path, file_type="h5")
index = iemb.index_train_embeddings(train_file_list, table_id, index, chunks=chunks, train_frac=train_frac)

In [None]:
%%time
# populate faiss index
file_list = files_from_directory(embedding_path, file_type="h5")
index, table_dict = iemb.index_load_file_array(file_list, table_id, index, chunks=chunks)

In [None]:
%%time
# save index
faiss.write_index(index, f"{save_path}/{factory_string}.faiss")
json.dump( table_dict, open( f"{save_path}/{factory_string}.json", 'w' ) )

### Search index

In [None]:
encoder_path = "../data/model_deep64/model_encoder.h5"
decoder_path ="../data/model_deep64/model_decoder.h5"
save_path = f"../data/embeddings_d64"
factory_string = "PCA16,SQ4"
# load the previously created
table_dict = json.load( open( f"{save_path}/{factory_string}.json" ) )
index = faiss.read_index(f"{save_path}/{factory_string}.faiss")
queries = np.array([
    "r1bqk1nr/pp1pbppp/2n1p3/8/3N4/6P1/PPP1PPBP/RNBQK2R w KQkq - 3 6",
    "8/1R6/4p1k1/1p6/p1b2K2/P1Br4/1P6/8 b - - 8 49",
    "8/8/5p2/R3pkp1/5n2/5K2/8/8 w - - 0 42"
])
num_results = 3

In [None]:
%%time
# search index
D, I, E = iemb.index_query_positions(queries, index, encoder_path,
                                     input_format='fen', num_results=num_results)

In [None]:
%%time
# retrieve the belonging bitboards
file, table, offset = iemb.location_from_index(I, table_dict)
bb_table = iemb.manipulate_prefix(table, "position")
bitboards = iemb.retrieve_elements_from_file(file, bb_table, offset)
print(bitboards.shape, bitboards.dtype)

In [None]:
# retrieve belonging embeddings
embeddings = iemb.retrieve_elements_from_file(file, table, offset)
e_shape = embeddings.shape
print(f"embedding shape {e_shape}")
embeddings = embeddings.reshape((-1,e_shape[-1]))

# reconstruct with decoder
decoder = tf.keras.models.load_model(decoder_path)
decoded_bitboards = decoder(embeddings)
decoded_bitboards = decoded_bitboards.numpy()
decoded_bitboards = decoded_bitboards.reshape((*e_shape[:-1],-1))
print(f"reconstructed bitboard shape {decoded_bitboards.shape}")

In [None]:
# convert bitboards to fen
def fen_converter(bb):
    board = bitboard_to_board(bb) 
    return board.fen()
fc = np.vectorize(fen_converter, signature=f'(773)->()')

bitboards_fen = fc(bitboards)
decoded_bitboards_fen = fc(decoded_bitboards)

bb_shape = bitboards.shape

In [None]:
from IPython.display import HTML
html = ""
for i in range(bb_shape[0]):
    for j in range(bb_shape[1]):
        html += f"<h4>Query {i} | Retrieved bitboard {j}: euclidean distance {D[i][j]} to query | Reconstructed bitboard {j}</h4>"
        html += chess.svg.board(chess.Board(queries[i]), size=300)
        html += chess.svg.board(chess.Board(bitboards_fen[i][j]), size=300)
        html += chess.svg.board(chess.Board(decoded_bitboards_fen[i][j]), size=300)
    html += "<hr>"
HTML(html)