Skip to content

Commit

Permalink
perf: levenshtein optimization (#3780)
Browse files Browse the repository at this point in the history
optimize compile time. `levenshtein` is a hotspot since it is called a
lot during type analysis to construct exceptions (which are then caught
as part of the validation routines). this commit delays calling
`levenshtein` until the last minute, and also adds a mechanism to
`VyperException` so that hints can be constructed lazily in general.

on a couple test contracts, compilation time comes down 7%. however, as
a portion of the time spent in the frontend, compilation time comes down
20-30%. this will become important as projects become larger (that is,
many imports but only some functions are actually present in codegen)
and compilation time is dominated by the frontend.
  • Loading branch information
charles-cooper committed Feb 14, 2024
1 parent 8e5e1c2 commit 4b4e188
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 23 deletions.
2 changes: 1 addition & 1 deletion tests/functional/syntax/modules/test_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,4 +1178,4 @@ def test_ownership_decl_errors_not_swallowed(make_input_bundle):
input_bundle = make_input_bundle({"lib1.vy": lib1})
with pytest.raises(UndeclaredDefinition) as e:
compile_code(main, input_bundle=input_bundle)
assert e.value._message == "'lib2' has not been declared. "
assert e.value._message == "'lib2' has not been declared."
39 changes: 32 additions & 7 deletions tests/functional/syntax/test_for_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def foo():
""",
StructureException,
"Invalid syntax for loop iterator",
None,
"a[1]",
),
(
Expand All @@ -32,6 +33,7 @@ def bar():
""",
StructureException,
"Bound must be at least 1",
None,
"0",
),
(
Expand All @@ -44,6 +46,7 @@ def foo():
""",
StateAccessViolation,
"Bound must be a literal",
None,
"x",
),
(
Expand All @@ -55,6 +58,7 @@ def foo():
""",
StructureException,
"Please remove the `bound=` kwarg when using range with constants",
None,
"5",
),
(
Expand All @@ -66,6 +70,7 @@ def foo():
""",
StructureException,
"Bound must be at least 1",
None,
"0",
),
(
Expand All @@ -78,6 +83,7 @@ def bar():
""",
ArgumentException,
"Invalid keyword argument 'extra'",
None,
"extra=3",
),
(
Expand All @@ -89,6 +95,7 @@ def bar():
""",
StructureException,
"End must be greater than start",
None,
"0",
),
(
Expand All @@ -101,6 +108,7 @@ def bar():
""",
StateAccessViolation,
"Value must be a literal integer, unless a bound is specified",
None,
"x",
),
(
Expand All @@ -113,6 +121,7 @@ def bar():
""",
StateAccessViolation,
"Value must be a literal integer, unless a bound is specified",
None,
"x",
),
(
Expand All @@ -125,6 +134,7 @@ def repeat(n: uint256) -> uint256:
""",
StateAccessViolation,
"Value must be a literal integer, unless a bound is specified",
None,
"n * 10",
),
(
Expand All @@ -137,6 +147,7 @@ def bar():
""",
StateAccessViolation,
"Value must be a literal integer, unless a bound is specified",
None,
"x + 1",
),
(
Expand All @@ -148,6 +159,7 @@ def bar():
""",
StructureException,
"End must be greater than start",
None,
"1",
),
(
Expand All @@ -160,6 +172,7 @@ def bar():
""",
StateAccessViolation,
"Value must be a literal integer, unless a bound is specified",
None,
"x",
),
(
Expand All @@ -172,6 +185,7 @@ def foo():
""",
StateAccessViolation,
"Value must be a literal integer, unless a bound is specified",
None,
"x",
),
(
Expand All @@ -184,6 +198,7 @@ def repeat(n: uint256) -> uint256:
""",
StateAccessViolation,
"Value must be a literal integer, unless a bound is specified",
None,
"n",
),
(
Expand All @@ -196,6 +211,7 @@ def foo(x: int128):
""",
StateAccessViolation,
"Value must be a literal integer, unless a bound is specified",
None,
"x",
),
(
Expand All @@ -207,6 +223,7 @@ def bar(x: uint256):
""",
StateAccessViolation,
"Value must be a literal integer, unless a bound is specified",
None,
"x",
),
(
Expand All @@ -221,6 +238,7 @@ def foo():
""",
TypeMismatch,
"Given reference has type int128, expected uint256",
None,
"FOO",
),
(
Expand All @@ -234,6 +252,7 @@ def foo():
""",
StructureException,
"Bound must be at least 1",
None,
"FOO",
),
(
Expand All @@ -244,7 +263,8 @@ def foo():
pass
""",
UnknownType,
"No builtin or user-defined type named 'DynArra'. Did you mean 'DynArray'?",
"No builtin or user-defined type named 'DynArra'.",
"Did you mean 'DynArray'?",
"DynArra",
),
(
Expand All @@ -262,7 +282,8 @@ def foo():
pass
""",
UnknownType,
"No builtin or user-defined type named 'uint9'. Did you mean 'uint96', or maybe 'uint8'?",
"No builtin or user-defined type named 'uint9'.",
"Did you mean 'uint96', or maybe 'uint8'?",
"uint9",
),
(
Expand All @@ -278,7 +299,8 @@ def foo():
pass
""",
UnknownType,
"No builtin or user-defined type named 'uint9'. Did you mean 'uint96', or maybe 'uint8'?",
"No builtin or user-defined type named 'uint9'.",
"Did you mean 'uint96', or maybe 'uint8'?",
"uint9",
),
]
Expand All @@ -289,15 +311,18 @@ def foo():
f"{i:02d}: {for_code_regex.search(code).group(1)}" # type: ignore[union-attr]
f" raises {type(err).__name__}"
)
for i, (code, err, msg, src) in enumerate(fail_list)
for i, (code, err, msg, hint, src) in enumerate(fail_list)
]


@pytest.mark.parametrize("bad_code,error_type,message,source_code", fail_list, ids=fail_test_names)
def test_range_fail(bad_code, error_type, message, source_code):
@pytest.mark.parametrize(
"bad_code,error_type,message,hint,source_code", fail_list, ids=fail_test_names
)
def test_range_fail(bad_code, error_type, message, hint, source_code):
with pytest.raises(error_type) as exc_info:
compiler.compile_code(bad_code)
assert message == exc_info.value.message
assert message == exc_info.value._message
assert hint == exc_info.value.hint
assert source_code == exc_info.value.args[1].get_original_node().node_source_code


Expand Down
13 changes: 11 additions & 2 deletions vyper/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,20 @@ def with_annotation(self, *annotations):
exc.annotations = annotations
return exc

@property
def hint(self):
# some hints are expensive to compute, so we wait until the last
# minute when the formatted message is actually requested to compute
# them.
if callable(self._hint):
return self._hint()
return self._hint

@property
def message(self):
msg = self._message
if self._hint:
msg += f"\n\n (hint: {self._hint})"
if self.hint:
msg += f"\n\n (hint: {self.hint})"
return msg

def __str__(self):
Expand Down
10 changes: 8 additions & 2 deletions vyper/semantics/analysis/levenshtein_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Callable


def levenshtein_norm(source: str, target: str) -> float:
Expand Down Expand Up @@ -73,7 +73,13 @@ def levenshtein(source: str, target: str) -> int:
return matrix[len(source)][len(target)]


def get_levenshtein_error_suggestions(key: str, namespace: Dict[str, Any], threshold: float) -> str:
def get_levenshtein_error_suggestions(*args, **kwargs) -> Callable:
return lambda: _get_levenshtein_error_suggestions(*args, **kwargs)


def _get_levenshtein_error_suggestions(
key: str, namespace: dict[str, Any], threshold: float
) -> str:
"""
Generate an error message snippet for the suggested closest values in the provided namespace
with the shortest normalized Levenshtein distance from the given key if that distance
Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,9 @@ def _raise_invalid_reference(name, node):
if name in self.namespace:
_raise_invalid_reference(name, node)

suggestions_str = get_levenshtein_error_suggestions(name, t.members, 0.4)
hint = get_levenshtein_error_suggestions(name, t.members, 0.4)
raise UndeclaredDefinition(
f"Storage variable '{name}' has not been declared. {suggestions_str}", node
f"Storage variable '{name}' has not been declared.", node, hint=hint
) from None

def types_from_BinOp(self, node):
Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def __setitem__(self, attr, obj):

def __getitem__(self, key):
if key not in self:
suggestions_str = get_levenshtein_error_suggestions(key, self, 0.2)
raise UndeclaredDefinition(f"'{key}' has not been declared. {suggestions_str}")
hint = get_levenshtein_error_suggestions(key, self, 0.2)
raise UndeclaredDefinition(f"'{key}' has not been declared.", hint=hint)
return super().__getitem__(key)

def __enter__(self):
Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,8 @@ def get_member(self, key: str, node: vy_ast.VyperNode) -> "VyperType":
if not self.members:
raise StructureException(f"{self} instance does not have members", node)

suggestions_str = get_levenshtein_error_suggestions(key, self.members, 0.3)
raise UnknownAttribute(f"{self} has no member '{key}'. {suggestions_str}", node)
hint = get_levenshtein_error_suggestions(key, self.members, 0.3)
raise UnknownAttribute(f"{self} has no member '{key}'.", node, hint=hint)

def __repr__(self):
return self._id
Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/types/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,9 @@ def _ctor_call_return(self, node: vy_ast.Call) -> "StructT":
keys = list(self.member_types.keys())
for i, (key, value) in enumerate(zip(node.args[0].keys, node.args[0].values)):
if key is None or key.get("id") not in members:
suggestions_str = get_levenshtein_error_suggestions(key.get("id"), members, 1.0)
hint = get_levenshtein_error_suggestions(key.get("id"), members, 1.0)
raise UnknownAttribute(
f"Unknown or duplicate struct member. {suggestions_str}", key or value
"Unknown or duplicate struct member.", key or value, hint=hint
)
expected_key = keys[i]
if key.id != expected_key:
Expand Down
5 changes: 2 additions & 3 deletions vyper/semantics/types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,9 @@ def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType:
raise InvalidType(err_msg, node)

if node.id not in namespace: # type: ignore
suggestions_str = get_levenshtein_error_suggestions(node.node_source_code, namespace, 0.3)
hint = get_levenshtein_error_suggestions(node.node_source_code, namespace, 0.3)
raise UnknownType(
f"No builtin or user-defined type named '{node.node_source_code}'. {suggestions_str}",
node,
f"No builtin or user-defined type named '{node.node_source_code}'.", node, hint=hint
) from None

typ_ = namespace[node.id]
Expand Down

0 comments on commit 4b4e188

Please sign in to comment.