In [2]:
import gymnasium as gym
import numpy as np
!pip install ale_py
import ale_py
import pickle

Collecting ale_py
  Downloading ale_py-0.10.1-cp311-cp311-win_amd64.whl (1.4 MB)
     ---------------------------------------- 1.4/1.4 MB 6.8 MB/s eta 0:00:00
Installing collected packages: ale_py
Successfully installed ale_py-0.10.1



[notice] A new release of pip available: 22.3 -> 24.3.1
[notice] To update, run: C:\Users\v\AppData\Local\Programs\Python\Python311\python.exe -m pip install --upgrade pip


In [3]:
MAX_STEPS = 1000
TRACK_SEED = 42 # in order to reproduce the same random choices
GAMMA = 1  # discount factor
NUM_ACTIONS = 5
UCT_C = 1.4  # exploration-exploitation trade-off constant

In [7]:
class Node:
    def __init__(self, parent, terminal):
        self.parent = parent
        self.terminal = terminal

        self.children = {action: None for action in range(NUM_ACTIONS)}
        self.visit_count = 0
        self.total_reward = 0

    def get_value(self):
        return self.total_reward / self.visit_count if self.visit_count > 0 else 0

    def height(self):
        height = 1
        for child in self.children.values():
            if child is not None:
                height = max(height, 1 + child.height())
        return height

    def backpropagate(self, reward):
        current = self
        while current is not None:
            current.visit_count += 1
            current.total_reward += reward
            reward *= GAMMA
            current = current.parent


def select_child_with_uct(root):
    children = list(root.children.values())
    total_visits = sum(child.visit_count for child in children if child is not None)
    uct_values = [
        (child.get_value() + UCT_C * np.sqrt(np.log(total_visits) / child.visit_count))
        if child is not None and child.visit_count > 0
        else float('inf')  # prioritize unvisited nodes
        for child in children
    ]
    best_index = np.argmax(uct_values)
    return best_index, children[best_index]


def rollout():
    total_reward = 0
    for _ in range(MAX_STEPS):
        action = np.random.randint(0, NUM_ACTIONS)
        _, reward, terminated, truncated, _ = env.step(action)
        total_reward += reward
        if terminated or truncated:
            break
    return total_reward


def expand_and_simulate(root):
    if root.terminal:
        return 0

    # expansion: find unexplored action
    for action in range(NUM_ACTIONS):
        if root.children[action] is None:
            _, reward, terminated, truncated, _ = env.step(action)
            expanded = Node(root, terminated or truncated)
            root.children[action] = expanded
            
            # simulate
            total_reward = reward + rollout()
            expanded.backpropagate(total_reward)
            return total_reward

    # select UCT
    best_index, best_child = select_child_with_uct(root)
    env.step(best_index)
    return expand_and_simulate(best_child)


def reset_env(history):
    observation, info = env.reset()
    for action in history:
        env.step(action)
    return observation


def monte_carlo(history):
    global root, env

    for _ in range(100):  # number of iterations per decision
        expand_and_simulate(root)
        reset_env(history)

    # choosing the best action
    child_values = [child.get_value() if child else -float('inf') for child in root.children.values()]
    best_action = np.argmax(child_values)
    print(f"Tree height: {root.height()}")
    return best_action


In [9]:
root = Node(None, False)
history = []
np.random.seed(TRACK_SEED)
gym.register_envs(ale_py)
env = gym.make('ALE/Pacman-v5', render_mode="rgb_array")
observation, info = env.reset()
cumulative_reward = 0

while True:
    action = monte_carlo(history)
    observation, reward, terminated, truncated, info = env.step(action)
    history.append(action)
    root = root.children[action]
    cumulative_reward += reward # the reward is marked only for a frame (out of many), so most frames will have reward 0
    print(f"History: {history}, Reward: {cumulative_reward}")

    if terminated or truncated:
        break

env.close()

Tree height: 8
History: [0], Reward: 0.0
Tree height: 11
History: [0, 1], Reward: 0.0
Tree height: 14
History: [0, 1, 3], Reward: 0.0
Tree height: 13
History: [0, 1, 3, 0], Reward: 0.0
Tree height: 8
History: [0, 1, 3, 0, 2], Reward: 0.0
Tree height: 10
History: [0, 1, 3, 0, 2, 1], Reward: 0.0
Tree height: 10
History: [0, 1, 3, 0, 2, 1, 1], Reward: 0.0
Tree height: 10
History: [0, 1, 3, 0, 2, 1, 1, 2], Reward: 0.0
Tree height: 17
History: [0, 1, 3, 0, 2, 1, 1, 2, 3], Reward: 0.0
Tree height: 18
History: [0, 1, 3, 0, 2, 1, 1, 2, 3, 0], Reward: 0.0
Tree height: 17
History: [0, 1, 3, 0, 2, 1, 1, 2, 3, 0, 4], Reward: 0.0
Tree height: 17
History: [0, 1, 3, 0, 2, 1, 1, 2, 3, 0, 4, 0], Reward: 0.0
Tree height: 17
History: [0, 1, 3, 0, 2, 1, 1, 2, 3, 0, 4, 0, 3], Reward: 0.0
Tree height: 9
History: [0, 1, 3, 0, 2, 1, 1, 2, 3, 0, 4, 0, 3, 2], Reward: 0.0
Tree height: 9
History: [0, 1, 3, 0, 2, 1, 1, 2, 3, 0, 4, 0, 3, 2, 4], Reward: 0.0
Tree height: 10
History: [0, 1, 3, 0, 2, 1, 1, 2, 3, 0, 4, 

KeyboardInterrupt: 

In [10]:
# save the action history
import pickle
with open('action_history.pkl', 'wb') as f:
    pickle.dump(history, f)
print('Action history saved')

Action history saved


In [15]:
saved_history = None
try:
    with open('action_history.pkl', 'rb') as f:
        saved_history = pickle.load(f)
        print('Action history loaded')
except FileNotFoundError:
    print('Action history not found')


# replay the action history
env = gym.make('ALE/Pacman-v5', render_mode="human")
observation, info = env.reset()

# here the reward can actually be seen
for i in range(len(saved_history)):
    action = saved_history[i]
    observation, reward, terminated, truncated, info = env.step(action)
    if terminated or truncated:
        break
env.close()

Action history loaded
