In [1]:
import pandas as pd

missions = pd.read_csv('./data/missions.csv')
missions.sort_values('ID', inplace=True)

missions

Unnamed: 0,ID,type,target
7,0,action,1
27,1,action,2
13,2,activity,1
5,3,activity,10
8,4,activity,2
14,5,activity,3
1,6,activity,4
10,7,activity,5
28,8,activity,6
20,9,activity,7


In [2]:
from src.tree import TreeNode

root = TreeNode('root')
for name, group in missions.groupby('type'):
    node = TreeNode(name)
    root.add_child(node)
    for _, mission in group.iterrows():
        node.add_child(TreeNode(mission.to_dict()))

print(root)

root
	action
		{'ID': 0, 'type': 'action', 'target': 1}
		{'ID': 1, 'type': 'action', 'target': 2}
	activity
		{'ID': 2, 'type': 'activity', 'target': 1}
		{'ID': 3, 'type': 'activity', 'target': 10}
		{'ID': 4, 'type': 'activity', 'target': 2}
		{'ID': 5, 'type': 'activity', 'target': 3}
		{'ID': 6, 'type': 'activity', 'target': 4}
		{'ID': 7, 'type': 'activity', 'target': 5}
		{'ID': 8, 'type': 'activity', 'target': 6}
		{'ID': 9, 'type': 'activity', 'target': 7}
		{'ID': 10, 'type': 'activity', 'target': 8}
		{'ID': 11, 'type': 'activity', 'target': 9}
	episode
		{'ID': 12, 'type': 'episode', 'target': 1}
		{'ID': 13, 'type': 'episode', 'target': 2}
		{'ID': 14, 'type': 'episode', 'target': 3}
		{'ID': 15, 'type': 'episode', 'target': 4}
		{'ID': 16, 'type': 'episode', 'target': 5}
		{'ID': 17, 'type': 'episode', 'target': 6}
	exp
		{'ID': 18, 'type': 'exp', 'target': 100}
		{'ID': 19, 'type': 'exp', 'target': 50}
	mobility
		{'ID': 20, 'type': 'mobility', 'target': 1}
		{'ID': 21, 

In [3]:
import numpy as np
import math
from src.policy import Policy
from src.tree_bandit import TreeBandit

class EpsilonGreedyPolicy(Policy):
    """
    Epsilon-Greedy implementation of the selection policy.
    """
    def __init__(self, epsilon=0.1):
        super().__init__()
        self.epsilon = epsilon
        self.__leaves_stats = {}
    
    def init(self, node: TreeNode):
        """
        Initialize the stats for all the leaves of the tree.
        """
        if node.is_leaf:
            self.__leaves_stats[node] = {"count": 0, "reward": 0.0}
        else:
            for child in node.children:
                self.init(child)

    def select(self, nodes: list[TreeNode], n):
        """
        Select a node using the epsilon-greedy strategy.
        """
        if not all(node in self.__leaves_stats for node in nodes if node.is_leaf):
            raise ValueError("All nodes must be initialized before selection.")
        
        node_stats = {node: self.__retrieve_node_data(node) for node in nodes}
        average_rewards = [
            node_stats[node]["reward"] / (node_stats[node]["count"] + 1e-5)
            for node in nodes
        ]

        selected_nodes = set()
        selectable_nodes = {node: reward for node, reward in zip(nodes, average_rewards)}

        for _ in range(n):
            if np.random.rand() < self.epsilon:
                choice = np.random.choice(list(selectable_nodes.keys()))
            else:
                choice = max(selectable_nodes, key=selectable_nodes.get)
            selected_nodes.add(choice)
            del selectable_nodes[choice]

        return selected_nodes
    
    def __retrieve_node_data(self, node: TreeNode):
        if node.is_leaf:
            return self.__leaves_stats[node]
        
        return {
            "count": sum(self.__leaves_stats[child]["count"] for child in node.children),
            "reward": max(self.__leaves_stats[child]["reward"] for child in node.children)
        }

    def update(self, node, reward):
        """
        Update the stats for a node after a reward is received.
        """
        if node not in self.__leaves_stats:
            raise ValueError("Node must be initialized before updating.")

        self.__leaves_stats[node]["count"] += 1
        self.__leaves_stats[node]["reward"] += reward

class UCBPolicy(Policy):
    """
    Upper Confidence Bound (UCB) implementation of the selection policy.
    """
    def __init__(self, c=1.0):
        super().__init__()
        self.c = c
        self.total_count = 0
        self.__leaves_stats = {}
    
    def init(self, node: TreeNode):
        if node.is_leaf:
            self.__leaves_stats[node] = {"count": 0, "reward": 0.0}
        else:
            for child in node.children:
                self.init(child)

    def select(self, nodes: list[TreeNode], n: int):
        """
        Select a node using the UCB strategy.
        """
        if not all(node in self.__leaves_stats for node in nodes if node.is_leaf):
            raise ValueError("All nodes must be initialized before selection.")

        ucb_values = []
        for node in nodes:
            stats = self.__retrieve_node_data(node)
            if stats["count"] == 0:
                ucb_values.append(float("inf"))  # Prioritize unvisited nodes
            else:
                avg_reward = stats["reward"] / stats["count"]
                exploration_term = self.c * math.sqrt(math.log(self.total_count + 1) / stats["count"])
                ucb_values.append(avg_reward + exploration_term)

        nodes = np.array(nodes)
        return set(nodes[np.argsort(ucb_values)[-n:]])
    
    def __retrieve_node_data(self, node: TreeNode):
        if node.is_leaf:
            return self.__leaves_stats[node]
        
        return {
            "count": sum(self.__leaves_stats[child]["count"] for child in node.children),
            "reward": max(self.__leaves_stats[child]["reward"] for child in node.children)
        }

    def update(self, node, reward):
        """
        Update the stats for a node after a reward is received.
        """
        if node not in self.__leaves_stats:
            raise ValueError("Node must be initialized before updating.")

        self.__leaves_stats[node]["count"] += 1
        self.__leaves_stats[node]["reward"] += reward
        self.total_count += 1

In [20]:
policy = EpsilonGreedyPolicy(epsilon=1)
policy.init(root)

bandit = TreeBandit(root, policy)
bandit.select((3, 1))

{TreeNode({'ID': 17, 'type': 'episode', 'target': 6}),
 TreeNode({'ID': 22, 'type': 'mobility', 'target': 3}),
 TreeNode({'ID': 9, 'type': 'activity', 'target': 7})}