Skip to content

Commit

Permalink
[mypyc] Implement CallC IR (#8880)
Browse files Browse the repository at this point in the history
Relates to mypyc/mypyc#709

This PR adds a new IR op CallC to replace some PrimitiveOp that simply calls a C 
function. To demonstrate this prototype, str.join primitive is now switched from 
PrimitiveOp to CallC, with identical generated C code.
  • Loading branch information
TH3CHARLie committed May 27, 2020
1 parent b3d4398 commit f94fc7e
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 9 deletions.
5 changes: 4 additions & 1 deletion mypyc/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Value, ControlOp,
BasicBlock, OpVisitor, Assign, LoadInt, LoadErrorValue, RegisterOp, Goto, Branch, Return, Call,
Environment, Box, Unbox, Cast, Op, Unreachable, TupleGet, TupleSet, GetAttr, SetAttr,
LoadStatic, InitStatic, PrimitiveOp, MethodCall, RaiseStandardError,
LoadStatic, InitStatic, PrimitiveOp, MethodCall, RaiseStandardError, CallC
)


Expand Down Expand Up @@ -195,6 +195,9 @@ def visit_cast(self, op: Cast) -> GenAndKill:
def visit_raise_standard_error(self, op: RaiseStandardError) -> GenAndKill:
return self.visit_register_op(op)

def visit_call_c(self, op: CallC) -> GenAndKill:
return self.visit_register_op(op)


class DefinedVisitor(BaseAnalysisVisitor):
"""Visitor for finding defined registers.
Expand Down
7 changes: 6 additions & 1 deletion mypyc/codegen/emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
OpVisitor, Goto, Branch, Return, Assign, LoadInt, LoadErrorValue, GetAttr, SetAttr,
LoadStatic, InitStatic, TupleGet, TupleSet, Call, IncRef, DecRef, Box, Cast, Unbox,
BasicBlock, Value, MethodCall, PrimitiveOp, EmitterInterface, Unreachable, NAMESPACE_STATIC,
NAMESPACE_TYPE, NAMESPACE_MODULE, RaiseStandardError
NAMESPACE_TYPE, NAMESPACE_MODULE, RaiseStandardError, CallC
)
from mypyc.ir.rtypes import RType, RTuple
from mypyc.ir.func_ir import FuncIR, FuncDecl, FUNC_STATICMETHOD, FUNC_CLASSMETHOD
Expand Down Expand Up @@ -415,6 +415,11 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> None:
self.emitter.emit_line('PyErr_SetNone(PyExc_{});'.format(op.class_name))
self.emitter.emit_line('{} = 0;'.format(self.reg(op)))

def visit_call_c(self, op: CallC) -> None:
dest = self.get_dest_assign(op)
args = ', '.join(self.reg(arg) for arg in op.args)
self.emitter.emit_line("{}{}({});".format(dest, op.function_name, args))

# Helpers

def label(self, label: BasicBlock) -> str:
Expand Down
29 changes: 29 additions & 0 deletions mypyc/ir/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,31 @@ def accept(self, visitor: 'OpVisitor[T]') -> T:
return visitor.visit_raise_standard_error(self)


class CallC(RegisterOp):
"""ret = func_call(arg0, arg1, ...)
A call to a C function
"""

error_kind = ERR_MAGIC

def __init__(self, function_name: str, args: List[Value], ret_type: RType, line: int) -> None:
super().__init__(line)
self.function_name = function_name
self.args = args
self.type = ret_type

def to_str(self, env: Environment) -> str:
args_str = ', '.join(env.format('%r', arg) for arg in self.args)
return env.format('%r = %s(%s)', self, self.function_name, args_str)

def sources(self) -> List[Value]:
return self.args

def accept(self, visitor: 'OpVisitor[T]') -> T:
return visitor.visit_call_c(self)


@trait
class OpVisitor(Generic[T]):
"""Generic visitor over ops (uses the visitor design pattern)."""
Expand Down Expand Up @@ -1228,6 +1253,10 @@ def visit_unbox(self, op: Unbox) -> T:
def visit_raise_standard_error(self, op: RaiseStandardError) -> T:
raise NotImplementedError

@abstractmethod
def visit_call_c(self, op: CallC) -> T:
raise NotImplementedError


# TODO: Should this live somewhere else?
LiteralsMap = Dict[Tuple[Type[object], Union[int, float, str, bytes, complex]], str]
Expand Down
50 changes: 47 additions & 3 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,22 @@
BasicBlock, Environment, Op, LoadInt, Value, Register,
Assign, Branch, Goto, Call, Box, Unbox, Cast, GetAttr,
LoadStatic, MethodCall, PrimitiveOp, OpDescription, RegisterOp,
NAMESPACE_TYPE, NAMESPACE_MODULE, LoadErrorValue,
NAMESPACE_TYPE, NAMESPACE_MODULE, LoadErrorValue, CallC
)
from mypyc.ir.rtypes import (
RType, RUnion, RInstance, optional_value_type, int_rprimitive, float_rprimitive,
bool_rprimitive, list_rprimitive, str_rprimitive, is_none_rprimitive, object_rprimitive
bool_rprimitive, list_rprimitive, str_rprimitive, is_none_rprimitive, object_rprimitive,
void_rtype
)
from mypyc.ir.func_ir import FuncDecl, FuncSignature
from mypyc.ir.class_ir import ClassIR, all_concrete_classes
from mypyc.common import (
FAST_ISINSTANCE_MAX_SUBCLASSES, MAX_LITERAL_SHORT_INT,
)
from mypyc.primitives.registry import binary_ops, unary_ops, method_ops, func_ops
from mypyc.primitives.registry import (
binary_ops, unary_ops, method_ops, func_ops,
c_method_call_ops, CFunctionDescription
)
from mypyc.primitives.list_ops import (
list_extend_op, list_len_op, new_list_op
)
Expand Down Expand Up @@ -644,6 +648,41 @@ def add_bool_branch(self, value: Value, true: BasicBlock, false: BasicBlock) ->
value = self.primitive_op(bool_op, [value], value.line)
self.add(Branch(value, true, false, Branch.BOOL_EXPR))

def call_c(self,
function_name: str,
args: List[Value],
line: int,
result_type: Optional[RType]) -> Value:
# handle void function via singleton RVoid instance
ret_type = void_rtype if result_type is None else result_type
target = self.add(CallC(function_name, args, ret_type, line))
return target

def matching_call_c(self,
candidates: List[CFunctionDescription],
args: List[Value],
line: int,
result_type: Optional[RType] = None) -> Optional[Value]:
# TODO: this function is very similar to matching_primitive_op
# we should remove the old one or refactor both them into only as we move forward
matching = None # type: Optional[CFunctionDescription]
for desc in candidates:
if len(desc.arg_types) != len(args):
continue
if all(is_subtype(actual.type, formal)
for actual, formal in zip(args, desc.arg_types)):
if matching:
assert matching.priority != desc.priority, 'Ambiguous:\n1) %s\n2) %s' % (
matching, desc)
if desc.priority > matching.priority:
matching = desc
else:
matching = desc
if matching:
target = self.call_c(matching.c_function_name, args, line, result_type)
return target
return None

# Internal helpers

def decompose_union_helper(self,
Expand Down Expand Up @@ -728,6 +767,11 @@ def translate_special_method_call(self,
Return None if no translation found; otherwise return the target register.
"""
ops = method_ops.get(name, [])
call_c_ops_candidates = c_method_call_ops.get(name, [])
call_c_op = self.matching_call_c(call_c_ops_candidates, [base_reg] + args, line,
result_type=result_type)
if call_c_op is not None:
return call_c_op
return self.matching_primitive_op(ops, [base_reg] + args, line, result_type=result_type)

def translate_eq_cmp(self,
Expand Down
23 changes: 22 additions & 1 deletion mypyc/primitives/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,20 @@
optimized implementations of all ops.
"""

from typing import Dict, List, Optional
from typing import Dict, List, Optional, NamedTuple

from mypyc.ir.ops import (
OpDescription, EmitterInterface, EmitCallback, StealsDescription, short_name
)
from mypyc.ir.rtypes import RType, bool_rprimitive

CFunctionDescription = NamedTuple(
'CFunctionDescription', [('name', str),
('arg_types', List[RType]),
('result_type', Optional[RType]),
('c_function_name', str),
('error_kind', int),
('priority', int)])

# Primitive binary ops (key is operator such as '+')
binary_ops = {} # type: Dict[str, List[OpDescription]]
Expand All @@ -58,6 +65,8 @@
# Primitive ops for reading module attributes (key is name such as 'builtins.None')
name_ref_ops = {} # type: Dict[str, OpDescription]

c_method_call_ops = {} # type: Dict[str, List[CFunctionDescription]]


def simple_emit(template: str) -> EmitCallback:
"""Construct a simple PrimitiveOp emit callback function.
Expand Down Expand Up @@ -312,6 +321,18 @@ def custom_op(arg_types: List[RType],
emit, steals, is_borrowed, 0)


def c_method_op(name: str,
arg_types: List[RType],
result_type: Optional[RType],
c_function_name: str,
error_kind: int,
priority: int = 1) -> None:
ops = c_method_call_ops.setdefault(name, [])
desc = CFunctionDescription(name, arg_types, result_type,
c_function_name, error_kind, priority)
ops.append(desc)


# Import various modules that set up global state.
import mypyc.primitives.int_ops # noqa
import mypyc.primitives.str_ops # noqa
Expand Down
8 changes: 5 additions & 3 deletions mypyc/primitives/str_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
from mypyc.primitives.registry import (
func_op, binary_op, simple_emit, name_ref_op, method_op, call_emit, name_emit,
c_method_op
)


Expand All @@ -33,12 +34,13 @@
emit=call_emit('PyUnicode_Concat'))

# str.join(obj)
method_op(
c_method_op(
name='join',
arg_types=[str_rprimitive, object_rprimitive],
result_type=str_rprimitive,
error_kind=ERR_MAGIC,
emit=call_emit('PyUnicode_Join'))
c_function_name='PyUnicode_Join',
error_kind=ERR_MAGIC
)

# str[index] (for an int index)
method_op(
Expand Down
12 changes: 12 additions & 0 deletions mypyc/test-data/irbuild-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -3383,3 +3383,15 @@ L0:
r5 = None
return r5

[case testCallCWithStrJoin]
from typing import List
def f(x: str, y: List[str]) -> str:
return x.join(y)
[out]
def f(x, y):
x :: str
y :: list
r0 :: str
L0:
r0 = PyUnicode_Join(x, y)
return r0

0 comments on commit f94fc7e

Please sign in to comment.