Skip to content

Commit

Permalink
Add type annotations for backend/base.py (#717)
Browse files Browse the repository at this point in the history
  • Loading branch information
amaslenn committed Jul 21, 2023
1 parent 1121554 commit 8ecab66
Showing 1 changed file with 28 additions and 20 deletions.
48 changes: 28 additions & 20 deletions testinfra/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import shlex
import subprocess
import urllib.parse
from typing import Any, Optional

logger = logging.getLogger("testinfra")

Expand All @@ -26,13 +27,13 @@
class CommandResult:
def __init__(
self,
backend,
exit_status,
command,
stdout_bytes,
stderr_bytes,
stdout=None,
stderr=None,
backend: "BaseBackend",
exit_status: int,
command: bytes,
stdout_bytes: bytes,
stderr_bytes: bytes,
stdout: Optional[str] = None,
stderr: Optional[str] = None,
):
self.exit_status = exit_status
self._stdout_bytes = stdout_bytes
Expand All @@ -44,7 +45,7 @@ def __init__(
super().__init__()

@property
def succeeded(self):
def succeeded(self) -> bool:
"""Returns whether the command was successful
>>> host.run("true").succeeded
Expand All @@ -53,7 +54,7 @@ def succeeded(self):
return self.exit_status == 0

@property
def failed(self):
def failed(self) -> bool:
"""Returns whether the command failed
>>> host.run("false").failed
Expand All @@ -62,7 +63,7 @@ def failed(self):
return self.exit_status != 0

@property
def rc(self):
def rc(self) -> int:
"""Gets the returncode of a command
>>> host.run("true").rc
Expand All @@ -71,30 +72,30 @@ def rc(self):
return self.exit_status

@property
def stdout(self):
def stdout(self) -> str:
if self._stdout is None:
self._stdout = self._backend.decode(self._stdout_bytes)
return self._stdout

@property
def stderr(self):
def stderr(self) -> str:
if self._stderr is None:
self._stderr = self._backend.decode(self._stderr_bytes)
return self._stderr

@property
def stdout_bytes(self):
def stdout_bytes(self) -> bytes:
if self._stdout_bytes is None:
self._stdout_bytes = self._backend.encode(self._stdout)
return self._stdout_bytes

@property
def stderr_bytes(self):
def stderr_bytes(self) -> bytes:
if self._stderr_bytes is None:
self._stderr_bytes = self._backend.encode(self._stderr)
return self._stderr_bytes

def __repr__(self):
def __repr__(self) -> str:
return (
"CommandResult(command={!r}, exit_status={}, stdout={!r}, " "stderr={!r})"
).format(
Expand All @@ -112,7 +113,14 @@ class BaseBackend(metaclass=abc.ABCMeta):
HAS_RUN_ANSIBLE = False
NAME: str

def __init__(self, hostname, sudo=False, sudo_user=None, *args, **kwargs):
def __init__(
self,
hostname: str,
sudo: bool = False,
sudo_user: Optional[bool] = None,
*args: Any,
**kwargs: Any,
):
self._encoding = None
self._host = None
self.hostname = hostname
Expand Down Expand Up @@ -245,7 +253,7 @@ def parse_containerspec(containerspec):
user, name = name.split("@", 1)
return name, user

def get_encoding(self) -> str:
def get_encoding(self):
encoding = None
for python in ("python3", "python"):
cmd = self.run(
Expand All @@ -271,19 +279,19 @@ def encoding(self):
self._encoding = self.get_encoding()
return self._encoding

def decode(self, data):
def decode(self, data: bytes) -> str:
try:
return data.decode("ascii")
except UnicodeDecodeError:
return data.decode(self.encoding)

def encode(self, data):
def encode(self, data: str) -> bytes:
try:
return data.encode("ascii")
except UnicodeEncodeError:
return data.encode(self.encoding)

def result(self, *args, **kwargs):
def result(self, *args: Any, **kwargs: Any) -> CommandResult:
result = CommandResult(self, *args, **kwargs)
logger.debug("RUN %s", result)
return result

0 comments on commit 8ecab66

Please sign in to comment.