Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to sync methods #60

Merged
merged 1 commit into from
Mar 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,13 @@ ignore_missing_imports = true
[[tool.mypy.overrides]]
module = [
'samsungtvws.art',
'samsungtvws.connection',
'samsungtvws.helper',
'samsungtvws.remote',
'samsungtvws.rest',
'samsungtvws.shortcuts',
]
disallow_untyped_calls = false
disallow_untyped_defs = false

[[tool.mypy.overrides]]
module = [
'samsungtvws.async_connection',
'samsungtvws.async_rest',
'samsungtvws.remote',
]
disallow_untyped_calls = false
26 changes: 13 additions & 13 deletions samsungtvws/async_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,38 +19,38 @@
Boston, MA 02110-1335 USA

"""
from asyncio import Task, ensure_future, sleep
import contextlib
import json
import logging
import ssl
import sys
from asyncio import sleep
from types import TracebackType
from typing import Any, Awaitable, Callable, Dict, Union
from typing import Any, Awaitable, Callable, Dict, Optional, Union

from websockets.client import WebSocketClientProtocol, connect
from websockets.exceptions import ConnectionClosed

from . import connection, exceptions, helper
from .command import SamsungTVCommand

if sys.version_info >= (3, 7):
from asyncio import create_task
else:
from asyncio import ensure_future as create_task
from .event import MS_CHANNEL_CONNECT

_LOGGING = logging.getLogger(__name__)


class SamsungTVWSAsyncConnection(connection.SamsungTVWSBaseConnection):

connection: Optional[WebSocketClientProtocol]
_recv_loop: Optional[Task[None]]

async def __aenter__(self) -> "SamsungTVWSAsyncConnection":
return self

async def __aexit__(
self,
exc_type: Union[type, None],
exc_val: Union[BaseException, None],
exc_tb: Union[TracebackType, None],
exc_type: Optional[type],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
await self.close()

Expand All @@ -69,7 +69,7 @@ async def open(self) -> WebSocketClientProtocol:
response = helper.process_api_response(await connection.recv())
self._check_for_token(response)

if response["event"] != "ms.channel.connect":
if response["event"] != MS_CHANNEL_CONNECT:
await self.close()
raise exceptions.ConnectionFailure(response)

Expand All @@ -83,7 +83,7 @@ async def start_listening(
if self.connection is None:
self.connection = await self.open()

self._recv_loop = create_task(
self._recv_loop = ensure_future(
self._do_start_listening(callback, self.connection)
)

Expand Down Expand Up @@ -111,7 +111,7 @@ async def close(self) -> None:
async def send_command(
self,
command: Union[SamsungTVCommand, Dict[str, Any]],
key_press_delay: Union[float, None] = None,
key_press_delay: Optional[float] = None,
) -> None:
if self.connection is None:
self.connection = await self.open()
Expand Down
8 changes: 4 additions & 4 deletions samsungtvws/async_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

"""
import logging
from typing import Any, Dict, Union
from typing import Any, Dict, Optional

import aiohttp

Expand All @@ -36,11 +36,11 @@ def __init__(
*,
session: aiohttp.ClientSession,
port: int = 8001,
timeout: Union[float, None] = None,
timeout: Optional[float] = None,
) -> None:
super().__init__(
host,
endpoint=None,
endpoint="",
port=port,
timeout=timeout,
)
Expand All @@ -60,7 +60,7 @@ async def _rest_request(self, target: str, method: str = "GET") -> Dict[str, Any
else:
future = self.session.get(url, timeout=self.timeout, verify_ssl=False)
async with future as resp:
return helper.process_api_response(await resp.text()) # type: ignore[no-any-return]
return helper.process_api_response(await resp.text())
except aiohttp.ClientConnectionError:
raise exceptions.HttpApiError(
"TV unreachable or feature not supported on this model."
Expand Down
2 changes: 1 addition & 1 deletion samsungtvws/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


class SamsungTVCommand:
def __init__(self, method: str, params: Any) -> None:
def __init__(self, method: str, params: Dict[str, Any]) -> None:
self.method = method
self.params = params

Expand Down
67 changes: 41 additions & 26 deletions samsungtvws/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@
import logging
import ssl
import time
from types import TracebackType
from typing import Any, Dict, Optional, Union

import websocket

from . import exceptions, helper
from .command import SamsungTVCommand
from .event import MS_CHANNEL_CONNECT

_LOGGING = logging.getLogger(__name__)

Expand All @@ -41,15 +44,15 @@ class SamsungTVWSBaseConnection:

def __init__(
self,
host,
host: str,
*,
endpoint,
token=None,
token_file=None,
port=8001,
timeout=None,
key_press_delay=1,
name="SamsungTvRemote",
endpoint: str,
token: Optional[str] = None,
token_file: Optional[str] = None,
port: int = 8001,
timeout: Optional[float] = None,
key_press_delay: float = 1,
name: str = "SamsungTvRemote",
):
self.host = host
self.token = token
Expand All @@ -59,13 +62,13 @@ def __init__(
self.key_press_delay = key_press_delay
self.name = name
self.endpoint = endpoint
self.connection = None
self._recv_loop = None
self.connection: Optional[Any] = None
self._recv_loop: Optional[Any] = None

def _is_ssl_connection(self):
def _is_ssl_connection(self) -> bool:
return self.port == 8002

def _format_websocket_url(self, app):
def _format_websocket_url(self, app: str) -> str:
params = {
"host": self.host,
"port": self.port,
Expand All @@ -79,7 +82,7 @@ def _format_websocket_url(self, app):
else:
return self._URL_FORMAT.format(**params)

def _format_rest_url(self, route=""):
def _format_rest_url(self, route: str = "") -> str:
params = {
"protocol": "https" if self._is_ssl_connection() else "http",
"host": self.host,
Expand All @@ -89,17 +92,17 @@ def _format_rest_url(self, route=""):

return self._REST_URL_FORMAT.format(**params)

def _get_token(self):
def _get_token(self) -> Optional[str]:
if self.token_file is not None:
try:
with open(self.token_file) as token_file:
return token_file.readline()
except:
return ""
return None
else:
return self.token

def _set_token(self, token):
def _set_token(self, token: str) -> None:
_LOGGING.info("New token %s", token)
if self.token_file is not None:
_LOGGING.debug("Save token to file: %s", token)
Expand All @@ -108,18 +111,26 @@ def _set_token(self, token):
else:
self.token = token

def _check_for_token(self, response):
if response.get("data") and response["data"].get("token"):
token = response["data"].get("token")
def _check_for_token(self, response: Dict[str, Any]) -> None:
token = response.get("data", {}).get("token")
if token:
_LOGGING.debug("Got token %s", token)
self._set_token(token)


class SamsungTVWSConnection(SamsungTVWSBaseConnection):
def __enter__(self):

connection: Optional[websocket.WebSocket]

def __enter__(self) -> "SamsungTVWSConnection":
return self

def __exit__(self, type, value, traceback):
def __exit__(
self,
exc_type: Optional[type],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.close()

def open(self) -> websocket.WebSocket:
Expand All @@ -145,21 +156,25 @@ def open(self) -> websocket.WebSocket:
response = helper.process_api_response(connection.recv())
self._check_for_token(response)

if response["event"] != "ms.channel.connect":
if response["event"] != MS_CHANNEL_CONNECT:
self.close()
raise exceptions.ConnectionFailure(response)

self.connection = connection
return connection

def close(self):
def close(self) -> None:
if self.connection:
self.connection.close()

self.connection = None
_LOGGING.debug("Connection closed.")

def send_command(self, command, key_press_delay=None):
def send_command(
self,
command: Union[SamsungTVCommand, Dict[str, Any]],
key_press_delay: Optional[float] = None,
) -> None:
if self.connection is None:
self.connection = self.open()

Expand All @@ -172,5 +187,5 @@ def send_command(self, command, key_press_delay=None):
delay = self.key_press_delay if key_press_delay is None else key_press_delay
time.sleep(delay)

def is_alive(self):
return self.connection and self.connection.connected
def is_alive(self) -> bool:
return self.connection is not None and self.connection.connected
5 changes: 3 additions & 2 deletions samsungtvws/event.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from typing import Any, Dict, List, cast
from typing import Any, Dict, List

from .exceptions import MessageError

ED_APPS_LAUNCH_EVENT = "ed.apps.launch"
ED_EDENTV_UPDATE_EVENT = "ed.edenTV.update"
ED_INSTALLED_APP_EVENT = "ed.installedApp.get"
MS_CHANNEL_CONNECT = "ms.channel.connect"
MS_CHANNEL_CLIENT_CONNECT_EVENT = "ms.channel.clientConnect"
MS_CHANNEL_CLIENT_DISCONNECT_EVENT = "ms.channel.clientDisconnect"
MS_ERROR_EVENT = "ms.error"


def parse_installed_app(event: Dict[str, Any]) -> List[Dict[str, Any]]:
assert event["event"] == ED_INSTALLED_APP_EVENT
return cast(List[Dict[str, Any]], event["data"]["data"])
return event["data"]["data"] # type:ignore[no-any-return]


def parse_ms_error(event: Dict[str, Any]) -> MessageError:
Expand Down
7 changes: 4 additions & 3 deletions samsungtvws/helper.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
import base64
import json
import logging
from typing import Any, Dict, Union

from . import exceptions

_LOGGING = logging.getLogger(__name__)


def serialize_string(string):
def serialize_string(string: Union[str, bytes]) -> str:
if isinstance(string, str):
string = str.encode(string)

return base64.b64encode(string).decode("utf-8")


def process_api_response(response):
def process_api_response(response: Union[str, bytes]) -> Dict[str, Any]:
_LOGGING.debug("Processing API response: %s", response)
try:
return json.loads(response)
return json.loads(response) # type:ignore[no-any-return]
except json.JSONDecodeError:
raise exceptions.ResponseError(
"Failed to parse response from TV. Maybe feature not supported on this model"
Expand Down
Loading