Skip to content

Commit

Permalink
Merge pull request #58 from snok/sondrelg/pydantic-support
Browse files Browse the repository at this point in the history
Add Pydantic support
  • Loading branch information
sondrelg committed Jan 4, 2022
2 parents 2f5573e + 2a308fd commit 5549a5e
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 28 deletions.
37 changes: 31 additions & 6 deletions README.md
Expand Up @@ -55,23 +55,48 @@ so please only choose one.

To select one of the ranges, just specify the code in your flake8 config:

```
```ini
[flake8]
max-line-length = 80
max-complexity = 12
...
ignore = E501
select = C,E,F,W,..., TC, TC2 # or TC1
# alternatively:
# You can use 'select':
select = C,E,F..., TC, TC2 # or TC1
# or 'enable-extensions':
enable-extensions = TC, TC2 # or TC1
```

## Configuration

The plugin currently only has one setting:
These options are configurable,
and can be set in your flake8 config.

### Exempt modules

If you wish to exempt certain modules from
needing to be moved into type-checking blocks, you can specify which
modules to ignore.

- `type-checking-exempt-modules`: Specify a list of modules to ignore TC001/TC002 errors from.
Could be useful if you, e.g., want to handle big libraries, but prefer not to handle `typing` imports.
- **setting name**: `type-checking-exempt-modules`
- **type**: `list`

```ini
[flake8]
type-checking-exempt-modules: typing, typing_extensions # default []
```

### Pydantic support

If you use Pydantic models in your code, you should enable Pydantic support.
This will treat any class variable annotation as being needed during runtime.

- **name**: `type-checking-pydantic-enabled`
- **type**: `bool`
```ini
[flake8]
type-checking-pydantic-enabled: true # default false
```

## Rationale

Expand Down
35 changes: 19 additions & 16 deletions flake8_type_checking/checker.py
Expand Up @@ -38,17 +38,8 @@
class ImportVisitor(ast.NodeTransformer):
"""Map all imports outside of type-checking blocks."""

__slots__ = (
'cwd',
'exempt_imports',
'local_imports',
'remote_imports',
'import_names',
'uses',
'unwrapped_annotations',
)

def __init__(self, cwd: Path, exempt_modules: Optional[list[str]] = None) -> None:
def __init__(self, cwd: Path, pydantic_enabled: bool, exempt_modules: Optional[list[str]] = None) -> None:
self.pydantic_enabled = pydantic_enabled
self.cwd = cwd # we need to know the current directory to guess at which imports are remote and which are not

# Import patterns we want to avoid mapping
Expand Down Expand Up @@ -278,6 +269,17 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:

def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
"""Note down class names."""
if self.pydantic_enabled and node.bases:
# When pydantic support is enabled, treat any class variable
# annotation as being required during runtime.
# We need to do this, or users run the risk of guarding imports
# to resources that actually are required at runtime -- required
# because Pydantic unlike most libraries, evaluates annotations
# *at* runtime.
for element in node.body:
if isinstance(element, ast.AnnAssign):
self.visit(element.annotation)

self.class_names.add(node.name)
self.generic_visit(node)
return node
Expand All @@ -288,6 +290,7 @@ def visit_Name(self, node: ast.Name) -> ast.Name:
return node
if hasattr(node, ATTRIBUTE_PROPERTY):
self.uses[f'{node.id}.{getattr(node, ATTRIBUTE_PROPERTY)}'] = node

self.uses[node.id] = node
return node

Expand Down Expand Up @@ -429,11 +432,11 @@ class TypingOnlyImportsChecker:

def __init__(self, node: ast.Module, options: Optional[Namespace]) -> None:
self.cwd = Path(os.getcwd())
if options and hasattr(options, 'type_checking_exempt_modules'):
exempt_modules = options.type_checking_exempt_modules
else:
exempt_modules = []
self.visitor = ImportVisitor(self.cwd, exempt_modules=exempt_modules)

exempt_modules = getattr(options, 'type_checking_exempt_modules', [])
pydantic_enabled = getattr(options, 'type_checking_pydantic_enabled', False)

self.visitor = ImportVisitor(self.cwd, pydantic_enabled=pydantic_enabled, exempt_modules=exempt_modules)
self.visitor.visit(node)

self.generators = [
Expand Down
7 changes: 7 additions & 0 deletions flake8_type_checking/plugin.py
Expand Up @@ -46,6 +46,13 @@ def add_options(cls, option_manager: OptionManager) -> None:
default=[],
help='Skip TC001 and TC002 checks for specified modules or libraries.',
)
option_manager.add_option(
'--type-checking-pydantic-enabled',
action='store_true',
parse_from_config=True,
default=False,
help='Add compatibility for Pydantic models.',
)

def run(self) -> Flake8Generator:
"""Run flake8 plugin and return any relevant errors."""
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = 'flake8-type-checking'
version = "1.1.2"
version = "1.2.0"
description = 'A flake8 plugin for managing type-checking imports & forward references'
homepage = 'https://github.com/snok'
repository = 'https://github.com/sondrelg/flake8-type-checking'
Expand Down
5 changes: 4 additions & 1 deletion tests/__init__.py
Expand Up @@ -17,9 +17,12 @@ def _get_error(example: str, error_code_filter: Optional[str] = None, **kwargs:
if error_code_filter:
mock_options = Mock()
mock_options.select = [error_code_filter]
# defaults
mock_options.extended_default_select = []
mock_options.enable_extensions = []
mock_options.type_checking_pydantic_enabled = False
mock_options.type_checking_exempt_modules = []
# kwarg overrides
for k, v in kwargs.items():
setattr(mock_options, k, v)
plugin = Plugin(ast.parse(example), options=mock_options)
Expand All @@ -29,4 +32,4 @@ def _get_error(example: str, error_code_filter: Optional[str] = None, **kwargs:
errors = {f'{line}:{col} {msg}' for line, col, msg, _ in plugin.run()}
if error_code_filter is None:
error_code_filter = ''
return {error for error in errors if error_code_filter in error}
return {error for error in errors if any(error_code in error for error_code in error_code_filter.split(','))}
2 changes: 1 addition & 1 deletion tests/test_errors.py
Expand Up @@ -158,7 +158,7 @@ def test_import_is_local():
def raise_value_error(*args, **kwargs):
raise ValueError('test')

visitor = ImportVisitor(REPO_ROOT)
visitor = ImportVisitor(REPO_ROOT, False)
assert visitor._import_is_local(mod) is True

patch('flake8_type_checking.checker.find_spec', raise_value_error).start()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_import_visitors.py
Expand Up @@ -10,13 +10,13 @@


def _get_remote_imports(example):
visitor = ImportVisitor(REPO_ROOT)
visitor = ImportVisitor(REPO_ROOT, False)
visitor.visit(ast.parse(example.replace('; ', '\n')))
return list(visitor.remote_imports.keys())


def _get_local_imports(example):
visitor = ImportVisitor(REPO_ROOT)
visitor = ImportVisitor(REPO_ROOT, False)
visitor.visit(ast.parse(example.replace('; ', '\n')))
return list(visitor.local_imports.keys())

Expand Down
2 changes: 1 addition & 1 deletion tests/test_name_visitor.py
Expand Up @@ -8,7 +8,7 @@


def _get_names(example: str) -> Set[str]:
visitor = ImportVisitor('fake cwd') # type: ignore
visitor = ImportVisitor('fake cwd', False) # type: ignore
visitor.visit(ast.parse(example))
return visitor.names

Expand Down
107 changes: 107 additions & 0 deletions tests/test_pydantic.py
@@ -0,0 +1,107 @@
"""
This file tests pydantic support.
See https://github.com/snok/flake8-type-checking/issues/52
for discussion on the implementation.
"""

import textwrap

import pytest

from flake8_type_checking.codes import TC002
from tests import _get_error


@pytest.mark.parametrize(
'enabled, expected',
(
[True, {'2:0 ' + TC002.format(module='decimal.Decimal')}],
[False, {'2:0 ' + TC002.format(module='decimal.Decimal')}],
),
)
def test_non_pydantic_model(enabled, expected):
"""
A class cannot be a pydantic model if it doesn't have a base class,
so we should raise the same error here in both cases.
"""
example = textwrap.dedent(
'''
from decimal import Decimal
class X:
x: Decimal
'''
)
assert _get_error(example, error_code_filter='TC001,TC002', type_checking_pydantic_enabled=enabled) == expected


def test_class_with_base_class():
"""
Whenever a class inherits from anything, we need
to assume it might be a pydantic model, for which
we need to register annotations as uses.
"""
example = textwrap.dedent(
'''
from decimal import Decimal
class X(Y):
x: Decimal
'''
)
assert _get_error(example, error_code_filter='TC001,TC002', type_checking_pydantic_enabled=True) == set()


def test_complex_pydantic_model():
"""
Test actual Pydantic models, with different annotation types.
"""
example = textwrap.dedent(
'''
from __future__ import annotations
from datetime import datetime
from decimal import Decimal
from typing import TYPE_CHECKING
from pydantic import BaseModel, condecimal, validator
if TYPE_CHECKING:
from datetime import date
from typing import Union
def format_datetime(value: Union[str, datetime]) -> datetime:
if isinstance(value, str):
value = datetime.strptime(value, '%Y-%m-%dT%H:%M:%S.%f%z')
assert isinstance(value, datetime)
return value
class ModelBase(BaseModel):
id: int
created_at: datetime
updated_at: datetime
_format_datetime = validator('created_at', 'updated_at', pre=True, allow_reuse=True)(format_datetime)
class NestedModel(ModelBase):
z: Decimal
x: int
y: str
class FinalModel(ModelBase):
a: str
b: int
c: float
d: bool
e: date
f: NestedModel
g: condecimal(ge=Decimal(0)) = Decimal(0)
'''
)
assert _get_error(example, error_code_filter='TC001,TC002', type_checking_pydantic_enabled=True) == set()

0 comments on commit 5549a5e

Please sign in to comment.