Skip to content

Commit

Permalink
add docs for connect
Browse files Browse the repository at this point in the history
Signed-off-by: zhongzc <zhongzc_arch@outlook.com>
  • Loading branch information
zhongzc committed Nov 22, 2020
1 parent 22b97bc commit 4a8caf0
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 66 deletions.
109 changes: 80 additions & 29 deletions py_hcl/core/stmt/connect.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,46 @@
"""
Implement connection between two PyHCL expressions.
Examples
--------
>>> from py_hcl import *
Connect literal to output:
>>> class _(Module):
... io = IO(o=Output(U.w(5)))
... io.o <<= U(10)
Connect input to output:
>>> class _(Module):
... io = IO(i=Input(U.w(8)), o=Output(U.w(5)))
... io.o <<= io.i
Connect wire to output and connect input to wire:
>>> class _(Module):
... io = IO(i=Input(U.w(8)), o=Output(U.w(5)))
... w = Wire(U.w(6))
... io.o <<= w
... w <<= io.i
Connection with wrong direction
>>> class _(Module):
... io = IO(i=Input(U.w(8)))
... lit = U(8)
... lit <<= io.i
Traceback (most recent call last):
...
py_hcl.core.stmt.error.StatementError: Connection statement with unexpected \
direction.
"""

import logging
from enum import Enum

Expand Down Expand Up @@ -30,14 +73,31 @@ def __init__(self, left, right):
connector = op_register('<<=')


def check_connect_direction(f):
def _(left: HclType, right: HclType):
if left.variable_type not in (VariableType.LOCATION,
VariableType.ASSIGNABLE_VALUE):
direction = left.variable_type
raise StatementError.connect_direction_error(
f'The lhs of connection statement can not be a {direction}')
if right.variable_type not in (VariableType.VALUE,
VariableType.ASSIGNABLE_VALUE):
direction = right.variable_type
raise StatementError.connect_direction_error(
f'The rhs of connection statement can not be a {direction}')

return f(left, right)

return _


@connector(UIntT, UIntT)
@check_connect_direction
def _(left, right):
check_connect_dir(left, right)

if left.hcl_type.width < right.hcl_type.width:
msg = 'connect(): connecting {} to {} will truncate the bits'.format(
right.hcl_type, left.hcl_type)
logging.warning(msg)
logging.warning(
f'connect(): connecting {right.hcl_type} to {left.hcl_type} '
f'will truncate the bits')
right = right[left.hcl_type.width - 1:0]

if left.hcl_type.width > right.hcl_type.width:
Expand All @@ -49,13 +109,12 @@ def _(left, right):


@connector(SIntT, SIntT)
@check_connect_direction
def _(left, right):
check_connect_dir(left, right)

if left.hcl_type.width < right.hcl_type.width:
logging.warning(
'connect(): connecting {} to {} will truncate the bits'.format(
right.hcl_type, left.hcl_type))
f'connect(): connecting {right.hcl_type} to {left.hcl_type} '
f'will truncate the bits')
right = right[left.hcl_type.width - 1:0].to_sint()

if left.hcl_type.width > right.hcl_type.width:
Expand All @@ -67,37 +126,38 @@ def _(left, right):


@connector(UIntT, SIntT)
@check_connect_direction
def _(left, right):
msg = 'connect(): connecting SInt to UInt causes auto-conversion'
logging.warning(msg)
logging.warning(
'connect(): connecting SInt to UInt will cause auto-conversion')

if left.hcl_type.width < right.hcl_type.width:
logging.warning(
'connect(): connecting {} to {} will truncate the bits'.format(
right.hcl_type, left.hcl_type))
f'connect(): connect {right.hcl_type} to {left.hcl_type} '
f'will truncate the bits')
return op_apply('<<=')(left, right[left.hcl_type.width - 1:0])

return op_apply('<<=')(left, right.to_uint())


@connector(SIntT, UIntT)
@check_connect_direction
def _(left, right):
msg = 'connect(): connecting UInt to SInt causes auto-conversion'
logging.warning(msg)
logging.warning(
'connect(): connecting UInt to SInt will cause auto-conversion')

if left.hcl_type.width < right.hcl_type.width:
logging.warning(
'connect(): connecting {} to {} will truncate the bits'.format(
right.hcl_type, left.hcl_type))
f'connect(): connecting {right.hcl_type} to {left.hcl_type} '
f'will truncate the bits')
right = right[left.hcl_type.width - 1:0]

return op_apply('<<=')(left, right.to_sint())


@connector(BundleT, BundleT)
@check_connect_direction
def _(left, right):
check_connect_dir(left, right)

# TODO: Accurate Error Message
dir_and_types = right.hcl_type.fields
keys = dir_and_types.keys()
Expand All @@ -115,9 +175,8 @@ def _(left, right):


@connector(VectorT, VectorT)
@check_connect_direction
def _(left, right):
check_connect_dir(left, right)

# TODO: Accurate Error Message
assert left.hcl_type.size == right.hcl_type.size

Expand All @@ -130,11 +189,3 @@ def _(left, right):
@connector(HclType, HclType)
def _(_0, _1):
raise StatementError.connect_type_error(_0, _1)


def check_connect_dir(left, right):
# TODO: Accurate Error Message
assert left.variable_type in (VariableType.LOCATION,
VariableType.ASSIGNABLE_VALUE)
assert right.variable_type in (VariableType.VALUE,
VariableType.ASSIGNABLE_VALUE)
12 changes: 11 additions & 1 deletion py_hcl/core/stmt/error/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@ def set_up():
'ConnectTypeError': {
'code': 301,
'value':
StatementError('connect statement contains unexpected type')
StatementError('Connect statement contains unexpected type.')
},
'ConnectDirectionError': {
'code':
302,
'value':
StatementError('Connection statement with unexpected direction.')
}
})

Expand All @@ -29,5 +35,9 @@ def connect_type_error(*args):
msg = 'connect(): unsupported connect type: {}'.format(ts)
return StatementError.err('ConnectTypeError', msg)

@staticmethod
def connect_direction_error(msg):
return StatementError.err('ConnectDirectionError', msg)


set_up()
8 changes: 0 additions & 8 deletions py_hcl/transformer/pyhcl_to_firrtl/context.py

This file was deleted.

36 changes: 18 additions & 18 deletions py_hcl/transformer/pyhcl_to_firrtl/conv_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from multipledispatch import dispatch

from py_hcl.transformer.pyhcl_to_firrtl.context import Context
from py_hcl.transformer.pyhcl_to_firrtl.global_context import GlobalContext
from py_hcl.transformer.pyhcl_to_firrtl.conv_port import ports_to_bundle_type
from py_hcl.transformer.pyhcl_to_firrtl.conv_type import convert_type
from py_hcl.transformer.pyhcl_to_firrtl.utils import build_io_name, get_io_obj
Expand Down Expand Up @@ -35,9 +35,9 @@


def convert_expr_by_id(expr_id: int):
obj = Context.expr_table[expr_id]
if id(obj) in Context.expr_obj_id_to_ref:
return [], Context.expr_obj_id_to_ref[id(obj)]
obj = GlobalContext.expr_table[expr_id]
if id(obj) in GlobalContext.expr_obj_id_to_ref:
return [], GlobalContext.expr_obj_id_to_ref[id(obj)]

return convert_expr(obj)

Expand Down Expand Up @@ -120,22 +120,22 @@ def convert_expr_op(expr_holder: ExprHolder, et: Extend):
typ = convert_type(expr_holder.hcl_type)
ref = copy.copy(v_ref)
ref.tpe = typ
Context.expr_obj_id_to_ref[id(expr_holder)] = ref
GlobalContext.expr_obj_id_to_ref[id(expr_holder)] = ref
return stmts, ref


@dispatch()
def convert_expr_op(expr_holder: ExprHolder, vi: VecIndex):
stmts, v_ref = convert_expr_by_id(vi.ref_expr_id)
ref = SubIndex(v_ref, vi.index, convert_type(expr_holder.hcl_type))
Context.expr_obj_id_to_ref[id(expr_holder)] = ref
GlobalContext.expr_obj_id_to_ref[id(expr_holder)] = ref
return stmts, ref


@dispatch()
def convert_expr_op(expr_holder: ExprHolder, fa: FieldAccess):
typ = convert_type(expr_holder.hcl_type)
obj = Context.expr_table[fa.ref_expr_id]
obj = GlobalContext.expr_table[fa.ref_expr_id]

def fetch_current_io_holder(obj):
current_node = obj.io_chain_head
Expand All @@ -148,21 +148,21 @@ def fetch_current_io_holder(obj):
io_holder = fetch_current_io_holder(obj)
name = build_io_name(io_holder.module_name, fa.item)
ref = Reference(name, typ)
Context.expr_obj_id_to_ref[id(expr_holder)] = ref
GlobalContext.expr_obj_id_to_ref[id(expr_holder)] = ref
return [], ref

elif isinstance(obj, ModuleInst):
stmts, b_ref = convert_expr_by_id(obj.id)
io_holder = fetch_current_io_holder(get_io_obj(obj.packed_module))
name = build_io_name(io_holder.module_name, fa.item)
ref = SubField(b_ref, name, typ)
Context.expr_obj_id_to_ref[id(expr_holder)] = ref
GlobalContext.expr_obj_id_to_ref[id(expr_holder)] = ref
return stmts, ref

else:
stmts, src_ref = convert_expr_by_id(fa.ref_expr_id)
ref = SubField(src_ref, fa.item, typ)
Context.expr_obj_id_to_ref[id(expr_holder)] = ref
GlobalContext.expr_obj_id_to_ref[id(expr_holder)] = ref
return stmts, ref


Expand All @@ -174,14 +174,14 @@ def convert_expr(expr_holder: ExprHolder):
@dispatch()
def convert_expr(slit: SLiteral):
ft = SIntLiteral(slit.value, convert_type(slit.hcl_type).width)
Context.expr_obj_id_to_ref[id(slit)] = ft
GlobalContext.expr_obj_id_to_ref[id(slit)] = ft
return [], ft


@dispatch()
def convert_expr(ulit: ULiteral):
ft = UIntLiteral(ulit.value, convert_type(ulit.hcl_type).width)
Context.expr_obj_id_to_ref[id(ulit)] = ft
GlobalContext.expr_obj_id_to_ref[id(ulit)] = ft
return [], ft


Expand All @@ -192,16 +192,16 @@ def convert_expr(wire: Wire):

stmt = DefWire(name, typ)
ref = Reference(name, typ)
Context.expr_obj_id_to_ref[id(wire)] = ref
GlobalContext.expr_obj_id_to_ref[id(wire)] = ref
return [stmt], ref


@dispatch()
def convert_expr(mi: ModuleInst):
if mi.module_name not in Context.modules:
if mi.module_name not in GlobalContext.modules:
from .conv_module import convert_module
convert_module(mi.packed_module)
module = Context.modules[mi.module_name]
module = GlobalContext.modules[mi.module_name]
name = NameGetter.get(mi.id)
ref = Reference(name, ports_to_bundle_type(module.ports))
stmts = [
Expand All @@ -211,7 +211,7 @@ def convert_expr(mi: ModuleInst):
Connect(SubField(ref, 'reset', UIntType(Width(1))),
Reference('reset', UIntType(Width(1))))
]
Context.expr_obj_id_to_ref[id(mi)] = ref
GlobalContext.expr_obj_id_to_ref[id(mi)] = ref
return stmts, ref


Expand All @@ -221,7 +221,7 @@ class NameGetter(object):
@classmethod
def get(cls, expr_id: int):
try:
return Context.expr_id_to_name[expr_id]
return GlobalContext.expr_id_to_name[expr_id]
except KeyError:
cls.cnt += 1
return "_T_" + str(cls.cnt)
Expand All @@ -230,5 +230,5 @@ def get(cls, expr_id: int):
def save_node_ref(op_ir, name, tpe, obj_id):
stmt = DefNode(name, op_ir)
ref = Reference(name, tpe)
Context.expr_obj_id_to_ref[obj_id] = ref
GlobalContext.expr_obj_id_to_ref[obj_id] = ref
return stmt, ref
6 changes: 3 additions & 3 deletions py_hcl/transformer/pyhcl_to_firrtl/conv_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from py_hcl.transformer.pyhcl_to_firrtl.context import Context
from py_hcl.transformer.pyhcl_to_firrtl.global_context import GlobalContext
from py_hcl.transformer.pyhcl_to_firrtl.conv_port import convert_ports
from py_hcl.transformer.pyhcl_to_firrtl.conv_stmt import convert_stmt
from py_hcl.transformer.pyhcl_to_firrtl.utils import build_reserve_name, \
Expand All @@ -12,7 +12,7 @@


def convert_module(packed_module: PackedModule):
Context.expr_id_to_name.update(
GlobalContext.expr_id_to_name.update(
flatten_named_expr_chain(packed_module.named_expr_chain))

name = packed_module.name
Expand All @@ -23,7 +23,7 @@ def convert_module(packed_module: PackedModule):
final_stmts = [ss for s in stmts for ss in convert_stmt(s)]

module = DefModule(name, ports, Block(final_stmts))
Context.modules[name] = module
GlobalContext.modules[name] = module
return module


Expand Down
10 changes: 5 additions & 5 deletions py_hcl/transformer/pyhcl_to_firrtl/convertor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from py_hcl.transformer.pyhcl_to_firrtl.context import Context
from py_hcl.transformer.pyhcl_to_firrtl.global_context import GlobalContext
from py_hcl.transformer.pyhcl_to_firrtl.conv_module import convert_module
from py_hcl.core.module.packed_module import PackedModule
from py_hcl.firrtl_ir.stmt.defn.circuit import DefCircuit
Expand All @@ -7,10 +7,10 @@

def convert(packed_module: PackedModule):
convert_module(packed_module)
modules = list(Context.modules.values())
Context.modules.clear()
Context.expr_obj_id_to_ref.clear()
Context.expr_id_to_name.clear()
modules = list(GlobalContext.modules.values())

GlobalContext.clear()

cir = DefCircuit(packed_module.name, modules)
assert check(cir)
return cir
Loading

0 comments on commit 4a8caf0

Please sign in to comment.