In [1]:
class Node:
    @property
    def has_children(self) -> bool:
        raise NotImplementedError()
    
    @property
    def children(self) -> list['Node']:
        raise NotImplementedError()
    
    @property
    def Q(self) -> float:
        raise NotImplementedError()
    
    @property
    def N(self) -> int:
        raise NotImplementedError()

    def update(self, reward: float, alpha: float):
        raise NotImplementedError("Cannot update a node which is not a leaf node")
    
    def get_leaves(self) -> list['LeafNode']:
        leaves = []
        if self.has_children:
            for child in self.children:
                leaves.extend(child.get_leaves())
        else:
            leaves.append(self)
        return leaves

    def __str__(self, level=0):
        ret = "\t" * level + repr(self) + "\n"
        if self.has_children:
            for child in self.children:
                ret += child.__str__(level + 1)
        return ret

class DecisionNode(Node):
    def __init__(self, children: list[Node]):
        self._children = children

    @Node.has_children.getter
    def has_children(self):
        return True
    
    @Node.children.getter
    def children(self):
        return self._children
    
    @Node.Q.getter
    def Q(self):
        # Return the average reward of all children
        return sum([child.Q for child in self._children]) / len(self._children)
    
    @Node.N.getter
    def N(self):
        return sum([child.N for child in self._children])
    
    def __repr__(self):
        return f'Children: {len(self._children)}'

class LeafNode(Node):
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)
        self._Q = 0
        self._N = 0

    @Node.has_children.getter
    def has_children(self):
        return False
    
    @Node.children.getter
    def children(self):
        return None
    
    @Node.Q.getter
    def Q(self):
        return self._Q
    
    @Node.N.getter
    def N(self):
        return self._N

    def update(self, reward: float, alpha: float):
        self._N += 1
        self._Q += alpha * (reward - self._Q)

    def __repr__(self):
        attrs = [f'{key}: {value}' for key, value in self.__dict__.items()]
        return ', '.join(attrs)

In [2]:
import pandas as pd

missions = pd.read_csv('./data/missions.csv')
tree = DecisionNode(children=[
    DecisionNode(children=[
        LeafNode(ID=i, type=t, target=g) for i, t, g in zip(group['ID'], group['type'], group['target'])
    ]) for _, group in missions.groupby('type')
])

print(tree)

Children: 7
	Children: 2
		ID: 0, type: action, target: 1, _Q: 0, _N: 0
		ID: 1, type: action, target: 2, _Q: 0, _N: 0
	Children: 10
		ID: 6, type: activity, target: 4, _Q: 0, _N: 0
		ID: 10, type: activity, target: 8, _Q: 0, _N: 0
		ID: 3, type: activity, target: 10, _Q: 0, _N: 0
		ID: 4, type: activity, target: 2, _Q: 0, _N: 0
		ID: 7, type: activity, target: 5, _Q: 0, _N: 0
		ID: 2, type: activity, target: 1, _Q: 0, _N: 0
		ID: 5, type: activity, target: 3, _Q: 0, _N: 0
		ID: 9, type: activity, target: 7, _Q: 0, _N: 0
		ID: 11, type: activity, target: 9, _Q: 0, _N: 0
		ID: 8, type: activity, target: 6, _Q: 0, _N: 0
	Children: 6
		ID: 14, type: episode, target: 3, _Q: 0, _N: 0
		ID: 16, type: episode, target: 5, _Q: 0, _N: 0
		ID: 17, type: episode, target: 6, _Q: 0, _N: 0
		ID: 15, type: episode, target: 4, _Q: 0, _N: 0
		ID: 13, type: episode, target: 2, _Q: 0, _N: 0
		ID: 12, type: episode, target: 1, _Q: 0, _N: 0
	Children: 2
		ID: 19, type: exp, target: 50, _Q: 0, _N: 0
		ID: 18

In [3]:
import random
import math

class Policy:
    def select(self, tree: Node, n=1):
        raise NotImplementedError()

class EpsilonGreedyPolicy:
    def __init__(self, epsilon: float):
        self.epsilon = epsilon

    def select(self, tree: Node, n=1):
        selected = set()
        candidates = {c for c in tree.children}

        for _ in range(n):
            if random.random() < self.epsilon:
                selected.add(random.choice(list(candidates)))
            else:
                s = max(candidates, key=lambda x: x.Q)
                candidates.remove(s)
                selected.add(s)
            
        return selected

class UCBPolicy:
    def __init__(self, c: float):
        self.c = c

    def select(self, tree: Node, n=1):
        scores = {c: c.Q + self.c * math.sqrt(2 * math.log(tree.N) / c.N) for c in tree.children}
        return set(sorted(scores, key=scores.get, reverse=True)[:n])

In [4]:
import numpy as np
class Bandit:
    def __init__(self, tree: Node, policy: Policy):
        self.tree = tree
        self.policy = policy

    def select(self, n, tree=None):
        if tree is None:
            tree = self.tree
        
        if tree.has_children:
            assert n[0] <= len(tree.children), f"You can only select {len(tree.children)} children, but you selected {n[0]}"
            selected_nodes = self.policy.select(tree, n=n[0])
            selected_nodes = [self.select(n=n[1:], tree=child) for child in selected_nodes]
            return np.array(selected_nodes).flatten()

        return tree

In [5]:
true_probs = {node : np.random.beta(1, 1) for node in tree.get_leaves()}

In [None]:
policy = EpsilonGreedyPolicy(epsilon=0.2)
agent = Bandit(tree, policy)

for _ in range(1000):
    missions: list[Node] = agent.select([3, 1])
    for mission in missions:
        reward = np.random.binomial(1, true_probs[mission])
        mission.update(reward=reward, alpha=0.1)

0.7396666666666667
