In [1]:
import sys, os, time

sys.path.append("../..")
import pyzx as zx
from benchmarking import Benchmark
import numpy as np
import argparse
import json
import random
import time
from distutils.util import strtobool
from typing import Optional

import gym
import gym_zx
import numpy as np
import pyzx as zx
import torch
import torch.nn as nn
import torch_geometric.nn as geom_nn
from torch import Tensor
from torch.distributions.categorical import Categorical
from torch.nn.functional import softmax
from torch_geometric.data import Batch, Data
from torch_geometric.nn import Sequential as geo_Sequential
from torch_geometric.nn import aggr
from torch_geometric.typing import Adj, OptTensor, PairTensor, SparseTensor, torch_sparse
from torch_geometric.utils import softmax as geom_softmax

global device

# device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
device = torch.device("cuda")

In [2]:
class CategoricalMasked(Categorical):
    def __init__(self, probs=None, logits=None, validate_args=None, masks=None, device=None):
        if masks is None:
            masks = []
        self.masks = masks
        if len(self.masks) != 0:
            self.masks = masks.type(torch.BoolTensor).to(device)
            logits = torch.where(self.masks, logits, torch.tensor(-1e8).to(device))
        super(CategoricalMasked, self).__init__(probs, logits, validate_args)

    def entropy(self):
        if len(self.masks) == 0:
            return super(CategoricalMasked, self).entropy()
        p_log_p = self.logits * self.probs
        p_log_p = torch.where(self.masks, p_log_p, torch.tensor(0.0).to(device))
        return -p_log_p.sum(-1)


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


class ModAttentionalAggregation(geom_nn.Aggregation):
    r"""The soft attention aggregation layer from the `"Graph Matching Networks
    for Learning the Similarity of Graph Structured Objects"
    <https://arxiv.org/abs/1904.12787>`_ paper

    .. math::
        \mathbf{r}_i = \sum_{n=1}^{N_i} \mathrm{softmax} \left(
        h_{\mathrm{gate}} ( \mathbf{x}_n ) \right) \cdot
        h_{\mathbf{\Theta}} ( \mathbf{x}_n ),

    where :math:`h_{\mathrm{gate}} \colon \mathbb{R}^F \to
    \mathbb{R}` and :math:`h_{\mathbf{\Theta}}` denote neural networks, *i.e.*
    MLPs.

    Args:
        gate_nn (torch.nn.Module): A neural network :math:`h_{\mathrm{gate}}`
            that computes attention scores by mapping node features :obj:`x` of
            shape :obj:`[-1, in_channels]` to shape :obj:`[-1, 1]` (for
            node-level gating) or :obj:`[1, out_channels]` (for feature-level
            gating), *e.g.*, defined by :class:`torch.nn.Sequential`.
        nn (torch.nn.Module, optional): A neural network
            :math:`h_{\mathbf{\Theta}}` that maps node features :obj:`x` of
            shape :obj:`[-1, in_channels]` to shape :obj:`[-1, out_channels]`
            before combining them with the attention scores, *e.g.*, defined by
            :class:`torch.nn.Sequential`. (default: :obj:`None`)
    """

    def __init__(self, gate_nn: torch.nn.Module, nn: Optional[torch.nn.Module] = None):
        super().__init__()
        self.gate_nn = gate_nn
        self.nn = nn
        self.reset_parameters()

    def reset_parameters(self):
        geom_nn.inits.reset(self.gate_nn)
        geom_nn.inits.reset(self.nn)

    def forward(
        self,
        x: Tensor,
        index: Optional[Tensor] = None,
        ptr: Optional[Tensor] = None,
        dim_size: Optional[int] = None,
        dim: int = -2,
    ) -> Tensor:
        self.assert_two_dimensional_input(x, dim)
        gate = self.gate_nn(x)
        x = self.nn(x) if self.nn is not None else x
        gate = geom_softmax(gate, index, ptr, dim_size, dim)

        return self.reduce(gate * x, index, ptr, dim_size, dim), gate

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(gate_nn={self.gate_nn}, " f"nn={self.nn})"


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


class AgentGNN(nn.Module):
    def __init__(
        self,
        envs,
        device,
        c_hidden=32,
        c_hidden_v=32,
        c_out=64,
        c_at=4,
        dp_rate=0.0,
        num_layers=1,
        layer_type="GraphConv",
        **kwargs,
    ):
        super().__init__()

        self.device = device
        self.obs_shape = envs.envs[0].shape
        self.bin_required = int(np.ceil(np.log2(self.obs_shape)))
        self.qubits = envs.envs[0].qubits

        c_in_p = 16
        c_in_v = 11
        edge_dim = 6
        edge_dim_v = 3
        # self.critic_gnn = ModAttentionalAggregation(gate_nn=nn.Sequential(nn.Linear(c_out, 1)))
        self.global_attention_critic = geom_nn.GlobalAttention(
            gate_nn=nn.Sequential(
                nn.Linear(c_hidden, c_hidden),
                nn.ReLU(),
                nn.Linear(c_hidden, c_hidden),
                nn.ReLU(),
                nn.Linear(c_hidden, 1),
            ),
            nn=nn.Sequential(nn.Linear(c_hidden, c_hidden_v), nn.ReLU(), nn.Linear(c_hidden_v, c_hidden_v), nn.ReLU()),
        )

        self.critic_gnn = geo_Sequential(
            "x, edge_index, edge_attr",
            [
                (
                    geom_nn.GATv2Conv(c_in_v, c_hidden, edge_dim=edge_dim_v, add_self_loops=True),
                    "x, edge_index, edge_attr -> x",
                ),
                nn.ReLU(),
                (
                    geom_nn.GATv2Conv(c_hidden, c_hidden, edge_dim=edge_dim_v, add_self_loops=True),
                    "x, edge_index, edge_attr -> x",
                ),
                nn.ReLU(),
                (
                    geom_nn.GATv2Conv(c_hidden, c_hidden, edge_dim=edge_dim_v, add_self_loops=True),
                    "x, edge_index, edge_attr -> x",
                ),
                nn.ReLU(),
                (
                    geom_nn.GATv2Conv(c_hidden, c_hidden, edge_dim=edge_dim_v, add_self_loops=True),
                    "x, edge_index, edge_attr -> x",
                ),
                nn.ReLU(),
                (
                    geom_nn.GATv2Conv(c_hidden, c_hidden, edge_dim=edge_dim_v, add_self_loops=True),
                    "x, edge_index, edge_attr -> x",
                ),
                nn.ReLU(),
            ],
        )

        self.actor_gnn = geom_nn.Sequential(
            "x, edge_index, edge_attr",
            [
                (
                    geom_nn.GATv2Conv(c_in_p, c_hidden, edge_dim=edge_dim, add_self_loops=True),
                    "x, edge_index, edge_attr -> x",
                ),
                nn.ReLU(),
                (
                    geom_nn.GATv2Conv(c_hidden, c_hidden, edge_dim=edge_dim, add_self_loops=True),
                    "x, edge_index, edge_attr -> x",
                ),
                nn.ReLU(),
                (
                    geom_nn.GATv2Conv(c_hidden, c_hidden, edge_dim=edge_dim, add_self_loops=True),
                    "x, edge_index, edge_attr -> x",
                ),
                nn.ReLU(),
                (
                    geom_nn.GATv2Conv(c_hidden, c_hidden, edge_dim=edge_dim, add_self_loops=True),
                    "x, edge_index, edge_attr -> x",
                ),
                nn.ReLU(),
                (
                    geom_nn.GATv2Conv(c_hidden, c_hidden, edge_dim=edge_dim, add_self_loops=True),
                    "x, edge_index, edge_attr -> x",
                ),
                nn.ReLU(),
                (nn.Linear(c_hidden, c_hidden),),
                nn.ReLU(),
                (nn.Linear(c_hidden, 1),),
            ],
        )

        self.critic_ff = nn.Sequential(
            nn.Linear(c_hidden_v, c_hidden_v),
            nn.ReLU(),
            nn.Linear(c_hidden_v, c_hidden_v),
            nn.ReLU(),
            nn.Linear(c_hidden_v, out_features=1),
        )

    def actor(self, x):
        logits = self.actor_gnn(x.x, x.edge_index, x.edge_attr)
        # logits = self.actor_gnn(x.x, x.edge_index)
        return logits

    """
    def critic(self, x, y, batch):
        non_action_nodes = torch.where(y != -1)[0].to(device)
        aggregated = self.global_attention_critic(x[non_action_nodes], batch[non_action_nodes])
        return self.critic_ff(aggregated)
    """

    def critic(self, x):
        features = self.critic_gnn(x.x, x.edge_index, x.edge_attr)
        aggregated = self.global_attention_critic(features, x.batch)
        return self.critic_ff(aggregated)

    def get_action_and_value(self, x, action_mask=None, action=None):
        """
        features = self.gnn(x.x, x.edge_index, x.edge_attr)
        logits = self.actor_gnn(features, x.edge_index, x.edge_attr)
        values = self.critic(features, x.y, x.batch)
        """
        policy_obs, value_obs = x
        logits = self.actor(policy_obs)
        # values = self.critic(value_obs)

        batch_logits = torch.zeros([x[0].num_graphs, self.obs_shape]).to(device)
        act_mask = torch.zeros([x[0].num_graphs, self.obs_shape]).to(device)
        act_ids = torch.zeros([x[0].num_graphs, self.obs_shape]).to(device)
        action_logits = torch.tensor([]).to(device)
        # indices = torch.triu_indices(max_node, max_node).to(device)
        for b in range(x[0].num_graphs):

            ids = x[0].y[x[0].batch == b].to(device)
            action_nodes = torch.where(ids != -1)[0].to(device)
            probs = logits[x[0].batch == b][action_nodes]
            batch_logits[b, : probs.shape[0]] = probs.flatten()
            act_mask[b, : probs.shape[0]] = torch.tensor([True] * probs.shape[0])
            act_ids[b, : action_nodes.shape[0]] = ids[action_nodes]
            action_logits = torch.cat((action_logits, probs.flatten()), 0).reshape(-1)

        categoricals = CategoricalMasked(logits=batch_logits, masks=act_mask, device=device)

        # Convert the list of samples back to a tensor
        # values = values.squeeze(-1)
        if action is None:
            action = categoricals.sample()
            stop_indx = self.obs_shape - self.qubits
            batch_id = torch.arange(x[0].num_graphs)
            action_id = act_ids[batch_id, action]

        return action.T, action_id.T

    def get_value(self, x):
        values = self.critic(x)
        return values


def parse_args():
    # fmt: off
    parser = argparse.ArgumentParser()
    # Algorithm specific arguments
    parser.add_argument("--num-envs", type=int, default=1,
        help="the number of parallel game environments") #default 8
    parser.add_argument("--num-episodes", type=int, default=100,
        help="the number of episodes to run")
    parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
        help="if toggled, `torch.backends.cudnn.deterministic=False`")
    parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
        help="if toggled, cuda will be enabled by default")
    parser.add_argument("--seed", type=int, default=10000,
        help="seed of the experiment")
    parser.add_argument("--num-steps", type=int, default=128,
        help="the number of steps to run in each environment per policy rollout")
    parser.add_argument("--gym-id", type=str, default="zx-v0",
        help="the id of the gym environment")

    return parser.parse_known_args()[0]

In [3]:
b = Benchmark()

In [4]:
b.load_circuits(
    dirname=os.path.join("..", "..", "pyzx", "circuits", "benchmarking_circuits", "Fast", "before", "before_noQFT"),
    group_name="fast",
)
b.load_circuits(
    dirname=os.path.join("..", "..", "pyzx", "circuits", "benchmarking_circuits", "Fast", "nrscm", "nrscm_noQFT"),
    group_name="fast",
    simp_strategy="NRSCM",
)
b.load_circuits(
    dirname=os.path.join("..", "..", "pyzx", "circuits", "benchmarking_circuits", "Fast", "tpar", "tpar_noQFT"),
    group_name="fast",
    simp_strategy="TPar",
)
b.show_attributes()

Circuit attributes:  ['Qubits', 'Gates', '2Q Count', 'T Count', 't_opt']
No loaded functions
Loaded routines:  ['TPar', 'NRSCM']
Loaded circuit groups:  ['fast']


Unnamed: 0,Original,NRSCM,TPar
fast,Y,Y,-


In [5]:
def basic_optimise(c):
    c1 = zx.basic_optimization(c.copy(), do_swaps=False).to_basic_gates()
    c2 = zx.basic_optimization(c.copy(), do_swaps=True).to_basic_gates()
    if c2.twoqubitcount() < c1.twoqubitcount():
        return c2  # As this optimisation algorithm is targetted at reducting H-gates, we use the circuit with the smaller 2-qubit gate count here, either using SWAP rules or not.
    return c1


def flow_opt(c):
    g = c.to_graph()
    zx.teleport_reduce(g)
    zx.to_graph_like(g)
    zx.flow_2Q_simp(g)
    c2 = zx.extract_simple(g).to_basic_gates()
    return basic_optimise(c2)

In [6]:
def make_env(gym_id, seed, idx, capture_video, run_name, qubits, gates, circ):
    def thunk():
        env = gym.make(gym_id, qubits=qubits, depth=gates, circuit=circ)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env.action_space.seed(seed)
        env.observation_space.seed(seed)
        return env

    return thunk


def rl_zx(c):
    best_result = np.inf
    stats = c.stats_dict()
    qubits = stats["qubits"]
    depth = stats["depth"]
    args = parse_args()
    envs = gym.vector.SyncVectorEnv(
        [make_env(args.gym_id, args.seed + i, i, True, "Benchmarking", qubits, depth, c) for i in range(args.num_envs)]
    )

    agent = AgentGNN(envs, device).to(device)  # Pass the envs argument here

    agent.load_state_dict(
        torch.load("state_dict_model5x70_twoqubits_new.pt", map_location=torch.device("cpu"))
    )  # Add the map_location argument here
    agent.eval()

    done = False

    for episode in range(10):
        done = False
        obs0, reset_info = envs.reset()
        new_value_data = []
        new_policy_data = []
        for item in reset_info["graph_obs"]:
            policy_items, value_items = item[0], item[1]
            value_graph = Data(x=value_items[0], edge_index=value_items[1])
            policy_graph = Data(
                x=policy_items[0], edge_index=policy_items[1], edge_attr=policy_items[2], y=policy_items[3]
            )
            new_value_data.append(value_graph)
            new_policy_data.append(policy_graph)

        # next_obs = torch.Tensor(obs0).to(device)  # Torch.size([8,1,40,40])
        next_obs_graph = (
            Batch.from_data_list(new_policy_data).to(device),
            Batch.from_data_list(new_value_data).to(device),
        )
        state = next_obs_graph
        start = time.time()
        while not done:

            # action_masks = torch.Tensor(np.array(envs.envs[0].action_mask()).reshape(1, -1))
            # Get the action from the model
            action, action_id = agent.get_action_and_value(state, None)
            action = action.flatten()
            # Take the action in the environment
            next_obs, reward, done, deprecated, info = envs.step(action_id.cpu().numpy())
            new_value_data = []
            new_policy_data = []

            for item in info["graph_obs"]:
                policy_items, value_items = item[0], item[1]
                value_graph = Data(x=value_items[0], edge_index=value_items[1])
                policy_graph = Data(
                    x=policy_items[0], edge_index=policy_items[1], edge_attr=policy_items[2], y=policy_items[3]
                )
                new_value_data.append(value_graph)
                new_policy_data.append(policy_graph)

            # next_obs = torch.Tensor(next_obs).to(device)  # Torch.size([8,1,40,40])
            next_obs_graph = (
                Batch.from_data_list(new_policy_data).to(device),
                Batch.from_data_list(new_value_data).to(device),
            )
            next_done = torch.zeros(args.num_envs).to(device)
            state = next_obs_graph

        info = info["final_info"][0]
        twoq_gates = info["rl_stats"]["twoqubits"]
        if twoq_gates < best_result:
            circuit = info["final_circuit"]
            action_seq = info["action_sequence"]
            graph = info["initial_graph"]
            best_result = twoq_gates
    return circuit, action_seq, graph

In [7]:
graph = zx.generate.cliffordT(qubits=10, depth=100)
c =zx.Circuit.from_graph(graph)

final_circuit,action = rl_zx(c)



  logger.warn(
  logger.warn(
  logger.warn(f"{pre} is not within the observation space.")
  return action.T, action_id.T
  if not isinstance(terminated, (bool, np.bool8)):
  logger.warn(f"{pre} is not within the observation space.")


In [8]:
graph_like = zx.to_gh(c)

S L65 L37 I65 L29 I27 L31 L29 L53 L36 I13 L39 L33 L36 L29 L41 P23,33 P18,25 P29,34 P29,45 P25,26 P35,27 N


In [14]:
zx.draw()

In [7]:
b.add_simplification_func(func=rl_zx, name="RL-ZX", groups_to_run=None)
b.add_simplification_func(func=flow_opt, name="flow-opt", groups_to_run=["fast"], verify=True, rerun=False)



b.run(funcs_to_run=["RL-ZX"], groups_to_run=["fast"], verify=True, rerun=False)

b.show_attributes()

Processing flow-opt on barenco_tof_3                                  :   0%|          | 0/28 [00:00<?, ?it/s]

Processing flow-opt on gf2^7_mult                                     : 100%|██████████| 28/28 [02:29<00:00,  5.32s/it]
  logger.warn(
  logger.warn(
  logger.warn(f"{pre} is not within the observation space.")
  return action.T, action_id.T
  if not isinstance(terminated, (bool, np.bool8)):
  logger.warn(f"{pre} is not within the observation space.")
Processing RL-ZX on gf2^7_mult                                        : 100%|██████████| 28/28 [1:51:17<00:00, 238.49s/it]

Circuit attributes:  ['Qubits', 'Gates', '2Q Count', 'T Count', 't_opt']
Loaded functions:  ['RL-ZX', 'flow-opt']
Loaded routines:  ['NRSCM', 'TPar']
Loaded circuit groups:  ['fast']





Unnamed: 0,Original,NRSCM,RL-ZX,TPar,flow-opt
fast,Y,Y,Y,-,Y


In [8]:
df = b.df(groups=["fast"], routines=["all"], funcs=["all"], atts=["Qubits","Gates", "2Q Count"])

Unnamed: 0_level_0,Original,Original,Original,NRSCM,NRSCM,RL-ZX,RL-ZX,flow-opt,flow-opt
Unnamed: 0_level_1,Qubits,Gates,2Q Count,Gates,2Q Count,Gates,2Q Count,Gates,2Q Count
Circuits,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Adder8,23,637,243,190,94,334,183,237,112
adder_8,24,900,409,606,291,621,320,589,277
barenco_tof_10,19,450,192,264,130,365,176,320,146
barenco_tof_3,5,58,24,40,18,50,22,45,20
barenco_tof_4,7,114,48,72,34,95,44,89,37
barenco_tof_5,9,170,72,104,50,140,66,121,55
csla_mux_3_original,15,170,80,155,70,153,72,154,73
csum_mux_9_corrected,30,448,168,266,140,308,168,288,140
gf2^4_mult,12,243,99,187,99,178,99,174,94
gf2^5_mult,15,379,154,296,154,281,154,274,146


In [9]:
b.show_attributes()

Circuit attributes:  ['Qubits', 'Gates', '2Q Count', 'T Count', 't_opt']
Loaded functions:  ['full-reduce', 'RL-ZX']
Loaded routines:  ['NRSCM', 'TPar']
Loaded circuit groups:  ['fast']


Unnamed: 0,Original,NRSCM,RL-ZX,TPar,full-reduce
fast,Y,Y,Y,Y,Y


In [14]:
df = b.df(groups=["fast"], routines=["all"], funcs=["all"], atts=["Qubits", "Gates", "2Q Count"])

Unnamed: 0_level_0,Original,Original,Original,NRSCM,NRSCM,TPar,TPar,RL-ZX,RL-ZX,full-reduce,full-reduce
Unnamed: 0_level_1,Qubits,Gates,2Q Count,Gates,2Q Count,Gates,2Q Count,Gates,2Q Count,Gates,2Q Count
Circuits,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2
Adder8,23,637,243,190,94,-,-,334,183,271,163
adder_8,24,900,409,606,291,1280,885,621,320,707,442
barenco_tof_10,19,450,192,264,130,517,328,365,176,374,216
barenco_tof_3,5,58,24,40,18,82,54,50,22,66,34
barenco_tof_4,7,114,48,72,34,141,90,95,44,102,56
barenco_tof_5,9,170,72,104,50,206,132,140,66,130,68
csla_mux_3_original,15,170,80,155,70,-,-,153,72,255,172
csum_mux_9_corrected,30,448,168,266,140,-,-,308,168,484,336
gf2^4_mult,12,243,99,187,99,419,324,178,99,368,290
gf2^5_mult,15,379,154,296,154,682,535,281,154,704,578
