In [9]:
import chess
import chess.engine
import random
from collections import defaultdict
from sarfa_saliency import computeSaliencyUsingSarfa
from sys import platform as _platform
from utils import get_all_pos, get_move_obj

In [10]:
def get_engine(engine_file = './stockfish_15_x64_avx2'):
    chess_engine = chess.engine.SimpleEngine.popen_uci(engine_file)
    return chess_engine

def q_values(board, candidate_actions, selected_action, multipv=3, runtime=2.0, use_optimal_action=False):
    options = chess_engine.analyse(board, chess.engine.Limit(time=runtime), multipv=multipv)
    if (use_optimal_action):
        selected_action = str(options[0]["pv"][0])
    
    score_per_move = defaultdict(int)

    for option in options:
        is_white_move = option['score'].turn
        score = option['score'].white() if is_white_move else option['score'].black()
        
        curr_action = str(option["pv"][0])
        if option['score'].is_mate():
            score = 40 if '+' in str(score) else -40
        else:
            score = round(score.cp/100.0, 2)
        
        score_per_move[curr_action] = score

    q_vals = {}
    for valid_move in candidate_actions:
        q_vals[str(valid_move)] = score_per_move[str(valid_move)]
    
    return q_vals, selected_action

In [11]:
def removal_perturb(board, position):
    position = get_move_obj(position)
    piece = board.piece_at(position)
    if not piece or piece == chess.Piece(chess.KING, chess.WHITE) or piece == chess.Piece(chess.KING, chess.BLACK):
        return
    
    perturbed_board = board.copy()
    perturbed_board.remove_piece_at(position)
    if perturbed_board.was_into_check():
        return 
    
    return [perturbed_board], position

def empty_space_perturb(board, position):
    position = get_move_obj(position)
    piece = board.piece_at(position)
    if not piece:
        perturbed_board = board.copy()
        new_piece = chess.Piece(chess.PAWN, chess.WHITE) if board.turn else chess.Piece(chess.PAWN, chess.BLACK)
        perturbed_board.set_piece_at(position, new_piece)
        if perturbed_board.was_into_check():
            return
        
        return [perturbed_board], position
    else:
        return

def opp_piece_perturb(board, position):
    position = get_move_obj(position)
    piece = board.piece_at(position)
    if not piece:
        return
    
    if piece.color and board.turn:
        
    
    

In [15]:
#########################################
num_actions = 100
runtime = 1.0
use_optimal_action = False
allow_defense = True
#########################################

# Bishop Pins Rooks
FEN = "5rk1/6pp/1B1r4/5p2/8/2P5/PP6/1K4R1 w - - 0 1"
action = "b6c5"

# Bishop Checks on Open Diagonal - showcases empty space perturbation
# FEN = "6N1/1pkb4/p3pQ1p/2P1P3/3P4/4q1P1/PP5P/5R1K b - - 0 33"
# action = "d7c6"

# Backrank Check protected by Knight - identifying defensive pieces?
# FEN = "6k1/1n3ppp/8/8/8/8/7P/1K1R4 w - - 0 1"
# action = "d1d8"

# More complicated Backrank Check - identify defensive pieces
# FEN = "3r2kr/5ppp/8/1n6/3B3R/8/1PP4P/RK6 b - - 0 1"
# action = "d8d4"


chess_engine = get_engine()
board = chess.Board(FEN)
legal_moves = set(list(board.legal_moves))

board_positions = get_all_pos()
for pos in board_positions:
    # can substitute for any perturb function
    perturbed_state = removal_perturb(board, pos)
    # perturbed_state = empty_space_perturb(board, pos)
    if perturbed_state is None:
        continue

    perturbed_boards, _ = perturbed_state

    # list of perturbed boards allows for perturb function to test different 
    # perturbations on the same state
    for i, perturbed_board in enumerate(perturbed_boards):
        new_legal_moves = perturbed_board.legal_moves
        candidate_actions = legal_moves.intersection(set(new_legal_moves))

        if not use_optimal_action and chess.Move.from_uci(action) not in candidate_actions:
            continue
        
        q_vals_before, selected_action = q_values(board, candidate_actions, action, multipv=num_actions, 
                                            runtime=runtime, use_optimal_action=use_optimal_action)
        q_vals_after, selected_action = q_values(perturbed_board, candidate_actions, selected_action, 
                                            multipv=num_actions, runtime=runtime, use_optimal_action=use_optimal_action)
    
        saliency, dP, K, QmaxAnswer, _, _ = computeSaliencyUsingSarfa(selected_action, q_vals_before, q_vals_after, allow_defense_check=allow_defense)
        saliency = 0 if abs(saliency) < 0.001 else saliency

        saliency_type = ""
        if dP < 0 and abs(dP) > 0.1:
            saliency_type = "Defensive "
            saliency = abs(saliency)
        

        print(f'Perturbed Board {i}, Perturbed position: {pos}, Action: {selected_action}, {saliency_type}Saliency: {saliency}')

Perturbed Board 0, Perturbed position: a2, Action: b6c5, Saliency: 0.23584494299623263
Perturbed Board 0, Perturbed position: b2, Action: b6c5, Saliency: 0.7961583839447218
Perturbed Board 0, Perturbed position: c3, Action: b6c5, Saliency: -0.08521809526130364
Perturbed Board 0, Perturbed position: d6, Action: b6c5, Saliency: 0.5802403549049537
Perturbed Board 0, Perturbed position: f5, Action: b6c5, Saliency: 0.2777414715410583
Perturbed Board 0, Perturbed position: f8, Action: b6c5, Saliency: 0.7671264784824567
Perturbed Board 0, Perturbed position: g1, Action: b6c5, Saliency: 0.7936239753577831
Perturbed Board 0, Perturbed position: h7, Action: b6c5, Saliency: 0.1522039529993271
