<a href="https://colab.research.google.com/github/predatorx7/boring/blob/master/3_A_AlphaBetaPruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
class Node:
  """A simple node """
  def __declare_instance_variables(this) -> None:
    this.parent: Node = None
    this.root: Node = None
    this.__children: list = []
  def __init__(this, child = None, children: list = None, value: float = None, tag: str = None):
    """child: Node, children: List[Node]"""
    this.__declare_instance_variables()
    this.tag = tag
    this.value = value
    if (child != None):
      this.add(child)
    if (children != None):
      this.add_children(children)

  def get_neighbors(this) -> list:
    """Returns the neighbor nodes"""
    if this.parent == None:
      return [this]
    children = this.parent.get_children()
    if children == None:
      return []
    return children
  
  def get_first(this):
    """Returns the first children of this node"""
    if (this.is_empty): return None
    return this.__children[0]

  def is_root(this) -> bool:
    return this.parent == None

  def is_leaf(this) -> bool:
    if (this.__children == None): return True
    return this.is_empty() 

  def is_inner(this) -> bool:
    return not (this.is_leaf() or this.is_root())

  def get_children(this) -> list:
    return this.__children

  def get_root(this):
    """Returns -> Node"""
    if (this.is_root()):
      return this
    else:
      return this.parent.root

  def get_height(this) -> int:
    if (this.is_empty()):
      return 0
    maxHeight: int = 0
    children: list = this.get_children()
    for element in children:
      height: int = element.get_height()
      if (height > maxHeight):
        maxHeight = height
    return maxHeight + 1

  def get_depth(this) -> int:
    if (this.is_root()):
      return 0
    return this.parent.get_depth() + 1

  def is_empty(this) -> bool:
    return len(this.__children) == 0

  def is_not_empty(this) -> bool:
    return not this.is_empty()

  def add(this, child) -> None:
    """child: Node"""
    assert child != None
    if (this.__children == None):
      this.__children = []
    child.parent = this
    child.root = this.get_root()
    this.__children.append(child)

  def add_children(this, children: list) -> None:
    assert children != None
    if (len(children) == 0):
      return
    if (this.__children == None):
      this.__children = []
    for element in children:
      element.parent = this
      element.root = this.get_root()
      this.__children.append(element)

  def __len__(this) -> int:
    if (len(this.__children) != 0 and this.__children != None):
      maxLength: int = 1
      for child in this.__children:
          maxLength += len(child)
      return maxLength
    else:
      return 1

In [None]:
from math import inf

def max(a: float, b: float) -> float:
    assert (a != None and b != None)
    if (a > b):
        return a
    return b

def min(a: float, b: float) -> float:
    assert (a != None and b != None)
    if (a < b):
        return a
    return b

class Counter:
    def __init__(this, count = None):
        this.__count = 0 if count == None else count
    def get_count(this):
        return this.__count
    def increment(this):
        this.__count += 1
    def __str__(this):
        return str(this.__count)

class State:
    def __declare(this):
        this.__isMaximizer = True
        this.__isMinimizer = not this.__isMaximizer
        this.__alpha = - inf
        this.__beta = inf

    def __init__(this, alpha: float, beta: float, isMaximizingPlayer: bool):
        this.__declare()
        if alpha != None:
            this.set_alpha(alpha)
        if beta != None:
            this.set_beta(beta)
        assert isMaximizingPlayer != None
        this.__isMaximizer = isMaximizingPlayer

    def get_alpha(this) -> float:
        return this.__alpha

    def set_alpha(this, alpha: float) -> None:
        assert alpha != None
        this.__alpha = alpha

    def get_beta(this) -> float:
        return this.__beta

    def set_beta(this, beta: float) -> None:
        assert beta != None
        this.__beta = beta

    def is_maximizer(this) -> bool:
        return this.__isMaximizer

    def set_maximizer(this, isMaximizer: bool) -> None:
        this.__isMaximizer = isMaximizer
        this.__isMinimizer = not isMaximizer

    def is_minimizer(this) -> bool:
        return this.__isMinimizer

    def set_minimizer(this, isMinimizer: bool) -> None:
        this.__isMinimizer = isMinimizer
        this.__isMaximizer = not isMinimizer

    def get_statevalue(this) -> float:
        return this.get_alpha() if this.is_maximizer() else this.get_beta

    def increment(this) -> None:
        this.__count += 1

    def getcount(this) -> int:
        return this.__count

    def __str__(this) -> str:
        return (
        f'STATE: '
        f'alpha: {this.get_alpha()} '
        f'beta: {this.get_beta()} '
        f'isMaximizer: {this.is_maximizer()} '
        f'value: {this.get_statevalue()}')

In [None]:
def alphabetaPruning(node: Node, state: State, counter: Counter = None) -> float:
    """```pseudo-code
    function alphabeta(node, depth, α, β, maximizingPlayer) is
        if depth = 0 or node is a terminal node then
            return the heuristic value of node
        if maximizingPlayer then
            value := −∞
            for each child of node do
                value := max(value, alphabeta(child, depth − 1, α, β, FALSE))
                α := max(α, value)
                if α ≥ β then
                    break (* β cut-off *)
            return value
        else
            value := +∞
            for each child of node do
                value := min(value, alphabeta(child, depth − 1, α, β, TRUE))
                β := min(β, value)
                if β ≤ α then
                    break (* α cut-off *)
            return value
    ```
    """
    if counter != None: counter.increment()
    if (node.is_empty() or (node.get_height() == 0)):
        assert node.value != None, "leaf nodes must have a value"
        return node.value
    value: int = None
    if (state.is_maximizer()):
        value = -inf
        for child in node.get_children():
            assert isinstance(child, Node)
            value = max(
                value,
                alphabetaPruning(child, State(
                    state.get_alpha(), state.get_beta(), False), counter=counter),
            )
            state.set_alpha(max(state.get_alpha(), value))
            if (state.get_alpha() >= state.get_beta()):
                break
        value = state.get_alpha()
    else:
      
        value = inf
        for child in node.get_children():
            value = min(
                value,
                alphabetaPruning(
                    child,
                    State(state.get_alpha(), state.get_beta(), True), counter=counter
                ),
            )
            state.set_beta(min(state.get_beta(), value))
            if (state.get_beta() <= state.get_alpha()):
                break
        value = state.get_beta()
    if (value == None):
        raise TypeError('value should not be NoneType')
    return value

In [None]:
# Here I've created a game tree with nodes.
# Here, every leaf node (aka terminal node) and has a heuristic value
head: Node = Node(
    # This is the root node with children branches
    children=[
        # The first child of root node
        Node(
            children=[
                Node(
                    # This node only has leaf nodes (or terminal nodes) as childrens
                    children=[
                        Node(value=-5),
                        Node(value=7),
                    ],
                ),
                Node(
                    # This node only has leaf nodes (or terminal nodes) as childrens
                    children=[
                        Node(value=1),
                        Node(value=-2),
                        Node(value=8),
                    ],
                ),
                Node(
                    child=Node(
                        # This is a leaf node (or terminal node) and has a value (heuristic value)
                        value=-4),
                ),
            ],
        ),
        # The second child of root node
        Node(
            children=[
                Node(
                    # This node only has leaf nodes (or terminal nodes) as childrens
                    children=[
                        Node(value=3),
                        Node(value=7),
                    ]),
                Node(
                    child=Node(
                        # This is a leaf node (or terminal node) and has a value (heuristic value)
                        value=-9),
                ),
                Node(
                    # This node only has leaf nodes (or terminal nodes) as childrens
                    children=[
                        Node(value=2),
                        Node(value=3),
                    ]),
            ],
        ),
        # The third child of root node
        Node(children=[
            Node(
                # This node only has leaf nodes (or terminal nodes) as childrens
                children=[
                    Node(value=1),
                    Node(value=-5),
                ]),
            Node(
                child=Node(
                    # This is a leaf node (or terminal node) and has a value (heuristic value)
                    value=8),
            ),
            Node(
                # This node only has leaf nodes (or terminal nodes) as childrens
                children=[
                    Node(value=1),
                    Node(value=3),
                    Node(value=2),
                ]),
        ]),
    ],
)

In [None]:
counter: int = Counter()
optimalValue: float = alphabetaPruning(
    head, State(None, None, True), counter)
totalNodes = len(head)

In [None]:
print(f'Optimal value: {optimalValue}',
      f'{totalNodes - counter.get_count()} nodes pruned'
      ' + '
      f'{counter} nodes traversed'
      ' = '
      f'{totalNodes} total nodes', sep=',\n')

Optimal value: 1,
5 nodes pruned + 25 nodes traversed = 30 total nodes
