Skip to content

Commit

Permalink
[mypyc] Add primitives for list, str and tuple slicing (#9283)
Browse files Browse the repository at this point in the history
This speeds up some microbenchmarks from 40% (list) to 100% (str) 
when I run them on Ubuntu 20.04.

Non-default strides aren't optimized, since they are fairly rare.

`a[::-1]` for lists might be worth special casing in the future. Also, once
we have primitives for `bytes`, it should also be special cased.

Fixes mypyc/mypyc#725.
  • Loading branch information
JukkaL committed Sep 27, 2020
1 parent fa538f8 commit 765acca
Show file tree
Hide file tree
Showing 22 changed files with 286 additions and 16 deletions.
7 changes: 6 additions & 1 deletion mypyc/common.py
@@ -1,5 +1,6 @@
import sys
from typing import Dict, Any
import sys

from typing_extensions import Final

Expand All @@ -22,7 +23,11 @@

# Max short int we accept as a literal is based on 32-bit platforms,
# so that we can just always emit the same code.
MAX_LITERAL_SHORT_INT = (1 << 30) - 1 # type: Final

# Maximum value for a short tagged integer.
#
# Note: Assume that the compiled code uses the same bit width as mypyc.
MAX_LITERAL_SHORT_INT = sys.maxsize >> 1 # type: Final

TOP_LEVEL_NAME = '__top_level__' # type: Final # Special function representing module top level

Expand Down
4 changes: 4 additions & 0 deletions mypyc/doc/list_operations.rst
Expand Up @@ -29,6 +29,10 @@ Get item by integer index:

* ``lst[n]``

Slicing:

* ``lst[n:m]``, ``lst[n:]``, ``lst[:m]``, ``lst[:]``

Repeat list ``n`` times:

* ``lst * n``, ``n * lst``
Expand Down
4 changes: 4 additions & 0 deletions mypyc/doc/str_operations.rst
Expand Up @@ -24,6 +24,10 @@ Indexing:

* ``s[n]`` (integer index)

Slicing:

* ``s[n:m]``, ``s[n:]``, ``s[:m]``

Comparisons:

* ``s1 == s2``, ``s1 != s2``
Expand Down
1 change: 1 addition & 0 deletions mypyc/doc/tuple_operations.rst
Expand Up @@ -20,6 +20,7 @@ Operators
---------

* ``tup[n]`` (integer index)
* ``tup[n:m]``, ``tup[n:]``, ``tup[:m]`` (slicing)

Statements
----------
Expand Down
3 changes: 3 additions & 0 deletions mypyc/irbuild/builder.py
Expand Up @@ -185,6 +185,9 @@ def py_get_attr(self, obj: Value, attr: str, line: int) -> Value:
def load_static_unicode(self, value: str) -> Value:
return self.builder.load_static_unicode(value)

def load_static_int(self, value: int) -> Value:
return self.builder.load_static_int(value)

def primitive_op(self, desc: OpDescription, args: List[Value], line: int) -> Value:
return self.builder.primitive_op(desc, args, line)

Expand Down
58 changes: 53 additions & 5 deletions mypyc/irbuild/expression.py
Expand Up @@ -15,18 +15,22 @@
)
from mypy.types import TupleType, get_proper_type

from mypyc.common import MAX_LITERAL_SHORT_INT
from mypyc.ir.ops import (
Value, TupleGet, TupleSet, BasicBlock, Assign, LoadAddress
)
from mypyc.ir.rtypes import RTuple, object_rprimitive, is_none_rprimitive, is_int_rprimitive
from mypyc.ir.rtypes import (
RTuple, object_rprimitive, is_none_rprimitive, int_rprimitive, is_int_rprimitive
)
from mypyc.ir.func_ir import FUNC_CLASSMETHOD, FUNC_STATICMETHOD
from mypyc.primitives.registry import CFunctionDescription, builtin_names
from mypyc.primitives.generic_ops import iter_op
from mypyc.primitives.misc_ops import new_slice_op, ellipsis_op, type_op
from mypyc.primitives.list_ops import list_append_op, list_extend_op
from mypyc.primitives.tuple_ops import list_tuple_op
from mypyc.primitives.list_ops import list_append_op, list_extend_op, list_slice_op
from mypyc.primitives.tuple_ops import list_tuple_op, tuple_slice_op
from mypyc.primitives.dict_ops import dict_new_op, dict_set_item_op
from mypyc.primitives.set_ops import new_set_op, set_add_op, set_update_op
from mypyc.primitives.str_ops import str_slice_op
from mypyc.primitives.int_ops import int_comparison_op_mapping
from mypyc.irbuild.specialize import specializers
from mypyc.irbuild.builder import IRBuilder
Expand Down Expand Up @@ -323,15 +327,59 @@ def transform_op_expr(builder: IRBuilder, expr: OpExpr) -> Value:

def transform_index_expr(builder: IRBuilder, expr: IndexExpr) -> Value:
base = builder.accept(expr.base)
index = expr.index

if isinstance(base.type, RTuple) and isinstance(index, IntExpr):
return builder.add(TupleGet(base, index.value, expr.line))

if isinstance(base.type, RTuple) and isinstance(expr.index, IntExpr):
return builder.add(TupleGet(base, expr.index.value, expr.line))
if isinstance(index, SliceExpr):
value = try_gen_slice_op(builder, base, index)
if value:
return value

index_reg = builder.accept(expr.index)
return builder.gen_method_call(
base, '__getitem__', [index_reg], builder.node_type(expr), expr.line)


def try_gen_slice_op(builder: IRBuilder, base: Value, index: SliceExpr) -> Optional[Value]:
"""Generate specialized slice op for some index expressions.
Return None if a specialized op isn't available.
This supports obj[x:y], obj[:x], and obj[x:] for a few types.
"""
if index.stride:
# We can only handle the default stride of 1.
return None

if index.begin_index:
begin_type = builder.node_type(index.begin_index)
else:
begin_type = int_rprimitive
if index.end_index:
end_type = builder.node_type(index.end_index)
else:
end_type = int_rprimitive

# Both begin and end index must be int (or missing).
if is_int_rprimitive(begin_type) and is_int_rprimitive(end_type):
if index.begin_index:
begin = builder.accept(index.begin_index)
else:
begin = builder.load_static_int(0)
if index.end_index:
end = builder.accept(index.end_index)
else:
# Replace missing end index with the largest short integer
# (a sequence can't be longer).
end = builder.load_static_int(MAX_LITERAL_SHORT_INT)
candidates = [list_slice_op, tuple_slice_op, str_slice_op]
return builder.builder.matching_call_c(candidates, [base, begin, end], index.line)

return None


def transform_conditional_expr(builder: IRBuilder, expr: ConditionalExpr) -> Value:
if_body, else_body, next = BasicBlock(), BasicBlock(), BasicBlock()

Expand Down
4 changes: 4 additions & 0 deletions mypyc/lib-rt/CPy.h
Expand Up @@ -303,6 +303,7 @@ CPyTagged CPyObject_Hash(PyObject *o);
PyObject *CPyObject_GetAttr3(PyObject *v, PyObject *name, PyObject *defl);
PyObject *CPyIter_Next(PyObject *iter);
PyObject *CPyNumber_Power(PyObject *base, PyObject *index);
PyObject *CPyObject_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end);


// List operations
Expand All @@ -318,6 +319,7 @@ CPyTagged CPyList_Count(PyObject *obj, PyObject *value);
PyObject *CPyList_Extend(PyObject *o1, PyObject *o2);
PyObject *CPySequence_Multiply(PyObject *seq, CPyTagged t_size);
PyObject *CPySequence_RMultiply(CPyTagged t_size, PyObject *seq);
PyObject *CPyList_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end);


// Dict operations
Expand Down Expand Up @@ -367,6 +369,7 @@ static inline char CPyDict_CheckSize(PyObject *dict, CPyTagged size) {
PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index);
PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split);
PyObject *CPyStr_Append(PyObject *o1, PyObject *o2);
PyObject *CPyStr_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end);


// Set operations
Expand All @@ -379,6 +382,7 @@ bool CPySet_Remove(PyObject *set, PyObject *key);


PyObject *CPySequenceTuple_GetItem(PyObject *tuple, CPyTagged index);
PyObject *CPySequenceTuple_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end);


// Exception operations
Expand Down
17 changes: 17 additions & 0 deletions mypyc/lib-rt/generic_ops.c
Expand Up @@ -40,3 +40,20 @@ PyObject *CPyNumber_Power(PyObject *base, PyObject *index)
{
return PyNumber_Power(base, index, Py_None);
}

PyObject *CPyObject_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end) {
PyObject *start_obj = CPyTagged_AsObject(start);
PyObject *end_obj = CPyTagged_AsObject(end);
if (unlikely(start_obj == NULL || end_obj == NULL)) {
return NULL;
}
PyObject *slice = PySlice_New(start_obj, end_obj, NULL);
Py_DECREF(start_obj);
Py_DECREF(end_obj);
if (unlikely(slice == NULL)) {
return NULL;
}
PyObject *result = PyObject_GetItem(obj, slice);
Py_DECREF(slice);
return result;
}
16 changes: 16 additions & 0 deletions mypyc/lib-rt/list_ops.c
Expand Up @@ -123,3 +123,19 @@ PyObject *CPySequence_Multiply(PyObject *seq, CPyTagged t_size) {
PyObject *CPySequence_RMultiply(CPyTagged t_size, PyObject *seq) {
return CPySequence_Multiply(seq, t_size);
}

PyObject *CPyList_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end) {
if (likely(PyList_CheckExact(obj)
&& CPyTagged_CheckShort(start) && CPyTagged_CheckShort(end))) {
Py_ssize_t startn = CPyTagged_ShortAsSsize_t(start);
Py_ssize_t endn = CPyTagged_ShortAsSsize_t(end);
if (startn < 0) {
startn += PyList_GET_SIZE(obj);
}
if (endn < 0) {
endn += PyList_GET_SIZE(obj);
}
return PyList_GetSlice(obj, startn, endn);
}
return CPyObject_GetSlice(obj, start, end);
}
2 changes: 1 addition & 1 deletion mypyc/lib-rt/setup.py
Expand Up @@ -17,7 +17,7 @@
version='0.1',
ext_modules=[Extension(
'test_capi',
['test_capi.cc', 'init.c', 'int_ops.c', 'list_ops.c', 'exc_ops.c'],
['test_capi.cc', 'init.c', 'int_ops.c', 'list_ops.c', 'exc_ops.c', 'generic_ops.c'],
depends=['CPy.h', 'mypyc_util.h', 'pythonsupport.h'],
extra_compile_args=['-Wno-unused-function', '-Wno-sign-compare'] + compile_args,
library_dirs=['../external/googletest/make'],
Expand Down
22 changes: 22 additions & 0 deletions mypyc/lib-rt/str_ops.c
Expand Up @@ -58,3 +58,25 @@ PyObject *CPyStr_Append(PyObject *o1, PyObject *o2) {
PyUnicode_Append(&o1, o2);
return o1;
}

PyObject *CPyStr_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end) {
if (likely(PyUnicode_CheckExact(obj)
&& CPyTagged_CheckShort(start) && CPyTagged_CheckShort(end))) {
Py_ssize_t startn = CPyTagged_ShortAsSsize_t(start);
Py_ssize_t endn = CPyTagged_ShortAsSsize_t(end);
if (startn < 0) {
startn += PyUnicode_GET_LENGTH(obj);
if (startn < 0) {
startn = 0;
}
}
if (endn < 0) {
endn += PyUnicode_GET_LENGTH(obj);
if (endn < 0) {
endn = 0;
}
}
return PyUnicode_Substring(obj, startn, endn);
}
return CPyObject_GetSlice(obj, start, end);
}
16 changes: 16 additions & 0 deletions mypyc/lib-rt/tuple_ops.c
Expand Up @@ -29,3 +29,19 @@ PyObject *CPySequenceTuple_GetItem(PyObject *tuple, CPyTagged index) {
return NULL;
}
}

PyObject *CPySequenceTuple_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end) {
if (likely(PyTuple_CheckExact(obj)
&& CPyTagged_CheckShort(start) && CPyTagged_CheckShort(end))) {
Py_ssize_t startn = CPyTagged_ShortAsSsize_t(start);
Py_ssize_t endn = CPyTagged_ShortAsSsize_t(end);
if (startn < 0) {
startn += PyTuple_GET_SIZE(obj);
}
if (endn < 0) {
endn += PyTuple_GET_SIZE(obj);
}
return PyTuple_GetSlice(obj, startn, endn);
}
return CPyObject_GetSlice(obj, start, end);
}
8 changes: 8 additions & 0 deletions mypyc/primitives/list_ops.py
Expand Up @@ -128,3 +128,11 @@ def emit_len(emitter: EmitterInterface, args: List[str], dest: str) -> None:
emitter.emit_declaration('Py_ssize_t %s;' % temp)
emitter.emit_line('%s = PyList_GET_SIZE(%s);' % (temp, args[0]))
emitter.emit_line('%s = CPyTagged_ShortFromSsize_t(%s);' % (dest, temp))


# list[begin:end]
list_slice_op = c_custom_op(
arg_types=[list_rprimitive, int_rprimitive, int_rprimitive],
return_type=object_rprimitive,
c_function_name='CPyList_GetSlice',
error_kind=ERR_MAGIC,)
5 changes: 3 additions & 2 deletions mypyc/primitives/registry.py
Expand Up @@ -151,7 +151,8 @@ def custom_op(arg_types: List[RType],
format_str: Optional[str] = None,
steals: StealsDescription = False,
is_borrowed: bool = False,
is_var_arg: bool = False) -> OpDescription:
is_var_arg: bool = False,
priority: int = 1) -> OpDescription:
"""Create a one-off op that can't be automatically generated from the AST.
Note that if the format_str argument is not provided, then a
Expand All @@ -174,7 +175,7 @@ def custom_op(arg_types: List[RType],
typename)
assert format_str is not None
return OpDescription('<custom>', arg_types, result_type, is_var_arg, error_kind, format_str,
emit, steals, is_borrowed, 0)
emit, steals, is_borrowed, priority)


def c_method_op(name: str,
Expand Down
8 changes: 7 additions & 1 deletion mypyc/primitives/str_ops.py
Expand Up @@ -79,9 +79,15 @@
error_kind=ERR_MAGIC,
steals=[True, False])


unicode_compare = c_custom_op(
arg_types=[str_rprimitive, str_rprimitive],
return_type=c_int_rprimitive,
c_function_name='PyUnicode_Compare',
error_kind=ERR_NEVER)

# str[begin:end]
str_slice_op = c_custom_op(
arg_types=[str_rprimitive, int_rprimitive, int_rprimitive],
return_type=object_rprimitive,
c_function_name='CPyStr_GetSlice',
error_kind=ERR_MAGIC)
11 changes: 8 additions & 3 deletions mypyc/primitives/tuple_ops.py
Expand Up @@ -8,9 +8,7 @@
from mypyc.ir.rtypes import (
tuple_rprimitive, int_rprimitive, list_rprimitive, object_rprimitive, c_pyssize_t_rprimitive
)
from mypyc.primitives.registry import (
c_method_op, c_function_op, c_custom_op
)
from mypyc.primitives.registry import c_method_op, c_function_op, c_custom_op


# tuple[index] (for an int index)
Expand Down Expand Up @@ -45,3 +43,10 @@
return_type=tuple_rprimitive,
c_function_name='PySequence_Tuple',
error_kind=ERR_MAGIC)

# tuple[begin:end]
tuple_slice_op = c_custom_op(
arg_types=[tuple_rprimitive, int_rprimitive, int_rprimitive],
return_type=object_rprimitive,
c_function_name='CPySequenceTuple_GetSlice',
error_kind=ERR_MAGIC)
9 changes: 9 additions & 0 deletions mypyc/test-data/fixtures/ir.py
Expand Up @@ -45,6 +45,9 @@ def __le__(self, n: int) -> bool: pass
def __ge__(self, n: int) -> bool: pass

class str:
@overload
def __init__(self) -> None: pass
@overload
def __init__(self, x: object) -> None: pass
def __add__(self, x: str) -> str: pass
def __eq__(self, x: object) -> bool: pass
Expand All @@ -53,7 +56,10 @@ def __lt__(self, x: str) -> bool: ...
def __le__(self, x: str) -> bool: ...
def __gt__(self, x: str) -> bool: ...
def __ge__(self, x: str) -> bool: ...
@overload
def __getitem__(self, i: int) -> str: pass
@overload
def __getitem__(self, i: slice) -> str: pass
def __contains__(self, item: str) -> bool: pass
def __iter__(self) -> Iterator[str]: ...
def split(self, sep: Optional[str] = None, max: Optional[int] = None) -> List[str]: pass
Expand Down Expand Up @@ -89,7 +95,10 @@ def __init__(self, o: object = ...) -> None: ...

class tuple(Generic[T_co], Sequence[T_co], Iterable[T_co]):
def __init__(self, i: Iterable[T_co]) -> None: pass
@overload
def __getitem__(self, i: int) -> T_co: pass
@overload
def __getitem__(self, i: slice) -> Tuple[T_co, ...]: pass
def __len__(self) -> int: pass
def __iter__(self) -> Iterator[T_co]: ...
def __contains__(self, item: object) -> int: ...
Expand Down

0 comments on commit 765acca

Please sign in to comment.