In [1]:
import quartz
quartz_context = quartz.QuartzContext(gate_set=['h', 'cx', 'x', 't', 'tdg'], filename='../bfs_verified_simplified.json')
parser = quartz.PyQASMParser(context=quartz_context)
my_dag = parser.load_qasm(filename="../circuit/nam-circuits/qasm_files/tof_4_before.qasm")
init_graph = quartz.PyGraph(context=quartz_context, dag=my_dag)

Using backend: pytorch[00:47:54] /opt/dgl/src/runtime/tensordispatch.cc:43: TensorDispatcher: dlopen failed: /home/zikunli/anaconda3/envs/quantum/lib/python3.9/site-packages/dgl/tensoradapter/pytorch/libtensoradapter_pytorch_1.10.2.so: cannot open shared object file: No such file or directory



In [3]:
import numpy as np
gate_type_num = 26
class QAgent:
    def __init__(self, lr, a_size):
        torch.manual_seed(42)
        self.q_net = QGNN(gate_type_num, 16, a_size, 16)
        self.target_net = copy.deepcopy(self.q_net)
        self.loss_fn = torch.nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr = lr)
        self.a_size = a_size 
               
    def select_a(self, g, e):
        a_size = self.a_size
        
        if random.random() < e:
            node = np.random.randint(0, g.num_nodes())
            A = np.random.randint(0, a_size)
            #print("random")
            
        else:
            with torch.no_grad():
                pred = self.q_net(g)
            Qs, As = torch.max(pred)
            Q, node = torch.max(Qs, dim = 0, keepdim = True)
            A = As[node]                 
        
        #print(node)
        #print(A)
        return node, A
    

    def train(self, data, batch_size):
        losses = 0
        pred_rs = []
        target_rs = []
        for i in range(batch_size):
            s, node, a, r, s_next = data.get_data()

            pred = self.q_net(s)
            pred_r = pred[node][a]
            #s_a = s_as.gather(1, a)

            if s_next == None:
                target_r = torch.tensor(-1.0)
            else:
                q_next = self.target_net(s_next).detach()
                target_r = r + self.gamma * q_next
            
            pred_rs.append(pred_r)
            target_rs.append(target_r)
        loss = self.loss_fn(torch.stack(pred_rs), torch.stack(target_rs))
        self.optimizer.zero_grad()
        loss.backward(retain_graph=True)
        for param in self.q_net.parameters():
            param.grad.data.clamp_(-1,1)
        self.optimizer.step()
              
        return loss.item()    

    
class QData:
    def __init__(self):
        self.data = deque(maxlen=100000)
        
    def add_data(self, d):
        self.data.append(d)
        
    def get_data(self):
        s = random.sample(self.data, 1)[0]
        #print(s)
        return s[0],s[1],s[2],s[3],s[4]
    

In [7]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn

class QConv(nn.Module):

    def __init__(self, in_feat, inter_dim, out_feat):
        super(QConv, self).__init__()
        self.linear2 = nn.Linear(in_feat + inter_dim, out_feat)
        self.linear1 = nn.Linear(in_feat + 3, inter_dim, bias=False)
        self.reset_parameters()

    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_normal_(self.linear1.weight, gain=gain)
        nn.init.xavier_normal_(self.linear2.weight, gain=gain)

    def message_func(self, edges):
        #print(f'node h {edges.src["h"].shape}')
        #print(f'node w {edges.data["w"].shape}')
        return {'m': torch.cat([edges.src['h'], edges.data['w']], dim=1)}

    def reduce_func(self, nodes):
        #print(f'node m {nodes.mailbox["m"].shape}')
        tmp = self.linear1(nodes.mailbox['m'])
        tmp = F.leaky_relu(tmp)
        h = torch.mean(tmp, dim=1)
        return {'h_N': h}

    def forward(self, g, h):
        g.ndata['h'] = h
        #g.edata['w'] = w #self.embed(torch.unsqueeze(w,1))
        g.update_all(self.message_func, self.reduce_func)
        h_N = g.ndata['h_N']
        h_total = torch.cat([h, h_N], dim=1)
        return self.linear2(h_total)



class QGNN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes, inter_dim):
        super(QGNN, self).__init__()
        self.conv1 = QConv(in_feats, inter_dim, h_feats)
        self.conv2 = QConv(h_feats, inter_dim, h_feats)
        self.conv3 = QConv(h_feats, inter_dim, h_feats)
        self.conv4 = QConv(h_feats, inter_dim, h_feats)
        self.conv5 = QConv(h_feats, inter_dim, num_classes)
        self.embedding = nn.Embedding(in_feats, in_feats)
    
    def forward(self, g):
        #print(g.ndata['gate_type'])
        #print(self.embedding)
        g.ndata['h'] = self.embedding(g.ndata['gate_type'])
        w = torch.cat([torch.unsqueeze(g.edata['src_idx'],1),torch.unsqueeze(g.edata['dst_idx'],1),torch.unsqueeze(g.edata['reversed'],1)],dim = 1)
        g.edata['w'] = w 
        h = self.conv1(g, g.ndata['h'])
        h = F.relu(h)
        h = self.conv2(g, h)
        h = F.relu(h)
        h = self.conv3(g, h)
        h = F.relu(h)
        h = self.conv4(g, h)
        h = F.relu(h)
        h = self.conv5(g, h)
        return h



In [5]:
# Pretraining using tof_3 circuit
# Preparaing data
from concurrent.futures import ProcessPoolExecutor

def get_dataset(i):
    dag_i = parser.load_qasm(filename="tof_3_opt_path/subst_history_" + str(i) + ".qasm")
    graph = quartz.PyGraph(context=quartz_context, dag=dag_i)
    dgl_graph = graph.to_dgl_graph()
    appliable_xfer_matrix = graph.get_available_xfers_matrix(context=quartz_context)
    dgl_graph.ndata['label'] = torch.tensor(appliable_xfer_matrix,dtype=torch.float)
    return dgl_graph

idx_list = list(range(40))
with ProcessPoolExecutor(max_workers=32) as executor:
    results = executor.map(get_dataset, idx_list)

opt_path_dgls = [r for r in results]

In [32]:
def train_supervised(g, model, lr=0.01, epochs=20):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    all_logits = []
    best_val_acc = 0
    best_test_acc = 0

    features = g.ndata['gate_type']

    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
    for e in range(epochs):
        # Forward
        logits = model(g)

        # Compute loss
        # Note that we should only compute the losses of the nodes in the training set,
        # i.e. with train_mask 1.
        #print(logits)
        
        loss = torch.nn.MSELoss()(logits[train_mask], labels[train_mask])
        pred = logits > 0.5

        # Compute accuracy on training/validation/test
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
        val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

        train_recall = torch.sum((torch.logical_and((pred[train_mask] == 1), (labels[train_mask] == 1))).float()) / torch.sum((labels[train_mask] == 1).float())
        val_recall = torch.sum((torch.logical_and((pred[val_mask] == 1), (labels[val_mask] == 1))).float()) / torch.sum((labels[val_mask] == 1).float())
        test_recall = torch.sum((torch.logical_and((pred[test_mask] == 1), (labels[test_mask] == 1))).float()) / torch.sum((labels[test_mask] == 1).float())

        # Save the best validation accuracy and the corresponding test accuracy.
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        all_logits.append(logits.detach())

        # TODO: print out false negative
        if e % 1 == 0:
            print('In epoch {}, loss: {:.5f}, train acc: {:.5f}, train recall: {:.5f}, val acc: {:.5f} (best {:.5f}), val recall: {:.5f}, test acc: {:.5f} (best {:.5f}), test recall: {:.5f}'.format(
                e, loss, train_acc, train_recall, val_acc, best_val_acc,val_recall, test_acc, best_test_acc, test_recall))

def test(*, filename):
    test_dag = parser.load_qasm(filename=filename)
    test_graph = quartz.PyGraph(context=quartz_context, dag=test_dag)
    test_graph_dgl = test_graph.to_dgl_graph()
    appliable_xfer_matrix = test_graph.get_available_xfers_matrix(context=quartz_context)
    test_graph_dgl.ndata['label'] = torch.tensor(appliable_xfer_matrix,dtype=torch.float)
    labels = test_graph_dgl.ndata['label']

    with torch.no_grad():
        logits = model(test_graph_dgl)
        pred = logits > 0.5
        test_acc = (pred == labels).float().mean()
        print(f"test_acc: {test_acc:.6f}")

In [33]:
from random import sample
bg = dgl.batch(opt_path_dgls)
node_cnt = bg.num_nodes()
l = list(range(node_cnt))
train_rate = 0.7
val_rate = 0.15

train_num = int(node_cnt * train_rate)
val_num = int(node_cnt * val_rate)
test_num = node_cnt - train_num - val_num

train_sample = sample(l, train_num)
node_left = [n for n in l if n not in train_sample]
val_sample = sample(node_left, val_num)
test_sample = [n for n in node_left if n not in val_sample]

train_mask = [0] * node_cnt
val_mask = [0] * node_cnt
test_mask = [0] * node_cnt

for i in range(node_cnt):
    if i in train_sample:
        train_mask[i] = 1
    elif i in val_sample:
        val_mask[i] = 1
    elif i in test_sample:
        test_mask[i] = 1
    else:
        assert False

bg.ndata['train_mask'] = torch.tensor(train_mask,dtype=torch.bool) 
bg.ndata['val_mask'] = torch.tensor(val_mask,dtype=torch.bool) 
bg.ndata['test_mask'] = torch.tensor(test_mask,dtype=torch.bool) 

model = QGNN(26, 16, quartz_context.num_xfers, 16)
train_supervised(bg, model, lr=0.05, epochs=20)

In epoch 0, loss: 0.08711, train acc: 0.95660, train recall: 0.02133, val acc: 0.95812 (best 0.95812), val recall: 0.01309, test acc: 0.95799 (best 0.95799), test recall: 0.01314
In epoch 1, loss: 0.04278, train acc: 0.99810, train recall: 0.00690, val acc: 0.99806 (best 0.99806), val recall: 0.00773, test acc: 0.99824 (best 0.99824), test recall: 0.00751
In epoch 2, loss: 0.03640, train acc: 0.99845, train recall: 0.08594, val acc: 0.99848 (best 0.99848), val recall: 0.09518, test acc: 0.99855 (best 0.99855), test recall: 0.10701
In epoch 3, loss: 0.00817, train acc: 0.99848, train recall: 0.00000, val acc: 0.99850 (best 0.99850), val recall: 0.00000, test acc: 0.99858 (best 0.99858), test recall: 0.00000
In epoch 4, loss: 0.00578, train acc: 0.99850, train recall: 0.01117, val acc: 0.99852 (best 0.99852), val recall: 0.01428, test acc: 0.99861 (best 0.99861), test recall: 0.01564
In epoch 5, loss: 0.00406, train acc: 0.99860, train recall: 0.07615, val acc: 0.99863 (best 0.99863), va

In [75]:
import random
from tqdm import tqdm
import copy
from collections import deque
import torch


agent = QAgent(lr = 1e-3, a_size = 1118)
data = QData()

replay_times = 10
episodes = 10
epsilon = 1
train_epoch = 5

for i in tqdm(range(episodes)):
    rewards = 0
    losses = 0
    for j in range(replay_times):
        count = 0
        end = False
        g = init_graph
        while(count < 10 and not end):
            dgl_g = g.to_dgl_graph()
            count += 1 
            node, A = agent.select_a(dgl_g, epsilon)
            # print(A)
            new_g = g.apply_xfer(xfer=quartz_context.get_xfer_from_id(id=A), node = g.all_nodes()[node])
            
            if new_g == None:
                end = True
                data.add_data([dgl_g, torch.tensor(node), torch.tensor(A), torch.tensor(-1), None])
            
            else:
                dgl_new_g = new_g.to_dgl_graph()
                reward = g.num_gates() - new_g.num_gates()
                                         
                data.add_data([dgl_g, torch.tensor(node), torch.tensor(A), torch.tensor(reward), dgl_new_g])
            
                g = new_g
                rewards += reward
        

    for j in range(train_epoch):
        loss = agent.train(data, 3)
        losses += loss  
        
    if epsilon > 0.05 :
        epsilon -= 0.0001
        

    agent.target_net.load_state_dict(agent.q_net.state_dict())


 10%|██████████████▊                                                                                                                                     | 1/10 [00:00<00:01,  7.64it/s]

766
284
275
266
407
788
216
662
1078
686
185
356
747
359
713
648
957
859
743
496


 30%|████████████████████████████████████████████▍                                                                                                       | 3/10 [00:00<00:00,  7.70it/s]

1011
531
658
121
713
977
378
370
372
332
379
155
840
197
118
880
875
1112
624
843


 50%|██████████████████████████████████████████████████████████████████████████                                                                          | 5/10 [00:00<00:00,  7.68it/s]

268
235
193
14
714
909
530
64
564
493
663
852
549
846
171
841
672
906
355
78


 70%|███████████████████████████████████████████████████████████████████████████████████████████████████████▌                                            | 7/10 [00:00<00:00,  7.61it/s]

493
269
248
1012
61
589
894
131
983
757
1082
980
164
1099
267
640
908
191
1065
883


 90%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏              | 9/10 [00:01<00:00,  7.65it/s]

836
962
457
979
1048
505
168
362
676
461
364
18
188
416
1097
381
132
657
132
42


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  7.63it/s]
