Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 44 additions & 5 deletions serialx/platforms/serial_esphome.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,23 @@
from collections.abc import Callable, Coroutine
from contextlib import suppress
from enum import IntFlag
import functools
import logging
import threading
from typing import Any, TypeVar
from typing import Any, TypeVar, cast
import urllib.parse

import aioesphomeapi
from aioesphomeapi import APIClient, SerialProxyDataReceived, SerialProxyParity
from aioesphomeapi.core import APIConnectionError, PingRequest, PingResponse
from aioesphomeapi.core import (
APIConnectionError,
PingRequest,
PingResponse,
TimeoutAPIError,
)
from typing_extensions import Buffer

from serialx import UnsupportedSetting
from serialx import SerialException, UnsupportedSetting
from serialx.common import (
BaseSerial,
BaseSerialTransport,
Expand All @@ -27,6 +33,7 @@
)

_T = TypeVar("_T")
_F = TypeVar("_F", bound=Callable[..., Any])

LOGGER = logging.getLogger(__name__)

Expand All @@ -44,6 +51,21 @@
}


def translate_esphome_errors(func: _F) -> _F:
"""Translate aioesphomeapi errors into standard serialx exceptions."""

@functools.wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except TimeoutAPIError as exc:
raise TimeoutError(str(exc)) from exc
except APIConnectionError as exc:
raise SerialException(str(exc)) from exc

return cast(_F, wrapper)


class LineStateFlag(IntFlag):
"""Bitmap of serial line states."""

Expand All @@ -66,6 +88,8 @@ def __init__(
api: APIClient | None = None,
port_name: str | None = None,
port_instance: int | None = None,
password: str | None = None,
noise_psk: str | None = None,
**kwargs,
) -> None:
"""Initialize ESPHome serial port."""
Expand All @@ -86,6 +110,8 @@ def __init__(
self._api: APIClient | None = api
self._port_name: str | None = port_name
self._instance_id: int | None = port_instance
self._password: str | None = password
self._noise_psk: str | None = noise_psk
self._disconnect_api: bool = False

self._read_buffer = bytearray()
Expand Down Expand Up @@ -121,6 +147,7 @@ def is_open(self) -> bool:
"""Return whether the serial port is open."""
return self._api is not None

@translate_esphome_errors
async def _async_open(self) -> None:
# Only connect if the API was not passed in externally
if self._api is None:
Expand All @@ -139,11 +166,17 @@ async def _async_open(self) -> None:
else:
self._port_name = port_value

if "password" in params:
self._password = params["password"][0]

if "noise_psk" in params:
self._noise_psk = params["noise_psk"][0]

self._api = aioesphomeapi.APIClient(
address=parsed.hostname,
port=parsed.port or ESPHOME_DEFAULT_PORT,
password=params["password"][0] if "password" in params else None,
noise_psk=params["noise_psk"][0] if "noise_psk" in params else None,
password=self._password,
noise_psk=self._noise_psk,
)

self._disconnect_api = True
Expand All @@ -152,11 +185,13 @@ async def _async_open(self) -> None:
# Don't disconnect an externally-passed API
self._disconnect_api = False

@translate_esphome_errors
async def _async_subscribe(self) -> None:
assert self._api is not None
await self._subscribe_instance()
self._unsub = self._api.subscribe_serial_proxy_data(self._on_data)

@translate_esphome_errors
async def _ping(self, *, timeout: float) -> None:
"""Ping the ESPHome API."""
assert self._api is not None
Expand Down Expand Up @@ -227,6 +262,7 @@ def _set_modem_pins(self, modem_pins: ModemPins) -> None:
"""Set modem control bits."""
self._call_on_loop(self._async_set_modem_pins(modem_pins))

@translate_esphome_errors
async def _async_set_modem_pins(self, modem_pins: ModemPins) -> None:
assert self._api is not None
line_states = self._last_line_state
Expand All @@ -252,6 +288,7 @@ async def _async_set_modem_pins(self, modem_pins: ModemPins) -> None:
def _get_modem_pins(self) -> ModemPins:
return self._call_on_loop(self._async_get_modem_pins())

@translate_esphome_errors
async def _async_get_modem_pins(self) -> ModemPins:
assert self._api is not None
rsp = await self._api.serial_proxy_get_modem_pins(instance=self._instance_id)
Expand All @@ -277,6 +314,7 @@ def reset_read_buffer(self) -> None:
def reset_write_buffer(self) -> None:
"""Reset the write buffer."""

@translate_esphome_errors
async def _async_flush(self) -> None:
"""Flush write buffers."""
assert self._api is not None
Expand Down Expand Up @@ -345,6 +383,7 @@ def __init__(
super().__init__(loop, protocol)
self._unsub: Callable[[], None] | None = None

@translate_esphome_errors
async def _connect(self, **kwargs) -> None:
self._serial = ESPHomeSerial(loop=self._loop, **kwargs)
self._extra["serial"] = self._serial
Expand Down
8 changes: 7 additions & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,14 +286,20 @@ def create_adapter_pair(left: str, right: str) -> Iterator[tuple[str, str]]:


@contextlib.contextmanager
def create_esphome_pair(left_tty: str, right_tty: str) -> Iterator[tuple[str, str]]:
def create_esphome_pair(
left_tty: str,
right_tty: str,
*,
noise_psk: str = "",
) -> Iterator[tuple[str, str]]:
"""Create an esphome:// pair."""
assert ESPHOME_HOST_BINARY is not None

env = os.environ.copy()
env["SERIALX_UART_LEFT"] = left_tty
env["SERIALX_UART_RIGHT"] = right_tty
env["SERIALX_API_PORT"] = "0"
env["SERIALX_NOISE_PSK"] = noise_psk

process = subprocess.Popen( # noqa: S603
[ESPHOME_HOST_BINARY],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
DEPENDENCIES = ["api", "uart"]

CONF_API_PORT_ENV = "api_port_env"
CONF_NOISE_PSK_ENV = "noise_psk_env"
CONF_LEFT_UART_ENV = "left_uart_env"
CONF_LEFT_UART_ID = "left_uart_id"
CONF_RIGHT_UART_ENV = "right_uart_env"
CONF_RIGHT_UART_ID = "right_uart_id"

DEFAULT_API_PORT_ENV = "SERIALX_API_PORT"
DEFAULT_NOISE_PSK_ENV = "SERIALX_NOISE_PSK"
DEFAULT_LEFT_UART_ENV = "SERIALX_UART_LEFT"
DEFAULT_RIGHT_UART_ENV = "SERIALX_UART_RIGHT"

Expand All @@ -39,6 +41,9 @@
cv.Optional(
CONF_API_PORT_ENV, default=DEFAULT_API_PORT_ENV
): cv.string_strict,
cv.Optional(
CONF_NOISE_PSK_ENV, default=DEFAULT_NOISE_PSK_ENV
): cv.string_strict,
}
).extend(cv.COMPONENT_SCHEMA),
cv.only_on(PLATFORM_HOST),
Expand All @@ -58,3 +63,4 @@ async def to_code(config):
cg.add(var.set_left_uart_env(config[CONF_LEFT_UART_ENV]))
cg.add(var.set_right_uart_env(config[CONF_RIGHT_UART_ENV]))
cg.add(var.set_api_port_env(config[CONF_API_PORT_ENV]))
cg.add(var.set_noise_psk_env(config[CONF_NOISE_PSK_ENV]))
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "serialx_host_overrides.h"

#include "esphome/components/api/api_server.h"
#include "esphome/core/helpers.h"
#include "esphome/core/log.h"

#include <cerrno>
Expand Down Expand Up @@ -50,6 +51,24 @@ void SerialxHostOverridesComponent::setup() {

api::global_api_server->set_port(static_cast<uint16_t>(parsed));
ESP_LOGI(TAG, "Overrode API port from %s", this->api_port_env_.c_str());

#ifdef USE_API_NOISE
const char *noise_psk_value = std::getenv(this->noise_psk_env_.c_str());
if (noise_psk_value != nullptr) {
if (noise_psk_value[0] == '\0') {
// Empty string: disable encryption
api::global_api_server->set_noise_psk({});
ESP_LOGI(TAG, "Disabled noise encryption from %s", this->noise_psk_env_.c_str());
} else {
// Base64-encoded PSK
auto decoded = base64_decode(noise_psk_value);
api::psk_t psk{};
std::copy_n(decoded.begin(), std::min(decoded.size(), psk.size()), psk.begin());
api::global_api_server->set_noise_psk(psk);
ESP_LOGI(TAG, "Overrode noise PSK from %s", this->noise_psk_env_.c_str());
}
}
#endif // USE_API_NOISE
}

void SerialxHostOverridesComponent::loop() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ class SerialxHostOverridesComponent : public Component {
void set_left_uart_env(std::string env_name) { this->left_uart_env_ = std::move(env_name); }
void set_right_uart_env(std::string env_name) { this->right_uart_env_ = std::move(env_name); }
void set_api_port_env(std::string env_name) { this->api_port_env_ = std::move(env_name); }
void set_noise_psk_env(std::string env_name) { this->noise_psk_env_ = std::move(env_name); }

protected:
uart::HostUartComponent *left_uart_{nullptr};
uart::HostUartComponent *right_uart_{nullptr};
std::string left_uart_env_;
std::string right_uart_env_;
std::string api_port_env_;
std::string noise_psk_env_;
bool ready_printed_{false};
};

Expand Down
1 change: 1 addition & 0 deletions tests/esphome/host_daemon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ host:

api:
port: ${api_port}
encryption:

logger:
level: INFO
Expand Down
70 changes: 67 additions & 3 deletions tests/test_serial_esphome.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
allow_module_level=True,
)

from base64 import b64encode
from unittest.mock import patch
import urllib.parse

from serialx import open_serial_connection
from serialx import SerialException, open_serial_connection
from serialx.platforms.serial_esphome import (
ESPHOME_DEFAULT_PORT,
ESPHomeSerialTransport,
Expand All @@ -21,6 +23,13 @@
from .common import ESPHOME_HOST_BINARY, create_esphome_pair, create_socat_pair


def base64(key: bytes) -> str:
"""Base64 encode a Noise key."""
assert len(key) == 32

return b64encode(key).decode("ascii")


@pytest.mark.skipif(not ESPHOME_HOST_BINARY, reason="esphome host binary not available")
async def test_externally_passed_api() -> None:
"""Test passing an ESPHome API instance externally."""
Expand Down Expand Up @@ -90,8 +99,8 @@ async def test_connect_by_instance_id() -> None:
with create_esphome_pair(socat_left, socat_right) as (left, _right):
parsed = urllib.parse.urlparse(left)

# Connect by instance ID instead of name
url = f"esphome://{parsed.hostname}:{parsed.port}/0"
# Connect by instance ID instead of name, with a password
url = f"esphome://{parsed.hostname}:{parsed.port}/0?password=unused"

reader, writer = await open_serial_connection(
url=url,
Expand All @@ -118,3 +127,58 @@ async def test_connect_by_invalid_name() -> None:
url=url,
baudrate=115200,
)


@pytest.mark.skipif(not ESPHOME_HOST_BINARY, reason="esphome host binary not available")
async def test_connect_plaintext_to_encrypted_server() -> None:
"""Test that connecting without encryption to an encrypted server raises."""
with create_socat_pair() as (socat_left, socat_right):
with create_esphome_pair(
socat_left,
socat_right,
noise_psk=base64(b"A noise PSK we do not provide..."),
) as (left, _right):
parsed = urllib.parse.urlparse(left)
url = (
f"esphome://{parsed.hostname}:{parsed.port}?port_name=Serial+Proxy+Left"
)

with pytest.raises(SerialException, match="Connection requires encryption"):
await open_serial_connection(
url=url,
baudrate=115200,
)


@pytest.mark.skipif(not ESPHOME_HOST_BINARY, reason="esphome host binary not available")
async def test_connect_encrypted_plaintext_to_server() -> None:
"""Test that connecting with encryption to an unencrypted server raises."""
with create_socat_pair() as (socat_left, socat_right):
with create_esphome_pair(
socat_left,
socat_right,
) as (left, _right):
parsed = urllib.parse.urlparse(left)
noise_psk = base64(b"An unnecessary noise PSK we use.")

url = (
f"esphome://{parsed.hostname}:{parsed.port}"
f"?port_name=Serial+Proxy+Left"
f"&noise_psk={noise_psk}"
)

with pytest.raises(
SerialException, match="The device is using plaintext protocol"
):
await open_serial_connection(url=url, baudrate=115200)


async def test_connect_timeout_raises_timeout_error() -> None:
"""Test that a TCP connect timeout is translated to TimeoutError."""

with patch("aioesphomeapi.connection.TCP_CONNECT_TIMEOUT", 1.0):
with pytest.raises(TimeoutError, match="Timeout while connecting"):
# 192.0.2.1 is TEST-NET-1 (RFC 5737), packets are silently dropped
await open_serial_connection(
url="esphome://192.0.2.1:6053?port_name=test", baudrate=115200
)
Loading