Skip to content

Commit

Permalink
Now route parameters are required if not at the end of string
Browse files Browse the repository at this point in the history
The mypy plugin was updated to mark all route parameters as required
unless they are at the end of the string.
Route parsing now has a dedicated helper class.
  • Loading branch information
sanjacob committed Jan 11, 2024
1 parent 2bbbc70 commit 101bb83
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 14 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [1.2.2] - 2024-01-11

### Changed
- mypy plugin now treats most route params as required
- mypy route parsing slightly improved

## [1.2.1] - 2024-01-06

### Added
- New mypy plugin for route parameter type checking support

## [1.2.0] - 2024-01-04

### Changed
Expand Down
30 changes: 30 additions & 0 deletions docs/typing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,36 @@ Therefore, errors like this will be highlighted:
expected features. Additional requests parameters in endpoint calls
are not supported yet.

**Required and optional parameters**

By design, only the last positional parameter can be optional. For instance:

::

@get('/users/{user_id}/comments/{comment_id}')
def get_comment(self, response: dict[str, str]) -> Comment:
...

client.get_comment(comment_id="...") # forgot to add user id

>>> error: Missing named argument...

.. note::
If you want to leave the starting parameters empty, you will have to
explicitly pass an empty string.

The parameter can only be marked optional if the route ends on that
parameter. This implies the route cannot end in a slash, for instance.
The endpoint below has no optional route parameters.

::

@get('/users/{user_id}/comments/{comment_id}/likes')
def get_likes(self, response: int) -> int:
...

**Enabling the plugin**

To enable the plugin, add this to your pyproject.toml, or check the
`mypy_config`_ documentation if you are using a different file format.

Expand Down
50 changes: 50 additions & 0 deletions tests/test_mypy_plugin.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
- case: mypy_plugin_no_placeholders
main: |
from tiny_api_client import get, api_client
@api_client('https://api.example.org')
class MyClient:
@get('/users')
def get_users(self, response: list[str]) -> list[str]:
return response
client = MyClient()
client.get_users()
env:
- PYTHONPATH=$(pwd)/../

- case: mypy_plugin_correct_route_param
main: |
from tiny_api_client import get, api_client
Expand All @@ -13,6 +28,23 @@
env:
- PYTHONPATH=$(pwd)/../


- case: mypy_plugin_multiple_route_params
main: |
from tiny_api_client import get, api_client
@api_client('https://api.example.org')
class MyClient:
@get('/users/{user_id}/comments/{comment_id}')
def get_comment(self, response: str) -> str:
return response
client = MyClient()
client.get_comment(user_id='peterparker', comment_id='001')
env:
- PYTHONPATH=$(pwd)/../


- case: mypy_plugin_optional_route_param
main: |
from tiny_api_client import get, api_client
Expand Down Expand Up @@ -62,3 +94,21 @@
- PYTHONPATH=$(pwd)/../
out: |
main:10: error: Unexpected keyword argument "unknown_id" for "get_users" of "MyClient" [call-arg]
- case: mypy_plugin_non_optional_args
main: |
from tiny_api_client import get, api_client
@api_client('https://api.example.org')
class MyClient:
@get('/category/{category_id}/product/{product_id}')
def get_product(self, response: str) -> str:
return response
client = MyClient()
client.get_product(product_id='peterparker')
env:
- PYTHONPATH=$(pwd)/../
out: |
main:10: error: Missing named argument "category_id" for "get_product" of "MyClient" [call-arg]
53 changes: 39 additions & 14 deletions tiny_api_client/mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,42 @@
# 02110-1301 USA

import string
from collections.abc import Callable
from typing import NamedTuple
from collections.abc import Callable, Iterable

from mypy.nodes import ARG_NAMED_OPT, StrExpr
from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, StrExpr
from mypy.options import Options
from mypy.plugin import MethodContext, Plugin
from mypy.types import Type, CallableType


class RouteParser:
formatter = string.Formatter()

class FormatTuple(NamedTuple):
literal_text: str | None
field_name: str | None
format_spec: str | None
conversion: str | None

def __init__(self, route: str):
parsed = self.formatter.parse(route)
self.params = []

for t in parsed:
self.params.append(self.FormatTuple(*t))

@property
def fields(self) -> Iterable[str]:
return (x.field_name for x in self.params if x.field_name is not None)

@property
def has_optional(self) -> bool:
if not len(self.params):
return False
return self.params[-1].field_name is not None


class TinyAPIClientPlugin(Plugin):
"""Companion mypy plugin for tiny-api-client.
Expand All @@ -42,7 +70,7 @@ class TinyAPIClientPlugin(Plugin):
they were factual ones.
"""
def __init__(self, options: Options) -> None:
self._ctx_cache: dict[str, list[str]] = {}
self._ctx_cache: dict[str, RouteParser] = {}
super().__init__(options)

def get_method_hook(self, fullname: str
Expand All @@ -65,15 +93,9 @@ def _factory_callback(self, ctx: MethodContext) -> Type:
"""
if len(ctx.args) and len(ctx.args[0]):
pos = f"{ctx.context.line},{ctx.context.column}"
route_params = []
formatter = string.Formatter()
route = ctx.args[0][0]
assert isinstance(route, StrExpr)

for x in formatter.parse(route.value):
if x[1] is not None:
route_params.append(x[1])
self._ctx_cache[pos] = route_params
self._ctx_cache[pos] = RouteParser(route.value)
return ctx.default_return_type

def _decorator_callback(self, ctx: MethodContext) -> Type:
Expand All @@ -89,15 +111,18 @@ def _decorator_callback(self, ctx: MethodContext) -> Type:
assert isinstance(default_ret, CallableType)

# Modify default return type in place (probably fine)
for p in self._ctx_cache[pos]:
route_parser = self._ctx_cache[pos]
for p in route_parser.fields:
default_ret.arg_types.append(
# Since the URL is a string, type of arguments
# should also be string
# API endpoint URL params must be strings
ctx.api.named_generic_type("builtins.str", [])
)
default_ret.arg_kinds.append(ARG_NAMED_OPT)
default_ret.arg_kinds.append(ARG_NAMED)
default_ret.arg_names.append(p)

if route_parser.has_optional:
default_ret.arg_kinds[-1] = ARG_NAMED_OPT

return ctx.default_return_type


Expand Down

0 comments on commit 101bb83

Please sign in to comment.