In [100]:
# emv.py -- compute the expected monetary value of a decision tree
#        -- use the composite design pattern for structure, and the
#           visitor pattern for operations (just emv())

import math
from typing import cast, List, Tuple, Union

In [101]:
# Using strings as types below since these are forward references
# (the Visitor and Tree classes are mutually dependent).

class Visitor(object):
    '''
    This class represents an interface to be implemented. The visit 
    method is the only one with *help*, since all the others are the same
    '''
    def visit_d(self, d: 'Decision') -> None:
        '''
        This is the idealized structure for this method
        
        Parameters:
        -----------
        d : (Decision) It will recieve a Decision object. 
        The interface defines that using a string just for 
        forward reference whenever it is implemented
        
        Return:
        -------
        Void
        '''
        raise NotImplementedError('Visitor.visit_d()')

    def visit_c(self, c: 'Chance') -> float:
        raise NotImplementedError('Visitor.visit_c()')

    def visit_t(self, t: 'Terminal') -> float:
        raise NotImplementedError('Visitor.visit_t()')

In [102]:
class Tree(object):
    '''
    This is the parent class of all trees. 
    '''
    def __init__(self, name: str) -> None:
        '''
        This is the constructor of the tree object. It includes in self
        the probability and the emv of the node (initializaed as float('nan'))
        
        Parameters:
        -----------
        name : (str) The name given to the tree
        
        Return:
        -------
        Void
        '''
        self.name = name
        self.probability = float('nan')
        self.emv = float('nan')

    def accept(self, v: Visitor) -> None:
        '''
        This is the method that allows that the visitor object collects the 
        information regarding expected monetary value. All trees know how to
        handle the Emv object.
        
        Parameters:
        -----------
        v : (Visitor) the Emv object that performs the math to compute the
        expected monetary value
        
        Return:
        -------
        Void
        '''
        if type(self) == Decision:
            v.visit_d(cast(Decision, self)) # cast(type, value)
        elif type(self) == Chance:
            v.visit_c(cast(Chance, self))
        elif type(self) == Terminal:
            v.visit_t(cast(Terminal, self))

    def prn(self, level: int) -> None:
        '''
        This method is defined to print the tree using indentation according
        to the level assigned to each node.
        '''
        print(' '*4*level, end='')
        if not math.isnan(self.probability): # decision nodes have 'nan' probability
            print(self.probability, '', end='')
        print(self)

    def __str__(self) -> str:
        '''
        The class override the __str__ method from object to represent
        the type of node, its name, and its expected monetary value
        '''
        return '%s(%s, %s)' % (type(self).__name__[0], self.name, self.emv)

In [103]:
class NonTerminal(Tree):
    '''
    This class defines the abstraction that implements either types of
    non-terminal nodes (Chance or Decision). It inherits from its parent
    class (Tree)
    '''

    def __init__(self, name: str) -> None:
        '''
        This method defines the constructor of the class. It extends the
        constructor defined in the parent class by adding an empty children
        list to start appending all children trees that start in this node.
        '''
        super().__init__(name)
        self.children = []      # type: List[Tree]

    def accept(self, v: Visitor) -> None:
        '''
        This method is defined to perform the mathematical operations
        that will traverse the tree 
        
        Parameters:
        -----------
        v : (Visitor) this is and Emv object that will perform the expected
        monetary operations.
        
        Return:
        -------
        Void
        '''
        for t in self.children:
            t.accept(v)
        super().accept(v)

    def prn(self, level: int) -> None:
        '''
        This method is used to print the levels that the trees will have
        inside of the whole structure
        
        Parameters:
        -----------
        level : (int) An integer to define the level of the node to set
        the indentation
        
        Return:
        -------
        Void  
        '''
        super().prn(level)
        for t in self.children:
            t.prn(level + 1)


class Decision(NonTerminal):
    '''
    This class uses the constructor defined on its parent class (NonTerminal) that 
    initializes an empty list to start collecting trees of the type
    Decision
    '''
    def add_child(self, t: Tree) -> None: 
        '''
        Adds a child tree to the list of childrens defined by this Decision node
        
        Parameters:
        -----------
        t : (Tree) the subtree that starts in this node that can be of 
            either type (Decision or Chance)
        
        Return:
        -------
        Void
        '''
        self.children.append(t)


class Chance(NonTerminal):
    '''
    This class uses the constructor defined on its parent class (NonTerminal) that
    initializes an empty list to start collecting trees of the type 
    Chance
    '''
    def add_child(self, p: float, t: Tree) -> None: 
        '''
        Adds a child tree to the list of childrens defined by this chance node
        
        Parameters:
        -----------
        p : (float) the probability associated to the child node
        t : (Tree) the subtree that starts in this node that can be of 
            either type (Decision or Chance)
        
        Returns:
        --------
        Void
        '''
        self.children.append(t)
        t.probability = p # asign a probability to the chance

In [104]:
class Terminal(Tree):
    '''
    This class build the terminal nodes of the tree. It does not have the 
    ability of adding more subtrees to the model; however, it inherits all
    the methods defined in its parent class (Tree). The class extends the
    constructor of its super class by adding the monetario value to self
    
    Parameters:
    -----------
    name : (str) The name of the terminal node (it will be passed to the
    super class)
    value : (float) The monetary value given to that node
    
    Return:
    -------
    Void
    '''

    def __init__(self, name: str, value: float) -> None:
        super().__init__(name)
        self.value = value

In [105]:
# Add the ability to do "Emv" operations
class Emv(Visitor):
    '''
    This class implements the Visitor interface.
    '''
    def visit_d(self, d: Decision) -> None:
        '''
        This method is triggered when the Emv object visits a Decision node.
        It will compute the maximum value of all nodes atached to the Decision
        node
        
        Parameters:
        -----------
        d : (Decision) a decision node that has children attached
        
        Return:
        -------
        Void
        '''
        d.emv = max([t.emv for t in d.children])

    def visit_c(self, c: Chance) -> float:
        '''
        This method is triggered when the Emv object visits a Chance node.
        It will compute the expected value of all nodes atached to the Chance
        node
        
        Parameters:
        -----------
        c : (Chance) a Chance node that has children attached
        
        Return:
        -------
        Void
        '''
        c.emv = sum([t.probability * t.emv for t in c.children])

    def visit_t(self, t: Terminal) -> float:
        '''
        This method is triggered when the Emv object visits a Terminal node.
        It will store the value atached to the Terminal node
        
        Parameters:
        -----------
        t : (Terminal) a terminal node that has a monetary value attached to
        it
        
        Return:
        -------
        Void
        '''
        t.emv = t.value

In [106]:
def ex1():
    a = Decision('Equipment Problem')
    b = Decision('Don\'t move equipment')
    c = Chance('Do nothing')
    d = Chance('Build platform')

    a.add_child(Terminal('Move equipment', -1800))
    a.add_child(b)

    b.add_child(c)
    b.add_child(d)

    c.add_child(0.73, Terminal('Normal water level', 0))
    c.add_child(0.25, Terminal('High water level', -10000))
    c.add_child(0.02, Terminal('Flood waters', -60000))

    d.add_child(0.98, Terminal('Normal-to-high water levels', -500))
    d.add_child(0.02, Terminal('Flood waters', -60500))

    a.accept(Emv())
    result = a.emv

    a.prn(0)

    print('\nFinal value:')
    print(result, '\n')

def ex2():
    a = Decision('a')
    b = Chance('b')
    c = Decision('c')
    d = Decision('d')
    e = Chance('e')
    f = Chance('f')

    a.add_child(Terminal('Crash', -3600))
    a.add_child(b)

    b.add_child(0.5, c)
    b.add_child(0.1, d)
    b.add_child(0.4, Terminal('Fair', 0))

    c.add_child(e)
    c.add_child(Terminal('Don\'t Crash', -5000))

    e.add_child(0.6, Terminal('Save 1 Week', -6000))
    e.add_child(0.3, Terminal('Save 2 Weeks', -4500))
    e.add_child(0.1, Terminal('Save 3 Weeks', -3000))

    d.add_child(f)
    d.add_child(Terminal('Don\'t Crash', -20000))

    f.add_child(0.7, Terminal('Save 2 Weeks', -15000))
    f.add_child(0.2, Terminal('Save 3 Weeks', -12500))
    f.add_child(0.1, Terminal('Save 4 Weeks', -10000))

    a.accept(Emv())
    result = a.emv

    a.prn(0)

    print('\nFinal value:')
    print(result, '\n')

In [107]:
if __name__ == '__main__':
    ex1()
    ex2()

D(Equipment Problem, -1700.0)
    T(Move equipment, -1800)
    D(Don't move equipment, -1700.0)
        C(Do nothing, -3700.0)
            0.73 T(Normal water level, 0)
            0.25 T(High water level, -10000)
            0.02 T(Flood waters, -60000)
        C(Build platform, -1700.0)
            0.98 T(Normal-to-high water levels, -500)
            0.02 T(Flood waters, -60500)

Final value:
-1700.0 

D(a, -3600)
    T(Crash, -3600)
    C(b, -3900.0)
        0.5 D(c, -5000)
            C(e, -5250.0)
                0.6 T(Save 1 Week, -6000)
                0.3 T(Save 2 Weeks, -4500)
                0.1 T(Save 3 Weeks, -3000)
            T(Don't Crash, -5000)
        0.1 D(d, -14000.0)
            C(f, -14000.0)
                0.7 T(Save 2 Weeks, -15000)
                0.2 T(Save 3 Weeks, -12500)
                0.1 T(Save 4 Weeks, -10000)
            T(Don't Crash, -20000)
        0.4 T(Fair, 0)

Final value:
-3600 

