New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic support for user-defined mypy plugins #3517

Merged
merged 6 commits into from Jun 13, 2017
Jump to file or symbol
Failed to load files and symbols.
+253 −18
Diff settings

Always

Just for now

View
@@ -42,7 +42,7 @@
from mypy.stats import dump_type_stats
from mypy.types import Type
from mypy.version import __version__
from mypy.plugin import DefaultPlugin
from mypy.plugin import Plugin, DefaultPlugin, ChainedPlugin
# We need to know the location of this file to load data, but
@@ -183,7 +183,9 @@ def build(sources: List[BuildSource],
reports=reports,
options=options,
version_id=__version__,
)
plugin=DefaultPlugin(options.python_version))
manager.plugin = load_custom_plugins(manager.plugin, options, manager.errors)
try:
graph = dispatch(sources, manager)
@@ -333,6 +335,67 @@ def import_priority(imp: ImportBase, toplevel_priority: int) -> int:
return toplevel_priority
def load_custom_plugins(default_plugin: Plugin, options: Options, errors: Errors) -> Plugin:
"""Load custom plugins if any are configured.
Return a plugin that chains all custom plugins (if any) and falls
back to default_plugin.
"""
def plugin_error(message: str) -> None:
errors.report(0, 0, message)
errors.raise_error()
custom_plugins = []
for plugin_path in options.plugins:
if options.config_file:
# Plugin paths are relative to the config file location.
plugin_path = os.path.join(os.path.dirname(options.config_file), plugin_path)
errors.set_file(plugin_path, None)
if not os.path.isfile(plugin_path):
plugin_error("Can't find plugin")
plugin_dir = os.path.dirname(plugin_path)
fnam = os.path.basename(plugin_path)
if not fnam.endswith('.py'):
plugin_error("Plugin must have .py extension")
module_name = fnam[:-3]
import importlib
sys.path.insert(0, plugin_dir)
try:
m = importlib.import_module(module_name)
except Exception:
print('Error importing plugin {}\n'.format(plugin_path))
raise # Propagate to display traceback
finally:
assert sys.path[0] == plugin_dir
del sys.path[0]
if not hasattr(m, 'plugin'):
plugin_error('Plugin does not define entry point function "plugin"')
try:
plugin_type = getattr(m, 'plugin')(__version__)
except Exception:
print('Error calling the plugin(version) entry point of {}\n'.format(plugin_path))
raise # Propagate to display traceback
if not isinstance(plugin_type, type):
plugin_error(
'Type object expected as the return value of "plugin" (got {!r})'.format(
plugin_type))
if not issubclass(plugin_type, Plugin):
plugin_error(
'Return value of "plugin" must be a subclass of "mypy.plugin.Plugin"')
try:
custom_plugins.append(plugin_type(options.python_version))
except Exception:
print('Error constructing plugin instance of {}\n'.format(plugin_type.__name__))
raise # Propagate to display traceback
if not custom_plugins:
return default_plugin
else:
# Custom plugins take precendence over built-in plugins.
return ChainedPlugin(options.python_version, custom_plugins + [default_plugin])
# TODO: Get rid of all_types. It's not used except for one log message.
# Maybe we could instead publish a map from module ID to its type_map.
class BuildManager:
@@ -356,6 +419,7 @@ class BuildManager:
missing_modules: Set of modules that could not be imported encountered so far
stale_modules: Set of modules that needed to be rechecked
version_id: The current mypy version (based on commit id when possible)
plugin: Active mypy plugin(s)
"""
def __init__(self, data_dir: str,
@@ -364,7 +428,8 @@ def __init__(self, data_dir: str,
source_set: BuildSourceSet,
reports: Reports,
options: Options,
version_id: str) -> None:
version_id: str,
plugin: Plugin) -> None:
self.start_time = time.time()
self.data_dir = data_dir
self.errors = Errors(options.show_error_context, options.show_column_numbers)
@@ -384,6 +449,7 @@ def __init__(self, data_dir: str,
self.indirection_detector = TypeIndirectionVisitor()
self.stale_modules = set() # type: Set[str]
self.rechecked_modules = set() # type: Set[str]
self.plugin = plugin
def maybe_swap_for_shadow_path(self, path: str) -> str:
if (self.options.shadow_file and
@@ -1506,9 +1572,8 @@ def type_check_first_pass(self) -> None:
if self.options.semantic_analysis_only:
return
with self.wrap_context():
plugin = DefaultPlugin(self.options.python_version)
self.type_checker = TypeChecker(manager.errors, manager.modules, self.options,
self.tree, self.xpath, plugin)
self.tree, self.xpath, manager.plugin)
self.type_checker.check_first_pass()
def type_check_second_pass(self) -> bool:
View
@@ -372,7 +372,7 @@ def disallow_any_argument_type(raw_options: str) -> List[str]:
parser.parse_args(args, dummy)
config_file = dummy.config_file
if config_file is not None and not os.path.exists(config_file):
parser.error("Cannot file config file '%s'" % config_file)
parser.error("Cannot find config file '%s'" % config_file)
# Parse config file first, so command line can override.
options = Options()
@@ -605,6 +605,7 @@ def get_init_file(dir: str) -> Optional[str]:
# These two are for backwards compatibility
'silent_imports': bool,
'almost_silent': bool,
'plugins': lambda s: [p.strip() for p in s.split(',')],
}
SHARED_CONFIG_FILES = ('setup.cfg',)
View
@@ -113,6 +113,9 @@ def __init__(self) -> None:
self.debug_cache = False
self.quick_and_dirty = False
# Paths of user plugins
self.plugins = [] # type: List[str]
# Per-module options (raw)
self.per_module_options = {} # type: Dict[Pattern[str], Dict[str, object]]
View
@@ -1,4 +1,4 @@
from typing import Callable, List, Tuple, Optional, NamedTuple
from typing import Callable, List, Tuple, Optional, NamedTuple, TypeVar
from mypy.nodes import Expression, StrExpr, IntExpr, UnaryExpr, Context
from mypy.types import (
@@ -60,7 +60,7 @@
class Plugin:
"""Base class of type checker plugins.
"""Base class of all type checker plugins.
This defines a no-op plugin. Subclasses can override some methods to
provide some actual functionality.
@@ -69,8 +69,6 @@ class Plugin:
results might be cached).
"""
# TODO: Way of chaining multiple plugins
def __init__(self, python_version: Tuple[int, int]) -> None:
self.python_version = python_version
@@ -86,6 +84,46 @@ def get_method_hook(self, fullname: str) -> Optional[MethodHook]:
# TODO: metaclass / class decorator hook
T = TypeVar('T')
class ChainedPlugin(Plugin):
"""A plugin that represents a sequence of chained plugins.
Each lookup method returns the hook for the first plugin that
reports a match.
This class should not be subclassed -- use Plugin as the base class
for all plugins.
"""
# TODO: Support caching of lookup results (through a LRU cache, for example).
def __init__(self, python_version: Tuple[int, int], plugins: List[Plugin]) -> None:
"""Initialize chained plugin.
Assume that the child plugins aren't mutated (results may be cached).
"""
super().__init__(python_version)
self._plugins = plugins
def get_function_hook(self, fullname: str) -> Optional[FunctionHook]:
return self._find_hook(lambda plugin: plugin.get_function_hook(fullname))
def get_method_signature_hook(self, fullname: str) -> Optional[MethodSignatureHook]:
return self._find_hook(lambda plugin: plugin.get_method_signature_hook(fullname))
def get_method_hook(self, fullname: str) -> Optional[MethodHook]:
return self._find_hook(lambda plugin: plugin.get_method_hook(fullname))
def _find_hook(self, lookup: Callable[[Plugin], T]) -> Optional[T]:

This comment has been minimized.

@chadrik

chadrik Jun 9, 2017

Contributor

The result of this should probably be cached based on hook-type and fullname.

@chadrik

chadrik Jun 9, 2017

Contributor

The result of this should probably be cached based on hook-type and fullname.

This comment has been minimized.

@JukkaL

JukkaL Jun 13, 2017

Collaborator

Created an issue about caching (#3533). This may need some analysis or experimentation to decide a caching strategy (e.g. unlimited cache size vs bounded cache size; maximum size of the cache) so I feel that it's better to do it separately.

@JukkaL

JukkaL Jun 13, 2017

Collaborator

Created an issue about caching (#3533). This may need some analysis or experimentation to decide a caching strategy (e.g. unlimited cache size vs bounded cache size; maximum size of the cache) so I feel that it's better to do it separately.

for plugin in self._plugins:
hook = lookup(plugin)
if hook:
return hook
return None
class DefaultPlugin(Plugin):
"""Type checker plugin that is enabled by default."""
View
@@ -13,6 +13,9 @@
from mypy.myunit import TestCase, SkipTestCaseException
root_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), '..', '..'))
def parse_test_cases(
path: str,
perform: Optional[Callable[['DataDrivenTestCase'], None]],
@@ -62,7 +65,9 @@ def parse_test_cases(
# Record an extra file needed for the test case.
arg = p[i].arg
assert arg is not None
file_entry = (join(base_path, arg), '\n'.join(p[i].data))
contents = '\n'.join(p[i].data)
contents = expand_variables(contents)
file_entry = (join(base_path, arg), contents)
if p[i].id == 'file':
files.append(file_entry)
elif p[i].id == 'outfile':
@@ -119,13 +124,15 @@ def parse_test_cases(
deleted_paths.setdefault(num, set()).add(full)
elif p[i].id == 'out' or p[i].id == 'out1':
tcout = p[i].data
if native_sep and os.path.sep == '\\':
tcout = [expand_variables(line) for line in tcout]
if os.path.sep == '\\':
tcout = [fix_win_path(line) for line in tcout]
ok = True
elif re.match(r'out[0-9]*$', p[i].id):
passnum = int(p[i].id[3:])
assert passnum > 1
output = p[i].data
output = [expand_variables(line) for line in output]
if native_sep and os.path.sep == '\\':
output = [fix_win_path(line) for line in output]
tcout2[passnum] = output
@@ -415,6 +422,10 @@ def expand_includes(a: List[str], base_path: str) -> List[str]:
return res
def expand_variables(s: str) -> str:
return s.replace('<ROOT>', root_dir)
def expand_errors(input: List[str], output: List[str], fnam: str) -> None:
"""Transform comments such as '# E: message' or
'# E:3: message' in input.
@@ -445,16 +456,17 @@ def expand_errors(input: List[str], output: List[str], fnam: str) -> None:
def fix_win_path(line: str) -> str:
r"""Changes paths to Windows paths in error messages.
r"""Changes Windows paths to Linux paths in error messages.
E.g. foo/bar.py -> foo\bar.py.
E.g. foo\bar.py -> foo/bar.py.
"""
line = line.replace(root_dir, root_dir.replace('\\', '/'))
m = re.match(r'^([\S/]+):(\d+:)?(\s+.*)', line)
if not m:
return line
else:
filename, lineno, message = m.groups()
return '{}:{}{}'.format(filename.replace('/', '\\'),
return '{}:{}{}'.format(filename.replace('\\', '/'),
lineno or '', message)
View
@@ -76,6 +76,7 @@
'check-classvar.test',
'check-enum.test',
'check-incomplete-fixture.test',
'check-custom-plugin.test',
]
@@ -261,7 +262,8 @@ def find_error_paths(self, a: List[str]) -> Set[str]:
for line in a:
m = re.match(r'([^\s:]+):\d+: error:', line)
if m:
p = m.group(1).replace('/', os.path.sep)
# Normalize to Linux paths.
p = m.group(1).replace(os.path.sep, '/')
hits.add(p)
return hits
View
@@ -15,7 +15,7 @@
from mypy.test.config import test_data_prefix, test_temp_dir
from mypy.test.data import fix_cobertura_filename
from mypy.test.data import parse_test_cases, DataDrivenTestCase
from mypy.test.helpers import assert_string_arrays_equal
from mypy.test.helpers import assert_string_arrays_equal, normalize_error_messages
from mypy.version import __version__, base_version
# Path to Python 3 interpreter
@@ -71,10 +71,12 @@ def test_python_evaluation(testcase: DataDrivenTestCase) -> None:
os.path.abspath(test_temp_dir))
if testcase.native_sep and os.path.sep == '\\':
normalized_output = [fix_cobertura_filename(line) for line in normalized_output]
normalized_output = normalize_error_messages(normalized_output)
assert_string_arrays_equal(expected_content.splitlines(), normalized_output,
'Output file {} did not match its expected output'.format(
path))
else:
out = normalize_error_messages(out)
assert_string_arrays_equal(testcase.output, out,
'Invalid output ({}, line {})'.format(
testcase.file, testcase.line))
View
@@ -8,6 +8,8 @@
from mypy.version import __version__
from mypy.options import Options
from mypy.report import Reports
from mypy.plugin import Plugin
from mypy import defaults
class GraphSuite(Suite):
@@ -42,6 +44,7 @@ def _make_manager(self) -> BuildManager:
reports=Reports('', {}),
options=Options(),
version_id=__version__,
plugin=Plugin(defaults.PYTHON3_VERSION),
)
return manager
View
@@ -89,6 +89,7 @@ def test_semanal(testcase: DataDrivenTestCase) -> None:
a += str(f).split('\n')
except CompileError as e:
a = e.messages
a = normalize_error_messages(a)
assert_string_arrays_equal(
testcase.output, a,
'Invalid semantic analyzer output ({}, line {})'.format(testcase.file,
@@ -7,7 +7,9 @@
from mypy import build
from mypy.build import BuildSource
from mypy.myunit import Suite
from mypy.test.helpers import assert_string_arrays_equal, testfile_pyversion
from mypy.test.helpers import (
assert_string_arrays_equal, testfile_pyversion, normalize_error_messages
)
from mypy.test.data import parse_test_cases, DataDrivenTestCase
from mypy.test.config import test_data_prefix, test_temp_dir
from mypy.errors import CompileError
@@ -73,6 +75,7 @@ def test_transform(testcase: DataDrivenTestCase) -> None:
a += str(f).split('\n')
except CompileError as e:
a = e.messages
a = normalize_error_messages(a)
assert_string_arrays_equal(
testcase.output, a,
'Invalid semantic analyzer output ({}, line {})'.format(testcase.file,
Oops, something went wrong.
ProTip! Use n and p to navigate between commits in a pull request.