In [1]:
import pandas as pd

missions = pd.read_csv('./data/missions.csv')

missions

Unnamed: 0,ID,type,target
0,30,streak,1
1,6,activity,4
2,25,quiz,2
3,14,episode,3
4,10,activity,8
5,3,activity,10
6,19,exp,50
7,0,action,1
8,4,activity,2
9,26,quiz,3


In [2]:
import numpy as np
import math
from src.tree import TreeNode
from src.policy import Policy
from src.tree_bandit import TreeBandit

class EpsilonGreedyPolicy(Policy):
    """
    Epsilon-Greedy implementation of the selection policy.
    """
    def __init__(self, epsilon=0.1):
        super().__init__()
        self.epsilon = epsilon
        self._node_stats = {}
    
    def init(self, node):
        if node not in self._node_stats:
            self._node_stats[node] = {"count": 0, "reward": 0.0}

    def select(self, nodes):
        """
        Select a node using the epsilon-greedy strategy.
        """
        if not all(node in self._node_stats for node in nodes):
            raise ValueError("All nodes must be initialized before selection.")

        if np.random.rand() < self.epsilon:
            return np.random.choice(nodes)  # Exploration
        else:
            node_stats = {node: self.__retrieve_node_data(node) for node in nodes}
            average_rewards = [
                node_stats[node]["reward"] / (node_stats[node]["count"] + 1e-5)
                for node in nodes
            ]
            return nodes[np.argmax(average_rewards)]  # Exploitation
    
    def __retrieve_node_data(self, node: TreeNode):
        if node not in self._node_stats:
            raise ValueError("Node must be initialized before updating.")
        
        if node.is_leaf:
            return self._node_stats[node]
        
        return {
            "count": sum(self._node_stats[child]["count"] for child in node.children),
            "reward": max(self._node_stats[child]["reward"] for child in node.children)
        }

    def update(self, node, reward):
        """
        Update the stats for a node after a reward is received.
        """
        if node not in self._node_stats:
            raise ValueError("Node must be initialized before updating.")

        self._node_stats[node]["count"] += 1
        self._node_stats[node]["reward"] += reward

class UCBPolicy(Policy):
    """
    Upper Confidence Bound (UCB) implementation of the selection policy.
    """
    def __init__(self, c=1.0):
        super().__init__()
        self.c = c
        self.total_count = 0
        self._node_stats = {}
    
    def init(self, node):
        if node not in self._node_stats:
            self._node_stats[node] = {"count": 0, "reward": 0.0}

    def select(self, nodes):
        """
        Select a node using the UCB strategy.
        """
        if not all(node in self._node_stats for node in nodes):
            raise ValueError("All nodes must be initialized before selection.")

        ucb_values = []
        for node in nodes:
            stats = self.__retrieve_node_data(node)
            if stats["count"] == 0:
                ucb_values.append(float("inf"))  # Prioritize unvisited nodes
            else:
                avg_reward = stats["reward"] / stats["count"]
                exploration_term = self.c * math.sqrt(math.log(self.total_count + 1) / stats["count"])
                ucb_values.append(avg_reward + exploration_term)

        return nodes[np.argmax(ucb_values)]
    
    def __retrieve_node_data(self, node: TreeNode):
        if node not in self._node_stats:
            raise ValueError("Node must be initialized before updating.")
        
        if node.is_leaf:
            return self._node_stats[node]
        
        return {
            "count": sum(self._node_stats[child]["count"] for child in node.children),
            "reward": max(self._node_stats[child]["reward"] for child in node.children)
        }

    def update(self, node, reward):
        """
        Update the stats for a node after a reward is received.
        """
        if node not in self._node_stats:
            raise ValueError("Node must be initialized before updating.")

        self._node_stats[node]["count"] += 1
        self._node_stats[node]["reward"] += reward
        self.total_count += 1

# Example Usage
if __name__ == "__main__":
    # Create the tree structure
    root = TreeNode("root")
    child1 = TreeNode("child1")
    child2 = TreeNode("child2")
    leaf1 = TreeNode("leaf1")
    leaf2 = TreeNode("leaf2")
    leaf3 = TreeNode("leaf3")

    root.add_child(child1)
    root.add_child(child2)
    child1.add_child(leaf1)
    child1.add_child(leaf2)
    child2.add_child(leaf3)

    # Initialize bandit with epsilon-greedy policy
    epsilon_policy = EpsilonGreedyPolicy(epsilon=0.1)
    bandit = TreeBandit(root, epsilon_policy)
    bandit.initialize_tree(root)

    # Simulate bandit process
    print("Running Epsilon-Greedy Policy...")
    for _ in range(1000):
        selected_leaf = bandit.select()
        reward = np.random.binomial(1, p=(0.8 if selected_leaf.value == "leaf3" else 0.2))
        bandit.update(selected_leaf, reward)

    # Display results
    print("\nEpsilon-Greedy Node Statistics:")
    for node, stats in epsilon_policy._node_stats.items():
        if node.is_leaf:
            print(f"{node.value}: {stats}")

    # Initialize bandit with UCB policy
    ucb_policy = UCBPolicy(c=1.0)
    bandit = TreeBandit(root, ucb_policy)
    bandit.initialize_tree(root)

    # Simulate bandit process
    print("\nRunning UCB Policy...")
    for _ in range(1000):
        selected_leaf = bandit.select()
        reward = np.random.binomial(1, p=(0.8 if selected_leaf.value == "leaf3" else 0.2))
        bandit.update(selected_leaf, reward)

    # Display results
    print("\nUCB Node Statistics:")
    for node, stats in ucb_policy._node_stats.items():
        if node.is_leaf:
            print(f"{node.value}: {stats}")


Running Epsilon-Greedy Policy...

Epsilon-Greedy Node Statistics:
leaf1: {'count': 52, 'reward': 9.0}
leaf2: {'count': 2, 'reward': 0.0}
leaf3: {'count': 946, 'reward': 744.0}

Running UCB Policy...

UCB Node Statistics:
leaf1: {'count': 7, 'reward': 2.0}
leaf2: {'count': 6, 'reward': 2.0}
leaf3: {'count': 987, 'reward': 815.0}
