In [13]:
from queue import Queue
import numpy as np
import math
from operator import itemgetter
from cacheout import Cache

edge_count = 0

class SystolicArrayCell:
    def __init__(self, row_n, col_n):
        self.pos_x = 0
        self.pos_y = 0
        self.row_n = row_n
        self.col_n = col_n
        
        #ring register
        self.receive_cell = None  #接收数据寄存器
        self.receive_reg = 1
        self.receive_out = 0
        
        self.receive_cell_ring = None  #接收数据寄存器
        self.receive_reg_ring = 1
        self.receive_out_ring = 0
        #self.send    =   #发送数据寄存器
        
        #On chip buffer
        self.result_bank_input = None
        
        #edge update
        self.process_id = 0
        self.process_id_out = 0
        self.process_id_ring = 0
        self.process_id_ring_out = 0
        
        self.next_src = -1
        self.next_dst = -1
        self.src = -1
        self.dst = -1
        self.rb_depth = 0
        self.rb_value = 0
        self.edge_empty   = False
        self.edge_compute = True
        self.hold         = False
        self.edge_number = 0
        
        self.cache_bank = None

    # Connects this cell to its neighbors above and to the left
    def connect(self, pos_x, pos_y, array):
        self.pos_x = pos_x
        self.pos_y = pos_y
        self.edge_number = pos_y
        #ring dataflow
        if self.pos_y is array.row_n-1:
            self.receive_cell = array.cells[0][self.pos_x] #cell 第一个代表行数，也就是Y， 第二个代表列数，也就是X
        # Otherwise, it's another cell
        else:
            self.receive_cell = array.cells[self.pos_y+1][self.pos_x]
        
        if self.pos_y is 0:
            self.receive_cell_ring = array.cells[array.row_n-1][self.pos_x] #cell 第一个代表行数，也就是Y， 第二个代表列数，也就是X
        # Otherwise, it's another cell
        else:
            self.receive_cell_ring = array.cells[self.pos_y-1][self.pos_x]
            
        #each PE on the same row connect to the same result bank
        self.result_bank_input = array.result_bank[self.pos_y][self.pos_x]
        self.edge_bank         = array.edge_bank[self.pos_y]
        self.cache_bank        = array.cache_bank[self.pos_y]

    def set_process_id(self, idx):
        self.process_id = idx
        self.process_id_ring = idx
        self.cache_bank.set(idx,'none')
        
    # We'll model the transfer of signals through registers with a read() and a
    # compute() method. 
    # read() represents the registers sampling data at the positive edge of the
    # clock
    def read(self, edge_update):
        #ring dataflow 
        #print("Enter | cell({:d},{:d}) next_src {:d}, next_dst {:d}, src {:d}, dst {:d}, process_id {:d}". format(self.pos_x, self.pos_y, self.next_src, self.next_dst, self.src, self.dst, self.process_id))
        if self.edge_bank.empty():
            self.edge_empty = True
        elif self.edge_compute or edge_update:
            self.src, self.dst = self.edge_bank.get()
            self.hold         = True
            self.edge_compute = False
        else:
            self.edge_compute = False
        
        if edge_update:
            self.process_id = self.process_id
            self.process_id_ring = self.process_id_ring #add
        else:
            self.receive_reg = self.receive_cell.receive_out
            self.process_id  = self.receive_cell.process_id_out
            self.receive_reg_ring = self.receive_cell_ring.receive_out_ring #add
            self.process_id_ring = self.receive_cell_ring.process_id_ring_out
            if self.cache_bank.has(self.process_id) is False:
                self.cache_bank.set(self.process_id,'none')
            if self.cache_bank.has(self.process_id_ring) is False:
                self.cache_bank.set(self.process_id_ring,'none')
        
        #print("Medium | cell({:d},{:d}) next_src {:d}, next_dst {:d}, src {:d}, dst {:d}, process_id {:d}". format(self.pos_x, self.pos_y, self.next_src, self.next_dst, self.src, self.dst, self.process_id))
        if self.cache_bank.has(self.src) and self.hold:
            self.edge_compute = True
            self.hold         = False
        else:
            self.edge_compute = False
        '''
        if edge_update:
            if self.edge_bank.empty():
                self.edge_empty = True
                self.hold       = False
            else:
                self.next_src, self.next_dst = self.edge_bank.get()
                self.hold                    = True
        
        if self.next_src == self.process_id and self.hold:
            self.edge_compute = True
            self.src  = self.next_src
            self.dst  = self.next_dst
            self.hold = False
            if self.edge_bank.empty():
                self.edge_empty = True
            else:
                self.next_src, self.next_dst = self.edge_bank.get()
                self.hold = True
        else:
            self.edge_compute = False
        '''
        
        self.rb_depth = int(self.dst/self.row_n)
        #print("Out | cell({:d},{:d}) src {:d}, dst {:d}, process_id {:d}, process_id_ring {:d}". format(self.pos_x, self.pos_y, self.src, self.dst, self.process_id, self.process_id_ring))
        self.rb_value = self.result_bank_input[self.rb_depth]
        

    # compute() represents combinational logic that takes place between 
    # positive edges of the clock (multiplication and addition)
    def compute(self):
        #ring dataflow
        if self.edge_compute:
            print("compute cell({:d},{:d}) src {:d}, dst {:d}". format(self.pos_x, self.pos_y, self.src, self.dst))
            global edge_count
            edge_count = edge_count + 1
            self.result_bank_input[self.rb_depth] = self.rb_value + self.receive_reg
        self.receive_out = self.receive_reg
        self.process_id_out = self.process_id
        self.receive_out_ring = self.receive_reg_ring    #add
        self.process_id_ring_out = self.process_id_ring
        #print("cell({:d},{:d}), hold {:d}, edge_empty {:d}". format(self.pos_x, self.pos_y, self.hold, self.edge_empty))
        #print(self.edge_number)
        
        
    def cell_state(self):
        #print("cell({:d},{:d}),rec_reg={:d}, rec_out={:d}, proc_id={:d}, proc_out={:d}". format(self.pos_x, self.pos_y, self.receive_reg, self.receive_out, self.process_id, self.process_id_out))
        print("cell({:d},{:d}),rec_reg={:d}, proc_id={:d}, rb_value={:d}". format(self.pos_x, self.pos_y, self.receive_reg, self.process_id, self.rb_value))

In [14]:
# This represents our entire array: cells, inputs, and outputs
class SystolicArray:
    # We'll take a parameter for the size of the square arrays to be multiplied
    def __init__(self, row_n, col_n):
        self.row_n = row_n
        self.col_n = col_n

        # "cells" will hold the array of processing elements
        self.cells = []
        # This array is a square with dimensions "array_size"
        for _ in range(self.row_n):
            row = []
            for _ in range(self.col_n):
                cell = SystolicArrayCell(row_n, col_n)
                row.append(cell)
            self.cells.append(row)
        
        self.cache_bank  = [Cache(maxsize=self.row_n) for _ in range(self.row_n)]
        self.edge_bank   = [Queue() for _ in range(self.row_n)]
        self.result_bank = [[list() for _ in range(self.col_n)] for _ in range(self.row_n)]

        # When all cells and inputs are created, then they can be connected 
        # (again, this would be accomplished with wiring)
        for row_num, row in enumerate(self.cells):
            for col_num, cell in enumerate(row):
                cell.connect(col_num, row_num, self) #每一行对应一个pos_y, 每一列对应一个pos_x
    
    #ring dataflow
    def edge_bucket_empty(self, e_b):
        for idx in e_b:
            for id_ in idx:
                if id_.empty() is False:
                    return True
        return False

    def edge_load_balance(self, row_n, src, dst):
        Edge_bucket   = [[Queue() for _ in range(row_n)] for _ in range(row_n)]
        for idx in sorted(zip(src, dst)):
            src, dst = idx
            Edge_bucket[src%row_n][dst%row_n].put(idx)
        while(self.edge_bucket_empty(Edge_bucket)):
            for i in range(row_n):
                num = [j for j in range(i,row_n)]
                for n in range(i):
                    num.append(n)
                for id_, val in enumerate(num):
                    #print("--({:d}, {:d})". format(val, id_))
                    if Edge_bucket[val][id_].empty() is False:
                        self.edge_bank[id_].put(Edge_bucket[val][id_].get())
    
    def edge_preprocess(self, num_node, edge_src, edge_dst):
        src, dst = zip(*(sorted(zip(edge_src, edge_dst), key=itemgetter(1))))
        result = [list() for _ in range(num_node)]
        for idx in range(len(dst)):
            result[dst[idx]].append((src[idx],dst[idx]))
        for idx in range(len(result)):
            #print(idx)
            #print(len(result[idx]))
            if len(result[idx]) is 0:
                result[idx] = []
            else:
                src, dst = zip(*result[idx])
                result_A = []
                result_B = []
                for idx_ in range(len(src)):
                    if(src[idx_] >= (dst[idx_]%self.row_n)):
                        result_A.append((src[idx_], dst[idx_]))
                    else:
                        result_B.append((src[idx_], dst[idx_]))
                result_A.extend(result_B)
                result[idx] = result_A
        return result

    def fill_edges(self, num_node, edge_src, edge_dst):
        edge_ = self.edge_preprocess(num_node, edge_src, edge_dst)
        for i, val in enumerate(edge_):
            for e in val:
                self.edge_bank[i%self.row_n].put(e)
                
    def fill_result_banks(self, num_nodes):
        for row_num in range(self.row_n):
            for idx_ in range(self.col_n):  
                for _ in range(math.ceil(num_nodes/self.row_n)):
                    self.result_bank[row_num][idx_].append(0)
                    
    def fill_idx(self,idx):
        for row_num in range(self.row_n):
            for col_num in range(self.col_n):
                self.cells[row_num][col_num].set_process_id(idx[row_num])
    
    # For this demo, all cells will read() the values of their neighbors first
    def read(self,edge_update):
        for row in self.cells:
            for cell in row:
                cell.read(edge_update)

    # And then after all cells have read(), they will compute() the next step
    def compute(self):
        for row in self.cells:
            for cell in row:
                cell.compute()
                
    def terminal_signal(self):
        for row in self.cells:
            for cell in row:
                #print(cell.hold)
                #print(cell.edge_bank.empty())
                if cell.hold or not cell.edge_empty:
                    return False
        return True
        #for id_x in self.edge_bank:
        #    if id_x.empty() is False:
        #        return False
        #return True
            
            
    def show_staus(self):
        for row in self.cells:
            for cell in row:
                cell.cell_state()
                
    # Each cycle involves a read() and a compute()
    def cycle(self, edge_update):
        # read() models register sampling on the positive edge of the clock
        self.read(edge_update)
        # compute() models the combinational logic between clock edges
        self.compute()
        #self.show_staus()

    # run() will execute the array's computation, assuming it's been filled
    def run(self, num_nodes):
        # It takes 3n-2 cycles to compute the full matrix of results
        edge_update = True
        cycle = 0
        while 1:
            print("-----Cycle----{:d}----------". format(cycle))
            self.cycle(edge_update)
            edge_update = False
            #self.get_edge_output(num_nodes)
            if(self.terminal_signal()):
                break
            cycle = cycle + 1
        return 1
        #return self.get_outputs()

    # The outputs are also staggered and transposed, so we'll format them 
    # before returning the results
    def get_outputs(self):
        ret = []

        return ret

    def get_edge_output(self, num_nodes):
        for id_x in range(num_nodes):
            print("id={:d}-|-{:d}". format(id_x, self.result_bank[int(id_x%self.row_n)][0][int(id_x/self.row_n)]))

In [15]:
# Here we'll use a small 3x3 test multiplication to see the systolic array
# in action
row_n = 3
col_n = 1
myArray = SystolicArray(row_n, col_n)

#src = [0,1,2,1,2,0,2,0,1]
#dst = [0,1,2,0,1,2,0,1,2]

#src = [0,0,0,0,1,1,2,2,2]
#dst = [0,1,3,5,2,4,0,3,5]

src = [0,1,1,2,0,0,0,2,2]
dst = [0,2,4,0,1,3,5,3,5]

#src = [0,1,1,2,0,1,0,2,2]
#dst = [0,2,4,0,1,3,5,4,5]

#src = [0,0,1,2]
#dst = [0,1,2,0]

#myArray.fill_edges(6, src, dst)
myArray.edge_load_balance(row_n, src, dst)
idx = [0,1,2]
myArray.fill_idx(idx)
myArray.fill_result_banks(6)       

res = myArray.run(6)
#assert (res == np.matmul(activations, weights)).all()
#print('Systolic array matches numpy matmul')

-----Cycle----0----------
compute cell(0,0) src 0, dst 0
compute cell(0,1) src 1, dst 4
compute cell(0,2) src 2, dst 5
-----Cycle----1----------
compute cell(0,0) src 2, dst 0
compute cell(0,1) src 0, dst 1
compute cell(0,2) src 0, dst 5
-----Cycle----2----------
compute cell(0,0) src 0, dst 3
compute cell(0,2) src 1, dst 2
-----Cycle----3----------
compute cell(0,0) src 2, dst 3
-----Cycle----4----------


In [16]:
import argparse, time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
from dgl import graph_index
from dgl.graph_index import disjoint_partition
from dgl.data import register_data_args, load_data
import math

In [17]:
    parser = argparse.ArgumentParser(description='GCN')
    parser.add_argument("--dataset", type=str, default="cora",
            help="dropout probability")
    parser.add_argument("--dropout", type=float, default=0.5,
            help="dropout probability")
    parser.add_argument("--gpu", type=int, default=-1,
            help="gpu")
    parser.add_argument("--lr", type=float, default=1e-2,
            help="learning rate")
    parser.add_argument("--n-epochs", type=int, default=200,
            help="number of training epochs")
    parser.add_argument("--n-hidden", type=int, default=16,
            help="number of hidden gcn units")
    parser.add_argument("--n-layers", type=int, default=1,
            help="number of hidden gcn layers")
    parser.add_argument("--weight-decay", type=float, default=5e-4,
            help="Weight for L2 loss")
    parser.add_argument("--self-loop", action='store_true',
            help="graph self-loop (default=False)")
    parser.set_defaults(self_loop=False)
    args = parser.parse_args(args=[])

In [18]:
    data = load_data(args)
    features = torch.FloatTensor(data.features)
    labels = torch.LongTensor(data.labels)
    train_mask = torch.ByteTensor(data.train_mask)
    val_mask = torch.ByteTensor(data.val_mask)
    test_mask = torch.ByteTensor(data.test_mask)
    in_feats = features.shape[1]
    n_classes = data.num_labels
    n_edges = data.graph.number_of_edges()

In [19]:
g = DGLGraph(data.graph)

In [20]:
partition_size = 32
Node_index = []
Edge = []
Edge_number = []
partition_number = math.ceil(g.number_of_nodes() / partition_size)
print("the graph split to {:d} part". format(partition_number))
for node_id in range(partition_number):
    #print(node_id)
    if node_id == partition_number-1:
        index = list(range(partition_size*node_id,g.number_of_nodes()))
    else:
        index = list(range(partition_size*node_id,partition_size*(node_id+1)))
    Node_index.append(index)
    src, dst = g.out_edges(index)
    Edge.append(list(zip(src.tolist(),dst.tolist())))
    Edge_number.append(src.shape[0])

the graph split to 85 part


In [21]:
src, dst = zip(*Edge[0])

In [22]:
len(Edge[0])

127

In [23]:
idx = Node_index[0]

In [24]:
row_n = 32
col_n = 1
myArray = SystolicArray(row_n, col_n)

#src = [0,1,2,1,2,0,2,0,1]
#dst = [0,1,2,0,1,2,0,1,2]

#src = [0,0,0,0,1,1,2,2,2]
#dst = [0,1,3,5,2,4,0,3,5]

#src = [0,1,1,2,0,0,0,2,2]
#dst = [0,2,4,0,1,3,5,3,5]

#src = [0,1,1,2,0,1,0,2,2]
#dst = [0,2,4,0,1,3,5,4,5]

#src = [0,0,1,2]
#dst = [0,1,2,0]

#myArray.fill_edges(2708, src, dst)
myArray.edge_load_balance(row_n, src, dst)

#idx = [0,1,2]
myArray.fill_idx(idx)
myArray.fill_result_banks(2708)       
edge_count = 0
res = myArray.run(2708)
#assert (res == np.matmul(activations, weights)).all()
#print('Systolic array matches numpy matmul')

-----Cycle----0----------
compute cell(0,0) src 0, dst 544
compute cell(0,15) src 15, dst 399
compute cell(0,23) src 23, dst 759
-----Cycle----1----------
compute cell(0,29) src 30, dst 285
-----Cycle----2----------
compute cell(0,12) src 14, dst 268
compute cell(0,28) src 30, dst 1148
-----Cycle----3----------
compute cell(0,7) src 10, dst 519
compute cell(0,28) src 31, dst 1116
-----Cycle----4----------
compute cell(0,5) src 9, dst 453
compute cell(0,10) src 14, dst 746
-----Cycle----5----------
compute cell(0,2) src 7, dst 258
compute cell(0,9) src 14, dst 393
compute cell(0,10) src 15, dst 234
compute cell(0,16) src 11, dst 624
compute cell(0,17) src 22, dst 2257
compute cell(0,18) src 13, dst 1234
compute cell(0,20) src 25, dst 20
compute cell(0,26) src 31, dst 250
-----Cycle----6----------
compute cell(0,2) src 8, dst 258
compute cell(0,4) src 10, dst 420
compute cell(0,8) src 14, dst 8
compute cell(0,12) src 18, dst 1932
compute cell(0,14) src 20, dst 334
compute cell(0,16) src 

In [51]:
edge_count

127