In [33]:
from torch_geometric.datasets import Planetoid
from torch_geometric.data.data import Data
from scipy.sparse import csc_matrix 
import numpy as np

# constrcut CSC data
dataset = Planetoid(root='/tmp/Cora', name='Cora')

A = np.zeros((list(dataset[0]['x'].shape)[0],list(dataset[0]['x'].shape)[0]))
for row, col in zip(dataset[0]['edge_index'][0], dataset[0]['edge_index'][1]):
    A[row][col] = 1

A_CSC = csc_matrix(A)
F1_CSC = csc_matrix(dataset[0]['x'])


In [35]:
class HyGCN():
    def __init__(self):
        self.cycle       = 0
        self.agg_engine  = self.AggEngine()
        self.comb_engine = self.CombEngine()
        self.mem         = self.MEM()
        self.stats       = self.Stats()

    def tick(self):
        """
        Global ticker, tick each component
        """
        self.cycle += 1
        self.agg_engine.tick()
        self.comb_engine.tick()

    class AggEngine():
        def __init__(self):
            self.simd_cores = [self.SIMDCore()] * 32
            self.esched = self.eSched()
            self.sampler = self.Sampler()
            self.edge_buffer = self.EdgeBuffer()
            self.input_buffer = self.InputBuffer()
            self.prefetcher = self.Prefetcher()

        def tick(self):
            pass

        class SIMDCore():
            def __init__(self):
                self.alu_status = [0] * 16

        class eSched():
            """

            """
            pass

        class Sampler():
            pass

        class EdgeBuffer():
            pass

        class InputBuffer():
            pass

        class Prefetcher():
            pass



    class CombEngine():
        def __init__(self):
            self.systolic_module = [self.SystolicModule()] * 8
            self.vsched = self.vSched()
            self.weight_buffer = self.WeightBuffer()

        def tick(self):
            pass
        
        class SystolicModule():
            def __init__(self):
                self.pes = [self.PE()] * 128

            class PE():
                def __init__(self):
                    self.pe_status = 0

        class vSched():
            pass

        class WeightBuffer():
            pass
    
    class MEM():
        def __init__(self):
            self.latency = 200
            self.queue   = [] # assume infinite queue, tuple [addr, cnt_down]
            self.addr    = [] # keep track of addr in queue
            
        def get(self, addr):
            if addr in self.addr:
                return
            self.addr.append(addr)
            self.queue.append([addr, self.latency])
        
        def tick(self):
            for i in range(len(self.queue)):
                self.queue[i][1] -= 1
            # notify finished accesss
            i = 0
            while i < len(self.queue):
                if not self.queue[i][1]:
                    self.notify(self.queue[i])
                    self.queue.pop(i)
                else:
                    i += 1
                    
        def notify(self, item):
            pass
            
    class Stats():
        pass

In [47]:
class MEM():
    def __init__(self):
        self.latency = 5
        self.queue   = [] # assume infinite queue, tuple [addr, cnt_down]
        self.addr    = [] # keep track of addr in queue

    def get(self, addr):
        if addr in self.addr:
            return
        self.addr.append(addr)
        self.queue.append([addr, self.latency])

    def tick(self):
        for i in range(len(self.queue)):
            self.queue[i][1] -= 1

        # notify finished accesss
        i = 0
        while i < len(self.queue):
            if not self.queue[i][1]:
                self.notify(self.queue[i])
                self.queue.pop(i)
                self.addr.pop(i)
            else:
                i += 1
        print(self.queue)
        print(self.addr)

    def notify(self, item):
        print(item[0], "Finished!")

    
mem = MEM()
mem.get('3241241')
mem.tick()
mem.get('54q98')
mem.tick()
mem.get('5932lkc')
mem.get('jlkfja')
mem.tick()
mem.get('foiwaertoi')

for i in range(10):
    mem.tick()



[['3241241', 4]]
['3241241']
[['3241241', 3], ['54q98', 4]]
['3241241', '54q98']
[['3241241', 2], ['54q98', 3], ['5932lkc', 4], ['jlkfja', 4]]
['3241241', '54q98', '5932lkc', 'jlkfja']
[['3241241', 1], ['54q98', 2], ['5932lkc', 3], ['jlkfja', 3], ['foiwaertoi', 4]]
['3241241', '54q98', '5932lkc', 'jlkfja', 'foiwaertoi']
3241241 Finished!
[['54q98', 1], ['5932lkc', 2], ['jlkfja', 2], ['foiwaertoi', 3]]
['54q98', '5932lkc', 'jlkfja', 'foiwaertoi']
54q98 Finished!
[['5932lkc', 1], ['jlkfja', 1], ['foiwaertoi', 2]]
['5932lkc', 'jlkfja', 'foiwaertoi']
5932lkc Finished!
jlkfja Finished!
[['foiwaertoi', 1]]
['foiwaertoi']
foiwaertoi Finished!
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
