# Playground
Understand the schema

## Script provided as `depthwise_graphviz.py`

In [1]:
##############################################
#                                            #
#        Depthwise convolutional layer       #
#                                            #
##############################################

# Implort pygears types
from pygears import gear, datagear, sim, find, reg

# Import pygears types
from pygears.typing import Array, Fixp, Queue, Tuple, Uint

# Import pygears built-in modules
from pygears.lib import accum, ccat, collect, czip, dreg, drv, flatten, mul, qdeal, qrange, qround, queuemap, replicate, saturate, sdp, when

# Packages used for verification and visualization
import matplotlib.image as mpimg
import numpy as np
import matplotlib.pyplot as plt

reg['gear/memoize'] = False

##############################################
#                                            #
#                   Design                   #
#                                            #
##############################################


# Dot product implementation with saturation and rounding
@gear
def dot(din):
    return din \
        | queuemap(f=mul) \
        | accum(init=Fixp[10, 18](0)) \
        | qround \
        | saturate(t=Uint[8])


# Reorganizes data on the bus for proper distribution among 3 dot product
# modules
@datagear
def reorder(din: Queue[Tuple['pixel', 'weight']]
            ) -> Array[Queue[Tuple['pixel[0]', 'weight[0]']], 3]:
    p = din.data[0]
    w = din.data[1]

    return (
        ((p[0], w[0]), din.eot),
        ((p[1], w[1]), din.eot),
        ((p[2], w[2]), din.eot),
    )


# Generates write addresses based on filter weights stream. This is used for
# caching of filter weights.
@gear(hdl={'compile': True})
async def wr_req(weights: Queue) -> Tuple[Uint[4], 'weights.data']:
    cnt = Uint[4](0)
    async for w, last in weights:
        yield cnt, w
        cnt += 1


# Implements:
#     1. filter weights caching for a single CNN filter
#     2. filter weights readout synchronized with the input image segment
#     3. 3 dot product modules in operating in parallel on different slices
#        along the image segment depth
#     4. outputs the result as a vector of 3 output feature map elements
@gear
def filter(
        img: Queue[Array[Uint, 3]],  # Image segment stream
        weights: Queue,  # Filter weights stream
) -> b'img.data':

    # - Performs the readout of the cached filter weights out of a simple
    #   dual-port (sdp) memory
    # - Before readout waits for the last of the filter weigts to be streamed in
    #   "weights['eot']"
    # - Kernel weights will be read out 30*30=900 times for each of the image
    #   segments
    w = when(weights['eot'] | dreg, 9) \
        | replicate(30 * 30) \
        | flatten \
        | qrange \
        | flatten \
        | sdp(wr_req(weights))

    # Pair up corresponding slices of the kernel and image segment and send
    # them for processing to a set of "dot" modules, one for each slice along
    # the tensor depth
    res = [dot(d) for d in reorder(czip(img, w))]

    # Synchronize outputs of the "dot" modules and combine them into a vector
    return ccat(*res)


# Top level design module - distributes image segments for processing on "num"
# filters in parallel
@gear
def depthwise(
        img,  # Image segment stream
        weights,  # Filter weights stream
        *,
        num,  # Number of parallel filters available
):
    res = [filter(img, w) for w in qdeal(weights, num=num, lvl=1)]

    return ccat(*res)


##############################################
#                                            #
#                 Simulation                 #
#                                            #
##############################################

res = []

# Driver that outputs image segments
img_drv = drv(t=Queue[Array[Uint[8], 3]], seq=[])
# Driver that outputs filter weights
w_drv = drv(t=Queue[Array[Fixp[3, 8], 3]], seq=[])

# Top level connection between drivers, dut and a monitor
#  - "depthwise" module will be first converted to SystemVerilog and simulated
#    using "verilator" HDL simulator

depthwise(img_drv, w_drv, num=2) \
    | Array[Array[int, 3], 2] \
    | collect(result=res)

##############################################
#                                            #
#     Graphviz hierarchy visualization       #
#                                            #
##############################################

top = find('/')

# Traverse hierarchy starting from the 'top' and generate graphviz graph


## Utils
Custom functions

In [2]:
import pygears


def visit(node):
    if not node:
        return
    
    visited.append(node)
        
    # For Gear visit it's local interfaces and out ports
    if isinstance(node, pygears.core.gear.Gear) and node not in gears:
        gears.append(node)
        for intf in node.local_intfs:
            print(f'Visiting local_intfs of {node}')
            visit(intf)
        for out_port in node.out_ports:
            visit(out_port)
    # For interface visit it's in and out ports
    elif isinstance(node, pygears.core.intf.Intf):  # and node not in intfs:
        print(node)
        intfs.append(node)
        visit(node.producer)
        for consumer in node.consumers:
            visit(consumer)
    # For in port visit it's gear
    elif isinstance(node, pygears.core.port.InPort) and node not in in_ports:
        in_ports.append(node)
        visit(node.gear)
    # For out port visit it's gear
    elif isinstance(node, pygears.core.port.OutPort) and node not in out_ports:
        out_ports.append(node)
        visit(node.gear)
    else:
        return
        print(f'Unexpected type: {type(node)}!')


# class Node():
#     def __init__(self, gear):
#         self.name = gear.name
#         self.gear = gear
#         self.in_nodes = []
#         self.out_nodes = []
#         self.depth = None
        
#     def set_in_nodes()
        
#     def depth(self):
#         if depth:
#             return depth
#         else:
#             maxx = 0
#             for in_port in gear.in_port_intfs:
#                 in_port_intf.prod


# def pretty(obj):
#     if isinstance(obj, pygears.core.gear.Gear):
        
#         gear_template = f"""
# Gear: {obj}
#   Input ports: {obj.in_ports}
#   Output ports: {obj.out_ports}
#         """
#         print(gear_template)

## Analyze `top`

In [3]:
visited = []
gears = []
intfs = []
in_ports = []
out_ports = []

visit(top)

Visiting local_intfs of Top
Intf([Array[u8, 3]])
Visiting local_intfs of depthwise("/depthwise")
Intf([Array[q3.5, 3]])
Visiting local_intfs of depthwise("/depthwise")
Intf([Array[q3.5, 3]])
Visiting local_intfs of filter("/depthwise/filter0")
Intf([Array[q3.5, 3]])
Visiting local_intfs of filter("/depthwise/filter0")
Intf(u1)
Visiting local_intfs of filter("/depthwise/filter0")
Intf(u1)
Visiting local_intfs of when_pass("/depthwise/filter0/when")
Intf(u4)
Visiting local_intfs of when_pass("/depthwise/filter0/when")
Intf(u1)
Visiting local_intfs of when_pass("/depthwise/filter0/when")
Intf((u4, u1))
Visiting local_intfs of when_pass("/depthwise/filter0/when")
Intf(u4 | u4)
Visiting local_intfs of filt_fix_sel("/depthwise/filter0/when/filt")
Intf(u4 | u4)
Visiting local_intfs of filt_fix_sel("/depthwise/filter0/when/filt")
Intf(u1)
Visiting local_intfs of filt_fix_sel("/depthwise/filter0/when/filt")
Intf((u4 | u4, u1))
Visiting local_intfs of filt_fix_sel("/depthwise/filter0/when/filt")

In [4]:
gears[:3]

[Top, drv("/drv0"), depthwise("/depthwise")]

In [5]:
intfs[:3]

[Intf(Queue[Array[Uint[8], 3], 1]),
 Intf(Queue[Array[Fixp[3, 8], 3], 1]),
 Intf(Queue[Array[Fixp[3, 8], 3], 1])]

In [6]:
in_ports[:3]

[InPort("/depthwise.img"),
 InPort("/depthwise.weights"),
 InPort("/depthwise/qdeal.din")]

In [7]:
out_ports[:3]

[OutPort("/drv0.dout"),
 OutPort("/depthwise/qdeal.dout0"),
 OutPort("/depthwise/qdeal.dout1")]

In [8]:
def sanity_check(objs):
    print(len(objs), len(set(objs)))

In [9]:
sanity_check(gears)
sanity_check(in_ports)
sanity_check(out_ports)
sanity_check(intfs)

164 164
211 211
169 169
241 241


## GraphViz

In [10]:
import graphviz

### Revision 0 - Naive

In [11]:
g = graphviz.Digraph(comment='Anari AI')
g.attr(rankdir='LR')
for gear in gears:
    with g.subgraph(name=f'cluster_{gear.name}') as sg:
        sg.attr(label=gear.name)
        for in_port in gear.in_ports:
            sg.node(name=in_port.name, label=in_port.name.split('.')[-1])
        with sg.subgraph() as s:
            s.attr(rank='0')
            for in_port in gear.in_ports:
                s.node(in_port.name)
        for out_port in gear.out_ports:
            sg.node(name=out_port.name, label=out_port.name.split('.')[-1])
        with sg.subgraph() as s:
            s.attr(rank='1')
            for out_port in gear.out_ports:
                s.node(out_port.name)
#     sg.attr = {'label': gear.name}
#     print(sg.attr['label'])
#     dot.subgraph(sg)

for intf in intfs:
    for out_port in intf.consumers:
        g.edge(intf.producer.name, out_port.name)

# with g.subgraph() as sg:
#     sg.attr(rank='same')
#     sg.node('/depthwise/filter1/replicate')
#     sg.node('/depthwise/filter1/flatten0')

# g.view()
# print(g.source)

### Revision 1 - Blocks

In [12]:
g = graphviz.Digraph(comment='Anari AI', node_attr={'shape': 'record'})
g.attr(rankdir='LR')

def short(port):
    return port.name.split('.')[-1]


def dot_string(ports):
    body = '|'.join(f'<{short(port)}> {short(port)}' for port in ports)
    return '{' + body + '}'
    # return body
    

for gear in gears:
    ins = dot_string(gear.in_ports)
    outs = dot_string(gear.out_ports)
    g.node(name=gear.name, label='{' + f'{ins}|{gear.name}|{outs}' + '}')
    # print(in_port_template)
#     with g.node(name=gear.name, label=gear.name) as node:
#         for in_port in gear.in_ports:
#             sg.node(name=in_port.name, label=in_port.name.split('.')[-1])
#         with sg.subgraph() as s:
#             s.attr(rank='0')
#             for in_port in gear.in_ports:
#                 s.node(in_port.name)
#         for out_port in gear.out_ports:
#             sg.node(name=out_port.name, label=out_port.name.split('.')[-1])
#         with sg.subgraph() as s:
#             s.attr(rank='1')
#             for out_port in gear.out_ports:
#                 s.node(out_port.name)
#     sg.attr = {'label': gear.name}
#     print(sg.attr['label'])
#     dot.subgraph(sg)

for intf in intfs:
    for out_port in intf.consumers:
        src = f'{intf.producer.gear.name}:{short(intf.producer)}'
        dst = f'{out_port.gear.name}:{short(out_port)}'
        g.edge(src, dst)

# with g.subgraph() as sg:
#     sg.attr(rank='same')
#     sg.node('/depthwise/filter1/replicate')
#     sg.node('/depthwise/filter1/flatten0')

# g.view()
# print(g.source)

### Revision 2 - Subgraphs - Missed the point

In [13]:
g = graphviz.Digraph(comment='Anari AI', node_attr={'shape': 'record'})
g.attr(rankdir='LR')


def short(port):
    return port.name.split('.')[-1]


def dot_string(ports):
    body = '|'.join(f'<{short(port)}> {short(port)}' for port in ports)
    return '{' + body + '}'
    # return body


def cluster_name(gear):
    return f'cluster_{gear.name}' if gear.hierarchical else gear.name


def find_gear(name):
    return find(name.split('cluster_')[1]) if 'cluster_' in name else find(name)


# for gear in gears:
#     if gear.hierarchical:
#         continue
#     ins = dot_string(gear.in_ports)
#     outs = dot_string(gear.out_ports)
#     g.node(name=gear.name, label='{' + f'{ins}|{gear.name}|{outs}' + '}')
    # print(in_port_template)
#     with g.node(name=gear.name, label=gear.name) as node:
#         for in_port in gear.in_ports:
#             sg.node(name=in_port.name, label=in_port.name.split('.')[-1])
#         with sg.subgraph() as s:
#             s.attr(rank='0')
#             for in_port in gear.in_ports:
#                 s.node(in_port.name)
#         for out_port in gear.out_ports:
#             sg.node(name=out_port.name, label=out_port.name.split('.')[-1])
#         with sg.subgraph() as s:
#             s.attr(rank='1')
#             for out_port in gear.out_ports:
#                 s.node(out_port.name)
#     sg.attr = {'label': gear.name}
#     print(sg.attr['label'])
#     dot.subgraph(sg)

subgraphs = {}

for gear in gears:
    if gear.hierarchical and gear.name and cluster_name(gear) not in subgraphs:
        subgraphs[cluster_name(gear)] = []
    
    parent = find('/'.join(gear.name.split('/')[:-1]))
    if parent:
        if cluster_name(parent) in subgraphs:
            subgraphs[cluster_name(parent)].append(cluster_name(gear))
        else:
            subgraphs[cluster_name(parent)] = [cluster_name(gear)]

for subgraph in sorted(subgraphs, key=lambda x: len(x.split('/')), reverse=True):
    with g.subgraph(name=subgraph) as sg:
        label = subgraph  # .split('/')[-1]
        sg.attr(label=label)
        for node in subgraphs[subgraph]:
            gear = find_gear(node)
            if gear.hierarchical:
                sg.node(node)
            else:
                ins = dot_string(gear.in_ports)
                outs = dot_string(gear.out_ports)
                sg.node(name=gear.name, label='{' + f'{ins}|{gear.name}|{outs}' + '}')
                
for intf in intfs:
    for out_port in intf.consumers:
        src = f'{intf.producer.gear.name}:{short(intf.producer)}'
        dst = f'{out_port.gear.name}:{short(out_port)}'
        if intf.producer.gear.hierarchical or out_port.gear.hierarchical:
            continue
        g.edge(src, dst)

# g.view()
# print(g.source)

'Digraph.gv.pdf'

In [14]:
subgraphs

{'cluster_': ['cluster_',
  '/drv0',
  'cluster_/depthwise',
  '/drv1',
  '/cast_dout',
  '/collect'],
 'cluster_/depthwise': ['/depthwise/qdeal',
  'cluster_/depthwise/filter0',
  'cluster_/depthwise/filter1',
  '/depthwise/ccat'],
 'cluster_/depthwise/filter0': ['/depthwise/filter0/sieve_eot',
  '/depthwise/filter0/wr_req',
  '/depthwise/filter0/dreg',
  'cluster_/depthwise/filter0/when',
  '/depthwise/filter0/const0',
  '/depthwise/filter0/ccat0',
  '/depthwise/filter0/const1',
  '/depthwise/filter0/replicate',
  '/depthwise/filter0/flatten0',
  '/depthwise/filter0/qrange',
  '/depthwise/filter0/flatten1',
  'cluster_/depthwise/filter0/sdp',
  'cluster_/depthwise/filter0/czip',
  '/depthwise/filter0/reorder',
  '/depthwise/filter0/sieve_0',
  '/depthwise/filter0/sieve_1',
  '/depthwise/filter0/sieve_2',
  'cluster_/depthwise/filter0/dot0',
  '/depthwise/filter0/ccat1',
  'cluster_/depthwise/filter0/dot1',
  'cluster_/depthwise/filter0/dot2',
  '/depthwise/filter0/cast'],
 'cluster_/

### Revision 3 - Subgraphs done right

In [25]:
g = graphviz.Digraph(comment='Anari AI', node_attr={'shape': 'record'})
g.attr(rankdir='LR')


def short(port):
    return port.name.split('.')[-1]


def dot_string(ports):
    body = '|'.join(f'<{short(port)}> {short(port)}' for port in ports)
    return '{' + body + '}'
    # return body


def cluster_name(gear):
    return f'cluster_{gear.name}' if gear.hierarchical else gear.name


def find_gear(name):
    return find(name.split('cluster_')[1]) if 'cluster_' in name else find(name)


def populate(subgraph, name, subgraphs):
    print(f'Populating: {name} with {subgraphs[name]}')
    for child in subgraphs[name]:
        gear = find_gear(child)
        if gear.hierarchical:
            with subgraph.subgraph(name=child) as sg:
                label = gear.name.split('/')[-1]
                sg.attr(label=label)
                populate(sg, child, subgraphs)
        else:
            ins = dot_string(gear.in_ports)
            outs = dot_string(gear.out_ports)
            subgraph.node(name=gear.name, label='{' + f'{ins}|{gear.name}|{outs}' + '}')
    

subgraphs = {}

for gear in gears:
    if gear.hierarchical and gear.name and cluster_name(gear) not in subgraphs:
        subgraphs[cluster_name(gear)] = []
    
    parent = find('/'.join(gear.name.split('/')[:-1]))
    if parent and cluster_name(parent) != cluster_name(gear):
        if cluster_name(parent) in subgraphs:
            subgraphs[cluster_name(parent)].append(cluster_name(gear))
        else:
            subgraphs[cluster_name(parent)] = [cluster_name(gear)]

for subgraph in subgraphs:
    if len(subgraph.split('/')) == 1:
        with g.subgraph(name=subgraph) as sg:
            gear = find_gear(subgraph)
            label = gear.name.split('/')[-1]
            populate(sg, subgraph, subgraphs)
        print(f'Processing: {subgraph}')
    else:
        continue

for intf in intfs:
    for out_port in intf.consumers:
        src = f'{intf.producer.gear.name}:{short(intf.producer)}'
        dst = f'{out_port.gear.name}:{short(out_port)}'
        if intf.producer.gear.hierarchical or out_port.gear.hierarchical:
            continue
        g.edge(src, dst)

g.view()
# print(g.source)

Populating: cluster_ with ['/drv0', 'cluster_/depthwise', '/drv1', '/cast_dout', '/collect']
Populating: cluster_/depthwise with ['/depthwise/qdeal', 'cluster_/depthwise/filter0', 'cluster_/depthwise/filter1', '/depthwise/ccat']
Populating: cluster_/depthwise/filter0 with ['/depthwise/filter0/sieve_eot', '/depthwise/filter0/wr_req', '/depthwise/filter0/dreg', 'cluster_/depthwise/filter0/when', '/depthwise/filter0/const0', '/depthwise/filter0/ccat0', '/depthwise/filter0/const1', '/depthwise/filter0/replicate', '/depthwise/filter0/flatten0', '/depthwise/filter0/qrange', '/depthwise/filter0/flatten1', 'cluster_/depthwise/filter0/sdp', 'cluster_/depthwise/filter0/czip', '/depthwise/filter0/reorder', '/depthwise/filter0/sieve_0', '/depthwise/filter0/sieve_1', '/depthwise/filter0/sieve_2', 'cluster_/depthwise/filter0/dot0', '/depthwise/filter0/ccat1', 'cluster_/depthwise/filter0/dot1', 'cluster_/depthwise/filter0/dot2', '/depthwise/filter0/cast']
Populating: cluster_/depthwise/filter0/when w

'Digraph.gv.pdf'

## Stop here

In [15]:
stop here

SyntaxError: invalid syntax (4067800170.py, line 1)

In [None]:
type(top)

In [None]:
top

In [None]:
top.out_port_intfs

In [None]:
top.dout

In [None]:
top.out_ports

In [None]:
top.params

In [None]:
top.basename

In [None]:
top.const_args

In [None]:
top.hierarchical

In [None]:
top.in_port_intfs

In [None]:
top.in_ports

In [None]:
top.inputs

In [None]:
top.in_port_intfs

In [None]:
type(top.local_intfs[0])

In [None]:
producers = []
consumers = []
for ifc in top.local_intfs:
    print(f'producer: {ifc.producer}')
    print(f'producer type: {type(ifc.producer)}')
    producers.append(ifc.producer)
    print(f'consumers: {ifc.consumers}')
    print(f'consumer type: {type(ifc.consumers[0])}')
    consumers.extend(ifc.consumers)

In [None]:
top.name

In [None]:
top.out_port_intfs

In [None]:
top.outnames

In [None]:
top.parent

In [None]:
top.trace

## Traverse the tree

In [None]:
producers[0].name

In [None]:
drv0 = find('/drv0')

In [None]:
type(drv0)

In [None]:
drv0.out_ports

In [None]:
pretty(producers[0].gear)