In [1]:
from ai import ai_by_score, dprint
from marubatsu import Marubatsu
from copy import deepcopy

@ai_by_score
def ai_mtdf(mb, debug=False, shortest_victory=False, init_ab=False, use_tt=False, 
               f=0, ai_for_mo=None, params={}, sort_allnodes=False, calc_count=False):           
    count = 0
    def ab_search(mborig, tt, alpha=float("-inf"), beta=float("inf")):
        nonlocal count
        count += 1
        if mborig.status == Marubatsu.CIRCLE:
            return (11 - mborig.move_count) / 2 if shortest_victory else 1
        elif mborig.status == Marubatsu.CROSS:
            return (mborig.move_count - 10) / 2 if shortest_victory else -1
        elif mborig.status == Marubatsu.DRAW:
            return 0
        
        if use_tt:
            boardtxt = mborig.board_to_str()
            if boardtxt in tt:
                lower_bound, upper_bound = tt[boardtxt]
                if lower_bound == upper_bound:
                    return lower_bound
                elif upper_bound <= alpha:
                    return upper_bound
                elif beta <= lower_bound:
                    return lower_bound
                else:
                    alpha = max(alpha, lower_bound)
                    beta = min(beta, upper_bound)
            else:
                lower_bound = min_score
                upper_bound = max_score
        
        alphaorig = alpha
        betaorig = beta

        legal_moves = mborig.calc_legal_moves()
        if ai_for_mo is not None:
            if sort_allnodes:
                score_by_move = ai_for_mo(mborig, analyze=True, **params)["score_by_move"]
                score_by_move_list = sorted(score_by_move.items(), key=lambda x:x[1], reverse=True)
                legal_moves = [x[0] for x in score_by_move_list]
            else:
                legal_moves = mborig.calc_legal_moves()
                bestmove = ai_for_mo(mborig, rand=False, **params)
                index = legal_moves.index(bestmove)
                legal_moves[0], legal_moves[index] = legal_moves[index], legal_moves[0]
        if mborig.turn == Marubatsu.CIRCLE:
            score = float("-inf")
            for x, y in legal_moves:
                mb = deepcopy(mborig)
                mb.move(x, y)
                score = max(score, ab_search(mb, tt, alpha, beta))
                if score >= beta:
                    break
                alpha = max(alpha, score)
        else:
            score = float("inf")
            for x, y in legal_moves:
                mb = deepcopy(mborig)
                mb.move(x, y)
                score = min(score, ab_search(mb, tt, alpha, beta))
                if score <= alpha:
                    break
                beta = min(beta, score)   
            
        from util import calc_same_boardtexts

        if use_tt:
            boardtxtlist = calc_same_boardtexts(mborig)
            if score <= alphaorig:
                upper_bound = score
            elif score < betaorig:
                lower_bound = score
                upper_bound = score
            else:
                lower_bound = score
            for boardtxt in boardtxtlist:
                tt[boardtxt] = (lower_bound, upper_bound)
        return score
                
    min_score = -2 if shortest_victory else -1
    max_score = 3 if shortest_victory else 1
    lbound = min_score if init_ab else float("-inf")
    ubound = max_score if init_ab else float("inf")

    tt = {}
    while lbound != ubound:
        beta = f + 1 if lbound == f else f
        f = ab_search(mb, tt, alpha=beta - 1, beta=beta)
        if f >= beta:
            lbound = f
        else:
            ubound = f
    score = f
            
    dprint(debug, "count =", count)
    if calc_count:
        return count
    if mb.turn == Marubatsu.CIRCLE:
        score *= -1
    return score

In [2]:
from util import Check_solved

for shortest_victory in [False, True]:
    for use_tt in [False, True]:
        for init_ab in [False, True]:
            print(f"sv: {shortest_victory}, use_tt: {use_tt}, init_ab: {init_ab}")
            params = {
                "shortest_victory": shortest_victory,
                "init_ab": init_ab,
                "use_tt": use_tt,
            }
            Check_solved.is_strongly_solved(ai_mtdf, params=params)

sv: False, use_tt: False, init_ab: False


100%|██████████| 431/431 [00:04<00:00, 87.39it/s] 


431/431 100.00%
sv: False, use_tt: False, init_ab: True


100%|██████████| 431/431 [00:03<00:00, 124.23it/s]


431/431 100.00%
sv: False, use_tt: True, init_ab: False


100%|██████████| 431/431 [00:03<00:00, 127.08it/s]


431/431 100.00%
sv: False, use_tt: True, init_ab: True


100%|██████████| 431/431 [00:02<00:00, 149.80it/s]


431/431 100.00%
sv: True, use_tt: False, init_ab: False


100%|██████████| 431/431 [00:04<00:00, 88.04it/s] 


431/431 100.00%
sv: True, use_tt: False, init_ab: True


100%|██████████| 431/431 [00:05<00:00, 80.79it/s] 


431/431 100.00%
sv: True, use_tt: True, init_ab: False


100%|██████████| 431/431 [00:03<00:00, 123.87it/s]


431/431 100.00%
sv: True, use_tt: True, init_ab: True


100%|██████████| 431/431 [00:03<00:00, 113.19it/s]

431/431 100.00%



