In [1]:
import os
import glob

from pathlib import Path
import shutil

import tensorflow as tf
import tf_keras as keras
import numpy as np
from tqdm import tqdm
import yaml

from rl.network import ResNet
from rl.mcts import MCTS
from rl.buffer import ReplayBuffer, Sample
from rl.game import Game, encode_state

with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

base_path = "graphs"
index = "20241212"
qubits = config["game_settings"]["N"]
training_settings = config["training_settings"]
network_settings = config["network_settings"]
mcts_settings = config["mcts_settings"]
num_cpus = training_settings["num_cpus"]
num_gpus = training_settings["num_gpus"]
n_episodes = training_settings["n_episodes"]
buffer_size = training_settings["buffer_size"]
batch_size = training_settings["batch_size"]
epochs_per_update = training_settings["epochs_per_update"]
update_period = training_settings["update_period"]
save_period = training_settings["save_period"]
eval_period = training_settings.get("eval_period", 100)


def selfplay(weights, qubits, current_episode, config):
    record = []
    game = Game(qubits, config)
    state = game.get_initial_state()
    game.reset_used_columns()
    network = ResNet(action_space=game.action_space, config=config)
    network.predict(encode_state(state, qubits))
    network.set_weights(weights)

    mcts = MCTS(qubits=qubits, network=network, config=config)
    done = False
    total_score = 0
    step_count = 0
    prev_action = None

    while not done and step_count < game.MAX_STEPS:
        mcts_policy = mcts.search(
            root_state=state,
            prev_action=prev_action,
            num_simulations=mcts_settings["num_mcts_simulations"],
        )

        if prev_action is not None:
            indices = [i for i in range(game.action_space) if i != prev_action]
            valid_actions = game.get_valid_actions(state, prev_action)
            prob = mcts_policy[valid_actions]
            prob = prob / prob.sum()
            action = np.random.choice(valid_actions, p=prob)
        else:
            indices = list(range(game.action_space))
            prob = mcts_policy
            action = np.random.choice(indices, p=prob)
        record.append(Sample(state.copy(), mcts_policy, reward=None))
        state, done, action_score = game.step(state, action, prev_action)
        prev_action = action
        # print(state, action_score, done)
        total_score += action_score
        step_count += 1

    reward = game.get_reward(state, total_score)
    for sample in record:
        sample.reward = reward
    return record


def evaluate_self_play(qubits, network, config):
    pattern = os.path.join(base_path, f"adj_matrix_{qubits}_*.npy")
    file_paths = glob.glob(pattern)
    avg_depth = []
    avg_counts = []
    for file_path in file_paths:
        state = np.load(file_path)
        game = Game(qubits, config)
        swap_pairs = []
        done = False
        step_count = 0
        prev_action = None
        while not done and step_count < game.MAX_STEPS:
            encoded_state = encode_state(state, qubits)
            input_state = np.expand_dims(encoded_state, axis=0)
            policy_output, value_output = network.predict(input_state)
            policy = np.array(policy_output)[0]
            # policy = tf.nn.softmax(policy_logits).numpy()[0]
            # valid_actions = game.get_valid_actions(state, prev_action)
            if prev_action is not None:
                indices = [i for i in range(game.action_space) if i != prev_action]
                try:
                    valid_actions = game.get_valid_actions(state, prev_action)
                    prob = policy[valid_actions]
                except:
                    prob = policy[indices]
                try:
                    action = np.random.choice(valid_actions, p=prob / prob.sum())
                except:
                    action = np.random.choice(valid_actions)
            else:
                indices = list(range(game.action_space))
                prob = policy
                action = np.random.choice(indices, p=prob / prob.sum())
            if action < len(game.coupling_map):
                selected_action = game.coupling_map[action]
                swap_pairs.append(selected_action)
            else:
                for pair in game.coupling_map[action%2::2]:
                    swap_pairs.append(pair)
            state, done, _ = game.step(state, action, prev_action)
            prev_action = action
            step_count += 1
        if not done:
            depth = game.MAX_STEPS
            swap_count = game.MAX_STEPS
        else:
            game.current_layer+=1
            depth = game.current_layer
            swap_count = len(swap_pairs)
        print(f"depth: {depth}, count: {swap_count}")
        avg_counts.append(swap_count)
        avg_depth.append(depth)
    return avg_depth, avg_counts

2024-12-14 14:31:09.417889: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
logdir = Path("log")
if logdir.exists():
    shutil.rmtree(logdir)
summary_writer = tf.summary.create_file_writer(str(logdir))

game = Game(qubits, config)
network = ResNet(action_space=game.action_space, config=config)

dummy_state = encode_state(game.state, qubits)
network.predict(encode_state(game.state, qubits))
current_weights = network.get_weights()

optimizer = keras.optimizers.legacy.Adam(
    learning_rate=network_settings["learning_rate"]
)

replay = ReplayBuffer(buffer_size=buffer_size)

n_updates = 0

n = 0
while n < n_episodes:
    for _ in tqdm(range(update_period)):
        finished = selfplay(current_weights, qubits, n, config)
        replay.add_record(finished)
        n += 1
    print("-" * 50)
    if len(replay) >= batch_size:
        num_iters = epochs_per_update * (len(replay) // batch_size)
        value_loss_weight = 0.5
        policy_loss_weight = 1.5
        for i in tqdm(range(num_iters)):
            states, mcts_policy, rewards = replay.get_minibatch(batch_size=batch_size)
            with tf.GradientTape() as tape:
                p_pred, v_pred = network(states, training=True)
                value_loss = tf.square(rewards - v_pred)
                policy_loss = -tf.reduce_sum(
                    mcts_policy * tf.math.log(p_pred + 1e-5), axis=1, keepdims=True
                )
                loss = tf.reduce_mean(
                    value_loss_weight * value_loss + policy_loss_weight * policy_loss
                )
            grads = tape.gradient(loss, network.trainable_variables)
            grads, _ = tf.clip_by_global_norm(grads, 0.5)
            optimizer.apply_gradients(zip(grads, network.trainable_variables))
            n_updates += 1

            if i % 10 == 0:
                with summary_writer.as_default():
                    tf.summary.scalar(
                        "value_loss", tf.reduce_mean(value_loss), step=n_updates
                    )
                    tf.summary.scalar(
                        "policy_loss", tf.reduce_mean(policy_loss), step=n_updates
                    )

        current_weights = network.get_weights()

    if n % save_period == 0:
        network.save(f"checkpoints/network{qubits}_{index}_{n}", save_format="tf")
        network.save_weights(f"checkpoints/network{qubits}_{index}_{n}.weights.h5")
        print("-" * 50)
    if n % eval_period == 0:
        depth, count = evaluate_self_play(qubits, network, config)
        print(
            f"Episode {n}: SWAP depth is {np.mean(depth)}, SWAP count is {np.mean(count)}"
        )
        print("-" * 50)

2024-12-14 14:31:11.888423: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:984] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-12-14 14:31:11.991698: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:984] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-12-14 14:31:11.992497: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:984] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-12-14 14:31:11.997151: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:984] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-12-14 14:31:11.997669: I external/local_xla/xla/stream_executor

In [None]:
game = Game(qubits, config)
network = ResNet(action_space=game.action_space, config=config)
network = keras.models.load_model(f"checkpoints/network{qubits}_{index}_500")

In [None]:
depths = []
for _ in range(20):
    depth, count = evaluate_self_play(qubits, network, config)
    depths.append(depth)
min_depth = np.min(np.vstack(depths), axis=0)

In [None]:
for _ in range(20):
    game = Game(qubits, config)
    state = game.state
    ans = []
    done = False
    total_score = 0
    step_count = 0
    prev_action = None
    print(state)
    while not done and step_count < game.MAX_STEPS:
        encoded_state = encode_state(state, qubits)
        input_state = np.expand_dims(encoded_state, axis=0)

        policy_output, value_output = network.predict(input_state)
        policy = policy_output[0]
        if prev_action is not None:
            indices = [i for i in range(game.action_space) if i != prev_action]
            prob = policy[indices]
            action = np.random.choice(indices, p=prob / prob.sum())
        else:
            indices = list(range(game.action_space))
            action = np.random.choice(indices, p=policy)
        selected_action = game.coupling_map[action]
        ans.append(selected_action)
        state, done, _ = game.step(state, action, prev_action)
        prev_action = action
        step_count += 1
    if done:
        print(f"Game finished successfully in {step_count} steps with {ans}")
    else:
        print(f"Game terminated after reaching the maximum steps ({game.MAX_STEPS}).")
        print(f"Total score: {total_score}")

In [None]:
min_depth

In [None]:
np.mean(min_depth)

array([1, 5, 6, 4, 6, 7, 7, 2, 6, 3, 2, 5, 3, 7, 5, 7, 6, 3, 5, 2, 4, 9,
       2, 6, 5, 5, 3, 2, 2, 1]) -> 4.366666666666666