<h1>Imports</h1>

In [None]:
import gymnasium as gym
from gymnasium import spaces

import numpy as np

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Flatten, Conv2D
from tensorflow.keras.optimizers import Adam

import collections
import random
import time
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
import time 
import pickle

import ray

2025-08-07 15:57:02.749580: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754582222.782243  182357 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754582222.795108  182357 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


<h1>Neural Network</h1>

In [None]:
class NeuralNetwork:
    def __init__(self, action_space_size, learning_rate=0.001):
        self.state_shape = (23, 23, 4)
        self.action_space_size = action_space_size
        self.learning_rate = learning_rate
        self.model = self._build_model()

    def _build_model(self):
        input_layer = Input(shape=self.state_shape, name='matrix_input')
        x = Conv2D(filters=64, kernel_size=(3, 3), padding='same', activation='relu')(input_layer)
        x = Conv2D(filters=128, kernel_size=(3, 3), padding='same', activation='relu')(x)
        x = Flatten()(x)
        x = Dense(256, activation='relu')(x)

        policy_output = Dense(self.action_space_size, activation='softmax', name='policy_output')(x)

        value_output = Dense(1, activation='tanh', name='value_output')(x)

        model = Model(inputs=input_layer, outputs=[policy_output, value_output])
        model.compile(optimizer=Adam(self.learning_rate),
                      loss={'policy_output': 'categorical_crossentropy', 'value_output': 'mean_squared_error'})
        
        return model

    def predict(self, matrix_state):
        return self.model.predict(np.expand_dims(matrix_state, axis=0), verbose=0)

    def train(self, states, target_policies, target_values, batch_size=32):
        return self.model.fit(states, [target_policies, target_values], batch_size=batch_size, verbose=0)

<h1>Alpha Zero Agent</h1>

In [None]:
class AlphaZeroAgent:
    def _zero_array_factory(self):
        '''Inorder to save mcts details pickles needs a named function to look up, 
            using lambdas results in unpicklable objects'''
        return np.zeros(self.env.action_space.n)

    def __init__(self, env, network, simulations_per_move=50, max_depth=25, c_puct=1.0):
        self.env = env
        self.network = network
        self.simulations_per_move = simulations_per_move
        self.max_depth = max_depth
        self.c_puct = c_puct
        self.Q = collections.defaultdict(self._zero_array_factory)
        self.N_sa = collections.defaultdict(self._zero_array_factory)
        self.N_s = collections.defaultdict(int)
        self.P = {}

    def _get_matrix_state(self, state):
        board = state['board']
        num_nodes = len(board)
        matrix = np.zeros((num_nodes, num_nodes, 4), dtype=np.float32)

        np.fill_diagonal(matrix[:, :, 0], board == 1)
        np.fill_diagonal(matrix[:, :, 1], board == 2)

        adj_matrix = np.zeros((num_nodes, num_nodes), dtype=np.float32)
        for start, end_list in self.env.adj.items():
            for end in end_list:
                adj_matrix[start - 1, end - 1] = 1
        matrix[:, :, 2] = adj_matrix

        # Channel 3 - Player turn, matrix filled with 0 for goat and 1 for tiger
        matrix[:, :, 3] = state['player_turn']
        
        return matrix

    def search(self, state, depth):
        if depth >= self.max_depth:
            matrix_state = self._get_matrix_state(state) 
            _, value = self.network.predict(matrix_state)
            return -value[0][0]

        state_key = self._get_state_key(state)
        if state_key not in self.P:
            matrix_state = self._get_matrix_state(state) 
            policy, value = self.network.predict(matrix_state)
            self.P[state_key] = policy[0]
            return -value[0][0]
            
        node_env = self.env.copy(); node_env.board = state['board'].copy(); node_env.player_turn = state['player_turn']; node_env.goats_placed_count = self.env.NUM_GOATS - state['goats_to_place'][0]; node_env.goats_captured_count = state['goats_captured'][0]
        best_ucb = -np.inf; best_action = -1
        valid_actions = [a for a in range(node_env.action_space.n) if node_env.is_action_valid(a)[0]]
        for action in valid_actions:
            q_value = self.Q[state_key][action]; ucb = q_value + self.c_puct * self.P[state_key][action] * np.sqrt(self.N_s[state_key]) / (1 + self.N_sa[state_key][action]);
            if ucb > best_ucb: best_ucb = ucb; best_action = action
        if best_action == -1: return 0
        action = best_action
        next_state, _, done, info = node_env.step(action)
        if done:
            winner = info.get('winner', -1); value = 0
            if winner != -1: value = 1 if winner == state['player_turn'] else -1
        else: value = self.search(next_state, depth + 1)
        self.Q[state_key][action] = (self.N_sa[state_key][action] * self.Q[state_key][action] + value) / (self.N_sa[state_key][action] + 1); self.N_sa[state_key][action] += 1; self.N_s[state_key] += 1
        return -value

    def get_action(self, state, turn_count, training=True):
        state_key = self._get_state_key(state)
        for _ in range(self.simulations_per_move): self.search(state, 0)
        visit_counts = self.N_sa[state_key]
        if training:
            tau = 1.0 if turn_count < 10 else 0.1
            action_probs = visit_counts**(1/tau)
            action_probs[np.isnan(action_probs)] = 0
            if np.sum(action_probs) > 0: action_probs /= np.sum(action_probs)
            else:
                valid_actions = [a for a in range(self.env.action_space.n) if self.env.is_action_valid(a)[0]]
                action_probs = np.zeros(self.env.action_space.n)
                if valid_actions: action_probs[valid_actions] = 1 / len(valid_actions)
            action = np.random.choice(self.env.action_space.n, p=action_probs)
        else: action = np.argmax(visit_counts)
        policy_target = visit_counts / np.sum(visit_counts) if np.sum(visit_counts) > 0 else np.zeros(self.env.action_space.n)
        return action, policy_target
    def _get_state_key(self, state):
        return (state['board'].tobytes(), state['player_turn'])
    def save_mcts_tree(self, file_path):
        print(f"Saving MCTS tree to {file_path}"); tree_data = {'Q': self.Q, 'N_sa': self.N_sa, 'N_s': self.N_s, 'P': self.P}
        with open(file_path, 'wb') as f: pickle.dump(tree_data, f)
    def load_mcts_tree(self, file_path):
        if os.path.exists(file_path):
            print(f"Loading MCTS tree from {file_path}")
            with open(file_path, 'rb') as f:
                tree_data = pickle.load(f); self.Q = tree_data['Q']; self.N_sa = tree_data['N_sa']; self.N_s = tree_data['N_s']; self.P = tree_data['P']
        else: print("No existing MCTS tree found")

<h1>Self Play</h1>

In [None]:
# --- Ray-based Distributed AlphaZero Training ---

def train_agent(agent, training_data, batch_size=64):
    # Trains the neural network using collected self-play data
    if not training_data: print("Training data is empty, skipping training"); return
    states = np.array([agent._get_matrix_state(d[0]) for d in training_data])
    target_policies = np.array([d[1] for d in training_data])
    target_values = np.array([d[2] for d in training_data])
    agent.network.train(states, target_policies, target_values, batch_size=batch_size)

'''
Ray's @ray.remote decorator is similar to PySpark's distributed function execution.
It turns a Python function into a remote task that can be executed in parallel on different workers.
PySpark -> .map() or .foreach() on an RDD/DataFrame to distribute work.
@ray.remote distributes self-play games across CPU cores.
'''
# The @ray.remote decorator turns a normal Python function into a distributed task
@ray.remote(num_cpus=1, num_gpus=0.1) # 1 task uses 1 CPU and 0.1 of GPU( meaning it can run 10 tasks concurrently)
def self_play_task(model_weights):
    # This worker receives the model weights directly, instead of reading from a file
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    env = AaduPulliEnv()
    action_space_size = env.action_space.n
    
    network = NeuralNetwork(action_space_size)
    network.model.set_weights(model_weights) 

    agent = AlphaZeroAgent(env, network, simulations_per_move=SIMULATIONS_PER_MOVE, max_depth=25)
    
    game_history = []; state = agent.env.reset(); done = False; turn = 0
    while not done:
        action, policy = agent.get_action(state, turn, training=True)
        game_history.append((state, policy)); state, _, done, info = agent.env.step(action); turn += 1
    winner = info.get('winner', -1); game_training_data = []
    for hist_state, hist_policy in game_history:
        value = 0
        if winner != -1: value = 1 if hist_state['player_turn'] == winner else -1
        game_training_data.append((hist_state, hist_policy, value))
    return game_training_data

'''
Ray actors are similar to PySpark's accumulators or broadcast variables, but with mutable state.
A Ray actor is a Python class that runs in its own process and can maintain state across method calls.
This is useful for things like model weights and replay buffers.
'''
@ray.remote(num_gpus=1) # This actor will live on the GPU
class Trainer:
    def __init__(self, action_space_size, learning_rate, replay_buffer_size):
        self.network = NeuralNetwork(action_space_size, learning_rate)
        self.agent = AlphaZeroAgent(AaduPulliEnv(), self.network) 
        self.buffer = collections.deque(maxlen=replay_buffer_size)

    def train(self, batch_size):
        # Sample a batch from the replay buffer and train the model
        if len(self.buffer) < batch_size:
            return "Buffer too small, skipping training."
        
        training_data = random.sample(list(self.buffer), batch_size)
        train_agent(self.agent, training_data, batch_size)
        return "Training step complete."

    def add_data(self, data):
        # Add new self-play data to the replay buffer
        self.buffer.extend(data)

    def get_weights(self):
        # Return the current model weights
        return self.network.model.get_weights()

    def get_buffer_size(self):
        # Return the current size of the replay buffer
        return len(self.buffer)
    
    def save_state(self, model_path, buffer_path):
        # Save model weights and replay buffer to disk
        self.network.model.save_weights(model_path)
        with open(buffer_path, 'wb') as f:
            pickle.dump(self.buffer, f)
        return f"Saved model to {model_path} and buffer to {buffer_path}"

    def load_state(self, model_path, buffer_path):
        # Load model weights and replay buffer from disk if available
        if os.path.exists(model_path):
            self.network.model.load_weights(model_path)
            print(f"Loaded model weights from {model_path}")
        if os.path.exists(buffer_path):
            with open(buffer_path, 'rb') as f:
                self.buffer = pickle.load(f)
            print(f"Loaded replay buffer from {buffer_path}")


if __name__ == '__main__':
    ray.init()  # Initialize Ray runtime (like SparkContext in PySpark)

    NUM_ITERATIONS = 10
    GAMES_PER_ITERATION = 20
    LEARNING_RATE = 0.001
    BATCH_SIZE = 1024
    REPLAY_BUFFER_SIZE = 50000 
    SIMULATIONS_PER_MOVE = 80
    MODEL_SAVE_PATH = '...'
    BUFFER_SAVE_PATH = '...'
    MODEL_IMPORT_PATH = '...'
    BUFFER_IMPORT_PATH = '...'
    MAX_ELAPSED_SECONDS = 39600

    action_space_size = AaduPulliEnv().action_space.n
    
    # Create the central trainer actor (like a driver node in Spark)
    trainer = Trainer.remote(action_space_size, LEARNING_RATE, REPLAY_BUFFER_SIZE)
    ray.get(trainer.load_state.remote(MODEL_SAVE_PATH, BUFFER_SAVE_PATH))
    time_elapsed = time.time()
    for iteration in range(NUM_ITERATIONS):
        print(f"\n{'='*20} ITERATION {iteration+1} {'='*20}")
        
        model_weights_id = trainer.get_weights.remote()

        results_ids = []
        for _ in tqdm(range(GAMES_PER_ITERATION)):
            # In PySpark, this would be like submitting jobs to the cluster.
            # Here, we launch remote self-play tasks in parallel.
            if (time.time() - time_elapsed)  >= MAX_ELAPSED_SECONDS : print("Max seconds - {MAX_ELAPSED_SECONDS} elapsed"); break
                
            results_ids.append(self_play_task.remote(model_weights_id))

        # Gather results from all remote tasks (like collecting RDD results in PySpark)
        all_new_data = ray.get(results_ids)
        
        for data in all_new_data:
            trainer.add_data.remote(data)
        train_status = ray.get(trainer.train.remote(BATCH_SIZE))
        print(train_status)
        
        if (iteration + 1) % 2 == 0:
            print("Saving training state for long-term learning")
            save_message = ray.get(trainer.save_state.remote(MODEL_SAVE_PATH, BUFFER_SAVE_PATH))
            print(save_message)
    
    ray.shutdown()  # Clean up Ray resources (like SparkContext.stop())

2025-08-07 15:57:08,337	INFO worker.py:1917 -- Started a local Ray instance.
[36m(pid=182522)[0m 2025-08-07 15:57:10.586611: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=182522)[0m E0000 00:00:1754582230.608417  182522 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=182522)[0m E0000 00:00:1754582230.615372  182522 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
[36m(Trainer pid=182522)[0m I0000 00:00:1754582235.040520  182522 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13942 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5
[36m(Trainer pid=182522)[0m   saveable.lo

[36m(Trainer pid=182522)[0m Loaded model weights from /kaggle/working/alphazero_aadu_pulli_ray.weights.h5



100%|██████████| 2/2 [00:00<00:00, 110.78it/s]

[36m(Trainer pid=182522)[0m Loaded replay buffer from /kaggle/working/replay_buffer.pkl



[36m(pid=182523)[0m 2025-08-07 15:57:16.981159: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=182523)[0m E0000 00:00:1754582237.003964  182523 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=182523)[0m E0000 00:00:1754582237.010801  182523 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
[36m(self_play_task pid=182523)[0m I0000 00:00:1754582240.974561  182523 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13906 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5
[36m(self_play_task pid=182524)[0m I0000 00:00:1754582240.985054  182524 gpu_device.cc:2022] Created device /j

[36m(Trainer pid=182522)[0m Training network on 1024 samples...


[36m(Trainer pid=182522)[0m I0000 00:00:1754582274.959749  182737 service.cc:148] XLA service 0x79456c00c1d0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:[32m [repeated 2x across cluster][0m
[36m(Trainer pid=182522)[0m I0000 00:00:1754582274.959834  182737 service.cc:156]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5[32m [repeated 2x across cluster][0m
[36m(self_play_task pid=182524)[0m I0000 00:00:1754582242.296518  182795 cuda_dnn.cc:529] Loaded cuDNN version 90300
[36m(self_play_task pid=182524)[0m I0000 00:00:1754582243.315855  182795 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
[36m(Trainer pid=182522)[0m I0000 00:00:1754582275.159969  182737 cuda_dnn.cc:529] Loaded cuDNN version 90300
[36m(Trainer pid=182522)[0m I0000 00:00:1754582282.398775  182737 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once fo

Training step complete.



100%|██████████| 2/2 [00:00<00:00, 1893.16it/s]
[36m(pid=182521)[0m 2025-08-07 15:58:03.944829: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=182521)[0m E0000 00:00:1754582283.968560  182521 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=182521)[0m E0000 00:00:1754582283.976780  182521 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
[36m(self_play_task pid=182521)[0m I0000 00:00:1754582288.499825  182521 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13942 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5
[36m(pid=188260)[0m 2025-08-07 15:58:04.396529: E external/loca

[36m(Trainer pid=182522)[0m Training network on 1024 samples...
Training step complete.

Saving training state for long-term learning...
Saved model to /kaggle/working/alphazero_aadu_pulli_ray.weights.h5 and buffer to /kaggle/working/replay_buffer.pkl



  0%|          | 0/2 [00:00<?, ?it/s]

time done





[36m(Trainer pid=182522)[0m Training network on 1024 samples...
Training step complete.



  0%|          | 0/2 [00:00<?, ?it/s]

time done
[36m(Trainer pid=182522)[0m Training network on 1024 samples...





Training step complete.

Saving training state for long-term learning...
Saved model to /kaggle/working/alphazero_aadu_pulli_ray.weights.h5 and buffer to /kaggle/working/replay_buffer.pkl



  0%|          | 0/2 [00:00<?, ?it/s]

time done
[36m(Trainer pid=182522)[0m Training network on 1024 samples...





Training step complete.



  0%|          | 0/2 [00:00<?, ?it/s]

time done





[36m(Trainer pid=182522)[0m Training network on 1024 samples...
Training step complete.

Saving training state for long-term learning...
Saved model to /kaggle/working/alphazero_aadu_pulli_ray.weights.h5 and buffer to /kaggle/working/replay_buffer.pkl



  0%|          | 0/2 [00:00<?, ?it/s]

time done
[36m(Trainer pid=182522)[0m Training network on 1024 samples...





Training step complete.



  0%|          | 0/2 [00:00<?, ?it/s]

time done
[36m(Trainer pid=182522)[0m Training network on 1024 samples...





Training step complete.

Saving training state for long-term learning...
Saved model to /kaggle/working/alphazero_aadu_pulli_ray.weights.h5 and buffer to /kaggle/working/replay_buffer.pkl



  0%|          | 0/2 [00:00<?, ?it/s]

time done
[36m(Trainer pid=182522)[0m Training network on 1024 samples...





Training step complete.



  0%|          | 0/2 [00:00<?, ?it/s]

time done
[36m(Trainer pid=182522)[0m Training network on 1024 samples...





Training step complete.

Saving training state for long-term learning...
Saved model to /kaggle/working/alphazero_aadu_pulli_ray.weights.h5 and buffer to /kaggle/working/replay_buffer.pkl

Training complete. Performing final save...
Saved model to /kaggle/working/alphazero_aadu_pulli_ray.weights.h5 and buffer to /kaggle/working/replay_buffer.pkl


[36m(self_play_task pid=188260)[0m I0000 00:00:1754582288.959029  188260 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13840 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5
[36m(self_play_task pid=188260)[0m I0000 00:00:1754582290.102959  188406 service.cc:148] XLA service 0x7f20e0006850 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
[36m(self_play_task pid=188260)[0m I0000 00:00:1754582290.103012  188406 service.cc:156]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
[36m(self_play_task pid=188260)[0m I0000 00:00:1754582290.144306  188406 cuda_dnn.cc:529] Loaded cuDNN version 90300
[36m(self_play_task pid=188260)[0m I0000 00:00:1754582290.841156  188406 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
