In [1]:
from torch.utils.tensorboard import SummaryWriter

2024-12-16 00:08:45.171921: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
import os
import glob

from pathlib import Path
import shutil

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm
import yaml


from rl.network import ResNet
from rl.mcts import MCTS
from rl.buffer import ReplayBuffer, Sample
from rl.game import Game

with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

base_path = "graphs"
index = "20241212"
qubits = config["game_settings"]["N"]
training_settings = config["training_settings"]
network_settings = config["network_settings"]
mcts_settings = config["mcts_settings"]
num_cpus = training_settings["num_cpus"]
num_gpus = training_settings["num_gpus"]
n_episodes = training_settings["n_episodes"]
buffer_size = training_settings["buffer_size"]
batch_size = training_settings["batch_size"]
epochs_per_update = training_settings["epochs_per_update"]
update_period = training_settings["update_period"]
save_period = training_settings["save_period"]
eval_period = training_settings["eval_period"]


def selfplay(qubits, network, config, device="cpu"):
    record = []
    game = Game(qubits, config)
    state = game.get_initial_state()
    game.reset_used_columns()

    mcts = MCTS(qubits=qubits, network=network, config=config)
    done = False
    total_score = 0
    step_count = 0
    prev_action = None

    while not done and step_count < game.MAX_STEPS:
        mcts_policy = mcts.search(
            root_state=state,
            prev_action=prev_action,
            num_simulations=mcts_settings["num_mcts_simulations"],
        )

        if prev_action is not None:
            indices = [i for i in range(game.action_space) if i != prev_action]
            valid_actions = game.get_valid_actions(state, prev_action)
            prob = mcts_policy[valid_actions]
            prob = prob / prob.sum()
            action = np.random.choice(valid_actions, p=prob)
        else:
            indices = list(range(game.action_space))
            prob = mcts_policy
            action = np.random.choice(indices, p=prob)
        record.append(Sample(state.copy(), mcts_policy, reward=None))
        state, done, action_score = game.step(state, action, prev_action)
        prev_action = action
        total_score += action_score
        step_count += 1

    reward = game.get_reward(state, total_score)
    for sample in record:
        sample.reward = reward
    return record


def evaluate_self_play(qubits, network, config, device="cpu"):
    pattern = os.path.join(base_path, f"adj_matrix_{qubits}_*.npy")
    file_paths = glob.glob(pattern)
    avg_depth = []
    avg_counts = []
    for file_path in file_paths:
        state = np.load(file_path)
        game = Game(qubits, config)
        swap_pairs = []
        done = False
        step_count = 0
        prev_action = None
        while not done and step_count < game.MAX_STEPS:
            network.eval()
            with torch.no_grad():
                policy_output, value_output = network(
                    torch.tensor(state, dtype=torch.float32)
                    .unsqueeze(0)
                    .unsqueeze(0)
                    .to(device)
                )
                policy = policy_output.cpu().numpy()[0]
            if prev_action is not None:
                indices = [i for i in range(game.action_space) if i != prev_action]
                try:
                    valid_actions = game.get_valid_actions(state, prev_action)
                    prob = policy[valid_actions]
                except:
                    prob = policy[indices]
                try:
                    action = np.random.choice(valid_actions, p=prob / prob.sum())
                except:
                    action = np.random.choice(valid_actions)
            else:
                indices = list(range(game.action_space))
                prob = policy
                action = np.random.choice(indices, p=prob / prob.sum())
            if action < len(game.coupling_map):
                selected_action = game.coupling_map[action]
                swap_pairs.append(selected_action)
            else:
                for pair in game.coupling_map[action % 2 :: 2]:
                    swap_pairs.append(pair)
            state, done, _ = game.step(state, action, prev_action)
            prev_action = action
            step_count += 1
        if not done:
            depth = game.MAX_STEPS
            swap_count = game.MAX_STEPS
        else:
            game.current_layer += 1
            depth = game.current_layer
            swap_count = len(swap_pairs)
        print(f"depth: {depth}, count: {swap_count}")
        avg_counts.append(swap_count)
        avg_depth.append(depth)
    return avg_depth, avg_counts

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
logdir = Path("log")
if logdir.exists():
    shutil.rmtree(logdir)
summary_writer = SummaryWriter(log_dir=logdir)

game = Game(qubits, config)
network = ResNet(action_space=game.action_space, config=config).to("cpu")

dummy_input = (
    torch.tensor(game.state, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to("cpu")
)
network(dummy_input)

optimizer = optim.Adam(network.parameters(), lr=network_settings["learning_rate"])


replay = ReplayBuffer(buffer_size=buffer_size)

n_updates = 0

n = 0
while n < n_episodes:
    for _ in tqdm(range(update_period)):
        network.eval()
        finished = selfplay(qubits, network, config)
        replay.add_record(finished)
        n += 1

    print("-" * 50)
    network.to(device)
    if len(replay) >= batch_size:
        num_iters = epochs_per_update * (len(replay) // batch_size)
        value_loss_weight = 0.5
        policy_loss_weight = 1.5

        for i in tqdm(range(num_iters)):
            states, mcts_policy, rewards = replay.get_minibatch(batch_size=batch_size)
            states = torch.tensor(states, dtype=torch.float32).to(device)
            mcts_policy = torch.tensor(mcts_policy, dtype=torch.float32).to(device)
            rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
            network.train()

            policy_pred, value_pred = network(states)
            value_loss = torch.mean((rewards - value_pred.squeeze()) ** 2)
            policy_loss = -torch.sum(
                mcts_policy * torch.log(policy_pred + 1e-5), dim=1
            ).mean()
            loss = value_loss_weight * value_loss + policy_loss_weight * policy_loss

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(network.parameters(), max_norm=0.5)
            optimizer.step()

            n_updates += 1

            if i % 5 == 0:
                summary_writer.add_scalar("value_loss", value_loss.item(), n_updates)
                summary_writer.add_scalar("policy_loss", policy_loss.item(), n_updates)

    if n % save_period == 0:
        torch.save(network.state_dict(), f"checkpoints/network{qubits}_{index}_{n}.pth")
        print(f"Model saved: checkpoints/network{qubits}_{index}_{n}.pth")
        print("-" * 50)
    if n % eval_period == 0:
        network.eval()
        with torch.no_grad():
            depth, count = evaluate_self_play(qubits, network, config, device=device)
        print(
            f"Episode {n}: SWAP depth is {np.mean(depth)}, SWAP count is {np.mean(count)}"
        )
        print("-" * 50)

  0%|          | 0/50 [00:00<?, ?it/s]


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument weight in method wrapper_CUDA___slow_conv2d_forward)

In [None]:
game = Game(qubits, config)

checkpoint_path = f"checkpoints/network{qubits}_{index}_5.pth"
network = ResNet(action_space=game.action_space, config=config)
network.load_state_dict(torch.load(checkpoint_path))
network.eval()

print("Model loaded successfully.")

Model loaded successfully.


  network.load_state_dict(torch.load(checkpoint_path))


In [None]:
depths = []
for _ in range(4):
    depth, count = evaluate_self_play(qubits, network, config)
    depths.append(depth)
min_depth = np.min(np.vstack(depths), axis=0)

depth: 13, count: 13
depth: 13, count: 13
depth: 13, count: 13
depth: 13, count: 13
depth: 13, count: 13
depth: 8, count: 12
depth: 13, count: 13
depth: 13, count: 13
depth: 13, count: 13
depth: 9, count: 11
depth: 7, count: 11
depth: 13, count: 13
depth: 8, count: 13
depth: 13, count: 13
depth: 7, count: 9
depth: 13, count: 13
depth: 13, count: 13
depth: 8, count: 9
depth: 7, count: 12
depth: 13, count: 13
depth: 9, count: 11
depth: 13, count: 13
depth: 13, count: 13
depth: 13, count: 13
depth: 9, count: 12
depth: 13, count: 13
depth: 10, count: 11
depth: 13, count: 13
depth: 6, count: 9
depth: 13, count: 13
depth: 13, count: 13
depth: 6, count: 7
depth: 13, count: 13
depth: 13, count: 13
depth: 7, count: 10
depth: 13, count: 13
depth: 11, count: 13
depth: 7, count: 12
depth: 9, count: 12
depth: 13, count: 13
depth: 13, count: 13
depth: 8, count: 11
depth: 13, count: 13
depth: 13, count: 13
depth: 13, count: 13
depth: 13, count: 13
depth: 13, count: 13
depth: 13, count: 13
depth: 8, c

In [None]:
min_depth

array([13,  6,  8, 13,  7,  8, 10,  7,  9,  7,  7,  8,  8,  7,  7, 13,  7,
        8,  7,  9,  9,  7, 13, 13,  9, 13,  9,  9,  6, 13])

In [None]:
dummy_input = torch.randn(1, 1, qubits, qubits)  # 例: 入力が8x8の行列の場合
onnx_path = f"checkpoints/network{qubits}_{index}_5.onnx"

# モデルをONNX形式でエクスポート
torch.onnx.export(
    network,
    dummy_input,
    onnx_path,
    input_names=["input"],
    output_names=["policy", "value"],
    dynamic_axes={
        "input": {0: "batch_size"},
        "policy": {0: "batch_size"},
        "value": {0: "batch_size"},
    },
    opset_version=15,
)

print(f"ONNX model saved to {onnx_path}")

ONNX model saved to checkpoints/network6_20241212_5.onnx


array([1, 5, 6, 4, 6, 7, 7, 2, 6, 3, 2, 5, 3, 7, 5, 7, 6, 3, 5, 2, 4, 9,
       2, 6, 5, 5, 3, 2, 2, 1]) -> 4.366666666666666