In [None]:
import os
import glob

from pathlib import Path
import shutil

import tensorflow as tf
import tf_keras as keras
import numpy as np
from tqdm import tqdm
import yaml

from rl.network import ResNet
from rl.mcts import MCTS
from rl.buffer import ReplayBuffer, Sample
from rl.game import Game, encode_state

with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
base_path = "graphs"
index = "20241212"
qubits = config["game_settings"]["N"]
training_settings = config["training_settings"]
network_settings = config["network_settings"]
mcts_settings = config["mcts_settings"]
num_cpus = training_settings["num_cpus"]
num_gpus = training_settings["num_gpus"]
n_episodes = training_settings["n_episodes"]
buffer_size = training_settings["buffer_size"]
batch_size = training_settings["batch_size"]
epochs_per_update = training_settings["epochs_per_update"]
update_period = training_settings["update_period"]
save_period = training_settings["save_period"]
eval_period = training_settings.get("eval_period", 100)


def selfplay(weights, qubits, current_episode, config):
    record = []
    game = Game(qubits, config)
    state = game.get_initial_state()
    game.reset_used_columns()
    network = ResNet(action_space=game.action_space, config=config)
    network.predict(encode_state(state, qubits))
    network.set_weights(weights)

    mcts = MCTS(qubits=qubits, network=network, config=config)
    done = False
    total_score = 0
    step_count = 0
    prev_action = None

    while not done and step_count < game.MAX_STEPS:
        mcts_policy = mcts.search(
            root_state=state,
            prev_action=prev_action,
            num_simulations=mcts_settings["num_mcts_simulations"],
        )
        if prev_action is not None:
            indices = [i for i in range(game.action_space) if i != prev_action]
            valid_actions = game.get_valid_actions(state, prev_action)
            prob = mcts_policy[valid_actions]
            prob = prob / prob.sum()
            action = np.random.choice(valid_actions, p=prob)
        else:
            indices = list(range(game.action_space))
            prob = mcts_policy
            action = np.random.choice(indices, p=prob)
        record.append(Sample(state.copy(), mcts_policy, reward=None))
        state, done, action_score = game.step(state, action, prev_action)
        prev_action = action
        # print(state, action_score, done)
        total_score += action_score
        step_count += 1

    reward = game.get_reward(state, total_score)
    for sample in record:
        sample.reward = reward
    return record


def evaluate_self_play(qubits, network, config):
    pattern = os.path.join(base_path, f"adj_matrix_{qubits}_*.npy")
    file_paths = glob.glob(pattern)
    avg_depth = []
    avg_counts = []
    for file_path in file_paths:
        state = np.load(file_path)
        # print(state)
        game = Game(qubits, config)
        swap_pairs = []
        done = False
        step_count = 0
        prev_action = None
        while not done and step_count < game.MAX_STEPS:
            encoded_state = encode_state(state, qubits)
            input_state = np.expand_dims(encoded_state, axis=0)
            policy_output, value_output = network.predict(input_state, verbose=0)
            policy = np.array(policy_output)[0]
            # policy = tf.nn.softmax(policy_logits).numpy()[0]
            # valid_actions = game.get_valid_actions(state, prev_action)
            if prev_action is not None:
                indices = [i for i in range(game.action_space) if i != prev_action]
                try:
                    valid_actions = game.get_valid_actions(state, prev_action)
                    prob = policy[valid_actions]
                except:
                    prob = policy[indices]
                action = np.random.choice(valid_actions, p=prob / prob.sum())
            else:
                indices = list(range(game.action_space))
                prob = policy
                action = np.random.choice(indices, p=prob / prob.sum())
            if action < len(game.coupling_map):
                selected_action = game.coupling_map[action]
                swap_pairs.append(selected_action)
            else:
                for pair in game.coupling_map[action%2::2]:
                    swap_pairs.append(pair)
            state, done, _ = game.step(state, action, prev_action)
            prev_action = action
            step_count += 1
        if not done:
            depth = game.MAX_STEPS
            swap_count = game.MAX_STEPS
        else:
            game.current_layer+=1
            depth = game.current_layer
            swap_count = len(swap_pairs)
        print(f"depth: {depth}, count: {swap_count},swap:{swap_pairs}")
        avg_counts.append(swap_count)
        avg_depth.append(depth)
    return avg_depth, avg_counts

In [None]:
game = Game(qubits, config)
network = ResNet(action_space=game.action_space, config=config)
network = keras.models.load_model(f"checkpoints/network{qubits}_{index}_50")

In [None]:
def evaluate_state_depth_like_sabre(qubits, network,config, state, reps=1):
    min_depth = float("inf")
    res_count = 0
    for _ in range(reps):
        depth, count, swap_pairs = evaluate_self_play_like_sabre(qubits, network, config,state)
        for col1,col2 in swap_pairs:
            state[:, [col1, col2]] = state[:, [col2, col1]]
            state[[col1, col2], :] = state[[col2, col1], :]
        if depth < min_depth:
            min_depth = depth
            res_count = count
    return min_depth, res_count, swap_pairs

def evaluate_self_play_like_sabre(qubits, network, config,state):
    game = Game(qubits, config)
    swap_pairs = []
    done = False
    step_count = 0
    prev_action = None
    while not done and step_count < game.MAX_STEPS:
        encoded_state = encode_state(state, qubits)
        input_state = np.expand_dims(encoded_state, axis=0)
        policy_output, value_output = network.predict(input_state)
        policy = np.array(policy_output)[0]
        # policy = tf.nn.softmax(policy_logits).numpy()[0]
        # valid_actions = game.get_valid_actions(state, prev_action)
        if prev_action is not None:
            indices = [i for i in range(game.action_space) if i != prev_action]
            try:
                valid_actions = game.get_valid_actions(state, prev_action)
                prob = policy[valid_actions]
            except:
                prob = policy[indices]
            action = np.random.choice(valid_actions, p=prob / prob.sum())
        else:
            indices = list(range(game.action_space))
            prob = policy
            action = np.random.choice(indices, p=prob / prob.sum())
        if action < len(game.coupling_map):
            selected_action = game.coupling_map[action]
            swap_pairs.append(selected_action)
        else:
            for pair in game.coupling_map[action%2::2]:
                swap_pairs.append(pair)
        state, done, _ = game.step(state, action, prev_action)
        prev_action = action
        step_count += 1
    if not done:
        depth = game.MAX_STEPS
        swap_count = game.MAX_STEPS
    else:
        game.current_layer+=1
        depth = game.current_layer
        swap_count = len(swap_pairs)
    print(f"depth: {depth}, count: {swap_count}")
    return depth, swap_count, swap_pairs

pattern = os.path.join(base_path, f"adj_matrix_{qubits}_*.npy")
file_paths = glob.glob(pattern)
depths = []
counts = []
for file_path in tqdm(file_paths):
    state = np.load(file_path)
    min_depth, count,swap_pairs = evaluate_state_depth_like_sabre(qubits,network,config,state,reps=30)
    depths.append(min_depth)
    counts.append(count)

In [None]:
np.mean(depths)

In [None]:
np.mean(counts)

In [None]:
depths = []
for _ in range(40):
    print("==========================")
    depth, count = evaluate_self_play(qubits, network, config)
    depths.append(depth)
min_depth = np.min(np.vstack(depths), axis=0)

In [None]:
min_depth

In [None]:
np.mean(min_depth)

In [None]:
np.mean(count)

In [None]:
qubits

array([1, 5, 6, 4, 6, 7, 7, 2, 6, 3, 2, 5, 3, 7, 5, 7, 6, 3, 5, 2, 4, 9,
       2, 6, 5, 5, 3, 2, 2, 1]) -> 4.366666666666666

#7 100 3.1666  24.366666