Skip to content

Commit

Permalink
Increase type coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
sloria committed Jul 8, 2019
1 parent 043d932 commit 101e113
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 24 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

### 5.1.0 (unreleased)

Other changes:

- Improve typings.

### 5.0.0 (2019-07-06)

Features:
Expand Down
46 changes: 22 additions & 24 deletions environs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from pathlib import Path

import marshmallow as ma
from dotenv import load_dotenv
from dotenv.main import _walk_to_root
from dotenv.main import load_dotenv, DotEnv, _walk_to_root

__version__ = "5.0.0"
__all__ = ["EnvError", "Env"]
Expand All @@ -26,16 +25,15 @@ class EnvError(ValueError):
_PROXIED_PATTERN = re.compile(r"\s*{{\s*(\S*)\s*}}\s*")

FieldFactory = typing.Callable[..., ma.fields.Field]
Subcast = typing.Union[typing.Type, typing.Callable]


def _field2method(
field_or_factory: typing.Union[typing.Type[ma.fields.Field], FieldFactory],
method_name: str,
preprocess: typing.Callable = None,
) -> typing.Callable:
def method(
self: "Env", name: str, default: typing.Any = ma.missing, subcast: typing.Type = None, **kwargs
):
def method(self: "Env", name: str, default: typing.Any = ma.missing, subcast: Subcast = None, **kwargs):
missing = kwargs.pop("missing", None) or default
if isinstance(field_or_factory, type) and issubclass(field_or_factory, ma.fields.Field):
field = typing.cast(typing.Type[ma.fields.Field], field_or_factory)(missing=missing, **kwargs)
Expand Down Expand Up @@ -94,32 +92,32 @@ class Meta:
return type("", (ma.Schema,), attrs)


def _make_list_field(**kwargs) -> ma.fields.List:
subcast = kwargs.pop("subcast", None)
def _make_list_field(*, subcast: Subcast, **kwargs) -> ma.fields.List:
inner_field = ma.Schema.TYPE_MAPPING[subcast] if subcast else ma.fields.Field
return ma.fields.List(inner_field, **kwargs)


def _preprocess_list(value, **kwargs) -> typing.Iterable:
return value if ma.utils.is_iterable_but_not_string(value) else value.split(",")
def _preprocess_list(value: typing.Union[str, typing.Iterable], **kwargs) -> typing.Iterable:
return value if ma.utils.is_iterable_but_not_string(value) else typing.cast(str, value).split(",")


def _preprocess_dict(value: typing.Union[str, typing.Mapping], **kwargs) -> typing.Mapping:
def _preprocess_dict(
value: typing.Union[str, typing.Mapping], subcast: Subcast, **kwargs
) -> typing.Mapping[str, typing.Any]:
if isinstance(value, Mapping):
return value

subcast = kwargs.get("subcast")
return {
key.strip(): subcast(val.strip()) if subcast else val.strip()
for key, val in (item.split("=") for item in value.split(",") if value)
}


def _preprocess_json(value, **kwargs):
def _preprocess_json(value: str, **kwargs):
return pyjson.loads(value)


def _dj_db_url_parser(value, **kwargs) -> dict:
def _dj_db_url_parser(value: str, **kwargs) -> dict:
try:
import dj_database_url
except ImportError:
Expand All @@ -130,7 +128,7 @@ def _dj_db_url_parser(value, **kwargs) -> dict:
return dj_database_url.parse(value, **kwargs)


def _dj_email_url_parser(value, **kwargs) -> dict:
def _dj_email_url_parser(value: str, **kwargs) -> dict:
try:
import dj_email_url
except ImportError:
Expand All @@ -142,7 +140,7 @@ def _dj_email_url_parser(value, **kwargs) -> dict:


class URLField(ma.fields.URL):
def _serialize(self, value, attr, obj):
def _serialize(self, value: ParseResult, *args, **kwargs) -> str:
return value.geturl()

# Override deserialize rather than _deserialize because we need
Expand Down Expand Up @@ -180,15 +178,15 @@ class Env:
url=_field2method(URLField, "url"),
dj_db_url=_func2method(_dj_db_url_parser, "dj_db_url"),
dj_email_url=_func2method(_dj_email_url_parser, "dj_email_url"),
)
) # type: typing.Dict[str, typing.Callable]

def __init__(self):
self._fields = {}
self._values = {}
self._prefix = None
self._fields = {} # type: typing.Dict[str, ma.fields.Field]
self._values = {} # type: typing.Dict[str, typing.Any]
self._prefix = None # type: typing.Optional[str]
self.__parser_map__ = self.default_parser_map.copy()

def __repr__(self):
def __repr__(self) -> str:
return "<{} {}>".format(self.__class__.__name__, self._values)

__str__ = __repr__
Expand All @@ -200,7 +198,7 @@ def read_env(
stream: str = None,
verbose: bool = False,
override: bool = False,
):
) -> DotEnv:
"""Read a .env file into os.environ.
If .env is not found in the directory from which this method is called,
Expand Down Expand Up @@ -229,7 +227,7 @@ def read_env(
return load_dotenv(start, stream=stream, verbose=verbose, override=override)

@contextlib.contextmanager
def prefixed(self, prefix: str):
def prefixed(self, prefix: str) -> typing.Iterator["Env"]:
"""Context manager for parsing envvars with a common prefix."""
try:
old_prefix = self._prefix
Expand All @@ -243,7 +241,7 @@ def prefixed(self, prefix: str):
self._prefix = None
self._prefix = old_prefix

def __getattr__(self, name, **kwargs):
def __getattr__(self, name: str, **kwargs):
try:
return functools.partial(self.__parser_map__[name], self)
except KeyError:
Expand Down Expand Up @@ -271,7 +269,7 @@ def add_parser_from_field(self, name: str, field_cls: typing.Type[ma.fields.Fiel
"""Register a new parser method with name ``name``, given a marshmallow ``Field``."""
self.__parser_map__[name] = _field2method(field_cls, method_name=name)

def dump(self) -> typing.Mapping:
def dump(self) -> typing.Mapping[str, typing.Any]:
"""Dump parsed environment variables to a dictionary of simple data types (numbers
and strings).
"""
Expand Down

0 comments on commit 101e113

Please sign in to comment.