In [1]:
# The agent model for the environment input, candidate: list[tuple[int, int]], sabre_dag: graph, current_layout: list[int], distance_matrix: list[list[int]]
# action -> tuple[int, int]
import numpy as np
import tqdm
import networkx as nx

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

In [2]:
class RolloutBuffer:
    def __init__(self) -> None:
        self.buffer = list()

    def store(self, transition):
        self.buffer.append(transition)

    def sample(self):
        s, a, r, s_prime, done = zip(*self.buffer)
        s = list(map(lambda x: map(lambda y: torch.tensor(y, dtype=torch.float32), x), s))
        a = list(map(lambda x: torch.tensor(x, dtype=torch.int64), a))
        s_prime = list(map(lambda x: map(lambda y: torch.tensor(y, dtype=torch.float32), x), s_prime))
        self.buffer.clear()
        return (
            s,
            a,
            torch.tensor(r).unsqueeze(1),
            s_prime,
            torch.tensor(done, dtype=torch.int).unsqueeze(1)
        )
    
    @property
    def size(self):
        return len(self.buffer)
    

In [None]:
# Suppose the candidate is fixed, S_a, and the next swap gate is H
# the input of the model is (batch_size, N, H)

class PolicyModel(nn.Module):
    def __init__(self, state_dim: tuple[int, int] = (20, 20), action_dim: int = 20, hidden_dims: tuple = (512, )):
        super(PolicyModel, self).__init__()
        self.flatten_layer = nn.Flatten()
        self.input_layer = nn.Linear(state_dim[0] * state_dim[1], hidden_dims[0])
        self.hidden_layers = nn.ModuleList()
        for i in range(len(hidden_dims) - 1):
            self.hidden_layers.append(nn.Linear(hidden_dims[i], hidden_dims[i + 1]))
        self.mu_layer = nn.Linear(hidden_dims[-1], action_dim)
        self.log_std_layer = nn.Linear(hidden_dims[-1], action_dim)
        self.activation_fn = F.tanh

    def forward(self, x):
        x = self.activation_fn(self.input_layer(self.flatten_layer(x)))
        for layer in self.hidden_layers:
            x = self.activation_fn(layer(x))

        mu = self.mu_layer(x)
        log_std = torch.tanh(self.log_std_layer(x))
        
        return mu, log_std

class ValueModel(nn.Module):
    def __init__(self, state_dim: tuple[int, int] = (20, 20), hidden_dims: tuple = (512, )):
        super(PolicyModel, self).__init__()
        self.flatten_layer = nn.Flatten()
        self.input_layer = nn.Linear(state_dim[0] * state_dim[1], hidden_dims[0])
        self.hidden_layers = nn.ModuleList()
        for i in range(len(hidden_dims) - 1):
            self.hidden_layers.append(nn.Linear(hidden_dims[i], hidden_dims[i + 1]))
        self.output_layer = nn.Linear(hidden_dims[-1], 1)
        self.activation_fn = F.tanh

    def forward(self, x): 
        x = self.activation_fn(self.input_layer(self.flatten_layer(x)))
        for layer in self.hidden_layers:
            x = self.activation_fn(layer(x))
        return self.output_layer(x)

In [4]:
class RLAgent:
    def __init__(self, state_dim, action_dim, hidden_dims=(512,), n_steps=2048, n_epochs=10, batch_size=64, policy_lr=3e-4, value_lr=1e-3, gamma=0.99, lmda=0.95, clip_ratio=0.2, vf_coef = 1.0, ent_coef = 0.01):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.policy = PolicyModel(state_dim, action_dim, hidden_dims).to(self.device)
        self.value = ValueModel(state_dim, hidden_dims).to(self.device)
        self.n_steps = n_steps
        self.n_epochs = n_epochs
        self.batch_size = batch_size
        self.batch_size = batch_size
        self.lmda = lmda
        self.gamma = gamma
        self.clip_ratio = clip_ratio
        self.vf_coef = vf_coef
        self.ent_coef = ent_coef

        self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=policy_lr)
        self.value_optimizer = torch.optim.Adam(self.value.parameters(), lr=value_lr)

        self.buffer = RolloutBuffer()

    @torch.no_grad()
    def act(self, s, training=True):
        self.policy.train(training)
        
        s = tuple(map(lambda x: torch.tensor(x, dtype=torch.float32).unsqueeze(0).to(self.device), s))
        mu, std = map(lambda x: x.squeeze(0) ,self.policy(s[0], s[1], s[2]))
        z = torch.normal(mu, std) if training else mu
        action = torch.tanh(z)

        return action.cpu().numpy()
    
    def learn(self):
        # set train
        self.policy.train()
        self.value.train()

        # sample from buffer
        s, a, r, s_prime, done = self.buffer.sample()
        candidate, adjacency, current_layout = map(lambda x: list(map(lambda y: torch.tensor(y, dtype=torch.float32), x)), map(list, zip(*s)))
        candidate = pad_sequence(candidate, batch_first=True, padding_value=-1).to(self.device)
        max_N = max(map(lambda x: x.shape[0], adjacency))
        adjacency = [torch.tensor(np.pad(x, ((0, max_N - x.shape[0]), (0, max_N - x.shape[0])), mode='constant'), dtype=torch.float32) for x in adjacency]
        adjacency = torch.stack(adjacency).to(self.device)
        current_layout = torch.stack(current_layout).to(self.device)
        
        candidate_prime, adjacency_prime, current_layout_prime = map(lambda x: list(map(lambda y: torch.tensor(y, dtype=torch.float32), x)), map(list, zip(*s_prime)))
        candidate_prime = pad_sequence(candidate_prime, batch_first=True, padding_value=-1).to(self.device)
        max_N_prime = max(map(lambda x: x.shape[0], adjacency_prime))
        adjacency_prime = [torch.tensor(np.pad(x, ((0, max_N_prime - x.shape[0]), (0, max_N_prime - x.shape[0])), mode='constant'), dtype=torch.float32) for x in adjacency_prime]
        adjacency_prime = torch.stack(adjacency_prime).to(self.device)
        current_layout_prime = torch.stack(current_layout_prime).to(self.device)
        
        a = pad_sequence(a, batch_first=True, padding_value=0).to(self.device)
        r = r.to(self.device)
        done = done.to(self.device)

        # calculate advantages and returns
        with torch.no_grad():
            delta = r + (1 - done) * self.gamma * self.value(candidate, adjacency, current_layout) - self.value(candidate_prime, adjacency_prime, current_layout_prime) 
            adv = torch.clone(delta) 
            ret = torch.clone(r) 
            for t in reversed(range(len(r) - 1)):
                adv[t] += (1 - done[t]) * self.gamma * self.lmda * adv[t + 1]
                ret[t] += (1 - done[t]) * self.gamma * ret[t + 1]

            mu, std = self.policy(candidate, adjacency, current_layout)
            m = torch.distributions.Normal(mu, std)
            z = torch.atanh(torch.clamp(a, -1.0 + 1e-7, 1.0 - 1e-7))
            log_prob_old = m.log_prob(z).sum(dim=-1, keepdims=True)

        dts = TensorDataset(candidate, adjacency, current_layout, a, ret, adv, log_prob_old)
        loader = DataLoader(dts, batch_size=self.batch_size, shuffle=True)
        policy_losses, value_losses, entropy_bonuses = [], [], []
        for e in range(self.n_epochs):
            value_losses, policy_losses, entropy_bonuses = [], [], []
            for batch in loader:
                candidate_, adjacency_, current_layout_, a_, ret_, adv_, log_prob_old_ = batch
                # 가치 네트워크의 손실함수 계산
                value = self.value(candidate_, adjacency_, current_layout_)
                value_loss = F.mse_loss(value, ret_)

                # 정책 네트워크의 손실함수 계산
                mu, std = self.policy(candidate_, adjacency_, current_layout_)
                m = torch.distributions.Normal(mu, std)
                z = torch.atanh(torch.clamp(a_, -1.0 + 1e-7, 1.0 - 1e-7))
                log_prob = m.log_prob(z).sum(dim=-1, keepdims=True)
                
                ratio = (log_prob - log_prob_old_).exp()
                surr1 = adv_ * ratio
                surr2 = adv_ * torch.clamp(ratio, 1.0 - self.clip_ratio, 1.0 + self.clip_ratio)

                policy_loss = -torch.min(surr1, surr2).mean()
                entropy_bonus = -m.entropy().mean()

                loss = policy_loss + self.vf_coef * value_loss + self.ent_coef * entropy_bonus
                self.value_optimizer.zero_grad()
                self.policy_optimizer.zero_grad()
                loss.backward()
                self.value_optimizer.step()
                self.policy_optimizer.step()

                value_losses.append(value_loss.item())
                policy_losses.append(policy_loss.item())
                entropy_bonuses.append(-entropy_bonus.item())

        result = {'policy_loss': np.mean(policy_losses),
                  'value_loss': np.mean(value_losses),
                  'entropy_bonus': np.mean(entropy_bonuses)}

        return result

    def step(self, transition):
        result = None
        self.buffer.store(transition)
        if self.buffer.size >= self.n_steps:
            result = self.learn()

        return result


In [5]:
import os
import gymnasium as gym

import qiskit.qasm2
from qiskit_ibm_runtime.fake_provider import FakeAlmadenV2
from qiskit.transpiler import CouplingMap

In [6]:
# Set up the environment
provider = FakeAlmadenV2()
coupling_map = CouplingMap(provider.configuration().coupling_map)
data_path = '../data'
file_list = os.listdir(data_path)
file_list = [f for f in file_list if f.endswith('.qasm')]
paper_file_list = [
    '4mod5-v1_22.qasm',
    'mod5mils_65.qasm',
    'alu-v0_27.qasm',
    'decod24-v2_43.qasm',
    '4gt13_92.qasm',
    'ising_model_10.qasm',
    'ising_model_13.qasm',
    'ising_model_16.qasm',
    'qft_10.qasm',
    'qft_13.qasm',
    'qft_16.qasm',
    'qft_20.qasm',
    'rd84_142.qasm',
    'adr4_197.qasm',
    'radd_250.qasm',
    'z4_268.qasm',
    'sym6_145.qasm',
    'misex1_241.qasm',
    'rd73_252.qasm',
    'cycle10_2_110.qasm',
    'square_root_7.qasm',
    'sqn_258.qasm',
    'rd84_253.qasm',
    'co14_215.qasm',
    'sym9_193.qasm',
    '9symml_195.qasm',
]
paper_file_list = [file for file in file_list if file in paper_file_list]
file_list = [f for i, f in enumerate(file_list) if i % 5 == 0]
circuits = [qiskit.qasm2.load(os.path.join(data_path, f)) for f in file_list]
gym.register(
    id="SaberSwap-v0", entry_point="algorithm.sabre_env:SabreSwapEnv")
env = gym.make(
    id="SaberSwap-v0", circuits = circuits, coupling_map=coupling_map)

In [None]:
# In the IBM's paper, the input is fixed length S_a, number of Swap candidates, and H, the next swap candidate
def unpack_state(state):
    pass

In [7]:
def evaluate(agent, eval_iterations):
    env = gym.make("SaberSwap-v0", circuits=circuits, coupling_map=coupling_map)
    scores = []
    for i in range(eval_iterations):
        s, _ = env.reset()
        s = tuple(s.values()) 
        s = s[0], nx.to_numpy_array(s[1]), list(map(lambda x: x._index, s[2])) # Convert the graph to a numpy array of adjacency matrix
        done = False
        score = 0.0
        while not done:
            action = agent.act(s, training=False)
            s_prime, r, terminated, truncated, info = env.step(s[0][action.argmax()])
            s_prime = tuple(s_prime.values())
            s_prime = s_prime[0], nx.to_numpy_array(s_prime[1]), list(map(lambda x: x._index, s_prime[2]))
            score += float(r)
            done = terminated or truncated
            s = s_prime
        scores.append(score)
    env.close()
    return round(np.mean(scores), 4)

In [None]:
# train the agent
max_iterations = 1000000
eval_intervals = 100000
eval_iterations = 10

env = gym.make("SaberSwap-v0", circuits=circuits, coupling_map=coupling_map)
agent = RLAgent(
    state_dim=20,
    action_dim=10,
    hidden_dims=(512, 512),
    n_steps=1024,
    n_epochs=10,
    batch_size=64,
    policy_lr=3e-4,
    value_lr=1e-3,
    gamma=0.99,
    lmda=0.95,
    clip_ratio=0.2,
    vf_coef=1.0,
    ent_coef=0.01
)

logger = []
s, _ = env.reset()
s = tuple(s.values()) 
s = s[0], nx.to_numpy_array(s[1]), list(map(lambda x: x._index, s[2])) # Convert the graph to a numpy array of adjacency matrix
for i in tqdm.tqdm(range(1, 1 + max_iterations)):
    a = agent.act(s, training=True)
    s_prime, r, terminated, truncated, info = env.step(s[0][a.argmax()])
    s_prime = tuple(s_prime.values())
    s_prime = s_prime[0], nx.to_numpy_array(s_prime[1]), list(map(lambda x: x._index, s_prime[2]))
    done = terminated or truncated
    transition = (s, a, r, s_prime, done)
    result = agent.step(transition)
    s = s_prime

    if result is not None:
        logger.append(result)

    if done:
        s, _ = env.reset()
        s = tuple(s.values()) 
        s = s[0], nx.to_numpy_array(s[1]), list(map(lambda x: x._index, s[2]))

    if (i + 1) % eval_intervals == 0:
        score = evaluate(agent, eval_iterations)
        print(f"Iteration {i + 1}, Score: {score}")

  candidate, adjacency, current_layout = map(lambda x: list(map(lambda y: torch.tensor(y, dtype=torch.float32), x)), map(list, zip(*s)))
  candidate_prime, adjacency_prime, current_layout_prime = map(lambda x: list(map(lambda y: torch.tensor(y, dtype=torch.float32), x)), map(list, zip(*s_prime)))
  value_loss = F.mse_loss(value, ret_)
  7%|▋         | 66141/1000000 [13:39<7:04:33, 36.66it/s] 