In [None]:
import os
import glob

from pathlib import Path
import shutil

import tensorflow as tf
import tf_keras as keras
import numpy as np
from tqdm import tqdm
import yaml

from rl.network import ResNet
from rl.mcts import MCTS
from rl.buffer import ReplayBuffer, Sample
from rl.game import Game, encode_state

with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

base_path = "graphs"
index = "20241201"
qubits = config["game_settings"]["N"]
training_settings = config["training_settings"]
network_settings = config["network_settings"]
mcts_settings = config["mcts_settings"]
num_cpus = training_settings["num_cpus"]
num_gpus = training_settings["num_gpus"]
n_episodes = training_settings["n_episodes"]
buffer_size = training_settings["buffer_size"]
batch_size = training_settings["batch_size"]
epochs_per_update = training_settings["epochs_per_update"]
update_period = training_settings["update_period"]
save_period = training_settings["save_period"]
eval_period = training_settings.get("eval_period", 100)


def selfplay(weights, qubits, current_episode, config):
    record = []
    game = Game(qubits, config)
    state = game.get_initial_state()
    game.reset_used_columns()
    network = ResNet(action_space=len(game.coupling_map), config=config)
    network.predict(encode_state(state, qubits))
    network.set_weights(weights)

    mcts = MCTS(qubits=qubits, network=network, config=config)
    done = False
    total_score = 0
    step_count = 0
    prev_action = None

    while not done and step_count < game.MAX_STEPS:
        mcts_policy = mcts.search(
            root_state=state,
            prev_action=prev_action,
            num_simulations=mcts_settings["num_mcts_simulations"],
        )
        if prev_action is not None:
            indices = [i for i in range(len(game.coupling_map)) if i != prev_action]
            valid_actions = game.get_valid_actions(state, prev_action)
            prob = mcts_policy[valid_actions]
            prob = prob / prob.sum()
            action = np.random.choice(valid_actions, p=prob)
        else:
            indices = list(range(len(game.coupling_map)))
            prob = mcts_policy
            action = np.random.choice(indices, p=prob)
        record.append(Sample(state.copy(), mcts_policy, reward=None))
        state, done, action_score = game.step(state, action, prev_action)
        prev_action = action
        # print(state, action_score, done)
        total_score += action_score
        step_count += 1

    reward = game.get_reward(state, total_score)
    for sample in record:
        sample.reward = reward
    return record


def evaluate_self_play(qubits, network, config):
    pattern = os.path.join(base_path, f"adj_matrix_{qubits}_*.npy")
    file_paths = glob.glob(pattern)
    avg_depth = []
    avg_counts = []
    for file_path in tqdm(file_paths):
        state = np.load(file_path)
        game = Game(qubits, config)
        swap_pairs = []
        done = False
        step_count = 0
        prev_action = None
        while not done and step_count < game.MAX_STEPS:
            encoded_state = encode_state(state, qubits)
            input_state = np.expand_dims(encoded_state, axis=0)
            policy_output, value_output = network.predict(input_state)
            policy = np.array(policy_output)[0]
            # policy = tf.nn.softmax(policy_logits).numpy()[0]
            # valid_actions = game.get_valid_actions(state, prev_action)
            if prev_action is not None:
                indices = [i for i in range(len(game.coupling_map)) if i != prev_action]
                try:
                    valid_actions = game.get_valid_actions(state, prev_action)
                    prob = policy[valid_actions]
                except:
                    prob = policy[indices]
                action = np.random.choice(valid_actions, p=prob / prob.sum())
            else:
                indices = list(range(len(game.coupling_map)))
                prob = policy
                action = np.random.choice(indices, p=prob / prob.sum())
            selected_action = game.coupling_map[action]
            swap_pairs.append(selected_action)
            state, done, _ = game.step(state, action, prev_action)
            prev_action = action
            step_count += 1
        if not done:
            depth = game.MAX_STEPS
            swap_count = game.MAX_STEPS
        else:
            depth = game.current_layer
            swap_count = len(swap_pairs)
        avg_counts.append(swap_count)
        avg_depth.append(depth)
    return avg_depth, avg_counts

In [None]:
logdir = Path("log")
if logdir.exists():
    shutil.rmtree(logdir)
summary_writer = tf.summary.create_file_writer(str(logdir))

game = Game(qubits, config)
network = ResNet(action_space=len(game.coupling_map), config=config)

dummy_state = encode_state(game.state, qubits)
network.predict(encode_state(game.state, qubits))
current_weights = network.get_weights()

optimizer = keras.optimizers.legacy.Adam(
    learning_rate=network_settings["learning_rate"]
)

replay = ReplayBuffer(buffer_size=buffer_size)

n_updates = 0

n = 0
while n < n_episodes:
    for _ in tqdm(range(update_period)):
        finished = selfplay(current_weights, qubits, n, config)
        replay.add_record(finished)
        n += 1
    print("-" * 50)
    if len(replay) >= batch_size:
        num_iters = epochs_per_update * (len(replay) // batch_size)
        value_loss_weight = 1
        policy_loss_weight = 1
        for i in tqdm(range(num_iters)):
            states, mcts_policy, rewards = replay.get_minibatch(batch_size=batch_size)
            with tf.GradientTape() as tape:
                p_pred, v_pred = network(states, training=True)
                value_loss = tf.square(rewards - v_pred)
                policy_loss = -tf.reduce_sum(
                    mcts_policy * tf.math.log(p_pred + 1e-5), axis=1, keepdims=True
                )
                loss = tf.reduce_mean(
                    value_loss_weight * value_loss + policy_loss_weight * policy_loss
                )
            grads = tape.gradient(loss, network.trainable_variables)
            grads, _ = tf.clip_by_global_norm(grads, 1.0)
            optimizer.apply_gradients(zip(grads, network.trainable_variables))
            n_updates += 1

            if i % 10 == 0:
                with summary_writer.as_default():
                    tf.summary.scalar(
                        "value_loss", tf.reduce_mean(value_loss), step=n_updates
                    )
                    tf.summary.scalar(
                        "policy_loss", tf.reduce_mean(policy_loss), step=n_updates
                    )

        current_weights = network.get_weights()

    if n % save_period == 0:
        network.save(f"checkpoints/network{qubits}_{index}_{n}", save_format="tf")
        network.save_weights(f"checkpoints/network{qubits}_{index}_{n}.weights.h5")
        print("-" * 50)
    if n % eval_period == 0:
        depth, count = evaluate_self_play(qubits, network, config)
        print(
            f"Episode {n}: SWAP depth is {np.mean(depth)}, SWAP count is {np.mean(count)}"
        )
        print("-" * 50)

In [None]:
game = Game(qubits, config)
network = ResNet(action_space=len(game.coupling_map), config=config)
network = keras.models.load_model(f"checkpoints/network{qubits}_{index}_700")

In [None]:
for _ in range(20):
    game = Game(qubits, config)
    state = game.state
    ans = []
    done = False
    total_score = 0
    step_count = 0
    prev_action = None
    print(state)
    while not done and step_count < game.MAX_STEPS:
        encoded_state = encode_state(state, qubits)
        input_state = np.expand_dims(encoded_state, axis=0)

        policy_output, value_output = network.predict(input_state)
        policy = policy_output[0]
        if prev_action is not None:
            indices = [i for i in range(len(game.coupling_map)) if i != prev_action]
            prob = policy[indices]
            action = np.random.choice(indices, p=prob / prob.sum())
        else:
            indices = list(range(len(game.coupling_map)))
            action = np.random.choice(indices, p=policy)
        selected_action = game.coupling_map[action]
        ans.append(selected_action)
        state, done, _ = game.step(state, action, prev_action)
        prev_action = action
        step_count += 1
    if done:
        print(f"Game finished successfully in {step_count} steps with {ans}")
    else:
        print(f"Game terminated after reaching the maximum steps ({game.MAX_STEPS}).")
        print(f"Total score: {total_score}")

In [12]:
depths = []
for _ in range(10):
    depth, count = evaluate_self_play(qubits, network, config)
    depths.append(depth)
min_depth = np.min(np.vstack(depths), axis=0)



 30%|███       | 9/30 [00:11<00:29,  1.42s/it]



 33%|███▎      | 10/30 [00:12<00:25,  1.27s/it]



 37%|███▋      | 11/30 [00:13<00:20,  1.07s/it]



 40%|████      | 12/30 [00:14<00:18,  1.04s/it]



 43%|████▎     | 13/30 [00:15<00:16,  1.02it/s]



 47%|████▋     | 14/30 [00:16<00:17,  1.08s/it]



 50%|█████     | 15/30 [00:17<00:15,  1.06s/it]



 53%|█████▎    | 16/30 [00:18<00:15,  1.14s/it]



 57%|█████▋    | 17/30 [00:19<00:13,  1.04s/it]



 60%|██████    | 18/30 [00:21<00:13,  1.10s/it]



 63%|██████▎   | 19/30 [00:21<00:11,  1.06s/it]



 67%|██████▋   | 20/30 [00:22<00:08,  1.20it/s]



 70%|███████   | 21/30 [00:23<00:07,  1.15it/s]



 73%|███████▎  | 22/30 [00:24<00:08,  1.05s/it]



 77%|███████▋  | 23/30 [00:25<00:07,  1.03s/it]



 80%|████████  | 24/30 [00:26<00:06,  1.08s/it]



 83%|████████▎ | 25/30 [00:27<00:04,  1.01it/s]



 87%|████████▋ | 26/30 [00:28<00:03,  1.00it/s]



 90%|█████████ | 27/30 [00:30<00:03,  1.09s/it]



 93%|█████████▎| 28/30 [00:30<00:01,  1.03it/s]



 97%|█████████▋| 29/30 [00:31<00:00,  1.27it/s]



100%|██████████| 30/30 [00:31<00:00,  1.06s/it]
  0%|          | 0/30 [00:00<?, ?it/s]



  3%|▎         | 1/30 [00:00<00:15,  1.84it/s]



  7%|▋         | 2/30 [00:01<00:21,  1.32it/s]



 10%|█         | 3/30 [00:02<00:23,  1.16it/s]



 13%|█▎        | 4/30 [00:03<00:25,  1.00it/s]



 17%|█▋        | 5/30 [00:04<00:22,  1.09it/s]



 20%|██        | 6/30 [00:05<00:26,  1.11s/it]



 23%|██▎       | 7/30 [00:07<00:27,  1.20s/it]



 27%|██▋       | 8/30 [00:07<00:22,  1.02s/it]



 30%|███       | 9/30 [00:09<00:23,  1.10s/it]



 33%|███▎      | 10/30 [00:09<00:18,  1.11it/s]



 37%|███▋      | 11/30 [00:10<00:14,  1.32it/s]



 40%|████      | 12/30 [00:11<00:15,  1.19it/s]



 43%|████▎     | 13/30 [00:11<00:12,  1.40it/s]



 47%|████▋     | 14/30 [00:12<00:13,  1.20it/s]



 50%|█████     | 15/30 [00:13<00:12,  1.23it/s]



 53%|█████▎    | 16/30 [00:14<00:13,  1.07it/s]



 57%|█████▋    | 17/30 [00:15<00:13,  1.04s/it]



 60%|██████    | 18/30 [00:16<00:11,  1.03it/s]



 63%|██████▎   | 19/30 [00:18<00:12,  1.10s/it]



 67%|██████▋   | 20/30 [00:19<00:10,  1.08s/it]



 70%|███████   | 21/30 [00:20<00:10,  1.11s/it]



 73%|███████▎  | 22/30 [00:21<00:09,  1.20s/it]



 77%|███████▋  | 23/30 [00:22<00:07,  1.12s/it]



 80%|████████  | 24/30 [00:23<00:07,  1.17s/it]



 83%|████████▎ | 25/30 [00:24<00:05,  1.12s/it]



 87%|████████▋ | 26/30 [00:25<00:03,  1.00it/s]



 90%|█████████ | 27/30 [00:26<00:02,  1.00it/s]



 93%|█████████▎| 28/30 [00:27<00:01,  1.14it/s]



 97%|█████████▋| 29/30 [00:28<00:00,  1.20it/s]



100%|██████████| 30/30 [00:28<00:00,  1.06it/s]
  0%|          | 0/30 [00:00<?, ?it/s]



  3%|▎         | 1/30 [00:00<00:09,  3.08it/s]



  7%|▋         | 2/30 [00:01<00:17,  1.60it/s]



 10%|█         | 3/30 [00:02<00:21,  1.26it/s]



 13%|█▎        | 4/30 [00:03<00:24,  1.07it/s]



 17%|█▋        | 5/30 [00:04<00:23,  1.06it/s]



 20%|██        | 6/30 [00:05<00:27,  1.13s/it]



 23%|██▎       | 7/30 [00:06<00:25,  1.12s/it]



 27%|██▋       | 8/30 [00:07<00:20,  1.07it/s]



 30%|███       | 9/30 [00:08<00:20,  1.05it/s]



 33%|███▎      | 10/30 [00:08<00:15,  1.27it/s]



 37%|███▋      | 11/30 [00:09<00:13,  1.43it/s]



 40%|████      | 12/30 [00:10<00:14,  1.22it/s]



 43%|████▎     | 13/30 [00:11<00:12,  1.32it/s]



 47%|████▋     | 14/30 [00:12<00:16,  1.04s/it]



 50%|█████     | 15/30 [00:14<00:17,  1.18s/it]



 53%|█████▎    | 16/30 [00:15<00:15,  1.12s/it]



 57%|█████▋    | 17/30 [00:16<00:16,  1.26s/it]



 60%|██████    | 18/30 [00:17<00:14,  1.24s/it]



 63%|██████▎   | 19/30 [00:18<00:12,  1.17s/it]



 67%|██████▋   | 20/30 [00:19<00:10,  1.02s/it]



 70%|███████   | 21/30 [00:20<00:09,  1.11s/it]



 73%|███████▎  | 22/30 [00:22<00:10,  1.36s/it]



 77%|███████▋  | 23/30 [00:24<00:09,  1.40s/it]



 80%|████████  | 24/30 [00:25<00:07,  1.20s/it]



 83%|████████▎ | 25/30 [00:26<00:05,  1.15s/it]



 87%|████████▋ | 26/30 [00:27<00:04,  1.18s/it]



 90%|█████████ | 27/30 [00:29<00:04,  1.50s/it]



 93%|█████████▎| 28/30 [00:30<00:02,  1.31s/it]



 97%|█████████▋| 29/30 [00:30<00:01,  1.04s/it]



100%|██████████| 30/30 [00:31<00:00,  1.05s/it]
  0%|          | 0/30 [00:00<?, ?it/s]



  3%|▎         | 1/30 [00:00<00:11,  2.47it/s]



  7%|▋         | 2/30 [00:02<00:31,  1.12s/it]



 10%|█         | 3/30 [00:03<00:29,  1.09s/it]



 13%|█▎        | 4/30 [00:04<00:29,  1.13s/it]



 17%|█▋        | 5/30 [00:05<00:33,  1.33s/it]



 20%|██        | 6/30 [00:08<00:39,  1.64s/it]



 23%|██▎       | 7/30 [00:09<00:33,  1.47s/it]



 27%|██▋       | 8/30 [00:09<00:25,  1.14s/it]



 30%|███       | 9/30 [00:11<00:28,  1.34s/it]



 33%|███▎      | 10/30 [00:11<00:20,  1.04s/it]



 37%|███▋      | 11/30 [00:12<00:16,  1.16it/s]



 40%|████      | 12/30 [00:13<00:15,  1.19it/s]



 43%|████▎     | 13/30 [00:13<00:13,  1.25it/s]



 47%|████▋     | 14/30 [00:15<00:18,  1.14s/it]



 50%|█████     | 15/30 [00:17<00:20,  1.38s/it]



 53%|█████▎    | 16/30 [00:18<00:17,  1.28s/it]



 57%|█████▋    | 17/30 [00:19<00:15,  1.22s/it]



 60%|██████    | 18/30 [00:20<00:12,  1.06s/it]



 63%|██████▎   | 19/30 [00:21<00:11,  1.06s/it]



 67%|██████▋   | 20/30 [00:22<00:08,  1.16it/s]



 70%|███████   | 21/30 [00:22<00:08,  1.12it/s]



 73%|███████▎  | 22/30 [00:24<00:09,  1.15s/it]



 77%|███████▋  | 23/30 [00:26<00:09,  1.34s/it]



 80%|████████  | 24/30 [00:27<00:07,  1.21s/it]



 83%|████████▎ | 25/30 [00:28<00:05,  1.12s/it]



 87%|████████▋ | 26/30 [00:29<00:04,  1.05s/it]



 90%|█████████ | 27/30 [00:30<00:03,  1.07s/it]



 93%|█████████▎| 28/30 [00:30<00:01,  1.07it/s]



 97%|█████████▋| 29/30 [00:31<00:00,  1.24it/s]



100%|██████████| 30/30 [00:31<00:00,  1.06s/it]


In [13]:
min_depth

array([1, 5, 6, 7, 6, 7, 7, 1, 7, 3, 2, 5, 3, 5, 7, 7, 8, 4, 6, 1, 7, 8,
       4, 6, 4, 5, 4, 2, 2, 2])

In [14]:
np.mean(min_depth)

4.733333333333333

array([1, 5, 6, 4, 6, 7, 7, 2, 6, 3, 2, 5, 3, 7, 5, 7, 6, 3, 5, 2, 4, 9,
       2, 6, 5, 5, 3, 2, 2, 1]) -> 4.366666666666666