In [1]:
import os
import glob

from dataclasses import dataclass
from pathlib import Path
import shutil

import tensorflow as tf
import numpy as np
from tqdm import tqdm
from qiskit import QuantumCircuit
import yaml
import ray

from rl.network import ResNet
from rl.mcts import MCTS
from rl.buffer import ReplayBuffer

from rl.game import Game, encode_state

with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)
base_path = "graphs"
index = "20241120"
qubits = config["game_settings"]["N"]
gate = config["game_settings"]["gate"]
layer = config["game_settings"]["layer"]
training_settings = config["training_settings"]
network_settings = config["network_settings"]
mcts_settings = config["mcts_settings"]
num_cpus = training_settings["num_cpus"]
num_gpus = training_settings["num_gpus"]
n_episodes = training_settings["n_episodes"]
buffer_size = training_settings["buffer_size"]
batch_size = training_settings["batch_size"]
epochs_per_update = training_settings["epochs_per_update"]
update_period = training_settings["update_period"]
save_period = training_settings["save_period"]
game = Game(qubits, config)


def evaluate_self_play(qubits, network):
    pattern = os.path.join(base_path, f"adj_matrix_{qubits}_*.npy")
    file_paths = glob.glob(pattern)
    avg_depth = []
    avg_counts = []
    for file_path in file_paths:
        state = np.load(file_path)
        swap_pairs = []
        done = False
        total_score = 0
        step_count = 0
        prev_action = None
        while not done and step_count < game.MAX_STEPS:
            encoded_state = encode_state(state, qubits)
            input_state = np.expand_dims(encoded_state, axis=0)
            policy_output, value_output = network.predict(input_state)
            policy = policy_output[0]
            if prev_action is not None:
                indices = [i for i in range(len(game.coupling_map)) if i != prev_action]
                prob = policy[indices]
                if prob.sum() < 1e-9:
                    action = np.random.choice(indices)
                else:
                    action = np.random.choice(indices, p=prob / prob.sum())
            else:
                indices = list(range(len(game.coupling_map)))
                action = np.random.choice(indices, p=policy)
            selected_action = game.coupling_map[action]
            swap_pairs.append(selected_action)
            state, done, _ = game.step(state, action, prev_action)
            prev_action = action
            step_count += 1
        qc = QuantumCircuit(qubits)
        if not done:
            depth = game.MAX_STEPS
            swap_count = game.MAX_STEPS
        else:
            for swap in swap_pairs:
                qc.swap(*swap)
            depth = qc.depth(lambda instr: instr.name == "swap")
            swap_count = len(swap_pairs)
        avg_counts.append(swap_count)
        avg_depth.append(depth)
    return np.mean(avg_depth), np.mean(avg_counts)
@dataclass
class Sample:
    state: np.ndarray
    mcts_policy: np.ndarray
    reward: float


@ray.remote(num_cpus=num_cpus, num_gpus=num_gpus)
def selfplay(weights, qubits, current_episode, test=False):
    """Perform a self-play game and collect training data."""
    record = []
    if test:
        state = game.get_initial_test_state()
    else:
        state = game.state
    game.reset_used_columns()
    network = ResNet(action_space=len(game.coupling_map), config=config)

    # Initialize network parameters
    network.predict(encode_state(state, qubits))
    network.set_weights(weights)

    mcts = MCTS(qubits=qubits, network=network, config=config)
    mcts.current_episode = current_episode
    done = False
    total_score = 0
    step_count = 0
    prev_action = None
    while not done and step_count < game.MAX_STEPS:
        mcts_policy = mcts.search(
            root_state=state,
            num_simulations=mcts_settings["num_mcts_simulations"],
            prev_action=prev_action,
        )
        if prev_action is not None:
            indices = [i for i in range(len(game.coupling_map)) if i != prev_action]
            prob = mcts_policy[indices]
            prob = prob / prob.sum()
            # if use_network_policy:
            #     prob = np.ones(len(prob))/len(prob)
            action = np.random.choice(indices, p=prob)
        else:
            indices = list(range(len(game.coupling_map)))
            prob = mcts_policy
            # if use_network_policy:
            #     prob = np.ones(len(mcts_policy))/len(mcts_policy)
            action = np.random.choice(indices, p=prob)
        record.append(Sample(state.copy(), mcts_policy, reward=None))
        state, done, action_score = game.step(state, action, prev_action)
        prev_action = action
        # print(state, action_score, done)
        total_score += action_score
        step_count += 1

    # The reward is calculated based on the final state
    reward = game.get_reward(state, total_score)
    # Assign the reward to each sample
    for sample in record:
        sample.reward = reward
    return record

In [2]:
ray.init(num_cpus=num_cpus, num_gpus=num_gpus, local_mode=False, include_dashboard=True)
print(ray.available_resources())
logdir = Path("log")
if logdir.exists():
    shutil.rmtree(logdir)
summary_writer = tf.summary.create_file_writer(str(logdir))


network = ResNet(action_space=len(game.coupling_map), config=config)

dummy_state = encode_state(game.state, qubits)
network.predict(dummy_state)

# current_weights = network.get_weights()
current_weights = ray.put(network.get_weights())

optimizer = tf.keras.optimizers.Adam(learning_rate=network_settings["learning_rate"])

replay = ReplayBuffer(buffer_size=buffer_size)

# # Start self-play workers
# work_in_progresses = [selfplay(current_weights,qubits, True)]
work_in_progresses = [
    selfplay.remote(current_weights, qubits, True) for _ in range(num_cpus - 1)
]
n_updates = 0

n = 0
while n < n_episodes:
    for _ in tqdm(range(update_period)):
        # finished = selfplay(current_weights,qubits, n, test=False)
        # replay.add_record(finished)
        finished, work_in_progresses = ray.wait(work_in_progresses, num_returns=1)
        replay.add_record(ray.get(finished[0]))
        work_in_progresses.extend([selfplay.remote(current_weights, qubits, True)])
        n += 1

    # Update network
    if len(replay) >= batch_size:
        num_iters = epochs_per_update * (len(replay) // batch_size)
        value_loss_weight = 1
        policy_loss_weight = 1
        for i in range(num_iters):
            states, mcts_policy, rewards = replay.get_minibatch(batch_size=batch_size)
            with tf.GradientTape() as tape:
                p_pred, v_pred = network(states, training=True)
                value_loss = tf.square(rewards - v_pred)
                policy_loss = -tf.reduce_sum(
                    mcts_policy * tf.math.log(p_pred + 1e-5), axis=1, keepdims=True
                )
                loss = tf.reduce_mean(
                    value_loss_weight * value_loss + policy_loss_weight * policy_loss
                )
            grads = tape.gradient(loss, network.trainable_variables)
            grads, _ = tf.clip_by_global_norm(grads, 1.0)
            optimizer.apply_gradients(zip(grads, network.trainable_variables))
            n_updates += 1

            if i % 10 == 0:
                with summary_writer.as_default():
                    tf.summary.scalar(
                        "value_loss", tf.reduce_mean(value_loss), step=n_updates
                    )
                    tf.summary.scalar(
                        "policy_loss", tf.reduce_mean(policy_loss), step=n_updates
                    )

        # current_weights = network.get_weights()k？
        current_weights = ray.put(network.get_weights())

    if n % save_period == 0:
        network.save(f"checkpoints/network{qubits}_{index}_{n}.keras")
        network.save_weights(f"checkpoints/network{qubits}_{index}_{n}.weights.h5")
        depth, count = evaluate_self_play(qubits,network)
        print(f"SWAP depth is {depth}, SWAP count is {count}")
        print("-" * 50)

2024-11-30 12:28:40,875	INFO worker.py:1807 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m


{'CPU': 8.0, 'memory': 40702418944.0, 'node:127.0.0.1': 1.0, 'object_store_memory': 2147483648.0, 'node:__internal_head__': 1.0}


100%|██████████| 50/50 [03:32<00:00,  4.25s/it]


SWAP depth is 7.533333333333333, SWAP count is 11.133333333333333
--------------------------------------------------


100%|██████████| 50/50 [02:51<00:00,  3.44s/it]


SWAP depth is 6.8, SWAP count is 10.533333333333333
--------------------------------------------------


100%|██████████| 50/50 [02:06<00:00,  2.54s/it]


SWAP depth is 7.033333333333333, SWAP count is 10.8
--------------------------------------------------


100%|██████████| 50/50 [01:59<00:00,  2.39s/it]


SWAP depth is 7.033333333333333, SWAP count is 11.033333333333333
--------------------------------------------------


100%|██████████| 50/50 [01:55<00:00,  2.32s/it]


SWAP depth is 6.9, SWAP count is 10.6
--------------------------------------------------


100%|██████████| 50/50 [01:45<00:00,  2.11s/it]


SWAP depth is 6.733333333333333, SWAP count is 9.933333333333334
--------------------------------------------------


100%|██████████| 50/50 [01:40<00:00,  2.00s/it]


SWAP depth is 6.233333333333333, SWAP count is 9.9
--------------------------------------------------


100%|██████████| 50/50 [01:36<00:00,  1.93s/it]


SWAP depth is 6.533333333333333, SWAP count is 10.3
--------------------------------------------------


100%|██████████| 50/50 [01:36<00:00,  1.92s/it]


SWAP depth is 6.466666666666667, SWAP count is 10.133333333333333
--------------------------------------------------


100%|██████████| 50/50 [01:30<00:00,  1.81s/it]


SWAP depth is 6.866666666666666, SWAP count is 10.7
--------------------------------------------------


100%|██████████| 50/50 [01:31<00:00,  1.82s/it]


SWAP depth is 7.1, SWAP count is 11.233333333333333
--------------------------------------------------


100%|██████████| 50/50 [01:31<00:00,  1.82s/it]


SWAP depth is 6.266666666666667, SWAP count is 9.766666666666667
--------------------------------------------------


 88%|████████▊ | 44/50 [01:22<00:13,  2.18s/it]

In [3]:
import numpy as np
import tensorflow as tf
from rl.network import ResNet
import rl.game as game

game = Game(qubits, config)

network = ResNet(action_space=len(game.coupling_map),config=config)
# network = tf.keras.models.load_model(f"checkpoints/network{qubits}_{index}_700.keras")
# 初期状態の生成
network.load_weights(f"checkpoints/network{qubits}_{index}_700.weights.h5")

In [14]:
for _ in range(12):
    state = game.state
    ans = []
    done = False
    total_score = 0
    step_count = 0
    prev_action = None
    print(state)
    while not done and step_count < game.MAX_STEPS:
        encoded_state = encode_state(state, qubits)
        input_state = np.expand_dims(encoded_state, axis=0)

        policy_output, value_output = network.predict(input_state)
        policy = policy_output[0]
        if prev_action is not None:
            indices = [i for i in range(len(game.coupling_map)) if i != prev_action]

            prob = policy[indices]
            if prob.sum() < 1e-6:
                action = np.random.choice(indices)
            else:
                action = np.random.choice(indices, p=prob / prob.sum())
        else:
            indices = list(range(len(game.coupling_map)))
            action = np.random.choice(indices, p=policy)
        selected_action = game.coupling_map[action]
        ans.append(selected_action)
        state, done, _ = game.step(state, action, prev_action)
        prev_action = action
        step_count += 1

    if done:
        print(f"Game finished successfully in {step_count} steps with {ans}")
    else:
        print(f"Game terminated after reaching the maximum steps ({game.MAX_STEPS}).")
        print(f"Total score: {total_score}")

[[0. 0. 1. 1. 1. 1. 0.]
 [0. 0. 0. 1. 0. 1. 1.]
 [1. 0. 0. 0. 1. 1. 0.]
 [1. 1. 0. 0. 0. 1. 1.]
 [1. 0. 1. 0. 0. 0. 1.]
 [1. 1. 1. 1. 0. 0. 0.]
 [0. 1. 0. 1. 1. 0. 0.]]
Game terminated after reaching the maximum steps (25).
Total score: 0
[[0. 0. 1. 1. 1. 1. 0.]
 [0. 0. 0. 1. 0. 1. 1.]
 [1. 0. 0. 0. 1. 1. 0.]
 [1. 1. 0. 0. 0. 1. 1.]
 [1. 0. 1. 0. 0. 0. 1.]
 [1. 1. 1. 1. 0. 0. 0.]
 [0. 1. 0. 1. 1. 0. 0.]]
Game terminated after reaching the maximum steps (25).
Total score: 0
[[0. 0. 1. 1. 1. 1. 0.]
 [0. 0. 0. 1. 0. 1. 1.]
 [1. 0. 0. 0. 1. 1. 0.]
 [1. 1. 0. 0. 0. 1. 1.]
 [1. 0. 1. 0. 0. 0. 1.]
 [1. 1. 1. 1. 0. 0. 0.]
 [0. 1. 0. 1. 1. 0. 0.]]
Game terminated after reaching the maximum steps (25).
Total score: 0
[[0. 0. 1. 1. 1. 1. 0.]
 [0. 0. 0. 1. 0. 1. 1.]
 [1. 0. 0. 0. 1. 1. 0.]
 [1. 1. 0. 0. 0. 1. 1.]
 [1. 0. 1. 0. 0. 0. 1.]
 [1. 1. 1. 1. 0. 0. 0.]
 [0. 1. 0. 1. 1. 0. 0.]]
Game terminated after reaching the maximum steps (25).
Total score: 0
[[0. 0. 1. 1. 1. 1. 0.]
 [0. 0. 0. 1. 0. 1. 