Skip to content

Commit

Permalink
added support for command groups
Browse files Browse the repository at this point in the history
  • Loading branch information
dantownsend committed Apr 22, 2020
1 parent d2067c0 commit b76bd4e
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 31 deletions.
11 changes: 11 additions & 0 deletions example_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,22 @@ def compound_interest(interest_rate: float, years: int):
print(((interest_rate + 1) ** years) - 1)


def create(username: str):
"""
Create a new user.
:param username:
The new user's username.
"""
print(f"Creating {username}")


if __name__ == "__main__":
cli = CLI()
cli.register(say_hello)
cli.register(echo)
cli.register(add)
cli.register(print_pi)
cli.register(compound_interest)
cli.register(create, group_name="user")
cli.run()
98 changes: 67 additions & 31 deletions targ/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,21 @@ class Arguments:
@dataclass
class Command:
command: t.Callable
group_name: t.Optional[str] = None

def __post_init__(self):
self.command_docstring: Docstring = parse(self.command.__doc__)
self.annotations = t.get_type_hints(self.command)
self.signature = inspect.signature(self.command)
self.command_name = self.command.__name__

@cached_property
def full_name(self):
return (
f"{self.group_name} {self.command_name}"
if self.group_name
else self.command_name
)

@cached_property
def description(self) -> str:
Expand Down Expand Up @@ -93,7 +103,7 @@ def usage(self) -> str:
some_command required_arg [--optional_arg=value] [--some_flag]
"""
output = [format_text(self.command.__name__, color=Color.green)]
output = [format_text(self.command_name, color=Color.green)]

for arg_name, parameter in self.signature.parameters.items():
if parameter.default is inspect._empty: # type: ignore
Expand All @@ -110,29 +120,29 @@ def usage(self) -> str:

return " ".join(output)

def print_help(self):
print("")
print(self.command_name)
print(get_underline(len(self.command_name)))
print(self.description)

print("")
print("Usage")
print(get_underline(5, character="-"))
print(self.usage)
print("")

print("Args")
print(get_underline(4, character="-"))
print(self.arguments_description)
print("")

def call_with(self, arg_class: Arguments):
"""
Call the command function with the given arguments.
"""
if arg_class.kwargs.get("help"):
name = self.command.__name__

print("")
print(name)
print(get_underline(len(name)))
print(self.description)

print("")
print("Usage")
print(get_underline(5, character="-"))
print(self.usage)
print("")

print("Args")
print(get_underline(4, character="-"))
print(self.arguments_description)
print("")

self.print_help()
return

annotations = t.get_type_hints(self.command)
Expand Down Expand Up @@ -167,8 +177,18 @@ class CLI:
description: str = "Targ CLI"
commands: t.List[Command] = field(default_factory=list)

def register(self, command: t.Callable, group: t.Optional[str] = None):
self.commands.append(Command(command))
def _validate_group_name(self, group_name: str) -> bool:
if " " in group_name:
return False
return True

def register(
self, command: t.Callable, group_name: t.Optional[str] = None
):
if group_name and not self._validate_group_name(group_name):
raise ValueError("The group name should not contain spaces.")

self.commands.append(Command(command=command, group_name=group_name))

def get_help_text(self) -> str:
lines = [
Expand All @@ -186,9 +206,7 @@ def get_help_text(self) -> str:
]

for command in self.commands:
lines.append(
format_text(command.command.__name__, color=Color.green)
)
lines.append(format_text(command.full_name, color=Color.green))
lines.append(command.description)
lines.append("")

Expand All @@ -206,9 +224,13 @@ def get_cleaned_args(self) -> t.List[str]:
output.append(arg)
return output

def get_command(self, command_name: str) -> t.Optional[Command]:
def get_command(
self, command_name: str, group_name: t.Optional[str] = None
) -> t.Optional[Command]:
for command in self.commands:
if command.command.__name__ == command_name:
if command.command_name == command_name:
if group_name and command.group_name != group_name:
continue
return command
return None

Expand Down Expand Up @@ -240,24 +262,38 @@ def get_arg_class(self, args: t.List[str]) -> Arguments:
return arguments

def run(self):
args = self.get_cleaned_args()
cleaned_args = self.get_cleaned_args()

if len(args) == 0:
if len(cleaned_args) == 0:
print(self.get_help_text())
return

command_name = args[0]
args = args[1:]
command_name = cleaned_args[0]

command = self.get_command(command_name=command_name)

if command:
cleaned_args = cleaned_args[1:]
else:
# See if it belongs to a group:
if len(cleaned_args) >= 2:
group_name = cleaned_args[0]
command_name = cleaned_args[1]
command = self.get_command(
command_name=command_name, group_name=group_name
)
if command:
cleaned_args = cleaned_args[2:]

if not command:
print(f"Unrecognised command - {command_name}")
print(self.get_help_text())
else:
try:
arg_class = self.get_arg_class(args)
arg_class = self.get_arg_class(cleaned_args)
command.call_with(arg_class)
except Exception as exception:
print(format_text("The command failed.", color=Color.red))
print(exception)
command.print_help()
sys.exit(1)

0 comments on commit b76bd4e

Please sign in to comment.