# Lab 3: Policy Search

## Task

Write agents able to play [*Nim*](https://en.wikipedia.org/wiki/Nim), with an arbitrary number of rows and an upper bound $k$ on the number of objects that can be removed in a turn (a.k.a., *subtraction game*).

The player taking the last object wins.

* Task3.3: An agent using minmax

In [79]:
from nim_utils import Nimply, Nim 
from functools import lru_cache

## **MinMax for Nim**

In [87]:
# Check terminal state
def check_teminal(rows, is_maximizing) -> int:
    if sum(rows) == 0:
        return -1 if is_maximizing else 1

# Get moves
def possible_new_states(rows):
    # row order is not imortant, to avoid different symmetrical solutions, order rows
    sorted_rows = tuple(sorted(rows)) 
    for row, num_objs in enumerate(sorted_rows):
        for remain in range(num_objs):
            yield sorted_rows[:row] + (remain,) + sorted_rows[row + 1 :]

# Minmax
@lru_cache
def minmax(rows, is_maximizing, alpha=-1, beta=1) -> Nimply:
    if (score := check_teminal(rows, is_maximizing)) is not None:
        return score

    scores = []
    for new_state in possible_new_states(rows):
        scores.append(
            score := minmax(new_state, not is_maximizing, alpha, beta)
        )
        if is_maximizing:
            # Update alpha: min score of maximizing player
            alpha = max(alpha, score)
        else:
            # Update beta: maximum score of minimizing player
            beta = min(beta, score)
        # Do not consider further moves (maximizer already found a better solution than any of the unexplored ones)
        if beta <= alpha:
            break
    return (max if is_maximizing else min)(scores)
    

# Best move
def best_move(state: Nim) -> Nimply:
    return max(
        (minmax(new_state, is_maximizing=False), new_state) for new_state in possible_new_states(state.rows)
    )

### **MinMax Strategy**

In [95]:
def minmax_strategy(board: Nim) -> Nimply:
    _, new_state = best_move(board)
    for idx, (curr_row, new_row) in enumerate(zip(board.rows, new_state)):
        print(curr_row, new_row)
        if curr_row != new_row:
            return Nimply(idx, curr_row - new_row)


In [96]:
minmax_strategy(nim)

1 1
3 3
5 5
7 6


Nimply(row=3, num_objects=1)