In [4]:
# import numpy as np
# N = 3
# def get_coupling_map():
#     """
#     Generates a set of possible column pairs.
#     Here, we use all adjacent column pairs as an example.
#     """
#     coupling_map = set()
#     for i in range(N - 1):
#         coupling_map.add((i, i + 1))
#     return coupling_map

# def get_coupling_map_mat(coupling_map):
#     """
#     Generates an upper triangular matrix from coupling_map.
#     For each column pair (i, j) in coupling_map, sets the position (i, j) in the matrix to 1.
#     """
#     coupling_map_mat = np.zeros((N, N))
#     for (i, j) in coupling_map:
#         coupling_map_mat[i, j] = 1
#         coupling_map_mat[j,i] = 1
#     return coupling_map_mat  # Make it an upper triangular matrix

# def get_initial_state():
#     """
#     Generates an N x N upper triangular matrix consisting of random 0s and 1s.
#     """
#     def _generate_binary_symmetric_matrix():
#         upper_triangle = np.triu(np.random.randint(0, 2, size=(N, N)), k=1)
#         symmetric_matrix = upper_triangle + upper_triangle.T
#         np.fill_diagonal(symmetric_matrix, 0)
#         return symmetric_matrix
#     mat = _generate_binary_symmetric_matrix()
#     mat = mat - np.multiply(mat, coupling_map_mat)
#     return mat

# coupling_map_mat = get_coupling_map_mat(get_coupling_map())
# mat = get_initial_state()
# mat = mat-coupling_map_mat*mat
# col1, col2 = 1,2
# print(mat)
# print(col1,col2)
# print(coupling_map_mat)
# new_mat = mat.copy()
# # Swap columns
# new_mat[:, [col1, col2]] = new_mat[:, [col2, col1]]
# new_mat[[col1, col2],:] = new_mat[[col2, col1],:]
# new_mat = new_mat - np.multiply(new_mat, coupling_map_mat)
# new_mat = np.clip(new_mat, 0, 1)  # Prevent elements from becoming negative
# new_mat

In [5]:
from dataclasses import dataclass
import time
import random
from pathlib import Path
import shutil

import tensorflow as tf
import numpy as np
from tqdm import tqdm
import yaml

from rl.network import ResNet
from rl.mcts import MCTS
from rl.buffer import ReplayBuffer
from rl import game

with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

training_settings = config["training_settings"]
network_settings = config["network_settings"]
mcts_settings = config["mcts_settings"]
num_cpus = training_settings["num_cpus"]
n_episodes = training_settings["n_episodes"]
buffer_size = training_settings["buffer_size"]
batch_size = training_settings["batch_size"]
epochs_per_update = training_settings["epochs_per_update"]
update_period = training_settings["update_period"]
save_period = training_settings["save_period"]


@dataclass
class Sample:
    state: np.ndarray
    mcts_policy: np.ndarray
    reward: float


def selfplay(weights, test=False):
    """Perform a self-play game and collect training data."""
    record = []
    if test:
        state = game.get_initial_test_state()
    else:
        state = game.get_initial_state()
    game.reset_used_columns_set()
    network = ResNet(action_space=game.ACTION_SPACE)

    # Initialize network parameters
    network.predict(game.encode_state(state))
    network.set_weights(weights)

    mcts = MCTS(network=network)
    done = False
    total_score = 0
    step_count = 0

    while not done and step_count < game.MAX_STEPS:
        mcts_policy = mcts.search(
            root_state=state, num_simulations=mcts_settings["num_mcts_simulations"]
        )
        action = np.random.choice(range(game.ACTION_SPACE), p=mcts_policy)
        record.append(Sample(state.copy(), mcts_policy, reward=None))
        state, action_score, done = game.step(state, action, mcts_policy)
        print(state, action_score, done)
        total_score += action_score
        step_count += 1
    print("======================")

    # The reward is calculated based on the final state
    reward = game.get_reward(state, total_score)

    # Assign the reward to each sample
    for sample in record:
        sample.reward = reward

    return record

In [6]:
with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

training_settings = config["training_settings"]
network_settings = config["network_settings"]
mcts_settings = config["mcts_settings"]
num_cpus = training_settings["num_cpus"]
n_episodes = training_settings["n_episodes"]
buffer_size = training_settings["buffer_size"]
batch_size = training_settings["batch_size"]
epochs_per_update = training_settings["epochs_per_update"]
update_period = training_settings["update_period"]
save_period = training_settings["save_period"]

# ray.init(num_cpus=num_cpus, num_gpus=1, local_mode=False)

logdir = Path("log")
if logdir.exists():
    shutil.rmtree(logdir)
summary_writer = tf.summary.create_file_writer(str(logdir))

game.initialize_game()  # Initialize game variables

network = ResNet(action_space=game.ACTION_SPACE)

dummy_state = game.encode_state(game.get_initial_state())
network.predict(dummy_state)

current_weights = network.get_weights()

optimizer = tf.keras.optimizers.Adam(learning_rate=network_settings["learning_rate"])

replay = ReplayBuffer(buffer_size=buffer_size)

# # Start self-play workers
work_in_progresses = [selfplay(current_weights, True)]


n_updates = 0
n = 0
while n <= n_episodes:
    for _ in range(update_period):
        # Wait for a self-play worker to finish
        finished = selfplay(current_weights, True)
        replay.add_record(finished)
        n += 1

    # Update network
    if len(replay) >= batch_size:
        num_iters = epochs_per_update * (len(replay) // batch_size)
        for i in range(num_iters):
            states, mcts_policy, rewards = replay.get_minibatch(batch_size=batch_size)
            with tf.GradientTape() as tape:
                p_pred, v_pred = network(states, training=True)
                value_loss = tf.square(rewards - v_pred)
                policy_loss = -tf.reduce_sum(
                    mcts_policy * tf.math.log(p_pred + 1e-10), axis=1, keepdims=True
                )
                loss = tf.reduce_mean(value_loss + policy_loss)

            grads = tape.gradient(loss, network.trainable_variables)
            optimizer.apply_gradients(zip(grads, network.trainable_variables))
            n_updates += 1

            if i % 100 == 0:
                with summary_writer.as_default():
                    tf.summary.scalar(
                        "value_loss", tf.reduce_mean(value_loss), step=n_updates
                    )
                    tf.summary.scalar(
                        "policy_loss", tf.reduce_mean(policy_loss), step=n_updates
                    )

        current_weights = network.get_weights()

    if n % save_period == 0:
        network.save_weights("checkpoints/network")

[[0. 0. 0. 1.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [1. 0. 0. 0.]] 4 False
[[0. 0. 1. 0.]
 [0. 0. 0. 0.]
 [1. 0. 0. 0.]
 [0. 0. 0. 0.]] 4 False
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]] 1 True
[[0. 0. 1. 1.]
 [0. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]] 4 False
[[0. 0. 0. 0.]
 [0. 0. 0. 1.]
 [0. 0. 0. 0.]
 [0. 1. 0. 0.]] 1 False
[[0. 0. 0. 1.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [1. 0. 0. 0.]] 1 False
[[0. 0. 1. 0.]
 [0. 0. 0. 0.]
 [1. 0. 0. 0.]
 [0. 0. 0. 0.]] 4 False
[[0. 0. 0. 1.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [1. 0. 0. 0.]] 4 False
[[0. 0. 1. 0.]
 [0. 0. 0. 0.]
 [1. 0. 0. 0.]
 [0. 0. 0. 0.]] 4 False
[[0. 0. 0. 1.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [1. 0. 0. 0.]] 4 False
[[0. 0. 0. 0.]
 [0. 0. 0. 1.]
 [0. 0. 0. 0.]
 [0. 1. 0. 0.]] 1 False
[[0. 0. 0. 1.]
 [0. 0. 0. 1.]
 [0. 0. 0. 0.]
 [1. 1. 0. 0.]] 1 False
[[0. 0. 0. 1.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [1. 0. 0. 0.]] 4 False
[[0. 0. 0. 0.]
 [0. 0. 0. 1.]
 [0. 0. 0. 0.]
 [0. 1. 0. 0.]] 1 False
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 