Skip to content

Commit

Permalink
Merge pull request #3 from piccolo-orm/optional_args
Browse files Browse the repository at this point in the history
add support for t.Optional types
  • Loading branch information
dantownsend committed May 12, 2021
2 parents dfeee12 + 8fd1542 commit 1f51ba0
Show file tree
Hide file tree
Showing 3 changed files with 321 additions and 41 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
colorama==0.4.*
docstring-parser==0.7.1
typing_inspect==0.6.0; python_version < '3.8'
54 changes: 42 additions & 12 deletions targ/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
import decimal
import inspect
import json
import sys
import traceback
import typing as t

try:
from typing import get_args, get_origin # type: ignore
except ImportError:
# For Python 3.7 support
from typing_extensions import get_args, get_origin

from docstring_parser import parse, Docstring, DocstringParam # type: ignore

from .format import Color, format_text, get_underline
Expand All @@ -15,6 +22,11 @@
__VERSION__ = "0.1.9"


# If an annotation is one of these values, we will convert the string value
# to it.
CONVERTABLE_TYPES = (int, float, decimal.Decimal)


@dataclass
class Arguments:
args: t.List[str] = field(default_factory=list)
Expand Down Expand Up @@ -81,7 +93,7 @@ def arguments_description(self) -> str:
"""
output = []

for arg_name, annotation in self.annotations.items():
for arg_name, _ in self.annotations.items():
arg_description = self._get_arg_description(arg_name=arg_name)

arg_default = self._get_arg_default(arg_name=arg_name)
Expand Down Expand Up @@ -150,33 +162,51 @@ def print_help(self):
print("No args")
print("")

def _convert_arg_type(self):
pass

def call_with(self, arg_class: Arguments):
"""
Call the command function with the given arguments.
The arguments are all strings at this point, as they're come from the
command line.
"""
if arg_class.kwargs.get("help"):
self.print_help()
return

annotations = t.get_type_hints(self.command)

kwargs = {}
kwargs = arg_class.kwargs.copy()
for index, value in enumerate(arg_class.args):
key = list(annotations.keys())[index]
kwargs[key] = value

for kwarg_key, kwarg_value in arg_class.kwargs.items():
annotation = annotations.get(kwarg_key)
cleaned_kwargs = {}

for key, value in kwargs.items():
annotation = annotations.get(key)
# This only works with basic types like str at the moment.
if callable(annotation):
kwargs[kwarg_key] = annotation(kwarg_value)

for index, arg in enumerate(arg_class.args):
kwarg_key, annotation = list(annotations.items())[index]
if callable(annotation):
kwargs[kwarg_key] = annotation(arg)
if annotation in CONVERTABLE_TYPES:
value = annotation(value)
elif get_origin(annotation) is t.Union:
# t.Union is used to detect t.Optional
inner_annotations = get_args(annotation)
filtered = [i for i in inner_annotations if i is not None]
if len(filtered) == 1:
annotation = filtered[0]
if annotation in CONVERTABLE_TYPES:
value = annotation(value)

cleaned_kwargs[key] = value

if inspect.iscoroutinefunction(self.command):
asyncio.run(self.command(**kwargs))
asyncio.run(self.command(**cleaned_kwargs))
else:
self.command(**kwargs)
self.command(**cleaned_kwargs)


@dataclass
Expand Down

0 comments on commit 1f51ba0

Please sign in to comment.