In [2]:
import quartz
quartz_context = quartz.QuartzContext(gate_set=['h', 'cx', 'x', 't', 'tdg'], filename='../bfs_verified_simplified.json')
parser = quartz.PyQASMParser(context=quartz_context)
my_dag = parser.load_qasm(filename="../circuit/example-circuits/barenco_tof_3.qasm")
my_dag.num_qubits, my_dag.num_gates
init_graph = quartz.PyGraph(context=quartz_context, dag=my_dag)

Using backend: pytorch


In [70]:
import numpy as np
gate_type_num = 20
class QAgent:
    def __init__(self, lr, a_size):
        torch.manual_seed(42)
        self.q_net = QGNN(gate_type_num, 16, a_size, 16)
        self.target_net = copy.deepcopy(self.q_net)
        self.loss_fn = torch.nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr = lr)
        self.a_size = a_size 
               
    def select_a(self, g, e):
        a_size = self.a_size
        
        if random.random() < e:
            node = np.random.randint(0, g.num_nodes())
            A = np.random.randint(0, a_size)
            #print("random")
            
        else:
            with torch.no_grad():
                pred = self.q_net(g)
            Qs, As = torch.max(pred)
            Q, node = torch.max(Qs, dim = 0, keepdim = True)
            A = As[node]                 
        
        #print(node)
        #print(A)
        return node, A
    

    def train(self, data, batch_size):
        losses = 0
        pred_rs = []
        target_rs = []
        for i in range(batch_size):
            s, node, a, r, s_next = data.get_data()

            pred = self.q_net(s)
            pred_r = pred[node][a]
            #s_a = s_as.gather(1, a)

            if s_next == None:
                target_r = torch.tensor(-1.0)
            else:
                q_next = self.target_net(s_next).detach()
                target_r = r + self.gamma * q_next
            
            pred_rs.append(pred_r)
            target_rs.append(target_r)
        loss = self.loss_fn(torch.stack(pred_rs), torch.stack(target_rs))
        self.optimizer.zero_grad()
        loss.backward(retain_graph=True)
        for param in self.q_net.parameters():
            param.grad.data.clamp_(-1,1)
        self.optimizer.step()
              
        return loss.item()    

    
class QData:
    def __init__(self):
        self.data = deque(maxlen=100000)
        
    def add_data(self, d):
        self.data.append(d)
        
    def get_data(self):
        s = random.sample(self.data, 1)[0]
        #print(s)
        return s[0],s[1],s[2],s[3],s[4]
    

In [73]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn

class QConv(nn.Module):

    def __init__(self, in_feat, inter_dim, out_feat):
        super(QConv, self).__init__()
        self.linear2 = nn.Linear(in_feat + inter_dim, out_feat)
        self.linear1 = nn.Linear(in_feat + 3, inter_dim, bias=False)
        self.reset_parameters()

    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_normal_(self.linear1.weight, gain=gain)
        nn.init.xavier_normal_(self.linear2.weight, gain=gain)

    def message_func(self, edges):
        #print(f'node h {edges.src["h"].shape}')
        #print(f'node w {edges.data["w"].shape}')
        return {'m': torch.cat([edges.src['h'], edges.data['w']], dim=1)}

    def reduce_func(self, nodes):
        #print(f'node m {nodes.mailbox["m"].shape}')
        tmp = self.linear1(nodes.mailbox['m'])
        tmp = F.leaky_relu(tmp)
        h = torch.mean(tmp, dim=1)
        return {'h_N': h}

    def forward(self, g, h):
        g.ndata['h'] = h
        #g.edata['w'] = w #self.embed(torch.unsqueeze(w,1))
        g.update_all(self.message_func, self.reduce_func)
        h_N = g.ndata['h_N']
        h_total = torch.cat([h, h_N], dim=1)
        return self.linear2(h_total)



class QGNN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes, inter_dim):
        super(QGNN, self).__init__()
        self.conv1 = QConv(in_feats, inter_dim, h_feats)
        self.conv2 = QConv(h_feats, inter_dim, h_feats)
        self.conv3 = QConv(h_feats, inter_dim, h_feats)
        self.conv4 = QConv(h_feats, inter_dim, h_feats)
        self.conv5 = QConv(h_feats, inter_dim, num_classes)
        self.embedding = nn.Embedding(in_feats, in_feats)
    
    def forward(self, g):
        #print(g.ndata['gate_type'])
        #print(self.embedding)
        g.ndata['h'] = self.embedding(g.ndata['gate_type'])
        w = torch.cat([torch.unsqueeze(g.edata['src_idx'],1),torch.unsqueeze(g.edata['dst_idx'],1),torch.unsqueeze(g.edata['reversed'],1)],dim = 1)
        g.edata['w'] = w 
        h = self.conv1(g, g.ndata['h'])
        h = F.relu(h)
        h = self.conv2(g, h)
        h = F.relu(h)
        h = self.conv3(g, h)
        h = F.relu(h)
        h = self.conv4(g, h)
        h = F.relu(h)
        h = self.conv5(g, h)
        return h



In [75]:
import random
from tqdm import tqdm
import copy
from collections import deque
import torch


agent = QAgent(lr = 1e-3, a_size = 1118)
data = QData()

replay_times = 10
episodes = 10
epsilon = 1
train_epoch = 5

for i in tqdm(range(episodes)):
    rewards = 0
    losses = 0
    for j in range(replay_times):
        count = 0
        end = False
        g = init_graph
        while(count < 10 and not end):
            dgl_g = g.to_dgl_graph()
            count += 1 
            node, A = agent.select_a(dgl_g, epsilon)
            print(A)
            new_g = g.apply_xfer(xfer=quartz_context.get_xfer_from_id(id=A), node = g.all_nodes()[node])
            
            if new_g == None:
                end = True
                data.add_data([dgl_g, torch.tensor(node), torch.tensor(A), torch.tensor(-1), None])
            
            else:
                dgl_new_g = new_g.to_dgl_graph()
                reward = g.num_gates() - new_g.num_gates()
                                         
                data.add_data([dgl_g, torch.tensor(node), torch.tensor(A), torch.tensor(reward), dgl_new_g])
            
                g = new_g
                rewards += reward
        

    for j in range(train_epoch):
        loss = agent.train(data, 3)
        losses += loss  
        
    if epsilon > 0.05 :
        epsilon -= 0.0001
        

    agent.target_net.load_state_dict(agent.q_net.state_dict())


 10%|██████████████▊                                                                                                                                     | 1/10 [00:00<00:01,  7.64it/s]

766
284
275
266
407
788
216
662
1078
686
185
356
747
359
713
648
957
859
743
496


 30%|████████████████████████████████████████████▍                                                                                                       | 3/10 [00:00<00:00,  7.70it/s]

1011
531
658
121
713
977
378
370
372
332
379
155
840
197
118
880
875
1112
624
843


 50%|██████████████████████████████████████████████████████████████████████████                                                                          | 5/10 [00:00<00:00,  7.68it/s]

268
235
193
14
714
909
530
64
564
493
663
852
549
846
171
841
672
906
355
78


 70%|███████████████████████████████████████████████████████████████████████████████████████████████████████▌                                            | 7/10 [00:00<00:00,  7.61it/s]

493
269
248
1012
61
589
894
131
983
757
1082
980
164
1099
267
640
908
191
1065
883


 90%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏              | 9/10 [00:01<00:00,  7.65it/s]

836
962
457
979
1048
505
168
362
676
461
364
18
188
416
1097
381
132
657
132
42


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  7.63it/s]
