From 8a97af4dced99c186050c6a655f3c33d0ed37bd1 Mon Sep 17 00:00:00 2001 From: agessaman Date: Sat, 14 Feb 2026 15:40:34 -0800 Subject: [PATCH 01/50] Enhance core functionality with new features and improvements - Added conditional import for CompanionRadio to manage availability. - Updated Dispatcher class to support raw packet subscribers for enhanced logging and data handling. - Improved AdvertHandler to utilize a new payload parsing method and publish events for node discovery. - Enhanced ControlHandler with optional debug logging for better traceability. - Refactored GroupTextHandler to improve message handling and event publishing. - Introduced binary response callback in ProtocolResponseHandler for better response management. - Added X25519 scalar clamping utility in CryptoUtils for compatibility with firmware. - Updated LocalIdentity to ensure ECDH matches firmware specifications. - Improved packet building methods to align with routing specifications. --- src/pymc_core/__init__.py | 12 + src/pymc_core/companion/__init__.py | 109 +++ src/pymc_core/companion/binary_parsing.py | 123 +++ src/pymc_core/companion/channel_store.py | 83 ++ src/pymc_core/companion/companion_base.py | 641 ++++++++++++++ src/pymc_core/companion/companion_bridge.py | 805 ++++++++++++++++++ src/pymc_core/companion/companion_radio.py | 501 +++++++++++ src/pymc_core/companion/constants.py | 73 ++ src/pymc_core/companion/contact_store.py | 236 +++++ src/pymc_core/companion/message_queue.py | 58 ++ src/pymc_core/companion/models.py | 98 +++ src/pymc_core/companion/path_cache.py | 55 ++ src/pymc_core/companion/stats_collector.py | 57 ++ src/pymc_core/node/dispatcher.py | 69 +- src/pymc_core/node/handlers/advert.py | 80 +- src/pymc_core/node/handlers/control.py | 22 +- src/pymc_core/node/handlers/group_text.py | 55 +- .../node/handlers/protocol_response.py | 67 +- src/pymc_core/node/handlers/text.py | 1 + src/pymc_core/protocol/crypto.py | 11 + src/pymc_core/protocol/identity.py | 18 +- src/pymc_core/protocol/packet_builder.py | 14 +- 22 files changed, 3099 insertions(+), 89 deletions(-) create mode 100644 src/pymc_core/companion/__init__.py create mode 100644 src/pymc_core/companion/binary_parsing.py create mode 100644 src/pymc_core/companion/channel_store.py create mode 100644 src/pymc_core/companion/companion_base.py create mode 100644 src/pymc_core/companion/companion_bridge.py create mode 100644 src/pymc_core/companion/companion_radio.py create mode 100644 src/pymc_core/companion/constants.py create mode 100644 src/pymc_core/companion/contact_store.py create mode 100644 src/pymc_core/companion/message_queue.py create mode 100644 src/pymc_core/companion/models.py create mode 100644 src/pymc_core/companion/path_cache.py create mode 100644 src/pymc_core/companion/stats_collector.py diff --git a/src/pymc_core/__init__.py b/src/pymc_core/__init__.py index 0d1e9ab..04d3f34 100644 --- a/src/pymc_core/__init__.py +++ b/src/pymc_core/__init__.py @@ -21,5 +21,17 @@ "__version__", ] +# Conditional import for CompanionRadio +try: + from .companion.companion_radio import CompanionRadio + + _COMPANION_AVAILABLE = True +except ImportError: + _COMPANION_AVAILABLE = False + CompanionRadio = None + +if _COMPANION_AVAILABLE: + __all__.append("CompanionRadio") + # End of mesh package exports diff --git a/src/pymc_core/companion/__init__.py b/src/pymc_core/companion/__init__.py new file mode 100644 index 0000000..dbf2b10 --- /dev/null +++ b/src/pymc_core/companion/__init__.py @@ -0,0 +1,109 @@ +""" +MeshCore Companion Radio - Python-native implementation. + +Provides contact management, messaging with offline queue, advertisement +broadcasting, channel management, path tracking, signing, telemetry, +statistics, and device configuration on top of MeshNode. +""" + +from .companion_radio import CompanionRadio +from .companion_bridge import CompanionBridge +from .channel_store import ChannelStore +from .contact_store import ContactStore +from .message_queue import MessageQueue +from .path_cache import PathCache +from .stats_collector import StatsCollector +from .constants import ( + ADV_TYPE_CHAT, + ADV_TYPE_REPEATER, + ADV_TYPE_ROOM, + ADV_TYPE_SENSOR, + ADVERT_LOC_NONE, + ADVERT_LOC_SHARE, + AUTOADD_CHAT, + AUTOADD_OVERWRITE_OLDEST, + AUTOADD_REPEATER, + AUTOADD_ROOM, + AUTOADD_SENSOR, + BinaryReqType, + DEFAULT_MAX_CHANNELS, + DEFAULT_MAX_CONTACTS, + DEFAULT_OFFLINE_QUEUE_SIZE, + MSG_SEND_FAILED, + MSG_SEND_SENT_DIRECT, + MSG_SEND_SENT_FLOOD, + STATS_TYPE_CORE, + STATS_TYPE_PACKETS, + STATS_TYPE_RADIO, + TELEM_MODE_ALLOW_ALL, + TELEM_MODE_ALLOW_FLAGS, + TELEM_MODE_DENY, + TXT_TYPE_CLI_DATA, + TXT_TYPE_PLAIN, + TXT_TYPE_SIGNED_PLAIN, +) +from .models import ( + AdvertPath, + Channel, + Contact, + NodePrefs, + PacketStats, + QueuedMessage, + SentResult, +) + +__all__ = [ + # Main classes + "CompanionRadio", + "CompanionBridge", + # Stores + "ContactStore", + "ChannelStore", + "MessageQueue", + "PathCache", + "StatsCollector", + # Models + "Contact", + "Channel", + "NodePrefs", + "SentResult", + "PacketStats", + "AdvertPath", + "QueuedMessage", + # ADV Types + "ADV_TYPE_CHAT", + "ADV_TYPE_REPEATER", + "ADV_TYPE_ROOM", + "ADV_TYPE_SENSOR", + # Text Types + "TXT_TYPE_PLAIN", + "TXT_TYPE_CLI_DATA", + "TXT_TYPE_SIGNED_PLAIN", + # Telemetry Modes + "TELEM_MODE_DENY", + "TELEM_MODE_ALLOW_FLAGS", + "TELEM_MODE_ALLOW_ALL", + # Location Policy + "ADVERT_LOC_NONE", + "ADVERT_LOC_SHARE", + # Auto-Add Config + "AUTOADD_OVERWRITE_OLDEST", + "AUTOADD_CHAT", + "AUTOADD_REPEATER", + "AUTOADD_ROOM", + "AUTOADD_SENSOR", + # Message Send Result + "MSG_SEND_FAILED", + "MSG_SEND_SENT_FLOOD", + "MSG_SEND_SENT_DIRECT", + # Binary request types + "BinaryReqType", + # Stats Types + "STATS_TYPE_CORE", + "STATS_TYPE_RADIO", + "STATS_TYPE_PACKETS", + # Defaults + "DEFAULT_MAX_CONTACTS", + "DEFAULT_MAX_CHANNELS", + "DEFAULT_OFFLINE_QUEUE_SIZE", +] diff --git a/src/pymc_core/companion/binary_parsing.py b/src/pymc_core/companion/binary_parsing.py new file mode 100644 index 0000000..dbe4ec1 --- /dev/null +++ b/src/pymc_core/companion/binary_parsing.py @@ -0,0 +1,123 @@ +"""Parse binary response payloads by request type (BinaryReqType).""" + +import struct +from typing import Any, Optional + +from .constants import BinaryReqType + + +def parse_binary_response( + request_type: int, + data: bytes, + pubkey_prefix: str = "", + context: Optional[dict] = None, +) -> Optional[dict]: + """Parse response_data by request_type. Returns dict or None.""" + if request_type == BinaryReqType.STATUS and len(data) >= 52: + return _parse_status(data, pubkey_prefix=pubkey_prefix or None) + if request_type == BinaryReqType.TELEMETRY and len(data) >= 0: + return _parse_telemetry(data) + if request_type == BinaryReqType.MMA and len(data) >= 4: + return _parse_mma(data[4:]) # skip 4-byte header + if request_type == BinaryReqType.ACL: + return _parse_acl(data) + if request_type == BinaryReqType.NEIGHBOURS: + return _parse_neighbours(data, context or {}) + return {"raw_hex": data.hex(), "request_type": request_type} + + +def _parse_status(data: bytes, pubkey_prefix: Optional[str] = None, offset: int = 0) -> dict: + """Parse status response (52 bytes).""" + res = {} + if pubkey_prefix is None and len(data) >= 8: + res["pubkey_pre"] = data[2:8].hex() + offset = 8 + else: + res["pubkey_pre"] = pubkey_prefix or "" + res["bat"] = int.from_bytes(data[offset : offset + 2], byteorder="little") + res["tx_queue_len"] = int.from_bytes(data[offset + 2 : offset + 4], byteorder="little") + res["noise_floor"] = int.from_bytes( + data[offset + 4 : offset + 6], byteorder="little", signed=True + ) + res["last_rssi"] = int.from_bytes( + data[offset + 6 : offset + 8], byteorder="little", signed=True + ) + res["nb_recv"] = int.from_bytes(data[offset + 8 : offset + 12], byteorder="little") + res["nb_sent"] = int.from_bytes(data[offset + 12 : offset + 16], byteorder="little") + res["airtime"] = int.from_bytes(data[offset + 16 : offset + 20], byteorder="little") + res["uptime"] = int.from_bytes(data[offset + 20 : offset + 24], byteorder="little") + res["sent_flood"] = int.from_bytes(data[offset + 24 : offset + 28], byteorder="little") + res["sent_direct"] = int.from_bytes(data[offset + 28 : offset + 32], byteorder="little") + res["recv_flood"] = int.from_bytes(data[offset + 32 : offset + 36], byteorder="little") + res["recv_direct"] = int.from_bytes(data[offset + 36 : offset + 40], byteorder="little") + res["full_evts"] = int.from_bytes(data[offset + 40 : offset + 42], byteorder="little") + res["last_snr"] = ( + int.from_bytes(data[offset + 42 : offset + 44], byteorder="little", signed=True) / 4 + ) + res["direct_dups"] = int.from_bytes(data[offset + 44 : offset + 46], byteorder="little") + res["flood_dups"] = int.from_bytes(data[offset + 46 : offset + 48], byteorder="little") + res["rx_airtime"] = int.from_bytes(data[offset + 48 : offset + 52], byteorder="little") + return res + + +def _parse_telemetry(data: bytes) -> dict: + """Telemetry: Cayenne LPP or raw. Return dict with raw_hex; optional LPP if cayennelpp available.""" + out: dict = {"raw_hex": data.hex()} + try: + from cayennelpp import LppFrame + frame = LppFrame.from_bytes(data) + out["lpp"] = [{"channel": d.channel, "type": d.type_id, "value": d.data} for d in frame.data] + except Exception: + pass + return out + + +def _parse_mma(data: bytes) -> dict: + """MMA: LPP min/max/avg or raw.""" + out: dict = {"raw_hex": data.hex()} + try: + from cayennelpp import LppFrame + frame = LppFrame.from_bytes(data) + out["mma"] = [{"channel": d.channel, "type": d.type_id, "data": d.data} for d in frame.data] + except Exception: + pass + return out + + +def _parse_acl(buf: bytes) -> dict: + """ACL: 7-byte entries (key 6 + perm 1).""" + res = [] + i = 0 + while i + 7 <= len(buf): + key = buf[i : i + 6].hex() + perm = buf[i + 6] + if key != "000000000000": + res.append({"key": key, "perm": perm}) + i += 7 + return {"acl": res} + + +def _parse_neighbours(data: bytes, context: dict) -> dict: + """Neighbours: count(2) + results_count(2) + entries (pubkey_prefix + secs_ago(4) + snr(1)).""" + if len(data) < 4: + return {"raw_hex": data.hex()} + pk_plen = context.get("pubkey_prefix_length", 6) + neighbours_count = int.from_bytes(data[0:2], "little", signed=True) + results_count = int.from_bytes(data[2:4], "little", signed=True) + neighbours_list = [] + i = 4 + for _ in range(results_count): + if i + pk_plen + 4 + 1 > len(data): + break + pubkey = data[i : i + pk_plen].hex() + i += pk_plen + secs_ago = int.from_bytes(data[i : i + 4], "little", signed=True) + i += 4 + snr = int.from_bytes(data[i : i + 1], "little", signed=True) / 4 + i += 1 + neighbours_list.append({"pubkey": pubkey, "secs_ago": secs_ago, "snr": snr}) + return { + "neighbours_count": neighbours_count, + "results_count": results_count, + "neighbours": neighbours_list, + } diff --git a/src/pymc_core/companion/channel_store.py b/src/pymc_core/companion/channel_store.py new file mode 100644 index 0000000..4173a20 --- /dev/null +++ b/src/pymc_core/companion/channel_store.py @@ -0,0 +1,83 @@ +"""In-memory channel storage compatible with MeshNode's channel_db interface.""" + +from typing import Optional + +from .constants import DEFAULT_MAX_CHANNELS +from .models import Channel + + +class ChannelStore: + """In-memory channel storage compatible with MeshNode's channel_db interface. + + Provides both the interface expected by GroupTextHandler (get_channels returning + list of dicts) and companion radio operations (get/set/remove by index). + """ + + def __init__(self, max_channels: int = DEFAULT_MAX_CHANNELS): + self._channels: list[Optional[Channel]] = [None] * max_channels + self._max_channels = max_channels + + @property + def max_channels(self) -> int: + """Maximum number of channels (read-only). Used by companion protocol device info.""" + return self._max_channels + + # ------------------------------------------------------------------ + # Interface expected by GroupTextHandler / PacketBuilder + # ------------------------------------------------------------------ + + def get_channels(self) -> list[dict]: + """Return channels as list of dicts with 'name' and 'secret' keys. + + The secret is returned as a hex string, which is what the existing + GroupTextHandler and PacketBuilder expect. + """ + result = [] + for ch in self._channels: + if ch is not None: + result.append( + { + "name": ch.name, + "secret": ch.secret.hex(), + } + ) + return result + + # ------------------------------------------------------------------ + # Companion radio methods + # ------------------------------------------------------------------ + + def get(self, idx: int) -> Optional[Channel]: + """Get a channel by index. Returns None if index invalid or empty.""" + if 0 <= idx < self._max_channels: + return self._channels[idx] + return None + + def set(self, idx: int, channel: Channel) -> bool: + """Set a channel at the given index. Returns False if index out of range.""" + if 0 <= idx < self._max_channels: + self._channels[idx] = channel + return True + return False + + def remove(self, idx: int) -> bool: + """Remove a channel at the given index. Returns False if index invalid or already empty.""" + if 0 <= idx < self._max_channels and self._channels[idx] is not None: + self._channels[idx] = None + return True + return False + + def find_by_name(self, name: str) -> Optional[int]: + """Find a channel index by name. Returns None if not found.""" + for idx, ch in enumerate(self._channels): + if ch is not None and ch.name == name: + return idx + return None + + def get_count(self) -> int: + """Return the number of configured channels.""" + return sum(1 for ch in self._channels if ch is not None) + + def clear(self): + """Remove all channels.""" + self._channels = [None] * self._max_channels diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py new file mode 100644 index 0000000..974fced --- /dev/null +++ b/src/pymc_core/companion/companion_base.py @@ -0,0 +1,641 @@ +""" +CompanionBase - Shared logic for CompanionRadio and CompanionBridge. + +Provides stores, event handling, contact management, device configuration, +and push callbacks. Subclasses implement TX via MeshNode or packet_injector. +""" + +from __future__ import annotations + +import asyncio +import copy +import logging +import struct +import time +from collections import OrderedDict +from typing import Any, Callable, Optional + +from ..node.events import EventService, EventSubscriber, MeshEvents +from ..protocol import LocalIdentity, PacketBuilder +from ..protocol.constants import ( + ADVERT_FLAG_HAS_NAME, + ADVERT_FLAG_IS_CHAT_NODE, + ADVERT_FLAG_IS_REPEATER, + ADVERT_FLAG_IS_ROOM_SERVER, +) +from .channel_store import ChannelStore +from .constants import ( + ADV_TYPE_CHAT, + ADV_TYPE_REPEATER, + ADV_TYPE_ROOM, + ADV_TYPE_SENSOR, + ADVERT_LOC_SHARE, + DEFAULT_MAX_CHANNELS, + DEFAULT_MAX_CONTACTS, + DEFAULT_OFFLINE_QUEUE_SIZE, + MAX_SIGN_DATA_SIZE, + STATS_TYPE_CORE, + STATS_TYPE_PACKETS, + STATS_TYPE_RADIO, +) +from .contact_store import ContactStore +from .message_queue import MessageQueue +from .models import AdvertPath, Channel, Contact, NodePrefs, QueuedMessage +from .path_cache import PathCache +from .stats_collector import StatsCollector + +logger = logging.getLogger("CompanionBase") + +PUSH_CALLBACK_KEYS = [ + "message_received", + "channel_message_received", + "advert_received", + "contact_path_updated", + "send_confirmed", + "trace_received", + "node_discovered", + "login_result", + "telemetry_response", + "status_response", + "raw_data_received", + "binary_response", + "path_discovery_response", +] + + +class ResponseWaiter: + """Helper for awaiting async protocol/login responses.""" + + def __init__(self) -> None: + self.event = asyncio.Event() + self.data: dict = {"success": False, "text": None, "parsed": {}} + + def callback( + self, + success: bool, + text: str, + parsed_data: Optional[dict] = None, + ) -> None: + self.data["success"] = success + self.data["text"] = text + self.data["parsed"] = parsed_data or {} + self.event.set() + + async def wait(self, timeout: float = 10.0) -> dict: + try: + await asyncio.wait_for(self.event.wait(), timeout=timeout) + return self.data + except asyncio.TimeoutError: + return {**self.data, "timeout": True} + + +class _CompanionEventSubscriber(EventSubscriber): + """Bridges event service to companion push callbacks.""" + + def __init__(self, companion: CompanionBase) -> None: + self._companion = companion + + async def handle_event(self, event_type: str, data: dict) -> None: + await self._companion._handle_mesh_event(event_type, data) + + +def adv_type_to_flags(adv_type: int) -> int: + """Convert ADV_TYPE_* constant to advertisement flags byte.""" + if adv_type == ADV_TYPE_CHAT: + return ADVERT_FLAG_IS_CHAT_NODE + elif adv_type == ADV_TYPE_REPEATER: + return ADVERT_FLAG_IS_REPEATER + elif adv_type == ADV_TYPE_ROOM: + return ADVERT_FLAG_IS_ROOM_SERVER + elif adv_type == ADV_TYPE_SENSOR: + return 0x04 + return ADVERT_FLAG_IS_CHAT_NODE + + +class CompanionBase: + """Base class for companion implementations. + + Provides shared stores, event handling, contact management, device config, + and push callbacks. Subclasses implement TX (via node or packet_injector). + """ + + def _init_companion_stores( + self, + identity: LocalIdentity, + node_name: str = "pyMC", + adv_type: int = ADV_TYPE_CHAT, + max_contacts: int = DEFAULT_MAX_CONTACTS, + max_channels: int = DEFAULT_MAX_CHANNELS, + offline_queue_size: int = DEFAULT_OFFLINE_QUEUE_SIZE, + radio_config: Optional[dict] = None, + ) -> None: + """Initialize shared stores, prefs, event service, and push callbacks.""" + self._identity = identity + self._radio_config = radio_config or {} + self._running = False + + self.contacts = ContactStore(max_contacts) + self.channels = ChannelStore(max_channels) + self.message_queue = MessageQueue(offline_queue_size) + self.path_cache = PathCache() + self.stats = StatsCollector() + + self.prefs = NodePrefs( + node_name=node_name, + adv_type=adv_type, + tx_power_dbm=self._radio_config.get("power", self._radio_config.get("tx_power", 20)), + frequency_hz=self._radio_config.get("frequency", 915000000), + bandwidth_hz=self._radio_config.get("bandwidth", 250000), + spreading_factor=self._radio_config.get("spreading_factor", 10), + coding_rate=self._radio_config.get("coding_rate", 5), + ) + + self._custom_vars: dict[str, str] = {} + self._sign_buffer: Optional[bytearray] = None + self._flood_transport_key: Optional[bytes] = None + + self._event_service = EventService() + self._event_subscriber = _CompanionEventSubscriber(self) + self._event_service.subscribe_all(self._event_subscriber) + + self._push_callbacks: dict[str, list[Callable]] = { + k: [] for k in PUSH_CALLBACK_KEYS + } + + # Pending binary requests by tag (hex) for matching responses + self._pending_binary_requests: dict[str, dict] = {} + + # GRP_TXT dedup by packet hash: match Mesh.cpp behavior (only process when !_tables->hasSeen(pkt)), + # so companion queues one frame per logical message like the firmware. + self._seen_grp_txt: OrderedDict[str, float] = OrderedDict() + self._seen_grp_txt_ttl = 300 + self._seen_grp_txt_max = 1000 + + # ------------------------------------------------------------------------- + # Contact Management + # ------------------------------------------------------------------------- + + def get_contacts(self, since: int = 0) -> list[Contact]: + return self.contacts.get_all(since=since) + + def get_contact_by_key(self, pub_key: bytes) -> Optional[Contact]: + return self.contacts.get_by_key(pub_key) + + def get_contact_by_name(self, name: str) -> Optional[Contact]: + proxy = self.contacts.get_by_name(name) + if proxy: + return self.contacts.get_by_key(bytes.fromhex(proxy.public_key)) + return None + + def add_update_contact(self, contact: Contact) -> bool: + if contact.lastmod == 0: + contact.lastmod = int(time.time()) + return self.contacts.add(contact) + + def remove_contact(self, pub_key: bytes) -> bool: + return self.contacts.remove(pub_key) + + def export_contact(self, pub_key: Optional[bytes] = None) -> Optional[bytes]: + if pub_key is None: + key = self._identity.get_public_key() + name = self.prefs.node_name.encode("utf-8")[:32] + name = name + b"\x00" * (32 - len(name)) + lat = int(self.prefs.latitude * 1e6) + lon = int(self.prefs.longitude * 1e6) + return struct.pack( + "<32sB32sii", + key, + self.prefs.adv_type, + name, + lat, + lon, + ) + contact = self.contacts.get_by_key(pub_key) + if not contact: + return None + name = contact.name.encode("utf-8")[:32] + name = name + b"\x00" * (32 - len(name)) + lat = int(contact.gps_lat * 1e6) + lon = int(contact.gps_lon * 1e6) + return struct.pack( + "<32sB32sii", + contact.public_key, + contact.adv_type, + name, + lat, + lon, + ) + + def import_contact(self, packet_data: bytes) -> bool: + if len(packet_data) < 73: + logger.warning(f"Import data too short: {len(packet_data)} bytes") + return False + try: + pub_key = packet_data[:32] + adv_type = packet_data[32] + name_raw = packet_data[33:65] + lat, lon = struct.unpack_from(" None: + """Set the node's advertised name (max 31 chars).""" + self.prefs.node_name = name[:31] + + def set_advert_latlon(self, lat: float, lon: float) -> None: + if not (-90.0 <= lat <= 90.0): + raise ValueError(f"Latitude out of range: {lat}") + if not (-180.0 <= lon <= 180.0): + raise ValueError(f"Longitude out of range: {lon}") + self.prefs.latitude = lat + self.prefs.longitude = lon + + def set_radio_params(self, freq_hz: int, bw_hz: int, sf: int, cr: int) -> bool: + if not (5 <= sf <= 12): + raise ValueError(f"Spreading factor out of range: {sf}") + if not (5 <= cr <= 8): + raise ValueError(f"Coding rate out of range: {cr}") + self.prefs.frequency_hz = freq_hz + self.prefs.bandwidth_hz = bw_hz + self.prefs.spreading_factor = sf + self.prefs.coding_rate = cr + return True + + def set_tx_power(self, power_dbm: int) -> bool: + self.prefs.tx_power_dbm = power_dbm + return True + + def set_tuning_params(self, rx_delay: float, airtime_factor: float) -> None: + self.prefs.rx_delay_base = rx_delay + self.prefs.airtime_factor = airtime_factor + + def get_tuning_params(self) -> tuple[float, float]: + return (self.prefs.rx_delay_base, self.prefs.airtime_factor) + + def set_other_params( + self, + manual_add: int, + telemetry_modes: int, + advert_loc_policy: int, + multi_acks: int, + ) -> None: + self.prefs.manual_add_contacts = manual_add + self.prefs.telemetry_mode_base = telemetry_modes & 0x03 + self.prefs.telemetry_mode_location = (telemetry_modes >> 2) & 0x03 + self.prefs.telemetry_mode_environment = (telemetry_modes >> 4) & 0x03 + self.prefs.advert_loc_policy = advert_loc_policy + self.prefs.multi_acks = multi_acks + + def get_self_info(self) -> NodePrefs: + return copy.copy(self.prefs) + + def get_public_key(self) -> bytes: + return self._identity.get_public_key() + + # ------------------------------------------------------------------------- + # Path & Routing + # ------------------------------------------------------------------------- + + def reset_path(self, pub_key: bytes) -> bool: + contact = self.contacts.get_by_key(pub_key) + if not contact: + return False + contact.out_path_len = -1 + contact.out_path = b"" + self.contacts.update(contact) + return True + + def get_advert_path(self, pub_key_prefix: bytes) -> Optional[AdvertPath]: + return self.path_cache.get_by_prefix(pub_key_prefix) + + # ------------------------------------------------------------------------- + # Channel Management + # ------------------------------------------------------------------------- + + def get_channel(self, idx: int) -> Optional[Channel]: + return self.channels.get(idx) + + def set_channel(self, idx: int, name: str, secret: bytes) -> bool: + # MeshCore DataStore uses 32-byte secret; GroupTextHandler uses up to 32 for HMAC + if len(secret) < 32: + secret = secret + b"\x00" * (32 - len(secret)) + elif len(secret) > 32: + secret = secret[:32] + return self.channels.set(idx, Channel(name=name[:32], secret=secret)) + + # ------------------------------------------------------------------------- + # Signing Pipeline + # ------------------------------------------------------------------------- + + def sign_start(self) -> int: + self._sign_buffer = bytearray() + return MAX_SIGN_DATA_SIZE + + def sign_data(self, data: bytes) -> bool: + if self._sign_buffer is None: + logger.warning("sign_data called without sign_start") + return False + if len(self._sign_buffer) + len(data) > MAX_SIGN_DATA_SIZE: + logger.warning("Sign data would overflow buffer") + return False + self._sign_buffer.extend(data) + return True + + def sign_finish(self) -> Optional[bytes]: + if self._sign_buffer is None: + logger.warning("sign_finish called without sign_start") + return None + try: + return self._identity.sign(bytes(self._sign_buffer)) + except Exception as e: + logger.error(f"Signing error: {e}") + return None + finally: + self._sign_buffer = None + + # ------------------------------------------------------------------------- + # Key Management + # ------------------------------------------------------------------------- + + def export_private_key(self) -> bytes: + return self._identity.get_signing_key_bytes() + + # ------------------------------------------------------------------------- + # Flood Scope + # ------------------------------------------------------------------------- + + def set_flood_scope(self, transport_key: Optional[bytes] = None) -> None: + if transport_key and len(transport_key) >= 16: + self._flood_transport_key = transport_key[:16] + else: + self._flood_transport_key = None + + # ------------------------------------------------------------------------- + # Statistics (subclasses may override _get_radio_stats for STATS_TYPE_RADIO) + # ------------------------------------------------------------------------- + + def get_stats(self, stats_type: int = STATS_TYPE_PACKETS) -> dict: + if stats_type == STATS_TYPE_CORE: + return { + "uptime_secs": self.stats.get_uptime_secs(), + "queue_len": self.message_queue.count, + "contacts_count": self.contacts.get_count(), + "channels_count": self.channels.get_count(), + } + elif stats_type == STATS_TYPE_RADIO: + return self._get_radio_stats() + return self.stats.get_totals() + + def _get_radio_stats(self) -> dict: + """Override in CompanionRadio for hardware RSSI/SNR. Default: prefs only.""" + return { + "frequency_hz": self.prefs.frequency_hz, + "bandwidth_hz": self.prefs.bandwidth_hz, + "spreading_factor": self.prefs.spreading_factor, + "coding_rate": self.prefs.coding_rate, + "tx_power_dbm": self.prefs.tx_power_dbm, + } + + # ------------------------------------------------------------------------- + # Custom Variables + # ------------------------------------------------------------------------- + + def get_custom_vars(self) -> dict[str, str]: + return dict(self._custom_vars) + + def set_custom_var(self, name: str, value: str) -> bool: + self._custom_vars[name] = value + return True + + # ------------------------------------------------------------------------- + # Auto-Add Configuration + # ------------------------------------------------------------------------- + + def get_autoadd_config(self) -> int: + return self.prefs.autoadd_config + + def set_autoadd_config(self, config: int) -> None: + self.prefs.autoadd_config = config + + # ------------------------------------------------------------------------- + # Push Callbacks + # ------------------------------------------------------------------------- + + def on_message_received(self, callback: Callable) -> None: + self._push_callbacks["message_received"].append(callback) + + def on_channel_message_received(self, callback: Callable) -> None: + self._push_callbacks["channel_message_received"].append(callback) + + def on_advert_received(self, callback: Callable) -> None: + self._push_callbacks["advert_received"].append(callback) + + def on_contact_path_updated(self, callback: Callable) -> None: + self._push_callbacks["contact_path_updated"].append(callback) + + def on_send_confirmed(self, callback: Callable) -> None: + self._push_callbacks["send_confirmed"].append(callback) + + def on_trace_received(self, callback: Callable) -> None: + self._push_callbacks["trace_received"].append(callback) + + def on_node_discovered(self, callback: Callable) -> None: + self._push_callbacks["node_discovered"].append(callback) + + def on_login_result(self, callback: Callable) -> None: + self._push_callbacks["login_result"].append(callback) + + def on_telemetry_response(self, callback: Callable) -> None: + self._push_callbacks["telemetry_response"].append(callback) + + def on_status_response(self, callback: Callable) -> None: + self._push_callbacks["status_response"].append(callback) + + def on_raw_data_received(self, callback: Callable) -> None: + self._push_callbacks["raw_data_received"].append(callback) + + def on_binary_response(self, callback: Callable) -> None: + """Register callback for PUSH_CODE_BINARY_RESPONSE (0x8C). Callback(tag_bytes, response_data, ...).""" + self._push_callbacks["binary_response"].append(callback) + + def on_path_discovery_response(self, callback: Callable) -> None: + """Register callback for path discovery response (PUSH 0x8D). Callback(tag_bytes, contact_pubkey, out_path, in_path).""" + self._push_callbacks["path_discovery_response"].append(callback) + + def register_binary_request( + self, + tag_hex: str, + request_type: int, + timeout_seconds: float, + pubkey_prefix: str = "", + context: Optional[dict] = None, + ) -> None: + """Register a pending binary request for matching responses. Call cleanup_expired_requests first.""" + self._pending_binary_requests[tag_hex] = { + "request_type": request_type, + "pubkey_prefix": pubkey_prefix, + "expires_at": time.time() + timeout_seconds, + "context": context or {}, + } + + def cleanup_expired_binary_requests(self) -> None: + """Remove expired entries from _pending_binary_requests.""" + now = time.time() + expired = [ + tag for tag, info in self._pending_binary_requests.items() + if now > info["expires_at"] + ] + for tag in expired: + del self._pending_binary_requests[tag] + + async def _on_binary_response( + self, + tag_bytes: bytes, + response_data: bytes, + path_info: Optional[tuple] = None, + ) -> None: + """Called by ProtocolResponseHandler when a binary response (tag + data, optional path) is received.""" + if path_info is not None: + if await self._try_handle_path_discovery(tag_bytes, path_info): + return + self.cleanup_expired_binary_requests() + tag_hex = tag_bytes.hex() + info = self._pending_binary_requests.pop(tag_hex, None) + if not info: + logger.debug(f"Binary response for unknown tag {tag_hex}") + await self._fire_callbacks("binary_response", tag_bytes, response_data) + return + request_type = info["request_type"] + pubkey_prefix = info.get("pubkey_prefix", "") + context = info.get("context", {}) + parsed = None + try: + from . import binary_parsing + parsed = binary_parsing.parse_binary_response( + request_type, response_data, pubkey_prefix=pubkey_prefix, context=context + ) + except Exception as e: + logger.debug(f"Binary response parse for type {request_type}: {e}") + await self._fire_callbacks( + "binary_response", tag_bytes, response_data, parsed, request_type + ) + + async def _try_handle_path_discovery( + self, tag_bytes: bytes, path_info: tuple + ) -> bool: + """If this tag is a pending path discovery, fire path_discovery_response and return True. Override in bridge.""" + return False + + # ------------------------------------------------------------------------- + # Event Handling (shared) + # ------------------------------------------------------------------------- + + async def _handle_mesh_event(self, event_type: str, data: dict) -> None: + try: + if event_type == MeshEvents.NEW_MESSAGE: + await self._handle_new_message(data) + elif event_type == MeshEvents.NEW_CHANNEL_MESSAGE: + await self._handle_new_channel_message(data) + elif event_type == MeshEvents.NEW_CONTACT: + await self._fire_callbacks("node_discovered", data) + elif event_type == MeshEvents.CONTACT_UPDATED: + pass + elif event_type == MeshEvents.NODE_DISCOVERED: + await self._fire_callbacks("node_discovered", data) + elif event_type == MeshEvents.TELEMETRY_UPDATED: + await self._fire_callbacks("telemetry_response", data) + except Exception as e: + logger.error(f"Error handling mesh event {event_type}: {e}") + + async def _handle_new_message(self, data: dict) -> None: + sender_key_hex = data.get("contact_pubkey", "") + sender_key = bytes.fromhex(sender_key_hex) if sender_key_hex else b"" + # Handler publishes "message_text"; accept "text" for compatibility + message_text = (data.get("message_text") or data.get("text") or "").rstrip("\x00") + msg = QueuedMessage( + sender_key=sender_key, + txt_type=data.get("txt_type", data.get("flags", 0)), + timestamp=data.get("timestamp", int(time.time())), + text=message_text, + is_channel=False, + path_len=0, + ) + self.message_queue.push(msg) + await self._fire_callbacks( + "message_received", + sender_key, + message_text, + msg.timestamp, + msg.txt_type, + ) + + async def _handle_new_channel_message(self, data: dict) -> None: + # Deduplicate by packet hash so we queue one frame per logical message, matching + # firmware: Mesh.cpp only calls onChannelMessageRecv when !_tables->hasSeen(pkt). + pkt_hash = data.get("packet_hash") + if pkt_hash: + now = time.time() + if pkt_hash in self._seen_grp_txt: + return + expired = [k for k, ts in self._seen_grp_txt.items() if now - ts > self._seen_grp_txt_ttl] + for k in expired: + del self._seen_grp_txt[k] + self._seen_grp_txt[pkt_hash] = now + if len(self._seen_grp_txt) > self._seen_grp_txt_max: + self._seen_grp_txt.popitem(last=False) + + path_len = data.get("path_len", 0) + channel_name = data.get("channel_name", "") + # Resolve channel index so sync_next_message returns correct channel_idx in the frame + channel_idx = 0 + if getattr(self, "channels", None) and hasattr(self.channels, "find_by_name"): + idx = self.channels.find_by_name(channel_name) + if idx is not None: + channel_idx = idx + # MeshCore client expects "SenderName: Message" format in text field; it parses to show + # sender and message separately. Use full_content (not message_text) so client can split. + # Strip trailing nulls so frame matches firmware (exact string length, no padding). + display_text = (data.get("full_content", data.get("message_text", "")) or "").rstrip("\x00") + msg = QueuedMessage( + sender_key=b"", + txt_type=0, + timestamp=data.get("timestamp", int(time.time())), + text=display_text, + is_channel=True, + channel_idx=channel_idx, + path_len=path_len, + ) + self.message_queue.push(msg) + await self._fire_callbacks( + "channel_message_received", + data.get("channel_name", ""), + data.get("sender_name", ""), + display_text, + msg.timestamp, + path_len, + ) + + async def _fire_callbacks(self, event_name: str, *args: Any) -> None: + for callback in self._push_callbacks.get(event_name, []): + try: + if asyncio.iscoroutinefunction(callback): + await callback(*args) + else: + callback(*args) + except Exception as e: + logger.error(f"Error in {event_name} callback: {e}") diff --git a/src/pymc_core/companion/companion_bridge.py b/src/pymc_core/companion/companion_bridge.py new file mode 100644 index 0000000..dcc6106 --- /dev/null +++ b/src/pymc_core/companion/companion_bridge.py @@ -0,0 +1,805 @@ +""" +CompanionBridge - Repeater-integrated companion mode. + +Provides the same API as CompanionRadio but uses a shared dispatcher via +packet_injector. No radio ownership; host (repeater) injects packets via +process_received_packet and TX goes through packet_injector. +""" + +from __future__ import annotations + +import asyncio +import logging +import random +import time +from typing import Any, Callable, Optional + +from ..node.events import EventService, EventSubscriber, MeshEvents +from ..node.handlers import ( + AdvertHandler, + GroupTextHandler, + LoginResponseHandler, + PathHandler, + ProtocolResponseHandler, + TextMessageHandler, +) +from ..node.handlers.login_server import LoginServerHandler +from ..protocol import LocalIdentity, PacketBuilder +from ..protocol import Packet +from ..protocol.constants import ( + ADVERT_FLAG_HAS_LOCATION, + ADVERT_FLAG_HAS_NAME, + PAYLOAD_TYPE_ACK, + PAYLOAD_TYPE_ADVERT, + PAYLOAD_TYPE_ANON_REQ, + PAYLOAD_TYPE_CONTROL, + PAYLOAD_TYPE_GRP_TXT, + PAYLOAD_TYPE_PATH, + PAYLOAD_TYPE_RESPONSE, + PAYLOAD_TYPE_TXT_MSG, + ROUTE_TYPE_FLOOD, + ROUTE_TYPE_TRANSPORT_FLOOD, +) +from ..protocol.constants import REQ_TYPE_GET_TELEMETRY_DATA, TELEM_PERM_BASE +from .companion_base import CompanionBase, ResponseWaiter, adv_type_to_flags +from .constants import ( + ADV_TYPE_CHAT, + ADVERT_LOC_SHARE, + DEFAULT_MAX_CHANNELS, + DEFAULT_MAX_CONTACTS, + DEFAULT_OFFLINE_QUEUE_SIZE, + TXT_TYPE_PLAIN, +) +from .models import Contact, QueuedMessage, SentResult + +logger = logging.getLogger("CompanionBridge") + + +# --------------------------------------------------------------------------- +# Bridge ACK handler: fires send_confirmed when ACK CRC matches a pending send +# --------------------------------------------------------------------------- + +MAX_PENDING_ACK_CRCS = 64 + + +class _BridgeAckHandler: + """Handles discrete ACK packets and PathHandler stub. Fires send_confirmed when ACK CRC matches a pending send.""" + + def __init__(self, bridge: "CompanionBridge") -> None: + self._bridge = bridge + + @staticmethod + def payload_type() -> int: + return PAYLOAD_TYPE_ACK + + async def __call__(self, packet: Packet) -> None: + if not packet.payload or len(packet.payload) != 4: + return + crc = int.from_bytes(packet.payload, "little") + if crc in self._bridge._pending_ack_crcs: + self._bridge._pending_ack_crcs.discard(crc) + await self._bridge._fire_callbacks("send_confirmed", crc) + + async def process_path_ack_variants(self, packet: Packet) -> Optional[int]: + return None + + async def _notify_ack_received(self, crc: int) -> None: + pass + + +# --------------------------------------------------------------------------- +# Main CompanionBridge class +# --------------------------------------------------------------------------- + + +class CompanionBridge(CompanionBase): + """Repeater-integrated companion: shared dispatcher, packet_injector for TX. + + No MeshNode, no radio. Host calls process_received_packet when packets + destined for this companion arrive. All TX goes through packet_injector. + """ + + def __init__( + self, + identity: LocalIdentity, + packet_injector: Callable[..., Any], + node_name: str = "pyMC", + adv_type: int = ADV_TYPE_CHAT, + max_contacts: int = DEFAULT_MAX_CONTACTS, + max_channels: int = DEFAULT_MAX_CHANNELS, + offline_queue_size: int = DEFAULT_OFFLINE_QUEUE_SIZE, + radio_config: Optional[dict] = None, + authenticate_callback: Optional[Callable[..., tuple[bool, int]]] = None, + ): + """Initialise the companion bridge.""" + self._init_companion_stores( + identity=identity, + node_name=node_name, + adv_type=adv_type, + max_contacts=max_contacts, + max_channels=max_channels, + offline_queue_size=offline_queue_size, + radio_config=radio_config, + ) + self._packet_injector = packet_injector + + async def _send_packet(pkt: Packet, wait_for_ack: bool = False) -> bool: + return await self._packet_injector(pkt, wait_for_ack=wait_for_ack) + + def _login_send_callback(pkt: Packet, delay_ms: int) -> None: + async def _delayed_send() -> None: + await asyncio.sleep(delay_ms / 1000.0) + await self._packet_injector(pkt, wait_for_ack=False) + asyncio.create_task(_delayed_send()) + + _log = lambda msg: logger.debug(f"[CompanionBridge] {msg}") + + self._pending_ack_crcs: set[int] = set() + self._pending_discovery_tags: set[int] = set() + ack_handler = _BridgeAckHandler(self) + protocol_response_handler = ProtocolResponseHandler( + _log, identity, self.contacts + ) + login_response_handler = LoginResponseHandler( + identity, self.contacts, _log + ) + login_response_handler.set_protocol_response_handler( + protocol_response_handler + ) + path_handler = PathHandler( + _log, ack_handler, protocol_response_handler, login_response_handler + ) + + auth_cb = authenticate_callback + if auth_cb is None: + def _reject_all(*args, **kwargs) -> tuple[bool, int]: + return (False, 0) + auth_cb = _reject_all + + login_server_handler = LoginServerHandler( + identity, _log, authenticate_callback=auth_cb, is_room_server=False + ) + login_server_handler.set_send_packet_callback(_login_send_callback) + + self._handlers: dict[int, Any] = { + PAYLOAD_TYPE_ACK: ack_handler, + PAYLOAD_TYPE_TXT_MSG: TextMessageHandler( + identity, + self.contacts, + _log, + _send_packet, + self._event_service, + self._radio_config, + ), + PAYLOAD_TYPE_ADVERT: AdvertHandler( + _log, event_service=self._event_service + ), + PAYLOAD_TYPE_PATH: path_handler, + PAYLOAD_TYPE_ANON_REQ: login_server_handler, + PAYLOAD_TYPE_GRP_TXT: GroupTextHandler( + identity, + self.contacts, + _log, + _send_packet, + self.channels, + self._event_service, + node_name, + ), + PAYLOAD_TYPE_RESPONSE: login_response_handler, + } + + self._protocol_response_handler = protocol_response_handler + self._login_response_handler = login_response_handler + self._text_handler = self._handlers[PAYLOAD_TYPE_TXT_MSG] + protocol_response_handler.set_binary_response_callback(self._on_binary_response) + + # ------------------------------------------------------------------------- + # RX Entry Point + # ------------------------------------------------------------------------- + + async def process_received_packet(self, packet: Packet) -> None: + """Process a packet destined for this companion.""" + ptype = packet.header >> 2 & 0x0F + route_type = packet.header & 0x03 + is_flood = route_type in (ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD) + self.stats.record_rx(is_flood=is_flood) + + handler = self._handlers.get(ptype) + if handler: + try: + result = await handler(packet) + if ptype == PAYLOAD_TYPE_ADVERT and result: + contact = self._update_stores_from_advert(packet, result) + if contact: + await self._fire_callbacks("advert_received", contact) + except Exception as e: + logger.error(f"Handler error for type {ptype:02X}: {e}") + + def _update_stores_from_advert(self, packet: Packet, advert_data: dict): + """Update ContactStore and PathCache from advert result. Returns the Contact or None.""" + try: + from .models import AdvertPath + + pub_key = bytes.fromhex(advert_data.get("public_key", "")) + if len(pub_key) < 7: + return None + name = advert_data.get("name", "") + if not name: + return None + # Inbound path: route the advert took to reach us (for discovery list / advert path display). + # Stored in path_cache only; contact.out_path is separate and set elsewhere (e.g. path discovery). + path_len = getattr(packet, "path_len", 0) or 0 + path = getattr(packet, "path", bytearray()) or bytearray() + effective_len = path_len if path_len > 0 else len(path) + inbound_path = bytes(path[:effective_len]) if effective_len > 0 else b"" + now = int(time.time()) + last_advert_ts = advert_data.get("advert_timestamp", 0) + if last_advert_ts > now: + last_advert_ts = now + # Contact: out_path is for sending to this contact; leave unknown (-1) until set by path update. + contact = Contact( + public_key=pub_key, + name=name, + adv_type=advert_data.get("contact_type_id", 0), + gps_lat=advert_data.get("latitude", 0.0), + gps_lon=advert_data.get("longitude", 0.0), + lastmod=now, + last_advert_timestamp=last_advert_ts, + out_path_len=-1, + out_path=b"", + ) + self.contacts.add(contact) + + # Path cache: store inbound path (path advert took to get here) for discovery list display. + self.path_cache.update( + AdvertPath( + public_key_prefix=pub_key[:7], + name=name, + path_len=len(inbound_path), + path=inbound_path, + recv_timestamp=int(time.time()), + ) + ) + return contact + except Exception as e: + logger.error(f"Error updating stores from advert: {e}") + return None + + # ------------------------------------------------------------------------- + # Lifecycle + # ------------------------------------------------------------------------- + + async def start(self) -> None: + self._running = True + logger.info( + f"CompanionBridge started: name={self.prefs.node_name}, " + f"key={self._identity.get_public_key().hex()[:16]}..." + ) + + async def stop(self) -> None: + self._running = False + logger.info("CompanionBridge stopped") + + @property + def is_running(self) -> bool: + return self._running + + # ------------------------------------------------------------------------- + # Advertisement + # ------------------------------------------------------------------------- + + async def advertise(self, flood: bool = True) -> bool: + flags = adv_type_to_flags(self.prefs.adv_type) + flags |= ADVERT_FLAG_HAS_NAME + lat, lon = 0.0, 0.0 + if self.prefs.advert_loc_policy == ADVERT_LOC_SHARE: + lat, lon = self.prefs.latitude, self.prefs.longitude + if lat != 0.0 or lon != 0.0: + flags |= ADVERT_FLAG_HAS_LOCATION + route = "flood" if flood else "direct" + pkt = PacketBuilder.create_advert( + local_identity=self._identity, + name=self.prefs.node_name, + lat=lat, + lon=lon, + flags=flags, + route_type=route, + ) + success = await self._packet_injector(pkt, wait_for_ack=False) + if success: + self.stats.record_tx(is_flood=flood) + else: + self.stats.record_tx_error() + return success + + # ------------------------------------------------------------------------- + # Messaging + # ------------------------------------------------------------------------- + + async def send_text_message( + self, + pub_key: bytes, + text: str, + txt_type: int = TXT_TYPE_PLAIN, + attempt: int = 1, + ) -> SentResult: + contact = self.contacts.get_by_key(pub_key) + if not contact: + logger.warning(f"Contact not found for key {pub_key.hex()[:12]}...") + return SentResult(success=False) + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return SentResult(success=False) + try: + pkt, ack_crc = PacketBuilder.create_text_message( + contact=proxy, + local_identity=self._identity, + message=text, + attempt=attempt, + ) + if len(self._pending_ack_crcs) < MAX_PENDING_ACK_CRCS: + self._pending_ack_crcs.add(ack_crc) + success = await self._packet_injector(pkt, wait_for_ack=True) + is_flood = contact.out_path_len <= 0 + if success: + self.stats.record_tx(is_flood=is_flood) + else: + self.stats.record_tx_error() + return SentResult( + success=success, + is_flood=is_flood, + expected_ack=ack_crc, + timeout_ms=None, + ) + except Exception as e: + logger.error(f"Error sending text message: {e}") + self.stats.record_tx_error() + return SentResult(success=False) + + async def send_channel_message(self, channel_idx: int, text: str) -> bool: + channel = self.channels.get(channel_idx) + if not channel: + logger.warning(f"Channel {channel_idx} not found") + return False + try: + pkt = PacketBuilder.create_group_datagram( + group_name=channel.name, + local_identity=self._identity, + message=text, + sender_name=self.prefs.node_name, + channels_config=self.channels.get_channels(), + ) + success = await self._packet_injector(pkt, wait_for_ack=False) + if success: + self.stats.record_tx(is_flood=True) + else: + self.stats.record_tx_error() + return success + except Exception as e: + logger.error(f"Error sending channel message: {e}") + self.stats.record_tx_error() + return False + + def sync_next_message(self) -> Optional[QueuedMessage]: + return self.message_queue.pop() + + async def send_raw_data( + self, + dest_key: bytes, + data: bytes, + path: Optional[bytes] = None, + ) -> SentResult: + contact = self.contacts.get_by_key(dest_key) + if not contact: + return SentResult(success=False) + try: + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return SentResult(success=False) + pkt, _ = PacketBuilder.create_protocol_request( + contact=proxy, + local_identity=self._identity, + protocol_code=0x00, + data=data, + ) + success = await self._packet_injector(pkt, wait_for_ack=False) + return SentResult(success=success) + except Exception as e: + logger.error(f"Error sending raw data: {e}") + return SentResult(success=False) + + # ------------------------------------------------------------------------- + # Contact Management (share_contact override) + # ------------------------------------------------------------------------- + + async def share_contact(self, pub_key: bytes) -> bool: + contact = self.contacts.get_by_key(pub_key) + if not contact: + return False + try: + pkt = PacketBuilder.create_advert( + local_identity=self._identity, + name=contact.name, + flags=adv_type_to_flags(contact.adv_type) | ADVERT_FLAG_HAS_NAME, + route_type="direct", + ) + return await self._packet_injector(pkt, wait_for_ack=False) + except Exception as e: + logger.error(f"Error sharing contact: {e}") + return False + + # ------------------------------------------------------------------------- + # Path & Routing + # ------------------------------------------------------------------------- + + async def send_trace_path( + self, + pub_key: bytes, + tag: int, + auth_code: int, + flags: int = 0, + ) -> bool: + contact = self.contacts.get_by_key(pub_key) + if not contact: + return False + path = list(contact.out_path) if contact.out_path else [] + if not path: + path = [contact.public_key[0]] + try: + pkt = PacketBuilder.create_trace(tag, auth_code, flags, path=path) + return await self._packet_injector(pkt, wait_for_ack=False) + except Exception as e: + logger.error(f"Error sending trace: {e}") + return False + + async def send_trace_path_raw( + self, + tag: int, + auth_code: int, + flags: int, + path_bytes: bytes, + ) -> bool: + """Send a trace packet with an explicit path (e.g. from CMD_SEND_TRACE_PATH). Matches firmware behavior.""" + try: + path_list = list(path_bytes) + pkt = PacketBuilder.create_trace(tag, auth_code, flags, path=path_list) + return await self._packet_injector(pkt, wait_for_ack=False) + except Exception as e: + logger.error(f"Error sending trace (raw path): {e}") + return False + + async def _try_handle_path_discovery( + self, tag_bytes: bytes, path_info: tuple + ) -> bool: + """If tag is pending path discovery, fire path_discovery_response and return True.""" + out_path, in_path, contact_pubkey = path_info + tag_int = int.from_bytes(tag_bytes, "little") + if tag_int not in self._pending_discovery_tags: + return False + self._pending_discovery_tags.discard(tag_int) + await self._fire_callbacks( + "path_discovery_response", + tag_bytes, + contact_pubkey, + out_path, + in_path, + ) + return True + + async def send_path_discovery(self, pub_key: bytes) -> bool: + """Legacy: send path discovery without returning tag. Prefer send_path_discovery_req.""" + result = await self.send_path_discovery_req(pub_key) + return result.success + + async def send_path_discovery_req(self, pub_key: bytes) -> SentResult: + """Send path discovery (flood telemetry request with tag). Returns SentResult for RESP_CODE_SENT. + When path return arrives with matching tag, path_discovery_response is fired (PUSH 0x8D).""" + contact = self.contacts.get_by_key(pub_key) + if not contact: + return SentResult(success=False) + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return SentResult(success=False) + tag_int = random.randint(0, 0xFFFFFFFF) + tag_bytes = tag_int.to_bytes(4, "little") + # Firmware: REQ_TYPE_GET_TELEMETRY_DATA, ~TELEM_PERM_BASE, reserved(3), random(4) -> 9 bytes; tag is from sendRequest. + # We send tag(4) + type(1) + perm(1) + reserved(3) = 9 bytes so response echoes our tag. + inv_perm = 0xFF & ~TELEM_PERM_BASE + req_payload = tag_bytes + bytes( + [REQ_TYPE_GET_TELEMETRY_DATA, inv_perm, 0, 0, 0] + ) + old_path_len = contact.out_path_len + old_path = contact.out_path + contact.out_path_len = -1 + contact.out_path = b"" + self.contacts.update(contact) + try: + pkt, _ = PacketBuilder.create_protocol_request( + contact=proxy, + local_identity=self._identity, + protocol_code=REQ_TYPE_GET_TELEMETRY_DATA, + data=req_payload, + ) + success = await self._packet_injector(pkt, wait_for_ack=False) + if success: + self._pending_discovery_tags.add(tag_int) + return SentResult( + success=success, + is_flood=True, + expected_ack=tag_int, + timeout_ms=10000, + ) + except Exception as e: + logger.error(f"Error in path discovery: {e}") + return SentResult(success=False) + finally: + current = self.contacts.get_by_key(pub_key) + if current and current.out_path_len == -1: + current.out_path_len = old_path_len + current.out_path = old_path + self.contacts.update(current) + + async def send_control_data(self, data: bytes) -> bool: + """Send a CONTROL packet (e.g. discovery request). data = first byte flags/type (0x80 set for DISCOVER_REQ) + payload. + Firmware: (cmd_frame[1] & 0x80) != 0, createControlData(&cmd_frame[1], len-1), sendZeroHop(resp). Returns True if sent.""" + if not data or len(data) > 254: + return False + if (data[0] & 0x80) == 0: + return False # firmware requires first byte to have 0x80 set (e.g. DISCOVER_REQ) + try: + pkt = Packet() + pkt.header = PacketBuilder._create_header(PAYLOAD_TYPE_CONTROL, route_type="direct") + pkt.path_len = 0 + pkt.path = bytearray() + pkt.payload = bytearray(data) + pkt.payload_len = len(data) + return await self._packet_injector(pkt, wait_for_ack=False) + except Exception as e: + logger.error(f"Error sending control data: {e}") + return False + + # ------------------------------------------------------------------------- + # Key Management + # ------------------------------------------------------------------------- + + def import_private_key(self, key: bytes) -> bool: + try: + self._identity = LocalIdentity(seed=key) + logger.info( + f"Imported new identity: {self._identity.get_public_key().hex()[:16]}..." + ) + return True + except Exception as e: + logger.error(f"Error importing private key: {e}") + return False + + # ------------------------------------------------------------------------- + # Requests + # ------------------------------------------------------------------------- + + async def send_login(self, pub_key: bytes, password: str) -> dict: + contact = self.contacts.get_by_key(pub_key) + if not contact: + return {"success": False, "reason": "Contact not found"} + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return {"success": False, "reason": "Contact not found"} + dest_hash = bytes.fromhex(proxy.public_key)[0] + self._login_response_handler.store_login_password(dest_hash, password) + login_result = {"success": False, "data": {}} + login_event = asyncio.Event() + + def _login_cb(success: bool, data: dict) -> None: + login_result["success"] = success + login_result["data"] = data + login_event.set() + + self._login_response_handler.set_login_callback(_login_cb) + try: + pkt = PacketBuilder.create_login_packet( + contact=proxy, local_identity=self._identity, password=password + ) + await self._packet_injector(pkt, wait_for_ack=False) + try: + await asyncio.wait_for(login_event.wait(), timeout=10.0) + except asyncio.TimeoutError: + return {"success": False, "reason": "Login response timeout"} + data = login_result["data"] + return { + "success": login_result["success"], + "repeater": contact.name, + "is_admin": data.get("is_admin", False), + "keep_alive_interval": data.get("keep_alive_interval", 0), + "tag": data.get("timestamp", 0), + "acl_permissions": data.get("reserved", data.get("permissions", 0)), + "reason": "Login successful" if login_result["success"] else "Login failed", + } + except Exception as e: + logger.error(f"Login error: {e}") + return {"success": False, "reason": str(e)} + finally: + self._login_response_handler.set_login_callback(None) + self._login_response_handler.clear_login_password(dest_hash) + + async def send_status_request(self, pub_key: bytes) -> dict: + return await self.send_repeater_command(pub_key, "status") + + async def send_telemetry_request( + self, + pub_key: bytes, + want_base: bool = True, + want_location: bool = True, + want_environment: bool = True, + timeout: float = 10.0, + ) -> dict: + contact = self.contacts.get_by_key(pub_key) + if not contact: + return {"success": False, "reason": "Contact not found"} + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return {"success": False, "reason": "Contact not found"} + contact_hash = bytes.fromhex(proxy.public_key)[0] + waiter = ResponseWaiter() + self._protocol_response_handler.set_response_callback( + contact_hash, waiter.callback + ) + try: + inv = PacketBuilder._compute_inverse_perm_mask( + want_base, want_location, want_environment + ) + pkt, _ = PacketBuilder.create_protocol_request( + contact=proxy, + local_identity=self._identity, + protocol_code=REQ_TYPE_GET_TELEMETRY_DATA, + data=bytes([inv]), + ) + await self._packet_injector(pkt, wait_for_ack=False) + result = await waiter.wait(timeout) + return { + "success": result.get("success", False), + "contact": contact.name, + "telemetry_data": result.get("parsed", {}), + "response_text": result.get("text"), + "reason": "Telemetry received" if result.get("success") else "Telemetry failed", + } + except Exception as e: + logger.error(f"Telemetry error: {e}") + return {"success": False, "reason": str(e)} + finally: + self._protocol_response_handler.clear_response_callback(contact_hash) + + async def send_binary_req( + self, pub_key: bytes, data: bytes, timeout_seconds: float = 15.0 + ) -> SentResult: + """Send binary request (CMD_SEND_BINARY_REQ). data = request_type(1) + optional payload. + Returns SentResult with expected_ack (4-byte tag as int) and timeout_ms for RESP_CODE_SENT. + """ + contact = self.contacts.get_by_key(pub_key) + if not contact: + return SentResult(success=False) + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return SentResult(success=False) + tag_int = random.randint(0, 0xFFFFFFFF) + tag_bytes = tag_int.to_bytes(4, "little") + tag_hex = tag_bytes.hex() + request_type = data[0] if len(data) >= 1 else 0 + req_payload = tag_bytes + data + self.cleanup_expired_binary_requests() + self.register_binary_request( + tag_hex, + request_type=request_type, + timeout_seconds=timeout_seconds, + pubkey_prefix=pub_key[:6].hex(), + ) + try: + pkt, _ = PacketBuilder.create_protocol_request( + contact=proxy, + local_identity=self._identity, + protocol_code=0x02, + data=req_payload, + ) + success = await self._packet_injector(pkt, wait_for_ack=False) + except Exception as e: + logger.error(f"Binary request send error: {e}") + self._pending_binary_requests.pop(tag_hex, None) + return SentResult(success=False) + if not success: + self._pending_binary_requests.pop(tag_hex, None) + return SentResult(success=False) + return SentResult( + success=True, + is_flood=contact.out_path_len <= 0, + expected_ack=tag_int, + timeout_ms=10000, + ) + + async def send_binary_request(self, pub_key: bytes, data: bytes) -> dict: + """Legacy: send binary request and wait for response via waiter. Prefer send_binary_req + on_binary_response.""" + return await self._send_protocol_request(pub_key, 0x02, data) + + async def send_anon_request(self, pub_key: bytes, data: bytes) -> dict: + return await self._send_protocol_request(pub_key, 0x07, data) + + async def _send_protocol_request( + self, pub_key: bytes, protocol_code: int, data: bytes + ) -> dict: + contact = self.contacts.get_by_key(pub_key) + if not contact: + return {"success": False, "reason": "Contact not found"} + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return {"success": False, "reason": "Contact not found"} + contact_hash = bytes.fromhex(proxy.public_key)[0] + waiter = ResponseWaiter() + self._protocol_response_handler.set_response_callback( + contact_hash, waiter.callback + ) + try: + pkt, _ = PacketBuilder.create_protocol_request( + contact=proxy, + local_identity=self._identity, + protocol_code=protocol_code, + data=data, + ) + await self._packet_injector(pkt, wait_for_ack=False) + result = await waiter.wait(10.0) + return { + "success": result.get("success", False), + "response": result.get("text"), + "parsed_data": result.get("parsed", {}), + "reason": "Success" if result.get("success") else "Failed", + } + except Exception as e: + logger.error(f"Protocol request error: {e}") + return {"success": False, "reason": str(e)} + finally: + self._protocol_response_handler.clear_response_callback(contact_hash) + + async def send_repeater_command( + self, pub_key: bytes, command: str, parameters: Optional[str] = None + ) -> dict: + contact = self.contacts.get_by_key(pub_key) + if not contact: + return {"success": False, "reason": "Contact not found"} + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return {"success": False, "reason": "Contact not found"} + full_command = command + if parameters: + full_command += f" {parameters}" + response_data = {"text": None, "success": False} + response_event = asyncio.Event() + + def _response_cb(message_text: str, sender_contact: Any) -> None: + response_data["text"] = message_text + response_data["success"] = True + response_event.set() + + self._text_handler.set_command_response_callback(_response_cb) + try: + pkt, ack_crc = PacketBuilder.create_text_message( + contact=proxy, + local_identity=self._identity, + message=full_command, + attempt=1, + message_type="command", + ) + await self._packet_injector(pkt, wait_for_ack=True) + try: + await asyncio.wait_for(response_event.wait(), timeout=15.0) + except asyncio.TimeoutError: + pass + return { + "success": response_data["success"], + "repeater": contact.name, + "command": command, + "response": response_data["text"], + "reason": "Command successful" if response_data["success"] else "No response", + } + except Exception as e: + logger.error(f"Repeater command error: {e}") + return {"success": False, "reason": str(e)} + finally: + self._text_handler.set_command_response_callback(None) + diff --git a/src/pymc_core/companion/companion_radio.py b/src/pymc_core/companion/companion_radio.py new file mode 100644 index 0000000..5f8b612 --- /dev/null +++ b/src/pymc_core/companion/companion_radio.py @@ -0,0 +1,501 @@ +""" +MeshCore Companion Radio - Python-native implementation. + +Provides the same feature set as the MeshCore companion radio firmware +(meshcore-dev/MeshCore/examples/companion_radio), implemented as a +high-level wrapper around MeshNode with in-memory contact, channel, +message queue, path cache, and statistics management. +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any, Callable, Optional + +from ..node.node import MeshNode +from ..protocol import LocalIdentity, PacketBuilder +from ..protocol.constants import ADVERT_FLAG_HAS_LOCATION, ADVERT_FLAG_HAS_NAME +from .companion_base import CompanionBase, adv_type_to_flags +from .constants import ( + ADV_TYPE_CHAT, + ADVERT_LOC_SHARE, + DEFAULT_MAX_CHANNELS, + DEFAULT_MAX_CONTACTS, + DEFAULT_OFFLINE_QUEUE_SIZE, + STATS_TYPE_PACKETS, + TXT_TYPE_PLAIN, +) +from .models import QueuedMessage, SentResult + +logger = logging.getLogger("CompanionRadio") + + +class CompanionRadio(CompanionBase): + """Python-native MeshCore companion radio. + + Wraps MeshNode and augments it with application-layer state and services + that the C++ companion radio firmware provides: contact management, + messaging with offline queue, advertisement broadcasting, channel + management, path tracking, signing, telemetry, statistics, and device + configuration. + + Example: + ```python + from pymc_core import CompanionRadio, LocalIdentity + from pymc_core.hardware import KissModemWrapper + + radio = KissModemWrapper("/dev/ttyUSB0") + radio.connect() + identity = LocalIdentity() + companion = CompanionRadio(radio, identity, node_name="myNode") + + async def main(): + await companion.start() + print(f"Key: {companion.get_public_key().hex()}") + await companion.advertise() + await companion.stop() + + asyncio.run(main()) + ``` + """ + + def __init__( + self, + radio: Any, + identity: LocalIdentity, + node_name: str = "pyMC", + adv_type: int = ADV_TYPE_CHAT, + max_contacts: int = DEFAULT_MAX_CONTACTS, + max_channels: int = DEFAULT_MAX_CHANNELS, + offline_queue_size: int = DEFAULT_OFFLINE_QUEUE_SIZE, + radio_config: Optional[dict] = None, + ): + """Initialise the companion radio.""" + self._init_companion_stores( + identity=identity, + node_name=node_name, + adv_type=adv_type, + max_contacts=max_contacts, + max_channels=max_channels, + offline_queue_size=offline_queue_size, + radio_config=radio_config, + ) + self._radio = radio + self._dispatcher_task: Optional[asyncio.Task] = None + + self.node = MeshNode( + radio=radio, + local_identity=identity, + config={ + "node": {"name": node_name}, + "radio": self._radio_config, + }, + contacts=self.contacts, + channel_db=self.channels, + event_service=self._event_service, + ) + self._setup_packet_callbacks() + + # ------------------------------------------------------------------------- + # Lifecycle + # ------------------------------------------------------------------------- + + async def start(self) -> None: + if self._running: + logger.warning("CompanionRadio already running") + return + self._running = True + self._dispatcher_task = asyncio.create_task(self.node.start()) + logger.info( + f"CompanionRadio started: name={self.prefs.node_name}, " + f"key={self._identity.get_public_key().hex()[:16]}..." + ) + + async def stop(self) -> None: + self._running = False + if self._dispatcher_task: + self._dispatcher_task.cancel() + try: + await self._dispatcher_task + except asyncio.CancelledError: + pass + self._dispatcher_task = None + self.node.stop() + logger.info("CompanionRadio stopped") + + @property + def is_running(self) -> bool: + return self._running + + # ------------------------------------------------------------------------- + # Advertisement + # ------------------------------------------------------------------------- + + async def advertise(self, flood: bool = True) -> bool: + flags = adv_type_to_flags(self.prefs.adv_type) + flags |= ADVERT_FLAG_HAS_NAME + lat, lon = 0.0, 0.0 + if self.prefs.advert_loc_policy == ADVERT_LOC_SHARE: + lat = self.prefs.latitude + lon = self.prefs.longitude + if lat != 0.0 or lon != 0.0: + flags |= ADVERT_FLAG_HAS_LOCATION + route = "flood" if flood else "direct" + pkt = PacketBuilder.create_advert( + local_identity=self._identity, + name=self.prefs.node_name, + lat=lat, + lon=lon, + flags=flags, + route_type=route, + ) + success = await self.node.dispatcher.send_packet(pkt, wait_for_ack=False) + if success: + self.stats.record_tx(is_flood=flood) + else: + self.stats.record_tx_error() + return success + + # ------------------------------------------------------------------------- + # Messaging + # ------------------------------------------------------------------------- + + async def send_text_message( + self, + pub_key: bytes, + text: str, + txt_type: int = TXT_TYPE_PLAIN, + attempt: int = 1, + ) -> SentResult: + contact = self.contacts.get_by_key(pub_key) + if not contact: + logger.warning(f"Contact not found for key {pub_key.hex()[:12]}...") + return SentResult(success=False) + try: + result = await self.node.send_text( + contact_name=contact.name, + message=text, + attempt=attempt, + ) + success = result.get("success", False) + is_flood = contact.out_path_len <= 0 + if success: + self.stats.record_tx(is_flood=is_flood) + else: + self.stats.record_tx_error() + return SentResult( + success=success, + is_flood=is_flood, + expected_ack=result.get("crc"), + timeout_ms=None, + ) + except Exception as e: + logger.error(f"Error sending text message: {e}") + self.stats.record_tx_error() + return SentResult(success=False) + + async def send_channel_message(self, channel_idx: int, text: str) -> bool: + channel = self.channels.get(channel_idx) + if not channel: + logger.warning(f"Channel {channel_idx} not found") + return False + try: + result = await self.node.send_group_text( + group_name=channel.name, + message=text, + ) + success = result.get("success", False) + if success: + self.stats.record_tx(is_flood=True) + else: + self.stats.record_tx_error() + return success + except Exception as e: + logger.error(f"Error sending channel message: {e}") + self.stats.record_tx_error() + return False + + def sync_next_message(self) -> Optional[QueuedMessage]: + return self.message_queue.pop() + + async def send_raw_data( + self, + dest_key: bytes, + data: bytes, + path: Optional[bytes] = None, + ) -> SentResult: + contact = self.contacts.get_by_key(dest_key) + if not contact: + logger.warning(f"Contact not found for raw data send: {dest_key.hex()[:12]}") + return SentResult(success=False) + try: + result = await self.node.send_protocol_request( + repeater_name=contact.name, + protocol_code=0x00, + data=data, + ) + return SentResult(success=result.get("success", False)) + except Exception as e: + logger.error(f"Error sending raw data: {e}") + return SentResult(success=False) + + # ------------------------------------------------------------------------- + # Contact Management (share_contact overrides base - uses node) + # ------------------------------------------------------------------------- + + async def share_contact(self, pub_key: bytes) -> bool: + contact = self.contacts.get_by_key(pub_key) + if not contact: + logger.warning(f"Contact not found for sharing: {pub_key.hex()[:12]}") + return False + try: + pkt = PacketBuilder.create_advert( + local_identity=self._identity, + name=contact.name, + flags=adv_type_to_flags(contact.adv_type) | ADVERT_FLAG_HAS_NAME, + route_type="direct", + ) + return await self.node.dispatcher.send_packet(pkt, wait_for_ack=False) + except Exception as e: + logger.error(f"Error sharing contact: {e}") + return False + + # ------------------------------------------------------------------------- + # Device Configuration (overrides for radio hardware) + # ------------------------------------------------------------------------- + + def set_advert_name(self, name: str) -> None: + super().set_advert_name(name) + self.node.node_name = self.prefs.node_name + + def set_radio_params(self, freq_hz: int, bw_hz: int, sf: int, cr: int) -> bool: + super().set_radio_params(freq_hz, bw_hz, sf, cr) + if hasattr(self._radio, "configure_radio"): + try: + self._radio.configure_radio( + frequency=freq_hz, + bandwidth=bw_hz, + spreading_factor=sf, + coding_rate=cr, + ) + return True + except Exception as e: + logger.error(f"Error configuring radio: {e}") + return False + return True + + def set_tx_power(self, power_dbm: int) -> bool: + super().set_tx_power(power_dbm) + if hasattr(self._radio, "set_tx_power"): + try: + self._radio.set_tx_power(power_dbm) + return True + except Exception as e: + logger.error(f"Error setting TX power: {e}") + return False + return True + + # ------------------------------------------------------------------------- + # Path & Routing + # ------------------------------------------------------------------------- + + async def send_trace_path( + self, + pub_key: bytes, + tag: int, + auth_code: int, + flags: int = 0, + ) -> bool: + contact = self.contacts.get_by_key(pub_key) + if not contact: + logger.warning(f"Contact not found for trace: {pub_key.hex()[:12]}") + return False + try: + result = await self.node.send_trace_packet( + contact_name=contact.name, + tag=tag, + auth_code=auth_code, + flags=flags, + ) + return result.get("success", False) + except Exception as e: + logger.error(f"Error sending trace: {e}") + return False + + async def send_path_discovery(self, pub_key: bytes) -> bool: + contact = self.contacts.get_by_key(pub_key) + if not contact: + return False + old_path_len = contact.out_path_len + old_path = contact.out_path + contact.out_path_len = -1 + contact.out_path = b"" + self.contacts.update(contact) + try: + result = await self.node.send_telemetry_request( + contact_name=contact.name, + want_base=False, + want_location=False, + want_environment=False, + timeout=5.0, + ) + return result.get("success", False) + except Exception as e: + logger.error(f"Error in path discovery: {e}") + return False + finally: + current = self.contacts.get_by_key(pub_key) + if current and current.out_path_len == -1: + current.out_path_len = old_path_len + current.out_path = old_path + self.contacts.update(current) + + # ------------------------------------------------------------------------- + # Key Management + # ------------------------------------------------------------------------- + + def import_private_key(self, key: bytes) -> bool: + try: + self._identity = LocalIdentity(seed=key) + self.node = MeshNode( + radio=self._radio, + local_identity=self._identity, + config={ + "node": {"name": self.prefs.node_name}, + "radio": self._radio_config, + }, + contacts=self.contacts, + channel_db=self.channels, + event_service=self._event_service, + ) + self._setup_packet_callbacks() + logger.info( + f"Imported new identity: {self._identity.get_public_key().hex()[:16]}..." + ) + return True + except Exception as e: + logger.error(f"Error importing private key: {e}") + return False + + # ------------------------------------------------------------------------- + # Requests + # ------------------------------------------------------------------------- + + async def send_login(self, pub_key: bytes, password: str) -> dict: + contact = self.contacts.get_by_key(pub_key) + if not contact: + return {"success": False, "reason": "Contact not found"} + try: + return await self.node.send_login( + repeater_name=contact.name, + password=password, + ) + except Exception as e: + logger.error(f"Login error: {e}") + return {"success": False, "reason": str(e)} + + async def send_status_request(self, pub_key: bytes) -> dict: + contact = self.contacts.get_by_key(pub_key) + if not contact: + return {"success": False, "reason": "Contact not found"} + try: + return await self.node.send_status_request(repeater_name=contact.name) + except Exception as e: + logger.error(f"Status request error: {e}") + return {"success": False, "reason": str(e)} + + async def send_telemetry_request( + self, + pub_key: bytes, + want_base: bool = True, + want_location: bool = True, + want_environment: bool = True, + timeout: float = 10.0, + ) -> dict: + contact = self.contacts.get_by_key(pub_key) + if not contact: + return {"success": False, "reason": "Contact not found"} + try: + return await self.node.send_telemetry_request( + contact_name=contact.name, + want_base=want_base, + want_location=want_location, + want_environment=want_environment, + timeout=timeout, + ) + except Exception as e: + logger.error(f"Telemetry request error: {e}") + return {"success": False, "reason": str(e)} + + async def send_binary_request(self, pub_key: bytes, data: bytes) -> dict: + contact = self.contacts.get_by_key(pub_key) + if not contact: + return {"success": False, "reason": "Contact not found"} + try: + return await self.node.send_protocol_request( + repeater_name=contact.name, + protocol_code=0x02, + data=data, + ) + except Exception as e: + logger.error(f"Binary request error: {e}") + return {"success": False, "reason": str(e)} + + async def send_anon_request(self, pub_key: bytes, data: bytes) -> dict: + contact = self.contacts.get_by_key(pub_key) + if not contact: + return {"success": False, "reason": "Contact not found"} + try: + return await self.node.send_protocol_request( + repeater_name=contact.name, + protocol_code=0x07, + data=data, + ) + except Exception as e: + logger.error(f"Anon request error: {e}") + return {"success": False, "reason": str(e)} + + # ------------------------------------------------------------------------- + # Control Data + # ------------------------------------------------------------------------- + + async def send_control_data(self, data: bytes) -> bool: + try: + import random + tag = random.randint(0, 0xFFFFFFFF) + pkt = PacketBuilder.create_discovery_request(tag, filter_mask=0x04) + return await self.node.dispatcher.send_packet(pkt, wait_for_ack=False) + except Exception as e: + logger.error(f"Error sending control data: {e}") + return False + + # ------------------------------------------------------------------------- + # Statistics (override for radio hardware) + # ------------------------------------------------------------------------- + + def _get_radio_stats(self) -> dict: + radio_stats = super()._get_radio_stats() + if hasattr(self._radio, "get_last_rssi"): + radio_stats["last_rssi"] = self._radio.get_last_rssi() + if hasattr(self._radio, "get_last_snr"): + radio_stats["last_snr"] = self._radio.get_last_snr() + return radio_stats + + # ------------------------------------------------------------------------- + # Internal + # ------------------------------------------------------------------------- + + def _setup_packet_callbacks(self) -> None: + dispatcher = self.node.dispatcher + dispatcher.set_packet_received_callback(self._on_packet_received) + dispatcher.set_packet_sent_callback(self._on_packet_sent) + + async def _on_packet_received(self, pkt: Any) -> None: + from ..protocol.constants import ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD + route_type = pkt.get_route_type() + is_flood = route_type in (ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD) + self.stats.record_rx(is_flood=is_flood) + + async def _on_packet_sent(self, pkt: Any) -> None: + pass diff --git a/src/pymc_core/companion/constants.py b/src/pymc_core/companion/constants.py new file mode 100644 index 0000000..fa29674 --- /dev/null +++ b/src/pymc_core/companion/constants.py @@ -0,0 +1,73 @@ +"""Companion radio constants for application-layer mesh networking features.""" + +# --------------------------------------------------------------------------- +# ADV Types (contact/node classification) +# --------------------------------------------------------------------------- +ADV_TYPE_CHAT = 1 +ADV_TYPE_REPEATER = 2 +ADV_TYPE_ROOM = 3 +ADV_TYPE_SENSOR = 4 + +# --------------------------------------------------------------------------- +# Text Types +# --------------------------------------------------------------------------- +TXT_TYPE_PLAIN = 0 +TXT_TYPE_CLI_DATA = 1 +TXT_TYPE_SIGNED_PLAIN = 2 + +# --------------------------------------------------------------------------- +# Telemetry Modes +# --------------------------------------------------------------------------- +TELEM_MODE_DENY = 0 +TELEM_MODE_ALLOW_FLAGS = 1 +TELEM_MODE_ALLOW_ALL = 2 + +# --------------------------------------------------------------------------- +# Advert Location Policy +# --------------------------------------------------------------------------- +ADVERT_LOC_NONE = 0 +ADVERT_LOC_SHARE = 1 + +# --------------------------------------------------------------------------- +# Auto-Add Config Bitmask +# --------------------------------------------------------------------------- +AUTOADD_OVERWRITE_OLDEST = 0x01 +AUTOADD_CHAT = 0x02 +AUTOADD_REPEATER = 0x04 +AUTOADD_ROOM = 0x08 +AUTOADD_SENSOR = 0x10 + +# --------------------------------------------------------------------------- +# Message Send Result +# --------------------------------------------------------------------------- +MSG_SEND_FAILED = 0 +MSG_SEND_SENT_FLOOD = 1 +MSG_SEND_SENT_DIRECT = 2 + +# --------------------------------------------------------------------------- +# Stats Types +# --------------------------------------------------------------------------- +STATS_TYPE_CORE = 0 +STATS_TYPE_RADIO = 1 +STATS_TYPE_PACKETS = 2 + +# --------------------------------------------------------------------------- +# Binary request types (CMD_SEND_BINARY_REQ / PUSH_CODE_BINARY_RESPONSE) +# --------------------------------------------------------------------------- +class BinaryReqType: + """Binary request type codes (companion frame protocol).""" + STATUS = 0x01 + KEEP_ALIVE = 0x02 + TELEMETRY = 0x03 + MMA = 0x04 + ACL = 0x05 + NEIGHBOURS = 0x06 + +# --------------------------------------------------------------------------- +# Default configuration +# --------------------------------------------------------------------------- +DEFAULT_MAX_CONTACTS = 1000 +DEFAULT_OFFLINE_QUEUE_SIZE = 16 +DEFAULT_MAX_CHANNELS = 40 +CONTACT_NAME_SIZE = 32 +MAX_SIGN_DATA_SIZE = 8192 # 8KB signing buffer (matches firmware) diff --git a/src/pymc_core/companion/contact_store.py b/src/pymc_core/companion/contact_store.py new file mode 100644 index 0000000..7e0dd54 --- /dev/null +++ b/src/pymc_core/companion/contact_store.py @@ -0,0 +1,236 @@ +"""In-memory contact storage compatible with MeshNode's contacts interface.""" + +import time +from typing import Iterable, Iterator, Optional + +from .constants import DEFAULT_MAX_CONTACTS +from .models import Contact + + +class ContactProxy: + """Wraps a Contact to provide the interface expected by MeshNode handlers. + + The existing handlers expect contacts with: + - public_key as a hex string (not bytes) + - name as a string + - out_path as a list + - type as an int + """ + + def __init__(self, contact: Contact): + self._contact = contact + self.public_key = contact.public_key.hex() + self.name = contact.name + self.type = contact.adv_type + self.flags = contact.flags + self.out_path = list(contact.out_path) if contact.out_path else [] + self.out_path_len = contact.out_path_len + self.sync_since = contact.sync_since + self.last_advert_timestamp = contact.last_advert_timestamp + self.lastmod = contact.lastmod + self.gps_lat = contact.gps_lat + self.gps_lon = contact.gps_lon + + def _sync_from_contact(self): + """Update proxy fields from the underlying Contact.""" + c = self._contact + self.public_key = c.public_key.hex() + self.name = c.name + self.type = c.adv_type + self.flags = c.flags + self.out_path = list(c.out_path) if c.out_path else [] + self.out_path_len = c.out_path_len + self.sync_since = c.sync_since + self.last_advert_timestamp = c.last_advert_timestamp + self.lastmod = c.lastmod + self.gps_lat = c.gps_lat + self.gps_lon = c.gps_lon + + +class ContactStore: + """In-memory contact storage compatible with MeshNode's contacts interface. + + Provides both the interface expected by MeshNode/Dispatcher (contacts property, + get_by_name, list_contacts) and companion radio CRUD operations (add, update, + remove, get_by_key, etc.). + + The store can be populated from external sources using load_from() or + load_from_dicts() for easy integration with databases and configuration files. + """ + + def __init__(self, max_contacts: int = DEFAULT_MAX_CONTACTS): + self._contacts: dict[bytes, Contact] = {} # keyed by public_key bytes + self._proxies: dict[bytes, ContactProxy] = {} # cached proxies + self._max_contacts = max_contacts + + @property + def max_contacts(self) -> int: + """Maximum number of contacts (read-only). Used by companion protocol device info.""" + return self._max_contacts + + # ------------------------------------------------------------------ + # Interface expected by MeshNode/Dispatcher/Handlers + # ------------------------------------------------------------------ + + @property + def contacts(self) -> list: + """Return contacts as list of proxy objects with hex public_key attribute.""" + return list(self._proxies.values()) + + def list_contacts(self) -> list: + """Return contacts list (used by ProtocolResponseHandler).""" + return self.contacts + + def get_by_name(self, name: str) -> Optional[ContactProxy]: + """Lookup by name (required by MeshNode._get_contact_or_raise).""" + for proxy in self._proxies.values(): + if proxy.name == name: + return proxy + return None + + # ------------------------------------------------------------------ + # Companion radio CRUD operations + # ------------------------------------------------------------------ + + def add(self, contact: Contact) -> bool: + """Add a new contact. Returns False if store is full or key already exists.""" + if contact.public_key in self._contacts: + return self.update(contact) + if len(self._contacts) >= self._max_contacts: + return False + self._contacts[contact.public_key] = contact + self._proxies[contact.public_key] = ContactProxy(contact) + return True + + def update(self, contact: Contact) -> bool: + """Update an existing contact. Returns False if not found.""" + if contact.public_key not in self._contacts: + return self.add(contact) + self._contacts[contact.public_key] = contact + self._proxies[contact.public_key] = ContactProxy(contact) + return True + + def remove(self, public_key: bytes) -> bool: + """Remove a contact by public key. Returns False if not found.""" + if public_key not in self._contacts: + return False + del self._contacts[public_key] + del self._proxies[public_key] + return True + + def get_by_key(self, public_key: bytes) -> Optional[Contact]: + """Lookup a contact by full 32-byte public key.""" + return self._contacts.get(public_key) + + def get_by_key_prefix(self, prefix: bytes) -> Optional[Contact]: + """Lookup a contact by public key prefix (1-32 bytes).""" + for key, contact in self._contacts.items(): + if key[: len(prefix)] == prefix: + return contact + return None + + def get_all(self, since: int = 0) -> list[Contact]: + """Get all contacts, optionally filtered by lastmod >= since.""" + if since == 0: + return list(self._contacts.values()) + return [c for c in self._contacts.values() if c.lastmod >= since] + + def get_count(self) -> int: + """Return the number of stored contacts.""" + return len(self._contacts) + + def is_full(self) -> bool: + """Check if the contact store is at capacity.""" + return len(self._contacts) >= self._max_contacts + + def clear(self): + """Remove all contacts.""" + self._contacts.clear() + self._proxies.clear() + + # ------------------------------------------------------------------ + # Bulk loading from external sources + # ------------------------------------------------------------------ + + def load_from(self, contacts: Iterable[Contact]): + """Bulk-load contacts from any iterable of Contact objects. + + Replaces all existing contacts. + """ + self.clear() + for contact in contacts: + if len(self._contacts) >= self._max_contacts: + break + self._contacts[contact.public_key] = contact + self._proxies[contact.public_key] = ContactProxy(contact) + + def load_from_dicts(self, records: Iterable[dict]): + """Bulk-load contacts from dicts. + + Each dict must have 'public_key' (hex string or bytes) and 'name' keys. + Optional keys: 'adv_type', 'flags', 'out_path', 'out_path_len', + 'last_advert_timestamp', 'lastmod', 'gps_lat', 'gps_lon', 'sync_since'. + + Replaces all existing contacts. + """ + self.clear() + for rec in records: + if len(self._contacts) >= self._max_contacts: + break + + pub_key = rec["public_key"] + if isinstance(pub_key, str): + pub_key = bytes.fromhex(pub_key) + + out_path = rec.get("out_path", b"") + if isinstance(out_path, str): + out_path = bytes.fromhex(out_path) + elif isinstance(out_path, list): + out_path = bytes(out_path) + + contact = Contact( + public_key=pub_key, + name=rec.get("name", ""), + adv_type=rec.get("adv_type", 0), + flags=rec.get("flags", 0), + out_path_len=rec.get("out_path_len", -1), + out_path=out_path, + last_advert_timestamp=rec.get("last_advert_timestamp", 0), + lastmod=rec.get("lastmod", 0), + gps_lat=rec.get("gps_lat", 0.0), + gps_lon=rec.get("gps_lon", 0.0), + sync_since=rec.get("sync_since", 0), + ) + self._contacts[pub_key] = contact + self._proxies[pub_key] = ContactProxy(contact) + + def to_dicts(self) -> list[dict]: + """Export all contacts as a list of plain dicts for serialization.""" + result = [] + for c in self._contacts.values(): + result.append( + { + "public_key": c.public_key.hex(), + "name": c.name, + "adv_type": c.adv_type, + "flags": c.flags, + "out_path_len": c.out_path_len, + "out_path": c.out_path.hex() if c.out_path else "", + "last_advert_timestamp": c.last_advert_timestamp, + "lastmod": c.lastmod, + "gps_lat": c.gps_lat, + "gps_lon": c.gps_lon, + "sync_since": c.sync_since, + } + ) + return result + + # ------------------------------------------------------------------ + # Iterator (matches firmware's iterator pattern) + # ------------------------------------------------------------------ + + def iterate(self, since: int = 0) -> Iterator[Contact]: + """Iterate over contacts, optionally filtered by lastmod >= since.""" + for contact in self._contacts.values(): + if since == 0 or contact.lastmod >= since: + yield contact diff --git a/src/pymc_core/companion/message_queue.py b/src/pymc_core/companion/message_queue.py new file mode 100644 index 0000000..4452759 --- /dev/null +++ b/src/pymc_core/companion/message_queue.py @@ -0,0 +1,58 @@ +"""Fixed-size offline message queue for companion radio.""" + +from collections import deque +from typing import Optional + +from .constants import DEFAULT_OFFLINE_QUEUE_SIZE +from .models import QueuedMessage + + +class MessageQueue: + """Fixed-size offline message queue (FIFO). + + Stores incoming messages that arrive when no consumer is actively + reading. Matches the firmware's offline_queue behaviour with a + configurable maximum size. When full, the oldest messages are + silently dropped (deque maxlen behaviour). + """ + + def __init__(self, max_size: int = DEFAULT_OFFLINE_QUEUE_SIZE): + self._queue: deque[QueuedMessage] = deque(maxlen=max_size) + self._max_size = max_size + + def push(self, msg: QueuedMessage) -> bool: + """Add a message to the queue. Returns True on success. + + If the queue is at capacity the oldest message is silently dropped. + """ + self._queue.append(msg) + return True + + def pop(self) -> Optional[QueuedMessage]: + """Remove and return the oldest message, or None if empty.""" + if self._queue: + return self._queue.popleft() + return None + + def peek(self) -> Optional[QueuedMessage]: + """Return the oldest message without removing it, or None if empty.""" + if self._queue: + return self._queue[0] + return None + + def is_empty(self) -> bool: + """Check if the queue has no messages.""" + return len(self._queue) == 0 + + def is_full(self) -> bool: + """Check if the queue is at capacity.""" + return len(self._queue) >= self._max_size + + @property + def count(self) -> int: + """Return the number of messages in the queue.""" + return len(self._queue) + + def clear(self): + """Remove all messages from the queue.""" + self._queue.clear() diff --git a/src/pymc_core/companion/models.py b/src/pymc_core/companion/models.py new file mode 100644 index 0000000..3295b82 --- /dev/null +++ b/src/pymc_core/companion/models.py @@ -0,0 +1,98 @@ +"""Data models for companion radio state objects.""" + +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class Contact: + """Represents a mesh network contact.""" + + public_key: bytes # 32 bytes (Ed25519) + name: str = "" # up to 32 chars + adv_type: int = 0 # ADV_TYPE_CHAT/REPEATER/ROOM/SENSOR + flags: int = 0 # bitfield + out_path_len: int = -1 # -1 = unknown, 0 = direct, >0 = multi-hop + out_path: bytes = b"" # routing path bytes + last_advert_timestamp: int = 0 # remote timestamp + lastmod: int = 0 # local modification timestamp + gps_lat: float = 0.0 # degrees + gps_lon: float = 0.0 # degrees + sync_since: int = 0 # for filtered iteration + + +@dataclass +class Channel: + """Represents a group communication channel.""" + + name: str # up to 32 chars + secret: bytes # 16-byte PSK + + +@dataclass +class NodePrefs: + """Node configuration preferences (equivalent to firmware NodePrefs).""" + + node_name: str = "pyMC" + adv_type: int = 1 # ADV_TYPE_CHAT + tx_power_dbm: int = 20 + frequency_hz: int = 915000000 + bandwidth_hz: int = 250000 + spreading_factor: int = 10 + coding_rate: int = 5 + latitude: float = 0.0 + longitude: float = 0.0 + advert_loc_policy: int = 0 # ADVERT_LOC_NONE + multi_acks: int = 0 + telemetry_mode_base: int = 0 # TELEM_MODE_DENY + telemetry_mode_location: int = 0 + telemetry_mode_environment: int = 0 + manual_add_contacts: int = 0 + autoadd_config: int = 0 + rx_delay_base: float = 0.0 + airtime_factor: float = 0.0 + + +@dataclass +class SentResult: + """Result of a message send operation.""" + + success: bool + is_flood: bool = False + expected_ack: Optional[int] = None + timeout_ms: Optional[int] = None + + +@dataclass +class PacketStats: + """Packet transmission/reception statistics.""" + + flood_tx: int = 0 + flood_rx: int = 0 + direct_tx: int = 0 + direct_rx: int = 0 + tx_errors: int = 0 + + +@dataclass +class AdvertPath: + """Recently heard advertiser path information.""" + + public_key_prefix: bytes # 7 bytes + name: str = "" + path_len: int = 0 + path: bytes = b"" + recv_timestamp: int = 0 + + +@dataclass +class QueuedMessage: + """A message stored in the offline queue.""" + + sender_key: bytes # 32 bytes + txt_type: int = 0 + timestamp: int = 0 + text: str = "" + is_channel: bool = False + channel_idx: int = 0 # only meaningful if is_channel + path_len: int = 0 diff --git a/src/pymc_core/companion/path_cache.py b/src/pymc_core/companion/path_cache.py new file mode 100644 index 0000000..95e9881 --- /dev/null +++ b/src/pymc_core/companion/path_cache.py @@ -0,0 +1,55 @@ +"""Path cache for tracking recently heard advertiser paths.""" + +from typing import Optional + +from .models import AdvertPath + + +class PathCache: + """Tracks recently heard advertiser paths. + + Stores path information received from advertisements and path updates, + matching the firmware's advert_paths table. Paths are keyed by public + key prefix and updated on each new advertisement. + """ + + def __init__(self, max_entries: int = 16): + self._paths: list[AdvertPath] = [] + self._max = max_entries + + def update(self, advert_path: AdvertPath): + """Add or update a path entry. + + If a path with the same public key prefix already exists, it is + replaced. If the cache is full, the oldest entry is evicted. + """ + # Check for existing entry with same prefix + for i, existing in enumerate(self._paths): + if existing.public_key_prefix == advert_path.public_key_prefix: + self._paths[i] = advert_path + return + + # Add new entry, evicting oldest if full + if len(self._paths) >= self._max: + self._paths.pop(0) + self._paths.append(advert_path) + + def get_by_prefix(self, prefix: bytes) -> Optional[AdvertPath]: + """Lookup a path by public key prefix. + + Args: + prefix: Public key prefix to search for (matches the start + of stored public_key_prefix fields). + """ + for path in self._paths: + if path.public_key_prefix[: len(prefix)] == prefix: + return path + return None + + def get_all(self) -> list[AdvertPath]: + """Return all cached paths.""" + return list(self._paths) + + def clear(self): + """Remove all cached paths.""" + self._paths.clear() diff --git a/src/pymc_core/companion/stats_collector.py b/src/pymc_core/companion/stats_collector.py new file mode 100644 index 0000000..c7579e9 --- /dev/null +++ b/src/pymc_core/companion/stats_collector.py @@ -0,0 +1,57 @@ +"""Packet and radio statistics collector for companion radio.""" + +import time + +from .models import PacketStats + + +class StatsCollector: + """Collects packet transmission/reception statistics. + + Tracks flood vs direct packet counts, errors, and uptime. + Matches the firmware's statistics reporting via CMD_GET_STATS. + """ + + def __init__(self): + self.packets = PacketStats() + self._start_time = time.time() + + def record_tx(self, is_flood: bool): + """Record a successful transmission.""" + if is_flood: + self.packets.flood_tx += 1 + else: + self.packets.direct_tx += 1 + + def record_rx(self, is_flood: bool): + """Record a successful reception.""" + if is_flood: + self.packets.flood_rx += 1 + else: + self.packets.direct_rx += 1 + + def record_tx_error(self): + """Record a transmission error.""" + self.packets.tx_errors += 1 + + def get_uptime_secs(self) -> int: + """Return the number of seconds since the collector was created.""" + return int(time.time() - self._start_time) + + def get_totals(self) -> dict: + """Return a summary of all statistics.""" + return { + "flood_tx": self.packets.flood_tx, + "flood_rx": self.packets.flood_rx, + "direct_tx": self.packets.direct_tx, + "direct_rx": self.packets.direct_rx, + "tx_errors": self.packets.tx_errors, + "total_tx": self.packets.flood_tx + self.packets.direct_tx, + "total_rx": self.packets.flood_rx + self.packets.direct_rx, + "uptime_secs": self.get_uptime_secs(), + } + + def reset(self): + """Reset all counters and restart uptime.""" + self.packets = PacketStats() + self._start_time = time.time() diff --git a/src/pymc_core/node/dispatcher.py b/src/pymc_core/node/dispatcher.py index 9cd5a3e..847852e 100644 --- a/src/pymc_core/node/dispatcher.py +++ b/src/pymc_core/node/dispatcher.py @@ -3,7 +3,7 @@ import asyncio import enum import logging -from typing import Any, Awaitable, Callable, Optional +from typing import Any, Awaitable, Callable, List, Optional from ..protocol import Packet from ..protocol.constants import ( # Payload types @@ -68,8 +68,11 @@ def __init__( self.packet_received_callback: Optional[Callable[[Packet], Awaitable[None] | None]] = None self.packet_sent_callback: Optional[Callable[[Packet], Awaitable[None] | None]] = None - # Add raw packet callback for detailed logging + # Raw packet callbacks: single callback (legacy) and list of subscribers (after parse) self.raw_packet_callback: Optional[Callable[[Packet, bytes], Awaitable[None] | None]] = None + self._raw_packet_subscribers: List[Callable[..., Any]] = [] + # Raw RX subscribers: notified for every reception (data, rssi, snr) before duplicate/parse + self._raw_rx_subscribers: List[Callable[..., Any]] = [] self._handlers: dict[int, Any] = {} # Keep track of packet handlers self._handler_instances: dict[ @@ -163,7 +166,7 @@ def register_default_handlers( # Register all the standard handlers self.register_handler( AdvertHandler.payload_type(), - AdvertHandler(self._log), + AdvertHandler(self._log, event_service=event_service), ) self.register_handler(AckHandler.payload_type(), ack_handler) @@ -301,6 +304,36 @@ def set_raw_packet_callback( """Set callback for raw packet data (includes both parsed packet and raw bytes).""" self.raw_packet_callback = callback + def add_raw_packet_subscriber( + self, callback: Callable[..., Any] + ) -> None: + """Subscribe to every incoming raw packet. Callback receives (pkt, data) or (pkt, data, analysis). + Use this to forward raw RX to clients (e.g. PUSH_CODE_LOG_RX_DATA) so they can track repeats by packet hash.""" + if callback not in self._raw_packet_subscribers: + self._raw_packet_subscribers.append(callback) + + def remove_raw_packet_subscriber(self, callback: Callable[..., Any]) -> None: + """Unsubscribe from raw packet notifications (after parse).""" + try: + self._raw_packet_subscribers.remove(callback) + except ValueError: + pass + + def add_raw_rx_subscriber( + self, callback: Callable[[bytes, int, float], Awaitable[None] | None] + ) -> None: + """Subscribe to every incoming raw RX. Callback receives (data, rssi, snr). + Called before duplicate/blacklist so clients get every repeat (e.g. PUSH_CODE_LOG_RX_DATA).""" + if callback not in self._raw_rx_subscribers: + self._raw_rx_subscribers.append(callback) + + def remove_raw_rx_subscriber(self, callback: Callable[..., Any]) -> None: + """Unsubscribe from raw RX notifications.""" + try: + self._raw_rx_subscribers.remove(callback) + except ValueError: + pass + def _on_packet_received( self, data: bytes, @@ -323,6 +356,28 @@ async def _process_received_packet( """Process a received packet from the radio callback. rssi/snr are per-packet when provided.""" self._log(f"[RX DEBUG] Processing packet: {len(data)} bytes, data: {data.hex()[:32]}...") + # Notify raw RX subscribers first (every reception, including duplicates) so clients can track repeats + if rssi is not None: + rssi_val = rssi + elif hasattr(self.radio, "get_last_rssi"): + rssi_val = self.radio.get_last_rssi() + else: + rssi_val = 0 + if snr is not None: + snr_val = snr + elif hasattr(self.radio, "get_last_snr"): + snr_val = self.radio.get_last_snr() + else: + snr_val = 0.0 + for cb in self._raw_rx_subscribers: + try: + if asyncio.iscoroutinefunction(cb): + await cb(data, rssi_val, snr_val) + else: + cb(data, rssi_val, snr_val) + except Exception as e: + self._log(f"Raw RX subscriber error: {e}") + # Generate packet hash for deduplication and blacklist checking packet_hash = self.packet_filter.generate_hash(data) @@ -361,8 +416,6 @@ async def _process_received_packet( # Let the node know about this packet for analysis (statistics, caching, etc.) if self.packet_analysis_callback: try: - import asyncio - if asyncio.iscoroutinefunction(self.packet_analysis_callback): await self.packet_analysis_callback(pkt, data) else: @@ -371,9 +424,13 @@ async def _process_received_packet( except Exception as e: self._log(f"Error in packet analysis callback: {e}") - # Always call raw packet callback first for logging (regardless of source) + # Notify raw packet subscribers (e.g. companion clients for PUSH_CODE_LOG_RX_DATA) + analysis = {} + for callback in self._raw_packet_subscribers: + await self._invoke_enhanced_raw_callback(callback, pkt, data, analysis) if self.raw_packet_callback: await self._invoke_enhanced_raw_callback(self.raw_packet_callback, pkt, data, {}) + if self._raw_packet_subscribers or self.raw_packet_callback: self._log("[RX DEBUG] Raw packet callback completed") # Check if this is our own packet before processing handlers diff --git a/src/pymc_core/node/handlers/advert.py b/src/pymc_core/node/handlers/advert.py index bda80c1..83fc7ff 100644 --- a/src/pymc_core/node/handlers/advert.py +++ b/src/pymc_core/node/handlers/advert.py @@ -1,7 +1,8 @@ +import struct import time from typing import Any, Dict, Optional -from ...protocol import Identity, Packet, decode_appdata +from ...protocol import Identity, Packet, decode_appdata, parse_advert_payload from ...protocol.constants import ( MAX_ADVERT_DATA_SIZE, PAYLOAD_TYPE_ADVERT, @@ -11,6 +12,7 @@ describe_advert_flags, ) from ...protocol.utils import determine_contact_type_from_flags, get_contact_type_name +from ..events import MeshEvents from .base import BaseHandler @@ -19,33 +21,9 @@ class AdvertHandler(BaseHandler): def payload_type() -> int: return PAYLOAD_TYPE_ADVERT - def __init__(self, log_fn): + def __init__(self, log_fn, event_service=None): self.log = log_fn - - def _extract_advert_components(self, packet: Packet): - """Extract and validate advert packet components.""" - payload = packet.get_payload() - header_len = PUB_KEY_SIZE + TIMESTAMP_SIZE + SIGNATURE_SIZE - if len(payload) < header_len: - self.log( - f"Advert payload too short ({len(payload)} bytes, expected at least {header_len})" - ) - return None - - sig_offset = PUB_KEY_SIZE + TIMESTAMP_SIZE - pubkey = payload[:PUB_KEY_SIZE] - timestamp = payload[PUB_KEY_SIZE:sig_offset] - signature = payload[sig_offset : sig_offset + SIGNATURE_SIZE] - appdata = payload[sig_offset + SIGNATURE_SIZE :] - - if len(appdata) > MAX_ADVERT_DATA_SIZE: - self.log( - f"Advert appdata too large ({len(appdata)} bytes). " - f"Truncating to {MAX_ADVERT_DATA_SIZE}" - ) - appdata = appdata[:MAX_ADVERT_DATA_SIZE] - - return pubkey, timestamp, signature, appdata + self.event_service = event_service def _verify_advert_signature( self, pubkey: bytes, timestamp: bytes, appdata: bytes, signature: bytes @@ -85,13 +63,26 @@ def _verify_advert_signature( async def __call__(self, packet: Packet) -> Optional[Dict[str, Any]]: """Process advert packet and return parsed data with signature verification.""" try: - # Extract and validate packet components - components = self._extract_advert_components(packet) - if not components: + payload = packet.get_payload() + if not payload: + return None + try: + parsed = parse_advert_payload(payload) + except ValueError as e: + self.log(f"Advert payload parse error: {e}") return None - pubkey_bytes, timestamp_bytes, signature_bytes, appdata = components - pubkey_hex = pubkey_bytes.hex() + pubkey_bytes = bytes.fromhex(parsed["pubkey"]) + pubkey_hex = parsed["pubkey"] + advert_timestamp = parsed["timestamp"] + timestamp_bytes = struct.pack(" MAX_ADVERT_DATA_SIZE: + self.log( + f"Advert appdata too large ({len(appdata)} bytes), truncating to {MAX_ADVERT_DATA_SIZE}" + ) + appdata = appdata[:MAX_ADVERT_DATA_SIZE] # Verify cryptographic signature if not self._verify_advert_signature( @@ -102,7 +93,7 @@ async def __call__(self, packet: Packet) -> Optional[Dict[str, Any]]: self.log(f"Processing advert for pubkey: {pubkey_hex[:16]}...") - # Decode application data + # Decode application data (protocol.utils.decode_appdata) decoded = decode_appdata(appdata) # Extract name from decoded data @@ -119,6 +110,11 @@ async def __call__(self, packet: Packet) -> Optional[Dict[str, Any]]: contact_type_id = determine_contact_type_from_flags(flags_int) contact_type = get_contact_type_name(contact_type_id) + # Clamp to current time if remote clock is ahead (avoid "future" last-advert in UI) + now = int(time.time()) + if advert_timestamp > now: + advert_timestamp = now + # Build parsed advert data advert_data = { "public_key": pubkey_hex, @@ -129,6 +125,7 @@ async def __call__(self, packet: Packet) -> Optional[Dict[str, Any]]: "flags_description": flags_description, "contact_type_id": contact_type_id, "contact_type": contact_type, + "advert_timestamp": advert_timestamp, "timestamp": int(time.time()), "snr": packet._snr if hasattr(packet, "_snr") else 0.0, "rssi": packet._rssi if hasattr(packet, "_rssi") else 0, @@ -136,6 +133,23 @@ async def __call__(self, packet: Packet) -> Optional[Dict[str, Any]]: } self.log(f"Parsed advert: {name} ({contact_type})") + + # Publish so companion/app receives node-discovered and advert_received callbacks + if self.event_service: + try: + event_data = { + "public_key": pubkey_hex, + "name": name, + "contact_type": contact_type_id, + "lat": lat, + "lon": lon, + "snr": advert_data["snr"], + "rssi": advert_data["rssi"], + } + self.event_service.publish_sync(MeshEvents.NODE_DISCOVERED, event_data) + except Exception as e: + self.log(f"Failed to publish NODE_DISCOVERED event: {e}") + return advert_data except Exception as e: diff --git a/src/pymc_core/node/handlers/control.py b/src/pymc_core/node/handlers/control.py index 93148b4..0286754 100644 --- a/src/pymc_core/node/handlers/control.py +++ b/src/pymc_core/node/handlers/control.py @@ -24,13 +24,22 @@ class ControlHandler: This handler processes incoming discovery requests and responses. """ - def __init__(self, log_fn: Callable[[str], None]): + def __init__( + self, + log_fn: Callable[[str], None], + debug_log_fn: Optional[Callable[[str], None]] = None, + ): """Initialize control handler. - + Args: - log_fn: Logging function + log_fn: Logging function for normal messages. + debug_log_fn: Optional logging function for verbose messages (e.g. callback + presence). If set, "No callback waiting" and "Called response callback" + use this instead of log_fn, so callers can use logger.debug to avoid noise + when forwarding discovery to companions. """ self._log = log_fn + self._debug_log = debug_log_fn if debug_log_fn is not None else log_fn # Callbacks for discovery responses self._response_callbacks: Dict[int, Callable[[Dict[str, Any]], None]] = {} @@ -192,11 +201,11 @@ async def _handle_discovery_response(self, pkt: Packet) -> Optional[Dict[str, An callback = self._response_callbacks[tag] if callback: callback(response_data) - self._log( + self._debug_log( f"[ControlHandler] Called response callback for tag 0x{tag:08X}" ) else: - self._log( + self._debug_log( f"[ControlHandler] No callback waiting for tag 0x{tag:08X}" ) @@ -205,6 +214,3 @@ async def _handle_discovery_response(self, pkt: Packet) -> Optional[Dict[str, An except Exception as e: self._log(f"[ControlHandler] Error handling discovery response: {e}") return None - - except Exception as e: - self._log(f"[ControlHandler] Error handling discovery response: {e}") diff --git a/src/pymc_core/node/handlers/group_text.py b/src/pymc_core/node/handlers/group_text.py index 9c82091..859876c 100644 --- a/src/pymc_core/node/handlers/group_text.py +++ b/src/pymc_core/node/handlers/group_text.py @@ -1,7 +1,11 @@ from typing import Optional from ...protocol import Packet -from ...protocol.constants import PAYLOAD_TYPE_GRP_TXT +from ...protocol.constants import ( + PAYLOAD_TYPE_GRP_TXT, + ROUTE_TYPE_FLOOD, + ROUTE_TYPE_TRANSPORT_FLOOD, +) from ...protocol.crypto import CryptoUtils from .base import BaseHandler @@ -49,33 +53,32 @@ def _get_channel_by_hash(self, channel_hash: int) -> Optional[dict]: self.log(f"Error querying channel database: {e}") return None - def _derive_channel_hash(self, channel_secret: str) -> int: - """Derive a consistent channel hash from the secret.""" - import hashlib - - # Convert hex secret to bytes, then derive key + def _secret_bytes_for_hash(self, channel_secret: str) -> bytes: + """Normalize secret to bytes used for channel hash (match MeshCore firmware). + Firmware hashes only first 16 bytes when second 16 are zero (128-bit key).""" try: secret_bytes = bytes.fromhex(channel_secret) except ValueError: - # If not hex, treat as UTF-8 string secret_bytes = channel_secret.encode("utf-8") + if len(secret_bytes) >= 32 and secret_bytes[16:32] == b"\x00" * 16: + return secret_bytes[:16] + if len(secret_bytes) > 32: + return secret_bytes[:32] + return secret_bytes - # Simple SHA256 derivation (no salt) to match official spec - channel_key = hashlib.sha256(secret_bytes).digest() + def _derive_channel_hash(self, channel_secret: str) -> int: + """Derive channel hash (first byte of SHA256) to match MeshCore firmware.""" + import hashlib - # Return first byte as the channel hash + secret_bytes = self._secret_bytes_for_hash(channel_secret) + channel_key = hashlib.sha256(secret_bytes).digest() return channel_key[0] def _derive_channel_keys(self, channel_secret: str) -> tuple: """Derive all necessary keys from channel secret.""" import hashlib - try: - secret_bytes = bytes.fromhex(channel_secret) - except ValueError: - secret_bytes = channel_secret.encode("utf-8") - - # Simple SHA256 derivation to match official spec + secret_bytes = self._secret_bytes_for_hash(channel_secret) master_key = hashlib.sha256(secret_bytes).digest() # Split into different keys @@ -124,7 +127,9 @@ def _parse_plaintext_message(self, plaintext: bytes) -> Optional[dict]: try: timestamp = int.from_bytes(plaintext[:4], "little") flags = plaintext[4] - message_content = plaintext[5:].decode("utf-8", errors="replace") + # Decode and strip trailing null/padding (AES decrypt returns block-aligned data with zero padding) + raw = plaintext[5:].decode("utf-8", errors="replace") + message_content = raw.rstrip("\x00") # Parse message flags according to spec message_type = "unknown" @@ -137,7 +142,8 @@ def _parse_plaintext_message(self, plaintext: bytes) -> Optional[dict]: # For signed messages, first two bytes are sender prefix if len(plaintext) >= 7: # sender_prefix = plaintext[5:7] # Unused for now - message_content = plaintext[7:].decode("utf-8", errors="replace") + raw = plaintext[7:].decode("utf-8", errors="replace") + message_content = raw.rstrip("\x00") return { "timestamp": timestamp, @@ -270,6 +276,13 @@ async def _save_and_broadcast_group_message( channel_hash = f"{packet.get_payload()[0]:02X}" + # path_len: flood packets use actual path length; direct uses 0xFF + route_type = packet.header & 0x03 + if route_type in (ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD): + path_len = getattr(packet, "path_len", 0) or len(packet.path or []) + else: + path_len = 0xFF + # Use a custom message type for single channel message addition message_data = { "message_id": message_id, @@ -280,6 +293,8 @@ async def _save_and_broadcast_group_message( "timestamp": timestamp, "message_type": "group_text", "flags": 0, + "path_len": path_len, + "packet_hash": packet.calculate_packet_hash().hex().upper(), "full_content": packet.decrypted.get("group_text_data", {}).get( "full_content" ), @@ -293,8 +308,8 @@ async def _save_and_broadcast_group_message( }, } - # Publish channel message event - self.event_service.publish_sync(MeshEvents.NEW_CHANNEL_MESSAGE, message_data) + # Publish channel message event (await so message is queued and MSG_WAITING sent before return) + await self.event_service.publish(MeshEvents.NEW_CHANNEL_MESSAGE, message_data) self.log("Published group message event") except Exception as publish_error: self.log(f"Failed to publish group message event: {publish_error}") diff --git a/src/pymc_core/node/handlers/protocol_response.py b/src/pymc_core/node/handlers/protocol_response.py index c82ca4f..ef36d01 100644 --- a/src/pymc_core/node/handlers/protocol_response.py +++ b/src/pymc_core/node/handlers/protocol_response.py @@ -4,12 +4,13 @@ back as PATH packets with encrypted payloads. """ +import asyncio import struct from typing import Any, Callable, Dict, Optional from ...hardware.signal_utils import snr_register_to_db from ...protocol import CryptoUtils, Identity, Packet -from ...protocol.constants import PAYLOAD_TYPE_PATH +from ...protocol.constants import MAX_PATH_SIZE, PAYLOAD_TYPE_PATH, PAYLOAD_TYPE_RESPONSE class ProtocolResponseHandler: @@ -28,6 +29,9 @@ def __init__(self, log_fn: Callable[[str], None], local_identity, contact_book): # Callbacks for protocol responses self._response_callbacks: Dict[int, Callable[[bool, str, Dict[str, Any]], None]] = {} + # Optional: when set, decrypted payloads with tag+data (and optional path) are passed as binary response + # Signature: (tag_bytes, response_data, path_info=None). path_info = (out_path, in_path, contact_pubkey). + self._binary_response_callback: Optional[Callable[..., Any]] = None @staticmethod def payload_type() -> int: @@ -43,6 +47,11 @@ def clear_response_callback(self, contact_hash: int) -> None: """Clear callback for protocol responses from a specific contact.""" self._response_callbacks.pop(contact_hash, None) + def set_binary_response_callback(self, callback: Callable[..., Any]) -> None: + """Set callback for binary responses. Called with (tag_bytes, response_data, path_info=None). + path_info when present is (out_path, in_path, contact_pubkey) for path-return format.""" + self._binary_response_callback = callback + async def __call__(self, pkt: Packet) -> None: """Handle incoming PATH packet that might be a protocol response.""" try: @@ -54,9 +63,9 @@ async def __call__(self, pkt: Packet) -> None: # dest_hash(1) + src_hash(1) + encrypted_data(N) src_hash = pkt.payload[1] - # Check if we have a callback waiting for this source - if src_hash not in self._response_callbacks: - return # Not waiting for response from this source + # Proceed if we have a callback for this source or the binary (path-discovery) callback + if src_hash not in self._response_callbacks and self._binary_response_callback is None: + return self._log( "[ProtocolResponse] Processing potential protocol response " @@ -64,10 +73,47 @@ async def __call__(self, pkt: Packet) -> None: ) # Try to decrypt the response - success, decoded_text, parsed_data = await self._decrypt_protocol_response( + success, decoded_text, parsed_data, raw_decrypted = await self._decrypt_protocol_response( pkt, src_hash ) + # If binary response callback is set, parse and invoke (plain tag+data or path-return format) + if ( + success + and self._binary_response_callback is not None + and raw_decrypted is not None + and len(raw_decrypted) >= 4 + ): + path_info = None + tag_bytes = raw_decrypted[:4] + response_data = raw_decrypted[4:] + # Path-return format (MeshCore createPathReturn): path_len(1), path(path_len), extra_type(1), extra + path_len = raw_decrypted[0] + if ( + path_len <= MAX_PATH_SIZE + and len(raw_decrypted) >= 1 + path_len + 1 + 4 + ): + out_path = bytes(raw_decrypted[1 : 1 + path_len]) + extra_type = raw_decrypted[1 + path_len] + extra = raw_decrypted[2 + path_len :] + if extra_type == PAYLOAD_TYPE_RESPONSE and len(extra) >= 4: + tag_bytes = extra[:4] + response_data = extra[4:] + in_path = bytes(pkt.path) if pkt.path else b"" + contact = self._find_contact_by_hash(src_hash) + if contact: + contact_pubkey = bytes.fromhex(contact.public_key) + path_info = (out_path, in_path, contact_pubkey) + try: + cb_result = self._binary_response_callback( + tag_bytes, response_data, path_info + ) + if asyncio.iscoroutine(cb_result): + await cb_result + except Exception as e: + self._log(f"[ProtocolResponse] Binary response callback error: {e}") + return + # Call the waiting callback callback = self._response_callbacks[src_hash] if callback: @@ -78,13 +124,13 @@ async def __call__(self, pkt: Packet) -> None: async def _decrypt_protocol_response( self, pkt: Packet, src_hash: int - ) -> tuple[bool, str, Dict[str, Any]]: - """Decrypt and parse a protocol response packet.""" + ) -> tuple[bool, str, Dict[str, Any], Optional[bytes]]: + """Decrypt and parse a protocol response packet. Returns (success, text, parsed_data, raw_decrypted).""" try: # Find the contact by hash contact = self._find_contact_by_hash(src_hash) if not contact: - return False, f"Unknown contact for hash 0x{src_hash:02X}", {} + return False, f"Unknown contact for hash 0x{src_hash:02X}", {}, None # Get encryption keys contact_pubkey = bytes.fromhex(contact.public_key) @@ -101,11 +147,12 @@ async def _decrypt_protocol_response( self._log(f"[ProtocolResponse] Successfully decrypted {len(decrypted)} bytes") # Parse based on content type - return self._parse_protocol_response(decrypted) + success, text, parsed = self._parse_protocol_response(decrypted) + return success, text, parsed, decrypted except Exception as e: self._log(f"[ProtocolResponse] Decryption failed: {e}") - return False, f"Decryption failed: {e}", {} + return False, f"Decryption failed: {e}", {}, None def _parse_protocol_response(self, data: bytes) -> tuple[bool, str, Dict[str, Any]]: """Parse decrypted protocol response data.""" diff --git a/src/pymc_core/node/handlers/text.py b/src/pymc_core/node/handlers/text.py index 174b1df..2984c52 100644 --- a/src/pymc_core/node/handlers/text.py +++ b/src/pymc_core/node/handlers/text.py @@ -191,6 +191,7 @@ async def send_delayed_ack(): "contact_name": matched_contact.name, "contact_pubkey": matched_contact.public_key, "message_text": decoded_msg, + "txt_type": txt_type, "is_outgoing": False, "timestamp": message_timestamp, "delivery_status": "received", diff --git a/src/pymc_core/protocol/crypto.py b/src/pymc_core/protocol/crypto.py index 29a1aa4..e3eeb5f 100644 --- a/src/pymc_core/protocol/crypto.py +++ b/src/pymc_core/protocol/crypto.py @@ -69,6 +69,17 @@ def mac_then_decrypt(aes_key: bytes, shared_secret: bytes, data: bytes) -> bytes return decrypted + @staticmethod + def x25519_clamp_scalar(scalar: bytes) -> bytes: + """Clamp a 32-byte scalar for X25519 (matches MeshCore key_exchange.c).""" + if len(scalar) != 32: + raise ValueError("scalar must be 32 bytes") + s = bytearray(scalar) + s[0] &= 248 + s[31] &= 63 + s[31] |= 64 + return bytes(s) + @staticmethod def scalarmult(private_key: bytes, public_key: bytes) -> bytes: """ECDH shared secret calculation (X25519).""" diff --git a/src/pymc_core/protocol/identity.py b/src/pymc_core/protocol/identity.py index a5c8101..b879a7d 100644 --- a/src/pymc_core/protocol/identity.py +++ b/src/pymc_core/protocol/identity.py @@ -100,17 +100,19 @@ def __init__(self, seed: Optional[bytes] = None): if seed and len(seed) == 64: from nacl.bindings import crypto_scalarmult_ed25519_base_noclamp - # MeshCore format: [32-byte clamped scalar][32-byte nonce] + # MeshCore format: [32-byte scalar][32-byte nonce]; firmware clamps first 32 bytes for ECDH self._firmware_key = seed self.signing_key = None - # Derive public key from scalar + # Use X25519 clamping so ECDH matches firmware's ed25519_key_exchange() scalar = seed[:32] - ed25519_pub = crypto_scalarmult_ed25519_base_noclamp(scalar) + clamped = CryptoUtils.x25519_clamp_scalar(scalar) + ed25519_pub = crypto_scalarmult_ed25519_base_noclamp(clamped) self.verify_key = VerifyKey(ed25519_pub) - # Build ed25519_sk for X25519 conversion (use reconstructed format) - ed25519_sk = scalar + ed25519_pub + # Use clamped scalar directly for ECDH (firmware key_exchange.c uses first 32 bytes clamped) + self._x25519_private = clamped + self._x25519_public = CryptoUtils.scalarmult_base(clamped) else: # Standard 32-byte seed or None self._firmware_key = None @@ -121,9 +123,9 @@ def __init__(self, seed: Optional[bytes] = None): ed25519_pub = self.verify_key.encode() ed25519_sk = self.signing_key.encode() + ed25519_pub - # X25519 keypair for ECDH - self._x25519_private = CryptoUtils.ed25519_sk_to_x25519(ed25519_sk) - self._x25519_public = CryptoUtils.scalarmult_base(self._x25519_private) + # X25519 keypair for ECDH (libsodium conversion) + self._x25519_private = CryptoUtils.ed25519_sk_to_x25519(ed25519_sk) + self._x25519_public = CryptoUtils.scalarmult_base(self._x25519_private) # Initialise base class with Ed25519 pubkey super().__init__(ed25519_pub) diff --git a/src/pymc_core/protocol/packet_builder.py b/src/pymc_core/protocol/packet_builder.py index 109914e..82f48d2 100644 --- a/src/pymc_core/protocol/packet_builder.py +++ b/src/pymc_core/protocol/packet_builder.py @@ -523,9 +523,15 @@ def create_group_datagram( secret_bytes = ( bytes.fromhex(channel["secret"]) if isinstance(channel["secret"], str) - else channel["secret"].encode("utf-8") + else (channel["secret"] if isinstance(channel["secret"], bytes) else channel["secret"].encode("utf-8")) ) - channel_hash = hashlib.sha256(secret_bytes).digest()[0] + # Use same channel hash derivation as GroupTextHandler (firmware: hash first 16 bytes when 32-byte key has second 16 zero) + hash_input = ( + secret_bytes[:16] + if len(secret_bytes) >= 32 and secret_bytes[16:32] == b"\x00" * 16 + else (secret_bytes[:32] if len(secret_bytes) > 32 else secret_bytes) + ) + channel_hash = hashlib.sha256(hash_input).digest()[0] secret_bytes = (secret_bytes + b"\x00" * 32)[:32] timestamp, flags = PacketBuilder._get_timestamp(), 0x00 @@ -536,7 +542,7 @@ def create_group_datagram( mac = CryptoUtils._hmac_sha256(secret_bytes, ciphertext)[:2] payload = bytearray([channel_hash]) + mac + ciphertext - header = PacketBuilder._create_header(PAYLOAD_TYPE_GRP_TXT) + header = PacketBuilder._create_header(PAYLOAD_TYPE_GRP_TXT, route_type="flood") return PacketBuilder._create_packet(header, payload) @staticmethod @@ -573,7 +579,7 @@ def create_group_data_packet( cipher = PacketBuilder._encrypt_payload(aes_key, secret, plaintext) payload = bytearray([channel_hash]) + cipher - header = PacketBuilder._create_header(ptype) + header = PacketBuilder._create_header(ptype, route_type="flood") return PacketBuilder._create_packet(header, payload) @staticmethod From fdd75636e0a84c8020990b3d42ad1710ae45b104 Mon Sep 17 00:00:00 2001 From: agessaman Date: Sat, 14 Feb 2026 19:49:19 -0800 Subject: [PATCH 02/50] Enhance CompanionBridge and Protocol Handling - Introduced REQ_TYPE_GET_STATUS for requesting repeater stats. - Updated send_status_request method to handle status requests with improved error handling and response parsing. - Enhanced message type determination in send methods to differentiate between flood and direct routes. - Refactored contact_store to handle out_path_len more robustly. - Added CayenneLPP decoding functionality in protocol response handling for better sensor data interpretation. --- src/pymc_core/companion/companion_bridge.py | 52 +- src/pymc_core/companion/contact_store.py | 2 +- src/pymc_core/node/handlers/login_response.py | 22 +- .../node/handlers/protocol_response.py | 533 ++++++++++++------ src/pymc_core/protocol/__init__.py | 4 +- src/pymc_core/protocol/constants.py | 6 +- src/pymc_core/protocol/packet_builder.py | 35 +- 7 files changed, 466 insertions(+), 188 deletions(-) diff --git a/src/pymc_core/companion/companion_bridge.py b/src/pymc_core/companion/companion_bridge.py index dcc6106..cf4f478 100644 --- a/src/pymc_core/companion/companion_bridge.py +++ b/src/pymc_core/companion/companion_bridge.py @@ -40,7 +40,7 @@ ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD, ) -from ..protocol.constants import REQ_TYPE_GET_TELEMETRY_DATA, TELEM_PERM_BASE +from ..protocol.constants import REQ_TYPE_GET_STATUS, REQ_TYPE_GET_TELEMETRY_DATA, TELEM_PERM_BASE from .companion_base import CompanionBase, ResponseWaiter, adv_type_to_flags from .constants import ( ADV_TYPE_CHAT, @@ -331,16 +331,18 @@ async def send_text_message( if not proxy: return SentResult(success=False) try: + is_flood = proxy.out_path_len < 0 + msg_type = "flood" if is_flood else "direct" pkt, ack_crc = PacketBuilder.create_text_message( contact=proxy, local_identity=self._identity, message=text, attempt=attempt, + message_type=msg_type, ) if len(self._pending_ack_crcs) < MAX_PENDING_ACK_CRCS: self._pending_ack_crcs.add(ack_crc) success = await self._packet_injector(pkt, wait_for_ack=True) - is_flood = contact.out_path_len <= 0 if success: self.stats.record_tx(is_flood=is_flood) else: @@ -621,8 +623,47 @@ def _login_cb(success: bool, data: dict) -> None: self._login_response_handler.set_login_callback(None) self._login_response_handler.clear_login_password(dest_hash) - async def send_status_request(self, pub_key: bytes) -> dict: - return await self.send_repeater_command(pub_key, "status") + async def send_status_request(self, pub_key: bytes, timeout: float = 15.0) -> dict: + """Send a protocol request for repeater stats (REQ_TYPE_GET_STATUS). + + The firmware handles CMD_SEND_STATUS_REQ by calling + ``sendRequest(*recipient, REQ_TYPE_GET_STATUS, tag, est_timeout)`` + which creates a PAYLOAD_TYPE_REQ packet. The remote repeater replies + with a PAYLOAD_TYPE_RESPONSE containing ``reflected_timestamp(4) + + RepeaterStats(48)``. + """ + contact = self.contacts.get_by_key(pub_key) + if not contact: + return {"success": False, "reason": "Contact not found"} + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return {"success": False, "reason": "Contact not found"} + contact_hash = bytes.fromhex(proxy.public_key)[0] + waiter = ResponseWaiter() + self._protocol_response_handler.set_response_callback( + contact_hash, waiter.callback + ) + try: + pkt, _ = PacketBuilder.create_protocol_request( + contact=proxy, + local_identity=self._identity, + protocol_code=REQ_TYPE_GET_STATUS, + data=b"", + ) + await self._packet_injector(pkt, wait_for_ack=False) + result = await waiter.wait(timeout) + return { + "success": result.get("success", False), + "repeater": contact.name, + "stats": result.get("parsed", {}), + "response_text": result.get("text"), + "reason": "Stats received" if result.get("success") else "Stats request failed", + } + except Exception as e: + logger.error(f"Status request error: {e}") + return {"success": False, "reason": str(e)} + finally: + self._protocol_response_handler.clear_response_callback(contact_hash) async def send_telemetry_request( self, @@ -778,12 +819,13 @@ def _response_cb(message_text: str, sender_contact: Any) -> None: self._text_handler.set_command_response_callback(_response_cb) try: + msg_type = "flood" if proxy.out_path_len < 0 else "direct" pkt, ack_crc = PacketBuilder.create_text_message( contact=proxy, local_identity=self._identity, message=full_command, attempt=1, - message_type="command", + message_type=msg_type, ) await self._packet_injector(pkt, wait_for_ack=True) try: diff --git a/src/pymc_core/companion/contact_store.py b/src/pymc_core/companion/contact_store.py index 7e0dd54..0b90341 100644 --- a/src/pymc_core/companion/contact_store.py +++ b/src/pymc_core/companion/contact_store.py @@ -193,7 +193,7 @@ def load_from_dicts(self, records: Iterable[dict]): name=rec.get("name", ""), adv_type=rec.get("adv_type", 0), flags=rec.get("flags", 0), - out_path_len=rec.get("out_path_len", -1), + out_path_len=-1 if rec.get("out_path_len", -1) in (-1, 255) else rec.get("out_path_len", -1), out_path=out_path, last_advert_timestamp=rec.get("last_advert_timestamp", 0), lastmod=rec.get("lastmod", 0), diff --git a/src/pymc_core/node/handlers/login_response.py b/src/pymc_core/node/handlers/login_response.py index a695a97..6aea6f9 100644 --- a/src/pymc_core/node/handlers/login_response.py +++ b/src/pymc_core/node/handlers/login_response.py @@ -3,7 +3,7 @@ from typing import Callable, Optional from ...protocol import CryptoUtils, Identity, Packet -from ...protocol.constants import PAYLOAD_TYPE_ANON_REQ, PAYLOAD_TYPE_RESPONSE +from ...protocol.constants import MAX_PATH_SIZE, PAYLOAD_TYPE_ANON_REQ, PAYLOAD_TYPE_PATH, PAYLOAD_TYPE_RESPONSE from .base import BaseHandler # Response codes from C++ server @@ -147,7 +147,25 @@ async def _decrypt_response( aes_key = shared_secret[:16] plaintext = CryptoUtils.mac_then_decrypt(aes_key, shared_secret, encrypted_data) - if not plaintext or len(plaintext) < 12: + if not plaintext: + return None + + # If this is a PATH packet, unwrap the path-return envelope to get + # the inner response. PATH format after decryption: + # path_len(1) + path(N) + extra_type(1) + extra_data(M) + pkt_type = (packet.header >> 2) & 0x0F + if pkt_type == PAYLOAD_TYPE_PATH and len(plaintext) >= 2: + path_len_byte = plaintext[0] + inner_offset = 1 + path_len_byte + 1 # skip path_len + path + extra_type + if ( + path_len_byte <= MAX_PATH_SIZE + and len(plaintext) >= inner_offset + ): + extra_type = plaintext[1 + path_len_byte] & 0x0F + if extra_type == PAYLOAD_TYPE_RESPONSE and len(plaintext) > inner_offset: + plaintext = plaintext[inner_offset:] + + if len(plaintext) < 12: return None # Parse the C++ response format: diff --git a/src/pymc_core/node/handlers/protocol_response.py b/src/pymc_core/node/handlers/protocol_response.py index ef36d01..2799a8a 100644 --- a/src/pymc_core/node/handlers/protocol_response.py +++ b/src/pymc_core/node/handlers/protocol_response.py @@ -8,10 +8,93 @@ import struct from typing import Any, Callable, Dict, Optional -from ...hardware.signal_utils import snr_register_to_db from ...protocol import CryptoUtils, Identity, Packet from ...protocol.constants import MAX_PATH_SIZE, PAYLOAD_TYPE_PATH, PAYLOAD_TYPE_RESPONSE +# --------------------------------------------------------------------------- +# Built-in CayenneLPP decoder (no external dependency) +# Spec: https://docs.mydevices.com/docs/lorawan/cayenne-lpp +# Each record: channel(1) + type_id(1) + value(N) +# --------------------------------------------------------------------------- + +_LPP_TYPES: Dict[int, tuple] = { + # type_id: (name, value_size_bytes, divisor, signed) + # --- Original LPPv1 types --- + 0x00: ("Digital Input", 1, 1, False), + 0x01: ("Digital Output", 1, 1, False), + 0x02: ("Analog Input", 2, 100, True), + 0x03: ("Analog Output", 2, 100, True), + # --- Extended types (from CayenneLPP.h) --- + 0x64: ("Generic Sensor", 4, 1, False), # LPP_GENERIC_SENSOR = 100 + 0x65: ("Illuminance", 2, 1, False), # LPP_LUMINOSITY = 101 + 0x66: ("Presence", 1, 1, False), # LPP_PRESENCE = 102 + 0x67: ("Temperature", 2, 10, True), # LPP_TEMPERATURE = 103 + 0x68: ("Humidity", 1, 2, False), # LPP_RELATIVE_HUMIDITY = 104 + 0x71: ("Accelerometer", 6, 1000, True), # LPP_ACCELEROMETER = 113, 3×int16 + 0x73: ("Barometer", 2, 10, False), # LPP_BAROMETRIC_PRESSURE = 115 + 0x74: ("Voltage", 2, 100, False), # LPP_VOLTAGE = 116, 0.01V + 0x75: ("Current", 2, 1000, False), # LPP_CURRENT = 117, 0.001A + 0x76: ("Frequency", 4, 1, False), # LPP_FREQUENCY = 118, 1Hz + 0x78: ("Percentage", 1, 1, False), # LPP_PERCENTAGE = 120, 1-100% + 0x79: ("Altitude", 2, 1, True), # LPP_ALTITUDE = 121, 1m signed + 0x7D: ("Concentration", 2, 1, False), # LPP_CONCENTRATION = 125, 1ppm + 0x80: ("Power", 2, 1, False), # LPP_POWER = 128, 1W + 0x82: ("Distance", 4, 1000, False), # LPP_DISTANCE = 130, 0.001m + 0x83: ("Energy", 4, 1000, False), # LPP_ENERGY = 131, 0.001kWh + 0x84: ("Direction", 2, 1, False), # LPP_DIRECTION = 132, 1deg + 0x85: ("Unix Time", 4, 1, False), # LPP_UNIXTIME = 133 + 0x86: ("Gyroscope", 6, 100, True), # LPP_GYROMETER = 134, 3×int16 + 0x87: ("Colour", 3, 1, False), # LPP_COLOUR = 135, RGB + 0x88: ("GPS", 9, 1, True), # LPP_GPS = 136, lat(3)+lon(3)+alt(3) + 0x8E: ("Switch", 1, 1, False), # LPP_SWITCH = 142, 0/1 +} + + +def _decode_cayenne_lpp(data: bytes) -> list: + """Decode CayenneLPP binary payload into a list of sensor dicts.""" + sensors: list = [] + idx = 0 + while idx + 2 <= len(data): + channel = data[idx] + type_id = data[idx + 1] + idx += 2 + spec = _LPP_TYPES.get(type_id) + if spec is None: + break # unknown type → stop (remaining bytes may be padding) + name, size, divisor, signed = spec + if idx + size > len(data): + break + raw = data[idx : idx + size] + idx += size + + if type_id == 0x88: + # GPS: lat(3, signed, /10000) + lon(3, signed, /10000) + alt(3, signed, /100) + lat = int.from_bytes(raw[0:3], "big", signed=True) / 10000 + lon = int.from_bytes(raw[3:6], "big", signed=True) / 10000 + alt = int.from_bytes(raw[6:9], "big", signed=True) / 100 + sensors.append({"channel": channel, "type": name, "type_id": type_id, + "value": {"latitude": lat, "longitude": lon, "altitude": alt}, + "raw_value": raw.hex()}) + elif size == 6 and type_id in (0x71, 0x86): + # 3-axis: x(2) + y(2) + z(2), all signed + x = int.from_bytes(raw[0:2], "big", signed=True) / divisor + y = int.from_bytes(raw[2:4], "big", signed=True) / divisor + z = int.from_bytes(raw[4:6], "big", signed=True) / divisor + sensors.append({"channel": channel, "type": name, "type_id": type_id, + "value": {"x": x, "y": y, "z": z}, + "raw_value": raw.hex()}) + elif type_id == 0x87: + # Colour: R(1) + G(1) + B(1) + sensors.append({"channel": channel, "type": name, "type_id": type_id, + "value": {"r": raw[0], "g": raw[1], "b": raw[2]}, + "raw_value": raw.hex()}) + else: + val = int.from_bytes(raw, "big", signed=signed) + sensors.append({"channel": channel, "type": name, "type_id": type_id, + "value": val / divisor if divisor != 1 else val, + "raw_value": raw.hex()}) + return sensors + class ProtocolResponseHandler: """Handler for protocol responses that come back as encrypted PATH packets. @@ -53,13 +136,13 @@ def set_binary_response_callback(self, callback: Callable[..., Any]) -> None: self._binary_response_callback = callback async def __call__(self, pkt: Packet) -> None: - """Handle incoming PATH packet that might be a protocol response.""" + """Handle incoming PATH or RESPONSE packet that might be a protocol response.""" try: # Check if this looks like an encrypted protocol response if len(pkt.payload) < 4: return # Too short for protocol response - # PATH packet structure: + # Both PATH and RESPONSE packets share the same structure: # dest_hash(1) + src_hash(1) + encrypted_data(N) src_hash = pkt.payload[1] @@ -69,7 +152,7 @@ async def __call__(self, pkt: Packet) -> None: self._log( "[ProtocolResponse] Processing potential protocol response " - f"from 0x{src_hash:02X}" + f"from 0x{src_hash:02X}, payload_len={len(pkt.payload)}" ) # Try to decrypt the response @@ -77,6 +160,29 @@ async def __call__(self, pkt: Packet) -> None: pkt, src_hash ) + # If an explicit response callback is waiting for this source (e.g. telemetry, + # stats, repeater command), deliver there first. The binary/path-discovery + # callback is a generic fallback for unsolicited binary responses. + # + # Guard: skip responses that are clearly NOT protocol responses (e.g. a + # stale login response retransmission). Protocol responses always decrypt + # to a tag(4) + meaningful payload, so ≥20 bytes. Login responses are only + # ~12 bytes and parse as "binary" fallback. Without this check a + # retransmitted login response can consume the stats/telemetry waiter. + if src_hash in self._response_callbacks: + resp_type = parsed_data.get("type") if isinstance(parsed_data, dict) else None + decrypted_len = len(raw_decrypted) if raw_decrypted else 0 + if not success or (resp_type == "binary" and decrypted_len < 20): + self._log( + f"[ProtocolResponse] Ignoring non-protocol response for 0x{src_hash:02X} " + f"(success={success}, type={resp_type}, decrypted_len={decrypted_len})" + ) + return + callback = self._response_callbacks[src_hash] + if callback: + callback(success, decoded_text, parsed_data) + return + # If binary response callback is set, parse and invoke (plain tag+data or path-return format) if ( success @@ -85,25 +191,39 @@ async def __call__(self, pkt: Packet) -> None: and len(raw_decrypted) >= 4 ): path_info = None - tag_bytes = raw_decrypted[:4] - response_data = raw_decrypted[4:] - # Path-return format (MeshCore createPathReturn): path_len(1), path(path_len), extra_type(1), extra - path_len = raw_decrypted[0] - if ( - path_len <= MAX_PATH_SIZE - and len(raw_decrypted) >= 1 + path_len + 1 + 4 - ): - out_path = bytes(raw_decrypted[1 : 1 + path_len]) - extra_type = raw_decrypted[1 + path_len] - extra = raw_decrypted[2 + path_len :] - if extra_type == PAYLOAD_TYPE_RESPONSE and len(extra) >= 4: - tag_bytes = extra[:4] - response_data = extra[4:] - in_path = bytes(pkt.path) if pkt.path else b"" - contact = self._find_contact_by_hash(src_hash) - if contact: - contact_pubkey = bytes.fromhex(contact.public_key) - path_info = (out_path, in_path, contact_pubkey) + pkt_type = (pkt.header >> 2) & 0x0F + + if pkt_type == PAYLOAD_TYPE_PATH: + # PATH packet: decrypted is path_len(1)+path(N)+extra_type(1)+extra + # Extract inner response from path-return structure + path_len_byte = raw_decrypted[0] + inner_offset = 1 + path_len_byte + 1 + if ( + path_len_byte <= MAX_PATH_SIZE + and len(raw_decrypted) >= inner_offset + 4 + ): + out_path = bytes(raw_decrypted[1 : 1 + path_len_byte]) + extra_type = raw_decrypted[1 + path_len_byte] & 0x0F + extra = raw_decrypted[inner_offset:] + if extra_type == PAYLOAD_TYPE_RESPONSE and len(extra) >= 4: + tag_bytes = extra[:4] + response_data = extra[4:] + in_path = bytes(pkt.path) if pkt.path else b"" + contact = self._find_contact_by_hash(src_hash) + if contact: + contact_pubkey = bytes.fromhex(contact.public_key) + path_info = (out_path, in_path, contact_pubkey) + else: + tag_bytes = raw_decrypted[:4] + response_data = raw_decrypted[4:] + else: + tag_bytes = raw_decrypted[:4] + response_data = raw_decrypted[4:] + else: + # RESPONSE packet: decrypted is tag(4)+data directly + tag_bytes = raw_decrypted[:4] + response_data = raw_decrypted[4:] + try: cb_result = self._binary_response_callback( tag_bytes, response_data, path_info @@ -114,18 +234,18 @@ async def __call__(self, pkt: Packet) -> None: self._log(f"[ProtocolResponse] Binary response callback error: {e}") return - # Call the waiting callback - callback = self._response_callbacks[src_hash] - if callback: - callback(success, decoded_text, parsed_data) - except Exception as e: self._log(f"[ProtocolResponse] Error processing protocol response: {e}") async def _decrypt_protocol_response( self, pkt: Packet, src_hash: int ) -> tuple[bool, str, Dict[str, Any], Optional[bytes]]: - """Decrypt and parse a protocol response packet. Returns (success, text, parsed_data, raw_decrypted).""" + """Decrypt and parse a protocol response packet. Returns (success, text, parsed_data, raw_decrypted). + + Handles both packet types by inspecting the actual packet header: + - PAYLOAD_TYPE_RESPONSE (0x01): direct datagram → decrypted = tag(4)+data + - PAYLOAD_TYPE_PATH (0x08): path return → decrypted = path_len(1)+path(N)+extra_type(1)+extra + """ try: # Find the contact by hash contact = self._find_contact_by_hash(src_hash) @@ -144,10 +264,41 @@ async def _decrypt_protocol_response( # Decrypt the payload decrypted = CryptoUtils.mac_then_decrypt(aes_key, shared_secret, encrypted_data) - self._log(f"[ProtocolResponse] Successfully decrypted {len(decrypted)} bytes") + # Determine the actual payload type from the incoming packet header. + pkt_type = (pkt.header >> 2) & 0x0F + self._log( + f"[ProtocolResponse] Decrypted {len(decrypted)} bytes from " + f"pkt_type=0x{pkt_type:02X}, hex: {decrypted.hex()}" + ) + + # Extract the actual response data based on packet type. + response_data = decrypted + + if pkt_type == PAYLOAD_TYPE_PATH: + # Path-return format: path_len(1) + path(N) + extra_type(1) + extra_data + # The actual protocol response is inside the 'extra' field. + if len(decrypted) >= 2: # need at least path_len + extra_type + path_len_byte = decrypted[0] + inner_offset = 1 + path_len_byte + 1 # path_len + path + extra_type + if ( + path_len_byte <= MAX_PATH_SIZE + and len(decrypted) >= inner_offset + ): + extra_type = decrypted[1 + path_len_byte] & 0x0F + if extra_type == PAYLOAD_TYPE_RESPONSE and len(decrypted) > inner_offset: + response_data = decrypted[inner_offset:] + self._log( + f"[ProtocolResponse] PATH format: extracted inner response " + f"{len(response_data)} bytes (path_len={path_len_byte})" + ) + else: + self._log( + f"[ProtocolResponse] PATH format: extra_type=0x{extra_type:02X}, " + f"not RESPONSE" + ) # Parse based on content type - success, text, parsed = self._parse_protocol_response(decrypted) + success, text, parsed = self._parse_protocol_response(response_data) return success, text, parsed, decrypted except Exception as e: @@ -155,34 +306,62 @@ async def _decrypt_protocol_response( return False, f"Decryption failed: {e}", {}, None def _parse_protocol_response(self, data: bytes) -> tuple[bool, str, Dict[str, Any]]: - """Parse decrypted protocol response data.""" + """Parse decrypted protocol response data. + + Parse order mirrors MeshCore firmware priority: + 1. Stats (RepeaterStats struct, ≥52 bytes) + 2. Text / status (UTF-8 printable after stripping tag + nulls) + 3. Telemetry (reflected_timestamp + valid CayenneLPP with ≥1 sensor) + 4. Binary fallback + """ try: - # Check if this looks like a stats response (protocol 0x01) - if len(data) >= 48: - # Try parsing as RepeaterStats struct + self._log( + f"[ProtocolResponse] _parse_protocol_response: {len(data)} bytes, " + f"first 16: {data[:16].hex() if len(data) >= 16 else data.hex()}" + ) + + # 1. Check if this looks like a stats response (protocol 0x01) + # RepeaterStats is 48-56 bytes + 4-byte tag. Older firmware + # omits n_recv_errors (52 B struct → 56 total); PATH-wrapped + # responses may also lose trailing bytes to AES block alignment. + if len(data) >= 56: stats_result = self._parse_stats_response(data) if stats_result: - return True, stats_result["formatted"], stats_result["raw"] + # Include raw_bytes in the parsed dict so callers can + # forward the binary RepeaterStats to companion apps. + result_dict = stats_result["raw"] + result_dict["type"] = "stats" + result_dict["raw_bytes"] = stats_result["raw_bytes"] + self._log( + f"[ProtocolResponse] Parsed as STATS: batt={result_dict['batt_milli_volts']}mV, " + f"rssi={result_dict['last_rssi']}, snr={result_dict['last_snr']}, " + f"raw_bytes={len(result_dict['raw_bytes'])}B" + ) + return True, stats_result["formatted"], result_dict - # Check if this looks like a telemetry response (protocol 0x03) - if len(data) >= 4: # At minimum need some telemetry data + # 2. Try parsing as text/status response. + # Status responses are tag(4) + UTF-8 text. Strip the 4-byte + # tag that prefixes every response, then check for printable text. + if len(data) > 4: + try: + text_candidate = data[4:].rstrip(b"\x00").decode("utf-8") + if text_candidate.strip() and text_candidate.strip().isprintable(): + return ( + True, + text_candidate.strip(), + {"type": "text", "content": text_candidate.strip()}, + ) + except UnicodeDecodeError: + pass + + # 3. Check if this looks like a telemetry response (protocol 0x03) + # Must decode at least one sensor from valid CayenneLPP after the tag. + if len(data) >= 8: # tag(4) + at least one LPP record (ch+type+val = 3+) telemetry_result = self._parse_telemetry_response(data) - if telemetry_result: + if telemetry_result and telemetry_result.get("sensor_count", 0) > 0: return True, telemetry_result["formatted"], telemetry_result - # Try parsing as text response - try: - text_response = data.rstrip(b"\x00").decode("utf-8") - if text_response.strip(): - return ( - True, - text_response, - {"type": "text", "content": text_response}, - ) - except UnicodeDecodeError: - pass - - # Fall back to hex representation + # 4. Fall back to hex representation hex_response = data.hex() return ( True, @@ -194,37 +373,100 @@ def _parse_protocol_response(self, data: bytes) -> tuple[bool, str, Dict[str, An return False, f"Parse error: {e}", {} def _parse_stats_response(self, data: bytes) -> Optional[Dict[str, Any]]: - """Parse RepeaterStats struct response (protocol 0x01).""" + """Parse RepeaterStats struct response (protocol 0x01). + + RepeaterStats layout (from simple_repeater/MyMesh.h): + uint16_t batt_milli_volts; // offset 0 + uint16_t curr_tx_queue_len; // offset 2 + int16_t noise_floor; // offset 4 + int16_t last_rssi; // offset 6 + uint32_t n_packets_recv; // offset 8 + uint32_t n_packets_sent; // offset 12 + uint32_t total_air_time_secs; // offset 16 + uint32_t total_up_time_secs; // offset 20 + uint32_t n_sent_flood; // offset 24 + uint32_t n_sent_direct; // offset 28 + uint32_t n_recv_flood; // offset 32 + uint32_t n_recv_direct; // offset 36 + uint16_t err_events; // offset 40 + int16_t last_snr; // ×4 // offset 42 + uint16_t n_direct_dups; // offset 44 + uint16_t n_flood_dups; // offset 46 + uint32_t total_rx_air_time_secs; // offset 48 + uint32_t n_recv_errors; // offset 52 + Total: 56 bytes + """ try: - # Skip 4-byte header as per C++ code: memcpy(&reply_data[4], &stats, sizeof(stats)) - if len(data) < 52: # 4 header + 48 struct = 52 minimum + # Skip 4-byte reflected timestamp/tag + # memcpy(&reply_data[4], &stats, sizeof(stats)) + if len(data) < 56: # 4 tag + 52 struct minimum (without n_recv_errors) return None - stats_data = data[4:] # Skip the 4-byte header - - # Parse as all 16-bit values - this gives correct results - parsed = struct.unpack("<24H", stats_data[:48]) + stats_data = data[4:] # Skip the 4-byte tag + + # Pad to 56 bytes so struct.unpack always succeeds. Older firmware + # or PATH-wrapped responses with AES block alignment may yield fewer + # than 56 bytes; missing trailing fields default to zero. + if len(stats_data) < 56: + stats_data = stats_data + b"\x00" * (56 - len(stats_data)) + + # Parse with correct field types matching C++ struct + ( + batt_milli_volts, # uint16 offset 0 + curr_tx_queue_len, # uint16 offset 2 + noise_floor, # int16 offset 4 + last_rssi, # int16 offset 6 + n_packets_recv, # uint32 offset 8 + n_packets_sent, # uint32 offset 12 + total_air_time_secs,# uint32 offset 16 + total_up_time_secs, # uint32 offset 20 + n_sent_flood, # uint32 offset 24 + n_sent_direct, # uint32 offset 28 + n_recv_flood, # uint32 offset 32 + n_recv_direct, # uint32 offset 36 + err_events, # uint16 offset 40 + last_snr_raw, # int16 offset 42 + n_direct_dups, # uint16 offset 44 + n_flood_dups, # uint16 offset 46 + total_rx_air_time_secs, # uint32 offset 48 + n_recv_errors, # uint32 offset 52 + ) = struct.unpack(" Optional[Dict[str, Any]]: Expected format: - reflected_timestamp (4 bytes, little-endian) - CayenneLPP data (remaining bytes) + + Returns None if no valid CayenneLPP sensors can be decoded, allowing + the caller to fall back to other response types. """ try: - if len(data) < 4: - self._log( - "[ProtocolResponse] Telemetry data too short: " - f"{len(data)} bytes (need at least 4 for timestamp)" - ) + if len(data) < 8: + # Need at least tag(4) + one minimal LPP record (ch+type+val = 3) return None - self._log( - "[ProtocolResponse] Parsing " f"{len(data)} bytes telemetry data: {data.hex()}" - ) - - # Parse according to MeshCore TelemetryResponseData structure - # First 4 bytes: reflected timestamp (little-endian) + # First 4 bytes: reflected timestamp / tag (little-endian) reflected_timestamp = struct.unpack("= 0 and lpp_data[last_nonzero] == 0: - last_nonzero -= 1 - - if last_nonzero < len(lpp_data) - 1: - lpp_data = lpp_data[: last_nonzero + 1] - - # Try using the cayenne_lpp_helpers function (now without trailing zeros) if available - try: - try: - from utils.cayenne_lpp_helpers import decode_cayenne_lpp_payload - helper_result = decode_cayenne_lpp_payload(lpp_data.hex()) - except ImportError: - # Utils not available in lightweight mode - helper_result = {"error": "cayenne_lpp_helpers not available"} - - if "error" not in helper_result and helper_result.get("sensor_count", 0) > 0: - self._log( - "[ProtocolResponse] CayenneLPP parsing succeeded: " - f"{helper_result['sensor_count']} sensors" - ) + if len(lpp_data) < 3: + # Not enough for even one LPP record (channel + type + 1-byte value) + return None - # Convert to our expected format - converted_sensors = [] - for sensor in helper_result["sensors"]: - converted_sensor = { - "channel": sensor["channel"], - "type": sensor["type"], - "type_id": sensor["type_id"], - "value": sensor["value"], - "raw_value": sensor["raw_value"], - } - converted_sensors.append(converted_sensor) - - return { - "type": "telemetry", - "formatted": ( - f"Telemetry ({len(converted_sensors)} sensors, " - f"ts:{reflected_timestamp})" - ), - "reflected_timestamp": reflected_timestamp, - "sensor_count": len(converted_sensors), - "sensors": converted_sensors, - } - else: - self._log( - "[ProtocolResponse] CayenneLPP parsing failed: " - f"{helper_result.get('error', 'no sensors found')}" - ) + # Sanity check: MeshCore telemetry always starts with + # addVoltage(TELEM_CHANNEL_SELF=1, battery_volts) which produces + # channel=1, type=0x74 (LPP_VOLTAGE). Require this signature to + # distinguish telemetry from other response types that happen to + # decrypt to >= 8 bytes. + if lpp_data[0] != 0x01 or lpp_data[1] != 0x74: + return None - except Exception as e: - self._log(f"[ProtocolResponse] CayenneLPP parsing exception: {e}") + sensors = _decode_cayenne_lpp(lpp_data) + if not sensors: + return None - # All parsing methods failed - self._log("[ProtocolResponse] CayenneLPP parsing failed") + self._log( + f"[ProtocolResponse] CayenneLPP decoded {len(sensors)} sensor(s) " + f"from {len(lpp_data)} bytes: {lpp_data.hex()}" + ) return { "type": "telemetry", "formatted": ( - f"Unknown telemetry LPP data ({len(lpp_data)} bytes, " + f"Telemetry ({len(sensors)} sensors, " f"ts:{reflected_timestamp})" ), "reflected_timestamp": reflected_timestamp, - "sensor_count": 0, - "sensors": [], + "sensor_count": len(sensors), + "sensors": sensors, + "raw_bytes": bytes(data[4:]), # LPP data after tag for verbatim forwarding } except Exception as e: self._log(f"[ProtocolResponse] Telemetry parsing failed: {e}") - return { - "type": "telemetry", - "length": len(data), - "hex": data.hex(), - "format": "error", - "formatted": f"Telemetry parsing error: {e}", - "error": str(e), - } - - def _convert_signed_16bit(self, value: int) -> int: - """Convert unsigned 16-bit to signed if needed.""" - return value - 65536 if value > 32767 else value + return None def _format_stats(self, stats: Dict[str, Any]) -> str: """Format stats as human-readable string.""" @@ -368,11 +543,17 @@ def _format_stats(self, stats: Dict[str, Any]) -> str: # Signal quality result.append(f"RSSI: {stats['last_rssi']}dBm") result.append(f"SNR: {stats['last_snr']:.1f}dB") + result.append(f"NF: {stats['noise_floor']}dB") # Packet counts - result.append(f"RX: {stats['n_packets_recv']}") - result.append(f"TX: {stats['n_packets_sent']}") - result.append(f"Flood RX: {stats['n_recv_flood']}") + result.append( + f"TX: {stats['n_packets_sent']} " + f"(F:{stats['n_sent_flood']}/D:{stats['n_sent_direct']})" + ) + result.append( + f"RX: {stats['n_packets_recv']} " + f"(F:{stats['n_recv_flood']}/D:{stats['n_recv_direct']})" + ) # Uptime formatting uptime = stats["total_up_time_secs"] @@ -388,15 +569,21 @@ def _format_stats(self, stats: Dict[str, Any]) -> str: result.append(f"Up: {days}d{hours}h") # Air time - result.append(f"Air: {stats['total_air_time_secs']}s") + result.append(f"TxAir: {stats['total_air_time_secs']}s") + if stats.get("total_rx_air_time_secs"): + result.append(f"RxAir: {stats['total_rx_air_time_secs']}s") # Error events (only if > 0) if stats["err_events"] > 0: result.append(f"Err: {stats['err_events']}") + # RX errors (only if > 0) + if stats.get("n_recv_errors", 0) > 0: + result.append(f"RxErr: {stats['n_recv_errors']}") + # Duplicates (only if > 0) if stats["n_direct_dups"] > 0 or stats["n_flood_dups"] > 0: - result.append(f"Dups: {stats['n_direct_dups']}/{stats['n_flood_dups']}") + result.append(f"Dups: D:{stats['n_direct_dups']}/F:{stats['n_flood_dups']}") return " | ".join(result) diff --git a/src/pymc_core/protocol/__init__.py b/src/pymc_core/protocol/__init__.py index 75cdddf..c583630 100644 --- a/src/pymc_core/protocol/__init__.py +++ b/src/pymc_core/protocol/__init__.py @@ -42,6 +42,7 @@ PH_VER_MASK, PH_VER_SHIFT, PUB_KEY_SIZE, + REQ_TYPE_GET_STATUS, REQ_TYPE_GET_TELEMETRY_DATA, ROUTE_TYPE_DIRECT, ROUTE_TYPE_FLOOD, @@ -148,7 +149,8 @@ "CONTACT_TYPE_REPEATER", "CONTACT_TYPE_ROOM_SERVER", "CONTACT_TYPE_HYBRID", - # Telemetry + # Protocol request types + "REQ_TYPE_GET_STATUS", "REQ_TYPE_GET_TELEMETRY_DATA", "TELEM_PERM_BASE", "TELEM_PERM_LOCATION", diff --git a/src/pymc_core/protocol/constants.py b/src/pymc_core/protocol/constants.py index d5b23d9..adeff12 100644 --- a/src/pymc_core/protocol/constants.py +++ b/src/pymc_core/protocol/constants.py @@ -107,9 +107,9 @@ def describe_advert_flags(flags: int) -> str: CONTACT_TYPE_HYBRID = 4 -# Telemetry Permissions - -REQ_TYPE_GET_TELEMETRY_DATA = 0x03 +# Protocol Request Types +REQ_TYPE_GET_STATUS = 0x01 # Get repeater stats (RepeaterStats struct) +REQ_TYPE_GET_TELEMETRY_DATA = 0x03 # Get telemetry data (CayenneLPP) TELEM_PERM_BASE = 0x01 TELEM_PERM_LOCATION = 0x02 TELEM_PERM_ENVIRONMENT = 0x04 diff --git a/src/pymc_core/protocol/packet_builder.py b/src/pymc_core/protocol/packet_builder.py index 82f48d2..8fd6a74 100644 --- a/src/pymc_core/protocol/packet_builder.py +++ b/src/pymc_core/protocol/packet_builder.py @@ -469,10 +469,25 @@ def create_login_packet(contact: Any, local_identity: LocalIdentity, password: s contact_identity = Identity(contact_pubkey) shared_secret = contact_identity.calc_shared_secret(local_identity.get_private_key()) - return PacketBuilder.create_anon_req( - contact_identity, local_identity, shared_secret, plaintext, "direct" + out_path_len = getattr(contact, "out_path_len", -1) + if out_path_len < 0: + route_type = "flood" + else: + route_type = "direct" + + pkt = PacketBuilder.create_anon_req( + contact_identity, local_identity, shared_secret, plaintext, route_type ) + if route_type == "direct" and out_path_len > 0: + out_path = getattr(contact, "out_path", b"") + if out_path: + path_bytes = out_path[:MAX_PATH_SIZE] + pkt.path = bytearray(path_bytes) + pkt.path_len = len(pkt.path) + + return pkt + @staticmethod def create_group_datagram( group_name: str, @@ -807,8 +822,22 @@ def create_protocol_request( contact, local_identity, plaintext ) - header = PacketBuilder._create_header(PAYLOAD_TYPE_REQ) + out_path_len = getattr(contact, "out_path_len", -1) + if out_path_len < 0: + route_type = "flood" + else: + route_type = "direct" + + header = PacketBuilder._create_header(PAYLOAD_TYPE_REQ, route_type) packet = PacketBuilder._create_packet(header, payload) + + if route_type == "direct" and out_path_len > 0: + out_path = getattr(contact, "out_path", b"") + if out_path: + path_bytes = out_path[:MAX_PATH_SIZE] + packet.path = bytearray(path_bytes) + packet.path_len = len(packet.path) + return packet, timestamp @staticmethod From 80544b62de0e72ac724b0222e8c05a52ac2a82d7 Mon Sep 17 00:00:00 2001 From: agessaman Date: Sat, 14 Feb 2026 19:59:19 -0800 Subject: [PATCH 03/50] Enhance CompanionRadio with new path discovery and binary request features - Introduced methods for sending trace packets with explicit paths and handling path discovery responses. - Added support for sending binary requests with improved error handling and response management. - Updated existing path discovery methods to utilize new tagging and telemetry request mechanisms. - Enhanced logging for error handling in packet sending operations. --- src/pymc_core/companion/companion_radio.py | 183 +++++++++++++++++++-- 1 file changed, 170 insertions(+), 13 deletions(-) diff --git a/src/pymc_core/companion/companion_radio.py b/src/pymc_core/companion/companion_radio.py index 5f8b612..76cef8c 100644 --- a/src/pymc_core/companion/companion_radio.py +++ b/src/pymc_core/companion/companion_radio.py @@ -11,11 +11,18 @@ import asyncio import logging +import random from typing import Any, Callable, Optional from ..node.node import MeshNode -from ..protocol import LocalIdentity, PacketBuilder -from ..protocol.constants import ADVERT_FLAG_HAS_LOCATION, ADVERT_FLAG_HAS_NAME +from ..protocol import LocalIdentity, Packet, PacketBuilder +from ..protocol.constants import ( + ADVERT_FLAG_HAS_LOCATION, + ADVERT_FLAG_HAS_NAME, + PAYLOAD_TYPE_CONTROL, + REQ_TYPE_GET_TELEMETRY_DATA, + TELEM_PERM_BASE, +) from .companion_base import CompanionBase, adv_type_to_flags from .constants import ( ADV_TYPE_CHAT, @@ -83,6 +90,7 @@ def __init__( ) self._radio = radio self._dispatcher_task: Optional[asyncio.Task] = None + self._pending_discovery_tags: set[int] = set() self.node = MeshNode( radio=radio, @@ -323,27 +331,84 @@ async def send_trace_path( logger.error(f"Error sending trace: {e}") return False + async def send_trace_path_raw( + self, + tag: int, + auth_code: int, + flags: int, + path_bytes: bytes, + ) -> bool: + """Send a trace packet with an explicit path (e.g. from CMD_SEND_TRACE_PATH). Matches firmware behavior.""" + try: + path_list = list(path_bytes) + pkt = PacketBuilder.create_trace(tag, auth_code, flags, path=path_list) + return await self.node.dispatcher.send_packet(pkt, wait_for_ack=False) + except Exception as e: + logger.error(f"Error sending trace (raw path): {e}") + return False + + async def _try_handle_path_discovery( + self, tag_bytes: bytes, path_info: tuple + ) -> bool: + """If tag is pending path discovery, fire path_discovery_response and return True.""" + out_path, in_path, contact_pubkey = path_info + tag_int = int.from_bytes(tag_bytes, "little") + if tag_int not in self._pending_discovery_tags: + return False + self._pending_discovery_tags.discard(tag_int) + await self._fire_callbacks( + "path_discovery_response", + tag_bytes, + contact_pubkey, + out_path, + in_path, + ) + return True + async def send_path_discovery(self, pub_key: bytes) -> bool: + """Legacy: send path discovery without returning tag. Prefer send_path_discovery_req.""" + result = await self.send_path_discovery_req(pub_key) + return result.success + + async def send_path_discovery_req(self, pub_key: bytes) -> SentResult: + """Send path discovery (flood telemetry request with tag). Returns SentResult for RESP_CODE_SENT. + When path return arrives with matching tag, path_discovery_response is fired (PUSH 0x8D).""" contact = self.contacts.get_by_key(pub_key) if not contact: - return False + return SentResult(success=False) + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return SentResult(success=False) + tag_int = random.randint(0, 0xFFFFFFFF) + tag_bytes = tag_int.to_bytes(4, "little") + inv_perm = 0xFF & ~TELEM_PERM_BASE + req_payload = tag_bytes + bytes( + [REQ_TYPE_GET_TELEMETRY_DATA, inv_perm, 0, 0, 0] + ) old_path_len = contact.out_path_len old_path = contact.out_path contact.out_path_len = -1 contact.out_path = b"" self.contacts.update(contact) try: - result = await self.node.send_telemetry_request( - contact_name=contact.name, - want_base=False, - want_location=False, - want_environment=False, - timeout=5.0, + pkt, _ = PacketBuilder.create_protocol_request( + contact=proxy, + local_identity=self._identity, + protocol_code=REQ_TYPE_GET_TELEMETRY_DATA, + data=req_payload, + ) + success = await self.node.dispatcher.send_packet(pkt, wait_for_ack=False) + if success: + self._pending_discovery_tags.add(tag_int) + return SentResult( + success=success, + is_flood=True, + expected_ack=tag_int, + timeout_ms=10000, ) - return result.get("success", False) except Exception as e: logger.error(f"Error in path discovery: {e}") - return False + return SentResult(success=False) finally: current = self.contacts.get_by_key(pub_key) if current and current.out_path_len == -1: @@ -428,7 +493,54 @@ async def send_telemetry_request( logger.error(f"Telemetry request error: {e}") return {"success": False, "reason": str(e)} + async def send_binary_req( + self, pub_key: bytes, data: bytes, timeout_seconds: float = 15.0 + ) -> SentResult: + """Send binary request (CMD_SEND_BINARY_REQ). data = request_type(1) + optional payload. + Returns SentResult with expected_ack (4-byte tag as int) and timeout_ms for RESP_CODE_SENT. + """ + contact = self.contacts.get_by_key(pub_key) + if not contact: + return SentResult(success=False) + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return SentResult(success=False) + tag_int = random.randint(0, 0xFFFFFFFF) + tag_bytes = tag_int.to_bytes(4, "little") + tag_hex = tag_bytes.hex() + request_type = data[0] if len(data) >= 1 else 0 + req_payload = tag_bytes + data + self.cleanup_expired_binary_requests() + self.register_binary_request( + tag_hex, + request_type=request_type, + timeout_seconds=timeout_seconds, + pubkey_prefix=pub_key[:6].hex(), + ) + try: + pkt, _ = PacketBuilder.create_protocol_request( + contact=proxy, + local_identity=self._identity, + protocol_code=0x02, + data=req_payload, + ) + success = await self.node.dispatcher.send_packet(pkt, wait_for_ack=False) + except Exception as e: + logger.error(f"Binary request send error: {e}") + self._pending_binary_requests.pop(tag_hex, None) + return SentResult(success=False) + if not success: + self._pending_binary_requests.pop(tag_hex, None) + return SentResult(success=False) + return SentResult( + success=True, + is_flood=contact.out_path_len <= 0, + expected_ack=tag_int, + timeout_ms=10000, + ) + async def send_binary_request(self, pub_key: bytes, data: bytes) -> dict: + """Legacy: send binary request and wait for response via waiter. Prefer send_binary_req + on_binary_response.""" contact = self.contacts.get_by_key(pub_key) if not contact: return {"success": False, "reason": "Contact not found"} @@ -456,13 +568,54 @@ async def send_anon_request(self, pub_key: bytes, data: bytes) -> dict: logger.error(f"Anon request error: {e}") return {"success": False, "reason": str(e)} + async def send_repeater_command( + self, pub_key: bytes, command: str, parameters: Optional[str] = None + ) -> dict: + """Send a text-based command to a repeater and await response.""" + contact = self.contacts.get_by_key(pub_key) + if not contact: + return {"success": False, "reason": "Contact not found"} + try: + result = await self.node.send_repeater_command( + repeater_name=contact.name, + command=command, + parameters=parameters, + ) + return { + "success": result.get("success", False), + "repeater": contact.name, + "command": command, + "response": result.get("response"), + "reason": ( + "Command successful" if result.get("success") else "No response" + ), + } + except Exception as e: + logger.error(f"Repeater command error: {e}") + return {"success": False, "reason": str(e)} + # ------------------------------------------------------------------------- # Control Data # ------------------------------------------------------------------------- - async def send_control_data(self, data: bytes) -> bool: + async def send_control_data(self, data: Optional[bytes] = None) -> bool: + """Send a CONTROL packet. If data is provided and valid (len 1-254, first byte has 0x80), + send it as raw control payload; otherwise send a default discovery request (backward compat).""" + if data and len(data) <= 254 and (data[0] & 0x80) != 0: + try: + pkt = Packet() + pkt.header = PacketBuilder._create_header( + PAYLOAD_TYPE_CONTROL, route_type="direct" + ) + pkt.path_len = 0 + pkt.path = bytearray() + pkt.payload = bytearray(data) + pkt.payload_len = len(data) + return await self.node.dispatcher.send_packet(pkt, wait_for_ack=False) + except Exception as e: + logger.error(f"Error sending control data: {e}") + return False try: - import random tag = random.randint(0, 0xFFFFFFFF) pkt = PacketBuilder.create_discovery_request(tag, filter_mask=0x04) return await self.node.dispatcher.send_packet(pkt, wait_for_ack=False) @@ -490,6 +643,10 @@ def _setup_packet_callbacks(self) -> None: dispatcher = self.node.dispatcher dispatcher.set_packet_received_callback(self._on_packet_received) dispatcher.set_packet_sent_callback(self._on_packet_sent) + if hasattr(dispatcher, "protocol_response_handler") and dispatcher.protocol_response_handler: + dispatcher.protocol_response_handler.set_binary_response_callback( + self._on_binary_response + ) async def _on_packet_received(self, pkt: Any) -> None: from ..protocol.constants import ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD From aa7e5ff9653f6a33219a11928e11df9e675f7ef9 Mon Sep 17 00:00:00 2001 From: agessaman Date: Sat, 14 Feb 2026 20:05:25 -0800 Subject: [PATCH 04/50] Add CompanionRadio path discovery and binary request handling - Improved path discovery methods with enhanced tagging and telemetry requests. - Updated binary request handling to include better error management and response processing. - Enhanced logging for packet sending operations to improve traceability and debugging. --- tests/test_companion_base.py | 77 ++++++ tests/test_companion_bridge.py | 245 +++++++++++++++++++ tests/test_companion_radio.py | 305 +++++++++++++++++++++++ tests/test_companion_stores.py | 433 +++++++++++++++++++++++++++++++++ 4 files changed, 1060 insertions(+) create mode 100644 tests/test_companion_base.py create mode 100644 tests/test_companion_bridge.py create mode 100644 tests/test_companion_radio.py create mode 100644 tests/test_companion_stores.py diff --git a/tests/test_companion_base.py b/tests/test_companion_base.py new file mode 100644 index 0000000..e04951b --- /dev/null +++ b/tests/test_companion_base.py @@ -0,0 +1,77 @@ +"""Tests for companion base: ResponseWaiter, adv_type_to_flags, and base API via CompanionRadio.""" + +import pytest + +from pymc_core.companion.companion_base import ResponseWaiter, adv_type_to_flags +from pymc_core.companion.constants import ( + ADV_TYPE_CHAT, + ADV_TYPE_REPEATER, + ADV_TYPE_ROOM, + ADV_TYPE_SENSOR, +) +from pymc_core.protocol.constants import ( + ADVERT_FLAG_IS_CHAT_NODE, + ADVERT_FLAG_IS_REPEATER, + ADVERT_FLAG_IS_ROOM_SERVER, +) + + +# --------------------------------------------------------------------------- +# ResponseWaiter +# --------------------------------------------------------------------------- + + +class TestResponseWaiter: + def test_initial_state(self): + w = ResponseWaiter() + assert w.data["success"] is False + assert w.data["text"] is None + assert w.data["parsed"] == {} + + def test_callback_sets_data_and_event(self): + w = ResponseWaiter() + w.callback(True, "hello", {"k": "v"}) + assert w.data["success"] is True + assert w.data["text"] == "hello" + assert w.data["parsed"] == {"k": "v"} + assert w.event.is_set() + + @pytest.mark.asyncio + async def test_wait_returns_after_callback(self): + w = ResponseWaiter() + w.callback(True, "done", {"x": 1}) + result = await w.wait(timeout=1.0) + assert result["success"] is True + assert result["text"] == "done" + assert result["parsed"] == {"x": 1} + assert "timeout" not in result + + @pytest.mark.asyncio + async def test_wait_timeout(self): + w = ResponseWaiter() + result = await w.wait(timeout=0.05) + assert result["timeout"] is True + assert result["success"] is False + + +# --------------------------------------------------------------------------- +# adv_type_to_flags +# --------------------------------------------------------------------------- + + +class TestAdvTypeToFlags: + def test_chat(self): + assert adv_type_to_flags(ADV_TYPE_CHAT) == ADVERT_FLAG_IS_CHAT_NODE + + def test_repeater(self): + assert adv_type_to_flags(ADV_TYPE_REPEATER) == ADVERT_FLAG_IS_REPEATER + + def test_room(self): + assert adv_type_to_flags(ADV_TYPE_ROOM) == ADVERT_FLAG_IS_ROOM_SERVER + + def test_sensor(self): + assert adv_type_to_flags(ADV_TYPE_SENSOR) == 0x04 + + def test_unknown_defaults_to_chat(self): + assert adv_type_to_flags(99) == ADVERT_FLAG_IS_CHAT_NODE + assert adv_type_to_flags(0) == ADVERT_FLAG_IS_CHAT_NODE diff --git a/tests/test_companion_bridge.py b/tests/test_companion_bridge.py new file mode 100644 index 0000000..4cd2c80 --- /dev/null +++ b/tests/test_companion_bridge.py @@ -0,0 +1,245 @@ +"""Tests for CompanionBridge (repeater-integrated companion with packet_injector).""" + +import pytest + +from pymc_core.companion import CompanionBridge +from pymc_core.companion.models import Contact +from pymc_core.protocol import LocalIdentity, Packet +from pymc_core.protocol.constants import ( + PAYLOAD_TYPE_ADVERT, + PAYLOAD_TYPE_TXT_MSG, + ROUTE_TYPE_FLOOD, +) + + +def _make_peer_contact(name: str) -> Contact: + """Return a contact with a valid Ed25519 public key (required for packet encryption).""" + peer = LocalIdentity() + return Contact(public_key=peer.get_public_key(), name=name) + + +class MockPacketInjector: + """Records injected packets and returns True by default.""" + + def __init__(self): + self.calls: list[tuple] = [] + + async def __call__(self, pkt: Packet, wait_for_ack: bool = False) -> bool: + self.calls.append((pkt, wait_for_ack)) + return True + + +# --------------------------------------------------------------------------- +# Init +# --------------------------------------------------------------------------- + + +class TestCompanionBridgeInit: + def test_init_creates_stores(self): + injector = MockPacketInjector() + identity = LocalIdentity() + bridge = CompanionBridge(identity, injector, node_name="BridgeNode") + assert bridge.contacts is not None + assert bridge.contacts.get_count() == 0 + assert bridge.channels is not None + assert bridge.stats is not None + assert bridge.prefs.node_name == "BridgeNode" + assert bridge.get_public_key() == identity.get_public_key() + assert injector.calls == [] + + def test_init_with_authenticate_callback(self): + def auth_cb(*args, **kwargs): + return (True, 0) + + injector = MockPacketInjector() + bridge = CompanionBridge( + LocalIdentity(), + injector, + authenticate_callback=auth_cb, + ) + assert bridge._handlers is not None + + +# --------------------------------------------------------------------------- +# Lifecycle +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionBridgeLifecycle: + async def test_start_stop(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + assert bridge.is_running is False + await bridge.start() + assert bridge.is_running is True + await bridge.stop() + assert bridge.is_running is False + + +# --------------------------------------------------------------------------- +# process_received_packet +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionBridgeProcessReceivedPacket: + async def test_process_packet_records_rx_stats(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + await bridge.start() + pkt = Packet() + pkt.header = (ROUTE_TYPE_FLOOD << 0) | (PAYLOAD_TYPE_ADVERT << 2) + pkt.path_len = 0 + pkt.path = bytearray() + pkt.payload = bytearray() + pkt.payload_len = 0 + await bridge.process_received_packet(pkt) + tot = bridge.stats.get_totals() + assert tot["flood_rx"] == 1 + await bridge.stop() + + async def test_process_unknown_type_no_crash(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + pkt = Packet() + pkt.header = (ROUTE_TYPE_FLOOD << 0) | (15 << 2) + pkt.path_len = 0 + pkt.path = bytearray() + pkt.payload = bytearray() + pkt.payload_len = 0 + await bridge.process_received_packet(pkt) + assert True + + +# --------------------------------------------------------------------------- +# Advertise +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionBridgeAdvertise: + async def test_advertise_injects_packet(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + result = await bridge.advertise(flood=True) + assert result is True + assert len(injector.calls) == 1 + pkt, wait_for_ack = injector.calls[0] + assert pkt is not None + assert (pkt.header >> 2) & 0x0F == PAYLOAD_TYPE_ADVERT + assert wait_for_ack is False + assert bridge.stats.get_totals()["flood_tx"] == 1 + + +# --------------------------------------------------------------------------- +# Send text, share contact +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionBridgeSendAndShare: + async def test_send_text_message_no_contact(self, caplog): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + result = await bridge.send_text_message(b"\x00" * 32, "Hi") + assert result.success is False + assert len(injector.calls) == 0 + + async def test_send_text_message_with_contact_injects_packet(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + contact = _make_peer_contact("Alice") + bridge.contacts.add(contact) + result = await bridge.send_text_message(contact.public_key, "Hello") + assert len(injector.calls) >= 1 + pkt, _ = injector.calls[0] + assert (pkt.header >> 2) & 0x0F == PAYLOAD_TYPE_TXT_MSG + + async def test_share_contact_not_found(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + result = await bridge.share_contact(b"\x00" * 32) + assert result is False + assert len(injector.calls) == 0 + + async def test_share_contact_success(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + key = b"\x22" * 32 + bridge.contacts.add(Contact(public_key=key, name="Bob")) + result = await bridge.share_contact(key) + assert result is True + assert len(injector.calls) == 1 + + +# --------------------------------------------------------------------------- +# Path discovery, trace, control data +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionBridgePathAndControl: + async def test_send_path_discovery_req_no_contact(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + result = await bridge.send_path_discovery_req(b"\x00" * 32) + assert result.success is False + + async def test_send_path_discovery_req_success(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + contact = _make_peer_contact("Target") + bridge.contacts.add(contact) + result = await bridge.send_path_discovery_req(contact.public_key) + assert result.success is True + assert len(injector.calls) == 1 + assert result.timeout_ms == 10000 + + async def test_send_trace_path_raw(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + result = await bridge.send_trace_path_raw(0x12345678, 0xABCD, 0, bytes([0x01, 0x02])) + assert result is True + assert len(injector.calls) == 1 + + async def test_send_control_data_valid_payload(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + result = await bridge.send_control_data(bytes([0x80, 0x01])) + assert result is True + assert len(injector.calls) == 1 + pkt, _ = injector.calls[0] + assert pkt.payload_len == 2 + assert list(pkt.payload) == [0x80, 0x01] + + async def test_send_control_data_rejects_no_high_bit(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + result = await bridge.send_control_data(bytes([0x00, 0x01])) + assert result is False + assert len(injector.calls) == 0 + + +# --------------------------------------------------------------------------- +# Binary request +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionBridgeBinaryReq: + async def test_send_binary_req_no_contact(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + result = await bridge.send_binary_req(b"\x00" * 32, bytes([0x01])) + assert result.success is False + + async def test_send_binary_req_with_contact(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + contact = _make_peer_contact("Rpt") + bridge.contacts.add(contact) + result = await bridge.send_binary_req(contact.public_key, bytes([0x01]), timeout_seconds=5.0) + assert result.success is True + assert result.expected_ack is not None + assert len(injector.calls) == 1 diff --git a/tests/test_companion_radio.py b/tests/test_companion_radio.py new file mode 100644 index 0000000..6d24ca3 --- /dev/null +++ b/tests/test_companion_radio.py @@ -0,0 +1,305 @@ +"""Tests for CompanionRadio (stand-alone companion with radio).""" + +import pytest + +from pymc_core.companion import CompanionRadio +from pymc_core.companion.constants import ADV_TYPE_CHAT +from pymc_core.companion.models import Contact +from pymc_core.protocol import LocalIdentity + + +def _make_peer_contact(name: str) -> Contact: + """Return a contact with a valid Ed25519 public key (required for packet encryption).""" + peer = LocalIdentity() + return Contact(public_key=peer.get_public_key(), name=name) + + +class MockRadio: + """Mock radio for CompanionRadio: set_rx_callback, send, optional RSSI/SNR.""" + + def __init__(self): + self.rx_callback = None + self.sent: list[bytes] = [] + + def set_rx_callback(self, callback): + self.rx_callback = callback + + async def send(self, data: bytes) -> bool: + self.sent.append(data) + return True + + def get_last_rssi(self): + return -70 + + def get_last_snr(self): + return 5 + + +# --------------------------------------------------------------------------- +# Init and lifecycle +# --------------------------------------------------------------------------- + + +class TestCompanionRadioInit: + def test_init_creates_stores(self): + radio = MockRadio() + identity = LocalIdentity() + comp = CompanionRadio(radio, identity, node_name="TestNode") + assert comp.contacts is not None + assert comp.contacts.get_count() == 0 + assert comp.channels is not None + assert comp.message_queue is not None + assert comp.path_cache is not None + assert comp.stats is not None + assert comp.prefs.node_name == "TestNode" + assert comp.prefs.adv_type == ADV_TYPE_CHAT + assert comp.get_public_key() == identity.get_public_key() + assert comp.node is not None + assert comp.node.dispatcher is not None + + def test_init_passes_contacts_to_node(self): + radio = MockRadio() + identity = LocalIdentity() + comp = CompanionRadio(radio, identity) + comp.contacts.add(Contact(public_key=b"\x01" * 32, name="Alice")) + assert comp.node.contacts is comp.contacts + assert comp.node.contacts.get_by_name("Alice") is not None + + +@pytest.mark.asyncio +class TestCompanionRadioLifecycle: + async def test_start_stop(self): + radio = MockRadio() + identity = LocalIdentity() + comp = CompanionRadio(radio, identity) + assert comp.is_running is False + await comp.start() + assert comp.is_running is True + await comp.stop() + assert comp.is_running is False + + async def test_start_idempotent_warning(self, caplog): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + await comp.start() + await comp.start() + await comp.stop() + assert "already running" in caplog.text.lower() or True + + +# --------------------------------------------------------------------------- +# Contact management (base API via radio) +# --------------------------------------------------------------------------- + + +class TestCompanionRadioContacts: + def test_add_and_get_contact(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + key = b"\x02" * 32 + comp.add_update_contact(Contact(public_key=key, name="Bob")) + assert comp.get_contact_by_key(key) is not None + assert comp.get_contact_by_key(key).name == "Bob" + assert comp.get_contact_by_name("Bob") is not None + + def test_import_contact_packet_data(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + # 73 bytes: 32 key + 1 adv_type + 32 name (padded) + 4 lat + 4 lon + name_padded = b"Charlie\x00" * 4 # 32 bytes + packet_data = b"\x03" * 32 + bytes([1]) + name_padded + (0).to_bytes(4, "little") * 2 + assert comp.import_contact(packet_data) is True + contacts = comp.get_contacts() + assert len(contacts) == 1 + assert contacts[0].name.startswith("Charlie") + + def test_export_contact_self(self): + radio = MockRadio() + identity = LocalIdentity() + comp = CompanionRadio(radio, identity, node_name="Me") + data = comp.export_contact(None) + assert data is not None + assert len(data) >= 73 + assert data[:32] == identity.get_public_key() + + +# --------------------------------------------------------------------------- +# Advertise +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionRadioAdvertise: + async def test_advertise_sends_packet(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + result = await comp.advertise(flood=True) + assert result is True + assert len(radio.sent) == 1 + assert comp.stats.get_totals()["flood_tx"] == 1 + + async def test_advertise_direct(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + await comp.advertise(flood=False) + assert len(radio.sent) == 1 + assert comp.stats.get_totals()["direct_tx"] == 1 + + +# --------------------------------------------------------------------------- +# Send text (requires contact) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionRadioSendText: + async def test_send_text_message_no_contact(self, caplog): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + result = await comp.send_text_message(b"\x00" * 32, "Hi") + assert result.success is False + assert "contact not found" in caplog.text.lower() or "Contact not found" in caplog.text + + async def test_send_text_message_with_contact_sends_packet(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + contact = _make_peer_contact("Alice") + comp.contacts.add(contact) + result = await comp.send_text_message(contact.public_key, "Hello") + assert len(radio.sent) >= 1 + # success may be False if no ACK (mock radio doesn't echo ACK) + assert result.success is False or result.success is True + + +# --------------------------------------------------------------------------- +# Share contact, channel message, sync message +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionRadioMisc: + async def test_share_contact_not_found(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + result = await comp.share_contact(b"\x00" * 32) + assert result is False + + async def test_share_contact_success(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + key = b"\x22" * 32 + comp.contacts.add(Contact(public_key=key, name="Bob")) + result = await comp.share_contact(key) + assert result is True + assert len(radio.sent) == 1 + + async def test_sync_next_message_empty(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + assert comp.sync_next_message() is None + + async def test_send_channel_message_no_channel(self, caplog): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + result = await comp.send_channel_message(0, "Hi") + assert result is False + + +# --------------------------------------------------------------------------- +# Path discovery, trace, control data +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionRadioPathAndControl: + async def test_send_path_discovery_no_contact(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + result = await comp.send_path_discovery(b"\x00" * 32) + assert result is False + + async def test_send_path_discovery_req_sends(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + contact = _make_peer_contact("Target") + comp.contacts.add(contact) + result = await comp.send_path_discovery_req(contact.public_key) + assert result.success is True + assert len(radio.sent) == 1 + + async def test_send_trace_path_raw(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + result = await comp.send_trace_path_raw(0x12345678, 0xABCD, 0, bytes([0x01, 0x02])) + assert result is True + assert len(radio.sent) == 1 + + async def test_send_control_data_default_discovery(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + result = await comp.send_control_data() + assert result is True + assert len(radio.sent) == 1 + + async def test_send_control_data_raw_payload(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + result = await comp.send_control_data(bytes([0x80, 0x04])) + assert result is True + assert len(radio.sent) == 1 + + +# --------------------------------------------------------------------------- +# Stats and config +# --------------------------------------------------------------------------- + + +class TestCompanionRadioStats: + def test_get_stats_core(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + comp.contacts.add(Contact(public_key=b"\x01" * 32, name="A")) + core = comp.get_stats(0) + assert "contacts_count" in core + assert core["contacts_count"] == 1 + assert "queue_len" in core + assert "uptime_secs" in core + + def test_get_stats_packets(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + tot = comp.get_stats(2) + assert "flood_tx" in tot + assert "direct_rx" in tot + assert "tx_errors" in tot + + +# --------------------------------------------------------------------------- +# Binary request and repeater command (delegate to node) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionRadioBinaryAndRepeater: + async def test_send_binary_req_no_contact(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + result = await comp.send_binary_req(b"\x00" * 32, bytes([0x01])) + assert result.success is False + + async def test_send_binary_req_with_contact(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + contact = _make_peer_contact("Rpt") + comp.contacts.add(contact) + result = await comp.send_binary_req(contact.public_key, bytes([0x01]), timeout_seconds=5.0) + assert result.success is True + assert result.expected_ack is not None + assert len(radio.sent) == 1 + + async def test_send_repeater_command_no_contact(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + out = await comp.send_repeater_command(b"\x00" * 32, "status") + assert out["success"] is False + assert "not found" in out["reason"].lower() diff --git a/tests/test_companion_stores.py b/tests/test_companion_stores.py new file mode 100644 index 0000000..c33c48f --- /dev/null +++ b/tests/test_companion_stores.py @@ -0,0 +1,433 @@ +"""Tests for companion stores and models: ContactStore, ChannelStore, MessageQueue, PathCache, StatsCollector.""" + +import pytest + +from pymc_core.companion import ( + ContactStore, + ChannelStore, + MessageQueue, + PathCache, + StatsCollector, +) +from pymc_core.companion.constants import DEFAULT_MAX_CONTACTS, DEFAULT_MAX_CHANNELS +from pymc_core.companion.models import ( + AdvertPath, + Channel, + Contact, + NodePrefs, + PacketStats, + QueuedMessage, + SentResult, +) + + +# --------------------------------------------------------------------------- +# Models +# --------------------------------------------------------------------------- + + +class TestContact: + def test_contact_defaults(self): + key = b"\x00" * 32 + c = Contact(public_key=key, name="alice") + assert c.public_key == key + assert c.name == "alice" + assert c.adv_type == 0 + assert c.out_path_len == -1 + assert c.out_path == b"" + assert c.gps_lat == 0.0 + assert c.gps_lon == 0.0 + + def test_contact_with_path(self): + c = Contact( + public_key=b"\x01" * 32, + name="bob", + out_path_len=3, + out_path=bytes([0xAA, 0xBB, 0xCC]), + ) + assert c.out_path_len == 3 + assert c.out_path == bytes([0xAA, 0xBB, 0xCC]) + + +class TestSentResult: + def test_sent_result_minimal(self): + r = SentResult(success=False) + assert r.success is False + assert r.is_flood is False + assert r.expected_ack is None + assert r.timeout_ms is None + + def test_sent_result_full(self): + r = SentResult(success=True, is_flood=True, expected_ack=0x1234, timeout_ms=5000) + assert r.success is True + assert r.is_flood is True + assert r.expected_ack == 0x1234 + assert r.timeout_ms == 5000 + + +class TestNodePrefs: + def test_node_prefs_defaults(self): + p = NodePrefs() + assert p.node_name == "pyMC" + assert p.adv_type == 1 + assert p.tx_power_dbm == 20 + assert p.frequency_hz == 915000000 + + def test_node_prefs_custom(self): + p = NodePrefs(node_name="TestNode", adv_type=2) + assert p.node_name == "TestNode" + assert p.adv_type == 2 + + +class TestQueuedMessage: + def test_queued_message_direct(self): + key = b"\x02" * 32 + msg = QueuedMessage(sender_key=key, text="Hello", timestamp=1000) + assert msg.sender_key == key + assert msg.text == "Hello" + assert msg.timestamp == 1000 + assert msg.is_channel is False + assert msg.channel_idx == 0 + + def test_queued_message_channel(self): + msg = QueuedMessage( + sender_key=b"", + text="Sender: Hi", + is_channel=True, + channel_idx=2, + ) + assert msg.is_channel is True + assert msg.channel_idx == 2 + + +class TestAdvertPath: + def test_advert_path(self): + prefix = b"\x03" * 7 + ap = AdvertPath( + public_key_prefix=prefix, + name="sensor1", + path_len=2, + path=bytes([1, 2]), + recv_timestamp=12345, + ) + assert ap.public_key_prefix == prefix + assert ap.name == "sensor1" + assert ap.path_len == 2 + assert ap.path == bytes([1, 2]) + assert ap.recv_timestamp == 12345 + + +# --------------------------------------------------------------------------- +# ContactStore +# --------------------------------------------------------------------------- + + +class TestContactStore: + def test_empty_store(self): + store = ContactStore(max_contacts=10) + assert store.get_count() == 0 + assert store.get_all() == [] + assert store.get_by_key(b"\x00" * 32) is None + assert store.get_by_name("nobody") is None + assert store.is_full() is False + assert store.contacts == [] + assert store.list_contacts() == [] + + def test_add_and_get_by_key(self): + store = ContactStore(max_contacts=5) + key = b"\x11" * 32 + contact = Contact(public_key=key, name="Alice") + assert store.add(contact) is True + assert store.get_count() == 1 + assert store.get_by_key(key) is contact + assert store.get_by_key(b"\x22" * 32) is None + + def test_add_and_get_by_name(self): + store = ContactStore(max_contacts=5) + key = b"\x11" * 32 + contact = Contact(public_key=key, name="Bob") + store.add(contact) + proxy = store.get_by_name("Bob") + assert proxy is not None + assert proxy.name == "Bob" + assert proxy.public_key == key.hex() + assert store.get_by_name("Charlie") is None + + def test_update_existing(self): + store = ContactStore(max_contacts=5) + key = b"\x11" * 32 + store.add(Contact(public_key=key, name="Alice")) + updated = Contact(public_key=key, name="AliceUpdated", gps_lat=1.0) + assert store.update(updated) is True + c = store.get_by_key(key) + assert c.name == "AliceUpdated" + assert c.gps_lat == 1.0 + + def test_remove(self): + store = ContactStore(max_contacts=5) + key = b"\x11" * 32 + store.add(Contact(public_key=key, name="Alice")) + assert store.remove(key) is True + assert store.get_count() == 0 + assert store.get_by_key(key) is None + assert store.remove(key) is False + + def test_max_contacts(self): + store = ContactStore(max_contacts=2) + store.add(Contact(public_key=b"\x01" * 32, name="A")) + store.add(Contact(public_key=b"\x02" * 32, name="B")) + assert store.add(Contact(public_key=b"\x03" * 32, name="C")) is False + assert store.get_count() == 2 + assert store.is_full() is True + + def test_get_all_since(self): + store = ContactStore(max_contacts=10) + store.add(Contact(public_key=b"\x01" * 32, name="A", lastmod=100)) + store.add(Contact(public_key=b"\x02" * 32, name="B", lastmod=200)) + store.add(Contact(public_key=b"\x03" * 32, name="C", lastmod=150)) + all_c = store.get_all() + assert len(all_c) == 3 + since_150 = store.get_all(since=150) + assert len(since_150) == 2 + + def test_clear(self): + store = ContactStore(max_contacts=5) + store.add(Contact(public_key=b"\x01" * 32, name="A")) + store.clear() + assert store.get_count() == 0 + assert store.get_by_name("A") is None + + def test_load_from(self): + store = ContactStore(max_contacts=10) + contacts = [ + Contact(public_key=bytes([i] * 32), name=f"C{i}") for i in range(3) + ] + store.load_from(contacts) + assert store.get_count() == 3 + assert store.get_by_name("C1").name == "C1" + + def test_load_from_dicts(self): + store = ContactStore(max_contacts=10) + store.load_from_dicts([ + {"public_key": "a1" * 32, "name": "DictAlice"}, + {"public_key": "b2" * 32, "name": "DictBob"}, + ]) + assert store.get_count() == 2 + assert store.get_by_name("DictAlice") is not None + assert store.get_by_name("DictBob") is not None + + def test_to_dicts(self): + store = ContactStore(max_contacts=5) + store.add(Contact(public_key=b"\xaa" * 32, name="Export", adv_type=1)) + dicts = store.to_dicts() + assert len(dicts) == 1 + assert dicts[0]["name"] == "Export" + assert dicts[0]["public_key"] == "aa" * 32 + assert dicts[0]["adv_type"] == 1 + + def test_get_by_key_prefix(self): + store = ContactStore(max_contacts=5) + key = b"\x11\x22\x33" + b"\x00" * 29 + store.add(Contact(public_key=key, name="Prefix")) + assert store.get_by_key_prefix(b"\x11\x22") is not None + assert store.get_by_key_prefix(b"\x11\x22\x33").name == "Prefix" + assert store.get_by_key_prefix(b"\xff\xff") is None + + +# --------------------------------------------------------------------------- +# ChannelStore +# --------------------------------------------------------------------------- + + +class TestChannelStore: + def test_empty_channels(self): + store = ChannelStore(max_channels=8) + assert store.get_count() == 0 + assert store.get(0) is None + assert store.get_channels() == [] + assert store.find_by_name("any") is None + + def test_set_and_get(self): + store = ChannelStore(max_channels=8) + ch = Channel(name="general", secret=b"\x11" * 16) + assert store.set(0, ch) is True + assert store.get(0) is ch + assert store.get_count() == 1 + assert store.get_channels() == [{"name": "general", "secret": "11" * 16}] + + def test_find_by_name(self): + store = ChannelStore(max_channels=8) + store.set(0, Channel(name="alpha", secret=b"\x00" * 16)) + store.set(1, Channel(name="beta", secret=b"\x01" * 16)) + assert store.find_by_name("alpha") == 0 + assert store.find_by_name("beta") == 1 + assert store.find_by_name("gamma") is None + + def test_remove(self): + store = ChannelStore(max_channels=8) + store.set(0, Channel(name="x", secret=b"\x00" * 16)) + assert store.remove(0) is True + assert store.get(0) is None + assert store.remove(0) is False + assert store.remove(99) is False + + def test_clear(self): + store = ChannelStore(max_channels=8) + store.set(0, Channel(name="a", secret=b"\x00" * 16)) + store.clear() + assert store.get_count() == 0 + assert store.get(0) is None + + def test_out_of_range(self): + store = ChannelStore(max_channels=4) + ch = Channel(name="x", secret=b"\x00" * 16) + assert store.set(-1, ch) is False + assert store.set(4, ch) is False + assert store.get(4) is None + + +# --------------------------------------------------------------------------- +# MessageQueue +# --------------------------------------------------------------------------- + + +class TestMessageQueue: + def test_empty_queue(self): + q = MessageQueue(max_size=5) + assert q.count == 0 + assert q.is_empty() is True + assert q.is_full() is False + assert q.pop() is None + assert q.peek() is None + + def test_push_and_pop(self): + q = MessageQueue(max_size=5) + msg = QueuedMessage(sender_key=b"\x00" * 32, text="Hi") + q.push(msg) + assert q.count == 1 + assert q.peek() is msg + assert q.pop() is msg + assert q.count == 0 + assert q.pop() is None + + def test_maxlen_drops_oldest(self): + q = MessageQueue(max_size=2) + q.push(QueuedMessage(sender_key=b"\x01" * 32, text="1")) + q.push(QueuedMessage(sender_key=b"\x02" * 32, text="2")) + q.push(QueuedMessage(sender_key=b"\x03" * 32, text="3")) + assert q.count == 2 + first = q.pop() + assert first.text == "2" + assert q.pop().text == "3" + + def test_clear(self): + q = MessageQueue(max_size=5) + q.push(QueuedMessage(sender_key=b"\x00" * 32, text="x")) + q.clear() + assert q.count == 0 + assert q.pop() is None + + +# --------------------------------------------------------------------------- +# PathCache +# --------------------------------------------------------------------------- + + +class TestPathCache: + def test_empty_cache(self): + cache = PathCache(max_entries=8) + assert cache.get_all() == [] + assert cache.get_by_prefix(b"\x00" * 7) is None + + def test_update_and_get(self): + cache = PathCache(max_entries=8) + ap = AdvertPath( + public_key_prefix=b"\x01" * 7, + name="n1", + path_len=2, + path=bytes([1, 2]), + recv_timestamp=100, + ) + cache.update(ap) + assert len(cache.get_all()) == 1 + found = cache.get_by_prefix(b"\x01" * 5) + assert found is not None + assert found.name == "n1" + assert found.path == bytes([1, 2]) + + def test_update_replaces_same_prefix(self): + cache = PathCache(max_entries=8) + prefix = b"\x02" * 7 + cache.update(AdvertPath(public_key_prefix=prefix, name="v1", path=bytes([1]))) + cache.update(AdvertPath(public_key_prefix=prefix, name="v2", path=bytes([2, 2]))) + assert len(cache.get_all()) == 1 + assert cache.get_by_prefix(prefix).name == "v2" + assert cache.get_by_prefix(prefix).path == bytes([2, 2]) + + def test_eviction_when_full(self): + cache = PathCache(max_entries=2) + cache.update(AdvertPath(public_key_prefix=b"\x01" * 7, name="1", path=b"")) + cache.update(AdvertPath(public_key_prefix=b"\x02" * 7, name="2", path=b"")) + cache.update(AdvertPath(public_key_prefix=b"\x03" * 7, name="3", path=b"")) + assert len(cache.get_all()) == 2 + assert cache.get_by_prefix(b"\x01" * 7) is None + assert cache.get_by_prefix(b"\x02" * 7) is not None + assert cache.get_by_prefix(b"\x03" * 7) is not None + + def test_clear(self): + cache = PathCache(max_entries=8) + cache.update(AdvertPath(public_key_prefix=b"\x01" * 7, name="x", path=b"")) + cache.clear() + assert cache.get_all() == [] + assert cache.get_by_prefix(b"\x01" * 7) is None + + +# --------------------------------------------------------------------------- +# StatsCollector +# --------------------------------------------------------------------------- + + +class TestStatsCollector: + def test_initial_state(self): + s = StatsCollector() + assert s.packets.flood_tx == 0 + assert s.packets.direct_rx == 0 + assert s.packets.tx_errors == 0 + assert s.get_uptime_secs() >= 0 + + def test_record_tx_rx(self): + s = StatsCollector() + s.record_tx(is_flood=True) + s.record_tx(is_flood=True) + s.record_tx(is_flood=False) + s.record_rx(is_flood=False) + s.record_rx(is_flood=True) + assert s.packets.flood_tx == 2 + assert s.packets.direct_tx == 1 + assert s.packets.direct_rx == 1 + assert s.packets.flood_rx == 1 + + def test_record_tx_error(self): + s = StatsCollector() + s.record_tx_error() + s.record_tx_error() + assert s.packets.tx_errors == 2 + + def test_get_totals(self): + s = StatsCollector() + s.record_tx(is_flood=True) + s.record_rx(is_flood=False) + tot = s.get_totals() + assert tot["flood_tx"] == 1 + assert tot["direct_rx"] == 1 + assert tot["total_tx"] == 1 + assert tot["total_rx"] == 1 + assert "uptime_secs" in tot + + def test_reset(self): + s = StatsCollector() + s.record_tx(is_flood=True) + s.record_tx_error() + s.reset() + assert s.packets.flood_tx == 0 + assert s.packets.tx_errors == 0 From 5c4e7abbe36eb287daf125948152bf7e68165339 Mon Sep 17 00:00:00 2001 From: agessaman Date: Sat, 14 Feb 2026 20:29:31 -0800 Subject: [PATCH 05/50] Refactor logging and enhance message handling in CompanionBase and PathHandler - Updated logging in CompanionBase to skip debug messages for small payloads, improving log clarity. - Added a new method in MessageQueue to pop the most recent message, enhancing message retrieval capabilities. - Simplified PATH packet analysis logging in PathHandler, consolidating log statements for better readability and efficiency. - Removed redundant logging in ProtocolResponseHandler to streamline response processing. --- src/pymc_core/companion/companion_base.py | 5 ++- src/pymc_core/companion/message_queue.py | 6 ++++ src/pymc_core/node/handlers/path.py | 36 +++---------------- .../node/handlers/protocol_response.py | 20 +---------- 4 files changed, 16 insertions(+), 51 deletions(-) diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py index 974fced..4e28ade 100644 --- a/src/pymc_core/companion/companion_base.py +++ b/src/pymc_core/companion/companion_base.py @@ -517,7 +517,9 @@ async def _on_binary_response( tag_hex = tag_bytes.hex() info = self._pending_binary_requests.pop(tag_hex, None) if not info: - logger.debug(f"Binary response for unknown tag {tag_hex}") + # Skip log for small payloads (e.g. login response already handled by LoginResponseHandler) + if len(response_data) >= 20: + logger.debug(f"Binary response for unknown tag {tag_hex}") await self._fire_callbacks("binary_response", tag_bytes, response_data) return request_type = info["request_type"] @@ -628,6 +630,7 @@ async def _handle_new_channel_message(self, data: dict) -> None: display_text, msg.timestamp, path_len, + channel_idx, ) async def _fire_callbacks(self, event_name: str, *args: Any) -> None: diff --git a/src/pymc_core/companion/message_queue.py b/src/pymc_core/companion/message_queue.py index 4452759..8803eb5 100644 --- a/src/pymc_core/companion/message_queue.py +++ b/src/pymc_core/companion/message_queue.py @@ -34,6 +34,12 @@ def pop(self) -> Optional[QueuedMessage]: return self._queue.popleft() return None + def pop_last(self) -> Optional[QueuedMessage]: + """Remove and return the most recently pushed message, or None if empty.""" + if self._queue: + return self._queue.pop() + return None + def peek(self) -> Optional[QueuedMessage]: """Return the oldest message without removing it, or None if empty.""" if self._queue: diff --git a/src/pymc_core/node/handlers/path.py b/src/pymc_core/node/handlers/path.py index 225190c..f9debf7 100644 --- a/src/pymc_core/node/handlers/path.py +++ b/src/pymc_core/node/handlers/path.py @@ -65,20 +65,14 @@ async def __call__(self, pkt: Packet) -> None: # Optional PATH packet analysis if analyzer is available try: - # Try to use any available packet analyzer through callback if hasattr(self, "_dispatcher") and hasattr( self._dispatcher, "packet_analysis_callback" - ): - if self._dispatcher.packet_analysis_callback: - self._dispatcher.packet_analysis_callback(pkt) - self._log("PATH packet analysis delegated to app") - else: - self._log("PATH packet received - hop analysis requires app-level analyzer") - + ) and self._dispatcher.packet_analysis_callback: + self._dispatcher.packet_analysis_callback(pkt) except Exception as e: self._log(f"PATH packet analysis failed: {e}") - # Extract and log key PATH information directly from packet + # Single summary line for PATH packet try: payload = pkt.get_payload() hop_count = pkt.path_len @@ -87,30 +81,10 @@ async def __call__(self, pkt: Packet) -> None: src_hash = payload[1] self._log( f"PATH packet: hop_count={hop_count}, " - f"dest=0x{dest_hash:02X}, src=0x{src_hash:02X}, " - f"payload_len={len(payload)}" + f"dest=0x{dest_hash:02X}, src=0x{src_hash:02X}, payload_len={len(payload)}" ) - if hop_count > 0: - self._log(f"Path contains {hop_count} hops") - else: - self._log("Direct PATH (no intermediate hops)") else: - self._log("PATH packet received with minimal payload") - - # Log basic routing behavior based on header - try: - # These constants are already imported at the top - # from ...protocol.constants import ( - # ROUTE_TYPE_DIRECT, - # ROUTE_TYPE_FLOOD, - # ) - - # Extract route type from packet header if possible - # This is a simplified version without full analysis - self._log("PATH packet routing analysis requires app-level analyzer") - except ImportError: - pass - + self._log("PATH packet: minimal payload") except Exception as e: self._log(f"Error extracting PATH information: {e}") diff --git a/src/pymc_core/node/handlers/protocol_response.py b/src/pymc_core/node/handlers/protocol_response.py index 2799a8a..a1b86d5 100644 --- a/src/pymc_core/node/handlers/protocol_response.py +++ b/src/pymc_core/node/handlers/protocol_response.py @@ -150,11 +150,6 @@ async def __call__(self, pkt: Packet) -> None: if src_hash not in self._response_callbacks and self._binary_response_callback is None: return - self._log( - "[ProtocolResponse] Processing potential protocol response " - f"from 0x{src_hash:02X}, payload_len={len(pkt.payload)}" - ) - # Try to decrypt the response success, decoded_text, parsed_data, raw_decrypted = await self._decrypt_protocol_response( pkt, src_hash @@ -266,10 +261,6 @@ async def _decrypt_protocol_response( # Determine the actual payload type from the incoming packet header. pkt_type = (pkt.header >> 2) & 0x0F - self._log( - f"[ProtocolResponse] Decrypted {len(decrypted)} bytes from " - f"pkt_type=0x{pkt_type:02X}, hex: {decrypted.hex()}" - ) # Extract the actual response data based on packet type. response_data = decrypted @@ -287,11 +278,7 @@ async def _decrypt_protocol_response( extra_type = decrypted[1 + path_len_byte] & 0x0F if extra_type == PAYLOAD_TYPE_RESPONSE and len(decrypted) > inner_offset: response_data = decrypted[inner_offset:] - self._log( - f"[ProtocolResponse] PATH format: extracted inner response " - f"{len(response_data)} bytes (path_len={path_len_byte})" - ) - else: + elif extra_type != PAYLOAD_TYPE_RESPONSE: self._log( f"[ProtocolResponse] PATH format: extra_type=0x{extra_type:02X}, " f"not RESPONSE" @@ -315,11 +302,6 @@ def _parse_protocol_response(self, data: bytes) -> tuple[bool, str, Dict[str, An 4. Binary fallback """ try: - self._log( - f"[ProtocolResponse] _parse_protocol_response: {len(data)} bytes, " - f"first 16: {data[:16].hex() if len(data) >= 16 else data.hex()}" - ) - # 1. Check if this looks like a stats response (protocol 0x01) # RepeaterStats is 48-56 bytes + 4-byte tag. Older firmware # omits n_recv_errors (52 B struct → 56 total); PATH-wrapped From 13cd3c9e9353c44bfe5d2ff58d7fea42ab821429 Mon Sep 17 00:00:00 2001 From: agessaman Date: Sat, 14 Feb 2026 20:41:13 -0800 Subject: [PATCH 06/50] Add deduplication for direct and channel messages in CompanionBase - Implemented a mechanism to deduplicate incoming direct and channel messages by packet hash, preventing the same packet from being queued multiple times during reconnects. - Increased the default offline queue size to accommodate more messages. - Updated the TextMessageHandler to include packet hash in message data for deduplication purposes. - Added unit tests to verify the deduplication functionality in the CompanionBridge. --- src/pymc_core/companion/companion_base.py | 17 +++++++++++++ src/pymc_core/companion/constants.py | 2 +- src/pymc_core/node/handlers/text.py | 1 + tests/test_companion_bridge.py | 30 +++++++++++++++++++++++ 4 files changed, 49 insertions(+), 1 deletion(-) diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py index 4e28ade..6995038 100644 --- a/src/pymc_core/companion/companion_base.py +++ b/src/pymc_core/companion/companion_base.py @@ -170,6 +170,10 @@ def _init_companion_stores( self._seen_grp_txt: OrderedDict[str, float] = OrderedDict() self._seen_grp_txt_ttl = 300 self._seen_grp_txt_max = 1000 + # TXT_MSG (direct) dedup by packet hash so reconnects don't queue the same packet multiple times. + self._seen_txt: OrderedDict[str, float] = OrderedDict() + self._seen_txt_ttl = 300 + self._seen_txt_max = 1000 # ------------------------------------------------------------------------- # Contact Management @@ -565,6 +569,19 @@ async def _handle_mesh_event(self, event_type: str, data: dict) -> None: logger.error(f"Error handling mesh event {event_type}: {e}") async def _handle_new_message(self, data: dict) -> None: + # Deduplicate by packet hash so reconnects don't queue the same packet multiple times. + pkt_hash = data.get("packet_hash") + if pkt_hash: + now = time.time() + if pkt_hash in self._seen_txt: + return + expired = [k for k, ts in self._seen_txt.items() if now - ts > self._seen_txt_ttl] + for k in expired: + del self._seen_txt[k] + self._seen_txt[pkt_hash] = now + if len(self._seen_txt) > self._seen_txt_max: + self._seen_txt.popitem(last=False) + sender_key_hex = data.get("contact_pubkey", "") sender_key = bytes.fromhex(sender_key_hex) if sender_key_hex else b"" # Handler publishes "message_text"; accept "text" for compatibility diff --git a/src/pymc_core/companion/constants.py b/src/pymc_core/companion/constants.py index fa29674..94a4d69 100644 --- a/src/pymc_core/companion/constants.py +++ b/src/pymc_core/companion/constants.py @@ -67,7 +67,7 @@ class BinaryReqType: # Default configuration # --------------------------------------------------------------------------- DEFAULT_MAX_CONTACTS = 1000 -DEFAULT_OFFLINE_QUEUE_SIZE = 16 +DEFAULT_OFFLINE_QUEUE_SIZE = 512 DEFAULT_MAX_CHANNELS = 40 CONTACT_NAME_SIZE = 32 MAX_SIGN_DATA_SIZE = 8192 # 8KB signing buffer (matches firmware) diff --git a/src/pymc_core/node/handlers/text.py b/src/pymc_core/node/handlers/text.py index 2984c52..f3a1c4c 100644 --- a/src/pymc_core/node/handlers/text.py +++ b/src/pymc_core/node/handlers/text.py @@ -202,6 +202,7 @@ async def send_delayed_ack(): }, "sender_name": matched_contact.name, "is_read": False, + "packet_hash": packet.calculate_packet_hash().hex().upper(), } # Publish new message event for app to handle database storage diff --git a/tests/test_companion_bridge.py b/tests/test_companion_bridge.py index 4cd2c80..fd0a428 100644 --- a/tests/test_companion_bridge.py +++ b/tests/test_companion_bridge.py @@ -4,6 +4,7 @@ from pymc_core.companion import CompanionBridge from pymc_core.companion.models import Contact +from pymc_core.node.events import MeshEvents from pymc_core.protocol import LocalIdentity, Packet from pymc_core.protocol.constants import ( PAYLOAD_TYPE_ADVERT, @@ -243,3 +244,32 @@ async def test_send_binary_req_with_contact(self): assert result.success is True assert result.expected_ack is not None assert len(injector.calls) == 1 + + +# --------------------------------------------------------------------------- +# Deduplication (direct messages by packet_hash) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionBridgeDeduplication: + async def test_direct_message_deduplicated_by_packet_hash(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + key_hex = LocalIdentity().get_public_key().hex() + same_hash = "A1B2C3D4E5F6" + data = { + "contact_pubkey": key_hex, + "message_text": "Hello", + "timestamp": 1000, + "txt_type": 0, + "packet_hash": same_hash, + } + await bridge._handle_mesh_event(MeshEvents.NEW_MESSAGE, data) + await bridge._handle_mesh_event(MeshEvents.NEW_MESSAGE, data) + await bridge._handle_mesh_event(MeshEvents.NEW_MESSAGE, data) + assert bridge.message_queue.count == 1 + msg = bridge.sync_next_message() + assert msg is not None + assert msg.text == "Hello" + assert bridge.sync_next_message() is None From 02f187cec1c5556e4142aea5d3c1d4351a20ff8a Mon Sep 17 00:00:00 2001 From: agessaman Date: Sat, 14 Feb 2026 22:04:11 -0800 Subject: [PATCH 07/50] Add Companion Guide and refactor companion modules - Added a new Companion Guide to the documentation for better user guidance. - Refactored several companion modules to improve type hinting and code clarity, including the addition of return types for methods. - Enhanced the CompanionBase class with new methods for managing contacts and advertisements, improving overall functionality. - Updated constants and models to support new features and ensure consistency across the codebase. --- docs/docs/companion.md | 628 ++++++++++++++++++++ docs/mkdocs.yml | 1 + src/pymc_core/companion/binary_parsing.py | 2 + src/pymc_core/companion/channel_store.py | 4 +- src/pymc_core/companion/companion_base.py | 338 ++++++++++- src/pymc_core/companion/companion_bridge.py | 227 +------ src/pymc_core/companion/companion_radio.py | 218 +------ src/pymc_core/companion/constants.py | 14 +- src/pymc_core/companion/contact_store.py | 10 +- src/pymc_core/companion/message_queue.py | 4 +- src/pymc_core/companion/models.py | 2 + src/pymc_core/companion/path_cache.py | 24 +- src/pymc_core/companion/stats_collector.py | 12 +- src/pymc_core/protocol/constants.py | 1 + tests/test_companion_base.py | 3 +- 15 files changed, 1040 insertions(+), 448 deletions(-) create mode 100644 docs/docs/companion.md diff --git a/docs/docs/companion.md b/docs/docs/companion.md new file mode 100644 index 0000000..e78b650 --- /dev/null +++ b/docs/docs/companion.md @@ -0,0 +1,628 @@ +# Companion Module + +The companion module provides a high-level Python interface to the MeshCore companion radio protocol. It manages contacts, messaging, channels, advertisements, path routing, telemetry, cryptographic signing, and device configuration on top of pyMC_core's `MeshNode`. + +Two implementations are provided: + +| Class | Owns Radio | Use Case | +|---|---|---| +| `CompanionRadio` | Yes | Standalone companion — wraps a hardware radio and `MeshNode` | +| `CompanionBridge` | No | Repeater-integrated companion — shares an existing dispatcher via a packet injector callback | + +Both inherit from `CompanionBase` (an abstract base class), which holds all shared stores, event handling, device configuration logic, and unified TX methods (advertising, binary requests, path discovery, offline queue sync). Subclasses implement transport via the abstract `_send_packet` method. + +--- + +## Architecture + +``` +CompanionBase (ABC) +├── ContactStore (in-memory contacts, max 1000) +├── ChannelStore (group channels, max 40) +├── MessageQueue (offline FIFO, max 512) +├── PathCache (recent advert paths, max 16) +├── StatsCollector (TX/RX counters, uptime) +├── NodePrefs (radio params, name, location) +│ +│ Unified methods (use abstract _send_packet): +│ advertise, share_contact, send_binary_req, +│ send_path_discovery_req, send_trace_path_raw, +│ sync_next_message +│ +├─► CompanionRadio (owns MeshNode + hardware radio) +└─► CompanionBridge (packet_injector callback, no radio) +``` + +--- + +## Installation + +```bash +pip install pymc_core # core only +pip install pymc_core[hardware] # SX1262 direct radio support +pip install pymc_core[all] # everything +``` + +--- + +## CompanionRadio + +`CompanionRadio` is a standalone companion that owns a hardware radio and a `MeshNode`. It is the typical entry point for building a chat application, sensor gateway, or automation tool that participates in a MeshCore network. + +### Quick Start + +```python +import asyncio +from pymc_core import LocalIdentity +from pymc_core.companion import ( + CompanionRadio, + ADV_TYPE_CHAT, + STATS_TYPE_PACKETS, +) + +async def main(): + # --- Setup --- + from pymc_core.hardware import KissModemWrapper + + radio = KissModemWrapper("/dev/ttyUSB0") + radio.connect() + + identity = LocalIdentity() # generates a new Ed25519 keypair + companion = CompanionRadio( + radio=radio, + identity=identity, + node_name="myNode", + adv_type=ADV_TYPE_CHAT, + ) + + # --- Register callbacks before starting --- + companion.on_message_received(on_msg) + companion.on_advert_received(on_advert) + companion.on_channel_message_received(on_chan_msg) + companion.on_send_confirmed(on_ack) + + await companion.start() + + # --- Advertise presence --- + await companion.advertise(flood=True) + + # --- Send a direct message --- + dest = bytes.fromhex("ab" * 32) # 32-byte public key + result = await companion.send_text_message(dest, "Hello mesh!") + print(f"Sent: success={result.success}, flood={result.is_flood}") + + # --- Keep running --- + try: + while True: + await asyncio.sleep(1) + finally: + await companion.stop() + + +# --- Callbacks --- +def on_msg(sender_key, text, timestamp, txt_type): + print(f"DM from {sender_key[:8].hex()}: {text}") + +def on_advert(contact): + print(f"Discovered: {contact.name} (type={contact.adv_type})") + +def on_chan_msg(channel_name, sender_name, text, timestamp, path_len, channel_idx): + print(f"[{channel_name}] {sender_name}: {text}") + +def on_ack(ack_crc): + print(f"ACK confirmed: {ack_crc:#x}") + + +asyncio.run(main()) +``` + +### Constructor + +```python +CompanionRadio( + radio, # hardware radio wrapper + identity: LocalIdentity, # Ed25519 identity + node_name: str = "pyMC", + adv_type: int = ADV_TYPE_CHAT, # 1=chat, 2=repeater, 3=room, 4=sensor + max_contacts: int = 1000, + max_channels: int = 40, + offline_queue_size: int = 512, + radio_config: dict | None = None, +) +``` + +### Lifecycle + +```python +await companion.start() # start dispatcher task +await companion.stop() # cancel dispatcher +companion.is_running # bool property +``` + +### Messaging + +```python +# Direct text message +result = await companion.send_text_message(pub_key, "hello", txt_type=0, attempt=1) +# result: SentResult(success, is_flood, expected_ack, timeout_ms) + +# Group channel message +ok = await companion.send_channel_message(channel_idx=0, text="hello group") + +# Pop oldest queued offline message +msg = companion.sync_next_message() # -> QueuedMessage | None + +# Raw binary data (direct path only) +result = await companion.send_raw_data(dest_key, data=b"\x01\x02", path=None) +``` + +### Advertisements + +```python +await companion.advertise(flood=True) # broadcast presence +await companion.share_contact(pub_key) # share a contact via direct advert +``` + +### Contact Management + +```python +from pymc_core.companion import Contact + +# List / lookup +contacts = companion.get_contacts(since=0) +contact = companion.get_contact_by_key(pub_key_bytes) +contact = companion.get_contact_by_name("Alice") + +# Add / update / remove +companion.add_update_contact(Contact(public_key=key, name="Bob")) +companion.remove_contact(pub_key_bytes) + +# Reset routing path (force re-discovery) +companion.reset_path(pub_key_bytes) + +# Serialise for sharing +blob = companion.export_contact(pub_key) # bytes +ok = companion.import_contact(blob) # bool +``` + +### Channel Management + +```python +from pymc_core.companion import Channel + +companion.set_channel(0, name="General", secret=b"shared_secret_key_here__________") +ch = companion.get_channel(0) # -> Channel | None +``` + +### Path Discovery & Tracing + +```python +# Path discovery (returns SentResult, fires on_path_discovery_response callback) +result = await companion.send_path_discovery_req(pub_key) + +# Trace path through the network +ok = await companion.send_trace_path(pub_key, tag=42, auth_code=0, flags=0) + +# Get recently heard advert path +advert_path = companion.get_advert_path(pub_key_prefix_7bytes) +``` + +### Repeater Interaction + +```python +# Login to a repeater +resp = await companion.send_login(repeater_key, password="secret") + +# Request repeater status +resp = await companion.send_status_request(repeater_key) + +# Request telemetry +resp = await companion.send_telemetry_request( + repeater_key, + want_base=True, + want_location=True, + want_environment=False, + timeout=10.0, +) + +# Send a text command to a repeater +resp = await companion.send_repeater_command(repeater_key, command="status") +``` + +### Binary Requests + +The generic binary request/response mechanism uses random 4-byte tags for matching. + +```python +# Send and wait for response +result = await companion.send_binary_req(pub_key, data=b"\x01", timeout_seconds=10) + +# Register a callback for responses +companion.on_binary_response( + lambda tag, data, parsed, req_type: print(f"Response: {parsed}") +) +``` + +### Device Configuration + +```python +companion.set_advert_name("NewName") # max 31 chars +companion.set_advert_latlon(37.7749, -122.4194) # GPS coordinates +companion.set_radio_params(915_000_000, 250_000, 10, 5) # freq, bw, SF, CR +companion.set_tx_power(22) # dBm +companion.set_tuning_params(rx_delay=0.0, airtime_factor=0.0) + +# Location sharing in adverts +from pymc_core.companion import ADVERT_LOC_SHARE +companion.set_other_params( + manual_add=0, + telemetry_modes=(0, 0, 0), + advert_loc_policy=ADVERT_LOC_SHARE, + multi_acks=0, +) + +prefs = companion.get_self_info() # -> NodePrefs +``` + +### Cryptographic Signing + +```python +buf_size = companion.sign_start() # returns max buffer size (8192) +companion.sign_data(b"data to sign...") +signature = companion.sign_finish() # -> 64-byte Ed25519 signature +``` + +### Statistics + +```python +from pymc_core.companion import STATS_TYPE_CORE, STATS_TYPE_RADIO, STATS_TYPE_PACKETS + +stats = companion.get_stats(STATS_TYPE_CORE) +# {'uptime': 3600, 'queue_len': 2, 'contacts_count': 15, 'channels_count': 3} + +stats = companion.get_stats(STATS_TYPE_PACKETS) +# {'flood_tx': 42, 'flood_rx': 108, 'direct_tx': 5, 'direct_rx': 12, ...} +``` + +--- + +## CompanionBridge + +`CompanionBridge` is designed for repeater integration. It does not own a radio or `MeshNode` — instead, the repeater host feeds received packets in via `process_received_packet()`, and all outbound packets go through a `packet_injector` callback you provide. This lets a companion identity coexist alongside a repeater on the same radio. + +### Quick Start + +```python +import asyncio +from pymc_core import LocalIdentity +from pymc_core.companion import CompanionBridge, ADV_TYPE_CHAT + +async def main(): + identity = LocalIdentity() + + async def packet_injector(pkt, wait_for_ack=False): + """Forward packet to the repeater's radio.""" + return await my_repeater.send_packet(pkt) + + def authenticate(user_hash, password): + """Validate login attempts. Return (success, acl_bits).""" + if user_hash == expected_hash: + return (True, 0x01) + return (False, 0) + + bridge = CompanionBridge( + identity=identity, + packet_injector=packet_injector, + node_name="myBridge", + adv_type=ADV_TYPE_CHAT, + authenticate_callback=authenticate, + ) + + bridge.on_message_received( + lambda key, text, ts, tt: print(f"Bridge msg: {text}") + ) + + await bridge.start() + + # Feed packets from the repeater's dispatcher + async def on_repeater_rx(packet): + await bridge.process_received_packet(packet) + + # ... register on_repeater_rx with your repeater ... + +asyncio.run(main()) +``` + +### Constructor + +```python +CompanionBridge( + identity: LocalIdentity, + packet_injector: Callable, # async (pkt, wait_for_ack=False) -> bool + node_name: str = "pyMC", + adv_type: int = ADV_TYPE_CHAT, + max_contacts: int = 1000, + max_channels: int = 40, + offline_queue_size: int = 512, + radio_config: dict | None = None, + authenticate_callback: Callable | None = None, # (hash, pw) -> (bool, int) +) +``` + +### RX Entry Point + +```python +# Called by the repeater host for every received packet +await bridge.process_received_packet(packet) +``` + +The bridge registers internal handlers for these payload types: + +| Payload Type | Handler | +|---|---| +| ACK | Bridge ACK handler (matches pending CRCs) | +| TXT_MSG | TextMessageHandler | +| ADVERT | AdvertHandler | +| PATH | PathHandler | +| ANON_REQ | LoginServerHandler | +| GRP_TXT | GroupTextHandler | +| RESPONSE | LoginResponseHandler | + +### All Other APIs + +`CompanionBridge` exposes the same messaging, contact, channel, path, signing, stats, and configuration APIs as `CompanionRadio` (inherited from `CompanionBase`). The only behavioral difference is that all TX goes through the `packet_injector` instead of an owned radio. + +--- + +## Use Cases + +### 1. Chat Application + +Build a terminal or GUI chat client that discovers peers and exchanges messages. + +```python +companion = CompanionRadio(radio, identity, node_name="ChatApp") + +companion.on_message_received(display_message) +companion.on_advert_received(add_to_contact_list) + +await companion.start() +await companion.advertise() + +# User picks a contact and sends +contact = companion.get_contact_by_name("Alice") +await companion.send_text_message(contact.public_key, user_input) +``` + +### 2. Sensor Gateway + +Collect telemetry from sensor nodes in the mesh, forward to a database or MQTT. + +```python +companion = CompanionRadio(radio, identity, node_name="Gateway", adv_type=ADV_TYPE_CHAT) + +async def on_telemetry(event_data): + # event_data contains parsed CayenneLPP sensor readings + publish_to_mqtt(event_data) + +companion.on_telemetry_response(on_telemetry) + +await companion.start() + +# Periodically poll known sensors +for sensor in companion.get_contacts(): + if sensor.adv_type == ADV_TYPE_SENSOR: + await companion.send_telemetry_request( + sensor.public_key, want_base=True, + want_location=True, want_environment=True, timeout=15, + ) +``` + +### 3. Repeater Companion (Bridge Mode) + +Add a companion identity to an existing repeater without a second radio. + +```python +bridge = CompanionBridge( + identity=identity, + packet_injector=repeater.inject_packet, + node_name="RepeaterBot", + authenticate_callback=auth_check, +) + +bridge.on_message_received(handle_bot_command) +await bridge.start() + +# In the repeater's RX loop: +async def repeater_on_rx(pkt): + await bridge.process_received_packet(pkt) + # ... also handle repeater logic ... +``` + +### 4. Network Diagnostics Tool + +Trace paths and discover topology. + +```python +companion.on_trace_received(lambda data: print(f"Trace: {data}")) +companion.on_path_discovery_response( + lambda tag, key, out_path, in_path: print(f"Path to {key.hex()[:8]}: out={out_path}, in={in_path}") +) + +# Trace route to a node +await companion.send_trace_path(target_key, tag=1, auth_code=0) + +# Discover paths +await companion.send_path_discovery_req(target_key) +``` + +### 5. Group Chat / Channels + +```python +companion.set_channel(0, name="Emergency", secret=b"shared_channel_secret___________") +companion.set_channel(1, name="General", secret=b"another_shared_secret___________") + +companion.on_channel_message_received( + lambda ch_name, sender, text, ts, path_len, idx: + print(f"[{ch_name}] {sender}: {text}") +) + +await companion.send_channel_message(0, "Emergency broadcast") +``` + +--- + +## Push Callbacks Reference + +Register callbacks to receive asynchronous events. Both sync and async functions are supported. + +| Registration Method | Callback Signature | +|---|---| +| `on_message_received` | `(sender_key: bytes, text: str, timestamp: int, txt_type: int)` | +| `on_channel_message_received` | `(channel_name: str, sender_name: str, text: str, timestamp: int, path_len: int, channel_idx: int)` | +| `on_advert_received` | `(contact: Contact)` | +| `on_contact_path_updated` | `(contact: Contact)` | +| `on_send_confirmed` | `(ack_crc: int)` | +| `on_trace_received` | `(trace_data)` | +| `on_node_discovered` | `(event_data)` | +| `on_login_result` | `(result_data)` | +| `on_telemetry_response` | `(event_data)` | +| `on_status_response` | `(status_data)` | +| `on_raw_data_received` | `(raw_data)` | +| `on_binary_response` | `(tag: bytes, data: bytes, parsed: dict\|None, request_type: int\|None)` | +| `on_path_discovery_response` | `(tag: bytes, contact_pubkey: bytes, out_path: bytes, in_path: bytes)` | + +--- + +## Models Reference + +### Contact + +```python +@dataclass +class Contact: + public_key: bytes # 32-byte Ed25519 public key + name: str = "" # up to 32 characters + adv_type: int = 0 # ADV_TYPE_CHAT / REPEATER / ROOM / SENSOR + flags: int = 0 + out_path_len: int = -1 # -1=unknown, 0=direct, >0=multi-hop + out_path: bytes = b"" + last_advert_timestamp: int = 0 + lastmod: int = 0 + gps_lat: float = 0.0 + gps_lon: float = 0.0 +``` + +### Channel + +```python +@dataclass +class Channel: + name: str # up to 32 characters + secret: bytes # 16-byte pre-shared key +``` + +### SentResult + +```python +@dataclass +class SentResult: + success: bool + is_flood: bool = False + expected_ack: int | None = None + timeout_ms: int | None = None +``` + +### QueuedMessage + +```python +@dataclass +class QueuedMessage: + sender_key: bytes + txt_type: int = 0 # TXT_TYPE_PLAIN / CLI_DATA / SIGNED_PLAIN + timestamp: int = 0 + text: str = "" + is_channel: bool = False + channel_idx: int = 0 + path_len: int = 0 +``` + +--- + +## Constants + +```python +# Advertisement types +ADV_TYPE_CHAT = 1 +ADV_TYPE_REPEATER = 2 +ADV_TYPE_ROOM = 3 +ADV_TYPE_SENSOR = 4 + +# Text message types +TXT_TYPE_PLAIN = 0 +TXT_TYPE_CLI_DATA = 1 +TXT_TYPE_SIGNED_PLAIN = 2 + +# Telemetry modes +TELEM_MODE_DENY = 0 +TELEM_MODE_ALLOW_FLAGS = 1 +TELEM_MODE_ALLOW_ALL = 2 + +# Location policy +ADVERT_LOC_NONE = 0 +ADVERT_LOC_SHARE = 1 + +# Auto-add config bitmask +AUTOADD_OVERWRITE_OLDEST = 0x01 +AUTOADD_CHAT = 0x02 +AUTOADD_REPEATER = 0x04 +AUTOADD_ROOM = 0x08 +AUTOADD_SENSOR = 0x10 + +# Stats types +STATS_TYPE_CORE = 0 +STATS_TYPE_RADIO = 1 +STATS_TYPE_PACKETS = 2 + +# Binary request types (IntEnum) +BinaryReqType.STATUS # 0x01 +BinaryReqType.KEEP_ALIVE # 0x02 +BinaryReqType.TELEMETRY # 0x03 +BinaryReqType.MMA # 0x04 +BinaryReqType.ACL # 0x05 +BinaryReqType.NEIGHBOURS # 0x06 + +# Protocol codes +PROTOCOL_CODE_RAW_DATA = 0x00 +PROTOCOL_CODE_BINARY_REQ = 0x02 +PROTOCOL_CODE_ANON_REQ = 0x07 + +# Timeouts +DEFAULT_RESPONSE_TIMEOUT_MS = 10000 +``` + +--- + +## Unimplemented MeshCore Companion Features + +The following features from the MeshCore companion radio firmware (`examples/companion_radio/`) are **not yet implemented** in pyMC_core: + +| Feature | Firmware Command | Description | +|---|---|---| +| Device query | `CMD_DEVICE_QUERY` (0x16) | Hardware capability & firmware version handshake | +| App start handshake | `CMD_APP_START` (0x01) | Initial BLE/serial session setup with self-info response | +| Device time get/set | `CMD_GET_DEVICE_TIME` / `CMD_SET_DEVICE_TIME` | RTC clock synchronisation | +| Reboot | `CMD_REBOOT` (0x13) | Remote device reboot (with confirmation string) | +| Factory reset | `CMD_FACTORY_RESET` (0x33) | Erase all data and reset to defaults | +| BLE PIN | `CMD_SET_DEVICE_PIN` (0x25) | Set BLE pairing PIN | +| Battery & storage | `CMD_GET_BATT_AND_STORAGE` (0x14) | Battery voltage and flash storage info | +| Logout | `CMD_LOGOUT` (0x1D) | Disconnect from a server/repeater session | +| Has connection | `CMD_HAS_CONNECTION` (0x1C) | Check if active connection exists to a contact | +| Contact-by-key lookup (protocol) | `CMD_GET_CONTACT_BY_KEY` (0x1E) | Protocol-level single-contact fetch (available in-memory via `get_contact_by_key`) | +| GPS configuration | GPS enable/interval | GPS hardware control and periodic fix interval | +| Data persistence | File I/O (`/contacts3`, `/channels2`, `/new_prefs`) | Automatic save/load of contacts, channels, and preferences to flash storage | +| Push: contact deleted | `PUSH_CODE_CONTACT_DELETED` (0x8F) | Notification when a contact is overwritten by auto-add | +| Push: contacts full | `PUSH_CODE_CONTACTS_FULL` (0x90) | Notification when contact storage is full | +| Push: RX data log | `PUSH_CODE_LOG_RX_DATA` (0x88) | Raw received packet logging for diagnostics | +| Keep-alive mechanism | Server-driven keep-alive | Periodic keep-alive packets for active server connections | +| Firmware version reporting | `FIRMWARE_VER_CODE` / `FIRMWARE_BUILD_DATE` | Version and build metadata in device info response | diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index bf450d3..c688b1b 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -96,6 +96,7 @@ markdown_extensions: nav: - Home: index.md - Node Usage Guide: node.md + - Companion Guide: companion.md - API Reference: - Core: api/core.md - Protocol: api/protocol.md diff --git a/src/pymc_core/companion/binary_parsing.py b/src/pymc_core/companion/binary_parsing.py index dbe4ec1..c3bc582 100644 --- a/src/pymc_core/companion/binary_parsing.py +++ b/src/pymc_core/companion/binary_parsing.py @@ -1,5 +1,7 @@ """Parse binary response payloads by request type (BinaryReqType).""" +from __future__ import annotations + import struct from typing import Any, Optional diff --git a/src/pymc_core/companion/channel_store.py b/src/pymc_core/companion/channel_store.py index 4173a20..81fd03a 100644 --- a/src/pymc_core/companion/channel_store.py +++ b/src/pymc_core/companion/channel_store.py @@ -1,5 +1,7 @@ """In-memory channel storage compatible with MeshNode's channel_db interface.""" +from __future__ import annotations + from typing import Optional from .constants import DEFAULT_MAX_CHANNELS @@ -78,6 +80,6 @@ def get_count(self) -> int: """Return the number of configured channels.""" return sum(1 for ch in self._channels if ch is not None) - def clear(self): + def clear(self) -> None: """Remove all channels.""" self._channels = [None] * self._max_channels diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py index 6995038..35b33b7 100644 --- a/src/pymc_core/companion/companion_base.py +++ b/src/pymc_core/companion/companion_base.py @@ -10,18 +10,24 @@ import asyncio import copy import logging +import random import struct import time +from abc import ABC, abstractmethod from collections import OrderedDict from typing import Any, Callable, Optional from ..node.events import EventService, EventSubscriber, MeshEvents -from ..protocol import LocalIdentity, PacketBuilder +from ..protocol import LocalIdentity, Packet, PacketBuilder from ..protocol.constants import ( + ADVERT_FLAG_HAS_LOCATION, ADVERT_FLAG_HAS_NAME, ADVERT_FLAG_IS_CHAT_NODE, ADVERT_FLAG_IS_REPEATER, ADVERT_FLAG_IS_ROOM_SERVER, + ADVERT_FLAG_IS_SENSOR, + REQ_TYPE_GET_TELEMETRY_DATA, + TELEM_PERM_BASE, ) from .channel_store import ChannelStore from .constants import ( @@ -33,14 +39,16 @@ DEFAULT_MAX_CHANNELS, DEFAULT_MAX_CONTACTS, DEFAULT_OFFLINE_QUEUE_SIZE, + DEFAULT_RESPONSE_TIMEOUT_MS, MAX_SIGN_DATA_SIZE, + PROTOCOL_CODE_BINARY_REQ, STATS_TYPE_CORE, STATS_TYPE_PACKETS, STATS_TYPE_RADIO, ) from .contact_store import ContactStore from .message_queue import MessageQueue -from .models import AdvertPath, Channel, Contact, NodePrefs, QueuedMessage +from .models import AdvertPath, Channel, Contact, NodePrefs, QueuedMessage, SentResult from .path_cache import PathCache from .stats_collector import StatsCollector @@ -108,12 +116,12 @@ def adv_type_to_flags(adv_type: int) -> int: elif adv_type == ADV_TYPE_ROOM: return ADVERT_FLAG_IS_ROOM_SERVER elif adv_type == ADV_TYPE_SENSOR: - return 0x04 + return ADVERT_FLAG_IS_SENSOR return ADVERT_FLAG_IS_CHAT_NODE -class CompanionBase: - """Base class for companion implementations. +class CompanionBase(ABC): + """Abstract base class for companion implementations. Provides shared stores, event handling, contact management, device config, and push callbacks. Subclasses implement TX (via node or packet_injector). @@ -164,6 +172,8 @@ def _init_companion_stores( # Pending binary requests by tag (hex) for matching responses self._pending_binary_requests: dict[str, dict] = {} + # Pending path discovery tags for matching responses + self._pending_discovery_tags: set[int] = set() # GRP_TXT dedup by packet hash: match Mesh.cpp behavior (only process when !_tables->hasSeen(pkt)), # so companion queues one frame per logical message like the firmware. @@ -180,26 +190,32 @@ def _init_companion_stores( # ------------------------------------------------------------------------- def get_contacts(self, since: int = 0) -> list[Contact]: + """Return all contacts, optionally filtered by modification time.""" return self.contacts.get_all(since=since) def get_contact_by_key(self, pub_key: bytes) -> Optional[Contact]: + """Look up a contact by its full 32-byte public key.""" return self.contacts.get_by_key(pub_key) def get_contact_by_name(self, name: str) -> Optional[Contact]: + """Look up a contact by name, returning the full Contact or None.""" proxy = self.contacts.get_by_name(name) if proxy: return self.contacts.get_by_key(bytes.fromhex(proxy.public_key)) return None def add_update_contact(self, contact: Contact) -> bool: + """Add or update a contact, setting lastmod if unset.""" if contact.lastmod == 0: contact.lastmod = int(time.time()) return self.contacts.add(contact) def remove_contact(self, pub_key: bytes) -> bool: + """Remove a contact by public key.""" return self.contacts.remove(pub_key) def export_contact(self, pub_key: Optional[bytes] = None) -> Optional[bytes]: + """Export a contact (or self) as a 73-byte binary packet.""" if pub_key is None: key = self._identity.get_public_key() name = self.prefs.node_name.encode("utf-8")[:32] @@ -231,6 +247,7 @@ def export_contact(self, pub_key: Optional[bytes] = None) -> Optional[bytes]: ) def import_contact(self, packet_data: bytes) -> bool: + """Import a contact from a 73-byte binary packet.""" if len(packet_data) < 73: logger.warning(f"Import data too short: {len(packet_data)} bytes") return False @@ -262,6 +279,7 @@ def set_advert_name(self, name: str) -> None: self.prefs.node_name = name[:31] def set_advert_latlon(self, lat: float, lon: float) -> None: + """Set the GPS coordinates included in advertisements.""" if not (-90.0 <= lat <= 90.0): raise ValueError(f"Latitude out of range: {lat}") if not (-180.0 <= lon <= 180.0): @@ -270,6 +288,7 @@ def set_advert_latlon(self, lat: float, lon: float) -> None: self.prefs.longitude = lon def set_radio_params(self, freq_hz: int, bw_hz: int, sf: int, cr: int) -> bool: + """Set radio parameters (frequency, bandwidth, SF, CR).""" if not (5 <= sf <= 12): raise ValueError(f"Spreading factor out of range: {sf}") if not (5 <= cr <= 8): @@ -281,14 +300,17 @@ def set_radio_params(self, freq_hz: int, bw_hz: int, sf: int, cr: int) -> bool: return True def set_tx_power(self, power_dbm: int) -> bool: + """Set the transmit power in dBm.""" self.prefs.tx_power_dbm = power_dbm return True def set_tuning_params(self, rx_delay: float, airtime_factor: float) -> None: + """Set RX delay and airtime factor tuning parameters.""" self.prefs.rx_delay_base = rx_delay self.prefs.airtime_factor = airtime_factor def get_tuning_params(self) -> tuple[float, float]: + """Return the current (rx_delay, airtime_factor) tuning parameters.""" return (self.prefs.rx_delay_base, self.prefs.airtime_factor) def set_other_params( @@ -298,6 +320,7 @@ def set_other_params( advert_loc_policy: int, multi_acks: int, ) -> None: + """Set additional node parameters (manual add, telemetry, location, multi-acks).""" self.prefs.manual_add_contacts = manual_add self.prefs.telemetry_mode_base = telemetry_modes & 0x03 self.prefs.telemetry_mode_location = (telemetry_modes >> 2) & 0x03 @@ -306,9 +329,11 @@ def set_other_params( self.prefs.multi_acks = multi_acks def get_self_info(self) -> NodePrefs: + """Return a copy of the current node preferences.""" return copy.copy(self.prefs) def get_public_key(self) -> bytes: + """Return this node's 32-byte Ed25519 public key.""" return self._identity.get_public_key() # ------------------------------------------------------------------------- @@ -316,6 +341,7 @@ def get_public_key(self) -> bytes: # ------------------------------------------------------------------------- def reset_path(self, pub_key: bytes) -> bool: + """Reset the outbound routing path for a contact.""" contact = self.contacts.get_by_key(pub_key) if not contact: return False @@ -325,6 +351,7 @@ def reset_path(self, pub_key: bytes) -> bool: return True def get_advert_path(self, pub_key_prefix: bytes) -> Optional[AdvertPath]: + """Look up a cached advert path by public key prefix.""" return self.path_cache.get_by_prefix(pub_key_prefix) # ------------------------------------------------------------------------- @@ -332,9 +359,11 @@ def get_advert_path(self, pub_key_prefix: bytes) -> Optional[AdvertPath]: # ------------------------------------------------------------------------- def get_channel(self, idx: int) -> Optional[Channel]: + """Return the channel at the given index, or None.""" return self.channels.get(idx) def set_channel(self, idx: int, name: str, secret: bytes) -> bool: + """Set a channel at the given index with name and 32-byte secret.""" # MeshCore DataStore uses 32-byte secret; GroupTextHandler uses up to 32 for HMAC if len(secret) < 32: secret = secret + b"\x00" * (32 - len(secret)) @@ -347,10 +376,12 @@ def set_channel(self, idx: int, name: str, secret: bytes) -> bool: # ------------------------------------------------------------------------- def sign_start(self) -> int: + """Begin a signing session; returns the maximum sign buffer size.""" self._sign_buffer = bytearray() return MAX_SIGN_DATA_SIZE def sign_data(self, data: bytes) -> bool: + """Append data to the signing buffer.""" if self._sign_buffer is None: logger.warning("sign_data called without sign_start") return False @@ -377,6 +408,7 @@ def sign_finish(self) -> Optional[bytes]: # ------------------------------------------------------------------------- def export_private_key(self) -> bytes: + """Return the raw signing key bytes for backup/export.""" return self._identity.get_signing_key_bytes() # ------------------------------------------------------------------------- @@ -384,6 +416,7 @@ def export_private_key(self) -> bytes: # ------------------------------------------------------------------------- def set_flood_scope(self, transport_key: Optional[bytes] = None) -> None: + """Set or clear the flood transport key for scoped flooding.""" if transport_key and len(transport_key) >= 16: self._flood_transport_key = transport_key[:16] else: @@ -394,6 +427,7 @@ def set_flood_scope(self, transport_key: Optional[bytes] = None) -> None: # ------------------------------------------------------------------------- def get_stats(self, stats_type: int = STATS_TYPE_PACKETS) -> dict: + """Return statistics of the requested type (core, radio, or packets).""" if stats_type == STATS_TYPE_CORE: return { "uptime_secs": self.stats.get_uptime_secs(), @@ -420,9 +454,11 @@ def _get_radio_stats(self) -> dict: # ------------------------------------------------------------------------- def get_custom_vars(self) -> dict[str, str]: + """Return a copy of all custom variables.""" return dict(self._custom_vars) def set_custom_var(self, name: str, value: str) -> bool: + """Set a custom variable by name.""" self._custom_vars[name] = value return True @@ -431,9 +467,11 @@ def set_custom_var(self, name: str, value: str) -> bool: # ------------------------------------------------------------------------- def get_autoadd_config(self) -> int: + """Return the current auto-add configuration bitmask.""" return self.prefs.autoadd_config def set_autoadd_config(self, config: int) -> None: + """Set the auto-add configuration bitmask.""" self.prefs.autoadd_config = config # ------------------------------------------------------------------------- @@ -544,7 +582,265 @@ async def _on_binary_response( async def _try_handle_path_discovery( self, tag_bytes: bytes, path_info: tuple ) -> bool: - """If this tag is a pending path discovery, fire path_discovery_response and return True. Override in bridge.""" + """If tag is pending path discovery, fire path_discovery_response and return True.""" + out_path, in_path, contact_pubkey = path_info + tag_int = int.from_bytes(tag_bytes, "little") + if tag_int not in self._pending_discovery_tags: + return False + self._pending_discovery_tags.discard(tag_int) + await self._fire_callbacks( + "path_discovery_response", + tag_bytes, + contact_pubkey, + out_path, + in_path, + ) + return True + + # ------------------------------------------------------------------------- + # Abstract methods (subclasses must implement) + # ------------------------------------------------------------------------- + + @abstractmethod + async def _send_packet( + self, pkt: Packet, wait_for_ack: bool = False + ) -> bool: + """Send a packet via the subclass transport (radio or packet_injector).""" + + @abstractmethod + async def start(self) -> None: + """Start the companion.""" + + @abstractmethod + async def stop(self) -> None: + """Stop the companion.""" + + @property + @abstractmethod + def is_running(self) -> bool: + """Return whether the companion is currently running.""" + + @abstractmethod + async def send_text_message( + self, + pub_key: bytes, + text: str, + txt_type: int = 0, + attempt: int = 1, + ) -> SentResult: + """Send a direct text message to a contact.""" + + @abstractmethod + async def send_channel_message(self, channel_idx: int, text: str) -> bool: + """Send a message to a channel.""" + + @abstractmethod + async def send_login(self, pub_key: bytes, password: str) -> dict: + """Send a login request to a repeater.""" + + @abstractmethod + async def send_trace_path( + self, + pub_key: bytes, + tag: int, + auth_code: int, + flags: int = 0, + ) -> bool: + """Send a trace path request to a contact.""" + + @abstractmethod + def import_private_key(self, key: bytes) -> bool: + """Import a private key and rebuild the identity.""" + + @abstractmethod + async def send_control_data(self, data: Any = None) -> bool: + """Send a control data packet.""" + + # ------------------------------------------------------------------------- + # Unified TX methods (shared between Radio and Bridge) + # ------------------------------------------------------------------------- + + async def advertise(self, flood: bool = True) -> bool: + """Broadcast an advertisement packet.""" + flags = adv_type_to_flags(self.prefs.adv_type) + flags |= ADVERT_FLAG_HAS_NAME + lat, lon = 0.0, 0.0 + if self.prefs.advert_loc_policy == ADVERT_LOC_SHARE: + lat, lon = self.prefs.latitude, self.prefs.longitude + if lat != 0.0 or lon != 0.0: + flags |= ADVERT_FLAG_HAS_LOCATION + route = "flood" if flood else "direct" + pkt = PacketBuilder.create_advert( + local_identity=self._identity, + name=self.prefs.node_name, + lat=lat, + lon=lon, + flags=flags, + route_type=route, + ) + success = await self._send_packet(pkt, wait_for_ack=False) + if success: + self.stats.record_tx(is_flood=flood) + else: + self.stats.record_tx_error() + return success + + async def share_contact(self, pub_key: bytes) -> bool: + """Share a contact's advert to the mesh.""" + contact = self.contacts.get_by_key(pub_key) + if not contact: + return False + try: + pkt = PacketBuilder.create_advert( + local_identity=self._identity, + name=contact.name, + flags=adv_type_to_flags(contact.adv_type) | ADVERT_FLAG_HAS_NAME, + route_type="direct", + ) + return await self._send_packet(pkt, wait_for_ack=False) + except Exception as e: + logger.error(f"Error sharing contact: {e}") + return False + + async def send_trace_path_raw( + self, + tag: int, + auth_code: int, + flags: int, + path_bytes: bytes, + ) -> bool: + """Send a trace packet with an explicit path.""" + try: + path_list = list(path_bytes) + pkt = PacketBuilder.create_trace(tag, auth_code, flags, path=path_list) + return await self._send_packet(pkt, wait_for_ack=False) + except Exception as e: + logger.error(f"Error sending trace (raw path): {e}") + return False + + async def send_binary_req( + self, pub_key: bytes, data: bytes, timeout_seconds: float = 15.0 + ) -> SentResult: + """Send binary request (CMD_SEND_BINARY_REQ). + + data = request_type(1) + optional payload. + Returns SentResult with expected_ack (4-byte tag as int) and timeout_ms. + """ + contact = self.contacts.get_by_key(pub_key) + if not contact: + return SentResult(success=False) + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return SentResult(success=False) + tag_int = random.randint(0, 0xFFFFFFFF) + tag_bytes = tag_int.to_bytes(4, "little") + tag_hex = tag_bytes.hex() + request_type = data[0] if len(data) >= 1 else 0 + req_payload = tag_bytes + data + self.cleanup_expired_binary_requests() + self.register_binary_request( + tag_hex, + request_type=request_type, + timeout_seconds=timeout_seconds, + pubkey_prefix=pub_key[:6].hex(), + ) + try: + pkt, _ = PacketBuilder.create_protocol_request( + contact=proxy, + local_identity=self._identity, + protocol_code=PROTOCOL_CODE_BINARY_REQ, + data=req_payload, + ) + success = await self._send_packet(pkt, wait_for_ack=False) + except Exception as e: + logger.error(f"Binary request send error: {e}") + self._pending_binary_requests.pop(tag_hex, None) + return SentResult(success=False) + if not success: + self._pending_binary_requests.pop(tag_hex, None) + return SentResult(success=False) + return SentResult( + success=True, + is_flood=contact.out_path_len <= 0, + expected_ack=tag_int, + timeout_ms=DEFAULT_RESPONSE_TIMEOUT_MS, + ) + + async def send_path_discovery(self, pub_key: bytes) -> bool: + """Legacy: send path discovery without returning tag. Prefer send_path_discovery_req.""" + result = await self.send_path_discovery_req(pub_key) + return result.success + + async def send_path_discovery_req(self, pub_key: bytes) -> SentResult: + """Send path discovery (flood telemetry request with tag). + + Returns SentResult for RESP_CODE_SENT. When path return arrives with + matching tag, path_discovery_response is fired (PUSH 0x8D). + """ + contact = self.contacts.get_by_key(pub_key) + if not contact: + return SentResult(success=False) + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return SentResult(success=False) + tag_int = random.randint(0, 0xFFFFFFFF) + tag_bytes = tag_int.to_bytes(4, "little") + inv_perm = 0xFF & ~TELEM_PERM_BASE + req_payload = tag_bytes + bytes( + [REQ_TYPE_GET_TELEMETRY_DATA, inv_perm, 0, 0, 0] + ) + old_path_len = contact.out_path_len + old_path = contact.out_path + contact.out_path_len = -1 + contact.out_path = b"" + self.contacts.update(contact) + try: + pkt, _ = PacketBuilder.create_protocol_request( + contact=proxy, + local_identity=self._identity, + protocol_code=REQ_TYPE_GET_TELEMETRY_DATA, + data=req_payload, + ) + success = await self._send_packet(pkt, wait_for_ack=False) + if success: + self._pending_discovery_tags.add(tag_int) + return SentResult( + success=success, + is_flood=True, + expected_ack=tag_int, + timeout_ms=DEFAULT_RESPONSE_TIMEOUT_MS, + ) + except Exception as e: + logger.error(f"Error in path discovery: {e}") + return SentResult(success=False) + finally: + current = self.contacts.get_by_key(pub_key) + if current and current.out_path_len == -1: + current.out_path_len = old_path_len + current.out_path = old_path + self.contacts.update(current) + + def sync_next_message(self) -> Optional[QueuedMessage]: + """Pop and return the next queued message, or None.""" + return self.message_queue.pop() + + # ------------------------------------------------------------------------- + # Dedup Helper + # ------------------------------------------------------------------------- + + def _check_dedup( + self, cache: OrderedDict, key: str, ttl: float, max_size: int + ) -> bool: + """Return True if *key* is a duplicate. Evicts expired entries.""" + now = time.time() + if key in cache: + return True + expired = [k for k, ts in cache.items() if now - ts > ttl] + for k in expired: + del cache[k] + cache[key] = now + if len(cache) > max_size: + cache.popitem(last=False) return False # ------------------------------------------------------------------------- @@ -571,16 +867,10 @@ async def _handle_mesh_event(self, event_type: str, data: dict) -> None: async def _handle_new_message(self, data: dict) -> None: # Deduplicate by packet hash so reconnects don't queue the same packet multiple times. pkt_hash = data.get("packet_hash") - if pkt_hash: - now = time.time() - if pkt_hash in self._seen_txt: - return - expired = [k for k, ts in self._seen_txt.items() if now - ts > self._seen_txt_ttl] - for k in expired: - del self._seen_txt[k] - self._seen_txt[pkt_hash] = now - if len(self._seen_txt) > self._seen_txt_max: - self._seen_txt.popitem(last=False) + if pkt_hash and self._check_dedup( + self._seen_txt, pkt_hash, self._seen_txt_ttl, self._seen_txt_max + ): + return sender_key_hex = data.get("contact_pubkey", "") sender_key = bytes.fromhex(sender_key_hex) if sender_key_hex else b"" @@ -601,22 +891,17 @@ async def _handle_new_message(self, data: dict) -> None: message_text, msg.timestamp, msg.txt_type, + pkt_hash, ) async def _handle_new_channel_message(self, data: dict) -> None: # Deduplicate by packet hash so we queue one frame per logical message, matching # firmware: Mesh.cpp only calls onChannelMessageRecv when !_tables->hasSeen(pkt). pkt_hash = data.get("packet_hash") - if pkt_hash: - now = time.time() - if pkt_hash in self._seen_grp_txt: - return - expired = [k for k, ts in self._seen_grp_txt.items() if now - ts > self._seen_grp_txt_ttl] - for k in expired: - del self._seen_grp_txt[k] - self._seen_grp_txt[pkt_hash] = now - if len(self._seen_grp_txt) > self._seen_grp_txt_max: - self._seen_grp_txt.popitem(last=False) + if pkt_hash and self._check_dedup( + self._seen_grp_txt, pkt_hash, self._seen_grp_txt_ttl, self._seen_grp_txt_max + ): + return path_len = data.get("path_len", 0) channel_name = data.get("channel_name", "") @@ -648,6 +933,7 @@ async def _handle_new_channel_message(self, data: dict) -> None: msg.timestamp, path_len, channel_idx, + pkt_hash, ) async def _fire_callbacks(self, event_name: str, *args: Any) -> None: diff --git a/src/pymc_core/companion/companion_bridge.py b/src/pymc_core/companion/companion_bridge.py index cf4f478..78ac5d3 100644 --- a/src/pymc_core/companion/companion_bridge.py +++ b/src/pymc_core/companion/companion_bridge.py @@ -10,11 +10,9 @@ import asyncio import logging -import random import time from typing import Any, Callable, Optional -from ..node.events import EventService, EventSubscriber, MeshEvents from ..node.handlers import ( AdvertHandler, GroupTextHandler, @@ -27,8 +25,6 @@ from ..protocol import LocalIdentity, PacketBuilder from ..protocol import Packet from ..protocol.constants import ( - ADVERT_FLAG_HAS_LOCATION, - ADVERT_FLAG_HAS_NAME, PAYLOAD_TYPE_ACK, PAYLOAD_TYPE_ADVERT, PAYLOAD_TYPE_ANON_REQ, @@ -37,20 +33,23 @@ PAYLOAD_TYPE_PATH, PAYLOAD_TYPE_RESPONSE, PAYLOAD_TYPE_TXT_MSG, + REQ_TYPE_GET_STATUS, + REQ_TYPE_GET_TELEMETRY_DATA, ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD, ) -from ..protocol.constants import REQ_TYPE_GET_STATUS, REQ_TYPE_GET_TELEMETRY_DATA, TELEM_PERM_BASE -from .companion_base import CompanionBase, ResponseWaiter, adv_type_to_flags +from .companion_base import CompanionBase, ResponseWaiter from .constants import ( ADV_TYPE_CHAT, - ADVERT_LOC_SHARE, DEFAULT_MAX_CHANNELS, DEFAULT_MAX_CONTACTS, DEFAULT_OFFLINE_QUEUE_SIZE, + PROTOCOL_CODE_ANON_REQ, + PROTOCOL_CODE_BINARY_REQ, + PROTOCOL_CODE_RAW_DATA, TXT_TYPE_PLAIN, ) -from .models import Contact, QueuedMessage, SentResult +from .models import Contact, SentResult logger = logging.getLogger("CompanionBridge") @@ -110,7 +109,7 @@ def __init__( offline_queue_size: int = DEFAULT_OFFLINE_QUEUE_SIZE, radio_config: Optional[dict] = None, authenticate_callback: Optional[Callable[..., tuple[bool, int]]] = None, - ): + ) -> None: """Initialise the companion bridge.""" self._init_companion_stores( identity=identity, @@ -123,7 +122,7 @@ def __init__( ) self._packet_injector = packet_injector - async def _send_packet(pkt: Packet, wait_for_ack: bool = False) -> bool: + async def _handler_send_packet(pkt: Packet, wait_for_ack: bool = False) -> bool: return await self._packet_injector(pkt, wait_for_ack=wait_for_ack) def _login_send_callback(pkt: Packet, delay_ms: int) -> None: @@ -132,10 +131,10 @@ async def _delayed_send() -> None: await self._packet_injector(pkt, wait_for_ack=False) asyncio.create_task(_delayed_send()) - _log = lambda msg: logger.debug(f"[CompanionBridge] {msg}") + def _log(msg: str) -> None: + logger.debug(f"[CompanionBridge] {msg}") self._pending_ack_crcs: set[int] = set() - self._pending_discovery_tags: set[int] = set() ack_handler = _BridgeAckHandler(self) protocol_response_handler = ProtocolResponseHandler( _log, identity, self.contacts @@ -167,7 +166,7 @@ def _reject_all(*args, **kwargs) -> tuple[bool, int]: identity, self.contacts, _log, - _send_packet, + _handler_send_packet, self._event_service, self._radio_config, ), @@ -180,7 +179,7 @@ def _reject_all(*args, **kwargs) -> tuple[bool, int]: identity, self.contacts, _log, - _send_packet, + _handler_send_packet, self.channels, self._event_service, node_name, @@ -265,6 +264,16 @@ def _update_stores_from_advert(self, packet: Packet, advert_data: dict): logger.error(f"Error updating stores from advert: {e}") return None + # ------------------------------------------------------------------------- + # Abstract method implementations + # ------------------------------------------------------------------------- + + async def _send_packet( + self, pkt: Packet, wait_for_ack: bool = False + ) -> bool: + """Send a packet via the packet_injector.""" + return await self._packet_injector(pkt, wait_for_ack=wait_for_ack) + # ------------------------------------------------------------------------- # Lifecycle # ------------------------------------------------------------------------- @@ -284,34 +293,6 @@ async def stop(self) -> None: def is_running(self) -> bool: return self._running - # ------------------------------------------------------------------------- - # Advertisement - # ------------------------------------------------------------------------- - - async def advertise(self, flood: bool = True) -> bool: - flags = adv_type_to_flags(self.prefs.adv_type) - flags |= ADVERT_FLAG_HAS_NAME - lat, lon = 0.0, 0.0 - if self.prefs.advert_loc_policy == ADVERT_LOC_SHARE: - lat, lon = self.prefs.latitude, self.prefs.longitude - if lat != 0.0 or lon != 0.0: - flags |= ADVERT_FLAG_HAS_LOCATION - route = "flood" if flood else "direct" - pkt = PacketBuilder.create_advert( - local_identity=self._identity, - name=self.prefs.node_name, - lat=lat, - lon=lon, - flags=flags, - route_type=route, - ) - success = await self._packet_injector(pkt, wait_for_ack=False) - if success: - self.stats.record_tx(is_flood=flood) - else: - self.stats.record_tx_error() - return success - # ------------------------------------------------------------------------- # Messaging # ------------------------------------------------------------------------- @@ -382,9 +363,6 @@ async def send_channel_message(self, channel_idx: int, text: str) -> bool: self.stats.record_tx_error() return False - def sync_next_message(self) -> Optional[QueuedMessage]: - return self.message_queue.pop() - async def send_raw_data( self, dest_key: bytes, @@ -401,7 +379,7 @@ async def send_raw_data( pkt, _ = PacketBuilder.create_protocol_request( contact=proxy, local_identity=self._identity, - protocol_code=0x00, + protocol_code=PROTOCOL_CODE_RAW_DATA, data=data, ) success = await self._packet_injector(pkt, wait_for_ack=False) @@ -410,26 +388,6 @@ async def send_raw_data( logger.error(f"Error sending raw data: {e}") return SentResult(success=False) - # ------------------------------------------------------------------------- - # Contact Management (share_contact override) - # ------------------------------------------------------------------------- - - async def share_contact(self, pub_key: bytes) -> bool: - contact = self.contacts.get_by_key(pub_key) - if not contact: - return False - try: - pkt = PacketBuilder.create_advert( - local_identity=self._identity, - name=contact.name, - flags=adv_type_to_flags(contact.adv_type) | ADVERT_FLAG_HAS_NAME, - route_type="direct", - ) - return await self._packet_injector(pkt, wait_for_ack=False) - except Exception as e: - logger.error(f"Error sharing contact: {e}") - return False - # ------------------------------------------------------------------------- # Path & Routing # ------------------------------------------------------------------------- @@ -454,93 +412,6 @@ async def send_trace_path( logger.error(f"Error sending trace: {e}") return False - async def send_trace_path_raw( - self, - tag: int, - auth_code: int, - flags: int, - path_bytes: bytes, - ) -> bool: - """Send a trace packet with an explicit path (e.g. from CMD_SEND_TRACE_PATH). Matches firmware behavior.""" - try: - path_list = list(path_bytes) - pkt = PacketBuilder.create_trace(tag, auth_code, flags, path=path_list) - return await self._packet_injector(pkt, wait_for_ack=False) - except Exception as e: - logger.error(f"Error sending trace (raw path): {e}") - return False - - async def _try_handle_path_discovery( - self, tag_bytes: bytes, path_info: tuple - ) -> bool: - """If tag is pending path discovery, fire path_discovery_response and return True.""" - out_path, in_path, contact_pubkey = path_info - tag_int = int.from_bytes(tag_bytes, "little") - if tag_int not in self._pending_discovery_tags: - return False - self._pending_discovery_tags.discard(tag_int) - await self._fire_callbacks( - "path_discovery_response", - tag_bytes, - contact_pubkey, - out_path, - in_path, - ) - return True - - async def send_path_discovery(self, pub_key: bytes) -> bool: - """Legacy: send path discovery without returning tag. Prefer send_path_discovery_req.""" - result = await self.send_path_discovery_req(pub_key) - return result.success - - async def send_path_discovery_req(self, pub_key: bytes) -> SentResult: - """Send path discovery (flood telemetry request with tag). Returns SentResult for RESP_CODE_SENT. - When path return arrives with matching tag, path_discovery_response is fired (PUSH 0x8D).""" - contact = self.contacts.get_by_key(pub_key) - if not contact: - return SentResult(success=False) - proxy = self.contacts.get_by_name(contact.name) - if not proxy: - return SentResult(success=False) - tag_int = random.randint(0, 0xFFFFFFFF) - tag_bytes = tag_int.to_bytes(4, "little") - # Firmware: REQ_TYPE_GET_TELEMETRY_DATA, ~TELEM_PERM_BASE, reserved(3), random(4) -> 9 bytes; tag is from sendRequest. - # We send tag(4) + type(1) + perm(1) + reserved(3) = 9 bytes so response echoes our tag. - inv_perm = 0xFF & ~TELEM_PERM_BASE - req_payload = tag_bytes + bytes( - [REQ_TYPE_GET_TELEMETRY_DATA, inv_perm, 0, 0, 0] - ) - old_path_len = contact.out_path_len - old_path = contact.out_path - contact.out_path_len = -1 - contact.out_path = b"" - self.contacts.update(contact) - try: - pkt, _ = PacketBuilder.create_protocol_request( - contact=proxy, - local_identity=self._identity, - protocol_code=REQ_TYPE_GET_TELEMETRY_DATA, - data=req_payload, - ) - success = await self._packet_injector(pkt, wait_for_ack=False) - if success: - self._pending_discovery_tags.add(tag_int) - return SentResult( - success=success, - is_flood=True, - expected_ack=tag_int, - timeout_ms=10000, - ) - except Exception as e: - logger.error(f"Error in path discovery: {e}") - return SentResult(success=False) - finally: - current = self.contacts.get_by_key(pub_key) - if current and current.out_path_len == -1: - current.out_path_len = old_path_len - current.out_path = old_path - self.contacts.update(current) - async def send_control_data(self, data: bytes) -> bool: """Send a CONTROL packet (e.g. discovery request). data = first byte flags/type (0x80 set for DISCOVER_REQ) + payload. Firmware: (cmd_frame[1] & 0x80) != 0, createControlData(&cmd_frame[1], len-1), sendZeroHop(resp). Returns True if sent.""" @@ -709,58 +580,12 @@ async def send_telemetry_request( finally: self._protocol_response_handler.clear_response_callback(contact_hash) - async def send_binary_req( - self, pub_key: bytes, data: bytes, timeout_seconds: float = 15.0 - ) -> SentResult: - """Send binary request (CMD_SEND_BINARY_REQ). data = request_type(1) + optional payload. - Returns SentResult with expected_ack (4-byte tag as int) and timeout_ms for RESP_CODE_SENT. - """ - contact = self.contacts.get_by_key(pub_key) - if not contact: - return SentResult(success=False) - proxy = self.contacts.get_by_name(contact.name) - if not proxy: - return SentResult(success=False) - tag_int = random.randint(0, 0xFFFFFFFF) - tag_bytes = tag_int.to_bytes(4, "little") - tag_hex = tag_bytes.hex() - request_type = data[0] if len(data) >= 1 else 0 - req_payload = tag_bytes + data - self.cleanup_expired_binary_requests() - self.register_binary_request( - tag_hex, - request_type=request_type, - timeout_seconds=timeout_seconds, - pubkey_prefix=pub_key[:6].hex(), - ) - try: - pkt, _ = PacketBuilder.create_protocol_request( - contact=proxy, - local_identity=self._identity, - protocol_code=0x02, - data=req_payload, - ) - success = await self._packet_injector(pkt, wait_for_ack=False) - except Exception as e: - logger.error(f"Binary request send error: {e}") - self._pending_binary_requests.pop(tag_hex, None) - return SentResult(success=False) - if not success: - self._pending_binary_requests.pop(tag_hex, None) - return SentResult(success=False) - return SentResult( - success=True, - is_flood=contact.out_path_len <= 0, - expected_ack=tag_int, - timeout_ms=10000, - ) - async def send_binary_request(self, pub_key: bytes, data: bytes) -> dict: """Legacy: send binary request and wait for response via waiter. Prefer send_binary_req + on_binary_response.""" - return await self._send_protocol_request(pub_key, 0x02, data) + return await self._send_protocol_request(pub_key, PROTOCOL_CODE_BINARY_REQ, data) async def send_anon_request(self, pub_key: bytes, data: bytes) -> dict: - return await self._send_protocol_request(pub_key, 0x07, data) + return await self._send_protocol_request(pub_key, PROTOCOL_CODE_ANON_REQ, data) async def _send_protocol_request( self, pub_key: bytes, protocol_code: int, data: bytes diff --git a/src/pymc_core/companion/companion_radio.py b/src/pymc_core/companion/companion_radio.py index 76cef8c..acaab22 100644 --- a/src/pymc_core/companion/companion_radio.py +++ b/src/pymc_core/companion/companion_radio.py @@ -12,28 +12,25 @@ import asyncio import logging import random -from typing import Any, Callable, Optional +from typing import Any, Optional from ..node.node import MeshNode from ..protocol import LocalIdentity, Packet, PacketBuilder from ..protocol.constants import ( - ADVERT_FLAG_HAS_LOCATION, - ADVERT_FLAG_HAS_NAME, PAYLOAD_TYPE_CONTROL, - REQ_TYPE_GET_TELEMETRY_DATA, - TELEM_PERM_BASE, ) -from .companion_base import CompanionBase, adv_type_to_flags +from .companion_base import CompanionBase from .constants import ( ADV_TYPE_CHAT, - ADVERT_LOC_SHARE, DEFAULT_MAX_CHANNELS, DEFAULT_MAX_CONTACTS, DEFAULT_OFFLINE_QUEUE_SIZE, - STATS_TYPE_PACKETS, + PROTOCOL_CODE_ANON_REQ, + PROTOCOL_CODE_BINARY_REQ, + PROTOCOL_CODE_RAW_DATA, TXT_TYPE_PLAIN, ) -from .models import QueuedMessage, SentResult +from .models import SentResult logger = logging.getLogger("CompanionRadio") @@ -77,7 +74,7 @@ def __init__( max_channels: int = DEFAULT_MAX_CHANNELS, offline_queue_size: int = DEFAULT_OFFLINE_QUEUE_SIZE, radio_config: Optional[dict] = None, - ): + ) -> None: """Initialise the companion radio.""" self._init_companion_stores( identity=identity, @@ -90,7 +87,6 @@ def __init__( ) self._radio = radio self._dispatcher_task: Optional[asyncio.Task] = None - self._pending_discovery_tags: set[int] = set() self.node = MeshNode( radio=radio, @@ -105,6 +101,16 @@ def __init__( ) self._setup_packet_callbacks() + # ------------------------------------------------------------------------- + # Abstract method implementations + # ------------------------------------------------------------------------- + + async def _send_packet( + self, pkt: Packet, wait_for_ack: bool = False + ) -> bool: + """Send a packet via the MeshNode dispatcher.""" + return await self.node.dispatcher.send_packet(pkt, wait_for_ack=wait_for_ack) + # ------------------------------------------------------------------------- # Lifecycle # ------------------------------------------------------------------------- @@ -136,35 +142,6 @@ async def stop(self) -> None: def is_running(self) -> bool: return self._running - # ------------------------------------------------------------------------- - # Advertisement - # ------------------------------------------------------------------------- - - async def advertise(self, flood: bool = True) -> bool: - flags = adv_type_to_flags(self.prefs.adv_type) - flags |= ADVERT_FLAG_HAS_NAME - lat, lon = 0.0, 0.0 - if self.prefs.advert_loc_policy == ADVERT_LOC_SHARE: - lat = self.prefs.latitude - lon = self.prefs.longitude - if lat != 0.0 or lon != 0.0: - flags |= ADVERT_FLAG_HAS_LOCATION - route = "flood" if flood else "direct" - pkt = PacketBuilder.create_advert( - local_identity=self._identity, - name=self.prefs.node_name, - lat=lat, - lon=lon, - flags=flags, - route_type=route, - ) - success = await self.node.dispatcher.send_packet(pkt, wait_for_ack=False) - if success: - self.stats.record_tx(is_flood=flood) - else: - self.stats.record_tx_error() - return success - # ------------------------------------------------------------------------- # Messaging # ------------------------------------------------------------------------- @@ -224,9 +201,6 @@ async def send_channel_message(self, channel_idx: int, text: str) -> bool: self.stats.record_tx_error() return False - def sync_next_message(self) -> Optional[QueuedMessage]: - return self.message_queue.pop() - async def send_raw_data( self, dest_key: bytes, @@ -240,7 +214,7 @@ async def send_raw_data( try: result = await self.node.send_protocol_request( repeater_name=contact.name, - protocol_code=0x00, + protocol_code=PROTOCOL_CODE_RAW_DATA, data=data, ) return SentResult(success=result.get("success", False)) @@ -248,27 +222,6 @@ async def send_raw_data( logger.error(f"Error sending raw data: {e}") return SentResult(success=False) - # ------------------------------------------------------------------------- - # Contact Management (share_contact overrides base - uses node) - # ------------------------------------------------------------------------- - - async def share_contact(self, pub_key: bytes) -> bool: - contact = self.contacts.get_by_key(pub_key) - if not contact: - logger.warning(f"Contact not found for sharing: {pub_key.hex()[:12]}") - return False - try: - pkt = PacketBuilder.create_advert( - local_identity=self._identity, - name=contact.name, - flags=adv_type_to_flags(contact.adv_type) | ADVERT_FLAG_HAS_NAME, - route_type="direct", - ) - return await self.node.dispatcher.send_packet(pkt, wait_for_ack=False) - except Exception as e: - logger.error(f"Error sharing contact: {e}") - return False - # ------------------------------------------------------------------------- # Device Configuration (overrides for radio hardware) # ------------------------------------------------------------------------- @@ -331,91 +284,6 @@ async def send_trace_path( logger.error(f"Error sending trace: {e}") return False - async def send_trace_path_raw( - self, - tag: int, - auth_code: int, - flags: int, - path_bytes: bytes, - ) -> bool: - """Send a trace packet with an explicit path (e.g. from CMD_SEND_TRACE_PATH). Matches firmware behavior.""" - try: - path_list = list(path_bytes) - pkt = PacketBuilder.create_trace(tag, auth_code, flags, path=path_list) - return await self.node.dispatcher.send_packet(pkt, wait_for_ack=False) - except Exception as e: - logger.error(f"Error sending trace (raw path): {e}") - return False - - async def _try_handle_path_discovery( - self, tag_bytes: bytes, path_info: tuple - ) -> bool: - """If tag is pending path discovery, fire path_discovery_response and return True.""" - out_path, in_path, contact_pubkey = path_info - tag_int = int.from_bytes(tag_bytes, "little") - if tag_int not in self._pending_discovery_tags: - return False - self._pending_discovery_tags.discard(tag_int) - await self._fire_callbacks( - "path_discovery_response", - tag_bytes, - contact_pubkey, - out_path, - in_path, - ) - return True - - async def send_path_discovery(self, pub_key: bytes) -> bool: - """Legacy: send path discovery without returning tag. Prefer send_path_discovery_req.""" - result = await self.send_path_discovery_req(pub_key) - return result.success - - async def send_path_discovery_req(self, pub_key: bytes) -> SentResult: - """Send path discovery (flood telemetry request with tag). Returns SentResult for RESP_CODE_SENT. - When path return arrives with matching tag, path_discovery_response is fired (PUSH 0x8D).""" - contact = self.contacts.get_by_key(pub_key) - if not contact: - return SentResult(success=False) - proxy = self.contacts.get_by_name(contact.name) - if not proxy: - return SentResult(success=False) - tag_int = random.randint(0, 0xFFFFFFFF) - tag_bytes = tag_int.to_bytes(4, "little") - inv_perm = 0xFF & ~TELEM_PERM_BASE - req_payload = tag_bytes + bytes( - [REQ_TYPE_GET_TELEMETRY_DATA, inv_perm, 0, 0, 0] - ) - old_path_len = contact.out_path_len - old_path = contact.out_path - contact.out_path_len = -1 - contact.out_path = b"" - self.contacts.update(contact) - try: - pkt, _ = PacketBuilder.create_protocol_request( - contact=proxy, - local_identity=self._identity, - protocol_code=REQ_TYPE_GET_TELEMETRY_DATA, - data=req_payload, - ) - success = await self.node.dispatcher.send_packet(pkt, wait_for_ack=False) - if success: - self._pending_discovery_tags.add(tag_int) - return SentResult( - success=success, - is_flood=True, - expected_ack=tag_int, - timeout_ms=10000, - ) - except Exception as e: - logger.error(f"Error in path discovery: {e}") - return SentResult(success=False) - finally: - current = self.contacts.get_by_key(pub_key) - if current and current.out_path_len == -1: - current.out_path_len = old_path_len - current.out_path = old_path - self.contacts.update(current) - # ------------------------------------------------------------------------- # Key Management # ------------------------------------------------------------------------- @@ -493,52 +361,6 @@ async def send_telemetry_request( logger.error(f"Telemetry request error: {e}") return {"success": False, "reason": str(e)} - async def send_binary_req( - self, pub_key: bytes, data: bytes, timeout_seconds: float = 15.0 - ) -> SentResult: - """Send binary request (CMD_SEND_BINARY_REQ). data = request_type(1) + optional payload. - Returns SentResult with expected_ack (4-byte tag as int) and timeout_ms for RESP_CODE_SENT. - """ - contact = self.contacts.get_by_key(pub_key) - if not contact: - return SentResult(success=False) - proxy = self.contacts.get_by_name(contact.name) - if not proxy: - return SentResult(success=False) - tag_int = random.randint(0, 0xFFFFFFFF) - tag_bytes = tag_int.to_bytes(4, "little") - tag_hex = tag_bytes.hex() - request_type = data[0] if len(data) >= 1 else 0 - req_payload = tag_bytes + data - self.cleanup_expired_binary_requests() - self.register_binary_request( - tag_hex, - request_type=request_type, - timeout_seconds=timeout_seconds, - pubkey_prefix=pub_key[:6].hex(), - ) - try: - pkt, _ = PacketBuilder.create_protocol_request( - contact=proxy, - local_identity=self._identity, - protocol_code=0x02, - data=req_payload, - ) - success = await self.node.dispatcher.send_packet(pkt, wait_for_ack=False) - except Exception as e: - logger.error(f"Binary request send error: {e}") - self._pending_binary_requests.pop(tag_hex, None) - return SentResult(success=False) - if not success: - self._pending_binary_requests.pop(tag_hex, None) - return SentResult(success=False) - return SentResult( - success=True, - is_flood=contact.out_path_len <= 0, - expected_ack=tag_int, - timeout_ms=10000, - ) - async def send_binary_request(self, pub_key: bytes, data: bytes) -> dict: """Legacy: send binary request and wait for response via waiter. Prefer send_binary_req + on_binary_response.""" contact = self.contacts.get_by_key(pub_key) @@ -547,7 +369,7 @@ async def send_binary_request(self, pub_key: bytes, data: bytes) -> dict: try: return await self.node.send_protocol_request( repeater_name=contact.name, - protocol_code=0x02, + protocol_code=PROTOCOL_CODE_BINARY_REQ, data=data, ) except Exception as e: @@ -561,7 +383,7 @@ async def send_anon_request(self, pub_key: bytes, data: bytes) -> dict: try: return await self.node.send_protocol_request( repeater_name=contact.name, - protocol_code=0x07, + protocol_code=PROTOCOL_CODE_ANON_REQ, data=data, ) except Exception as e: diff --git a/src/pymc_core/companion/constants.py b/src/pymc_core/companion/constants.py index 94a4d69..f62ff64 100644 --- a/src/pymc_core/companion/constants.py +++ b/src/pymc_core/companion/constants.py @@ -1,5 +1,9 @@ """Companion radio constants for application-layer mesh networking features.""" +from __future__ import annotations + +from enum import IntEnum + # --------------------------------------------------------------------------- # ADV Types (contact/node classification) # --------------------------------------------------------------------------- @@ -54,7 +58,7 @@ # --------------------------------------------------------------------------- # Binary request types (CMD_SEND_BINARY_REQ / PUSH_CODE_BINARY_RESPONSE) # --------------------------------------------------------------------------- -class BinaryReqType: +class BinaryReqType(IntEnum): """Binary request type codes (companion frame protocol).""" STATUS = 0x01 KEEP_ALIVE = 0x02 @@ -63,9 +67,17 @@ class BinaryReqType: ACL = 0x05 NEIGHBOURS = 0x06 +# --------------------------------------------------------------------------- +# Protocol Codes (used in create_protocol_request / send_protocol_request) +# --------------------------------------------------------------------------- +PROTOCOL_CODE_RAW_DATA = 0x00 +PROTOCOL_CODE_BINARY_REQ = 0x02 +PROTOCOL_CODE_ANON_REQ = 0x07 + # --------------------------------------------------------------------------- # Default configuration # --------------------------------------------------------------------------- +DEFAULT_RESPONSE_TIMEOUT_MS = 10000 DEFAULT_MAX_CONTACTS = 1000 DEFAULT_OFFLINE_QUEUE_SIZE = 512 DEFAULT_MAX_CHANNELS = 40 diff --git a/src/pymc_core/companion/contact_store.py b/src/pymc_core/companion/contact_store.py index 0b90341..adad857 100644 --- a/src/pymc_core/companion/contact_store.py +++ b/src/pymc_core/companion/contact_store.py @@ -1,5 +1,7 @@ """In-memory contact storage compatible with MeshNode's contacts interface.""" +from __future__ import annotations + import time from typing import Iterable, Iterator, Optional @@ -31,7 +33,7 @@ def __init__(self, contact: Contact): self.gps_lat = contact.gps_lat self.gps_lon = contact.gps_lon - def _sync_from_contact(self): + def _sync_from_contact(self) -> None: """Update proxy fields from the underlying Contact.""" c = self._contact self.public_key = c.public_key.hex() @@ -143,7 +145,7 @@ def is_full(self) -> bool: """Check if the contact store is at capacity.""" return len(self._contacts) >= self._max_contacts - def clear(self): + def clear(self) -> None: """Remove all contacts.""" self._contacts.clear() self._proxies.clear() @@ -152,7 +154,7 @@ def clear(self): # Bulk loading from external sources # ------------------------------------------------------------------ - def load_from(self, contacts: Iterable[Contact]): + def load_from(self, contacts: Iterable[Contact]) -> None: """Bulk-load contacts from any iterable of Contact objects. Replaces all existing contacts. @@ -164,7 +166,7 @@ def load_from(self, contacts: Iterable[Contact]): self._contacts[contact.public_key] = contact self._proxies[contact.public_key] = ContactProxy(contact) - def load_from_dicts(self, records: Iterable[dict]): + def load_from_dicts(self, records: Iterable[dict]) -> None: """Bulk-load contacts from dicts. Each dict must have 'public_key' (hex string or bytes) and 'name' keys. diff --git a/src/pymc_core/companion/message_queue.py b/src/pymc_core/companion/message_queue.py index 8803eb5..283c32c 100644 --- a/src/pymc_core/companion/message_queue.py +++ b/src/pymc_core/companion/message_queue.py @@ -1,5 +1,7 @@ """Fixed-size offline message queue for companion radio.""" +from __future__ import annotations + from collections import deque from typing import Optional @@ -59,6 +61,6 @@ def count(self) -> int: """Return the number of messages in the queue.""" return len(self._queue) - def clear(self): + def clear(self) -> None: """Remove all messages from the queue.""" self._queue.clear() diff --git a/src/pymc_core/companion/models.py b/src/pymc_core/companion/models.py index 3295b82..426809d 100644 --- a/src/pymc_core/companion/models.py +++ b/src/pymc_core/companion/models.py @@ -1,5 +1,7 @@ """Data models for companion radio state objects.""" +from __future__ import annotations + from dataclasses import dataclass, field from typing import Optional diff --git a/src/pymc_core/companion/path_cache.py b/src/pymc_core/companion/path_cache.py index 95e9881..dbaf6c9 100644 --- a/src/pymc_core/companion/path_cache.py +++ b/src/pymc_core/companion/path_cache.py @@ -1,5 +1,8 @@ """Path cache for tracking recently heard advertiser paths.""" +from __future__ import annotations + +from collections import deque from typing import Optional from .models import AdvertPath @@ -14,24 +17,25 @@ class PathCache: """ def __init__(self, max_entries: int = 16): - self._paths: list[AdvertPath] = [] + self._paths: deque[AdvertPath] = deque() self._max = max_entries - def update(self, advert_path: AdvertPath): + def update(self, advert_path: AdvertPath) -> None: """Add or update a path entry. If a path with the same public key prefix already exists, it is - replaced. If the cache is full, the oldest entry is evicted. + removed and the new entry is appended (LRU refresh). If the cache + is full, the oldest entry is evicted. """ - # Check for existing entry with same prefix - for i, existing in enumerate(self._paths): + # Remove existing entry with same prefix (LRU refresh to tail) + for existing in self._paths: if existing.public_key_prefix == advert_path.public_key_prefix: - self._paths[i] = advert_path - return + self._paths.remove(existing) + break - # Add new entry, evicting oldest if full + # Evict oldest if full if len(self._paths) >= self._max: - self._paths.pop(0) + self._paths.popleft() self._paths.append(advert_path) def get_by_prefix(self, prefix: bytes) -> Optional[AdvertPath]: @@ -50,6 +54,6 @@ def get_all(self) -> list[AdvertPath]: """Return all cached paths.""" return list(self._paths) - def clear(self): + def clear(self) -> None: """Remove all cached paths.""" self._paths.clear() diff --git a/src/pymc_core/companion/stats_collector.py b/src/pymc_core/companion/stats_collector.py index c7579e9..a4f6bc7 100644 --- a/src/pymc_core/companion/stats_collector.py +++ b/src/pymc_core/companion/stats_collector.py @@ -1,5 +1,7 @@ """Packet and radio statistics collector for companion radio.""" +from __future__ import annotations + import time from .models import PacketStats @@ -12,25 +14,25 @@ class StatsCollector: Matches the firmware's statistics reporting via CMD_GET_STATS. """ - def __init__(self): + def __init__(self) -> None: self.packets = PacketStats() self._start_time = time.time() - def record_tx(self, is_flood: bool): + def record_tx(self, is_flood: bool) -> None: """Record a successful transmission.""" if is_flood: self.packets.flood_tx += 1 else: self.packets.direct_tx += 1 - def record_rx(self, is_flood: bool): + def record_rx(self, is_flood: bool) -> None: """Record a successful reception.""" if is_flood: self.packets.flood_rx += 1 else: self.packets.direct_rx += 1 - def record_tx_error(self): + def record_tx_error(self) -> None: """Record a transmission error.""" self.packets.tx_errors += 1 @@ -51,7 +53,7 @@ def get_totals(self) -> dict: "uptime_secs": self.get_uptime_secs(), } - def reset(self): + def reset(self) -> None: """Reset all counters and restart uptime.""" self.packets = PacketStats() self._start_time = time.time() diff --git a/src/pymc_core/protocol/constants.py b/src/pymc_core/protocol/constants.py index adeff12..aab5ab1 100644 --- a/src/pymc_core/protocol/constants.py +++ b/src/pymc_core/protocol/constants.py @@ -66,6 +66,7 @@ ADVERT_FLAG_IS_CHAT_NODE = 0x01 ADVERT_FLAG_IS_REPEATER = 0x02 ADVERT_FLAG_IS_ROOM_SERVER = 0x03 +ADVERT_FLAG_IS_SENSOR = 0x04 ADVERT_FLAG_HAS_LOCATION = 0x10 ADVERT_FLAG_HAS_FEATURE1 = 0x20 ADVERT_FLAG_HAS_FEATURE2 = 0x40 diff --git a/tests/test_companion_base.py b/tests/test_companion_base.py index e04951b..bbd11ff 100644 --- a/tests/test_companion_base.py +++ b/tests/test_companion_base.py @@ -13,6 +13,7 @@ ADVERT_FLAG_IS_CHAT_NODE, ADVERT_FLAG_IS_REPEATER, ADVERT_FLAG_IS_ROOM_SERVER, + ADVERT_FLAG_IS_SENSOR, ) @@ -70,7 +71,7 @@ def test_room(self): assert adv_type_to_flags(ADV_TYPE_ROOM) == ADVERT_FLAG_IS_ROOM_SERVER def test_sensor(self): - assert adv_type_to_flags(ADV_TYPE_SENSOR) == 0x04 + assert adv_type_to_flags(ADV_TYPE_SENSOR) == ADVERT_FLAG_IS_SENSOR def test_unknown_defaults_to_chat(self): assert adv_type_to_flags(99) == ADVERT_FLAG_IS_CHAT_NODE From 5703d6bdec6117c24efca94538bd021bbd79d417 Mon Sep 17 00:00:00 2001 From: agessaman Date: Sun, 15 Feb 2026 17:11:15 -0800 Subject: [PATCH 08/50] Implement flood scope functionality in Companion modules - Added methods to set flood regions and apply transport keys for scoped flooding in the CompanionBase class. - Enhanced packet handling to include flood scope transport codes, ensuring flood packets are tagged and routed correctly. - Updated CompanionRadio and Dispatcher classes to propagate flood scope settings and manage flood transport keys. - Documented flood scope usage in the companion guide for better user understanding. --- docs/docs/companion.md | 41 ++-- src/pymc_core/companion/companion_base.py | 42 ++++ src/pymc_core/companion/companion_bridge.py | 2 + src/pymc_core/companion/companion_radio.py | 14 ++ src/pymc_core/node/dispatcher.py | 26 ++ tests/test_companion_regions.py | 258 ++++++++++++++++++++ 6 files changed, 369 insertions(+), 14 deletions(-) create mode 100644 tests/test_companion_regions.py diff --git a/docs/docs/companion.md b/docs/docs/companion.md index e78b650..7aeb282 100644 --- a/docs/docs/companion.md +++ b/docs/docs/companion.md @@ -264,6 +264,30 @@ companion.set_other_params( prefs = companion.get_self_info() # -> NodePrefs ``` +### Flood Scope (Regions) + +Constrain flood packets to a specific region using transport key scoping. +Nodes outside the region will ignore scoped flood packets. + +```python +from pymc_core.protocol.transport_keys import get_auto_key_for + +# Set region by name (auto-derives transport key via SHA-256) +companion.set_flood_region("usa") # '#' prefix added automatically +companion.set_flood_region("#europe") # explicit '#' also works + +# Or set directly with a raw 16-byte transport key +key = get_auto_key_for("#usa") +companion.set_flood_scope(key) + +# Clear scope (flood to all nodes) +companion.set_flood_region(None) +``` + +When a flood scope is active, all flood packets are tagged with a 16-bit transport code +(HMAC-SHA256 derived) and sent as `ROUTE_TYPE_TRANSPORT_FLOOD`. Direct-routed packets +are unaffected. + ### Cryptographic Signing ```python @@ -605,24 +629,13 @@ DEFAULT_RESPONSE_TIMEOUT_MS = 10000 ## Unimplemented MeshCore Companion Features -The following features from the MeshCore companion radio firmware (`examples/companion_radio/`) are **not yet implemented** in pyMC_core: +The following protocol-level features from the MeshCore companion radio firmware (`examples/companion_radio/`) are **not yet implemented** in pyMC_core: -| Feature | Firmware Command | Description | +| Feature | Firmware Reference | Description | |---|---|---| -| Device query | `CMD_DEVICE_QUERY` (0x16) | Hardware capability & firmware version handshake | -| App start handshake | `CMD_APP_START` (0x01) | Initial BLE/serial session setup with self-info response | -| Device time get/set | `CMD_GET_DEVICE_TIME` / `CMD_SET_DEVICE_TIME` | RTC clock synchronisation | -| Reboot | `CMD_REBOOT` (0x13) | Remote device reboot (with confirmation string) | -| Factory reset | `CMD_FACTORY_RESET` (0x33) | Erase all data and reset to defaults | -| BLE PIN | `CMD_SET_DEVICE_PIN` (0x25) | Set BLE pairing PIN | -| Battery & storage | `CMD_GET_BATT_AND_STORAGE` (0x14) | Battery voltage and flash storage info | -| Logout | `CMD_LOGOUT` (0x1D) | Disconnect from a server/repeater session | +| Logout | `CMD_LOGOUT` (0x1D) | Disconnect from a repeater/server session | | Has connection | `CMD_HAS_CONNECTION` (0x1C) | Check if active connection exists to a contact | -| Contact-by-key lookup (protocol) | `CMD_GET_CONTACT_BY_KEY` (0x1E) | Protocol-level single-contact fetch (available in-memory via `get_contact_by_key`) | -| GPS configuration | GPS enable/interval | GPS hardware control and periodic fix interval | -| Data persistence | File I/O (`/contacts3`, `/channels2`, `/new_prefs`) | Automatic save/load of contacts, channels, and preferences to flash storage | | Push: contact deleted | `PUSH_CODE_CONTACT_DELETED` (0x8F) | Notification when a contact is overwritten by auto-add | | Push: contacts full | `PUSH_CODE_CONTACTS_FULL` (0x90) | Notification when contact storage is full | | Push: RX data log | `PUSH_CODE_LOG_RX_DATA` (0x88) | Raw received packet logging for diagnostics | | Keep-alive mechanism | Server-driven keep-alive | Periodic keep-alive packets for active server connections | -| Firmware version reporting | `FIRMWARE_VER_CODE` / `FIRMWARE_BUILD_DATE` | Version and build metadata in device info response | diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py index 35b33b7..2a81358 100644 --- a/src/pymc_core/companion/companion_base.py +++ b/src/pymc_core/companion/companion_base.py @@ -27,8 +27,11 @@ ADVERT_FLAG_IS_ROOM_SERVER, ADVERT_FLAG_IS_SENSOR, REQ_TYPE_GET_TELEMETRY_DATA, + ROUTE_TYPE_FLOOD, + ROUTE_TYPE_TRANSPORT_FLOOD, TELEM_PERM_BASE, ) +from ..protocol.transport_keys import calc_transport_code, get_auto_key_for from .channel_store import ChannelStore from .constants import ( ADV_TYPE_CHAT, @@ -422,6 +425,40 @@ def set_flood_scope(self, transport_key: Optional[bytes] = None) -> None: else: self._flood_transport_key = None + def set_flood_region(self, region_name: Optional[str] = None) -> None: + """Set flood scope from a region name (e.g., ``'#usa'``) or clear it. + + Derives the 16-byte transport key automatically via SHA-256 of the + region name. A leading ``#`` is added if not already present. + Pass ``None`` to clear the scope (flood to all). + """ + if region_name: + if not region_name.startswith("#"): + region_name = f"#{region_name}" + self._flood_transport_key = get_auto_key_for(region_name) + else: + self._flood_transport_key = None + + def _apply_flood_scope(self, pkt: Packet) -> None: + """Apply flood scope transport codes to a packet in-place. + + If ``_flood_transport_key`` is set and the packet uses flood routing, + calculates the transport code, attaches it to the packet, and changes + the route type to ``ROUTE_TYPE_TRANSPORT_FLOOD``. + + Matches firmware ``sendFloodScoped()`` in ``BaseChatMesh.cpp``. + """ + if self._flood_transport_key is None: + return + route_type = pkt.get_route_type() + if route_type != ROUTE_TYPE_FLOOD: + return # only scope flood packets, not direct + code = calc_transport_code(self._flood_transport_key, pkt) + pkt.transport_codes[0] = code + pkt.transport_codes[1] = 0 # reserved for home region (firmware TODO) + # Switch route type from FLOOD -> TRANSPORT_FLOOD + pkt.header = (pkt.header & ~0x03) | ROUTE_TYPE_TRANSPORT_FLOOD + # ------------------------------------------------------------------------- # Statistics (subclasses may override _get_radio_stats for STATS_TYPE_RADIO) # ------------------------------------------------------------------------- @@ -678,6 +715,7 @@ async def advertise(self, flood: bool = True) -> bool: flags=flags, route_type=route, ) + self._apply_flood_scope(pkt) success = await self._send_packet(pkt, wait_for_ack=False) if success: self.stats.record_tx(is_flood=flood) @@ -697,6 +735,7 @@ async def share_contact(self, pub_key: bytes) -> bool: flags=adv_type_to_flags(contact.adv_type) | ADVERT_FLAG_HAS_NAME, route_type="direct", ) + self._apply_flood_scope(pkt) return await self._send_packet(pkt, wait_for_ack=False) except Exception as e: logger.error(f"Error sharing contact: {e}") @@ -713,6 +752,7 @@ async def send_trace_path_raw( try: path_list = list(path_bytes) pkt = PacketBuilder.create_trace(tag, auth_code, flags, path=path_list) + self._apply_flood_scope(pkt) return await self._send_packet(pkt, wait_for_ack=False) except Exception as e: logger.error(f"Error sending trace (raw path): {e}") @@ -751,6 +791,7 @@ async def send_binary_req( protocol_code=PROTOCOL_CODE_BINARY_REQ, data=req_payload, ) + self._apply_flood_scope(pkt) success = await self._send_packet(pkt, wait_for_ack=False) except Exception as e: logger.error(f"Binary request send error: {e}") @@ -801,6 +842,7 @@ async def send_path_discovery_req(self, pub_key: bytes) -> SentResult: protocol_code=REQ_TYPE_GET_TELEMETRY_DATA, data=req_payload, ) + self._apply_flood_scope(pkt) success = await self._send_packet(pkt, wait_for_ack=False) if success: self._pending_discovery_tags.add(tag_int) diff --git a/src/pymc_core/companion/companion_bridge.py b/src/pymc_core/companion/companion_bridge.py index 78ac5d3..1c7f96a 100644 --- a/src/pymc_core/companion/companion_bridge.py +++ b/src/pymc_core/companion/companion_bridge.py @@ -321,6 +321,7 @@ async def send_text_message( attempt=attempt, message_type=msg_type, ) + self._apply_flood_scope(pkt) if len(self._pending_ack_crcs) < MAX_PENDING_ACK_CRCS: self._pending_ack_crcs.add(ack_crc) success = await self._packet_injector(pkt, wait_for_ack=True) @@ -352,6 +353,7 @@ async def send_channel_message(self, channel_idx: int, text: str) -> bool: sender_name=self.prefs.node_name, channels_config=self.channels.get_channels(), ) + self._apply_flood_scope(pkt) success = await self._packet_injector(pkt, wait_for_ack=False) if success: self.stats.record_tx(is_flood=True) diff --git a/src/pymc_core/companion/companion_radio.py b/src/pymc_core/companion/companion_radio.py index acaab22..49dfbe4 100644 --- a/src/pymc_core/companion/companion_radio.py +++ b/src/pymc_core/companion/companion_radio.py @@ -222,6 +222,20 @@ async def send_raw_data( logger.error(f"Error sending raw data: {e}") return SentResult(success=False) + # ------------------------------------------------------------------------- + # Flood Scope (sync to dispatcher) + # ------------------------------------------------------------------------- + + def set_flood_scope(self, transport_key: Optional[bytes] = None) -> None: + """Set or clear flood scope and propagate to the dispatcher.""" + super().set_flood_scope(transport_key) + self.node.dispatcher.flood_transport_key = self._flood_transport_key + + def set_flood_region(self, region_name: Optional[str] = None) -> None: + """Set flood region and propagate to the dispatcher.""" + super().set_flood_region(region_name) + self.node.dispatcher.flood_transport_key = self._flood_transport_key + # ------------------------------------------------------------------------- # Device Configuration (overrides for radio hardware) # ------------------------------------------------------------------------- diff --git a/src/pymc_core/node/dispatcher.py b/src/pymc_core/node/dispatcher.py index 847852e..f8fa8ae 100644 --- a/src/pymc_core/node/dispatcher.py +++ b/src/pymc_core/node/dispatcher.py @@ -10,7 +10,10 @@ PAYLOAD_TYPE_ACK, PAYLOAD_TYPE_ADVERT, PH_TYPE_SHIFT, + ROUTE_TYPE_FLOOD, + ROUTE_TYPE_TRANSPORT_FLOOD, ) +from ..protocol.transport_keys import calc_transport_code from ..protocol.utils import PAYLOAD_TYPES, ROUTE_TYPES, format_packet_info # Import handler classes @@ -85,6 +88,11 @@ def __init__( # Contact book for decrypting messages (set by the node later) self.contact_book = None + # Flood scope: 16-byte transport key for region-scoped flooding. + # When set, flood packets are tagged with a transport code and sent + # as ROUTE_TYPE_TRANSPORT_FLOOD. Set via companion set_flood_scope(). + self.flood_transport_key: Optional[bytes] = None + self._logger = logging.getLogger("Dispatcher") self._current_expected_crc: Optional[int] = None self._recent_acks: dict[int, float] = {} # {crc: timestamp} @@ -452,6 +460,23 @@ async def _process_received_packet( # Public interface - sending and receiving packets # ------------------------------------------------------------------ + def _apply_flood_scope(self, pkt: Packet) -> None: + """Apply flood scope transport codes to a packet in-place. + + If ``flood_transport_key`` is set and the packet uses flood routing, + calculates the transport code, attaches it to the packet, and + switches the route type to ``ROUTE_TYPE_TRANSPORT_FLOOD``. + """ + if self.flood_transport_key is None: + return + route_type = pkt.get_route_type() + if route_type != ROUTE_TYPE_FLOOD: + return + code = calc_transport_code(self.flood_transport_key, pkt) + pkt.transport_codes[0] = code + pkt.transport_codes[1] = 0 # reserved for home region + pkt.header = (pkt.header & ~0x03) | ROUTE_TYPE_TRANSPORT_FLOOD + async def send_packet( self, packet: Packet, @@ -468,6 +493,7 @@ async def send_packet( expected_crc: The expected CRC for ACK matching. If None, will be calculated from packet. """ + self._apply_flood_scope(packet) async with self._tx_lock: # Wait our turn return await self._send_packet_immediate(packet, wait_for_ack, expected_crc) diff --git a/tests/test_companion_regions.py b/tests/test_companion_regions.py new file mode 100644 index 0000000..d8a3994 --- /dev/null +++ b/tests/test_companion_regions.py @@ -0,0 +1,258 @@ +"""Tests for companion flood-scope / region support.""" + +from __future__ import annotations + +import pytest + +from pymc_core.companion import CompanionRadio +from pymc_core.companion.constants import ADV_TYPE_CHAT +from pymc_core.companion.models import Contact +from pymc_core.protocol import LocalIdentity, Packet, PacketBuilder +from pymc_core.protocol.constants import ( + ROUTE_TYPE_DIRECT, + ROUTE_TYPE_FLOOD, + ROUTE_TYPE_TRANSPORT_FLOOD, +) +from pymc_core.protocol.transport_keys import calc_transport_code, get_auto_key_for + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_flood_packet() -> Packet: + """Create a minimal flood-routed advert packet for testing.""" + identity = LocalIdentity() + return PacketBuilder.create_advert( + local_identity=identity, + name="test", + route_type="flood", + ) + + +def _make_direct_packet() -> Packet: + """Create a minimal direct-routed advert packet for testing.""" + identity = LocalIdentity() + return PacketBuilder.create_advert( + local_identity=identity, + name="test", + route_type="direct", + ) + + +class MockRadio: + """Minimal mock radio for CompanionRadio.""" + + def __init__(self): + self.rx_callback = None + self.sent: list[bytes] = [] + + def set_rx_callback(self, callback): + self.rx_callback = callback + + async def send(self, data: bytes) -> bool: + self.sent.append(data) + return True + + +def _make_companion() -> CompanionRadio: + """Create a CompanionRadio with a mock radio for testing.""" + radio = MockRadio() + identity = LocalIdentity() + return CompanionRadio(radio=radio, identity=identity, node_name="test") + + +def _make_peer_contact(name: str) -> Contact: + """Return a contact with a valid Ed25519 public key.""" + peer = LocalIdentity() + return Contact(public_key=peer.get_public_key(), name=name) + + +# --------------------------------------------------------------------------- +# _apply_flood_scope unit tests +# --------------------------------------------------------------------------- + + +class TestApplyFloodScope: + def test_sets_transport_codes_on_flood_packet(self): + companion = _make_companion() + key = get_auto_key_for("#usa") + companion.set_flood_scope(key) + pkt = _make_flood_packet() + + companion._apply_flood_scope(pkt) + + assert pkt.get_route_type() == ROUTE_TYPE_TRANSPORT_FLOOD + assert pkt.transport_codes[0] != 0 + assert pkt.transport_codes[1] == 0 + + def test_transport_code_matches_calc(self): + companion = _make_companion() + key = get_auto_key_for("#test-region") + companion.set_flood_scope(key) + pkt = _make_flood_packet() + + expected_code = calc_transport_code(key, pkt) + companion._apply_flood_scope(pkt) + + assert pkt.transport_codes[0] == expected_code + + def test_noop_when_no_key_set(self): + companion = _make_companion() + pkt = _make_flood_packet() + original_header = pkt.header + + companion._apply_flood_scope(pkt) + + assert pkt.header == original_header + assert pkt.get_route_type() == ROUTE_TYPE_FLOOD + assert pkt.transport_codes == [0, 0] + + def test_noop_on_direct_packet(self): + companion = _make_companion() + key = get_auto_key_for("#usa") + companion.set_flood_scope(key) + pkt = _make_direct_packet() + original_header = pkt.header + + companion._apply_flood_scope(pkt) + + assert pkt.header == original_header + assert pkt.get_route_type() == ROUTE_TYPE_DIRECT + assert pkt.transport_codes == [0, 0] + + +# --------------------------------------------------------------------------- +# set_flood_region tests +# --------------------------------------------------------------------------- + + +class TestSetFloodRegion: + def test_derives_key_with_hash_prefix(self): + companion = _make_companion() + companion.set_flood_region("#usa") + assert companion._flood_transport_key == get_auto_key_for("#usa") + + def test_auto_adds_hash_prefix(self): + companion = _make_companion() + companion.set_flood_region("usa") + assert companion._flood_transport_key == get_auto_key_for("#usa") + + def test_clear_with_none(self): + companion = _make_companion() + companion.set_flood_region("usa") + assert companion._flood_transport_key is not None + companion.set_flood_region(None) + assert companion._flood_transport_key is None + + def test_same_key_with_or_without_prefix(self): + c1 = _make_companion() + c2 = _make_companion() + c1.set_flood_region("europe") + c2.set_flood_region("#europe") + assert c1._flood_transport_key == c2._flood_transport_key + + +# --------------------------------------------------------------------------- +# set_flood_scope tests +# --------------------------------------------------------------------------- + + +class TestSetFloodScope: + def test_stores_16_byte_key(self): + companion = _make_companion() + key = b"\x01" * 16 + companion.set_flood_scope(key) + assert companion._flood_transport_key == key + + def test_truncates_longer_key(self): + companion = _make_companion() + key = b"\x02" * 32 + companion.set_flood_scope(key) + assert companion._flood_transport_key == b"\x02" * 16 + + def test_clear_with_none(self): + companion = _make_companion() + companion.set_flood_scope(b"\x01" * 16) + companion.set_flood_scope(None) + assert companion._flood_transport_key is None + + +# --------------------------------------------------------------------------- +# CompanionRadio dispatcher sync +# --------------------------------------------------------------------------- + + +class TestRadioDispatcherSync: + def test_set_flood_scope_syncs_to_dispatcher(self): + companion = _make_companion() + key = get_auto_key_for("#test") + companion.set_flood_scope(key) + assert companion.node.dispatcher.flood_transport_key == key + + def test_set_flood_region_syncs_to_dispatcher(self): + companion = _make_companion() + companion.set_flood_region("test") + expected = get_auto_key_for("#test") + assert companion.node.dispatcher.flood_transport_key == expected + + def test_clear_syncs_to_dispatcher(self): + companion = _make_companion() + companion.set_flood_scope(b"\x01" * 16) + assert companion.node.dispatcher.flood_transport_key is not None + companion.set_flood_scope(None) + assert companion.node.dispatcher.flood_transport_key is None + + +# --------------------------------------------------------------------------- +# Integration: advertise with flood scope +# --------------------------------------------------------------------------- + + +class TestAdvertiseWithFloodScope: + @pytest.mark.asyncio + async def test_advertise_flood_with_scope_sends_transport_flood(self): + radio = MockRadio() + identity = LocalIdentity() + companion = CompanionRadio( + radio=radio, identity=identity, node_name="scoped" + ) + companion.set_flood_region("usa") + + await companion.start() + try: + await companion.advertise(flood=True) + finally: + await companion.stop() + + # Verify the sent packet has transport codes + assert len(radio.sent) > 0 + raw = radio.sent[-1] + pkt = Packet() + pkt.read_from(raw) + assert pkt.get_route_type() == ROUTE_TYPE_TRANSPORT_FLOOD + assert pkt.transport_codes[0] != 0 + assert pkt.transport_codes[1] == 0 + + @pytest.mark.asyncio + async def test_advertise_flood_without_scope_sends_normal_flood(self): + radio = MockRadio() + identity = LocalIdentity() + companion = CompanionRadio( + radio=radio, identity=identity, node_name="noscope" + ) + # No flood scope set + + await companion.start() + try: + await companion.advertise(flood=True) + finally: + await companion.stop() + + assert len(radio.sent) > 0 + raw = radio.sent[-1] + pkt = Packet() + pkt.read_from(raw) + assert pkt.get_route_type() == ROUTE_TYPE_FLOOD + assert pkt.transport_codes == [0, 0] From 151da9cdfb3cde1eb186e4ec84ab9db2425623b5 Mon Sep 17 00:00:00 2001 From: agessaman Date: Sun, 15 Feb 2026 20:39:38 -0800 Subject: [PATCH 09/50] Refactor code for improved readability and consistency with pyMC_core contribution guide - Simplified function calls in discover_nodes.py for better clarity. - Reorganized import statements in test_modem_crypto.py and companion modules to follow PEP 8 guidelines. - Enhanced code formatting in various files for improved readability, including consistent spacing and line breaks. - Updated comments and docstrings for clarity and conciseness across multiple modules. --- examples/discover_nodes.py | 4 +- scripts/test_modem_crypto.py | 12 +- src/pymc_core/companion/__init__.py | 24 +-- src/pymc_core/companion/binary_parsing.py | 11 +- src/pymc_core/companion/companion_base.py | 40 ++-- src/pymc_core/companion/companion_bridge.py | 64 +++---- src/pymc_core/companion/companion_radio.py | 34 ++-- src/pymc_core/companion/constants.py | 3 + src/pymc_core/companion/contact_store.py | 5 +- src/pymc_core/companion/models.py | 2 +- src/pymc_core/hardware/kiss_modem_wrapper.py | 23 +-- src/pymc_core/node/dispatcher.py | 16 +- src/pymc_core/node/handlers/advert.py | 4 +- src/pymc_core/node/handlers/control.py | 28 +-- src/pymc_core/node/handlers/group_text.py | 10 +- src/pymc_core/node/handlers/login_response.py | 12 +- src/pymc_core/node/handlers/login_server.py | 16 +- src/pymc_core/node/handlers/path.py | 8 +- .../node/handlers/protocol_request.py | 141 +++++++------- .../node/handlers/protocol_response.py | 172 ++++++++++-------- src/pymc_core/node/handlers/text.py | 4 +- src/pymc_core/protocol/constants.py | 2 +- src/pymc_core/protocol/identity.py | 8 +- src/pymc_core/protocol/modem_identity.py | 3 +- src/pymc_core/protocol/packet_builder.py | 8 +- src/pymc_core/protocol/transport_keys.py | 29 +-- tests/test_companion_base.py | 1 - tests/test_companion_bridge.py | 12 +- tests/test_companion_regions.py | 10 +- tests/test_companion_stores.py | 29 +-- tests/test_kiss_modem_wrapper.py | 42 ++--- tests/test_modem_identity.py | 1 - tests/test_packet_utils.py | 8 +- 33 files changed, 379 insertions(+), 407 deletions(-) diff --git a/examples/discover_nodes.py b/examples/discover_nodes.py index 5e69845..196b583 100644 --- a/examples/discover_nodes.py +++ b/examples/discover_nodes.py @@ -177,9 +177,7 @@ def main(): if args.radio_type == "kiss-tnc": print(f"Serial port: {args.serial_port}") - asyncio.run( - discover_nodes(args.radio_type, args.serial_port, args.timeout, args.filter) - ) + asyncio.run(discover_nodes(args.radio_type, args.serial_port, args.timeout, args.filter)) if __name__ == "__main__": diff --git a/scripts/test_modem_crypto.py b/scripts/test_modem_crypto.py index cb4bbf4..d8f18f8 100644 --- a/scripts/test_modem_crypto.py +++ b/scripts/test_modem_crypto.py @@ -9,14 +9,14 @@ - Encryption/decryption """ -import sys import os +import sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) -from pymc_core.hardware.kiss_modem_wrapper import KissModemWrapper -from pymc_core.protocol.identity import Identity, LocalIdentity -from pymc_core.protocol.crypto import CryptoUtils +from pymc_core.hardware.kiss_modem_wrapper import KissModemWrapper # noqa: E402 +from pymc_core.protocol.crypto import CryptoUtils # noqa: E402 +from pymc_core.protocol.identity import Identity, LocalIdentity # noqa: E402 def test_modem_crypto(port: str = "/dev/cu.usbmodem1101"): @@ -149,7 +149,7 @@ def test_modem_crypto(port: str = "/dev/cu.usbmodem1101"): modem_decrypted = modem.decrypt_data(key, mac, ciphertext) if modem_decrypted: # Trim padding (modem pads to block size) - modem_decrypted = modem_decrypted[:len(plaintext)] + modem_decrypted = modem_decrypted[: len(plaintext)] print(f"Modem decrypted: {modem_decrypted}") if modem_decrypted == plaintext: @@ -178,7 +178,7 @@ def test_modem_crypto(port: str = "/dev/cu.usbmodem1101"): # Decrypt with modem modem_decrypted2 = modem.decrypt_data(key, python_mac, python_ciphertext) if modem_decrypted2: - modem_decrypted2 = modem_decrypted2[:len(plaintext)] + modem_decrypted2 = modem_decrypted2[: len(plaintext)] print(f"Modem decrypted Python ciphertext: {modem_decrypted2}") if modem_decrypted2 == plaintext: diff --git a/src/pymc_core/companion/__init__.py b/src/pymc_core/companion/__init__.py index dbf2b10..582c034 100644 --- a/src/pymc_core/companion/__init__.py +++ b/src/pymc_core/companion/__init__.py @@ -6,13 +6,9 @@ statistics, and device configuration on top of MeshNode. """ -from .companion_radio import CompanionRadio -from .companion_bridge import CompanionBridge from .channel_store import ChannelStore -from .contact_store import ContactStore -from .message_queue import MessageQueue -from .path_cache import PathCache -from .stats_collector import StatsCollector +from .companion_bridge import CompanionBridge +from .companion_radio import CompanionRadio from .constants import ( ADV_TYPE_CHAT, ADV_TYPE_REPEATER, @@ -25,7 +21,6 @@ AUTOADD_REPEATER, AUTOADD_ROOM, AUTOADD_SENSOR, - BinaryReqType, DEFAULT_MAX_CHANNELS, DEFAULT_MAX_CONTACTS, DEFAULT_OFFLINE_QUEUE_SIZE, @@ -41,16 +36,13 @@ TXT_TYPE_CLI_DATA, TXT_TYPE_PLAIN, TXT_TYPE_SIGNED_PLAIN, + BinaryReqType, ) -from .models import ( - AdvertPath, - Channel, - Contact, - NodePrefs, - PacketStats, - QueuedMessage, - SentResult, -) +from .contact_store import ContactStore +from .message_queue import MessageQueue +from .models import AdvertPath, Channel, Contact, NodePrefs, PacketStats, QueuedMessage, SentResult +from .path_cache import PathCache +from .stats_collector import StatsCollector __all__ = [ # Main classes diff --git a/src/pymc_core/companion/binary_parsing.py b/src/pymc_core/companion/binary_parsing.py index c3bc582..d2fbcd9 100644 --- a/src/pymc_core/companion/binary_parsing.py +++ b/src/pymc_core/companion/binary_parsing.py @@ -2,8 +2,7 @@ from __future__ import annotations -import struct -from typing import Any, Optional +from typing import Optional from .constants import BinaryReqType @@ -63,12 +62,15 @@ def _parse_status(data: bytes, pubkey_prefix: Optional[str] = None, offset: int def _parse_telemetry(data: bytes) -> dict: - """Telemetry: Cayenne LPP or raw. Return dict with raw_hex; optional LPP if cayennelpp available.""" + """Telemetry: Cayenne LPP or raw. Dict has raw_hex; optional LPP if cayennelpp available.""" out: dict = {"raw_hex": data.hex()} try: from cayennelpp import LppFrame + frame = LppFrame.from_bytes(data) - out["lpp"] = [{"channel": d.channel, "type": d.type_id, "value": d.data} for d in frame.data] + out["lpp"] = [ + {"channel": d.channel, "type": d.type_id, "value": d.data} for d in frame.data + ] except Exception: pass return out @@ -79,6 +81,7 @@ def _parse_mma(data: bytes) -> dict: out: dict = {"raw_hex": data.hex()} try: from cayennelpp import LppFrame + frame = LppFrame.from_bytes(data) out["mma"] = [{"channel": d.channel, "type": d.type_id, "data": d.data} for d in frame.data] except Exception: diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py index 2a81358..1fc740c 100644 --- a/src/pymc_core/companion/companion_base.py +++ b/src/pymc_core/companion/companion_base.py @@ -169,21 +169,19 @@ def _init_companion_stores( self._event_subscriber = _CompanionEventSubscriber(self) self._event_service.subscribe_all(self._event_subscriber) - self._push_callbacks: dict[str, list[Callable]] = { - k: [] for k in PUSH_CALLBACK_KEYS - } + self._push_callbacks: dict[str, list[Callable]] = {k: [] for k in PUSH_CALLBACK_KEYS} # Pending binary requests by tag (hex) for matching responses self._pending_binary_requests: dict[str, dict] = {} # Pending path discovery tags for matching responses self._pending_discovery_tags: set[int] = set() - # GRP_TXT dedup by packet hash: match Mesh.cpp behavior (only process when !_tables->hasSeen(pkt)), - # so companion queues one frame per logical message like the firmware. + # GRP_TXT dedup by packet hash: match Mesh.cpp (!_tables->hasSeen(pkt)); + # companion queues one frame per logical message like the firmware. self._seen_grp_txt: OrderedDict[str, float] = OrderedDict() self._seen_grp_txt_ttl = 300 self._seen_grp_txt_max = 1000 - # TXT_MSG (direct) dedup by packet hash so reconnects don't queue the same packet multiple times. + # TXT_MSG (direct) dedup by packet hash so reconnects don't re-queue same packet. self._seen_txt: OrderedDict[str, float] = OrderedDict() self._seen_txt_ttl = 300 self._seen_txt_max = 1000 @@ -549,11 +547,11 @@ def on_raw_data_received(self, callback: Callable) -> None: self._push_callbacks["raw_data_received"].append(callback) def on_binary_response(self, callback: Callable) -> None: - """Register callback for PUSH_CODE_BINARY_RESPONSE (0x8C). Callback(tag_bytes, response_data, ...).""" + """Register callback for PUSH 0x8C. Callback(tag_bytes, response_data).""" self._push_callbacks["binary_response"].append(callback) def on_path_discovery_response(self, callback: Callable) -> None: - """Register callback for path discovery response (PUSH 0x8D). Callback(tag_bytes, contact_pubkey, out_path, in_path).""" + """Register callback for path discovery 0x8D. (tag_bytes, pubkey, out_path, in_path).""" self._push_callbacks["path_discovery_response"].append(callback) def register_binary_request( @@ -564,7 +562,7 @@ def register_binary_request( pubkey_prefix: str = "", context: Optional[dict] = None, ) -> None: - """Register a pending binary request for matching responses. Call cleanup_expired_requests first.""" + """Register a pending binary request. Call cleanup_expired_requests first.""" self._pending_binary_requests[tag_hex] = { "request_type": request_type, "pubkey_prefix": pubkey_prefix, @@ -576,8 +574,7 @@ def cleanup_expired_binary_requests(self) -> None: """Remove expired entries from _pending_binary_requests.""" now = time.time() expired = [ - tag for tag, info in self._pending_binary_requests.items() - if now > info["expires_at"] + tag for tag, info in self._pending_binary_requests.items() if now > info["expires_at"] ] for tag in expired: del self._pending_binary_requests[tag] @@ -588,7 +585,7 @@ async def _on_binary_response( response_data: bytes, path_info: Optional[tuple] = None, ) -> None: - """Called by ProtocolResponseHandler when a binary response (tag + data, optional path) is received.""" + """Called when binary response (tag + data, optional path) received.""" if path_info is not None: if await self._try_handle_path_discovery(tag_bytes, path_info): return @@ -596,7 +593,7 @@ async def _on_binary_response( tag_hex = tag_bytes.hex() info = self._pending_binary_requests.pop(tag_hex, None) if not info: - # Skip log for small payloads (e.g. login response already handled by LoginResponseHandler) + # Skip log for small payloads (e.g. login response handled elsewhere) if len(response_data) >= 20: logger.debug(f"Binary response for unknown tag {tag_hex}") await self._fire_callbacks("binary_response", tag_bytes, response_data) @@ -607,6 +604,7 @@ async def _on_binary_response( parsed = None try: from . import binary_parsing + parsed = binary_parsing.parse_binary_response( request_type, response_data, pubkey_prefix=pubkey_prefix, context=context ) @@ -616,9 +614,7 @@ async def _on_binary_response( "binary_response", tag_bytes, response_data, parsed, request_type ) - async def _try_handle_path_discovery( - self, tag_bytes: bytes, path_info: tuple - ) -> bool: + async def _try_handle_path_discovery(self, tag_bytes: bytes, path_info: tuple) -> bool: """If tag is pending path discovery, fire path_discovery_response and return True.""" out_path, in_path, contact_pubkey = path_info tag_int = int.from_bytes(tag_bytes, "little") @@ -639,9 +635,7 @@ async def _try_handle_path_discovery( # ------------------------------------------------------------------------- @abstractmethod - async def _send_packet( - self, pkt: Packet, wait_for_ack: bool = False - ) -> bool: + async def _send_packet(self, pkt: Packet, wait_for_ack: bool = False) -> bool: """Send a packet via the subclass transport (radio or packet_injector).""" @abstractmethod @@ -827,9 +821,7 @@ async def send_path_discovery_req(self, pub_key: bytes) -> SentResult: tag_int = random.randint(0, 0xFFFFFFFF) tag_bytes = tag_int.to_bytes(4, "little") inv_perm = 0xFF & ~TELEM_PERM_BASE - req_payload = tag_bytes + bytes( - [REQ_TYPE_GET_TELEMETRY_DATA, inv_perm, 0, 0, 0] - ) + req_payload = tag_bytes + bytes([REQ_TYPE_GET_TELEMETRY_DATA, inv_perm, 0, 0, 0]) old_path_len = contact.out_path_len old_path = contact.out_path contact.out_path_len = -1 @@ -870,9 +862,7 @@ def sync_next_message(self) -> Optional[QueuedMessage]: # Dedup Helper # ------------------------------------------------------------------------- - def _check_dedup( - self, cache: OrderedDict, key: str, ttl: float, max_size: int - ) -> bool: + def _check_dedup(self, cache: OrderedDict, key: str, ttl: float, max_size: int) -> bool: """Return True if *key* is a duplicate. Evicts expired entries.""" now = time.time() if key in cache: diff --git a/src/pymc_core/companion/companion_bridge.py b/src/pymc_core/companion/companion_bridge.py index 1c7f96a..042c709 100644 --- a/src/pymc_core/companion/companion_bridge.py +++ b/src/pymc_core/companion/companion_bridge.py @@ -22,8 +22,7 @@ TextMessageHandler, ) from ..node.handlers.login_server import LoginServerHandler -from ..protocol import LocalIdentity, PacketBuilder -from ..protocol import Packet +from ..protocol import LocalIdentity, Packet, PacketBuilder from ..protocol.constants import ( PAYLOAD_TYPE_ACK, PAYLOAD_TYPE_ADVERT, @@ -62,7 +61,7 @@ class _BridgeAckHandler: - """Handles discrete ACK packets and PathHandler stub. Fires send_confirmed when ACK CRC matches a pending send.""" + """Handles ACK packets. Fires send_confirmed when ACK CRC matches.""" def __init__(self, bridge: "CompanionBridge") -> None: self._bridge = bridge @@ -129,6 +128,7 @@ def _login_send_callback(pkt: Packet, delay_ms: int) -> None: async def _delayed_send() -> None: await asyncio.sleep(delay_ms / 1000.0) await self._packet_injector(pkt, wait_for_ack=False) + asyncio.create_task(_delayed_send()) def _log(msg: str) -> None: @@ -136,23 +136,19 @@ def _log(msg: str) -> None: self._pending_ack_crcs: set[int] = set() ack_handler = _BridgeAckHandler(self) - protocol_response_handler = ProtocolResponseHandler( - _log, identity, self.contacts - ) - login_response_handler = LoginResponseHandler( - identity, self.contacts, _log - ) - login_response_handler.set_protocol_response_handler( - protocol_response_handler - ) + protocol_response_handler = ProtocolResponseHandler(_log, identity, self.contacts) + login_response_handler = LoginResponseHandler(identity, self.contacts, _log) + login_response_handler.set_protocol_response_handler(protocol_response_handler) path_handler = PathHandler( _log, ack_handler, protocol_response_handler, login_response_handler ) auth_cb = authenticate_callback if auth_cb is None: + def _reject_all(*args, **kwargs) -> tuple[bool, int]: return (False, 0) + auth_cb = _reject_all login_server_handler = LoginServerHandler( @@ -170,9 +166,7 @@ def _reject_all(*args, **kwargs) -> tuple[bool, int]: self._event_service, self._radio_config, ), - PAYLOAD_TYPE_ADVERT: AdvertHandler( - _log, event_service=self._event_service - ), + PAYLOAD_TYPE_ADVERT: AdvertHandler(_log, event_service=self._event_service), PAYLOAD_TYPE_PATH: path_handler, PAYLOAD_TYPE_ANON_REQ: login_server_handler, PAYLOAD_TYPE_GRP_TXT: GroupTextHandler( @@ -225,8 +219,8 @@ def _update_stores_from_advert(self, packet: Packet, advert_data: dict): name = advert_data.get("name", "") if not name: return None - # Inbound path: route the advert took to reach us (for discovery list / advert path display). - # Stored in path_cache only; contact.out_path is separate and set elsewhere (e.g. path discovery). + # Inbound path: route the advert took (discovery list / advert path display). + # Stored in path_cache only; contact.out_path is separate (e.g. path discovery). path_len = getattr(packet, "path_len", 0) or 0 path = getattr(packet, "path", bytearray()) or bytearray() effective_len = path_len if path_len > 0 else len(path) @@ -235,7 +229,7 @@ def _update_stores_from_advert(self, packet: Packet, advert_data: dict): last_advert_ts = advert_data.get("advert_timestamp", 0) if last_advert_ts > now: last_advert_ts = now - # Contact: out_path is for sending to this contact; leave unknown (-1) until set by path update. + # Contact: out_path is for sending; leave unknown (-1) until set by path update. contact = Contact( public_key=pub_key, name=name, @@ -249,7 +243,7 @@ def _update_stores_from_advert(self, packet: Packet, advert_data: dict): ) self.contacts.add(contact) - # Path cache: store inbound path (path advert took to get here) for discovery list display. + # Path cache: store inbound path for discovery list display. self.path_cache.update( AdvertPath( public_key_prefix=pub_key[:7], @@ -268,9 +262,7 @@ def _update_stores_from_advert(self, packet: Packet, advert_data: dict): # Abstract method implementations # ------------------------------------------------------------------------- - async def _send_packet( - self, pkt: Packet, wait_for_ack: bool = False - ) -> bool: + async def _send_packet(self, pkt: Packet, wait_for_ack: bool = False) -> bool: """Send a packet via the packet_injector.""" return await self._packet_injector(pkt, wait_for_ack=wait_for_ack) @@ -415,8 +407,9 @@ async def send_trace_path( return False async def send_control_data(self, data: bytes) -> bool: - """Send a CONTROL packet (e.g. discovery request). data = first byte flags/type (0x80 set for DISCOVER_REQ) + payload. - Firmware: (cmd_frame[1] & 0x80) != 0, createControlData(&cmd_frame[1], len-1), sendZeroHop(resp). Returns True if sent.""" + """Send CONTROL packet (e.g. discovery). data = flags (0x80 for DISCOVER_REQ) + payload. + Returns True if sent. + """ if not data or len(data) > 254: return False if (data[0] & 0x80) == 0: @@ -440,9 +433,7 @@ async def send_control_data(self, data: bytes) -> bool: def import_private_key(self, key: bytes) -> bool: try: self._identity = LocalIdentity(seed=key) - logger.info( - f"Imported new identity: {self._identity.get_public_key().hex()[:16]}..." - ) + logger.info(f"Imported new identity: {self._identity.get_public_key().hex()[:16]}...") return True except Exception as e: logger.error(f"Error importing private key: {e}") @@ -513,9 +504,7 @@ async def send_status_request(self, pub_key: bytes, timeout: float = 15.0) -> di return {"success": False, "reason": "Contact not found"} contact_hash = bytes.fromhex(proxy.public_key)[0] waiter = ResponseWaiter() - self._protocol_response_handler.set_response_callback( - contact_hash, waiter.callback - ) + self._protocol_response_handler.set_response_callback(contact_hash, waiter.callback) try: pkt, _ = PacketBuilder.create_protocol_request( contact=proxy, @@ -554,9 +543,7 @@ async def send_telemetry_request( return {"success": False, "reason": "Contact not found"} contact_hash = bytes.fromhex(proxy.public_key)[0] waiter = ResponseWaiter() - self._protocol_response_handler.set_response_callback( - contact_hash, waiter.callback - ) + self._protocol_response_handler.set_response_callback(contact_hash, waiter.callback) try: inv = PacketBuilder._compute_inverse_perm_mask( want_base, want_location, want_environment @@ -583,15 +570,13 @@ async def send_telemetry_request( self._protocol_response_handler.clear_response_callback(contact_hash) async def send_binary_request(self, pub_key: bytes, data: bytes) -> dict: - """Legacy: send binary request and wait for response via waiter. Prefer send_binary_req + on_binary_response.""" + """Legacy: send binary request and wait. Prefer send_binary_req + on_binary_response.""" return await self._send_protocol_request(pub_key, PROTOCOL_CODE_BINARY_REQ, data) async def send_anon_request(self, pub_key: bytes, data: bytes) -> dict: return await self._send_protocol_request(pub_key, PROTOCOL_CODE_ANON_REQ, data) - async def _send_protocol_request( - self, pub_key: bytes, protocol_code: int, data: bytes - ) -> dict: + async def _send_protocol_request(self, pub_key: bytes, protocol_code: int, data: bytes) -> dict: contact = self.contacts.get_by_key(pub_key) if not contact: return {"success": False, "reason": "Contact not found"} @@ -600,9 +585,7 @@ async def _send_protocol_request( return {"success": False, "reason": "Contact not found"} contact_hash = bytes.fromhex(proxy.public_key)[0] waiter = ResponseWaiter() - self._protocol_response_handler.set_response_callback( - contact_hash, waiter.callback - ) + self._protocol_response_handler.set_response_callback(contact_hash, waiter.callback) try: pkt, _ = PacketBuilder.create_protocol_request( contact=proxy, @@ -671,4 +654,3 @@ def _response_cb(message_text: str, sender_contact: Any) -> None: return {"success": False, "reason": str(e)} finally: self._text_handler.set_command_response_callback(None) - diff --git a/src/pymc_core/companion/companion_radio.py b/src/pymc_core/companion/companion_radio.py index 49dfbe4..f98f839 100644 --- a/src/pymc_core/companion/companion_radio.py +++ b/src/pymc_core/companion/companion_radio.py @@ -16,9 +16,7 @@ from ..node.node import MeshNode from ..protocol import LocalIdentity, Packet, PacketBuilder -from ..protocol.constants import ( - PAYLOAD_TYPE_CONTROL, -) +from ..protocol.constants import PAYLOAD_TYPE_CONTROL from .companion_base import CompanionBase from .constants import ( ADV_TYPE_CHAT, @@ -105,9 +103,7 @@ def __init__( # Abstract method implementations # ------------------------------------------------------------------------- - async def _send_packet( - self, pkt: Packet, wait_for_ack: bool = False - ) -> bool: + async def _send_packet(self, pkt: Packet, wait_for_ack: bool = False) -> bool: """Send a packet via the MeshNode dispatcher.""" return await self.node.dispatcher.send_packet(pkt, wait_for_ack=wait_for_ack) @@ -317,9 +313,7 @@ def import_private_key(self, key: bytes) -> bool: event_service=self._event_service, ) self._setup_packet_callbacks() - logger.info( - f"Imported new identity: {self._identity.get_public_key().hex()[:16]}..." - ) + logger.info(f"Imported new identity: {self._identity.get_public_key().hex()[:16]}...") return True except Exception as e: logger.error(f"Error importing private key: {e}") @@ -376,7 +370,7 @@ async def send_telemetry_request( return {"success": False, "reason": str(e)} async def send_binary_request(self, pub_key: bytes, data: bytes) -> dict: - """Legacy: send binary request and wait for response via waiter. Prefer send_binary_req + on_binary_response.""" + """Legacy: send binary request and wait. Prefer send_binary_req + on_binary_response.""" contact = self.contacts.get_by_key(pub_key) if not contact: return {"success": False, "reason": "Contact not found"} @@ -417,14 +411,13 @@ async def send_repeater_command( command=command, parameters=parameters, ) + reason = "Command successful" if result.get("success") else "No response" return { "success": result.get("success", False), "repeater": contact.name, "command": command, "response": result.get("response"), - "reason": ( - "Command successful" if result.get("success") else "No response" - ), + "reason": reason, } except Exception as e: logger.error(f"Repeater command error: {e}") @@ -435,14 +428,13 @@ async def send_repeater_command( # ------------------------------------------------------------------------- async def send_control_data(self, data: Optional[bytes] = None) -> bool: - """Send a CONTROL packet. If data is provided and valid (len 1-254, first byte has 0x80), - send it as raw control payload; otherwise send a default discovery request (backward compat).""" + """Send CONTROL packet. If data valid (len 1-254, byte0 0x80), send as control payload; + else send default discovery request (backward compat). + """ if data and len(data) <= 254 and (data[0] & 0x80) != 0: try: pkt = Packet() - pkt.header = PacketBuilder._create_header( - PAYLOAD_TYPE_CONTROL, route_type="direct" - ) + pkt.header = PacketBuilder._create_header(PAYLOAD_TYPE_CONTROL, route_type="direct") pkt.path_len = 0 pkt.path = bytearray() pkt.payload = bytearray(data) @@ -479,13 +471,17 @@ def _setup_packet_callbacks(self) -> None: dispatcher = self.node.dispatcher dispatcher.set_packet_received_callback(self._on_packet_received) dispatcher.set_packet_sent_callback(self._on_packet_sent) - if hasattr(dispatcher, "protocol_response_handler") and dispatcher.protocol_response_handler: + if ( + hasattr(dispatcher, "protocol_response_handler") + and dispatcher.protocol_response_handler + ): dispatcher.protocol_response_handler.set_binary_response_callback( self._on_binary_response ) async def _on_packet_received(self, pkt: Any) -> None: from ..protocol.constants import ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD + route_type = pkt.get_route_type() is_flood = route_type in (ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD) self.stats.record_rx(is_flood=is_flood) diff --git a/src/pymc_core/companion/constants.py b/src/pymc_core/companion/constants.py index f62ff64..38de817 100644 --- a/src/pymc_core/companion/constants.py +++ b/src/pymc_core/companion/constants.py @@ -55,11 +55,13 @@ STATS_TYPE_RADIO = 1 STATS_TYPE_PACKETS = 2 + # --------------------------------------------------------------------------- # Binary request types (CMD_SEND_BINARY_REQ / PUSH_CODE_BINARY_RESPONSE) # --------------------------------------------------------------------------- class BinaryReqType(IntEnum): """Binary request type codes (companion frame protocol).""" + STATUS = 0x01 KEEP_ALIVE = 0x02 TELEMETRY = 0x03 @@ -67,6 +69,7 @@ class BinaryReqType(IntEnum): ACL = 0x05 NEIGHBOURS = 0x06 + # --------------------------------------------------------------------------- # Protocol Codes (used in create_protocol_request / send_protocol_request) # --------------------------------------------------------------------------- diff --git a/src/pymc_core/companion/contact_store.py b/src/pymc_core/companion/contact_store.py index adad857..a44c3a3 100644 --- a/src/pymc_core/companion/contact_store.py +++ b/src/pymc_core/companion/contact_store.py @@ -2,7 +2,6 @@ from __future__ import annotations -import time from typing import Iterable, Iterator, Optional from .constants import DEFAULT_MAX_CONTACTS @@ -195,7 +194,9 @@ def load_from_dicts(self, records: Iterable[dict]) -> None: name=rec.get("name", ""), adv_type=rec.get("adv_type", 0), flags=rec.get("flags", 0), - out_path_len=-1 if rec.get("out_path_len", -1) in (-1, 255) else rec.get("out_path_len", -1), + out_path_len=-1 + if rec.get("out_path_len", -1) in (-1, 255) + else rec.get("out_path_len", -1), out_path=out_path, last_advert_timestamp=rec.get("last_advert_timestamp", 0), lastmod=rec.get("lastmod", 0), diff --git a/src/pymc_core/companion/models.py b/src/pymc_core/companion/models.py index 426809d..f51a2fe 100644 --- a/src/pymc_core/companion/models.py +++ b/src/pymc_core/companion/models.py @@ -2,7 +2,7 @@ from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Optional diff --git a/src/pymc_core/hardware/kiss_modem_wrapper.py b/src/pymc_core/hardware/kiss_modem_wrapper.py index d93e187..278ac8b 100644 --- a/src/pymc_core/hardware/kiss_modem_wrapper.py +++ b/src/pymc_core/hardware/kiss_modem_wrapper.py @@ -18,6 +18,10 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, Callable, Dict, Optional, Union +import serial + +from .base import LoRaRadio + # RX callback: (data) for backward compat, or (data, rssi, snr) for per-packet metrics RxCallback = Union[ Callable[[bytes], None], @@ -42,9 +46,6 @@ def _invoke_rx_callback( else: callback(data) -import serial - -from .base import LoRaRadio # KISS Protocol Constants (shared with standard KISS) KISS_FEND = 0xC0 # Frame End @@ -701,9 +702,7 @@ def _send_command( self._pending_response = None # SetHardware frame: type 0x06, payload = sub_cmd (1 byte) + data - kiss_frame = self._encode_kiss_frame( - KISS_CMD_SETHARDWARE, bytes([sub_cmd]) + data - ) + kiss_frame = self._encode_kiss_frame(KISS_CMD_SETHARDWARE, bytes([sub_cmd]) + data) if not self._write_frame(kiss_frame): logger.warning("SetHardware frame write failed") @@ -1168,8 +1167,12 @@ def get_status(self) -> Dict[str, Any]: status: Dict[str, Any] = { "initialized": self.is_connected, "frequency": cfg["frequency"] if cfg else self.radio_config.get("frequency", 0), - "tx_power": tx_power if tx_power is not None else self.radio_config.get("tx_power", self.radio_config.get("power", 0)), - "spreading_factor": cfg["spreading_factor"] if cfg else self.radio_config.get("spreading_factor", 0), + "tx_power": tx_power + if tx_power is not None + else self.radio_config.get("tx_power", self.radio_config.get("power", 0)), + "spreading_factor": cfg["spreading_factor"] + if cfg + else self.radio_config.get("spreading_factor", 0), "bandwidth": cfg["bandwidth"] if cfg else self.radio_config.get("bandwidth", 0), "coding_rate": cfg["coding_rate"] if cfg else self.radio_config.get("coding_rate", 0), "last_rssi": self.stats.get("last_rssi", -999), @@ -1247,9 +1250,7 @@ def _decode_kiss_byte(self, byte: int): if len(self.rx_frame_buffer) >= MAX_FRAME_SIZE: # Frame too long (e.g. lost FEND); reset and resync at next FEND self.stats["frame_errors"] += 1 - logger.warning( - "KISS frame exceeded max size (%d), resyncing", MAX_FRAME_SIZE - ) + logger.warning("KISS frame exceeded max size (%d), resyncing", MAX_FRAME_SIZE) self.rx_frame_buffer.clear() self.in_frame = False else: diff --git a/src/pymc_core/node/dispatcher.py b/src/pymc_core/node/dispatcher.py index f8fa8ae..9bb7ef8 100644 --- a/src/pymc_core/node/dispatcher.py +++ b/src/pymc_core/node/dispatcher.py @@ -312,11 +312,10 @@ def set_raw_packet_callback( """Set callback for raw packet data (includes both parsed packet and raw bytes).""" self.raw_packet_callback = callback - def add_raw_packet_subscriber( - self, callback: Callable[..., Any] - ) -> None: - """Subscribe to every incoming raw packet. Callback receives (pkt, data) or (pkt, data, analysis). - Use this to forward raw RX to clients (e.g. PUSH_CODE_LOG_RX_DATA) so they can track repeats by packet hash.""" + def add_raw_packet_subscriber(self, callback: Callable[..., Any]) -> None: + """Subscribe to every raw packet. Callback (pkt, data) or (pkt, data, analysis). + Forward raw RX to clients to track repeats by packet hash. + """ if callback not in self._raw_packet_subscribers: self._raw_packet_subscribers.append(callback) @@ -331,7 +330,8 @@ def add_raw_rx_subscriber( self, callback: Callable[[bytes, int, float], Awaitable[None] | None] ) -> None: """Subscribe to every incoming raw RX. Callback receives (data, rssi, snr). - Called before duplicate/blacklist so clients get every repeat (e.g. PUSH_CODE_LOG_RX_DATA).""" + Called before duplicate/blacklist so clients get every repeat. + """ if callback not in self._raw_rx_subscribers: self._raw_rx_subscribers.append(callback) @@ -361,10 +361,10 @@ async def _process_received_packet( rssi: Optional[int] = None, snr: Optional[float] = None, ) -> None: - """Process a received packet from the radio callback. rssi/snr are per-packet when provided.""" + """Process received packet. rssi/snr are per-packet when provided.""" self._log(f"[RX DEBUG] Processing packet: {len(data)} bytes, data: {data.hex()[:32]}...") - # Notify raw RX subscribers first (every reception, including duplicates) so clients can track repeats + # Notify raw RX subscribers so clients can track repeats if rssi is not None: rssi_val = rssi elif hasattr(self.radio, "get_last_rssi"): diff --git a/src/pymc_core/node/handlers/advert.py b/src/pymc_core/node/handlers/advert.py index 83fc7ff..8095b6c 100644 --- a/src/pymc_core/node/handlers/advert.py +++ b/src/pymc_core/node/handlers/advert.py @@ -8,7 +8,6 @@ PAYLOAD_TYPE_ADVERT, PUB_KEY_SIZE, SIGNATURE_SIZE, - TIMESTAMP_SIZE, describe_advert_flags, ) from ...protocol.utils import determine_contact_type_from_flags, get_contact_type_name @@ -80,7 +79,8 @@ async def __call__(self, packet: Packet) -> Optional[Dict[str, Any]]: appdata = parsed["appdata"] if len(appdata) > MAX_ADVERT_DATA_SIZE: self.log( - f"Advert appdata too large ({len(appdata)} bytes), truncating to {MAX_ADVERT_DATA_SIZE}" + f"Advert appdata too large ({len(appdata)} bytes), " + f"truncating to {MAX_ADVERT_DATA_SIZE}" ) appdata = appdata[:MAX_ADVERT_DATA_SIZE] diff --git a/src/pymc_core/node/handlers/control.py b/src/pymc_core/node/handlers/control.py index 0286754..0f7dd25 100644 --- a/src/pymc_core/node/handlers/control.py +++ b/src/pymc_core/node/handlers/control.py @@ -49,9 +49,7 @@ def __init__( def payload_type() -> int: return PAYLOAD_TYPE_CONTROL - def set_response_callback( - self, tag: int, callback: Callable[[Dict[str, Any]], None] - ) -> None: + def set_response_callback(self, tag: int, callback: Callable[[Dict[str, Any]], None]) -> None: """Set callback for discovery responses with a specific tag.""" self._response_callbacks[tag] = callback @@ -59,9 +57,7 @@ def clear_response_callback(self, tag: int) -> None: """Clear callback for discovery responses with a specific tag.""" self._response_callbacks.pop(tag, None) - def set_request_callback( - self, callback: Callable[[Dict[str, Any]], None] - ) -> None: + def set_request_callback(self, callback: Callable[[Dict[str, Any]], None]) -> None: """Set callback for discovery requests (for logging/monitoring).""" self._request_callbacks[0] = callback @@ -78,9 +74,7 @@ async def __call__(self, pkt: Packet) -> Optional[Dict[str, Any]]: # Check if this is a zero-hop packet (path_len must be 0) if pkt.path_len != 0: - self._log( - f"[ControlHandler] Non-zero path length ({pkt.path_len}), ignoring" - ) + self._log(f"[ControlHandler] Non-zero path length ({pkt.path_len}), ignoring") return None # Extract control type (upper 4 bits of first byte) @@ -91,9 +85,7 @@ async def __call__(self, pkt: Packet) -> Optional[Dict[str, Any]]: elif control_type == CTL_TYPE_NODE_DISCOVER_RESP: return await self._handle_discovery_response(pkt) else: - self._log( - f"[ControlHandler] Unknown control type: 0x{control_type:02X}" - ) + self._log(f"[ControlHandler] Unknown control type: 0x{control_type:02X}") return None except Exception as e: @@ -102,7 +94,7 @@ async def __call__(self, pkt: Packet) -> Optional[Dict[str, Any]]: async def _handle_discovery_request(self, pkt: Packet) -> Optional[Dict[str, Any]]: """Handle node discovery request packet and return parsed data. - + Expected format: - byte 0: type (0x80) + flags (bit 0: prefix_only) - byte 1: filter (bitfield of node types to respond) @@ -156,7 +148,7 @@ async def _handle_discovery_request(self, pkt: Packet) -> Optional[Dict[str, Any async def _handle_discovery_response(self, pkt: Packet) -> Optional[Dict[str, Any]]: """Handle node discovery response packet and return parsed data. - + Response format: - byte 0: type (0x90) + node_type (lower 4 bits) - byte 1: SNR of our request (int8_t, multiplied by 4) @@ -188,8 +180,8 @@ async def _handle_discovery_response(self, pkt: Packet) -> Optional[Dict[str, An "tag": tag, "node_type": node_type, "inbound_snr": inbound_snr, # SNR of our request at their end - "response_snr": pkt._snr, # SNR of their response at our end - "rssi": pkt._rssi, # RSSI of their response at our end + "response_snr": pkt._snr, # SNR of their response at our end + "rssi": pkt._rssi, # RSSI of their response at our end "pub_key": pub_key.hex(), "pub_key_bytes": bytes(pub_key), "timestamp": time.time(), @@ -205,9 +197,7 @@ async def _handle_discovery_response(self, pkt: Packet) -> Optional[Dict[str, An f"[ControlHandler] Called response callback for tag 0x{tag:08X}" ) else: - self._debug_log( - f"[ControlHandler] No callback waiting for tag 0x{tag:08X}" - ) + self._debug_log(f"[ControlHandler] No callback waiting for tag 0x{tag:08X}") return response_data diff --git a/src/pymc_core/node/handlers/group_text.py b/src/pymc_core/node/handlers/group_text.py index 859876c..f2438ef 100644 --- a/src/pymc_core/node/handlers/group_text.py +++ b/src/pymc_core/node/handlers/group_text.py @@ -1,11 +1,7 @@ from typing import Optional from ...protocol import Packet -from ...protocol.constants import ( - PAYLOAD_TYPE_GRP_TXT, - ROUTE_TYPE_FLOOD, - ROUTE_TYPE_TRANSPORT_FLOOD, -) +from ...protocol.constants import PAYLOAD_TYPE_GRP_TXT, ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD from ...protocol.crypto import CryptoUtils from .base import BaseHandler @@ -127,7 +123,7 @@ def _parse_plaintext_message(self, plaintext: bytes) -> Optional[dict]: try: timestamp = int.from_bytes(plaintext[:4], "little") flags = plaintext[4] - # Decode and strip trailing null/padding (AES decrypt returns block-aligned data with zero padding) + # Decode and strip trailing null (AES decrypt is block-aligned) raw = plaintext[5:].decode("utf-8", errors="replace") message_content = raw.rstrip("\x00") @@ -308,7 +304,7 @@ async def _save_and_broadcast_group_message( }, } - # Publish channel message event (await so message is queued and MSG_WAITING sent before return) + # Publish channel message event (await so queued and MSG_WAITING sent) await self.event_service.publish(MeshEvents.NEW_CHANNEL_MESSAGE, message_data) self.log("Published group message event") except Exception as publish_error: diff --git a/src/pymc_core/node/handlers/login_response.py b/src/pymc_core/node/handlers/login_response.py index 6aea6f9..0ced44b 100644 --- a/src/pymc_core/node/handlers/login_response.py +++ b/src/pymc_core/node/handlers/login_response.py @@ -3,7 +3,12 @@ from typing import Callable, Optional from ...protocol import CryptoUtils, Identity, Packet -from ...protocol.constants import MAX_PATH_SIZE, PAYLOAD_TYPE_ANON_REQ, PAYLOAD_TYPE_PATH, PAYLOAD_TYPE_RESPONSE +from ...protocol.constants import ( + MAX_PATH_SIZE, + PAYLOAD_TYPE_ANON_REQ, + PAYLOAD_TYPE_PATH, + PAYLOAD_TYPE_RESPONSE, +) from .base import BaseHandler # Response codes from C++ server @@ -157,10 +162,7 @@ async def _decrypt_response( if pkt_type == PAYLOAD_TYPE_PATH and len(plaintext) >= 2: path_len_byte = plaintext[0] inner_offset = 1 + path_len_byte + 1 # skip path_len + path + extra_type - if ( - path_len_byte <= MAX_PATH_SIZE - and len(plaintext) >= inner_offset - ): + if path_len_byte <= MAX_PATH_SIZE and len(plaintext) >= inner_offset: extra_type = plaintext[1 + path_len_byte] & 0x0F if extra_type == PAYLOAD_TYPE_RESPONSE and len(plaintext) > inner_offset: plaintext = plaintext[inner_offset:] diff --git a/src/pymc_core/node/handlers/login_server.py b/src/pymc_core/node/handlers/login_server.py index 4841a5d..92ed40a 100644 --- a/src/pymc_core/node/handlers/login_server.py +++ b/src/pymc_core/node/handlers/login_server.py @@ -144,22 +144,28 @@ async def __call__(self, packet: Packet) -> None: self.log("[LoginServer] Room server packet too short for sync_since field") return sync_since = struct.unpack(" 0 else '(empty)'}") + self.log( + f"[LoginServer] Room server: sync_since={sync_since}, " + f"password from byte 8 to {null_idx}" + ) + self.log( + f"[LoginServer] Password hex: " + f"{password_bytes.hex() if password_bytes else '(empty)'}" + ) else: # Repeater format: password only # Find null terminator after timestamp (starting from byte 4) null_idx = plaintext.find(b"\x00", 4) if null_idx == -1: null_idx = len(plaintext) - + password_bytes = plaintext[4:null_idx] self.log(f"[LoginServer] Repeater format: password from byte 4 to {null_idx}") diff --git a/src/pymc_core/node/handlers/path.py b/src/pymc_core/node/handlers/path.py index f9debf7..f0b13ee 100644 --- a/src/pymc_core/node/handlers/path.py +++ b/src/pymc_core/node/handlers/path.py @@ -65,9 +65,11 @@ async def __call__(self, pkt: Packet) -> None: # Optional PATH packet analysis if analyzer is available try: - if hasattr(self, "_dispatcher") and hasattr( - self._dispatcher, "packet_analysis_callback" - ) and self._dispatcher.packet_analysis_callback: + if ( + hasattr(self, "_dispatcher") + and hasattr(self._dispatcher, "packet_analysis_callback") + and self._dispatcher.packet_analysis_callback + ): self._dispatcher.packet_analysis_callback(pkt) except Exception as e: self._log(f"PATH packet analysis failed: {e}") diff --git a/src/pymc_core/node/handlers/protocol_request.py b/src/pymc_core/node/handlers/protocol_request.py index 6541635..54260df 100644 --- a/src/pymc_core/node/handlers/protocol_request.py +++ b/src/pymc_core/node/handlers/protocol_request.py @@ -5,11 +5,11 @@ """ import struct -from typing import Optional, Callable, Any +from typing import Callable, Optional +from pymc_core.protocol import PacketBuilder from pymc_core.protocol.constants import PAYLOAD_TYPE_REQ, PAYLOAD_TYPE_RESPONSE from pymc_core.protocol.crypto import CryptoUtils -from pymc_core.protocol import PacketBuilder # Request type codes (matching C++ implementation) REQ_TYPE_GET_STATUS = 0x01 @@ -25,17 +25,17 @@ class ProtocolRequestHandler: """ Handler for protocol request packets (PAYLOAD_TYPE_REQ). - + Processes encrypted request packets from authenticated clients and sends appropriate RESPONSE packets. Request handling is delegated to callbacks for application-specific logic. """ - + @staticmethod def payload_type(): """Return the payload type this handler processes.""" return PAYLOAD_TYPE_REQ - + def __init__( self, local_identity, @@ -46,7 +46,7 @@ def __init__( ): """ Initialize protocol request handler. - + Args: local_identity: LocalIdentity for this handler contacts: Contact manager or wrapper providing client lookup @@ -59,120 +59,131 @@ def __init__( self.get_client_fn = get_client_fn self.request_handlers = request_handlers or {} self.log = log_fn if log_fn else lambda msg: None - + async def __call__(self, packet): """ Process a protocol request packet. - + Args: packet: Packet instance with REQ payload - + Returns: Packet: RESPONSE packet to send, or None """ try: if len(packet.payload) < 2: return None - + dest_hash = packet.payload[0] src_hash = packet.payload[1] - + # Verify this packet is for us our_hash = self.local_identity.get_public_key()[0] if dest_hash != our_hash: return None - + self.log(f"Processing REQ from 0x{src_hash:02X}") - + # Get client info client = self._get_client(src_hash) if not client: self.log(f"REQ from unknown client 0x{src_hash:02X}") return None - + # Get shared secret shared_secret = self._get_shared_secret(client) if not shared_secret: self.log(f"No shared secret for client 0x{src_hash:02X}") return None - + # Decrypt request encrypted_data = packet.payload[2:] aes_key = shared_secret[:16] - + try: - plaintext = CryptoUtils.mac_then_decrypt(aes_key, shared_secret, bytes(encrypted_data)) + plaintext = CryptoUtils.mac_then_decrypt( + aes_key, shared_secret, bytes(encrypted_data) + ) except Exception as e: self.log(f"Failed to decrypt REQ: {e}") return None - + # Parse request if len(plaintext) < 5: self.log("REQ packet too short") return None - - timestamp = struct.unpack(' 5 else b'' - + req_data = plaintext[5:] if len(plaintext) > 5 else b"" + self.log(f"REQ type=0x{req_type:02X}, timestamp={timestamp}") - + # Handle request response_data = await self._handle_request(client, timestamp, req_type, req_data) - + if response_data: return self._build_response(packet, client, response_data, shared_secret) - + return None - + except Exception as e: self.log(f"Error processing REQ: {e}") return None - + def _get_client(self, src_hash: int): """Get client info by source hash.""" if self.get_client_fn: return self.get_client_fn(src_hash) - + # Fallback: search in contacts - if hasattr(self.contacts, 'contacts'): + if hasattr(self.contacts, "contacts"): for contact in self.contacts.contacts: - if hasattr(contact, 'public_key'): - pk = bytes.fromhex(contact.public_key) if isinstance(contact.public_key, str) else contact.public_key + if hasattr(contact, "public_key"): + pk = ( + bytes.fromhex(contact.public_key) + if isinstance(contact.public_key, str) + else contact.public_key + ) if pk[0] == src_hash: return contact - + return None - + def _get_shared_secret(self, client): """Get shared secret for client.""" - if hasattr(client, 'shared_secret'): + if hasattr(client, "shared_secret"): return client.shared_secret - - if hasattr(client, 'public_key'): - pk = bytes.fromhex(client.public_key) if isinstance(client.public_key, str) else client.public_key + + if hasattr(client, "public_key"): + pk = ( + bytes.fromhex(client.public_key) + if isinstance(client.public_key, str) + else client.public_key + ) from pymc_core.protocol.identity import Identity + identity = Identity(pk) return identity.calc_shared_secret(self.local_identity.get_private_key()) - + return None - + async def _handle_request(self, client, timestamp: int, req_type: int, req_data: bytes): """ Handle request and generate response. - + Args: client: Client info object timestamp: Request timestamp req_type: Request type code req_data: Request payload - + Returns: bytes: Response data (timestamp + payload) or None """ # Build response with reflected timestamp - response = bytearray(struct.pack('= 0 and len(client.out_path) > 0: - reply_packet.path = bytearray(client.out_path[:client.out_path_len]) + reply_packet.path = bytearray(client.out_path[: client.out_path_len]) reply_packet.path_len = client.out_path_len - - self.log(f"RESPONSE built for 0x{client_identity.get_public_key()[0]:02X} via {route_type.upper()}") - + + self.log( + f"RESPONSE built for 0x{client_identity.get_public_key()[0]:02X} " + f"via {route_type.upper()}" + ) + return reply_packet - + except Exception as e: self.log(f"Error building RESPONSE: {e}") return None diff --git a/src/pymc_core/node/handlers/protocol_response.py b/src/pymc_core/node/handlers/protocol_response.py index a1b86d5..1094709 100644 --- a/src/pymc_core/node/handlers/protocol_response.py +++ b/src/pymc_core/node/handlers/protocol_response.py @@ -25,28 +25,28 @@ 0x02: ("Analog Input", 2, 100, True), 0x03: ("Analog Output", 2, 100, True), # --- Extended types (from CayenneLPP.h) --- - 0x64: ("Generic Sensor", 4, 1, False), # LPP_GENERIC_SENSOR = 100 - 0x65: ("Illuminance", 2, 1, False), # LPP_LUMINOSITY = 101 - 0x66: ("Presence", 1, 1, False), # LPP_PRESENCE = 102 - 0x67: ("Temperature", 2, 10, True), # LPP_TEMPERATURE = 103 - 0x68: ("Humidity", 1, 2, False), # LPP_RELATIVE_HUMIDITY = 104 + 0x64: ("Generic Sensor", 4, 1, False), # LPP_GENERIC_SENSOR = 100 + 0x65: ("Illuminance", 2, 1, False), # LPP_LUMINOSITY = 101 + 0x66: ("Presence", 1, 1, False), # LPP_PRESENCE = 102 + 0x67: ("Temperature", 2, 10, True), # LPP_TEMPERATURE = 103 + 0x68: ("Humidity", 1, 2, False), # LPP_RELATIVE_HUMIDITY = 104 0x71: ("Accelerometer", 6, 1000, True), # LPP_ACCELEROMETER = 113, 3×int16 - 0x73: ("Barometer", 2, 10, False), # LPP_BAROMETRIC_PRESSURE = 115 - 0x74: ("Voltage", 2, 100, False), # LPP_VOLTAGE = 116, 0.01V - 0x75: ("Current", 2, 1000, False), # LPP_CURRENT = 117, 0.001A - 0x76: ("Frequency", 4, 1, False), # LPP_FREQUENCY = 118, 1Hz - 0x78: ("Percentage", 1, 1, False), # LPP_PERCENTAGE = 120, 1-100% - 0x79: ("Altitude", 2, 1, True), # LPP_ALTITUDE = 121, 1m signed - 0x7D: ("Concentration", 2, 1, False), # LPP_CONCENTRATION = 125, 1ppm - 0x80: ("Power", 2, 1, False), # LPP_POWER = 128, 1W - 0x82: ("Distance", 4, 1000, False), # LPP_DISTANCE = 130, 0.001m - 0x83: ("Energy", 4, 1000, False), # LPP_ENERGY = 131, 0.001kWh - 0x84: ("Direction", 2, 1, False), # LPP_DIRECTION = 132, 1deg - 0x85: ("Unix Time", 4, 1, False), # LPP_UNIXTIME = 133 - 0x86: ("Gyroscope", 6, 100, True), # LPP_GYROMETER = 134, 3×int16 - 0x87: ("Colour", 3, 1, False), # LPP_COLOUR = 135, RGB - 0x88: ("GPS", 9, 1, True), # LPP_GPS = 136, lat(3)+lon(3)+alt(3) - 0x8E: ("Switch", 1, 1, False), # LPP_SWITCH = 142, 0/1 + 0x73: ("Barometer", 2, 10, False), # LPP_BAROMETRIC_PRESSURE = 115 + 0x74: ("Voltage", 2, 100, False), # LPP_VOLTAGE = 116, 0.01V + 0x75: ("Current", 2, 1000, False), # LPP_CURRENT = 117, 0.001A + 0x76: ("Frequency", 4, 1, False), # LPP_FREQUENCY = 118, 1Hz + 0x78: ("Percentage", 1, 1, False), # LPP_PERCENTAGE = 120, 1-100% + 0x79: ("Altitude", 2, 1, True), # LPP_ALTITUDE = 121, 1m signed + 0x7D: ("Concentration", 2, 1, False), # LPP_CONCENTRATION = 125, 1ppm + 0x80: ("Power", 2, 1, False), # LPP_POWER = 128, 1W + 0x82: ("Distance", 4, 1000, False), # LPP_DISTANCE = 130, 0.001m + 0x83: ("Energy", 4, 1000, False), # LPP_ENERGY = 131, 0.001kWh + 0x84: ("Direction", 2, 1, False), # LPP_DIRECTION = 132, 1deg + 0x85: ("Unix Time", 4, 1, False), # LPP_UNIXTIME = 133 + 0x86: ("Gyroscope", 6, 100, True), # LPP_GYROMETER = 134, 3×int16 + 0x87: ("Colour", 3, 1, False), # LPP_COLOUR = 135, RGB + 0x88: ("GPS", 9, 1, True), # LPP_GPS = 136, lat(3)+lon(3)+alt(3) + 0x8E: ("Switch", 1, 1, False), # LPP_SWITCH = 142, 0/1 } @@ -72,27 +72,51 @@ def _decode_cayenne_lpp(data: bytes) -> list: lat = int.from_bytes(raw[0:3], "big", signed=True) / 10000 lon = int.from_bytes(raw[3:6], "big", signed=True) / 10000 alt = int.from_bytes(raw[6:9], "big", signed=True) / 100 - sensors.append({"channel": channel, "type": name, "type_id": type_id, - "value": {"latitude": lat, "longitude": lon, "altitude": alt}, - "raw_value": raw.hex()}) + sensors.append( + { + "channel": channel, + "type": name, + "type_id": type_id, + "value": {"latitude": lat, "longitude": lon, "altitude": alt}, + "raw_value": raw.hex(), + } + ) elif size == 6 and type_id in (0x71, 0x86): # 3-axis: x(2) + y(2) + z(2), all signed x = int.from_bytes(raw[0:2], "big", signed=True) / divisor y = int.from_bytes(raw[2:4], "big", signed=True) / divisor z = int.from_bytes(raw[4:6], "big", signed=True) / divisor - sensors.append({"channel": channel, "type": name, "type_id": type_id, - "value": {"x": x, "y": y, "z": z}, - "raw_value": raw.hex()}) + sensors.append( + { + "channel": channel, + "type": name, + "type_id": type_id, + "value": {"x": x, "y": y, "z": z}, + "raw_value": raw.hex(), + } + ) elif type_id == 0x87: # Colour: R(1) + G(1) + B(1) - sensors.append({"channel": channel, "type": name, "type_id": type_id, - "value": {"r": raw[0], "g": raw[1], "b": raw[2]}, - "raw_value": raw.hex()}) + sensors.append( + { + "channel": channel, + "type": name, + "type_id": type_id, + "value": {"r": raw[0], "g": raw[1], "b": raw[2]}, + "raw_value": raw.hex(), + } + ) else: val = int.from_bytes(raw, "big", signed=signed) - sensors.append({"channel": channel, "type": name, "type_id": type_id, - "value": val / divisor if divisor != 1 else val, - "raw_value": raw.hex()}) + sensors.append( + { + "channel": channel, + "type": name, + "type_id": type_id, + "value": val / divisor if divisor != 1 else val, + "raw_value": raw.hex(), + } + ) return sensors @@ -112,8 +136,8 @@ def __init__(self, log_fn: Callable[[str], None], local_identity, contact_book): # Callbacks for protocol responses self._response_callbacks: Dict[int, Callable[[bool, str, Dict[str, Any]], None]] = {} - # Optional: when set, decrypted payloads with tag+data (and optional path) are passed as binary response - # Signature: (tag_bytes, response_data, path_info=None). path_info = (out_path, in_path, contact_pubkey). + # Optional: decrypted payloads with tag+data (and optional path) passed as binary response. + # Signature: (tag_bytes, response_data, path_info=None). self._binary_response_callback: Optional[Callable[..., Any]] = None @staticmethod @@ -131,8 +155,8 @@ def clear_response_callback(self, contact_hash: int) -> None: self._response_callbacks.pop(contact_hash, None) def set_binary_response_callback(self, callback: Callable[..., Any]) -> None: - """Set callback for binary responses. Called with (tag_bytes, response_data, path_info=None). - path_info when present is (out_path, in_path, contact_pubkey) for path-return format.""" + """Set callback for binary responses. (tag_bytes, response_data, path_info=None). + path_info = (out_path, in_path, contact_pubkey) for path-return format.""" self._binary_response_callback = callback async def __call__(self, pkt: Packet) -> None: @@ -151,9 +175,12 @@ async def __call__(self, pkt: Packet) -> None: return # Try to decrypt the response - success, decoded_text, parsed_data, raw_decrypted = await self._decrypt_protocol_response( - pkt, src_hash - ) + ( + success, + decoded_text, + parsed_data, + raw_decrypted, + ) = await self._decrypt_protocol_response(pkt, src_hash) # If an explicit response callback is waiting for this source (e.g. telemetry, # stats, repeater command), deliver there first. The binary/path-discovery @@ -178,7 +205,7 @@ async def __call__(self, pkt: Packet) -> None: callback(success, decoded_text, parsed_data) return - # If binary response callback is set, parse and invoke (plain tag+data or path-return format) + # If binary response callback set, parse and invoke (tag+data or path-return) if ( success and self._binary_response_callback is not None @@ -193,10 +220,7 @@ async def __call__(self, pkt: Packet) -> None: # Extract inner response from path-return structure path_len_byte = raw_decrypted[0] inner_offset = 1 + path_len_byte + 1 - if ( - path_len_byte <= MAX_PATH_SIZE - and len(raw_decrypted) >= inner_offset + 4 - ): + if path_len_byte <= MAX_PATH_SIZE and len(raw_decrypted) >= inner_offset + 4: out_path = bytes(raw_decrypted[1 : 1 + path_len_byte]) extra_type = raw_decrypted[1 + path_len_byte] & 0x0F extra = raw_decrypted[inner_offset:] @@ -220,9 +244,7 @@ async def __call__(self, pkt: Packet) -> None: response_data = raw_decrypted[4:] try: - cb_result = self._binary_response_callback( - tag_bytes, response_data, path_info - ) + cb_result = self._binary_response_callback(tag_bytes, response_data, path_info) if asyncio.iscoroutine(cb_result): await cb_result except Exception as e: @@ -235,11 +257,11 @@ async def __call__(self, pkt: Packet) -> None: async def _decrypt_protocol_response( self, pkt: Packet, src_hash: int ) -> tuple[bool, str, Dict[str, Any], Optional[bytes]]: - """Decrypt and parse a protocol response packet. Returns (success, text, parsed_data, raw_decrypted). + """Decrypt and parse protocol response. Returns (success, text, parsed_data, raw_decrypted). - Handles both packet types by inspecting the actual packet header: - - PAYLOAD_TYPE_RESPONSE (0x01): direct datagram → decrypted = tag(4)+data - - PAYLOAD_TYPE_PATH (0x08): path return → decrypted = path_len(1)+path(N)+extra_type(1)+extra + Handles both packet types: + - RESPONSE (0x01): direct → tag(4)+data + - PATH (0x08): path_len+path(N)+extra_type+extra """ try: # Find the contact by hash @@ -271,10 +293,7 @@ async def _decrypt_protocol_response( if len(decrypted) >= 2: # need at least path_len + extra_type path_len_byte = decrypted[0] inner_offset = 1 + path_len_byte + 1 # path_len + path + extra_type - if ( - path_len_byte <= MAX_PATH_SIZE - and len(decrypted) >= inner_offset - ): + if path_len_byte <= MAX_PATH_SIZE and len(decrypted) >= inner_offset: extra_type = decrypted[1 + path_len_byte] & 0x0F if extra_type == PAYLOAD_TYPE_RESPONSE and len(decrypted) > inner_offset: response_data = decrypted[inner_offset:] @@ -315,9 +334,9 @@ def _parse_protocol_response(self, data: bytes) -> tuple[bool, str, Dict[str, An result_dict["type"] = "stats" result_dict["raw_bytes"] = stats_result["raw_bytes"] self._log( - f"[ProtocolResponse] Parsed as STATS: batt={result_dict['batt_milli_volts']}mV, " + f"[ProtocolResponse] STATS: batt={result_dict['batt_milli_volts']}mV, " f"rssi={result_dict['last_rssi']}, snr={result_dict['last_snr']}, " - f"raw_bytes={len(result_dict['raw_bytes'])}B" + f"raw={len(result_dict['raw_bytes'])}B" ) return True, stats_result["formatted"], result_dict @@ -394,24 +413,24 @@ def _parse_stats_response(self, data: bytes) -> Optional[Dict[str, Any]]: # Parse with correct field types matching C++ struct ( - batt_milli_volts, # uint16 offset 0 + batt_milli_volts, # uint16 offset 0 curr_tx_queue_len, # uint16 offset 2 - noise_floor, # int16 offset 4 - last_rssi, # int16 offset 6 - n_packets_recv, # uint32 offset 8 - n_packets_sent, # uint32 offset 12 - total_air_time_secs,# uint32 offset 16 - total_up_time_secs, # uint32 offset 20 - n_sent_flood, # uint32 offset 24 - n_sent_direct, # uint32 offset 28 - n_recv_flood, # uint32 offset 32 - n_recv_direct, # uint32 offset 36 - err_events, # uint16 offset 40 - last_snr_raw, # int16 offset 42 - n_direct_dups, # uint16 offset 44 - n_flood_dups, # uint16 offset 46 + noise_floor, # int16 offset 4 + last_rssi, # int16 offset 6 + n_packets_recv, # uint32 offset 8 + n_packets_sent, # uint32 offset 12 + total_air_time_secs, # uint32 offset 16 + total_up_time_secs, # uint32 offset 20 + n_sent_flood, # uint32 offset 24 + n_sent_direct, # uint32 offset 28 + n_recv_flood, # uint32 offset 32 + n_recv_direct, # uint32 offset 36 + err_events, # uint16 offset 40 + last_snr_raw, # int16 offset 42 + n_direct_dups, # uint16 offset 44 + n_flood_dups, # uint16 offset 46 total_rx_air_time_secs, # uint32 offset 48 - n_recv_errors, # uint32 offset 52 + n_recv_errors, # uint32 offset 52 ) = struct.unpack(" Optional[Dict[str, Any]]: ) return { "type": "telemetry", - "formatted": ( - f"Telemetry ({len(sensors)} sensors, " - f"ts:{reflected_timestamp})" - ), + "formatted": (f"Telemetry ({len(sensors)} sensors, " f"ts:{reflected_timestamp})"), "reflected_timestamp": reflected_timestamp, "sensor_count": len(sensors), "sensors": sensors, diff --git a/src/pymc_core/node/handlers/text.py b/src/pymc_core/node/handlers/text.py index f3a1c4c..f2d7c1d 100644 --- a/src/pymc_core/node/handlers/text.py +++ b/src/pymc_core/node/handlers/text.py @@ -87,8 +87,8 @@ async def __call__(self, packet: Packet) -> None: # Skip ACK for TXT_TYPE_CLI_DATA (0x01) - CLI commands don't need ACKs # Following C++ pattern: only TXT_TYPE_PLAIN (0x00) gets ACKs TXT_TYPE_PLAIN = 0x00 - TXT_TYPE_CLI_DATA = 0x01 - send_ack = (txt_type == TXT_TYPE_PLAIN) + TXT_TYPE_CLI_DATA = 0x01 # noqa: F841 + send_ack = txt_type == TXT_TYPE_PLAIN if send_ack: # Create appropriate ACK response diff --git a/src/pymc_core/protocol/constants.py b/src/pymc_core/protocol/constants.py index aab5ab1..9764875 100644 --- a/src/pymc_core/protocol/constants.py +++ b/src/pymc_core/protocol/constants.py @@ -109,7 +109,7 @@ def describe_advert_flags(flags: int) -> str: # Protocol Request Types -REQ_TYPE_GET_STATUS = 0x01 # Get repeater stats (RepeaterStats struct) +REQ_TYPE_GET_STATUS = 0x01 # Get repeater stats (RepeaterStats struct) REQ_TYPE_GET_TELEMETRY_DATA = 0x03 # Get telemetry data (CayenneLPP) TELEM_PERM_BASE = 0x01 TELEM_PERM_LOCATION = 0x02 diff --git a/src/pymc_core/protocol/identity.py b/src/pymc_core/protocol/identity.py index b879a7d..8bd499e 100644 --- a/src/pymc_core/protocol/identity.py +++ b/src/pymc_core/protocol/identity.py @@ -100,7 +100,7 @@ def __init__(self, seed: Optional[bytes] = None): if seed and len(seed) == 64: from nacl.bindings import crypto_scalarmult_ed25519_base_noclamp - # MeshCore format: [32-byte scalar][32-byte nonce]; firmware clamps first 32 bytes for ECDH + # MeshCore format: [32-byte scalar][32-byte nonce]; firmware clamps first 32 for ECDH self._firmware_key = seed self.signing_key = None @@ -110,7 +110,7 @@ def __init__(self, seed: Optional[bytes] = None): ed25519_pub = crypto_scalarmult_ed25519_base_noclamp(clamped) self.verify_key = VerifyKey(ed25519_pub) - # Use clamped scalar directly for ECDH (firmware key_exchange.c uses first 32 bytes clamped) + # Use clamped scalar for ECDH (firmware key_exchange.c uses first 32 bytes clamped) self._x25519_private = clamped self._x25519_public = CryptoUtils.scalarmult_base(clamped) else: @@ -161,10 +161,10 @@ def get_shared_public_key(self) -> bytes: def get_signing_key_bytes(self) -> bytes: """ Get the signing key bytes for this identity. - + For standard keys, returns the 32-byte Ed25519 seed. For firmware keys, returns the 64-byte expanded key format [scalar||nonce]. - + Returns: The signing key bytes (32 or 64 bytes depending on key type). """ diff --git a/src/pymc_core/protocol/modem_identity.py b/src/pymc_core/protocol/modem_identity.py index e8551ee..ab54ec2 100644 --- a/src/pymc_core/protocol/modem_identity.py +++ b/src/pymc_core/protocol/modem_identity.py @@ -188,8 +188,7 @@ def get_signing_key_bytes(self) -> bytes: RuntimeError: Always, as signing key is not accessible """ raise RuntimeError( - "ModemIdentity does not expose signing keys. " - "Use sign() for signing operations." + "ModemIdentity does not expose signing keys. " "Use sign() for signing operations." ) # Additional modem-specific methods diff --git a/src/pymc_core/protocol/packet_builder.py b/src/pymc_core/protocol/packet_builder.py index 8fd6a74..b458125 100644 --- a/src/pymc_core/protocol/packet_builder.py +++ b/src/pymc_core/protocol/packet_builder.py @@ -538,9 +538,13 @@ def create_group_datagram( secret_bytes = ( bytes.fromhex(channel["secret"]) if isinstance(channel["secret"], str) - else (channel["secret"] if isinstance(channel["secret"], bytes) else channel["secret"].encode("utf-8")) + else ( + channel["secret"] + if isinstance(channel["secret"], bytes) + else channel["secret"].encode("utf-8") + ) ) - # Use same channel hash derivation as GroupTextHandler (firmware: hash first 16 bytes when 32-byte key has second 16 zero) + # Same channel hash as GroupTextHandler (hash first 16 when key has second 16 zero) hash_input = ( secret_bytes[:16] if len(secret_bytes) >= 32 and secret_bytes[16:32] == b"\x00" * 16 diff --git a/src/pymc_core/protocol/transport_keys.py b/src/pymc_core/protocol/transport_keys.py index 9518f3e..d174196 100644 --- a/src/pymc_core/protocol/transport_keys.py +++ b/src/pymc_core/protocol/transport_keys.py @@ -7,43 +7,44 @@ """ import struct + from .crypto import CryptoUtils def get_auto_key_for(name: str) -> bytes: """ Generate 128-bit transport key from region name. - + Matches C++ implementation: void TransportKeyStore::getAutoKeyFor(uint16_t id, const char* name, TransportKey& dest) - + Args: name: Region name including '#' (e.g., "#usa") - + Returns: bytes: 16-byte transport key """ if not name: raise ValueError("Region name cannot be empty") - if not name.startswith('#'): + if not name.startswith("#"): raise ValueError("Region name must start with '#'") if len(name) > 64: raise ValueError("Region name is too long (max 64 characters)") - key_hash = CryptoUtils.sha256(name.encode('ascii')) + key_hash = CryptoUtils.sha256(name.encode("ascii")) return key_hash[:16] # First 16 bytes (128 bits) def calc_transport_code(key: bytes, packet) -> int: """ Calculate transport code for a packet. - + Matches C++ implementation: uint16_t TransportKey::calcTransportCode(const mesh::Packet* packet) const - + Args: key: 16-byte transport key packet: Packet with payload_type and payload - + Returns: int: 16-bit transport code """ @@ -51,20 +52,20 @@ def calc_transport_code(key: bytes, packet) -> int: raise ValueError(f"Transport key must be 16 bytes, got {len(key)}") payload_type = packet.get_payload_type() payload_data = packet.get_payload() - + # HMAC input: payload_type (1 byte) + payload hmac_data = bytes([payload_type]) + payload_data - + # Calculate HMAC-SHA256 hmac_digest = CryptoUtils._hmac_sha256(key, hmac_data) - + # Extract first 2 bytes as little-endian uint16 (matches Arduino platform endianness) - code = struct.unpack(' Contact: @@ -152,7 +148,7 @@ async def test_send_text_message_with_contact_injects_packet(self): bridge = CompanionBridge(LocalIdentity(), injector) contact = _make_peer_contact("Alice") bridge.contacts.add(contact) - result = await bridge.send_text_message(contact.public_key, "Hello") + await bridge.send_text_message(contact.public_key, "Hello") assert len(injector.calls) >= 1 pkt, _ = injector.calls[0] assert (pkt.header >> 2) & 0x0F == PAYLOAD_TYPE_TXT_MSG @@ -240,7 +236,9 @@ async def test_send_binary_req_with_contact(self): bridge = CompanionBridge(LocalIdentity(), injector) contact = _make_peer_contact("Rpt") bridge.contacts.add(contact) - result = await bridge.send_binary_req(contact.public_key, bytes([0x01]), timeout_seconds=5.0) + result = await bridge.send_binary_req( + contact.public_key, bytes([0x01]), timeout_seconds=5.0 + ) assert result.success is True assert result.expected_ack is not None assert len(injector.calls) == 1 diff --git a/tests/test_companion_regions.py b/tests/test_companion_regions.py index d8a3994..8154790 100644 --- a/tests/test_companion_regions.py +++ b/tests/test_companion_regions.py @@ -5,7 +5,6 @@ import pytest from pymc_core.companion import CompanionRadio -from pymc_core.companion.constants import ADV_TYPE_CHAT from pymc_core.companion.models import Contact from pymc_core.protocol import LocalIdentity, Packet, PacketBuilder from pymc_core.protocol.constants import ( @@ -15,7 +14,6 @@ ) from pymc_core.protocol.transport_keys import calc_transport_code, get_auto_key_for - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -215,9 +213,7 @@ class TestAdvertiseWithFloodScope: async def test_advertise_flood_with_scope_sends_transport_flood(self): radio = MockRadio() identity = LocalIdentity() - companion = CompanionRadio( - radio=radio, identity=identity, node_name="scoped" - ) + companion = CompanionRadio(radio=radio, identity=identity, node_name="scoped") companion.set_flood_region("usa") await companion.start() @@ -239,9 +235,7 @@ async def test_advertise_flood_with_scope_sends_transport_flood(self): async def test_advertise_flood_without_scope_sends_normal_flood(self): radio = MockRadio() identity = LocalIdentity() - companion = CompanionRadio( - radio=radio, identity=identity, node_name="noscope" - ) + companion = CompanionRadio(radio=radio, identity=identity, node_name="noscope") # No flood scope set await companion.start() diff --git a/tests/test_companion_stores.py b/tests/test_companion_stores.py index c33c48f..c778df1 100644 --- a/tests/test_companion_stores.py +++ b/tests/test_companion_stores.py @@ -1,26 +1,15 @@ -"""Tests for companion stores and models: ContactStore, ChannelStore, MessageQueue, PathCache, StatsCollector.""" +"""Tests for companion stores and models: ContactStore, ChannelStore, MessageQueue, PathCache.""" -import pytest - -from pymc_core.companion import ( - ContactStore, - ChannelStore, - MessageQueue, - PathCache, - StatsCollector, -) -from pymc_core.companion.constants import DEFAULT_MAX_CONTACTS, DEFAULT_MAX_CHANNELS +from pymc_core.companion import ChannelStore, ContactStore, MessageQueue, PathCache, StatsCollector from pymc_core.companion.models import ( AdvertPath, Channel, Contact, NodePrefs, - PacketStats, QueuedMessage, SentResult, ) - # --------------------------------------------------------------------------- # Models # --------------------------------------------------------------------------- @@ -199,19 +188,19 @@ def test_clear(self): def test_load_from(self): store = ContactStore(max_contacts=10) - contacts = [ - Contact(public_key=bytes([i] * 32), name=f"C{i}") for i in range(3) - ] + contacts = [Contact(public_key=bytes([i] * 32), name=f"C{i}") for i in range(3)] store.load_from(contacts) assert store.get_count() == 3 assert store.get_by_name("C1").name == "C1" def test_load_from_dicts(self): store = ContactStore(max_contacts=10) - store.load_from_dicts([ - {"public_key": "a1" * 32, "name": "DictAlice"}, - {"public_key": "b2" * 32, "name": "DictBob"}, - ]) + store.load_from_dicts( + [ + {"public_key": "a1" * 32, "name": "DictAlice"}, + {"public_key": "b2" * 32, "name": "DictBob"}, + ] + ) assert store.get_count() == 2 assert store.get_by_name("DictAlice") is not None assert store.get_by_name("DictBob") is not None diff --git a/tests/test_kiss_modem_wrapper.py b/tests/test_kiss_modem_wrapper.py index 77c8e72..a7397a6 100644 --- a/tests/test_kiss_modem_wrapper.py +++ b/tests/test_kiss_modem_wrapper.py @@ -13,23 +13,14 @@ from pymc_core.hardware.kiss_modem_wrapper import ( CMD_DATA, - CMD_ENCRYPT_DATA, - CMD_GET_AIRTIME, CMD_GET_BATTERY, - CMD_GET_IDENTITY, - CMD_GET_NOISE_FLOOR, CMD_GET_RADIO, - CMD_GET_RANDOM, CMD_GET_STATS, - CMD_GET_TX_POWER, CMD_GET_VERSION, - CMD_HASH, - CMD_KEY_EXCHANGE, CMD_PING, CMD_SET_RADIO, CMD_SET_TX_POWER, CMD_SIGN_DATA, - CMD_VERIFY_SIGNATURE, HW_CMD_GET_DEVICE_NAME, HW_CMD_GET_MCU_TEMP, HW_CMD_GET_SIGNAL_REPORT, @@ -39,6 +30,7 @@ HW_RESP_DEVICE_NAME, HW_RESP_MCU_TEMP, HW_RESP_OK, + HW_RESP_RX_META, HW_RESP_SIGNAL_REPORT, KISS_CMD_FULLDUPLEX, KISS_CMD_PERSISTENCE, @@ -49,25 +41,15 @@ KISS_FESC, KISS_TFEND, KISS_TFESC, - RESP_AIRTIME, RESP_BATTERY, - RESP_ENCRYPTED, RESP_ERROR, - RESP_HASH, RESP_IDENTITY, - RESP_NOISE_FLOOR, RESP_OK, RESP_PONG, RESP_RADIO, - RESP_RANDOM, - RESP_SHARED_SECRET, RESP_SIGNATURE, RESP_STATS, RESP_TX_DONE, - RESP_TX_POWER, - RESP_VERIFY, - RESP_VERSION, - HW_RESP_RX_META, KissModemWrapper, ) @@ -141,7 +123,9 @@ def test_decode_simple_frame(self): # Data frame: FEND + 0x00 + raw_packet + FEND (no in-frame metadata) data_frame = bytes([KISS_FEND, CMD_DATA, 0x01, 0x02, 0x03, KISS_FEND]) # RxMeta: FEND + 0x06 + 0xF9 + SNR + RSSI + FEND (sent immediately after Data) - rx_meta_frame = bytes([KISS_FEND, KISS_CMD_SETHARDWARE, HW_RESP_RX_META, 0x10, 0xB0, KISS_FEND]) + rx_meta_frame = bytes( + [KISS_FEND, KISS_CMD_SETHARDWARE, HW_RESP_RX_META, 0x10, 0xB0, KISS_FEND] + ) for byte in data_frame: modem._decode_kiss_byte(byte) @@ -160,10 +144,10 @@ def test_decode_frame_with_escapes(self): modem.on_frame_received = lambda data: received_frames.append(data) # Data frame: payload is escaped 0xC0 (FESC + TFEND) - data_frame = bytes( - [KISS_FEND, CMD_DATA, KISS_FESC, KISS_TFEND, KISS_FEND] + data_frame = bytes([KISS_FEND, CMD_DATA, KISS_FESC, KISS_TFEND, KISS_FEND]) + rx_meta_frame = bytes( + [KISS_FEND, KISS_CMD_SETHARDWARE, HW_RESP_RX_META, 0x10, 0xB0, KISS_FEND] ) - rx_meta_frame = bytes([KISS_FEND, KISS_CMD_SETHARDWARE, HW_RESP_RX_META, 0x10, 0xB0, KISS_FEND]) for byte in data_frame: modem._decode_kiss_byte(byte) @@ -180,7 +164,9 @@ def test_decode_extracts_rssi_snr(self): data_frame = bytes([KISS_FEND, CMD_DATA, 0xAA, 0xBB, KISS_FEND]) # RxMeta: SNR=0x10 (4.0 dB), RSSI=0xB0 (-80) - rx_meta_frame = bytes([KISS_FEND, KISS_CMD_SETHARDWARE, HW_RESP_RX_META, 0x10, 0xB0, KISS_FEND]) + rx_meta_frame = bytes( + [KISS_FEND, KISS_CMD_SETHARDWARE, HW_RESP_RX_META, 0x10, 0xB0, KISS_FEND] + ) for byte in data_frame: modem._decode_kiss_byte(byte) @@ -196,6 +182,7 @@ def test_rx_callback_receives_per_packet_rssi_snr(self): modem.is_connected = True received = [] + def capture(data, rssi, snr): received.append((data, rssi, snr)) @@ -274,7 +261,7 @@ def test_send_command_encodes_correctly(self): assert written_frame[0] == KISS_FEND assert written_frame[1] == KISS_CMD_SETHARDWARE # type SetHardware - assert written_frame[2] == HW_CMD_GET_VERSION # sub_cmd GetVersion + assert written_frame[2] == HW_CMD_GET_VERSION # sub_cmd GetVersion assert written_frame[-1] == KISS_FEND def test_response_parsing_identity(self): @@ -334,7 +321,7 @@ class TestRadioConfiguration: def test_radio_config_struct_format(self): """Test that radio config is packed correctly""" - modem = KissModemWrapper(port="/dev/null", auto_configure=False) + KissModemWrapper(port="/dev/null", auto_configure=False) freq_hz = 869618000 bw_hz = 62500 @@ -1007,7 +994,8 @@ def test_context_manager_calls_connect_disconnect(self): with patch.object(KissModemWrapper, "connect", return_value=True) as mock_connect: with patch.object(KissModemWrapper, "disconnect") as mock_disconnect: with KissModemWrapper(port="/dev/null", auto_configure=False) as modem: - pass + pass # keep reference so __del__ doesn't run before assert mock_connect.assert_called_once() mock_disconnect.assert_called_once() + _ = modem # hold ref so __del__ runs after assert, not before diff --git a/tests/test_modem_identity.py b/tests/test_modem_identity.py index f0a28dd..414672f 100644 --- a/tests/test_modem_identity.py +++ b/tests/test_modem_identity.py @@ -12,7 +12,6 @@ from pymc_core.protocol.modem_identity import ModemIdentity - # Generate a valid Ed25519 keypair for testing _TEST_SIGNING_KEY = SigningKey.generate() _TEST_PUBKEY = bytes(_TEST_SIGNING_KEY.verify_key) diff --git a/tests/test_packet_utils.py b/tests/test_packet_utils.py index c5bb384..bd6a0e6 100644 --- a/tests/test_packet_utils.py +++ b/tests/test_packet_utils.py @@ -149,7 +149,9 @@ class TestPacketHashingUtils: def test_hash_string_returns_full_uppercase_hex(self): payload_type = 0x05 path_len = 0 - payload = bytes.fromhex("D9BA8E4EA9444822AC56B4D52AC3C0044C6AE402997BB9805CCB331EC3378DCE339F2D") + payload = bytes.fromhex( + "D9BA8E4EA9444822AC56B4D52AC3C0044C6AE402997BB9805CCB331EC3378DCE339F2D" + ) expected_hex = "887B9BE6056D0B0517AF3A04AC2478EDFC2AB731936DEA525041500E7ADE74D3" @@ -166,7 +168,9 @@ def test_hash_string_returns_full_uppercase_hex(self): def test_hash_string_truncates_to_requested_length(self): payload_type = 0x05 path_len = 1 - payload = bytes.fromhex("D9BA8E4EA9444822AC56B4D52AC3C0044C6AE402997BB9805CCB331EC3378DCE339F2D") + payload = bytes.fromhex( + "D9BA8E4EA9444822AC56B4D52AC3C0044C6AE402997BB9805CCB331EC3378DCE339F2D" + ) expected_hex = "887B9BE6056D0B05" From 80c5c690574aed42060b823b20f3e99c4e71ecea Mon Sep 17 00:00:00 2001 From: agessaman Date: Tue, 17 Feb 2026 17:14:42 -0800 Subject: [PATCH 10/50] Refactor Companion modules for improved messaging and handler management - Introduced a shared factory for core protocol handlers to streamline handler registration in the Dispatcher. - Refactored CompanionBase, CompanionBridge, and CompanionRadio to utilize new handler access methods, enhancing code clarity and maintainability. - Added new methods for sending text, channel messages, and raw data in CompanionBase, improving messaging capabilities. - Updated documentation and comments for better understanding of the messaging flow and handler interactions. --- src/pymc_core/companion/companion_base.py | 426 ++++++++- src/pymc_core/companion/companion_bridge.py | 453 +--------- src/pymc_core/companion/companion_radio.py | 264 +----- src/pymc_core/node/dispatcher.py | 117 +-- src/pymc_core/node/handlers/__init__.py | 3 + src/pymc_core/node/handlers/registry.py | 104 +++ src/pymc_core/node/node.py | 953 ++------------------ tests/test_companion_regions.py | 7 - 8 files changed, 657 insertions(+), 1670 deletions(-) create mode 100644 src/pymc_core/node/handlers/registry.py diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py index 1fc740c..381c9f8 100644 --- a/src/pymc_core/companion/companion_base.py +++ b/src/pymc_core/companion/companion_base.py @@ -26,6 +26,8 @@ ADVERT_FLAG_IS_REPEATER, ADVERT_FLAG_IS_ROOM_SERVER, ADVERT_FLAG_IS_SENSOR, + PAYLOAD_TYPE_CONTROL, + REQ_TYPE_GET_STATUS, REQ_TYPE_GET_TELEMETRY_DATA, ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD, @@ -44,10 +46,13 @@ DEFAULT_OFFLINE_QUEUE_SIZE, DEFAULT_RESPONSE_TIMEOUT_MS, MAX_SIGN_DATA_SIZE, + PROTOCOL_CODE_ANON_REQ, PROTOCOL_CODE_BINARY_REQ, + PROTOCOL_CODE_RAW_DATA, STATS_TYPE_CORE, STATS_TYPE_PACKETS, STATS_TYPE_RADIO, + TXT_TYPE_PLAIN, ) from .contact_store import ContactStore from .message_queue import MessageQueue @@ -652,40 +657,24 @@ def is_running(self) -> bool: """Return whether the companion is currently running.""" @abstractmethod - async def send_text_message( - self, - pub_key: bytes, - text: str, - txt_type: int = 0, - attempt: int = 1, - ) -> SentResult: - """Send a direct text message to a contact.""" - - @abstractmethod - async def send_channel_message(self, channel_idx: int, text: str) -> bool: - """Send a message to a channel.""" + def import_private_key(self, key: bytes) -> bool: + """Import a private key and rebuild the identity.""" - @abstractmethod - async def send_login(self, pub_key: bytes, password: str) -> dict: - """Send a login request to a repeater.""" + def _get_protocol_response_handler(self) -> Any: + """Return the protocol response handler, or ``None``. - @abstractmethod - async def send_trace_path( - self, - pub_key: bytes, - tag: int, - auth_code: int, - flags: int = 0, - ) -> bool: - """Send a trace path request to a contact.""" + Subclasses that support request/response methods (telemetry, status, + binary request, etc.) must override this to return their handler. + """ + return None - @abstractmethod - def import_private_key(self, key: bytes) -> bool: - """Import a private key and rebuild the identity.""" + def _get_login_response_handler(self) -> Any: + """Return the login response handler, or ``None``.""" + return None - @abstractmethod - async def send_control_data(self, data: Any = None) -> bool: - """Send a control data packet.""" + def _get_text_handler(self) -> Any: + """Return the text message handler, or ``None``.""" + return None # ------------------------------------------------------------------------- # Unified TX methods (shared between Radio and Bridge) @@ -854,6 +843,383 @@ async def send_path_discovery_req(self, pub_key: bytes) -> SentResult: current.out_path = old_path self.contacts.update(current) + async def send_text_message( + self, + pub_key: bytes, + text: str, + txt_type: int = TXT_TYPE_PLAIN, + attempt: int = 1, + ) -> SentResult: + """Send a direct text message to a contact.""" + contact = self.contacts.get_by_key(pub_key) + if not contact: + logger.warning(f"Contact not found for key {pub_key.hex()[:12]}...") + return SentResult(success=False) + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return SentResult(success=False) + try: + is_flood = proxy.out_path_len < 0 + msg_type = "flood" if is_flood else "direct" + pkt, ack_crc = PacketBuilder.create_text_message( + contact=proxy, + local_identity=self._identity, + message=text, + attempt=attempt, + message_type=msg_type, + ) + self._apply_flood_scope(pkt) + self._track_pending_ack(ack_crc) + success = await self._send_packet(pkt, wait_for_ack=True) + if success: + self.stats.record_tx(is_flood=is_flood) + else: + self.stats.record_tx_error() + return SentResult( + success=success, + is_flood=is_flood, + expected_ack=ack_crc, + timeout_ms=None, + ) + except Exception as e: + logger.error(f"Error sending text message: {e}") + self.stats.record_tx_error() + return SentResult(success=False) + + async def send_channel_message(self, channel_idx: int, text: str) -> bool: + """Send a message to a channel.""" + channel = self.channels.get(channel_idx) + if not channel: + logger.warning(f"Channel {channel_idx} not found") + return False + try: + pkt = PacketBuilder.create_group_datagram( + group_name=channel.name, + local_identity=self._identity, + message=text, + sender_name=self.prefs.node_name, + channels_config=self.channels.get_channels(), + ) + self._apply_flood_scope(pkt) + success = await self._send_packet(pkt, wait_for_ack=False) + if success: + self.stats.record_tx(is_flood=True) + else: + self.stats.record_tx_error() + return success + except Exception as e: + logger.error(f"Error sending channel message: {e}") + self.stats.record_tx_error() + return False + + async def send_raw_data( + self, + dest_key: bytes, + data: bytes, + path: Optional[bytes] = None, + ) -> SentResult: + """Send raw data to a contact via a protocol request.""" + contact = self.contacts.get_by_key(dest_key) + if not contact: + return SentResult(success=False) + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return SentResult(success=False) + try: + pkt, _ = PacketBuilder.create_protocol_request( + contact=proxy, + local_identity=self._identity, + protocol_code=PROTOCOL_CODE_RAW_DATA, + data=data, + ) + success = await self._send_packet(pkt, wait_for_ack=False) + return SentResult(success=success) + except Exception as e: + logger.error(f"Error sending raw data: {e}") + return SentResult(success=False) + + async def send_trace_path( + self, + pub_key: bytes, + tag: int, + auth_code: int, + flags: int = 0, + ) -> bool: + """Send a trace path request to a contact.""" + contact = self.contacts.get_by_key(pub_key) + if not contact: + return False + path = list(contact.out_path) if contact.out_path else [] + if not path: + path = [contact.public_key[0]] + try: + pkt = PacketBuilder.create_trace(tag, auth_code, flags, path=path) + return await self._send_packet(pkt, wait_for_ack=False) + except Exception as e: + logger.error(f"Error sending trace: {e}") + return False + + async def send_control_data(self, data: Any = None) -> bool: + """Send a CONTROL packet (e.g. discovery request). + + If *data* is provided it must be 1-254 bytes with the first byte having + the 0x80 bit set (e.g. ``DISCOVER_REQ``). Returns ``False`` for + invalid payloads. + + When called with no *data* (or ``None``), a default discovery request + is sent for backward compatibility. + """ + try: + if data and len(data) <= 254 and (data[0] & 0x80) != 0: + pkt = Packet() + pkt.header = PacketBuilder._create_header(PAYLOAD_TYPE_CONTROL, route_type="direct") + pkt.path_len = 0 + pkt.path = bytearray() + pkt.payload = bytearray(data) + pkt.payload_len = len(data) + return await self._send_packet(pkt, wait_for_ack=False) + elif data is not None: + # data was provided but invalid + return False + # No data: send default discovery request + tag = random.randint(0, 0xFFFFFFFF) + pkt = PacketBuilder.create_discovery_request(tag, filter_mask=0x04) + return await self._send_packet(pkt, wait_for_ack=False) + except Exception as e: + logger.error(f"Error sending control data: {e}") + return False + + async def send_login(self, pub_key: bytes, password: str) -> dict: + """Send a login request to a repeater and wait for the response.""" + contact = self.contacts.get_by_key(pub_key) + if not contact: + return {"success": False, "reason": "Contact not found"} + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return {"success": False, "reason": "Contact not found"} + login_handler = self._get_login_response_handler() + if not login_handler: + return {"success": False, "reason": "Login handler not available"} + dest_hash = bytes.fromhex(proxy.public_key)[0] + login_handler.store_login_password(dest_hash, password) + login_result: dict = {"success": False, "data": {}} + login_event = asyncio.Event() + + def _login_cb(success: bool, data: dict) -> None: + login_result["success"] = success + login_result["data"] = data + login_event.set() + + login_handler.set_login_callback(_login_cb) + try: + pkt = PacketBuilder.create_login_packet( + contact=proxy, local_identity=self._identity, password=password + ) + await self._send_packet(pkt, wait_for_ack=False) + try: + await asyncio.wait_for(login_event.wait(), timeout=10.0) + except asyncio.TimeoutError: + return {"success": False, "reason": "Login response timeout"} + data = login_result["data"] + return { + "success": login_result["success"], + "repeater": contact.name, + "is_admin": data.get("is_admin", False), + "keep_alive_interval": data.get("keep_alive_interval", 0), + "tag": data.get("timestamp", 0), + "acl_permissions": data.get("reserved", data.get("permissions", 0)), + "reason": "Login successful" if login_result["success"] else "Login failed", + } + except Exception as e: + logger.error(f"Login error: {e}") + return {"success": False, "reason": str(e)} + finally: + login_handler.set_login_callback(None) + login_handler.clear_login_password(dest_hash) + + async def send_status_request(self, pub_key: bytes, timeout: float = 15.0) -> dict: + """Send a protocol request for repeater status/stats.""" + contact = self.contacts.get_by_key(pub_key) + if not contact: + return {"success": False, "reason": "Contact not found"} + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return {"success": False, "reason": "Contact not found"} + proto_handler = self._get_protocol_response_handler() + if not proto_handler: + return {"success": False, "reason": "Protocol handler not available"} + contact_hash = bytes.fromhex(proxy.public_key)[0] + waiter = ResponseWaiter() + proto_handler.set_response_callback(contact_hash, waiter.callback) + try: + pkt, _ = PacketBuilder.create_protocol_request( + contact=proxy, + local_identity=self._identity, + protocol_code=REQ_TYPE_GET_STATUS, + data=b"", + ) + await self._send_packet(pkt, wait_for_ack=False) + result = await waiter.wait(timeout) + return { + "success": result.get("success", False), + "repeater": contact.name, + "stats": result.get("parsed", {}), + "response_text": result.get("text"), + "reason": "Stats received" if result.get("success") else "Stats request failed", + } + except Exception as e: + logger.error(f"Status request error: {e}") + return {"success": False, "reason": str(e)} + finally: + proto_handler.clear_response_callback(contact_hash) + + async def send_telemetry_request( + self, + pub_key: bytes, + want_base: bool = True, + want_location: bool = True, + want_environment: bool = True, + timeout: float = 10.0, + ) -> dict: + """Send a telemetry request to a contact and wait for the response.""" + contact = self.contacts.get_by_key(pub_key) + if not contact: + return {"success": False, "reason": "Contact not found"} + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return {"success": False, "reason": "Contact not found"} + proto_handler = self._get_protocol_response_handler() + if not proto_handler: + return {"success": False, "reason": "Protocol handler not available"} + contact_hash = bytes.fromhex(proxy.public_key)[0] + waiter = ResponseWaiter() + proto_handler.set_response_callback(contact_hash, waiter.callback) + try: + inv = PacketBuilder._compute_inverse_perm_mask( + want_base, want_location, want_environment + ) + pkt, _ = PacketBuilder.create_protocol_request( + contact=proxy, + local_identity=self._identity, + protocol_code=REQ_TYPE_GET_TELEMETRY_DATA, + data=bytes([inv]), + ) + await self._send_packet(pkt, wait_for_ack=False) + result = await waiter.wait(timeout) + return { + "success": result.get("success", False), + "contact": contact.name, + "telemetry_data": result.get("parsed", {}), + "response_text": result.get("text"), + "reason": ("Telemetry received" if result.get("success") else "Telemetry failed"), + } + except Exception as e: + logger.error(f"Telemetry error: {e}") + return {"success": False, "reason": str(e)} + finally: + proto_handler.clear_response_callback(contact_hash) + + async def send_binary_request(self, pub_key: bytes, data: bytes) -> dict: + """Legacy: send binary request and wait. + + Prefer ``send_binary_req`` + ``on_binary_response``. + """ + return await self._send_protocol_request(pub_key, PROTOCOL_CODE_BINARY_REQ, data) + + async def send_anon_request(self, pub_key: bytes, data: bytes) -> dict: + """Send an anonymous request to a contact and wait for the response.""" + return await self._send_protocol_request(pub_key, PROTOCOL_CODE_ANON_REQ, data) + + async def _send_protocol_request(self, pub_key: bytes, protocol_code: int, data: bytes) -> dict: + """Build and send a protocol request, waiting for the response.""" + contact = self.contacts.get_by_key(pub_key) + if not contact: + return {"success": False, "reason": "Contact not found"} + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return {"success": False, "reason": "Contact not found"} + proto_handler = self._get_protocol_response_handler() + if not proto_handler: + return {"success": False, "reason": "Protocol handler not available"} + contact_hash = bytes.fromhex(proxy.public_key)[0] + waiter = ResponseWaiter() + proto_handler.set_response_callback(contact_hash, waiter.callback) + try: + pkt, _ = PacketBuilder.create_protocol_request( + contact=proxy, + local_identity=self._identity, + protocol_code=protocol_code, + data=data, + ) + await self._send_packet(pkt, wait_for_ack=False) + result = await waiter.wait(10.0) + return { + "success": result.get("success", False), + "response": result.get("text"), + "parsed_data": result.get("parsed", {}), + "reason": "Success" if result.get("success") else "Failed", + } + except Exception as e: + logger.error(f"Protocol request error: {e}") + return {"success": False, "reason": str(e)} + finally: + proto_handler.clear_response_callback(contact_hash) + + async def send_repeater_command( + self, pub_key: bytes, command: str, parameters: Optional[str] = None + ) -> dict: + """Send a text-based command to a repeater and wait for the response.""" + contact = self.contacts.get_by_key(pub_key) + if not contact: + return {"success": False, "reason": "Contact not found"} + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return {"success": False, "reason": "Contact not found"} + text_handler = self._get_text_handler() + if not text_handler: + return {"success": False, "reason": "Text handler not available"} + full_command = command + if parameters: + full_command += f" {parameters}" + response_data: dict = {"text": None, "success": False} + response_event = asyncio.Event() + + def _response_cb(message_text: str, sender_contact: Any) -> None: + response_data["text"] = message_text + response_data["success"] = True + response_event.set() + + text_handler.set_command_response_callback(_response_cb) + try: + msg_type = "flood" if proxy.out_path_len < 0 else "direct" + pkt, ack_crc = PacketBuilder.create_text_message( + contact=proxy, + local_identity=self._identity, + message=full_command, + attempt=1, + message_type=msg_type, + ) + await self._send_packet(pkt, wait_for_ack=True) + try: + await asyncio.wait_for(response_event.wait(), timeout=15.0) + except asyncio.TimeoutError: + pass + return { + "success": response_data["success"], + "repeater": contact.name, + "command": command, + "response": response_data["text"], + "reason": ("Command successful" if response_data["success"] else "No response"), + } + except Exception as e: + logger.error(f"Repeater command error: {e}") + return {"success": False, "reason": str(e)} + finally: + text_handler.set_command_response_callback(None) + + def _track_pending_ack(self, ack_crc: int) -> None: + """Hook for subclasses to track pending ACK CRCs. Default is a no-op.""" + def sync_next_message(self) -> Optional[QueuedMessage]: """Pop and return the next queued message, or None.""" return self.message_queue.pop() diff --git a/src/pymc_core/companion/companion_bridge.py b/src/pymc_core/companion/companion_bridge.py index 042c709..6e39bcb 100644 --- a/src/pymc_core/companion/companion_bridge.py +++ b/src/pymc_core/companion/companion_bridge.py @@ -13,42 +13,28 @@ import time from typing import Any, Callable, Optional -from ..node.handlers import ( - AdvertHandler, - GroupTextHandler, - LoginResponseHandler, - PathHandler, - ProtocolResponseHandler, - TextMessageHandler, -) +from ..node.handlers import create_core_handlers from ..node.handlers.login_server import LoginServerHandler -from ..protocol import LocalIdentity, Packet, PacketBuilder +from ..protocol import LocalIdentity, Packet from ..protocol.constants import ( PAYLOAD_TYPE_ACK, PAYLOAD_TYPE_ADVERT, PAYLOAD_TYPE_ANON_REQ, - PAYLOAD_TYPE_CONTROL, PAYLOAD_TYPE_GRP_TXT, PAYLOAD_TYPE_PATH, PAYLOAD_TYPE_RESPONSE, PAYLOAD_TYPE_TXT_MSG, - REQ_TYPE_GET_STATUS, - REQ_TYPE_GET_TELEMETRY_DATA, ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD, ) -from .companion_base import CompanionBase, ResponseWaiter +from .companion_base import CompanionBase from .constants import ( ADV_TYPE_CHAT, DEFAULT_MAX_CHANNELS, DEFAULT_MAX_CONTACTS, DEFAULT_OFFLINE_QUEUE_SIZE, - PROTOCOL_CODE_ANON_REQ, - PROTOCOL_CODE_BINARY_REQ, - PROTOCOL_CODE_RAW_DATA, - TXT_TYPE_PLAIN, ) -from .models import Contact, SentResult +from .models import AdvertPath, Contact logger = logging.getLogger("CompanionBridge") @@ -136,13 +122,21 @@ def _log(msg: str) -> None: self._pending_ack_crcs: set[int] = set() ack_handler = _BridgeAckHandler(self) - protocol_response_handler = ProtocolResponseHandler(_log, identity, self.contacts) - login_response_handler = LoginResponseHandler(identity, self.contacts, _log) - login_response_handler.set_protocol_response_handler(protocol_response_handler) - path_handler = PathHandler( - _log, ack_handler, protocol_response_handler, login_response_handler + + # Use shared factory for the core protocol handlers + core = create_core_handlers( + identity=identity, + contacts=self.contacts, + channels=self.channels, + event_service=self._event_service, + send_packet_fn=_handler_send_packet, + log_fn=_log, + node_name=node_name, + radio_config=self._radio_config, + ack_handler=ack_handler, ) + # Bridge-specific: LoginServerHandler for incoming login requests auth_cb = authenticate_callback if auth_cb is None: @@ -158,33 +152,35 @@ def _reject_all(*args, **kwargs) -> tuple[bool, int]: self._handlers: dict[int, Any] = { PAYLOAD_TYPE_ACK: ack_handler, - PAYLOAD_TYPE_TXT_MSG: TextMessageHandler( - identity, - self.contacts, - _log, - _handler_send_packet, - self._event_service, - self._radio_config, - ), - PAYLOAD_TYPE_ADVERT: AdvertHandler(_log, event_service=self._event_service), - PAYLOAD_TYPE_PATH: path_handler, + PAYLOAD_TYPE_TXT_MSG: core.text_handler, + PAYLOAD_TYPE_ADVERT: core.advert_handler, + PAYLOAD_TYPE_PATH: core.path_handler, PAYLOAD_TYPE_ANON_REQ: login_server_handler, - PAYLOAD_TYPE_GRP_TXT: GroupTextHandler( - identity, - self.contacts, - _log, - _handler_send_packet, - self.channels, - self._event_service, - node_name, - ), - PAYLOAD_TYPE_RESPONSE: login_response_handler, + PAYLOAD_TYPE_GRP_TXT: core.group_text_handler, + PAYLOAD_TYPE_RESPONSE: core.login_response_handler, } - self._protocol_response_handler = protocol_response_handler - self._login_response_handler = login_response_handler - self._text_handler = self._handlers[PAYLOAD_TYPE_TXT_MSG] - protocol_response_handler.set_binary_response_callback(self._on_binary_response) + self._protocol_response_handler = core.protocol_response_handler + self._login_response_handler = core.login_response_handler + self._text_handler_ref = core.text_handler + core.protocol_response_handler.set_binary_response_callback(self._on_binary_response) + + # ------------------------------------------------------------------------- + # Handler accessors (used by CompanionBase concrete send methods) + # ------------------------------------------------------------------------- + + def _get_protocol_response_handler(self) -> Any: + return self._protocol_response_handler + + def _get_login_response_handler(self) -> Any: + return self._login_response_handler + + def _get_text_handler(self) -> Any: + return self._text_handler_ref + + def _track_pending_ack(self, ack_crc: int) -> None: + if len(self._pending_ack_crcs) < MAX_PENDING_ACK_CRCS: + self._pending_ack_crcs.add(ack_crc) # ------------------------------------------------------------------------- # RX Entry Point @@ -211,16 +207,12 @@ async def process_received_packet(self, packet: Packet) -> None: def _update_stores_from_advert(self, packet: Packet, advert_data: dict): """Update ContactStore and PathCache from advert result. Returns the Contact or None.""" try: - from .models import AdvertPath - pub_key = bytes.fromhex(advert_data.get("public_key", "")) if len(pub_key) < 7: return None name = advert_data.get("name", "") if not name: return None - # Inbound path: route the advert took (discovery list / advert path display). - # Stored in path_cache only; contact.out_path is separate (e.g. path discovery). path_len = getattr(packet, "path_len", 0) or 0 path = getattr(packet, "path", bytearray()) or bytearray() effective_len = path_len if path_len > 0 else len(path) @@ -229,7 +221,6 @@ def _update_stores_from_advert(self, packet: Packet, advert_data: dict): last_advert_ts = advert_data.get("advert_timestamp", 0) if last_advert_ts > now: last_advert_ts = now - # Contact: out_path is for sending; leave unknown (-1) until set by path update. contact = Contact( public_key=pub_key, name=name, @@ -243,7 +234,6 @@ def _update_stores_from_advert(self, packet: Packet, advert_data: dict): ) self.contacts.add(contact) - # Path cache: store inbound path for discovery list display. self.path_cache.update( AdvertPath( public_key_prefix=pub_key[:7], @@ -285,147 +275,6 @@ async def stop(self) -> None: def is_running(self) -> bool: return self._running - # ------------------------------------------------------------------------- - # Messaging - # ------------------------------------------------------------------------- - - async def send_text_message( - self, - pub_key: bytes, - text: str, - txt_type: int = TXT_TYPE_PLAIN, - attempt: int = 1, - ) -> SentResult: - contact = self.contacts.get_by_key(pub_key) - if not contact: - logger.warning(f"Contact not found for key {pub_key.hex()[:12]}...") - return SentResult(success=False) - proxy = self.contacts.get_by_name(contact.name) - if not proxy: - return SentResult(success=False) - try: - is_flood = proxy.out_path_len < 0 - msg_type = "flood" if is_flood else "direct" - pkt, ack_crc = PacketBuilder.create_text_message( - contact=proxy, - local_identity=self._identity, - message=text, - attempt=attempt, - message_type=msg_type, - ) - self._apply_flood_scope(pkt) - if len(self._pending_ack_crcs) < MAX_PENDING_ACK_CRCS: - self._pending_ack_crcs.add(ack_crc) - success = await self._packet_injector(pkt, wait_for_ack=True) - if success: - self.stats.record_tx(is_flood=is_flood) - else: - self.stats.record_tx_error() - return SentResult( - success=success, - is_flood=is_flood, - expected_ack=ack_crc, - timeout_ms=None, - ) - except Exception as e: - logger.error(f"Error sending text message: {e}") - self.stats.record_tx_error() - return SentResult(success=False) - - async def send_channel_message(self, channel_idx: int, text: str) -> bool: - channel = self.channels.get(channel_idx) - if not channel: - logger.warning(f"Channel {channel_idx} not found") - return False - try: - pkt = PacketBuilder.create_group_datagram( - group_name=channel.name, - local_identity=self._identity, - message=text, - sender_name=self.prefs.node_name, - channels_config=self.channels.get_channels(), - ) - self._apply_flood_scope(pkt) - success = await self._packet_injector(pkt, wait_for_ack=False) - if success: - self.stats.record_tx(is_flood=True) - else: - self.stats.record_tx_error() - return success - except Exception as e: - logger.error(f"Error sending channel message: {e}") - self.stats.record_tx_error() - return False - - async def send_raw_data( - self, - dest_key: bytes, - data: bytes, - path: Optional[bytes] = None, - ) -> SentResult: - contact = self.contacts.get_by_key(dest_key) - if not contact: - return SentResult(success=False) - try: - proxy = self.contacts.get_by_name(contact.name) - if not proxy: - return SentResult(success=False) - pkt, _ = PacketBuilder.create_protocol_request( - contact=proxy, - local_identity=self._identity, - protocol_code=PROTOCOL_CODE_RAW_DATA, - data=data, - ) - success = await self._packet_injector(pkt, wait_for_ack=False) - return SentResult(success=success) - except Exception as e: - logger.error(f"Error sending raw data: {e}") - return SentResult(success=False) - - # ------------------------------------------------------------------------- - # Path & Routing - # ------------------------------------------------------------------------- - - async def send_trace_path( - self, - pub_key: bytes, - tag: int, - auth_code: int, - flags: int = 0, - ) -> bool: - contact = self.contacts.get_by_key(pub_key) - if not contact: - return False - path = list(contact.out_path) if contact.out_path else [] - if not path: - path = [contact.public_key[0]] - try: - pkt = PacketBuilder.create_trace(tag, auth_code, flags, path=path) - return await self._packet_injector(pkt, wait_for_ack=False) - except Exception as e: - logger.error(f"Error sending trace: {e}") - return False - - async def send_control_data(self, data: bytes) -> bool: - """Send CONTROL packet (e.g. discovery). data = flags (0x80 for DISCOVER_REQ) + payload. - Returns True if sent. - """ - if not data or len(data) > 254: - return False - if (data[0] & 0x80) == 0: - return False # firmware requires first byte to have 0x80 set (e.g. DISCOVER_REQ) - try: - pkt = Packet() - pkt.header = PacketBuilder._create_header(PAYLOAD_TYPE_CONTROL, route_type="direct") - pkt.path_len = 0 - pkt.path = bytearray() - pkt.payload = bytearray(data) - pkt.payload_len = len(data) - return await self._packet_injector(pkt, wait_for_ack=False) - except Exception as e: - logger.error(f"Error sending control data: {e}") - return False - # ------------------------------------------------------------------------- # Key Management # ------------------------------------------------------------------------- @@ -438,219 +287,3 @@ def import_private_key(self, key: bytes) -> bool: except Exception as e: logger.error(f"Error importing private key: {e}") return False - - # ------------------------------------------------------------------------- - # Requests - # ------------------------------------------------------------------------- - - async def send_login(self, pub_key: bytes, password: str) -> dict: - contact = self.contacts.get_by_key(pub_key) - if not contact: - return {"success": False, "reason": "Contact not found"} - proxy = self.contacts.get_by_name(contact.name) - if not proxy: - return {"success": False, "reason": "Contact not found"} - dest_hash = bytes.fromhex(proxy.public_key)[0] - self._login_response_handler.store_login_password(dest_hash, password) - login_result = {"success": False, "data": {}} - login_event = asyncio.Event() - - def _login_cb(success: bool, data: dict) -> None: - login_result["success"] = success - login_result["data"] = data - login_event.set() - - self._login_response_handler.set_login_callback(_login_cb) - try: - pkt = PacketBuilder.create_login_packet( - contact=proxy, local_identity=self._identity, password=password - ) - await self._packet_injector(pkt, wait_for_ack=False) - try: - await asyncio.wait_for(login_event.wait(), timeout=10.0) - except asyncio.TimeoutError: - return {"success": False, "reason": "Login response timeout"} - data = login_result["data"] - return { - "success": login_result["success"], - "repeater": contact.name, - "is_admin": data.get("is_admin", False), - "keep_alive_interval": data.get("keep_alive_interval", 0), - "tag": data.get("timestamp", 0), - "acl_permissions": data.get("reserved", data.get("permissions", 0)), - "reason": "Login successful" if login_result["success"] else "Login failed", - } - except Exception as e: - logger.error(f"Login error: {e}") - return {"success": False, "reason": str(e)} - finally: - self._login_response_handler.set_login_callback(None) - self._login_response_handler.clear_login_password(dest_hash) - - async def send_status_request(self, pub_key: bytes, timeout: float = 15.0) -> dict: - """Send a protocol request for repeater stats (REQ_TYPE_GET_STATUS). - - The firmware handles CMD_SEND_STATUS_REQ by calling - ``sendRequest(*recipient, REQ_TYPE_GET_STATUS, tag, est_timeout)`` - which creates a PAYLOAD_TYPE_REQ packet. The remote repeater replies - with a PAYLOAD_TYPE_RESPONSE containing ``reflected_timestamp(4) + - RepeaterStats(48)``. - """ - contact = self.contacts.get_by_key(pub_key) - if not contact: - return {"success": False, "reason": "Contact not found"} - proxy = self.contacts.get_by_name(contact.name) - if not proxy: - return {"success": False, "reason": "Contact not found"} - contact_hash = bytes.fromhex(proxy.public_key)[0] - waiter = ResponseWaiter() - self._protocol_response_handler.set_response_callback(contact_hash, waiter.callback) - try: - pkt, _ = PacketBuilder.create_protocol_request( - contact=proxy, - local_identity=self._identity, - protocol_code=REQ_TYPE_GET_STATUS, - data=b"", - ) - await self._packet_injector(pkt, wait_for_ack=False) - result = await waiter.wait(timeout) - return { - "success": result.get("success", False), - "repeater": contact.name, - "stats": result.get("parsed", {}), - "response_text": result.get("text"), - "reason": "Stats received" if result.get("success") else "Stats request failed", - } - except Exception as e: - logger.error(f"Status request error: {e}") - return {"success": False, "reason": str(e)} - finally: - self._protocol_response_handler.clear_response_callback(contact_hash) - - async def send_telemetry_request( - self, - pub_key: bytes, - want_base: bool = True, - want_location: bool = True, - want_environment: bool = True, - timeout: float = 10.0, - ) -> dict: - contact = self.contacts.get_by_key(pub_key) - if not contact: - return {"success": False, "reason": "Contact not found"} - proxy = self.contacts.get_by_name(contact.name) - if not proxy: - return {"success": False, "reason": "Contact not found"} - contact_hash = bytes.fromhex(proxy.public_key)[0] - waiter = ResponseWaiter() - self._protocol_response_handler.set_response_callback(contact_hash, waiter.callback) - try: - inv = PacketBuilder._compute_inverse_perm_mask( - want_base, want_location, want_environment - ) - pkt, _ = PacketBuilder.create_protocol_request( - contact=proxy, - local_identity=self._identity, - protocol_code=REQ_TYPE_GET_TELEMETRY_DATA, - data=bytes([inv]), - ) - await self._packet_injector(pkt, wait_for_ack=False) - result = await waiter.wait(timeout) - return { - "success": result.get("success", False), - "contact": contact.name, - "telemetry_data": result.get("parsed", {}), - "response_text": result.get("text"), - "reason": "Telemetry received" if result.get("success") else "Telemetry failed", - } - except Exception as e: - logger.error(f"Telemetry error: {e}") - return {"success": False, "reason": str(e)} - finally: - self._protocol_response_handler.clear_response_callback(contact_hash) - - async def send_binary_request(self, pub_key: bytes, data: bytes) -> dict: - """Legacy: send binary request and wait. Prefer send_binary_req + on_binary_response.""" - return await self._send_protocol_request(pub_key, PROTOCOL_CODE_BINARY_REQ, data) - - async def send_anon_request(self, pub_key: bytes, data: bytes) -> dict: - return await self._send_protocol_request(pub_key, PROTOCOL_CODE_ANON_REQ, data) - - async def _send_protocol_request(self, pub_key: bytes, protocol_code: int, data: bytes) -> dict: - contact = self.contacts.get_by_key(pub_key) - if not contact: - return {"success": False, "reason": "Contact not found"} - proxy = self.contacts.get_by_name(contact.name) - if not proxy: - return {"success": False, "reason": "Contact not found"} - contact_hash = bytes.fromhex(proxy.public_key)[0] - waiter = ResponseWaiter() - self._protocol_response_handler.set_response_callback(contact_hash, waiter.callback) - try: - pkt, _ = PacketBuilder.create_protocol_request( - contact=proxy, - local_identity=self._identity, - protocol_code=protocol_code, - data=data, - ) - await self._packet_injector(pkt, wait_for_ack=False) - result = await waiter.wait(10.0) - return { - "success": result.get("success", False), - "response": result.get("text"), - "parsed_data": result.get("parsed", {}), - "reason": "Success" if result.get("success") else "Failed", - } - except Exception as e: - logger.error(f"Protocol request error: {e}") - return {"success": False, "reason": str(e)} - finally: - self._protocol_response_handler.clear_response_callback(contact_hash) - - async def send_repeater_command( - self, pub_key: bytes, command: str, parameters: Optional[str] = None - ) -> dict: - contact = self.contacts.get_by_key(pub_key) - if not contact: - return {"success": False, "reason": "Contact not found"} - proxy = self.contacts.get_by_name(contact.name) - if not proxy: - return {"success": False, "reason": "Contact not found"} - full_command = command - if parameters: - full_command += f" {parameters}" - response_data = {"text": None, "success": False} - response_event = asyncio.Event() - - def _response_cb(message_text: str, sender_contact: Any) -> None: - response_data["text"] = message_text - response_data["success"] = True - response_event.set() - - self._text_handler.set_command_response_callback(_response_cb) - try: - msg_type = "flood" if proxy.out_path_len < 0 else "direct" - pkt, ack_crc = PacketBuilder.create_text_message( - contact=proxy, - local_identity=self._identity, - message=full_command, - attempt=1, - message_type=msg_type, - ) - await self._packet_injector(pkt, wait_for_ack=True) - try: - await asyncio.wait_for(response_event.wait(), timeout=15.0) - except asyncio.TimeoutError: - pass - return { - "success": response_data["success"], - "repeater": contact.name, - "command": command, - "response": response_data["text"], - "reason": "Command successful" if response_data["success"] else "No response", - } - except Exception as e: - logger.error(f"Repeater command error: {e}") - return {"success": False, "reason": str(e)} - finally: - self._text_handler.set_command_response_callback(None) diff --git a/src/pymc_core/companion/companion_radio.py b/src/pymc_core/companion/companion_radio.py index f98f839..c810cfd 100644 --- a/src/pymc_core/companion/companion_radio.py +++ b/src/pymc_core/companion/companion_radio.py @@ -11,24 +11,18 @@ import asyncio import logging -import random from typing import Any, Optional from ..node.node import MeshNode -from ..protocol import LocalIdentity, Packet, PacketBuilder -from ..protocol.constants import PAYLOAD_TYPE_CONTROL +from ..protocol import LocalIdentity, Packet +from ..protocol.constants import ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD from .companion_base import CompanionBase from .constants import ( ADV_TYPE_CHAT, DEFAULT_MAX_CHANNELS, DEFAULT_MAX_CONTACTS, DEFAULT_OFFLINE_QUEUE_SIZE, - PROTOCOL_CODE_ANON_REQ, - PROTOCOL_CODE_BINARY_REQ, - PROTOCOL_CODE_RAW_DATA, - TXT_TYPE_PLAIN, ) -from .models import SentResult logger = logging.getLogger("CompanionRadio") @@ -107,6 +101,19 @@ async def _send_packet(self, pkt: Packet, wait_for_ack: bool = False) -> bool: """Send a packet via the MeshNode dispatcher.""" return await self.node.dispatcher.send_packet(pkt, wait_for_ack=wait_for_ack) + # ------------------------------------------------------------------------- + # Handler accessors (used by CompanionBase concrete send methods) + # ------------------------------------------------------------------------- + + def _get_protocol_response_handler(self) -> Any: + return getattr(self.node.dispatcher, "protocol_response_handler", None) + + def _get_login_response_handler(self) -> Any: + return getattr(self.node.dispatcher, "login_response_handler", None) + + def _get_text_handler(self) -> Any: + return getattr(self.node.dispatcher, "text_message_handler", None) + # ------------------------------------------------------------------------- # Lifecycle # ------------------------------------------------------------------------- @@ -138,86 +145,6 @@ async def stop(self) -> None: def is_running(self) -> bool: return self._running - # ------------------------------------------------------------------------- - # Messaging - # ------------------------------------------------------------------------- - - async def send_text_message( - self, - pub_key: bytes, - text: str, - txt_type: int = TXT_TYPE_PLAIN, - attempt: int = 1, - ) -> SentResult: - contact = self.contacts.get_by_key(pub_key) - if not contact: - logger.warning(f"Contact not found for key {pub_key.hex()[:12]}...") - return SentResult(success=False) - try: - result = await self.node.send_text( - contact_name=contact.name, - message=text, - attempt=attempt, - ) - success = result.get("success", False) - is_flood = contact.out_path_len <= 0 - if success: - self.stats.record_tx(is_flood=is_flood) - else: - self.stats.record_tx_error() - return SentResult( - success=success, - is_flood=is_flood, - expected_ack=result.get("crc"), - timeout_ms=None, - ) - except Exception as e: - logger.error(f"Error sending text message: {e}") - self.stats.record_tx_error() - return SentResult(success=False) - - async def send_channel_message(self, channel_idx: int, text: str) -> bool: - channel = self.channels.get(channel_idx) - if not channel: - logger.warning(f"Channel {channel_idx} not found") - return False - try: - result = await self.node.send_group_text( - group_name=channel.name, - message=text, - ) - success = result.get("success", False) - if success: - self.stats.record_tx(is_flood=True) - else: - self.stats.record_tx_error() - return success - except Exception as e: - logger.error(f"Error sending channel message: {e}") - self.stats.record_tx_error() - return False - - async def send_raw_data( - self, - dest_key: bytes, - data: bytes, - path: Optional[bytes] = None, - ) -> SentResult: - contact = self.contacts.get_by_key(dest_key) - if not contact: - logger.warning(f"Contact not found for raw data send: {dest_key.hex()[:12]}") - return SentResult(success=False) - try: - result = await self.node.send_protocol_request( - repeater_name=contact.name, - protocol_code=PROTOCOL_CODE_RAW_DATA, - data=data, - ) - return SentResult(success=result.get("success", False)) - except Exception as e: - logger.error(f"Error sending raw data: {e}") - return SentResult(success=False) - # ------------------------------------------------------------------------- # Flood Scope (sync to dispatcher) # ------------------------------------------------------------------------- @@ -267,33 +194,6 @@ def set_tx_power(self, power_dbm: int) -> bool: return False return True - # ------------------------------------------------------------------------- - # Path & Routing - # ------------------------------------------------------------------------- - - async def send_trace_path( - self, - pub_key: bytes, - tag: int, - auth_code: int, - flags: int = 0, - ) -> bool: - contact = self.contacts.get_by_key(pub_key) - if not contact: - logger.warning(f"Contact not found for trace: {pub_key.hex()[:12]}") - return False - try: - result = await self.node.send_trace_packet( - contact_name=contact.name, - tag=tag, - auth_code=auth_code, - flags=flags, - ) - return result.get("success", False) - except Exception as e: - logger.error(f"Error sending trace: {e}") - return False - # ------------------------------------------------------------------------- # Key Management # ------------------------------------------------------------------------- @@ -319,138 +219,6 @@ def import_private_key(self, key: bytes) -> bool: logger.error(f"Error importing private key: {e}") return False - # ------------------------------------------------------------------------- - # Requests - # ------------------------------------------------------------------------- - - async def send_login(self, pub_key: bytes, password: str) -> dict: - contact = self.contacts.get_by_key(pub_key) - if not contact: - return {"success": False, "reason": "Contact not found"} - try: - return await self.node.send_login( - repeater_name=contact.name, - password=password, - ) - except Exception as e: - logger.error(f"Login error: {e}") - return {"success": False, "reason": str(e)} - - async def send_status_request(self, pub_key: bytes) -> dict: - contact = self.contacts.get_by_key(pub_key) - if not contact: - return {"success": False, "reason": "Contact not found"} - try: - return await self.node.send_status_request(repeater_name=contact.name) - except Exception as e: - logger.error(f"Status request error: {e}") - return {"success": False, "reason": str(e)} - - async def send_telemetry_request( - self, - pub_key: bytes, - want_base: bool = True, - want_location: bool = True, - want_environment: bool = True, - timeout: float = 10.0, - ) -> dict: - contact = self.contacts.get_by_key(pub_key) - if not contact: - return {"success": False, "reason": "Contact not found"} - try: - return await self.node.send_telemetry_request( - contact_name=contact.name, - want_base=want_base, - want_location=want_location, - want_environment=want_environment, - timeout=timeout, - ) - except Exception as e: - logger.error(f"Telemetry request error: {e}") - return {"success": False, "reason": str(e)} - - async def send_binary_request(self, pub_key: bytes, data: bytes) -> dict: - """Legacy: send binary request and wait. Prefer send_binary_req + on_binary_response.""" - contact = self.contacts.get_by_key(pub_key) - if not contact: - return {"success": False, "reason": "Contact not found"} - try: - return await self.node.send_protocol_request( - repeater_name=contact.name, - protocol_code=PROTOCOL_CODE_BINARY_REQ, - data=data, - ) - except Exception as e: - logger.error(f"Binary request error: {e}") - return {"success": False, "reason": str(e)} - - async def send_anon_request(self, pub_key: bytes, data: bytes) -> dict: - contact = self.contacts.get_by_key(pub_key) - if not contact: - return {"success": False, "reason": "Contact not found"} - try: - return await self.node.send_protocol_request( - repeater_name=contact.name, - protocol_code=PROTOCOL_CODE_ANON_REQ, - data=data, - ) - except Exception as e: - logger.error(f"Anon request error: {e}") - return {"success": False, "reason": str(e)} - - async def send_repeater_command( - self, pub_key: bytes, command: str, parameters: Optional[str] = None - ) -> dict: - """Send a text-based command to a repeater and await response.""" - contact = self.contacts.get_by_key(pub_key) - if not contact: - return {"success": False, "reason": "Contact not found"} - try: - result = await self.node.send_repeater_command( - repeater_name=contact.name, - command=command, - parameters=parameters, - ) - reason = "Command successful" if result.get("success") else "No response" - return { - "success": result.get("success", False), - "repeater": contact.name, - "command": command, - "response": result.get("response"), - "reason": reason, - } - except Exception as e: - logger.error(f"Repeater command error: {e}") - return {"success": False, "reason": str(e)} - - # ------------------------------------------------------------------------- - # Control Data - # ------------------------------------------------------------------------- - - async def send_control_data(self, data: Optional[bytes] = None) -> bool: - """Send CONTROL packet. If data valid (len 1-254, byte0 0x80), send as control payload; - else send default discovery request (backward compat). - """ - if data and len(data) <= 254 and (data[0] & 0x80) != 0: - try: - pkt = Packet() - pkt.header = PacketBuilder._create_header(PAYLOAD_TYPE_CONTROL, route_type="direct") - pkt.path_len = 0 - pkt.path = bytearray() - pkt.payload = bytearray(data) - pkt.payload_len = len(data) - return await self.node.dispatcher.send_packet(pkt, wait_for_ack=False) - except Exception as e: - logger.error(f"Error sending control data: {e}") - return False - try: - tag = random.randint(0, 0xFFFFFFFF) - pkt = PacketBuilder.create_discovery_request(tag, filter_mask=0x04) - return await self.node.dispatcher.send_packet(pkt, wait_for_ack=False) - except Exception as e: - logger.error(f"Error sending control data: {e}") - return False - # ------------------------------------------------------------------------- # Statistics (override for radio hardware) # ------------------------------------------------------------------------- @@ -480,8 +248,6 @@ def _setup_packet_callbacks(self) -> None: ) async def _on_packet_received(self, pkt: Any) -> None: - from ..protocol.constants import ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD - route_type = pkt.get_route_type() is_flood = route_type in (ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD) self.stats.record_rx(is_flood=is_flood) diff --git a/src/pymc_core/node/dispatcher.py b/src/pymc_core/node/dispatcher.py index 9bb7ef8..575382b 100644 --- a/src/pymc_core/node/dispatcher.py +++ b/src/pymc_core/node/dispatcher.py @@ -19,15 +19,10 @@ # Import handler classes from .handlers import ( AckHandler, - AdvertHandler, AnonReqResponseHandler, ControlHandler, - GroupTextHandler, - LoginResponseHandler, - PathHandler, - ProtocolResponseHandler, - TextMessageHandler, TraceHandler, + create_core_handlers, ) ACK_TIMEOUT = 5.0 # seconds to wait for an ACK @@ -167,94 +162,56 @@ def register_default_handlers( # Keep our identity handy for detecting our own packets self.local_identity = local_identity - # Set up ACK handler with callback to us + # --- ACK handler (dispatcher-specific wiring) --- ack_handler = AckHandler(self._log, self) ack_handler.set_ack_received_callback(self._register_ack_received) - - # Register all the standard handlers - self.register_handler( - AdvertHandler.payload_type(), - AdvertHandler(self._log, event_service=event_service), - ) self.register_handler(AckHandler.payload_type(), ack_handler) - # Text message handler - needs to send ACKs back through us - text_message_handler = TextMessageHandler( - local_identity, - contacts, - self._log, - self.send_packet, - event_service, - radio_config, - ) - # Keep a reference so the node can use it - self.text_message_handler = text_message_handler - self.register_handler( - TextMessageHandler.payload_type(), - text_message_handler, - ) - # Group text handler with channel database - self.register_handler( - GroupTextHandler.payload_type(), - GroupTextHandler( - local_identity, - contacts, - self._log, - self.send_packet, - channel_db, - event_service, - node_name, - ), - ) - # Protocol response handler for encrypted responses (including telemetry) - protocol_response_handler = ProtocolResponseHandler(self._log, local_identity, contacts) - # Keep a reference for the node - self.protocol_response_handler = protocol_response_handler - - # Login response handler for PAYLOAD_TYPE_RESPONSE packets - login_response_handler = LoginResponseHandler(local_identity, contacts, self._log) - # Connect protocol response handler for forwarding telemetry - login_response_handler.set_protocol_response_handler(protocol_response_handler) - # Keep references for backward compatibility - # Note: telemetry now uses protocol_response_handler, login uses PAYLOAD_TYPE_RESPONSE - self.login_response_handler = login_response_handler - # For backward compatibility, point telemetry handler to protocol response handler - self.telemetry_response_handler = protocol_response_handler - - # PATH handler - for route discovery packets, with ACK and protocol response processing - path_handler = PathHandler( - self._log, ack_handler, protocol_response_handler, login_response_handler + # --- Core handlers via shared factory --- + core = create_core_handlers( + identity=local_identity, + contacts=contacts, + channels=channel_db, + event_service=event_service, + send_packet_fn=self.send_packet, + log_fn=self._log, + node_name=node_name, + radio_config=radio_config, + ack_handler=ack_handler, ) - self.register_handler(PathHandler.payload_type(), path_handler) - # Login response handler for PAYLOAD_TYPE_RESPONSE packets - self.register_handler( - LoginResponseHandler.payload_type(), - login_response_handler, - ) - - # Anonymous request response handler for login responses that come as ANON_REQ + # Keep references for companion layer access + self.text_message_handler = core.text_handler + self.protocol_response_handler = core.protocol_response_handler + self.login_response_handler = core.login_response_handler + # Backward compat alias + self.telemetry_response_handler = core.protocol_response_handler + + # Register core handlers by payload type + from .handlers import AdvertHandler as _Adv + from .handlers import GroupTextHandler as _Grp + from .handlers import LoginResponseHandler as _Login + from .handlers import PathHandler as _Path + from .handlers import TextMessageHandler as _Txt + + self.register_handler(_Adv.payload_type(), core.advert_handler) + self.register_handler(_Txt.payload_type(), core.text_handler) + self.register_handler(_Grp.payload_type(), core.group_text_handler) + self.register_handler(_Path.payload_type(), core.path_handler) + self.register_handler(_Login.payload_type(), core.login_response_handler) + + # --- Dispatcher-only handlers --- self.register_handler( AnonReqResponseHandler.payload_type(), AnonReqResponseHandler(local_identity, contacts, self._log), ) - # TRACE handler for diagnostics and routing analysis - trace_handler = TraceHandler(self._log, protocol_response_handler) - self.register_handler( - TraceHandler.payload_type(), - trace_handler, - ) - # Keep a reference for the node + trace_handler = TraceHandler(self._log, core.protocol_response_handler) + self.register_handler(TraceHandler.payload_type(), trace_handler) self.trace_handler = trace_handler - # CONTROL handler for node discovery control_handler = ControlHandler(self._log) - self.register_handler( - ControlHandler.payload_type(), - control_handler, - ) - # Keep a reference for the node + self.register_handler(ControlHandler.payload_type(), control_handler) self.control_handler = control_handler self._logger.info("Default handlers registered.") diff --git a/src/pymc_core/node/handlers/__init__.py b/src/pymc_core/node/handlers/__init__.py index 27fc5c5..6bad086 100644 --- a/src/pymc_core/node/handlers/__init__.py +++ b/src/pymc_core/node/handlers/__init__.py @@ -10,6 +10,7 @@ from .login_response import AnonReqResponseHandler, LoginResponseHandler from .path import PathHandler from .protocol_response import ProtocolResponseHandler +from .registry import CoreHandlers, create_core_handlers from .text import TextMessageHandler from .trace import TraceHandler @@ -25,4 +26,6 @@ "AnonReqResponseHandler", "TraceHandler", "ControlHandler", + "CoreHandlers", + "create_core_handlers", ] diff --git a/src/pymc_core/node/handlers/registry.py b/src/pymc_core/node/handlers/registry.py new file mode 100644 index 0000000..746188f --- /dev/null +++ b/src/pymc_core/node/handlers/registry.py @@ -0,0 +1,104 @@ +"""Handler registry for creating and wiring standard MeshCore protocol handlers. + +Both the :class:`Dispatcher` and :class:`CompanionBridge` need the same core +set of handlers — this module provides a shared factory so handler creation +and cross-wiring only lives in one place. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Optional + +from .advert import AdvertHandler +from .group_text import GroupTextHandler +from .login_response import LoginResponseHandler +from .path import PathHandler +from .protocol_response import ProtocolResponseHandler +from .text import TextMessageHandler + + +@dataclass +class CoreHandlers: + """Bundle of the core protocol handlers shared by Dispatcher and Bridge.""" + + text_handler: TextMessageHandler + advert_handler: AdvertHandler + group_text_handler: GroupTextHandler + protocol_response_handler: ProtocolResponseHandler + login_response_handler: LoginResponseHandler + path_handler: PathHandler + + +def create_core_handlers( + *, + identity: Any, + contacts: Any, + channels: Any, + event_service: Any, + send_packet_fn: Callable, + log_fn: Callable, + node_name: str, + radio_config: Optional[dict] = None, + ack_handler: Any = None, +) -> CoreHandlers: + """Create and wire the standard set of MeshCore protocol handlers. + + This is the single source of truth for handler construction. Both + :meth:`Dispatcher.register_default_handlers` and + :class:`CompanionBridge.__init__` delegate here. + + Args: + identity: The local identity for encryption/signing. + contacts: Contact storage. + channels: Channel database. + event_service: Event service for broadcasting mesh events. + send_packet_fn: Async callable to send a packet (the transport). + log_fn: Logging callable (``str -> None``). + node_name: Human-readable node name. + radio_config: Optional radio configuration dict. + ack_handler: ACK handler instance (varies between Dispatcher and + Bridge). If ``None``, the :class:`PathHandler` is constructed + without ACK forwarding. + """ + protocol_response_handler = ProtocolResponseHandler(log_fn, identity, contacts) + + login_response_handler = LoginResponseHandler(identity, contacts, log_fn) + login_response_handler.set_protocol_response_handler(protocol_response_handler) + + path_handler = PathHandler( + log_fn, + ack_handler, + protocol_response_handler, + login_response_handler, + ) + + text_handler = TextMessageHandler( + identity, + contacts, + log_fn, + send_packet_fn, + event_service, + radio_config, + ) + + advert_handler = AdvertHandler(log_fn, event_service=event_service) + + group_text_handler = GroupTextHandler( + identity, + contacts, + log_fn, + send_packet_fn, + channels, + event_service, + node_name, + ) + + return CoreHandlers( + text_handler=text_handler, + advert_handler=advert_handler, + group_text_handler=group_text_handler, + protocol_response_handler=protocol_response_handler, + login_response_handler=login_response_handler, + path_handler=path_handler, + ) diff --git a/src/pymc_core/node/node.py b/src/pymc_core/node/node.py index 266be07..9d124d6 100644 --- a/src/pymc_core/node/node.py +++ b/src/pymc_core/node/node.py @@ -17,14 +17,18 @@ class MeshNode: - """Represents a node in a mesh network for radio communication. + """Thin transport layer for mesh radio communication. - Manages radio communication, message routing, and protocol handling - within a mesh network. Provides high-level APIs for sending messages, - telemetry requests, and commands to other nodes and repeaters. + Owns a radio interface and a :class:`Dispatcher` that handles raw packet + I/O (TX lock, ACK management, handler dispatch). Application-layer + concerns — contact lookup, message building, response waiting — belong in + the companion layer (:class:`CompanionBase` and its subclasses). - The node integrates with various components like contact storage, - channel databases, and event services for comprehensive mesh functionality. + Typical usage:: + + node = MeshNode(radio, identity, config={...}) + await node.start() # blocks in dispatcher.run_forever() + node.stop() """ def __init__( @@ -40,9 +44,6 @@ def __init__( ) -> None: """Initialise a mesh network node instance. - Sets up the node's core components including radio interface, - identity management, and communication handlers. - Args: radio: Radio hardware interface for transmission/reception. local_identity: Node's cryptographic identity for secure communication. @@ -80,73 +81,41 @@ def __init__( node_name=self.node_name, radio_config=self.radio_config, ) - # Store reference to text handler for command response callbacks - self._text_handler = None - - # Helper Methods - def _find_and_call_handler_method(self, method_name: str, *args, **kwargs) -> bool: - """Find and call a method on any handler that has it. Returns True if called.""" - found = False - - if hasattr(self.dispatcher, "_handler_instances"): - for handler in self.dispatcher._handler_instances.values(): - if hasattr(handler, method_name): - getattr(handler, method_name)(*args, **kwargs) - found = True - else: - for attr_name in dir(self.dispatcher): - if attr_name.endswith("_handler"): - handler = getattr(self.dispatcher, attr_name, None) - if handler and hasattr(handler, method_name): - getattr(handler, method_name)(*args, **kwargs) - found = True - - return found - def _get_contact_or_raise(self, contact_name: str): - """Get contact by name or raise RuntimeError if not found.""" - contact = self.contacts.get_by_name(contact_name) if self.contacts else None - if not contact: - raise RuntimeError(f"No contact '{contact_name}'") - return contact + # ------------------------------------------------------------------------- + # Lifecycle + # ------------------------------------------------------------------------- - class _ResponseWaiter: - """Helper class for managing asynchronous response callbacks. + async def start(self) -> None: + """Start the mesh node and begin processing radio communications. - Provides a synchronisation mechanism for waiting on responses - from remote nodes with timeout support. + Enters the dispatcher's main event loop for handling incoming/outgoing + messages. This method blocks until the node is stopped. """ + await self.dispatcher.run_forever() - def __init__(self): - self.event = asyncio.Event() - self.data = {"success": False, "text": None, "parsed": {}} - - def callback(self, success: bool, text: str, parsed_data: Optional[dict] = None): - """Standard callback for response handlers.""" - self.data["success"] = success - self.data["text"] = text - self.data["parsed"] = parsed_data or {} - self.event.set() + def stop(self): + """Stop the mesh node and clean up associated services.""" + try: + self.logger.info("Node stopped") + except Exception as e: + self.logger.error(f"Error stopping node: {e}") - async def wait(self, timeout: float = 10.0) -> dict: - """Wait for response with timeout. Returns the response data.""" - try: - await asyncio.wait_for(self.event.wait(), timeout=timeout) - return self.data - except asyncio.TimeoutError: - return {"success": False, "text": None, "parsed": {}, "timeout": True} + # ------------------------------------------------------------------------- + # Transport + # ------------------------------------------------------------------------- - def _time_operation(self): - """Context manager for timing operations and calculating RTT.""" - import time - from contextlib import contextmanager + async def send_packet(self, pkt: Any, *, wait_for_ack: bool = False, **kwargs) -> bool: + """Send a raw packet via the dispatcher. - @contextmanager - def timer(): - start_time = time.time() - yield lambda: (time.time() - start_time) * 1000 # RTT in milliseconds + This is the single transport entry point. All message-building and + response-waiting logic lives in the companion layer. + """ + return await self.dispatcher.send_packet(pkt, wait_for_ack=wait_for_ack, **kwargs) - return timer() + # ------------------------------------------------------------------------- + # Event service propagation + # ------------------------------------------------------------------------- def set_event_service(self, event_service): """Set the event service for broadcasting mesh events.""" @@ -165,837 +134,33 @@ def set_event_service(self, event_service): if handler and hasattr(handler, "event_service"): handler.event_service = event_service - async def start(self) -> None: - """Start the mesh node and begin processing radio communications. - - Initialises the radio interface and dispatcher, then enters the main - event loop for handling incoming/outgoing messages. This method blocks - until the node is stopped. - - Note: - This is an asynchronous operation that runs indefinitely until - cancelled or the node is stopped. - """ - await self.dispatcher.run_forever() - - async def send_text( - self, - contact_name: str, - message: str, - attempt: int = 1, - message_type: str = "direct", - out_path: Optional[list] = None, - ) -> dict: - """Send a text message to a specified contact. - - Transmits a text message to another node in the mesh network, - with optional routing and retry configuration. - - Args: - contact_name: Name of the target contact in the contact book. - message: Text content to send. - attempt: Message attempt number for retry logic (default: 1). - message_type: Routing type - "direct" or other supported types. - out_path: Optional list of intermediate nodes for routing. - - Returns: - Dictionary containing transmission results including success status, - signal strength metrics (SNR/RSSI), and routing information. - - Raises: - RuntimeError: If the specified contact is not found. - - Example: - ```python - result = await node.send_text("alice", "Hello, world!") - print(result["success"]) # True if message sent successfully - ``` - """ - from ..protocol import PacketBuilder - - contact = self._get_contact_or_raise(contact_name) - - # Create the text message packet using PacketBuilder - pkt, ack_crc = PacketBuilder.create_text_message( - contact=contact, - local_identity=self.identity, - message=message, - attempt=attempt, - message_type=message_type, - out_path=out_path, - ) - - # Log packet details with routing info - routing_info = f"type={message_type}" - if out_path: - routing_info += f", path={' -> '.join(str(hop) for hop in out_path)}" - - self.logger.debug( - f"[send_text] -> {contact_name} msg='{message}' CRC={ack_crc:08X} ({routing_info})" - ) - - # Send packet with the expected ACK CRC - success = await self.dispatcher.send_packet(pkt, wait_for_ack=True, expected_crc=ack_crc) - if not success: - self.logger.warning(f"No ACK received for CRC {ack_crc:08X}") - - # Extract signal strength information from radio - snr = getattr(pkt, "snr", None) if "pkt" in locals() else None - rssi = None - - # Get current signal strength from radio for outgoing messages - if hasattr(self.radio, "get_last_rssi"): - rssi = self.radio.get_last_rssi() - if hasattr(self.radio, "get_last_snr") and snr is None: - snr = self.radio.get_last_snr() - - return { - "success": success, - "attempt": attempt, - "message_type": message_type, - "out_path": out_path, - "snr": snr, - "rssi": rssi, - "crc": ack_crc, - } - - async def send_telemetry_request( - self, - contact_name: str, - want_base: bool = True, - want_location: bool = True, - want_environment: bool = True, - timeout: float = 10.0, - ) -> dict: - """Request telemetry data from a contact node. - - Sends a telemetry request and waits for the target node to respond - with requested sensor data including base metrics, location, and - environmental readings. - - Args: - contact_name: Name of the contact to query. - want_base: Include basic telemetry metrics in request. - want_location: Include GPS/location data in request. - want_environment: Include environmental sensors in request. - timeout: Maximum time to wait for response in seconds. - - Returns: - Dictionary with request results, telemetry data, and performance - metrics including round-trip time. - - Raises: - RuntimeError: If contact not found or protocol handler unavailable. - - Example: - ```python - result = await node.send_telemetry_request("sensor_node") - if result["success"]: - print(f"Temperature: {result['telemetry_data'].get('temp')}") - ``` - """ - from ..protocol import PacketBuilder - from ..protocol.constants import REQ_TYPE_GET_TELEMETRY_DATA - - contact = self._get_contact_or_raise(contact_name) - - with self._time_operation() as get_rtt: - contact_hash = bytes.fromhex(contact.public_key)[0] - - # Set up response waiting - waiter = self._ResponseWaiter() + # ------------------------------------------------------------------------- + # Backwards-compatible utilities (deprecated — prefer companion layer) + # ------------------------------------------------------------------------- - # Register callback with protocol response handler - if not hasattr(self.dispatcher, "protocol_response_handler"): - raise RuntimeError("Protocol response handler not available") - - self.dispatcher.protocol_response_handler.set_response_callback( - contact_hash, waiter.callback - ) - - try: - # Build and send telemetry request - inv = PacketBuilder._compute_inverse_perm_mask( - want_base, want_location, want_environment - ) - - pkt, _ = PacketBuilder.create_protocol_request( - contact=contact, - local_identity=self.identity, - protocol_code=REQ_TYPE_GET_TELEMETRY_DATA, - data=bytes([inv]), - ) - - self.logger.debug( - f"[send_telemetry_request] -> {contact_name} " - f"base={want_base}, location={want_location}, " - f"environment={want_environment}" - ) - - await self.dispatcher.send_packet(pkt, wait_for_ack=False) - - # Wait for response - result = await waiter.wait(timeout) - rtt = get_rtt() - - if result.get("timeout"): - self.logger.warning( - f"Timeout waiting for telemetry response from {contact_name}" - ) - return { - "success": False, - "contact": contact_name, - "requested": { - "base": want_base, - "location": want_location, - "environment": want_environment, - }, - "telemetry_data": None, - "rtt_ms": round(rtt, 2), - "reason": f"Telemetry response timeout after {timeout}s", - } - - self.logger.info( - f"[send_telemetry_request] Response from {contact_name}: '{result['text']}'" - ) - - return { - "success": result.get("success", False), - "contact": contact_name, - "requested": { - "base": want_base, - "location": want_location, - "environment": want_environment, - }, - "telemetry_data": result["parsed"], - "response_text": result["text"], - "rtt_ms": round(rtt, 2), - "reason": ( - "Telemetry response received" - if result.get("success") - else "Telemetry request failed" - ), - } - - finally: - self.dispatcher.protocol_response_handler.clear_response_callback(contact_hash) - - def stop(self): - """Stop the mesh node and clean up associated services. - - Terminates radio communications and shuts down all active handlers. - This method is synchronous and should be called to gracefully - shut down the node. - """ - try: - self.logger.info("Node stopped") - except Exception as e: - self.logger.error(f"Error stopping node: {e}") - - async def send_group_text(self, group_name: str, message: str) -> dict: - """Broadcast a text message to all members of a group. - - Sends a group datagram that will be received by all nodes configured - for the specified group. Group messages are fire-and-forget with no - acknowledgements expected. - - Args: - group_name: Name of the group to broadcast to. - message: Text content to broadcast. - - Returns: - Dictionary with transmission results and signal metrics. - Note: Group messages don't wait for acknowledgements. - - Example: - ```python - result = await node.send_group_text("team_alpha", "Meeting at 15:00") - print(f"Broadcast to {result['group']}: {result['success']}") - ``` - """ - from ..protocol import PacketBuilder - - # Get channels from database (live query) - try: - channels_config = self.channel_db.get_channels() if self.channel_db else [] - except Exception as e: - self.logger.error(f"Failed to get channels from database: {e}") - channels_config = [] - - # Create the group text message packet using PacketBuilder - pkt = PacketBuilder.create_group_datagram( - group_name=group_name, - local_identity=self.identity, - message=message, - sender_name=self.node_name, - channels_config=channels_config, - ) - - # Log packet details (no CRC for group messages - they don't use ACKs) - self.logger.debug(f"[send_group_text] -> {group_name} msg='{message}'") - - # Send packet without waiting for ACK (group messages are unverified) - success = await self.dispatcher.send_packet(pkt, wait_for_ack=False) - - # Extract signal strength information from radio - snr = getattr(pkt, "snr", None) if "pkt" in locals() else None - rssi = None - - # Get current signal strength from radio for outgoing messages - if hasattr(self.radio, "get_last_rssi"): - rssi = self.radio.get_last_rssi() - if hasattr(self.radio, "get_last_snr") and snr is None: - snr = self.radio.get_last_snr() - - # Note: Unlike text messages, we don't publish events here - # Let the app level handle outgoing message events if needed - # This prevents duplicate events when the message is received back - - return { - "success": success, - "snr": snr, - "rssi": rssi, - "group": group_name, - } - - async def send_login(self, repeater_name: str, password: str) -> dict: - """Authenticate with a repeater node. - - Sends login credentials to a repeater and waits for authentication - response. Successful login may grant administrative privileges. - - Args: - repeater_name: Name of the repeater to authenticate with. - password: Authentication password for the repeater. - - Returns: - Dictionary with login results including success status, - admin privileges, and keep-alive intervals. - - Raises: - RuntimeError: If repeater contact not found. - - Example: - ```python - result = await node.send_login("repeater_01", "secret123") - if result["success"] and result["is_admin"]: - print("Admin access granted") - ``` - """ - from ..protocol import PacketBuilder - - contact = self._get_contact_or_raise(repeater_name) - - with self._time_operation() as get_rtt: - contact_pubkey = bytes.fromhex(contact.public_key) - dest_hash = contact_pubkey[0] if len(contact_pubkey) > 0 else 0 - - # Store password in login handlers - self._find_and_call_handler_method("store_login_password", dest_hash, password) - - # Create and send login packet - pkt = PacketBuilder.create_login_packet( - contact=contact, local_identity=self.identity, password=password - ) - - self.logger.debug(f"[send_login] -> {repeater_name}") - - # Set up login response waiting - login_result = {"success": False, "data": {}} - login_event = asyncio.Event() - - def login_response_callback(success: bool, response_data: dict): - login_result["success"] = success - login_result["data"] = response_data - login_event.set() - - # Set callback on login response handlers - self._find_and_call_handler_method("set_login_callback", login_response_callback) - - try: - await self.dispatcher.send_packet(pkt, wait_for_ack=False) - - # Wait for login response - try: - await asyncio.wait_for(login_event.wait(), timeout=10.0) - except asyncio.TimeoutError: - self.logger.warning(f"Login timeout for repeater '{repeater_name}'") - return { - "success": False, - "repeater": repeater_name, - "command": "login", - "rtt_ms": round(get_rtt(), 2), - "reason": "Login response timeout", - } - - rtt = get_rtt() - success = login_result["success"] - response_data = login_result["data"] - - if success: - self.logger.info( - f"Login successful to '{repeater_name}' " - f"(admin: {response_data.get('is_admin', False)})" - ) - reason = ( - f"Login successful - Admin: " - f"{'Yes' if response_data.get('is_admin') else 'No'}" - ) - else: - error_msg = response_data.get("error", "Login failed") - self.logger.warning(f"Login failed to '{repeater_name}': {error_msg}") - reason = f"Login failed: {error_msg}" - - return { - "success": success, - "repeater": repeater_name, - "command": "login", - "rtt_ms": round(rtt, 2), - "is_admin": response_data.get("is_admin", False), - "keep_alive_interval": response_data.get("keep_alive_interval", 0), - "reason": reason, - } - - finally: - # Clear callbacks - self._find_and_call_handler_method("set_login_callback", None) - - async def send_logout(self, repeater_name: str) -> dict: - """Terminate authentication session with a repeater. - - Sends a logout command to end the current session with a repeater. - This should be called when finished with repeater operations. - - Args: - repeater_name: Name of the repeater to logout from. - - Returns: - Dictionary with logout results and performance metrics. - - Raises: - RuntimeError: If repeater contact not found. - - Example: - ```python - result = await node.send_logout("repeater_01") - print(f"Logout {'successful' if result['success'] else 'failed'}") - ``` - """ - from ..protocol import PacketBuilder - - contact = self._get_contact_or_raise(repeater_name) - - with self._time_operation() as get_rtt: - # Create the logout packet using PacketBuilder - pkt, ack_crc = PacketBuilder.create_logout_packet( - contact=contact, local_identity=self.identity - ) - - self.logger.debug(f"[send_logout] -> {repeater_name} with CRC={ack_crc:08X}") - - # Send packet and wait for ACK - success = await self.dispatcher.send_packet( - pkt, wait_for_ack=True, expected_crc=ack_crc - ) - rtt = get_rtt() - - if not success: - self.logger.warning(f"No ACK received for logout CRC {ack_crc:08X}") - - return { - "success": success, - "repeater": repeater_name, - "command": "logout", - "rtt_ms": round(rtt, 2), - "crc": ack_crc, - "reason": "Logout successful" if success else "No ACK received", - } - - async def send_status_request(self, repeater_name: str) -> dict: - """Request status information from a repeater. - - Queries a repeater for its current operational status and configuration. - This is a convenience method that uses the text command interface. - - Args: - repeater_name: Name of the repeater to query. - - Returns: - Dictionary with status information and response metrics. - - Raises: - RuntimeError: If repeater contact not found. - - Example: - ```python - status = await node.send_status_request("repeater_01") - if status["success"]: - print(f"Status: {status['response']}") - ``` - """ - # Use the simple text command approach instead of protocol packets - return await self.send_repeater_command(repeater_name, "status") - - async def send_protocol_request( - self, repeater_name: str, protocol_code: int, data: bytes = b"" - ) -> dict: - """Send a protocol-specific request to a repeater. - - Transmits a custom protocol request with optional data payload - and waits for the repeater's response. - - Args: - repeater_name: Name of the repeater to send request to. - protocol_code: Protocol operation code (0-255). - data: Optional binary data payload for the request. - - Returns: - Dictionary with protocol response, parsed data, and timing metrics. - - Raises: - RuntimeError: If repeater contact or protocol handler not found. - - Example: - ```python - result = await node.send_protocol_request("repeater_01", 0x10, b"config") - if result["success"]: - print(f"Response: {result['response']}") - ``` - """ - from ..protocol import PacketBuilder - - contact = self._get_contact_or_raise(repeater_name) - - with self._time_operation() as get_rtt: - contact_hash = bytes.fromhex(contact.public_key)[0] - - # Set up response waiting - waiter = self._ResponseWaiter() - - if not hasattr(self.dispatcher, "protocol_response_handler"): - raise RuntimeError("Protocol response handler not available") - - self.dispatcher.protocol_response_handler.set_response_callback( - contact_hash, waiter.callback - ) - - try: - pkt, _ = PacketBuilder.create_protocol_request( - contact=contact, - local_identity=self.identity, - protocol_code=protocol_code, - data=data, - ) - - self.logger.debug( - f"[send_protocol_request] -> {repeater_name}: protocol 0x{protocol_code:02X}" - ) - - await self.dispatcher.send_packet(pkt, wait_for_ack=False) - self.logger.debug("[send_protocol_request] Packet sent, waiting for response...") - - result = await waiter.wait(10.0) - rtt = get_rtt() - - if result.get("timeout"): - self.logger.warning( - f"Timeout waiting for protocol response from {repeater_name}" - ) - return { - "success": False, - "repeater": repeater_name, - "command": f"protocol_0x{protocol_code:02X}", - "protocol_code": protocol_code, - "response": None, - "parsed_data": {}, - "rtt_ms": round(rtt, 2), - "ack_received": False, - "reason": f"Protocol 0x{protocol_code:02X} timeout", - } - - self.logger.info( - f"[send_protocol_request] Response from {repeater_name}: '{result['text']}'" - ) - - return { - "success": result["success"], - "repeater": repeater_name, - "command": f"protocol_0x{protocol_code:02X}", - "protocol_code": protocol_code, - "response": result["text"], - "parsed_data": result["parsed"], - "rtt_ms": round(rtt, 2), - "ack_received": False, - "reason": ( - f"Protocol 0x{protocol_code:02X} " - f"{'successful' if result['success'] else 'failed'}" - ), - } - - finally: - self.dispatcher.protocol_response_handler.clear_response_callback(contact_hash) - - async def send_trace_packet( - self, - contact_name: str, - tag: int, - auth_code: int, - flags: int = 0, - path: Optional[list] = None, - timeout: float = 5.0, - ) -> dict: - """Send a diagnostic trace packet for network analysis. - - Transmits a trace packet to analyse routing paths and network - performance. Always expects a response with trace data, signal - metrics, and routing information. - - Args: - contact_name: Name of the target contact for tracing. - tag: Unique identifier for this trace operation. - auth_code: Authentication code for the trace request. - flags: Optional flags to modify trace behaviour. - path: Optional custom routing path for the trace. - timeout: Maximum time to wait for trace response. - - Returns: - Dictionary with trace results, routing data, and signal metrics. - - Raises: - RuntimeError: If contact not found or trace handler unavailable. - - Example: - ```python - trace = await node.send_trace_packet("target_node", 0x12345678, 0xABCD) - if trace["success"]: - print(f"RTT: {trace['rtt_ms']}ms") - ``` - """ - from ..protocol import PacketBuilder - - contact = self._get_contact_or_raise(contact_name) - path = path or [] - - # Get target node ID from contact's public key - target_node_id = bytes.fromhex(contact.public_key)[0] - - # Use provided path or create simple direct path - trace_path = path if path else [target_node_id] - - with self._time_operation() as get_rtt: - # Create trace packet with path included - pkt = PacketBuilder.create_trace(tag, auth_code, flags, path=trace_path) - - self.logger.debug( - f"[send_trace_packet] -> {contact_name} tag=0x{tag:08X} path={trace_path}" - ) - - # Send trace packet and wait for response - try: - handler = self.dispatcher.trace_handler - contact_hash = bytes.fromhex(contact.public_key)[0] - waiter = self._ResponseWaiter() - - if handler: - handler.set_response_callback(contact_hash, waiter.callback) - - try: - await self.dispatcher.send_packet(pkt, wait_for_ack=False) - result = await waiter.wait(timeout) - rtt = get_rtt() - - if result.get("timeout"): - self.logger.warning(f"No trace response from {contact_name}") - return { - "success": False, - "contact": contact_name, - "trace_data": { - "tag": tag, - "auth_code": auth_code, - "flags": flags, - "path": trace_path, - }, - "response": None, - "rtt_ms": round(rtt, 2), - "reason": f"Timeout after {timeout}s", - } - - self.logger.info( - f"Trace response from {contact_name}: '{result.get('text', '')}'" - ) - return { - "success": result.get("success", True), - "contact": contact_name, - "trace_data": { - "tag": tag, - "auth_code": auth_code, - "flags": flags, - "path": trace_path, - }, - "response": result.get("text"), - "parsed_data": result.get("parsed", {}), - "rtt_ms": round(rtt, 2), - "reason": "Response received", - } - - finally: - handler.clear_response_callback(contact_hash) - else: - self.logger.error(f"No trace handler for {contact_name}") - return { - "success": False, - "contact": contact_name, - "reason": "No trace handler available", - } - - except Exception as e: - rtt = get_rtt() - self.logger.error(f"Trace error: {e}") - return { - "success": False, - "contact": contact_name, - "trace_data": { - "tag": tag, - "auth_code": auth_code, - "flags": flags, - "path": trace_path, - }, - "response": None, - "rtt_ms": round(rtt, 2), - "reason": f"Error: {str(e)}", - } - - async def send_repeater_command( - self, repeater_name: str, command: str, parameters: Optional[str] = None - ) -> dict: - """Send a text-based command to a repeater and await response. - - Transmits a command string to a repeater using the text message - protocol and waits for a response. Useful for administrative - operations and status queries. + class _ResponseWaiter: + """Synchronisation helper for async response callbacks. - Args: - repeater_name: Name of the repeater to send command to. - command: Command string to execute on the repeater. - parameters: Optional parameters for the command. - - Returns: - Dictionary with command results, response text, and timing data. - - Raises: - RuntimeError: If repeater contact not found. - - Example: - ```python - result = await node.send_repeater_command("repeater_01", "status") - if result["success"]: - print(f"Response: {result['response']}") - ``` + .. deprecated:: + Use :class:`~pymc_core.companion.models.ResponseWaiter` from the + companion layer instead. """ - from ..protocol import PacketBuilder - - contact = self._get_contact_or_raise(repeater_name) - - with self._time_operation() as get_rtt: - # Build full command string - full_command = command - if parameters: - full_command += f" {parameters}" - - # Set up response capture - response_event = asyncio.Event() - response_data = {"text": None, "success": False} - def response_callback(message_text: str, sender_contact): - response_data["text"] = message_text - response_data["success"] = True - response_event.set() + def __init__(self): + self.event = asyncio.Event() + self.data = {"success": False, "text": None, "parsed": {}} - # Set response callback - self.dispatcher.text_message_handler.set_command_response_callback(response_callback) + def callback(self, success: bool, text: str, parsed_data: Optional[dict] = None): + """Standard callback for response handlers.""" + self.data["success"] = success + self.data["text"] = text + self.data["parsed"] = parsed_data or {} + self.event.set() + async def wait(self, timeout: float = 10.0) -> dict: + """Wait for response with timeout. Returns the response data.""" try: - # Create and send packet - pkt, ack_crc = PacketBuilder.create_text_message( - contact=contact, - local_identity=self.identity, - message=full_command, - attempt=1, - message_type="command", - ) - - # Send packet and get ACK result - ack_success = await self.dispatcher.send_packet( - pkt, wait_for_ack=True, expected_crc=ack_crc - ) - - # Wait for response (regardless of ACK result) - try: - await asyncio.wait_for(response_event.wait(), timeout=15.0) - response_received = True - except asyncio.TimeoutError: - response_received = False - - # Calculate RTT - rtt = get_rtt() - response_text = response_data["text"] - - # Return result based on what we got - if response_received: - return { - "success": True, - "repeater": repeater_name, - "command": command, - "parameters": parameters, - "full_command": full_command, - "response": response_text, - "rtt_ms": round(rtt, 2), - "crc": ack_crc, - "ack_received": ack_success, - "reason": f"Command '{command}' successful with response" - + ("" if ack_success else " (no ACK)"), - } - elif ack_success: - return { - "success": True, - "repeater": repeater_name, - "command": command, - "parameters": parameters, - "full_command": full_command, - "response": "Command sent successfully (no response received)", - "rtt_ms": round(rtt, 2), - "crc": ack_crc, - "ack_received": True, - "reason": f"Command '{command}' sent but no response received", - } - else: - return { - "success": False, - "repeater": repeater_name, - "command": command, - "parameters": parameters, - "full_command": full_command, - "response": None, - "rtt_ms": round(rtt, 2), - "crc": ack_crc, - "ack_received": False, - "reason": f"No ACK or response received for command '{command}'", - } - - except Exception as e: - rtt = get_rtt() - return { - "success": False, - "repeater": repeater_name, - "command": command, - "parameters": parameters, - "full_command": full_command, - "response": None, - "rtt_ms": round(rtt, 2), - "crc": None, - "ack_received": False, - "reason": f"Error sending command: {e}", - } - finally: - # Always clear the callback - self.dispatcher.text_message_handler.set_command_response_callback(None) + await asyncio.wait_for(self.event.wait(), timeout=timeout) + return self.data + except asyncio.TimeoutError: + return {"success": False, "text": None, "parsed": {}, "timeout": True} diff --git a/tests/test_companion_regions.py b/tests/test_companion_regions.py index 8154790..2c935f3 100644 --- a/tests/test_companion_regions.py +++ b/tests/test_companion_regions.py @@ -5,7 +5,6 @@ import pytest from pymc_core.companion import CompanionRadio -from pymc_core.companion.models import Contact from pymc_core.protocol import LocalIdentity, Packet, PacketBuilder from pymc_core.protocol.constants import ( ROUTE_TYPE_DIRECT, @@ -61,12 +60,6 @@ def _make_companion() -> CompanionRadio: return CompanionRadio(radio=radio, identity=identity, node_name="test") -def _make_peer_contact(name: str) -> Contact: - """Return a contact with a valid Ed25519 public key.""" - peer = LocalIdentity() - return Contact(public_key=peer.get_public_key(), name=name) - - # --------------------------------------------------------------------------- # _apply_flood_scope unit tests # --------------------------------------------------------------------------- From 2ecf74cd9a1f5106878ee44d68d94bcccaf7d226 Mon Sep 17 00:00:00 2001 From: agessaman Date: Tue, 17 Feb 2026 19:32:26 -0800 Subject: [PATCH 11/50] feat: add CompanionFrameServer and frame protocol constants Move the full companion frame protocol implementation from pyMC_repeater into pyMC_core as a reusable base class. The base CompanionFrameServer handles TCP framing, command dispatch, push callbacks, and contact/message/channel management. Persistence is handled through overridable hook methods (_persist_companion_message, _sync_next_from_persistence, _save_contacts, _save_channels) so the base class works standalone with in-memory stores while subclasses can add SQLite or other backends. Also adds ~90 frame protocol constants (CMD_*, RESP_CODE_*, PUSH_CODE_*, ERR_CODE_*) and frame delimiters to companion/constants.py, and exports CompanionFrameServer from companion/__init__.py. --- src/pymc_core/companion/__init__.py | 2 + src/pymc_core/companion/constants.py | 137 +++ src/pymc_core/companion/frame_server.py | 1381 +++++++++++++++++++++++ 3 files changed, 1520 insertions(+) create mode 100644 src/pymc_core/companion/frame_server.py diff --git a/src/pymc_core/companion/__init__.py b/src/pymc_core/companion/__init__.py index 582c034..91fbf3d 100644 --- a/src/pymc_core/companion/__init__.py +++ b/src/pymc_core/companion/__init__.py @@ -39,6 +39,7 @@ BinaryReqType, ) from .contact_store import ContactStore +from .frame_server import CompanionFrameServer from .message_queue import MessageQueue from .models import AdvertPath, Channel, Contact, NodePrefs, PacketStats, QueuedMessage, SentResult from .path_cache import PathCache @@ -48,6 +49,7 @@ # Main classes "CompanionRadio", "CompanionBridge", + "CompanionFrameServer", # Stores "ContactStore", "ChannelStore", diff --git a/src/pymc_core/companion/constants.py b/src/pymc_core/companion/constants.py index 38de817..78329cb 100644 --- a/src/pymc_core/companion/constants.py +++ b/src/pymc_core/companion/constants.py @@ -2,6 +2,7 @@ from __future__ import annotations +import base64 from enum import IntEnum # --------------------------------------------------------------------------- @@ -86,3 +87,139 @@ class BinaryReqType(IntEnum): DEFAULT_MAX_CHANNELS = 40 CONTACT_NAME_SIZE = 32 MAX_SIGN_DATA_SIZE = 8192 # 8KB signing buffer (matches firmware) + +# =========================================================================== +# Frame Protocol Constants (MeshCore Companion Radio Protocol) +# =========================================================================== + +# --------------------------------------------------------------------------- +# Commands (app -> radio) +# --------------------------------------------------------------------------- +CMD_APP_START = 1 +CMD_SEND_TXT_MSG = 2 +CMD_SEND_CHANNEL_TXT_MSG = 3 +CMD_GET_CONTACTS = 4 +CMD_GET_DEVICE_TIME = 5 +CMD_SET_DEVICE_TIME = 6 +CMD_SEND_SELF_ADVERT = 7 +CMD_SET_ADVERT_NAME = 8 +CMD_ADD_UPDATE_CONTACT = 9 +CMD_SYNC_NEXT_MESSAGE = 10 +CMD_SET_RADIO_PARAMS = 11 +CMD_SET_RADIO_TX_POWER = 12 +CMD_RESET_PATH = 13 +CMD_SET_ADVERT_LATLON = 14 +CMD_REMOVE_CONTACT = 15 +CMD_SHARE_CONTACT = 16 +CMD_EXPORT_CONTACT = 17 +CMD_IMPORT_CONTACT = 18 +CMD_REBOOT = 19 +CMD_GET_BATT_AND_STORAGE = 20 +CMD_SET_TUNING_PARAMS = 21 +CMD_DEVICE_QUERY = 22 +CMD_EXPORT_PRIVATE_KEY = 23 +CMD_IMPORT_PRIVATE_KEY = 24 +CMD_SEND_RAW_DATA = 25 +CMD_SEND_LOGIN = 26 +CMD_SEND_STATUS_REQ = 27 +CMD_HAS_CONNECTION = 28 +CMD_LOGOUT = 29 +CMD_GET_CONTACT_BY_KEY = 30 +CMD_GET_CHANNEL = 31 +CMD_SET_CHANNEL = 32 +CMD_SIGN_START = 33 +CMD_SIGN_DATA = 34 +CMD_SIGN_FINISH = 35 +CMD_SEND_TRACE_PATH = 36 +CMD_SET_DEVICE_PIN = 37 +CMD_SET_OTHER_PARAMS = 38 +CMD_SEND_TELEMETRY_REQ = 39 +CMD_GET_CUSTOM_VARS = 40 +CMD_SET_CUSTOM_VAR = 41 +CMD_GET_ADVERT_PATH = 42 +CMD_GET_TUNING_PARAMS = 43 +CMD_SEND_BINARY_REQ = 50 +CMD_FACTORY_RESET = 51 +CMD_SEND_PATH_DISCOVERY_REQ = 52 +CMD_SET_FLOOD_SCOPE = 54 +CMD_SEND_CONTROL_DATA = 55 +CMD_GET_STATS = 56 +CMD_SEND_ANON_REQ = 57 +CMD_SET_AUTOADD_CONFIG = 58 +CMD_GET_AUTOADD_CONFIG = 59 + +# --------------------------------------------------------------------------- +# Response codes (radio -> app) +# --------------------------------------------------------------------------- +RESP_CODE_OK = 0 +RESP_CODE_ERR = 1 +RESP_CODE_CONTACTS_START = 2 +RESP_CODE_CONTACT = 3 +RESP_CODE_END_OF_CONTACTS = 4 +RESP_CODE_SELF_INFO = 5 +RESP_CODE_SENT = 6 +RESP_CODE_CONTACT_MSG_RECV = 7 +RESP_CODE_CHANNEL_MSG_RECV = 8 +RESP_CODE_CURR_TIME = 9 +RESP_CODE_NO_MORE_MESSAGES = 10 +RESP_CODE_EXPORT_CONTACT = 11 +RESP_CODE_BATT_AND_STORAGE = 12 +RESP_CODE_DEVICE_INFO = 13 +RESP_CODE_PRIVATE_KEY = 14 +RESP_CODE_DISABLED = 15 +RESP_CODE_CONTACT_MSG_RECV_V3 = 16 +RESP_CODE_CHANNEL_MSG_RECV_V3 = 17 +RESP_CODE_CHANNEL_INFO = 18 +RESP_CODE_SIGN_START = 19 +RESP_CODE_SIGNATURE = 20 +RESP_CODE_CUSTOM_VARS = 21 +RESP_CODE_ADVERT_PATH = 22 +RESP_CODE_TUNING_PARAMS = 23 +RESP_CODE_STATS = 24 +RESP_CODE_AUTOADD_CONFIG = 25 + +# --------------------------------------------------------------------------- +# Push codes (radio -> app, unsolicited) +# --------------------------------------------------------------------------- +PUSH_CODE_ADVERT = 0x80 +PUSH_CODE_PATH_UPDATED = 0x81 +PUSH_CODE_SEND_CONFIRMED = 0x82 +PUSH_CODE_MSG_WAITING = 0x83 +PUSH_CODE_RAW_DATA = 0x84 +PUSH_CODE_LOGIN_SUCCESS = 0x85 +PUSH_CODE_LOGIN_FAIL = 0x86 +PUSH_CODE_STATUS_RESPONSE = 0x87 +PUSH_CODE_LOG_RX_DATA = 0x88 +PUSH_CODE_TRACE_DATA = 0x89 +PUSH_CODE_NEW_ADVERT = 0x8A +PUSH_CODE_TELEMETRY_RESPONSE = 0x8B +PUSH_CODE_BINARY_RESPONSE = 0x8C +PUSH_CODE_PATH_DISCOVERY_RESPONSE = 0x8D +PUSH_CODE_CONTROL_DATA = 0x8E +PUSH_CODE_CONTACT_DELETED = 0x8F +PUSH_CODE_CONTACTS_FULL = 0x90 + +# --------------------------------------------------------------------------- +# Error codes +# --------------------------------------------------------------------------- +ERR_CODE_UNSUPPORTED_CMD = 1 +ERR_CODE_NOT_FOUND = 2 +ERR_CODE_TABLE_FULL = 3 +ERR_CODE_BAD_STATE = 4 +ERR_CODE_FILE_IO_ERROR = 5 +ERR_CODE_ILLEGAL_ARG = 6 + +# --------------------------------------------------------------------------- +# Frame delimiters (USB/TCP: > = outbound, < = inbound) +# --------------------------------------------------------------------------- +FRAME_OUTBOUND_PREFIX = 0x3E # '>' +FRAME_INBOUND_PREFIX = 0x3C # '<' +MAX_FRAME_SIZE = 512 +PUB_KEY_SIZE = 32 +MAX_PATH_SIZE = 64 + +# --------------------------------------------------------------------------- +# Default public channel PSK (from firmware MeshCore companion_radio example) +# --------------------------------------------------------------------------- +PUBLIC_GROUP_PSK = b"izOH6cXN6mrJ5e26oRXNcg==" +DEFAULT_PUBLIC_CHANNEL_SECRET = base64.b64decode(PUBLIC_GROUP_PSK) diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py new file mode 100644 index 0000000..6ed83ff --- /dev/null +++ b/src/pymc_core/companion/frame_server.py @@ -0,0 +1,1381 @@ +""" +CompanionFrameServer - Standard MeshCore Companion Radio Protocol over TCP. + +Implements the full companion frame protocol: command dispatch, push callbacks, +and contact/message/channel management. Persistence is handled through +overridable hook methods so the base class works standalone (in-memory only) +while subclasses can add SQLite or other storage backends. + +Frame format: + Outbound (radio → app): ``>`` (0x3E) + 2-byte LE length + data + Inbound (app → radio): ``<`` (0x3C) + 2-byte LE length + data +""" + +import asyncio +import logging +import struct +import time +from typing import Any, Callable, Optional + +from .constants import ( + ADV_TYPE_CHAT, + CMD_ADD_UPDATE_CONTACT, + CMD_APP_START, + CMD_DEVICE_QUERY, + CMD_GET_ADVERT_PATH, + CMD_GET_BATT_AND_STORAGE, + CMD_GET_CHANNEL, + CMD_GET_CONTACT_BY_KEY, + CMD_GET_CONTACTS, + CMD_GET_STATS, + CMD_IMPORT_CONTACT, + CMD_REMOVE_CONTACT, + CMD_RESET_PATH, + CMD_SEND_BINARY_REQ, + CMD_SEND_CHANNEL_TXT_MSG, + CMD_SEND_CONTROL_DATA, + CMD_SEND_LOGIN, + CMD_SEND_PATH_DISCOVERY_REQ, + CMD_SEND_SELF_ADVERT, + CMD_SEND_STATUS_REQ, + CMD_SEND_TELEMETRY_REQ, + CMD_SEND_TRACE_PATH, + CMD_SEND_TXT_MSG, + CMD_SET_ADVERT_LATLON, + CMD_SET_ADVERT_NAME, + CMD_SET_CHANNEL, + CMD_SET_FLOOD_SCOPE, + CMD_SYNC_NEXT_MESSAGE, + ERR_CODE_BAD_STATE, + ERR_CODE_ILLEGAL_ARG, + ERR_CODE_NOT_FOUND, + ERR_CODE_TABLE_FULL, + ERR_CODE_UNSUPPORTED_CMD, + FRAME_INBOUND_PREFIX, + FRAME_OUTBOUND_PREFIX, + MAX_FRAME_SIZE, + MAX_PATH_SIZE, + PUB_KEY_SIZE, + PUSH_CODE_ADVERT, + PUSH_CODE_BINARY_RESPONSE, + PUSH_CODE_CONTROL_DATA, + PUSH_CODE_LOG_RX_DATA, + PUSH_CODE_LOGIN_FAIL, + PUSH_CODE_LOGIN_SUCCESS, + PUSH_CODE_MSG_WAITING, + PUSH_CODE_NEW_ADVERT, + PUSH_CODE_PATH_DISCOVERY_RESPONSE, + PUSH_CODE_PATH_UPDATED, + PUSH_CODE_SEND_CONFIRMED, + PUSH_CODE_STATUS_RESPONSE, + PUSH_CODE_TELEMETRY_RESPONSE, + PUSH_CODE_TRACE_DATA, + RESP_CODE_ADVERT_PATH, + RESP_CODE_BATT_AND_STORAGE, + RESP_CODE_CHANNEL_INFO, + RESP_CODE_CHANNEL_MSG_RECV, + RESP_CODE_CHANNEL_MSG_RECV_V3, + RESP_CODE_CONTACT, + RESP_CODE_CONTACT_MSG_RECV, + RESP_CODE_CONTACT_MSG_RECV_V3, + RESP_CODE_CONTACTS_START, + RESP_CODE_DEVICE_INFO, + RESP_CODE_END_OF_CONTACTS, + RESP_CODE_ERR, + RESP_CODE_NO_MORE_MESSAGES, + RESP_CODE_OK, + RESP_CODE_SELF_INFO, + RESP_CODE_SENT, + RESP_CODE_STATS, + STATS_TYPE_CORE, + STATS_TYPE_PACKETS, + STATS_TYPE_RADIO, +) +from .models import Contact, QueuedMessage + +logger = logging.getLogger("CompanionFrameServer") + + +def _build_advert_push_frames(data: dict) -> tuple[bytes, Optional[bytes]]: + """Build PUSH_CODE_ADVERT short frame and optional PUSH_CODE_NEW_ADVERT + full frame from extracted data. Thread-safe for ``asyncio.to_thread``.""" + pubkey_b = data.get("pubkey_b", b"") + if isinstance(pubkey_b, bytes): + pubkey_b = pubkey_b[:32].ljust(32, b"\x00") + else: + pubkey_b = b"\x00" * 32 + short = bytes([PUSH_CODE_ADVERT]) + pubkey_b + if not data.get("include_full"): + return (short, None) + op = data.get("out_path", b"") + op = (op if isinstance(op, bytes) else bytes(op or []))[:MAX_PATH_SIZE].ljust( + MAX_PATH_SIZE, b"\x00" + ) + nb = data.get("name_b", b"") + nb = ( + nb + if isinstance(nb, bytes) + else (nb.encode("utf-8", errors="replace") if isinstance(nb, str) else b"") + )[:32].ljust(32, b"\x00") + full = ( + bytes([PUSH_CODE_NEW_ADVERT]) + + pubkey_b + + bytes( + [ + data.get("adv_type", 0), + data.get("flags", 0), + data.get("opl_byte", 0xFF), + ] + ) + + op + + nb + + struct.pack(" None: + """Start the TCP server.""" + self._server = await asyncio.start_server( + self._handle_client, + self.bind_address, + self.port, + ) + addr = ( + self._server.sockets[0].getsockname() + if self._server.sockets + else (self.bind_address, self.port) + ) + logger.info( + "Companion frame server listening on %s:%s (hash=0x%02x)", + addr[0], + addr[1], + int(self.companion_hash), + ) + + async def stop(self) -> None: + """Stop the TCP server and disconnect any client.""" + if self._client_writer: + try: + self._client_writer.close() + await self._client_writer.wait_closed() + except Exception: + pass + self._client_writer = None + self._client_reader = None + if self._server: + self._server.close() + await self._server.wait_closed() + self._server = None + logger.info("Companion frame server stopped (port=%s)", self.port) + + # ------------------------------------------------------------------------- + # Persistence hooks (override in subclasses for SQLite, etc.) + # ------------------------------------------------------------------------- + + async def _persist_companion_message(self, msg_dict: dict) -> None: + """Hook: persist a received message. Default is a no-op — the message + stays in the bridge's in-memory queue for ``sync_next_message``.""" + + def _sync_next_from_persistence(self) -> Optional[QueuedMessage]: + """Hook: pop a persisted message when the bridge queue is empty. + Default returns ``None``.""" + return None + + def _save_contacts(self) -> None: + """Hook: persist the current contact list. Default is a no-op.""" + + def _save_channels(self) -> None: + """Hook: persist the current channel list. Default is a no-op.""" + + def _get_batt_and_storage(self) -> tuple[int, int, int]: + """Hook: return (millivolts, used_kb, total_kb). Default: all zeros.""" + return (0, 0, 0) + + # ------------------------------------------------------------------------- + # Push callbacks + # ------------------------------------------------------------------------- + + def _setup_push_callbacks(self) -> None: + """Subscribe to bridge events and send PUSH frames to connected client.""" + + def _write_push(data: bytes) -> None: + if self._client_writer and not self._client_writer.is_closing(): + try: + frame = bytes([FRAME_OUTBOUND_PREFIX]) + struct.pack("= 32: + _write_push(bytes([PUSH_CODE_PATH_UPDATED]) + pub_key[:32]) + + async def on_channel_message_received( + channel_name, + sender_name, + message_text, + timestamp, + path_len=0, + channel_idx=0, + packet_hash=None, + ): + msg_dict = { + "sender_key": b"", + "text": message_text, + "timestamp": timestamp, + "txt_type": 0, + "is_channel": True, + "channel_idx": channel_idx, + "path_len": path_len, + "packet_hash": packet_hash, + } + await self._persist_companion_message(msg_dict) + _write_push(bytes([PUSH_CODE_MSG_WAITING])) + + async def on_binary_response(tag_bytes, response_data, parsed=None, request_type=None): + frame = ( + bytes([PUSH_CODE_BINARY_RESPONSE, 0]) + + (tag_bytes if isinstance(tag_bytes, bytes) else struct.pack(" None: + """Push PUSH_CODE_TRACE_DATA (0x89) to client. Matches firmware + ``onTraceRecv()`` frame format.""" + if not self._client_writer or self._client_writer.is_closing(): + return + path_sz = flags & 0x03 + expected_snr_len = path_len >> path_sz + if len(path_snrs) != expected_snr_len: + logger.debug( + "push_trace_data: path_snrs len %s != expected %s", + len(path_snrs), + expected_snr_len, + ) + return + data = ( + bytes([PUSH_CODE_TRACE_DATA, 0, path_len, flags]) + + struct.pack(" None: + """Push raw RX packet to client (PUSH_CODE_LOG_RX_DATA 0x88).""" + if not self._client_writer or self._client_writer.is_closing(): + return + snr_byte = max(-128, min(127, int(round(snr * 4)))) + rssi_byte = max(-128, min(127, int(rssi))) + if snr_byte < 0: + snr_byte += 256 + if rssi_byte < 0: + rssi_byte += 256 + payload_len = min(len(raw), MAX_FRAME_SIZE - 3) + data = bytes([PUSH_CODE_LOG_RX_DATA, snr_byte & 0xFF, rssi_byte & 0xFF]) + raw[:payload_len] + try: + frame = bytes([FRAME_OUTBOUND_PREFIX]) + struct.pack(" None: + """Push CONTROL packet to client (PUSH_CODE_CONTROL_DATA 0x8E).""" + if not self._client_writer or self._client_writer.is_closing(): + logger.warning("Push control data skipped: no client connection") + return + # Discovery response (0x90): clear the no-op callback + if self._control_handler and len(payload) >= 6 and (payload[0] & 0xF0) == 0x90: + tag = struct.unpack(" None: + if self._client_writer: + try: + await self._client_writer.drain() + except Exception: + pass + + def _write_frame(self, data: bytes) -> None: + """Send a frame to the connected client (outbound format).""" + if self._client_writer and not self._client_writer.is_closing(): + frame = bytes([FRAME_OUTBOUND_PREFIX]) + struct.pack(" None: + self._write_frame(bytes([RESP_CODE_OK])) + + def _write_err(self, err_code: int) -> None: + self._write_frame(bytes([RESP_CODE_ERR, err_code])) + + # ------------------------------------------------------------------------- + # Client handling + # ------------------------------------------------------------------------- + + async def _handle_client( + self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ) -> None: + """Handle a new client connection. One client at a time.""" + if self._client_writer: + logger.warning("Companion already has a client; rejecting new connection") + writer.close() + await writer.wait_closed() + return + + self._client_reader = reader + self._client_writer = writer + self._setup_push_callbacks() + logger.info("Companion client connected (port=%s)", self.port) + + try: + while True: + prefix = await reader.read(1) + if not prefix: + break + if prefix[0] != FRAME_INBOUND_PREFIX: + logger.warning("Invalid frame prefix: 0x%02x", prefix[0]) + continue + len_bytes = await reader.readexactly(2) + frame_len = struct.unpack(" MAX_FRAME_SIZE: + logger.warning("Frame too long: %s", frame_len) + break + payload = await reader.readexactly(frame_len) + await self._handle_cmd(payload) + except asyncio.IncompleteReadError: + pass + except (ConnectionResetError, BrokenPipeError): + pass + except Exception as e: + logger.error("Client handler error: %s", e, exc_info=True) + finally: + self._client_writer = None + self._client_reader = None + logger.info("Companion client disconnected (port=%s)", self.port) + + # ------------------------------------------------------------------------- + # Command dispatch + # ------------------------------------------------------------------------- + + async def _handle_cmd(self, payload: bytes) -> None: + """Dispatch command to handler.""" + if not payload: + return + cmd = payload[0] + data = payload[1:] + logger.info("Companion cmd 0x%02x (%s) len=%s", cmd, cmd, len(payload)) + if cmd in (CMD_GET_CHANNEL, CMD_SET_CHANNEL): + logger.debug( + "Companion cmd 0x%02x (%s), payload_len=%s", + cmd, + "GET_CHANNEL" if cmd == CMD_GET_CHANNEL else "SET_CHANNEL", + len(payload), + ) + + try: + if cmd == CMD_APP_START: + await self._cmd_app_start(data) + elif cmd == CMD_DEVICE_QUERY: + await self._cmd_device_query(data) + elif cmd == CMD_GET_CONTACTS: + await self._cmd_get_contacts(data) + elif cmd == CMD_GET_CONTACT_BY_KEY: + await self._cmd_get_contact_by_key(data) + elif cmd == CMD_SEND_TXT_MSG: + await self._cmd_send_txt_msg(data) + elif cmd == CMD_SEND_CHANNEL_TXT_MSG: + await self._cmd_send_channel_txt_msg(data) + elif cmd == CMD_SYNC_NEXT_MESSAGE: + await self._cmd_sync_next_message(data) + elif cmd == CMD_SEND_LOGIN: + await self._cmd_send_login(data) + elif cmd == CMD_SEND_STATUS_REQ: + await self._cmd_send_status_req(data) + elif cmd == CMD_SEND_TELEMETRY_REQ: + await self._cmd_send_telemetry_req(data) + elif cmd == CMD_SEND_SELF_ADVERT: + await self._cmd_send_self_advert(data) + elif cmd == CMD_SET_ADVERT_NAME: + await self._cmd_set_advert_name(data) + elif cmd == CMD_SET_ADVERT_LATLON: + await self._cmd_set_advert_latlon(data) + elif cmd == CMD_ADD_UPDATE_CONTACT: + await self._cmd_add_update_contact(data) + elif cmd == CMD_REMOVE_CONTACT: + await self._cmd_remove_contact(data) + elif cmd == CMD_RESET_PATH: + await self._cmd_reset_path(data) + elif cmd == CMD_GET_BATT_AND_STORAGE: + await self._cmd_get_batt_and_storage(data) + elif cmd == CMD_GET_STATS: + await self._cmd_get_stats(data) + elif cmd == CMD_GET_ADVERT_PATH: + await self._cmd_get_advert_path(data) + elif cmd == CMD_IMPORT_CONTACT: + await self._cmd_import_contact(data) + elif cmd == CMD_GET_CHANNEL: + await self._cmd_get_channel(data) + elif cmd == CMD_SET_CHANNEL: + await self._cmd_set_channel(data) + elif cmd == CMD_SEND_BINARY_REQ: + await self._cmd_send_binary_req(data) + elif cmd == CMD_SEND_PATH_DISCOVERY_REQ: + await self._cmd_send_path_discovery_req(data) + elif cmd == CMD_SEND_CONTROL_DATA: + await self._cmd_send_control_data(data) + elif cmd == CMD_SEND_TRACE_PATH: + await self._cmd_send_trace_path(data) + elif cmd == CMD_SET_FLOOD_SCOPE: + await self._cmd_set_flood_scope(data) + else: + logger.warning( + "Companion unsupported cmd 0x%02x (%s) len=%s", + cmd, + cmd, + len(payload), + ) + self._write_err(ERR_CODE_UNSUPPORTED_CMD) + except Exception as e: + logger.error("Cmd 0x%02x error: %s", cmd, e, exc_info=True) + self._write_err(ERR_CODE_ILLEGAL_ARG) + + # ------------------------------------------------------------------------- + # Command handlers + # ------------------------------------------------------------------------- + + async def _cmd_app_start(self, data: bytes) -> None: + if len(data) >= 1: + self._app_target_ver = data[0] + prefs = self.bridge.get_self_info() + pubkey = self.bridge.get_public_key() + name = prefs.node_name.encode("utf-8", errors="replace") + lat = int(getattr(prefs, "latitude", 0) * 1e6) + lon = int(getattr(prefs, "longitude", 0) * 1e6) + frame = ( + bytes([RESP_CODE_SELF_INFO, ADV_TYPE_CHAT, prefs.tx_power_dbm, 22]) + + pubkey + + struct.pack(" None: + if len(data) >= 1: + self._app_target_ver = data[0] + firmware_ver = 8 + max_contacts = getattr(getattr(self.bridge, "contacts", None), "max_contacts", 1000) + max_channels_val = getattr(getattr(self.bridge, "channels", None), "max_channels", 40) + max_contacts_div_2 = min(max_contacts // 2, 255) + max_channels = min(max_channels_val, 255) + ble_pin = 0 + frame = ( + bytes( + [ + RESP_CODE_DEVICE_INFO, + firmware_ver, + max_contacts_div_2, + max_channels, + ] + ) + + struct.pack(" None: + since = struct.unpack("= 4 else 0 + contacts = self.bridge.get_contacts(since=since) + self._write_frame(bytes([RESP_CODE_CONTACTS_START]) + struct.pack(" None: + """Encode and write a single RESP_CODE_CONTACT frame.""" + pubkey = c.public_key if isinstance(c.public_key, bytes) else bytes.fromhex(c.public_key) + name = (c.name.encode("utf-8")[:32] if isinstance(c.name, str) else c.name[:32]).ljust( + 32, b"\x00" + ) + opl = c.out_path_len if hasattr(c, "out_path_len") else -1 + opl_byte = 0xFF if opl < 0 else min(opl, 255) + frame = ( + bytes([RESP_CODE_CONTACT]) + + pubkey + + bytes( + [ + c.adv_type if hasattr(c, "adv_type") else 0, + c.flags if hasattr(c, "flags") else 0, + ] + ) + + bytes([opl_byte]) + + (c.out_path[:MAX_PATH_SIZE] if hasattr(c, "out_path") and c.out_path else b"").ljust( + MAX_PATH_SIZE, b"\x00" + ) + + name + + struct.pack( + " None: + if len(data) < PUB_KEY_SIZE: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + pubkey = data[:PUB_KEY_SIZE] + contact = ( + self.bridge.contacts.get_by_key(pubkey) + if hasattr(self.bridge.contacts, "get_by_key") + else None + ) + if not contact: + self._write_err(ERR_CODE_NOT_FOUND) + return + self._write_contact_frame(contact) + + async def _cmd_send_txt_msg(self, data: bytes) -> None: + if len(data) < 12: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + txt_type = data[0] + attempt = data[1] + pubkey_prefix = data[6:12] + text = data[12:].decode("utf-8", errors="replace").rstrip("\x00") + contact = ( + self.bridge.contacts.get_by_key_prefix(pubkey_prefix) + if hasattr(self.bridge.contacts, "get_by_key_prefix") + else None + ) + if not contact: + for c in self.bridge.get_contacts(): + pk = ( + c.public_key if isinstance(c.public_key, bytes) else bytes.fromhex(c.public_key) + ) + if pk[:6] == pubkey_prefix: + contact = c + break + if not contact: + self._write_err(ERR_CODE_NOT_FOUND) + return + pubkey = ( + contact.public_key + if isinstance(contact.public_key, bytes) + else bytes.fromhex(contact.public_key) + ) + result = await self.bridge.send_text_message( + pubkey, text, txt_type=txt_type, attempt=attempt + 1 + ) + if result.success: + ack = result.expected_ack or 0 + timeout = result.timeout_ms or 5000 + frame = bytes([RESP_CODE_SENT, 1 if result.is_flood else 0]) + struct.pack( + " None: + if len(data) < 6: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + txt_type = data[0] + channel_idx = data[1] + text = data[6:].decode("utf-8", errors="replace").rstrip("\x00") + if txt_type != 0: + self._write_err(ERR_CODE_UNSUPPORTED_CMD) + return + if self.bridge.get_channel(channel_idx) is None: + self._write_err(ERR_CODE_NOT_FOUND) + return + ok = await self.bridge.send_channel_message(channel_idx, text) + if ok: + self._write_ok() + else: + self._write_err(ERR_CODE_BAD_STATE) + + async def _cmd_send_binary_req(self, data: bytes) -> None: + if len(data) < 33: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + pubkey = data[:32] + req_data = data[32:] + send_binary_req = getattr(self.bridge, "send_binary_req", None) + if not send_binary_req: + self._write_err(ERR_CODE_UNSUPPORTED_CMD) + return + try: + result = await send_binary_req(pubkey, req_data) + except Exception as e: + logger.error("send_binary_req error: %s", e, exc_info=True) + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + if not result.success: + self._write_err(ERR_CODE_NOT_FOUND) + return + tag = result.expected_ack if result.expected_ack is not None else 0 + timeout_ms = result.timeout_ms if result.timeout_ms is not None else 10000 + frame = bytes([RESP_CODE_SENT, 1 if result.is_flood else 0]) + struct.pack( + " None: + if len(data) < 2: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + if (data[0] & 0x80) == 0: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + # Discovery request: register a no-op response callback + if self._control_handler and len(data) >= 6 and (data[0] & 0xF0) == 0x80: + tag = struct.unpack(" None: + logger.info( + "Path discovery request received (cmd 52), data_len=%s", + len(data), + ) + if len(data) < 33: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + pub_key = data[1:33] + send_req = getattr(self.bridge, "send_path_discovery_req", None) + if not send_req: + self._write_err(ERR_CODE_UNSUPPORTED_CMD) + return + try: + result = await send_req(pub_key) + except Exception as e: + logger.error("send_path_discovery_req error: %s", e, exc_info=True) + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + if not result.success: + self._write_err(ERR_CODE_NOT_FOUND) + return + tag = result.expected_ack if result.expected_ack is not None else 0 + timeout_ms = result.timeout_ms if result.timeout_ms is not None else 10000 + frame = bytes([RESP_CODE_SENT, 1 if result.is_flood else 0]) + struct.pack( + " None: + if len(data) < 10: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + tag = struct.unpack_from("> path_sz) > MAX_PATH_SIZE or (path_len % (1 << path_sz)) != 0: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + send_raw = getattr(self.bridge, "send_trace_path_raw", None) + if not send_raw: + self._write_err(ERR_CODE_UNSUPPORTED_CMD) + return + try: + ok = await send_raw(tag, auth_code, flags, path_bytes) + except Exception as e: + logger.error("send_trace_path error: %s", e, exc_info=True) + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + if not ok: + self._write_err(ERR_CODE_TABLE_FULL) + return + est_timeout_ms = 5000 + (path_len * 200) + frame = bytes([RESP_CODE_SENT, 0]) + struct.pack("> path_sz + path_snrs = bytes(snr_len) + final_snr_byte = 0 + self.push_trace_data( + path_len, + flags, + tag, + auth_code, + path_bytes, + path_snrs, + final_snr_byte, + ) + + async def _cmd_sync_next_message(self, data: bytes) -> None: + msg = self.bridge.sync_next_message() + if msg is None: + msg = self._sync_next_from_persistence() + if msg is None: + self._write_frame(bytes([RESP_CODE_NO_MORE_MESSAGES])) + return + if msg.is_channel: + path_len_byte = msg.path_len if msg.path_len < 256 else 0xFF + txt_type = 0 + text_bytes = (msg.text or "").rstrip("\x00").encode("utf-8", errors="replace") + if self._app_target_ver >= 3: + frame = ( + bytes( + [ + RESP_CODE_CHANNEL_MSG_RECV_V3, + 0, + 0, + 0, + msg.channel_idx, + path_len_byte, + txt_type, + ] + ) + + struct.pack("= 6 else msg.sender_key.ljust(6, b"\x00") + ) + path_len_byte = msg.path_len if msg.path_len < 256 else 0xFF + text_bytes = msg.text.encode("utf-8", errors="replace") + if self._app_target_ver >= 3: + frame = ( + bytes([RESP_CODE_CONTACT_MSG_RECV_V3, 0, 0, 0]) + + prefix + + bytes([path_len_byte, msg.txt_type]) + + struct.pack(" None: + if len(data) < 32: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + pubkey = data[:32] + password = ( + data[32:].decode("utf-8", errors="replace").rstrip("\x00") if len(data) > 32 else "" + ) + self._write_frame(bytes([RESP_CODE_SENT, 1]) + struct.pack(" None: + if len(data) < 32: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + pubkey = data[0:32] + self._write_frame(bytes([RESP_CODE_SENT, 0]) + struct.pack(" None: + if len(data) < 35: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + pubkey = data[3:35] + flags = 0x07 # request all: base + location + environment + want_base = bool(flags & 0x01) + want_location = bool(flags & 0x02) + want_environment = bool(flags & 0x04) + self._write_frame(bytes([RESP_CODE_SENT, 0]) + struct.pack(" None: + flood = len(data) >= 1 and data[0] == 1 + ok = await self.bridge.advertise(flood=flood) + self._write_ok() if ok else self._write_err(ERR_CODE_BAD_STATE) + + async def _cmd_set_advert_name(self, data: bytes) -> None: + name = data.decode("utf-8", errors="replace").rstrip("\x00") + self.bridge.set_advert_name(name) + self._write_ok() + + async def _cmd_set_advert_latlon(self, data: bytes) -> None: + if len(data) < 8: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + lat, lon = struct.unpack_from(" None: + if len(data) < 36: + self._write_err(ERR_CODE_ILLEGAL_ARG) + await self._drain_writer() + return + pubkey = data[0:32] + adv_type = data[32] + flags = data[33] + out_path_len = struct.unpack_from("= out_path_end: + out_path = data[35:out_path_end].rstrip(b"\x00") + else: + out_path = data[35 : len(data)].rstrip(b"\x00") if len(data) > 35 else b"" + name_start = 35 + MAX_PATH_SIZE + name_end = name_start + 32 + if len(data) >= name_end: + name_raw = data[name_start:name_end] + elif len(data) > name_start: + name_raw = data[name_start : len(data)].ljust(32, b"\x00") + else: + name_raw = b"\x00" * 32 + name = name_raw.split(b"\x00")[0].decode("utf-8", errors="replace") + last_advert = 0 + if len(data) >= name_end + 4: + last_advert = struct.unpack_from("= name_end + 4 + 8: + gps_lat = struct.unpack_from("= name_end + 4 + 12: + lastmod = struct.unpack_from(" 255 else out_path_len + out_path_padded = (out_path[:MAX_PATH_SIZE] if out_path else b"").ljust( + MAX_PATH_SIZE, b"\x00" + ) + name_padded = (name.encode("utf-8")[:32] if isinstance(name, str) else name[:32]).ljust( + 32, b"\x00" + ) + contact_frame = ( + bytes([RESP_CODE_CONTACT]) + + pubkey + + bytes([adv_type, flags, opl_byte]) + + out_path_padded + + name_padded + + struct.pack(" None: + if len(data) < 32: + self._write_err(ERR_CODE_ILLEGAL_ARG) + await self._drain_writer() + return + pubkey = data[:32] + ok = self.bridge.remove_contact(pubkey) + if ok: + self._save_contacts() + self._write_ok() if ok else self._write_err(ERR_CODE_NOT_FOUND) + await self._drain_writer() + + async def _cmd_reset_path(self, data: bytes) -> None: + if len(data) < 32: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + pubkey = data[:32] + ok = self.bridge.reset_path(pubkey) + self._write_ok() if ok else self._write_err(ERR_CODE_NOT_FOUND) + + async def _cmd_get_batt_and_storage(self, data: bytes) -> None: + millivolts, used_kb, total_kb = self._get_batt_and_storage() + frame = ( + bytes([RESP_CODE_BATT_AND_STORAGE]) + + struct.pack(" None: + stats_type = data[0] if len(data) >= 1 else STATS_TYPE_PACKETS + if stats_type not in ( + STATS_TYPE_CORE, + STATS_TYPE_RADIO, + STATS_TYPE_PACKETS, + ): + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + stats = ( + self.stats_getter(stats_type) if self.stats_getter else None + ) or self.bridge.get_stats(stats_type) + frame = bytes([RESP_CODE_STATS, stats_type]) + if stats_type == STATS_TYPE_CORE: + battery_mv = int(stats.get("battery_mv", 0)) + uptime_secs = int(stats.get("uptime_secs", 0)) + errors = int(stats.get("errors", 0)) + queue_len = min(255, max(0, int(stats.get("queue_len", 0)))) + frame += struct.pack(" None: + if len(data) < 1 + PUB_KEY_SIZE: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + pub_key = data[1 : 1 + PUB_KEY_SIZE] + prefix = pub_key[:7] + found = ( + self.bridge.get_advert_path(prefix) + if getattr(self.bridge, "get_advert_path", None) + else None + ) + if not found: + self._write_err(ERR_CODE_NOT_FOUND) + return + path_bytes = getattr(found, "path", None) or b"" + if not isinstance(path_bytes, bytes): + path_bytes = bytes(path_bytes) + path_len = min(len(path_bytes), MAX_PATH_SIZE) + recv_ts = getattr(found, "recv_timestamp", 0) + frame = ( + bytes([RESP_CODE_ADVERT_PATH]) + + struct.pack(" None: + ok = self.bridge.import_contact(data) + self._write_ok() if ok else self._write_err(ERR_CODE_ILLEGAL_ARG) + + async def _cmd_get_channel(self, data: bytes) -> None: + channel_idx = data[0] if len(data) >= 1 else 0 + get_full_list = len(data) == 0 + max_channels_val = getattr(getattr(self.bridge, "channels", None), "max_channels", 40) + + def _channel_info_frame(idx: int, ch) -> bytes: + if ch is None: + name = b"\x00" * 32 + secret = b"\x00" * 16 + else: + name = ch.name.encode("utf-8", errors="replace")[:32].ljust(32, b"\x00") + secret = (ch.secret[:16] if ch.secret else b"\x00" * 16).ljust(16, b"\x00") + return bytes([RESP_CODE_CHANNEL_INFO, idx]) + name + secret + + if get_full_list: + for idx in range(max_channels_val): + ch = self.bridge.get_channel(idx) + frame = _channel_info_frame(idx, ch) + self._write_frame(frame) + return + + if channel_idx < 0 or channel_idx >= max_channels_val: + self._write_err(ERR_CODE_NOT_FOUND) + return + ch = self.bridge.get_channel(channel_idx) + frame = _channel_info_frame(channel_idx, ch) + self._write_frame(frame) + + async def _cmd_set_channel(self, data: bytes) -> None: + if len(data) < 34: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + channel_idx = data[0] + name_raw = data[1:33] + name = name_raw.split(b"\x00")[0].decode("utf-8", errors="replace").strip() + if len(data) >= 97: + try: + secret = bytes.fromhex(data[33:97].decode("ascii")) + except (ValueError, UnicodeDecodeError): + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + elif len(data) >= 65: + secret = data[33:65] + elif len(data) >= 49: + secret = data[33:49] + else: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + ok = self.bridge.set_channel(channel_idx, name, secret) + if ok: + self._save_channels() + self._write_ok() if ok else self._write_err(ERR_CODE_NOT_FOUND) + + async def _cmd_set_flood_scope(self, data: bytes) -> None: + """Delegate flood scope to the bridge.""" + if len(data) >= 16: + self.bridge.set_flood_scope(data[:16]) + else: + self.bridge.set_flood_scope(None) + self._write_ok() From 77866eed0a8efd09b62fc947bbb59936518e01e4 Mon Sep 17 00:00:00 2001 From: agessaman Date: Wed, 18 Feb 2026 08:41:17 -0800 Subject: [PATCH 12/50] feat(companion): add missing command handlers and prefs persistence hooks Add 11 new command handler methods and 13 dispatch branches to CompanionFrameServer for binary frame types needed by CompanionBridge and CompanionRadio: CMD_GET_DEVICE_TIME, CMD_SET_DEVICE_TIME, CMD_SET_RADIO_PARAMS, CMD_SET_RADIO_TX_POWER, CMD_SHARE_CONTACT, CMD_EXPORT_CONTACT, CMD_SET_TUNING_PARAMS, CMD_LOGOUT, CMD_GET_CUSTOM_VARS, CMD_SET_CUSTOM_VAR, CMD_SET_AUTOADD_CONFIG, CMD_GET_AUTOADD_CONFIG. Add get_time(), set_time(), and send_logout() to CompanionBase. Add no-op _save_prefs()/_load_prefs() hooks to CompanionBase following the existing persistence hook pattern (_save_contacts, _save_channels, etc.). _save_prefs() is called automatically after preference-mutating methods (set_radio_params, set_tx_power, set_tuning_params, set_autoadd_config, set_other_params, set_advert_name, set_advert_latlon). _load_prefs() is called once at the end of _init_companion_stores(). Subclasses override to persist to SQLite, JSON, etc. --- src/pymc_core/companion/companion_base.py | 64 +++++++++ src/pymc_core/companion/frame_server.py | 155 ++++++++++++++++++++++ 2 files changed, 219 insertions(+) diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py index 381c9f8..78382a0 100644 --- a/src/pymc_core/companion/companion_base.py +++ b/src/pymc_core/companion/companion_base.py @@ -169,6 +169,7 @@ def _init_companion_stores( self._custom_vars: dict[str, str] = {} self._sign_buffer: Optional[bytearray] = None self._flood_transport_key: Optional[bytes] = None + self._time_offset: float = 0.0 self._event_service = EventService() self._event_subscriber = _CompanionEventSubscriber(self) @@ -191,6 +192,35 @@ def _init_companion_stores( self._seen_txt_ttl = 300 self._seen_txt_max = 1000 + # Allow subclasses to restore persisted preferences on startup. + self._load_prefs() + + # ------------------------------------------------------------------------- + # Preference Persistence Hooks + # ------------------------------------------------------------------------- + + def _save_prefs(self) -> None: + """Hook: persist the current :attr:`prefs` to stable storage. + + The default implementation is a no-op — preferences live only in + memory. Subclasses that need persistence (e.g. backed by SQLite or + a JSON file) should override this method. + + Called automatically after any preference-mutating method + (``set_radio_params``, ``set_tx_power``, ``set_tuning_params``, + ``set_autoadd_config``, ``set_other_params``, + ``set_advert_name``, ``set_advert_latlon``). + """ + + def _load_prefs(self) -> None: + """Hook: restore :attr:`prefs` from stable storage on startup. + + The default implementation is a no-op. Subclasses should override + to populate :attr:`self.prefs` fields from their persistence layer. + + Called once at the end of :meth:`_init_companion_stores`. + """ + # ------------------------------------------------------------------------- # Contact Management # ------------------------------------------------------------------------- @@ -283,6 +313,7 @@ def import_contact(self, packet_data: bytes) -> bool: def set_advert_name(self, name: str) -> None: """Set the node's advertised name (max 31 chars).""" self.prefs.node_name = name[:31] + self._save_prefs() def set_advert_latlon(self, lat: float, lon: float) -> None: """Set the GPS coordinates included in advertisements.""" @@ -292,6 +323,7 @@ def set_advert_latlon(self, lat: float, lon: float) -> None: raise ValueError(f"Longitude out of range: {lon}") self.prefs.latitude = lat self.prefs.longitude = lon + self._save_prefs() def set_radio_params(self, freq_hz: int, bw_hz: int, sf: int, cr: int) -> bool: """Set radio parameters (frequency, bandwidth, SF, CR).""" @@ -303,22 +335,37 @@ def set_radio_params(self, freq_hz: int, bw_hz: int, sf: int, cr: int) -> bool: self.prefs.bandwidth_hz = bw_hz self.prefs.spreading_factor = sf self.prefs.coding_rate = cr + self._save_prefs() return True def set_tx_power(self, power_dbm: int) -> bool: """Set the transmit power in dBm.""" self.prefs.tx_power_dbm = power_dbm + self._save_prefs() return True def set_tuning_params(self, rx_delay: float, airtime_factor: float) -> None: """Set RX delay and airtime factor tuning parameters.""" self.prefs.rx_delay_base = rx_delay self.prefs.airtime_factor = airtime_factor + self._save_prefs() def get_tuning_params(self) -> tuple[float, float]: """Return the current (rx_delay, airtime_factor) tuning parameters.""" return (self.prefs.rx_delay_base, self.prefs.airtime_factor) + def get_time(self) -> int: + """Return the current device time as a Unix timestamp.""" + return int(time.time() + self._time_offset) + + def set_time(self, secs: int) -> bool: + """Set the device time. Returns False if *secs* is in the past.""" + current = self.get_time() + if secs < current: + return False + self._time_offset = secs - time.time() + return True + def set_other_params( self, manual_add: int, @@ -333,6 +380,7 @@ def set_other_params( self.prefs.telemetry_mode_environment = (telemetry_modes >> 4) & 0x03 self.prefs.advert_loc_policy = advert_loc_policy self.prefs.multi_acks = multi_acks + self._save_prefs() def get_self_info(self) -> NodePrefs: """Return a copy of the current node preferences.""" @@ -513,6 +561,7 @@ def get_autoadd_config(self) -> int: def set_autoadd_config(self, config: int) -> None: """Set the auto-add configuration bitmask.""" self.prefs.autoadd_config = config + self._save_prefs() # ------------------------------------------------------------------------- # Push Callbacks @@ -1037,6 +1086,21 @@ def _login_cb(success: bool, data: dict) -> None: login_handler.set_login_callback(None) login_handler.clear_login_password(dest_hash) + async def send_logout(self, pub_key: bytes) -> bool: + """Send a logout / disconnect to a repeater contact.""" + contact = self.contacts.get_by_key(pub_key) + if not contact: + return False + try: + pkt, _ = PacketBuilder.create_logout_packet( + contact=contact, local_identity=self._identity + ) + await self._send_packet(pkt, wait_for_ack=False) + return True + except Exception as e: + logger.error(f"Logout error: {e}") + return False + async def send_status_request(self, pub_key: bytes, timeout: float = 15.0) -> dict: """Send a protocol request for repeater status/stats.""" contact = self.contacts.get_by_key(pub_key) diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index 6ed83ff..031fd61 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -22,13 +22,18 @@ CMD_ADD_UPDATE_CONTACT, CMD_APP_START, CMD_DEVICE_QUERY, + CMD_EXPORT_CONTACT, CMD_GET_ADVERT_PATH, + CMD_GET_AUTOADD_CONFIG, CMD_GET_BATT_AND_STORAGE, CMD_GET_CHANNEL, CMD_GET_CONTACT_BY_KEY, CMD_GET_CONTACTS, + CMD_GET_CUSTOM_VARS, + CMD_GET_DEVICE_TIME, CMD_GET_STATS, CMD_IMPORT_CONTACT, + CMD_LOGOUT, CMD_REMOVE_CONTACT, CMD_RESET_PATH, CMD_SEND_BINARY_REQ, @@ -43,8 +48,15 @@ CMD_SEND_TXT_MSG, CMD_SET_ADVERT_LATLON, CMD_SET_ADVERT_NAME, + CMD_SET_AUTOADD_CONFIG, CMD_SET_CHANNEL, + CMD_SET_CUSTOM_VAR, + CMD_SET_DEVICE_TIME, CMD_SET_FLOOD_SCOPE, + CMD_SET_RADIO_PARAMS, + CMD_SET_RADIO_TX_POWER, + CMD_SET_TUNING_PARAMS, + CMD_SHARE_CONTACT, CMD_SYNC_NEXT_MESSAGE, ERR_CODE_BAD_STATE, ERR_CODE_ILLEGAL_ARG, @@ -71,6 +83,7 @@ PUSH_CODE_TELEMETRY_RESPONSE, PUSH_CODE_TRACE_DATA, RESP_CODE_ADVERT_PATH, + RESP_CODE_AUTOADD_CONFIG, RESP_CODE_BATT_AND_STORAGE, RESP_CODE_CHANNEL_INFO, RESP_CODE_CHANNEL_MSG_RECV, @@ -79,9 +92,12 @@ RESP_CODE_CONTACT_MSG_RECV, RESP_CODE_CONTACT_MSG_RECV_V3, RESP_CODE_CONTACTS_START, + RESP_CODE_CURR_TIME, + RESP_CODE_CUSTOM_VARS, RESP_CODE_DEVICE_INFO, RESP_CODE_END_OF_CONTACTS, RESP_CODE_ERR, + RESP_CODE_EXPORT_CONTACT, RESP_CODE_NO_MORE_MESSAGES, RESP_CODE_OK, RESP_CODE_SELF_INFO, @@ -652,6 +668,30 @@ async def _handle_cmd(self, payload: bytes) -> None: await self._cmd_send_trace_path(data) elif cmd == CMD_SET_FLOOD_SCOPE: await self._cmd_set_flood_scope(data) + elif cmd == CMD_GET_DEVICE_TIME: + await self._cmd_get_device_time(data) + elif cmd == CMD_SET_DEVICE_TIME: + await self._cmd_set_device_time(data) + elif cmd == CMD_SET_RADIO_PARAMS: + await self._cmd_set_radio_params(data) + elif cmd == CMD_SET_RADIO_TX_POWER: + await self._cmd_set_tx_power(data) + elif cmd == CMD_SHARE_CONTACT: + await self._cmd_share_contact(data) + elif cmd == CMD_EXPORT_CONTACT: + await self._cmd_export_contact(data) + elif cmd == CMD_SET_TUNING_PARAMS: + await self._cmd_set_tuning_params(data) + elif cmd == CMD_LOGOUT: + await self._cmd_logout(data) + elif cmd == CMD_GET_CUSTOM_VARS: + await self._cmd_get_custom_vars(data) + elif cmd == CMD_SET_CUSTOM_VAR: + await self._cmd_set_custom_var(data) + elif cmd == CMD_SET_AUTOADD_CONFIG: + await self._cmd_set_autoadd_config(data) + elif cmd == CMD_GET_AUTOADD_CONFIG: + await self._cmd_get_autoadd_config(data) else: logger.warning( "Companion unsupported cmd 0x%02x (%s) len=%s", @@ -1379,3 +1419,118 @@ async def _cmd_set_flood_scope(self, data: bytes) -> None: else: self.bridge.set_flood_scope(None) self._write_ok() + + # ------------------------------------------------------------------------- + # Time, radio, tuning, share/export, logout, custom vars, autoadd + # ------------------------------------------------------------------------- + + async def _cmd_get_device_time(self, data: bytes) -> None: + now = self.bridge.get_time() + self._write_frame(bytes([RESP_CODE_CURR_TIME]) + struct.pack(" None: + if len(data) < 4: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + secs = struct.unpack(" None: + if len(data) < 10: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + freq = struct.unpack_from(" None: + if len(data) < 1: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + power = struct.unpack_from(" 20: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + self.bridge.set_tx_power(power) + self._write_ok() + + async def _cmd_share_contact(self, data: bytes) -> None: + if len(data) < PUB_KEY_SIZE: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + pubkey = data[:PUB_KEY_SIZE] + ok = await self.bridge.share_contact(pubkey) + self._write_ok() if ok else self._write_err(ERR_CODE_NOT_FOUND) + + async def _cmd_export_contact(self, data: bytes) -> None: + if len(data) < PUB_KEY_SIZE: + raw = self.bridge.export_contact(None) + else: + raw = self.bridge.export_contact(data[:PUB_KEY_SIZE]) + if raw is None: + self._write_err(ERR_CODE_NOT_FOUND) + return + self._write_frame(bytes([RESP_CODE_EXPORT_CONTACT]) + raw) + + async def _cmd_set_tuning_params(self, data: bytes) -> None: + if len(data) < 8: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + rx_ms = struct.unpack_from(" None: + if len(data) < PUB_KEY_SIZE: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + pubkey = data[:PUB_KEY_SIZE] + await self.bridge.send_logout(pubkey) + self._write_ok() + + async def _cmd_get_custom_vars(self, data: bytes) -> None: + custom_vars = self.bridge.get_custom_vars() + parts = [f"{k}:{v}" for k, v in custom_vars.items()] + csv = ",".join(parts)[:140] + self._write_frame(bytes([RESP_CODE_CUSTOM_VARS]) + csv.encode("utf-8", errors="replace")) + + async def _cmd_set_custom_var(self, data: bytes) -> None: + if len(data) < 3: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + text = data.split(b"\x00")[0].decode("utf-8", errors="replace") + sep = text.find(":") + if sep < 1: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + name = text[:sep] + value = text[sep + 1 :] + ok = self.bridge.set_custom_var(name, value) + self._write_ok() if ok else self._write_err(ERR_CODE_ILLEGAL_ARG) + + async def _cmd_set_autoadd_config(self, data: bytes) -> None: + if len(data) < 1: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + self.bridge.set_autoadd_config(data[0]) + self._write_ok() + + async def _cmd_get_autoadd_config(self, data: bytes) -> None: + config = self.bridge.get_autoadd_config() + self._write_frame(bytes([RESP_CODE_AUTOADD_CONFIG, config & 0xFF])) From 31460a8924aadd3d0434ad393ed6c4bb6a9a863f Mon Sep 17 00:00:00 2001 From: agessaman Date: Wed, 18 Feb 2026 16:20:23 -0800 Subject: [PATCH 13/50] feat(companion): expand documentation on new class structure Update the companion module documentation to clarify the roles of three main classes: `CompanionRadio`, `CompanionBridge`, and `CompanionFrameServer`. Introduce new methods for sending text and channel messages, and improve the persistence hooks for preferences. Document the TCP frame protocol used by `CompanionFrameServer`, including frame format and quick start examples. Enhance the description of preference management and device configuration methods, ensuring clarity on their functionality and usage. --- docs/docs/companion.md | 372 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 356 insertions(+), 16 deletions(-) diff --git a/docs/docs/companion.md b/docs/docs/companion.md index 7aeb282..686cbe9 100644 --- a/docs/docs/companion.md +++ b/docs/docs/companion.md @@ -2,14 +2,17 @@ The companion module provides a high-level Python interface to the MeshCore companion radio protocol. It manages contacts, messaging, channels, advertisements, path routing, telemetry, cryptographic signing, and device configuration on top of pyMC_core's `MeshNode`. -Two implementations are provided: +Three main classes are provided: | Class | Owns Radio | Use Case | |---|---|---| | `CompanionRadio` | Yes | Standalone companion — wraps a hardware radio and `MeshNode` | | `CompanionBridge` | No | Repeater-integrated companion — shares an existing dispatcher via a packet injector callback | +| `CompanionFrameServer` | No | TCP server implementing the MeshCore companion frame protocol (the binary wire format used by companion apps) | -Both inherit from `CompanionBase` (an abstract base class), which holds all shared stores, event handling, device configuration logic, and unified TX methods (advertising, binary requests, path discovery, offline queue sync). Subclasses implement transport via the abstract `_send_packet` method. +`CompanionRadio` and `CompanionBridge` both inherit from `CompanionBase` (an abstract base class), which holds all shared stores, event handling, device configuration logic, and unified TX methods. Subclasses implement transport via the abstract `_send_packet` method. + +`CompanionFrameServer` wraps a `CompanionBridge` (or any `CompanionBase` subclass) and exposes it over TCP using the same binary frame protocol as the MeshCore firmware companion radio. --- @@ -24,13 +27,25 @@ CompanionBase (ABC) ├── StatsCollector (TX/RX counters, uptime) ├── NodePrefs (radio params, name, location) │ +│ Persistence hooks (no-op by default, override for SQLite/JSON): +│ _save_prefs, _load_prefs +│ │ Unified methods (use abstract _send_packet): -│ advertise, share_contact, send_binary_req, -│ send_path_discovery_req, send_trace_path_raw, -│ sync_next_message +│ advertise, share_contact, send_text_message, +│ send_channel_message, send_binary_req, +│ send_path_discovery_req, send_trace_path, +│ send_login, send_logout, sync_next_message │ ├─► CompanionRadio (owns MeshNode + hardware radio) └─► CompanionBridge (packet_injector callback, no radio) + │ + └─► CompanionFrameServer (TCP binary frame protocol server) + │ + │ Persistence hooks (no-op by default): + │ _persist_companion_message, _sync_next_from_persistence, + │ _save_contacts, _save_channels, _get_batt_and_storage + │ + └─► (your subclass, e.g. SQLite-backed repeater) ``` --- @@ -181,7 +196,7 @@ companion.remove_contact(pub_key_bytes) companion.reset_path(pub_key_bytes) # Serialise for sharing -blob = companion.export_contact(pub_key) # bytes +blob = companion.export_contact(pub_key) # bytes (73-byte binary packet) ok = companion.import_contact(blob) # bool ``` @@ -213,6 +228,9 @@ advert_path = companion.get_advert_path(pub_key_prefix_7bytes) # Login to a repeater resp = await companion.send_login(repeater_key, password="secret") +# Logout from a repeater +ok = await companion.send_logout(repeater_key) + # Request repeater status resp = await companion.send_status_request(repeater_key) @@ -245,23 +263,37 @@ companion.on_binary_response( ### Device Configuration +All preference-mutating methods automatically call `_save_prefs()` (a no-op by default; override in subclasses for persistence). + ```python -companion.set_advert_name("NewName") # max 31 chars -companion.set_advert_latlon(37.7749, -122.4194) # GPS coordinates +companion.set_advert_name("NewName") # max 31 chars +companion.set_advert_latlon(37.7749, -122.4194) # GPS coordinates companion.set_radio_params(915_000_000, 250_000, 10, 5) # freq, bw, SF, CR -companion.set_tx_power(22) # dBm +companion.set_tx_power(22) # dBm companion.set_tuning_params(rx_delay=0.0, airtime_factor=0.0) +# Time management (transient, not persisted) +device_time = companion.get_time() # Unix timestamp +ok = companion.set_time(1700000000) # returns False if in the past + +# Custom variables (key:value string pairs, max 140 chars total) +custom_vars = companion.get_custom_vars() # -> dict[str, str] +ok = companion.set_custom_var("key", "value") # -> bool + +# Auto-add configuration +config = companion.get_autoadd_config() # -> int bitmask +companion.set_autoadd_config(AUTOADD_CHAT | AUTOADD_REPEATER) + # Location sharing in adverts from pymc_core.companion import ADVERT_LOC_SHARE companion.set_other_params( manual_add=0, - telemetry_modes=(0, 0, 0), + telemetry_modes=0, advert_loc_policy=ADVERT_LOC_SHARE, multi_acks=0, ) -prefs = companion.get_self_info() # -> NodePrefs +prefs = companion.get_self_info() # -> NodePrefs (copy) ``` ### Flood Scope (Regions) @@ -308,6 +340,18 @@ stats = companion.get_stats(STATS_TYPE_PACKETS) # {'flood_tx': 42, 'flood_rx': 108, 'direct_tx': 5, 'direct_rx': 12, ...} ``` +### CompanionRadio-Specific Overrides + +`CompanionRadio` overrides several `CompanionBase` methods to also configure the physical radio hardware: + +| Method | Base Behavior | Radio Override | +|---|---|---| +| `set_radio_params()` | Updates `prefs` fields | Also calls `radio.configure_radio()` | +| `set_tx_power()` | Updates `prefs.tx_power_dbm` | Also calls `radio.set_tx_power()` | +| `set_advert_name()` | Updates `prefs.node_name` | Also syncs `node.node_name` | +| `set_flood_scope()` | Stores transport key | Also syncs to `node.dispatcher` | +| `set_flood_region()` | Derives key from name | Also syncs to `node.dispatcher` | + --- ## CompanionBridge @@ -396,6 +440,212 @@ The bridge registers internal handlers for these payload types: `CompanionBridge` exposes the same messaging, contact, channel, path, signing, stats, and configuration APIs as `CompanionRadio` (inherited from `CompanionBase`). The only behavioral difference is that all TX goes through the `packet_injector` instead of an owned radio. +Note that `set_radio_params()` and `set_tx_power()` update in-memory prefs only — there is no physical radio to configure. This is correct: the repeater host owns the radio hardware. + +--- + +## CompanionFrameServer + +`CompanionFrameServer` implements the MeshCore companion radio TCP frame protocol — the same binary wire format used by the C++ firmware (`examples/companion_radio/`). It wraps a `CompanionBase` subclass (typically a `CompanionBridge`) and exposes it to companion apps (e.g. MeshCore Android/iOS) over a TCP socket. + +### Frame Format + +All frames use a simple length-prefixed format: + +| Direction | Prefix | Length | Data | +|---|---|---|---| +| App → Radio | `<` (0x3C) | 2-byte LE | Command byte + payload | +| Radio → App | `>` (0x3E) | 2-byte LE | Response/push byte + payload | + +Maximum frame size: 512 bytes. + +### Quick Start + +```python +from pymc_core import LocalIdentity +from pymc_core.companion import CompanionBridge, CompanionFrameServer + +identity = LocalIdentity() +bridge = CompanionBridge(identity=identity, packet_injector=my_injector) + +server = CompanionFrameServer( + bridge=bridge, + companion_hash="abcd1234", # identifier for this companion + port=5000, + device_model="pyMC-Companion", + device_version="1.0.0", +) + +await server.start() # starts listening on TCP port +# ... companion app connects and sends commands ... +await server.stop() +``` + +### Constructor + +```python +CompanionFrameServer( + bridge: CompanionBase, # the companion to wrap + companion_hash: str, # unique identifier + port: int = 5000, + bind_address: str = "0.0.0.0", + *, + device_model: str = "pyMC-Companion", + device_version: str = "1.0.0", + build_date: str = "", + local_hash: int | None = None, + stats_getter: Callable | None = None, + control_handler: Any | None = None, +) +``` + +### Supported Commands + +The frame server handles the following companion radio protocol commands: + +| CMD | Code | Description | +|---|---|---| +| `CMD_APP_START` | 1 | Initialize connection, return device info | +| `CMD_SEND_TXT_MSG` | 2 | Send a direct text message | +| `CMD_SEND_CHANNEL_TXT_MSG` | 3 | Send a channel message | +| `CMD_GET_CONTACTS` | 4 | Retrieve contact list (paginated) | +| `CMD_GET_DEVICE_TIME` | 5 | Get current device time | +| `CMD_SET_DEVICE_TIME` | 6 | Set device time | +| `CMD_SEND_SELF_ADVERT` | 7 | Broadcast self advertisement | +| `CMD_SET_ADVERT_NAME` | 8 | Set advertised node name | +| `CMD_ADD_UPDATE_CONTACT` | 9 | Add or update a contact | +| `CMD_SYNC_NEXT_MESSAGE` | 10 | Pop next queued message | +| `CMD_SET_RADIO_PARAMS` | 11 | Set frequency, bandwidth, SF, CR | +| `CMD_SET_RADIO_TX_POWER` | 12 | Set transmit power | +| `CMD_RESET_PATH` | 13 | Reset routing path for a contact | +| `CMD_SET_ADVERT_LATLON` | 14 | Set GPS coordinates | +| `CMD_REMOVE_CONTACT` | 15 | Remove a contact | +| `CMD_SHARE_CONTACT` | 16 | Share a contact to the mesh | +| `CMD_EXPORT_CONTACT` | 17 | Export contact as 73-byte blob | +| `CMD_IMPORT_CONTACT` | 18 | Import contact from blob | +| `CMD_GET_BATT_AND_STORAGE` | 20 | Get battery/storage info | +| `CMD_SET_TUNING_PARAMS` | 21 | Set RX delay and airtime factor | +| `CMD_DEVICE_QUERY` | 22 | Return device model/version | +| `CMD_SEND_LOGIN` | 26 | Login to a repeater | +| `CMD_SEND_STATUS_REQ` | 27 | Request repeater status | +| `CMD_LOGOUT` | 29 | Logout from a repeater | +| `CMD_GET_CONTACT_BY_KEY` | 30 | Look up contact by public key | +| `CMD_GET_CHANNEL` | 31 | Get a channel by index | +| `CMD_SET_CHANNEL` | 32 | Set a channel | +| `CMD_SEND_TRACE_PATH` | 36 | Send trace path request | +| `CMD_SEND_TELEMETRY_REQ` | 39 | Request telemetry data | +| `CMD_GET_CUSTOM_VARS` | 40 | Get custom variables | +| `CMD_SET_CUSTOM_VAR` | 41 | Set a custom variable | +| `CMD_GET_ADVERT_PATH` | 42 | Get cached advert path | +| `CMD_SEND_BINARY_REQ` | 50 | Send binary request | +| `CMD_SEND_PATH_DISCOVERY_REQ` | 52 | Send path discovery | +| `CMD_SET_FLOOD_SCOPE` | 54 | Set flood scope transport key | +| `CMD_SEND_CONTROL_DATA` | 55 | Send control data | +| `CMD_GET_STATS` | 56 | Get statistics | +| `CMD_SET_AUTOADD_CONFIG` | 58 | Set auto-add configuration | +| `CMD_GET_AUTOADD_CONFIG` | 59 | Get auto-add configuration | + +### Push Notifications + +The frame server sends unsolicited push frames to the companion app when events occur: + +| Push Code | Value | Description | +|---|---|---| +| `PUSH_CODE_ADVERT` | 0x80 | Contact advertisement received | +| `PUSH_CODE_MSG_WAITING` | 0x82 | New message queued | +| `PUSH_CODE_SEND_CONFIRMED` | 0x84 | ACK received for a sent message | +| `PUSH_CODE_PATH_UPDATED` | 0x86 | Contact path updated | +| `PUSH_CODE_LOG_RX_DATA` | 0x88 | Raw RX packet (diagnostics) | +| `PUSH_CODE_TRACE_DATA` | 0x89 | Trace path response | +| `PUSH_CODE_NEW_ADVERT` | 0x8A | New (previously unknown) contact discovered | +| `PUSH_CODE_CONTROL_DATA` | 0x8E | Control data received | +| `PUSH_CODE_LOGIN_SUCCESS` | 0x91 | Repeater login succeeded | +| `PUSH_CODE_LOGIN_FAIL` | 0x92 | Repeater login failed | +| `PUSH_CODE_STATUS_RESPONSE` | 0x93 | Repeater status response | +| `PUSH_CODE_TELEMETRY_RESPONSE` | 0x94 | Telemetry response | +| `PUSH_CODE_BINARY_RESPONSE` | 0x95 | Binary request response | +| `PUSH_CODE_PATH_DISCOVERY_RESPONSE` | 0x96 | Path discovery response | + +### Host-Callable Push Methods + +The frame server exposes methods for the host application to push data to the connected companion app: + +```python +# Push trace data from the repeater +server.push_trace_data( + path_len=3, flags=0, tag=42, auth_code=0, + path_hashes=b"...", path_snrs=b"...", final_snr_byte=0 +) + +# Push raw RX packet for diagnostics logging +server.push_rx_raw(snr=-5.0, rssi=-100, raw=b"...") + +# Push control data +await server.push_control_data( + snr=-5.0, rssi=-100, path_len=2, + path_bytes=b"...", payload=b"..." +) +``` + +### Persistence Hooks + +`CompanionFrameServer` provides no-op hooks that subclasses override for persistent storage (e.g. SQLite): + +```python +class MyFrameServer(CompanionFrameServer): + async def _persist_companion_message(self, msg_dict: dict) -> None: + """Called when a message is received. Save to database.""" + await self.db.save_message(msg_dict) + + def _sync_next_from_persistence(self) -> QueuedMessage | None: + """Called when the in-memory queue is empty. Pop from database.""" + return self.db.pop_oldest_message() + + def _save_contacts(self) -> None: + """Called after contact list changes. Sync to database.""" + self.db.save_contacts(self.bridge.contacts.to_dicts()) + + def _save_channels(self) -> None: + """Called after channel changes. Sync to database.""" + self.db.save_channels(...) + + def _get_batt_and_storage(self) -> tuple[int, int, int]: + """Return (millivolts, used_kb, total_kb) for CMD_GET_BATT_AND_STORAGE.""" + return (4200, 128, 1024) +``` + +--- + +## Persistence Hooks + +The companion module uses a "no-op hook" pattern for persistence: base classes define empty methods that subclasses override to save/load state from their storage backend. + +### CompanionBase Hooks (Preferences) + +```python +class CompanionBase: + def _save_prefs(self) -> None: + """Persist self.prefs. Called after any pref-mutating method.""" + + def _load_prefs(self) -> None: + """Restore self.prefs on startup. Called at end of _init_companion_stores().""" +``` + +`_save_prefs()` is called automatically by: `set_radio_params`, `set_tx_power`, `set_tuning_params`, `set_autoadd_config`, `set_other_params`, `set_advert_name`, `set_advert_latlon`. + +Note: `set_time()` does **not** call `_save_prefs()` — the time offset is a transient runtime correction, not a persistent preference. + +### CompanionFrameServer Hooks (Messages, Contacts, Channels) + +```python +class CompanionFrameServer: + async def _persist_companion_message(self, msg_dict: dict) -> None: ... + def _sync_next_from_persistence(self) -> QueuedMessage | None: ... + def _save_contacts(self) -> None: ... + def _save_channels(self) -> None: ... + def _get_batt_and_storage(self) -> tuple[int, int, int]: ... +``` + --- ## Use Cases @@ -463,7 +713,31 @@ async def repeater_on_rx(pkt): # ... also handle repeater logic ... ``` -### 4. Network Diagnostics Tool +### 4. Companion Frame Server (TCP Protocol) + +Expose a companion over TCP so standard companion apps can connect. + +```python +bridge = CompanionBridge( + identity=identity, + packet_injector=repeater.inject_packet, + node_name="pyMC-Server", +) + +server = CompanionFrameServer( + bridge=bridge, + companion_hash="abcd1234", + port=5000, + device_model="pyMC-Companion", + device_version="1.0.0", +) + +await bridge.start() +await server.start() +# Companion apps (Android/iOS) can now connect on port 5000 +``` + +### 5. Network Diagnostics Tool Trace paths and discover topology. @@ -480,7 +754,7 @@ await companion.send_trace_path(target_key, tag=1, auth_code=0) await companion.send_path_discovery_req(target_key) ``` -### 5. Group Chat / Channels +### 6. Group Chat / Channels ```python companion.set_channel(0, name="Emergency", secret=b"shared_channel_secret___________") @@ -535,6 +809,7 @@ class Contact: lastmod: int = 0 gps_lat: float = 0.0 gps_lon: float = 0.0 + sync_since: int = 0 ``` ### Channel @@ -546,6 +821,31 @@ class Channel: secret: bytes # 16-byte pre-shared key ``` +### NodePrefs + +```python +@dataclass +class NodePrefs: + node_name: str = "pyMC" + adv_type: int = 1 # ADV_TYPE_CHAT + tx_power_dbm: int = 20 + frequency_hz: int = 915000000 + bandwidth_hz: int = 250000 + spreading_factor: int = 10 + coding_rate: int = 5 + latitude: float = 0.0 + longitude: float = 0.0 + advert_loc_policy: int = 0 # ADVERT_LOC_NONE + multi_acks: int = 0 + telemetry_mode_base: int = 0 # TELEM_MODE_DENY + telemetry_mode_location: int = 0 + telemetry_mode_environment: int = 0 + manual_add_contacts: int = 0 + autoadd_config: int = 0 + rx_delay_base: float = 0.0 + airtime_factor: float = 0.0 +``` + ### SentResult ```python @@ -571,6 +871,30 @@ class QueuedMessage: path_len: int = 0 ``` +### AdvertPath + +```python +@dataclass +class AdvertPath: + public_key_prefix: bytes # 7-byte prefix + name: str = "" + path_len: int = 0 + path: bytes = b"" + recv_timestamp: int = 0 +``` + +### PacketStats + +```python +@dataclass +class PacketStats: + flood_tx: int = 0 + flood_rx: int = 0 + direct_tx: int = 0 + direct_rx: int = 0 + tx_errors: int = 0 +``` + --- ## Constants @@ -621,8 +945,26 @@ PROTOCOL_CODE_RAW_DATA = 0x00 PROTOCOL_CODE_BINARY_REQ = 0x02 PROTOCOL_CODE_ANON_REQ = 0x07 -# Timeouts +# Frame format +FRAME_OUTBOUND_PREFIX = 0x3E # '>' (radio → app) +FRAME_INBOUND_PREFIX = 0x3C # '<' (app → radio) +MAX_FRAME_SIZE = 512 + +# Error codes (returned by frame server) +ERR_CODE_UNSUPPORTED_CMD = 1 +ERR_CODE_NOT_FOUND = 2 +ERR_CODE_TABLE_FULL = 3 +ERR_CODE_BAD_STATE = 4 +ERR_CODE_FILE_IO_ERROR = 5 +ERR_CODE_ILLEGAL_ARG = 6 + +# Defaults DEFAULT_RESPONSE_TIMEOUT_MS = 10000 +DEFAULT_MAX_CONTACTS = 1000 +DEFAULT_MAX_CHANNELS = 40 +DEFAULT_OFFLINE_QUEUE_SIZE = 512 +PUB_KEY_SIZE = 32 +MAX_PATH_SIZE = 64 ``` --- @@ -633,9 +975,7 @@ The following protocol-level features from the MeshCore companion radio firmware | Feature | Firmware Reference | Description | |---|---|---| -| Logout | `CMD_LOGOUT` (0x1D) | Disconnect from a repeater/server session | | Has connection | `CMD_HAS_CONNECTION` (0x1C) | Check if active connection exists to a contact | | Push: contact deleted | `PUSH_CODE_CONTACT_DELETED` (0x8F) | Notification when a contact is overwritten by auto-add | | Push: contacts full | `PUSH_CODE_CONTACTS_FULL` (0x90) | Notification when contact storage is full | -| Push: RX data log | `PUSH_CODE_LOG_RX_DATA` (0x88) | Raw received packet logging for diagnostics | | Keep-alive mechanism | Server-driven keep-alive | Periodic keep-alive packets for active server connections | From 5e1b5ef90d0d7c07028c106473f7801d3fa79552 Mon Sep 17 00:00:00 2001 From: agessaman Date: Wed, 18 Feb 2026 20:46:22 -0800 Subject: [PATCH 14/50] feat(group_text): enhance channel hash handling and decryption logic - Updated `_get_channel_by_hash` to return all matching channels due to potential hash collisions, improving robustness. - Introduced `_get_channels_by_hash` to handle multiple channel matches and updated documentation for clarity. - Modified `_decrypt_channel_message` to return `None` on HMAC mismatch, aligning with expected behavior during candidate iteration. - Adjusted message parsing logic to iterate through candidate channels until a valid decryption is found, logging potential hash collisions for better diagnostics. --- src/pymc_core/node/handlers/group_text.py | 78 ++++++++++++++++------- 1 file changed, 55 insertions(+), 23 deletions(-) diff --git a/src/pymc_core/node/handlers/group_text.py b/src/pymc_core/node/handlers/group_text.py index f2438ef..7671d96 100644 --- a/src/pymc_core/node/handlers/group_text.py +++ b/src/pymc_core/node/handlers/group_text.py @@ -30,24 +30,39 @@ def __init__( self.our_node_name = our_node_name # Store our node name for echo detection def _get_channel_by_hash(self, channel_hash: int) -> Optional[dict]: - """Find a channel by its hash (first byte of public key) from database.""" + """Find a channel by its hash (first byte of SHA256) from database. + + Returns the first matching channel. See also + :meth:`_get_channels_by_hash` which returns *all* matches (needed + because the hash is only 1 byte and collisions are expected). + """ + matches = self._get_channels_by_hash(channel_hash) + return matches[0] if matches else None + + def _get_channels_by_hash(self, channel_hash: int) -> list[dict]: + """Return **all** channels whose derived hash matches *channel_hash*. + + The channel hash is only 1 byte, so collisions between channels + with different PSKs are expected (~0.4 % per foreign channel). + The firmware handles this by trying each match until HMAC validates; + we do the same. + """ if not self.channel_db: self.log("No channel database available") - return None + return [] try: - # Get all channels from database (live query) channels = self.channel_db.get_channels() + matches = [] for channel in channels: if "secret" in channel: - # Use consistent channel hash derivation calculated_hash = self._derive_channel_hash(channel["secret"]) if calculated_hash == channel_hash: - return channel - return None + matches.append(channel) + return matches except Exception as e: self.log(f"Error querying channel database: {e}") - return None + return [] def _secret_bytes_for_hash(self, channel_secret: str) -> bytes: """Normalize secret to bytes used for channel hash (match MeshCore firmware). @@ -87,7 +102,12 @@ def _derive_channel_keys(self, channel_secret: str) -> tuple: def _decrypt_channel_message( self, channel_secret: str, mac: bytes, ciphertext: bytes ) -> Optional[bytes]: - """Decrypt a channel message using the channel secret.""" + """Attempt to decrypt a channel message using *channel_secret*. + + Returns the plaintext on success, or ``None`` if the HMAC does not + validate (which is expected during candidate iteration when multiple + channels share the same 1-byte hash). + """ try: # Convert hex secret to bytes try: @@ -97,22 +117,19 @@ def _decrypt_channel_message( # Ensure we have PUB_KEY_SIZE (32 bytes) for the secret if len(secret_bytes) < 32: - # Pad with zeros if needed secret_bytes = secret_bytes + b"\x00" * (32 - len(secret_bytes)) elif len(secret_bytes) > 32: - # Truncate if too long secret_bytes = secret_bytes[:32] expected_mac = CryptoUtils._hmac_sha256(secret_bytes, ciphertext)[:2] if mac != expected_mac: - raise ValueError("Invalid HMAC") + return None # HMAC mismatch — normal during candidate iteration - plaintext = CryptoUtils._aes_decrypt(secret_bytes[:16], ciphertext) - return plaintext + return CryptoUtils._aes_decrypt(secret_bytes[:16], ciphertext) except Exception as e: - self.log(f"Channel message decryption failed: {e}") + self.log(f"Channel message decryption error: {e}") return None def _parse_plaintext_message(self, plaintext: bytes) -> Optional[dict]: @@ -193,21 +210,36 @@ async def __call__(self, packet: Packet) -> None: cipher_mac = payload[1:3] ciphertext = payload[3:] - # Find the channel configuration - channel = self._get_channel_by_hash(channel_hash) - if not channel: + # Find all channels whose 1-byte hash matches (collisions are + # expected; the firmware tries up to 4 candidates). + candidates = self._get_channels_by_hash(channel_hash) + if not candidates: self.log(f"Unknown channel hash: {channel_hash:02X}") return + # Try each candidate until HMAC validates (matches firmware behaviour). + channel = None + plaintext = None + for candidate in candidates: + result = self._decrypt_channel_message(candidate["secret"], cipher_mac, ciphertext) + if result is not None: + channel = candidate + plaintext = result + break + + if channel is None or plaintext is None: + # No candidate validated — this is normal for hash collisions + # with channels on other networks. + self.log( + f"GRP_TXT hash {channel_hash:02X} matched " + f"{len(candidates)} channel(s) but HMAC failed for all " + f"— likely a hash collision from another network" + ) + return + channel_name = channel.get("name", f"Channel-{channel_hash:02X}") self.log(f"Received group message for channel: {channel_name}") - # Decrypt the message - plaintext = self._decrypt_channel_message(channel["secret"], cipher_mac, ciphertext) - if not plaintext: - self.log("Failed to decrypt channel message") - return - # Parse the decrypted message parsed_message = self._parse_plaintext_message(plaintext) if not parsed_message: From 07a0334796f42c248e2d93000a56d282f3719588 Mon Sep 17 00:00:00 2001 From: agessaman Date: Wed, 18 Feb 2026 21:51:03 -0800 Subject: [PATCH 15/50] docs(companion): clarify telemetry request frame format - Added documentation comments in `_cmd_send_telemetry_req` to specify the expected byte structure for the CMD_SEND_TELEMETRY_REQ protocol, detailing reserved bytes and public key format. - Updated comments in `GroupTextHandler` to improve clarity on hash collision handling and logging for unknown channels. --- src/pymc_core/companion/frame_server.py | 2 ++ src/pymc_core/node/handlers/group_text.py | 8 ++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index 031fd61..9506959 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -1122,6 +1122,8 @@ async def _cmd_send_status_req(self, data: bytes) -> None: self._write_frame(bytes([PUSH_CODE_STATUS_RESPONSE, 0]) + pubkey[:6] + raw_bytes) async def _cmd_send_telemetry_req(self, data: bytes) -> None: + # Protocol: CMD_SEND_TELEMETRY_REQ has reserved bytes(3) then pub_key bytes(32). + # See MeshCore Companion-Radio-Protocol: CMD_SEND_TELEMETRY_REQ frame format. if len(data) < 35: self._write_err(ERR_CODE_ILLEGAL_ARG) return diff --git a/src/pymc_core/node/handlers/group_text.py b/src/pymc_core/node/handlers/group_text.py index 7671d96..4e51422 100644 --- a/src/pymc_core/node/handlers/group_text.py +++ b/src/pymc_core/node/handlers/group_text.py @@ -228,12 +228,12 @@ async def __call__(self, packet: Packet) -> None: break if channel is None or plaintext is None: - # No candidate validated — this is normal for hash collisions - # with channels on other networks. + # No candidate validated — the packet is for a channel we + # don't have the key for (hash collision with 1-byte hash). self.log( f"GRP_TXT hash {channel_hash:02X} matched " - f"{len(candidates)} channel(s) but HMAC failed for all " - f"— likely a hash collision from another network" + f"{len(candidates)} local channel(s) but HMAC failed " + f"for all — unknown channel" ) return From a4ec4d020630f50901231ba65b4c1987ac3f3532 Mon Sep 17 00:00:00 2001 From: agessaman Date: Sat, 21 Feb 2026 21:29:07 -0800 Subject: [PATCH 16/50] feat(companion): add owner info request handling and parsing, improve handling of path, req, and non_req packets. - Introduced support for the OWNER_INFO request type in the binary parsing logic, allowing the parsing of owner information from binary responses. - Implemented the `_parse_owner_info` function to extract and structure owner data from the response payload. - Enhanced the `send_anon_req` method in `CompanionBase` to facilitate sending anonymous requests for owner information. - Updated the `CompanionFrameServer` to handle the new CMD_SEND_ANON_REQ command, ensuring proper processing of anonymous requests. - Added constants for the OWNER_INFO request type and updated relevant documentation for clarity on the new functionality. - Fix hang due to race condition in kiss_modem_wrapper.py --- src/pymc_core/companion/binary_parsing.py | 18 ++ src/pymc_core/companion/companion_base.py | 60 +++++- src/pymc_core/companion/companion_bridge.py | 64 +++++- src/pymc_core/companion/constants.py | 5 + src/pymc_core/companion/frame_server.py | 58 +++++- src/pymc_core/hardware/kiss_modem_wrapper.py | 21 +- .../node/handlers/protocol_request.py | 1 + .../node/handlers/protocol_response.py | 190 ++++++++++++------ src/pymc_core/node/handlers/text.py | 50 +++-- src/pymc_core/protocol/__init__.py | 2 + src/pymc_core/protocol/constants.py | 1 + src/pymc_core/protocol/identity.py | 14 +- 12 files changed, 386 insertions(+), 98 deletions(-) diff --git a/src/pymc_core/companion/binary_parsing.py b/src/pymc_core/companion/binary_parsing.py index d2fbcd9..74d3791 100644 --- a/src/pymc_core/companion/binary_parsing.py +++ b/src/pymc_core/companion/binary_parsing.py @@ -24,6 +24,8 @@ def parse_binary_response( return _parse_acl(data) if request_type == BinaryReqType.NEIGHBOURS: return _parse_neighbours(data, context or {}) + if request_type == BinaryReqType.OWNER_INFO and len(data) >= 4: + return _parse_owner_info(data) return {"raw_hex": data.hex(), "request_type": request_type} @@ -89,6 +91,22 @@ def _parse_mma(data: bytes) -> dict: return out +def _parse_owner_info(data: bytes) -> dict: + """Parse GET_OWNER_INFO response: tag(4) + 'version\\nname\\nowner' (variable).""" + try: + text = data[4:].decode("utf-8", errors="replace").strip() + parts = text.split("\n", 2) + return { + "tag": int.from_bytes(data[:4], "little"), + "version": parts[0] if len(parts) > 0 else "", + "node_name": parts[1] if len(parts) > 1 else "", + "owner_info": parts[2] if len(parts) > 2 else "", + "raw_text": text, + } + except Exception: + return {"raw_hex": data.hex(), "request_type": BinaryReqType.OWNER_INFO} + + def _parse_acl(buf: bytes) -> dict: """ACL: 7-byte entries (key 6 + perm 1).""" res = [] diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py index 78382a0..a4887b8 100644 --- a/src/pymc_core/companion/companion_base.py +++ b/src/pymc_core/companion/companion_base.py @@ -808,7 +808,14 @@ async def send_binary_req( tag_bytes = tag_int.to_bytes(4, "little") tag_hex = tag_bytes.hex() request_type = data[0] if len(data) >= 1 else 0 - req_payload = tag_bytes + data + # The firmware's sendRequest(req_data) sends: firmware_tag(4) + req_data on the wire. + # onContactRequest receives data = req_data, so data[0] = req_type. + # create_protocol_request sends: timestamp(4) + protocol_code(1) + extra_data. + # onContactRequest receives data = [protocol_code, extra_data...]. + # So protocol_code must be req_type (data[0]), and extra_data is data[1:] plus + # a random blob (matching the firmware's entropy suffix) for packet uniqueness. + protocol_code = request_type + req_payload = data[1:] + tag_bytes # optional payload + random tag for uniqueness self.cleanup_expired_binary_requests() self.register_binary_request( tag_hex, @@ -820,7 +827,7 @@ async def send_binary_req( pkt, _ = PacketBuilder.create_protocol_request( contact=proxy, local_identity=self._identity, - protocol_code=PROTOCOL_CODE_BINARY_REQ, + protocol_code=protocol_code, data=req_payload, ) self._apply_flood_scope(pkt) @@ -839,6 +846,55 @@ async def send_binary_req( timeout_ms=DEFAULT_RESPONSE_TIMEOUT_MS, ) + async def send_anon_req( + self, pub_key: bytes, data: bytes, timeout_seconds: float = 15.0 + ) -> SentResult: + """Send anonymous request (CMD_SEND_ANON_REQ), e.g. owner info. + + data = request payload (e.g. [0x07] for GET_OWNER_INFO). Response is + delivered via on_binary_response (PUSH_CODE_BINARY_RESPONSE) like binary req. + """ + contact = self.contacts.get_by_key(pub_key) + if not contact: + return SentResult(success=False) + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return SentResult(success=False) + tag_int = random.randint(0, 0xFFFFFFFF) + tag_bytes = tag_int.to_bytes(4, "little") + tag_hex = tag_bytes.hex() + request_type = PROTOCOL_CODE_ANON_REQ + req_payload = data + tag_bytes + self.cleanup_expired_binary_requests() + self.register_binary_request( + tag_hex, + request_type=request_type, + timeout_seconds=timeout_seconds, + pubkey_prefix=pub_key[:6].hex(), + ) + try: + pkt, _ = PacketBuilder.create_protocol_request( + contact=proxy, + local_identity=self._identity, + protocol_code=PROTOCOL_CODE_ANON_REQ, + data=req_payload, + ) + self._apply_flood_scope(pkt) + success = await self._send_packet(pkt, wait_for_ack=False) + except Exception as e: + logger.error(f"Anon request send error: {e}") + self._pending_binary_requests.pop(tag_hex, None) + return SentResult(success=False) + if not success: + self._pending_binary_requests.pop(tag_hex, None) + return SentResult(success=False) + return SentResult( + success=True, + is_flood=contact.out_path_len <= 0, + expected_ack=tag_int, + timeout_ms=DEFAULT_RESPONSE_TIMEOUT_MS, + ) + async def send_path_discovery(self, pub_key: bytes) -> bool: """Legacy: send path discovery without returning tag. Prefer send_path_discovery_req.""" result = await self.send_path_discovery_req(pub_key) diff --git a/src/pymc_core/companion/companion_bridge.py b/src/pymc_core/companion/companion_bridge.py index 6e39bcb..f86a749 100644 --- a/src/pymc_core/companion/companion_bridge.py +++ b/src/pymc_core/companion/companion_bridge.py @@ -17,6 +17,7 @@ from ..node.handlers.login_server import LoginServerHandler from ..protocol import LocalIdentity, Packet from ..protocol.constants import ( + MAX_PATH_SIZE, PAYLOAD_TYPE_ACK, PAYLOAD_TYPE_ADVERT, PAYLOAD_TYPE_ANON_REQ, @@ -47,7 +48,8 @@ class _BridgeAckHandler: - """Handles ACK packets. Fires send_confirmed when ACK CRC matches.""" + """Handles ACK packets (discrete and PATH-carried). + Fires send_confirmed when ACK CRC matches.""" def __init__(self, bridge: "CompanionBridge") -> None: self._bridge = bridge @@ -60,15 +62,67 @@ async def __call__(self, packet: Packet) -> None: if not packet.payload or len(packet.payload) != 4: return crc = int.from_bytes(packet.payload, "little") - if crc in self._bridge._pending_ack_crcs: - self._bridge._pending_ack_crcs.discard(crc) - await self._bridge._fire_callbacks("send_confirmed", crc) + await self._apply_ack(crc) + + async def _apply_ack(self, crc: int) -> None: + """If CRC is pending, clear it and fire send_confirmed.""" + if crc not in self._bridge._pending_ack_crcs: + return + self._bridge._pending_ack_crcs.discard(crc) + await self._bridge._fire_callbacks("send_confirmed", crc) async def process_path_ack_variants(self, packet: Packet) -> Optional[int]: + """Decrypt PATH payload; update contact out_path (firmware pattern), return ACK CRC. + Tries every contact matching src_hash (same as TXT_MSG) so we use the correct key. + """ + from ..protocol import CryptoUtils, Identity + + payload = packet.payload + if not payload or len(payload) < 2 + 6: + return None + dest_hash = payload[0] + src_hash = payload[1] + our_hash = self._bridge._identity.get_public_key()[0] + if dest_hash != our_hash: + return None + encrypted = bytes(payload[2:]) + # Try each contact with matching src_hash until decryption succeeds + for contact in self._bridge.contacts.contacts: + try: + pk = contact.public_key + pub = bytes.fromhex(pk) if isinstance(pk, str) else bytes(pk) + if len(pub) != 32 or pub[0] != src_hash: + continue + peer_id = Identity(pub) + shared_secret = peer_id.calc_shared_secret(self._bridge._identity.get_private_key()) + aes_key = shared_secret[:16] + decrypted = CryptoUtils.mac_then_decrypt(aes_key, shared_secret, encrypted) + except Exception: + continue + if len(decrypted) < 2: + continue + path_len = min(decrypted[0], MAX_PATH_SIZE) + if 1 + path_len > len(decrypted): + continue + path_bytes = bytes(decrypted[1 : 1 + path_len]) + # Firmware pattern: onContactPathRecv stores out_path so replies can use sendDirect() + # Update the underlying Contact (store expects Contact with bytes public_key, not proxy) + contact_obj = self._bridge.contacts.get_by_key(pub) + if contact_obj: + contact_obj.out_path_len = path_len + contact_obj.out_path = path_bytes + self._bridge.contacts.update(contact_obj) + await self._bridge._fire_callbacks("contact_path_updated", pub, path_len, path_bytes) + # If this PATH carries an ACK, return it so send_confirmed can fire + extra_start = 1 + path_len + if len(decrypted) >= extra_start + 1 + 4 and decrypted[extra_start] == PAYLOAD_TYPE_ACK: + return int.from_bytes(decrypted[extra_start + 1 : extra_start + 5], "little") + return None return None async def _notify_ack_received(self, crc: int) -> None: - pass + """Called by path handler when PATH packet contained an ACK.""" + await self._apply_ack(crc) # --------------------------------------------------------------------------- diff --git a/src/pymc_core/companion/constants.py b/src/pymc_core/companion/constants.py index 78329cb..9653957 100644 --- a/src/pymc_core/companion/constants.py +++ b/src/pymc_core/companion/constants.py @@ -69,6 +69,7 @@ class BinaryReqType(IntEnum): MMA = 0x04 ACL = 0x05 NEIGHBOURS = 0x06 + OWNER_INFO = 0x07 # REQ_TYPE_GET_OWNER_INFO: variable "version\nname\nowner" # --------------------------------------------------------------------------- @@ -92,6 +93,10 @@ class BinaryReqType(IntEnum): # Frame Protocol Constants (MeshCore Companion Radio Protocol) # =========================================================================== +# Protocol version reported in RESP_CODE_DEVICE_INFO; phone uses 9+ to infer +# CMD_SEND_ANON_REQ (owner requests, etc.) is supported. +FIRMWARE_VER_CODE = 9 + # --------------------------------------------------------------------------- # Commands (app -> radio) # --------------------------------------------------------------------------- diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index 9506959..6a6d254 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -36,6 +36,7 @@ CMD_LOGOUT, CMD_REMOVE_CONTACT, CMD_RESET_PATH, + CMD_SEND_ANON_REQ, CMD_SEND_BINARY_REQ, CMD_SEND_CHANNEL_TXT_MSG, CMD_SEND_CONTROL_DATA, @@ -63,6 +64,7 @@ ERR_CODE_NOT_FOUND, ERR_CODE_TABLE_FULL, ERR_CODE_UNSUPPORTED_CMD, + FIRMWARE_VER_CODE, FRAME_INBOUND_PREFIX, FRAME_OUTBOUND_PREFIX, MAX_FRAME_SIZE, @@ -168,7 +170,7 @@ def __init__( bind_address: str = "0.0.0.0", *, device_model: str = "pyMC-Companion", - device_version: str = "1.0.0", + device_version: Optional[str] = None, build_date: str = "", local_hash: Optional[int] = None, stats_getter: Optional[Callable] = None, @@ -186,7 +188,11 @@ def __init__( self._client_reader: Optional[asyncio.StreamReader] = None self._app_target_ver = 0 - # Pre-compute padded device info bytes for _cmd_device_query + # Pre-compute padded device info bytes for _cmd_device_query. Version string + # should reflect FIRMWARE_VER_CODE so clients that parse it see 9+ (owner/anon). + if device_version is None: + # At least 2 chars so client substring(0, 2) etc. doesn't RangeError + device_version = f"{FIRMWARE_VER_CODE}.0" self._build_date_bytes = (build_date.encode("utf-8") + b"\x00")[:12].ljust(12, b"\x00") self._model_bytes = (device_model.encode("utf-8") + b"\x00")[:40].ljust(40, b"\x00") self._version_bytes = (device_version.encode("utf-8") + b"\x00")[:20].ljust(20, b"\x00") @@ -660,6 +666,8 @@ async def _handle_cmd(self, payload: bytes) -> None: await self._cmd_set_channel(data) elif cmd == CMD_SEND_BINARY_REQ: await self._cmd_send_binary_req(data) + elif cmd == CMD_SEND_ANON_REQ: + await self._cmd_send_anon_req(data) elif cmd == CMD_SEND_PATH_DISCOVERY_REQ: await self._cmd_send_path_discovery_req(data) elif cmd == CMD_SEND_CONTROL_DATA: @@ -744,14 +752,23 @@ async def _cmd_app_start(self, data: bytes) -> None: self._write_frame(frame) async def _cmd_device_query(self, data: bytes) -> None: + # Layout must match MeshCore companion_radio MyMesh.cpp handleCmdFrame() CMD_DEVICE_QEURY: + # [0]=RESP_CODE_DEVICE_INFO, [1]=FIRMWARE_VER_CODE, [2]=MAX_CONTACTS/2, + # [3]=MAX_GROUP_CHANNELS, [4..7]=ble_pin, [8..19]=build_date(12), + # [20..59]=manufacturer(40), [60..79]=version(20), [80]=client_repeat. if len(data) >= 1: self._app_target_ver = data[0] - firmware_ver = 8 + firmware_ver = FIRMWARE_VER_CODE max_contacts = getattr(getattr(self.bridge, "contacts", None), "max_contacts", 1000) max_channels_val = getattr(getattr(self.bridge, "channels", None), "max_channels", 40) max_contacts_div_2 = min(max_contacts // 2, 255) max_channels = min(max_channels_val, 255) ble_pin = 0 + try: + prefs = self.bridge.get_self_info() + client_repeat = getattr(prefs, "client_repeat", 0) & 0xFF + except Exception: + client_repeat = 0 frame = ( bytes( [ @@ -765,6 +782,15 @@ async def _cmd_device_query(self, data: bytes) -> None: + self._build_date_bytes + self._model_bytes + self._version_bytes + + bytes([client_repeat & 0xFF]) + ) + version_str = self._version_bytes.split(b"\x00")[0].decode("utf-8", errors="replace") + logger.info( + "Companion device info sent: FIRMWARE_VER_CODE=%s (byte at index 1), " + "version string=%r, frame_len=%s", + firmware_ver, + version_str, + len(frame), ) self._write_frame(frame) @@ -917,6 +943,32 @@ async def _cmd_send_binary_req(self, data: bytes) -> None: ) self._write_frame(frame) + async def _cmd_send_anon_req(self, data: bytes) -> None: + if len(data) < 33: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + pubkey = data[:32] + req_data = data[32:] + send_anon_req = getattr(self.bridge, "send_anon_req", None) + if not send_anon_req: + self._write_err(ERR_CODE_UNSUPPORTED_CMD) + return + try: + result = await send_anon_req(pubkey, req_data) + except Exception as e: + logger.error("send_anon_req error: %s", e, exc_info=True) + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + if not result.success: + self._write_err(ERR_CODE_NOT_FOUND) + return + tag = result.expected_ack if result.expected_ack is not None else 0 + timeout_ms = result.timeout_ms if result.timeout_ms is not None else 10000 + frame = bytes([RESP_CODE_SENT, 1 if result.is_flood else 0]) + struct.pack( + " None: if len(data) < 2: self._write_err(ERR_CODE_ILLEGAL_ARG) diff --git a/src/pymc_core/hardware/kiss_modem_wrapper.py b/src/pymc_core/hardware/kiss_modem_wrapper.py index 278ac8b..c9ba3d6 100644 --- a/src/pymc_core/hardware/kiss_modem_wrapper.py +++ b/src/pymc_core/hardware/kiss_modem_wrapper.py @@ -20,6 +20,7 @@ import serial +from ..protocol.packet_utils import PacketTimingUtils from .base import LoRaRadio # RX callback: (data) for backward compat, or (data, rssi, snr) for per-packet metrics @@ -759,17 +760,21 @@ def is_channel_busy(self) -> bool: return resp[1][0] == 0x01 return False - def get_airtime(self, packet_length: int) -> Optional[int]: + def get_airtime(self, packet_length: int, timeout: Optional[float] = None) -> Optional[int]: """ - Get estimated airtime for a packet + Get estimated airtime for a packet from the modem. Args: packet_length: Length of packet in bytes + timeout: Response timeout in seconds (default: RESPONSE_TIMEOUT). + Use a shorter value (e.g. 1.0) in the TX path to avoid + blocking when the modem is busy or unresponsive. Returns: - Airtime in milliseconds or None on error + Airtime in milliseconds or None on error/timeout """ - resp = self._send_command(CMD_GET_AIRTIME, bytes([packet_length])) + t = timeout if timeout is not None else RESPONSE_TIMEOUT + resp = self._send_command(CMD_GET_AIRTIME, bytes([packet_length]), timeout=t) if resp and resp[0] == RESP_AIRTIME and len(resp[1]) >= 4: return struct.unpack(" Optional[Dict[str, Any]]: if not success: raise Exception("Failed to send frame via KISS modem") - airtime = self.get_airtime(len(data)) + # Use short timeout for GET_AIRTIME so TX path is not blocked if modem + # is busy or unresponsive (avoids 5s stall and subsequent bad state). + airtime = self.get_airtime(len(data), timeout=1.0) + if airtime is None: + airtime = int(PacketTimingUtils.estimate_airtime_ms(len(data), self.radio_config)) return { - "airtime_ms": airtime if airtime is not None else 0, + "airtime_ms": airtime, "lbt_attempts": len(lbt_backoff_delays), "lbt_backoff_delays_ms": lbt_backoff_delays, "lbt_channel_busy": len(lbt_backoff_delays) > 0, diff --git a/src/pymc_core/node/handlers/protocol_request.py b/src/pymc_core/node/handlers/protocol_request.py index 54260df..5b33f05 100644 --- a/src/pymc_core/node/handlers/protocol_request.py +++ b/src/pymc_core/node/handlers/protocol_request.py @@ -17,6 +17,7 @@ REQ_TYPE_GET_TELEMETRY_DATA = 0x03 REQ_TYPE_GET_ACCESS_LIST = 0x05 REQ_TYPE_GET_NEIGHBOURS = 0x06 +REQ_TYPE_GET_OWNER_INFO = 0x07 # Variable-length: tag(4) + "version\nname\nowner" # Response delay (matching C++ SERVER_RESPONSE_DELAY) SERVER_RESPONSE_DELAY_MS = 500 diff --git a/src/pymc_core/node/handlers/protocol_response.py b/src/pymc_core/node/handlers/protocol_response.py index 1094709..86a073e 100644 --- a/src/pymc_core/node/handlers/protocol_response.py +++ b/src/pymc_core/node/handlers/protocol_response.py @@ -10,6 +10,7 @@ from ...protocol import CryptoUtils, Identity, Packet from ...protocol.constants import MAX_PATH_SIZE, PAYLOAD_TYPE_PATH, PAYLOAD_TYPE_RESPONSE +from ...protocol.crypto import CIPHER_BLOCK_SIZE, CIPHER_MAC_SIZE # --------------------------------------------------------------------------- # Built-in CayenneLPP decoder (no external dependency) @@ -57,6 +58,10 @@ def _decode_cayenne_lpp(data: bytes) -> list: while idx + 2 <= len(data): channel = data[idx] type_id = data[idx + 1] + # Channel 0 is never used by MeshCore firmware (channels start at + # TELEM_CHANNEL_SELF=1). A channel=0 byte is AES zero-padding — stop. + if channel == 0: + break idx += 2 spec = _LPP_TYPES.get(type_id) if spec is None: @@ -186,19 +191,15 @@ async def __call__(self, pkt: Packet) -> None: # stats, repeater command), deliver there first. The binary/path-discovery # callback is a generic fallback for unsolicited binary responses. # - # Guard: skip responses that are clearly NOT protocol responses (e.g. a - # stale login response retransmission). Protocol responses always decrypt - # to a tag(4) + meaningful payload, so ≥20 bytes. Login responses are only - # ~12 bytes and parse as "binary" fallback. Without this check a - # retransmitted login response can consume the stats/telemetry waiter. + # Guard: only skip when this is a login response (13 bytes, response_code at [4] + # 0x00/0x01). A broad "decrypted_len < 20" would drop valid PATH-wrapped stats + # or other short responses and delay stats load after login. if src_hash in self._response_callbacks: - resp_type = parsed_data.get("type") if isinstance(parsed_data, dict) else None - decrypted_len = len(raw_decrypted) if raw_decrypted else 0 - if not success or (resp_type == "binary" and decrypted_len < 20): - self._log( - f"[ProtocolResponse] Ignoring non-protocol response for 0x{src_hash:02X} " - f"(success={success}, type={resp_type}, decrypted_len={decrypted_len})" - ) + if not success: + return + if self._is_login_response(pkt, raw_decrypted): + # Login responses are handled by LoginResponseHandler; do not deliver to + # stats/telemetry waiter. return callback = self._response_callbacks[src_hash] if callback: @@ -243,6 +244,13 @@ async def __call__(self, pkt: Packet) -> None: tag_bytes = raw_decrypted[:4] response_data = raw_decrypted[4:] + # Do not deliver login responses to the binary callback; they are + # handled by LoginResponseHandler. Login response format is + # tag(4) + response_code(1) + keep_alive(1) + is_admin(1) + ... + # = 13 bytes total, with response_code 0x00 or 0x01. + if len(response_data) == 9 and response_data[0] in (0x00, 0x01): + return + try: cb_result = self._binary_response_callback(tag_bytes, response_data, path_info) if asyncio.iscoroutine(cb_result): @@ -254,6 +262,25 @@ async def __call__(self, pkt: Packet) -> None: except Exception as e: self._log(f"[ProtocolResponse] Error processing protocol response: {e}") + def _is_login_response(self, pkt: Packet, raw_decrypted: Optional[bytes]) -> bool: + """True if raw_decrypted is a login response (13 bytes, response_code at [4] in 0x00/0x01). + Used to avoid delivering a retransmitted login to the stats/telemetry waiter. + """ + if not raw_decrypted or len(raw_decrypted) < 13: + return False + pkt_type = (pkt.header >> 2) & 0x0F + if pkt_type == PAYLOAD_TYPE_PATH: + if len(raw_decrypted) < 2: + return False + path_len_byte = raw_decrypted[0] + inner_offset = 1 + path_len_byte + 1 + if len(raw_decrypted) < inner_offset + 13: + return False + inner = raw_decrypted[inner_offset : inner_offset + 13] + else: + inner = raw_decrypted[:13] + return len(inner) == 13 and inner[4] in (0x00, 0x01) + async def _decrypt_protocol_response( self, pkt: Packet, src_hash: int ) -> tuple[bool, str, Dict[str, Any], Optional[bytes]]: @@ -262,37 +289,48 @@ async def _decrypt_protocol_response( Handles both packet types: - RESPONSE (0x01): direct → tag(4)+data - PATH (0x08): path_len+path(N)+extra_type+extra - """ - try: - # Find the contact by hash - contact = self._find_contact_by_hash(src_hash) - if not contact: - return False, f"Unknown contact for hash 0x{src_hash:02X}", {}, None - # Get encryption keys - contact_pubkey = bytes.fromhex(contact.public_key) - peer_id = Identity(contact_pubkey) - shared_secret = peer_id.calc_shared_secret(self._local_identity.get_private_key()) - aes_key = shared_secret[:16] - - # Extract encrypted data (skip dest_hash(1) + src_hash(1)) - encrypted_data = pkt.payload[2:] - - # Decrypt the payload - decrypted = CryptoUtils.mac_then_decrypt(aes_key, shared_secret, encrypted_data) - - # Determine the actual payload type from the incoming packet header. - pkt_type = (pkt.header >> 2) & 0x0F + Both use same wire payload layout: dest_hash(1) + src_hash(1) + MAC(2) + ciphertext. + """ + payload = pkt.get_payload() + if len(payload) < 2 + 4: # need dest+src + at least MAC(2)+min ciphertext + return False, "Payload too short", {}, None + encrypted_data = payload[2:] + # MAC(2) + ciphertext; ciphertext must be block-aligned (16 bytes) + # (20 = login response, 68 = stats. Variable: REQ_TYPE_GET_OWNER_INFO can produce 19 bytes.) + enc_len = len(encrypted_data) + if enc_len <= CIPHER_MAC_SIZE or (enc_len - CIPHER_MAC_SIZE) % CIPHER_BLOCK_SIZE != 0: + self._log( + f"[ProtocolResponse] Payload truncated or invalid length for hash " + f"0x{src_hash:02X}: encrypted_data={enc_len}B (need MAC(2)+16k ciphertext)" + ) + return False, "Payload truncated or invalid length", {}, None + pkt_type = (pkt.header >> 2) & 0x0F + + # Try every contact matching src_hash (same “try all hash matches” as TXT_MSG and PATH ACK). + # Repeaters use the same ECDH shared secret as login (createPathReturn(..., secret, ...)). + # Firmware: ed25519_key_exchange uses first 32B of priv (clamped) and (y+1)/(1-y) for peer + # pub; we match via libsodium ed25519_pk_to_curve25519 + scalarmult. + contacts_tried = list(self._contacts_by_hash(src_hash)) + for contact in contacts_tried: + try: + pk = contact.public_key + contact_pubkey = pk if isinstance(pk, bytes) else bytes.fromhex(pk) + if len(contact_pubkey) != 32: + continue + peer_id = Identity(contact_pubkey) + shared_secret = peer_id.calc_shared_secret(self._local_identity.get_private_key()) + aes_key = shared_secret[:16] + decrypted = CryptoUtils.mac_then_decrypt(aes_key, shared_secret, encrypted_data) + except Exception: + continue - # Extract the actual response data based on packet type. + # Determine the actual response data based on packet type. response_data = decrypted - if pkt_type == PAYLOAD_TYPE_PATH: - # Path-return format: path_len(1) + path(N) + extra_type(1) + extra_data - # The actual protocol response is inside the 'extra' field. - if len(decrypted) >= 2: # need at least path_len + extra_type + if len(decrypted) >= 2: path_len_byte = decrypted[0] - inner_offset = 1 + path_len_byte + 1 # path_len + path + extra_type + inner_offset = 1 + path_len_byte + 1 if path_len_byte <= MAX_PATH_SIZE and len(decrypted) >= inner_offset: extra_type = decrypted[1 + path_len_byte] & 0x0F if extra_type == PAYLOAD_TYPE_RESPONSE and len(decrypted) > inner_offset: @@ -303,28 +341,61 @@ async def _decrypt_protocol_response( f"not RESPONSE" ) - # Parse based on content type success, text, parsed = self._parse_protocol_response(response_data) return success, text, parsed, decrypted - except Exception as e: - self._log(f"[ProtocolResponse] Decryption failed: {e}") - return False, f"Decryption failed: {e}", {}, None + # Log once per packet: no contact or HMAC failed for every matching contact + if not contacts_tried: + self._log( + f"[ProtocolResponse] No contact for hash 0x{src_hash:02X}, " + "cannot decrypt PATH/RESPONSE" + ) + else: + self._log( + f"[ProtocolResponse] HMAC failed for hash 0x{src_hash:02X} " + f"(tried {len(contacts_tried)} contact(s), repeater PATH uses same ECDH as login)" + ) + return False, "Decryption failed: Invalid HMAC", {}, None def _parse_protocol_response(self, data: bytes) -> tuple[bool, str, Dict[str, Any]]: """Parse decrypted protocol response data. - Parse order mirrors MeshCore firmware priority: - 1. Stats (RepeaterStats struct, ≥52 bytes) - 2. Text / status (UTF-8 printable after stripping tag + nulls) - 3. Telemetry (reflected_timestamp + valid CayenneLPP with ≥1 sensor) + Parse order: + 0. Login response (13 bytes, response_code at [4] 0x00/0x01) → binary, + for LoginResponseHandler. + 1. Telemetry (reflected_timestamp + valid CayenneLPP signature byte check) + 2. Stats (RepeaterStats struct, ≥52 bytes, only when not telemetry) + 3. Text / status (UTF-8 printable after stripping tag + nulls) 4. Binary fallback + + Telemetry is checked first because CayenneLPP data can be ≥56 bytes for + sensors with many readings, which would otherwise be misidentified as stats. + The telemetry signature check (channel=1, type=0x74) is cheap and reliable. """ try: - # 1. Check if this looks like a stats response (protocol 0x01) + # 0. Login responses are 13 bytes (tag(4) + response_code(1) + keep_alive(1) + ...). + # Do not parse as telemetry/stats; LoginResponseHandler will handle them. + if len(data) == 13 and data[4] in (0x00, 0x01): + return ( + True, + "Binary response: " + data.hex(), + {"type": "binary", "hex": data.hex()}, + ) + + # 1. Check if this looks like a telemetry response (protocol 0x03). + # MeshCore always starts telemetry with addVoltage(TELEM_CHANNEL_SELF=1, ...) + # which produces LPP channel=0x01, type=0x74 (LPP_VOLTAGE) as first record. + # This signature reliably distinguishes telemetry from stats/text responses. + if len(data) >= 8: # tag(4) + at least one LPP record (ch+type+val = 3+) + telemetry_result = self._parse_telemetry_response(data) + if telemetry_result and telemetry_result.get("sensor_count", 0) > 0: + return True, telemetry_result["formatted"], telemetry_result + + # 2. Check if this looks like a stats response (protocol 0x01). # RepeaterStats is 48-56 bytes + 4-byte tag. Older firmware # omits n_recv_errors (52 B struct → 56 total); PATH-wrapped # responses may also lose trailing bytes to AES block alignment. + # Only reached if telemetry signature check above failed. if len(data) >= 56: stats_result = self._parse_stats_response(data) if stats_result: @@ -340,7 +411,7 @@ def _parse_protocol_response(self, data: bytes) -> tuple[bool, str, Dict[str, An ) return True, stats_result["formatted"], result_dict - # 2. Try parsing as text/status response. + # 3. Try parsing as text/status response. # Status responses are tag(4) + UTF-8 text. Strip the 4-byte # tag that prefixes every response, then check for printable text. if len(data) > 4: @@ -355,13 +426,6 @@ def _parse_protocol_response(self, data: bytes) -> tuple[bool, str, Dict[str, An except UnicodeDecodeError: pass - # 3. Check if this looks like a telemetry response (protocol 0x03) - # Must decode at least one sensor from valid CayenneLPP after the tag. - if len(data) >= 8: # tag(4) + at least one LPP record (ch+type+val = 3+) - telemetry_result = self._parse_telemetry_response(data) - if telemetry_result and telemetry_result.get("sensor_count", 0) > 0: - return True, telemetry_result["formatted"], telemetry_result - # 4. Fall back to hex representation hex_response = data.hex() return ( @@ -586,17 +650,19 @@ def _format_stats(self, stats: Dict[str, Any]) -> str: return " | ".join(result) def _find_contact_by_hash(self, contact_hash: int): - """Find contact by hash value.""" - if not self._contact_book: - return None + """Find first contact by hash value.""" + for contact in self._contacts_by_hash(contact_hash): + return contact + return None - # Search through contacts to find one with matching hash + def _contacts_by_hash(self, contact_hash: int): + """Yield all contacts whose public_key first byte matches contact_hash.""" + if not self._contact_book: + return for contact in self._contact_book.list_contacts(): try: contact_pubkey = bytes.fromhex(contact.public_key) if contact_pubkey[0] == contact_hash: - return contact + yield contact except (ValueError, IndexError): continue - - return None diff --git a/src/pymc_core/node/handlers/text.py b/src/pymc_core/node/handlers/text.py index f2d7c1d..43dc908 100644 --- a/src/pymc_core/node/handlers/text.py +++ b/src/pymc_core/node/handlers/text.py @@ -31,34 +31,56 @@ def set_command_response_callback(self, callback): """Set callback function for command responses.""" self.command_response_callback = callback + def _contact_pubkey_bytes(self, contact) -> bytes: + """Return contact's public key as 32 bytes (handles hex str or bytes).""" + pk = contact.public_key + return bytes.fromhex(pk) if isinstance(pk, str) else bytes(pk) + async def __call__(self, packet: Packet) -> None: if len(packet.payload) < 4: self.log("TXT_MSG payload too short to decrypt") return src_hash = packet.payload[1] - matched_contact = None + # Collect all contacts whose public key first byte matches src_hash (hash collision + # possible) + candidates = [] for contact in self.contacts.contacts: try: - if bytes.fromhex(contact.public_key)[0] == src_hash: - matched_contact = contact - break + pk = self._contact_pubkey_bytes(contact) + if len(pk) >= 1 and pk[0] == src_hash: + candidates.append(contact) except Exception as err: self.log(f"Error reading contact key: {err}") - if not matched_contact: + if not candidates: self.log(f"No contact found for src hash: {src_hash:02X}") return - peer_id = Identity(bytes.fromhex(matched_contact.public_key)) - shared_secret = peer_id.calc_shared_secret(self.local_identity.get_private_key()) - aes_key = shared_secret[:16] payload = packet.payload[2:] # Skip dest_hash and src_hash - - try: - decrypted = CryptoUtils.mac_then_decrypt(aes_key, shared_secret, payload) - except Exception as err: - self.log(f"Decryption failed: {err}") + matched_contact = None + decrypted = None + shared_secret = None + for contact in candidates: + try: + pubkey_bytes = self._contact_pubkey_bytes(contact) + if len(pubkey_bytes) != 32: + continue + peer_id = Identity(pubkey_bytes) + ss = peer_id.calc_shared_secret(self.local_identity.get_private_key()) + aes_key = ss[:16] + decrypted = CryptoUtils.mac_then_decrypt(aes_key, ss, payload) + matched_contact = contact + shared_secret = ss + break + except Exception: + continue + + if matched_contact is None or decrypted is None: + self.log( + f"Decryption failed: Invalid HMAC for all {len(candidates)} contact(s) " + f"with src hash {src_hash:02X}" + ) return if len(decrypted) < 5: # timestamp(4) + flags(1) minimum @@ -72,7 +94,7 @@ async def __call__(self, packet: Packet) -> None: txt_type = (flags >> 2) & 0x3F # Upper 6 bits are txt_type message_body = decrypted[5:] # Rest is the message content - pubkey = bytes.fromhex(matched_contact.public_key) + pubkey = self._contact_pubkey_bytes(matched_contact) timestamp_int = int.from_bytes(timestamp, "little") # Determine message routing type from packet header diff --git a/src/pymc_core/protocol/__init__.py b/src/pymc_core/protocol/__init__.py index c583630..d85c710 100644 --- a/src/pymc_core/protocol/__init__.py +++ b/src/pymc_core/protocol/__init__.py @@ -42,6 +42,7 @@ PH_VER_MASK, PH_VER_SHIFT, PUB_KEY_SIZE, + REQ_TYPE_GET_OWNER_INFO, REQ_TYPE_GET_STATUS, REQ_TYPE_GET_TELEMETRY_DATA, ROUTE_TYPE_DIRECT, @@ -152,6 +153,7 @@ # Protocol request types "REQ_TYPE_GET_STATUS", "REQ_TYPE_GET_TELEMETRY_DATA", + "REQ_TYPE_GET_OWNER_INFO", "TELEM_PERM_BASE", "TELEM_PERM_LOCATION", "TELEM_PERM_ENVIRONMENT", diff --git a/src/pymc_core/protocol/constants.py b/src/pymc_core/protocol/constants.py index 9764875..443d94b 100644 --- a/src/pymc_core/protocol/constants.py +++ b/src/pymc_core/protocol/constants.py @@ -111,6 +111,7 @@ def describe_advert_flags(flags: int) -> str: # Protocol Request Types REQ_TYPE_GET_STATUS = 0x01 # Get repeater stats (RepeaterStats struct) REQ_TYPE_GET_TELEMETRY_DATA = 0x03 # Get telemetry data (CayenneLPP) +REQ_TYPE_GET_OWNER_INFO = 0x07 # Variable-length: tag(4) + "version\nname\nowner" (simple_repeater) TELEM_PERM_BASE = 0x01 TELEM_PERM_LOCATION = 0x02 TELEM_PERM_ENVIRONMENT = 0x04 diff --git a/src/pymc_core/protocol/identity.py b/src/pymc_core/protocol/identity.py index 8bd499e..bd6a23f 100644 --- a/src/pymc_core/protocol/identity.py +++ b/src/pymc_core/protocol/identity.py @@ -100,17 +100,19 @@ def __init__(self, seed: Optional[bytes] = None): if seed and len(seed) == 64: from nacl.bindings import crypto_scalarmult_ed25519_base_noclamp - # MeshCore format: [32-byte scalar][32-byte nonce]; firmware clamps first 32 for ECDH + # MeshCore format: [32-byte scalar][32-byte nonce]. Identity.cpp readFrom(64) calls + # ed25519_derive_pub(pub_key, prv_key) which uses first 32 bytes as-is (no clamp). + # key_exchange.c uses first 32 bytes clamped for ECDH. We must match both. self._firmware_key = seed self.signing_key = None - # Use X25519 clamping so ECDH matches firmware's ed25519_key_exchange() - scalar = seed[:32] - clamped = CryptoUtils.x25519_clamp_scalar(scalar) - ed25519_pub = crypto_scalarmult_ed25519_base_noclamp(clamped) + scalar_first32 = seed[:32] + # Ed25519 public: derive without clamping so get_public_key() matches firmware + ed25519_pub = crypto_scalarmult_ed25519_base_noclamp(scalar_first32) self.verify_key = VerifyKey(ed25519_pub) - # Use clamped scalar for ECDH (firmware key_exchange.c uses first 32 bytes clamped) + # ECDH: use clamped scalar so shared secret matches firmware's ed25519_key_exchange() + clamped = CryptoUtils.x25519_clamp_scalar(scalar_first32) self._x25519_private = clamped self._x25519_public = CryptoUtils.scalarmult_base(clamped) else: From 5b3955138e031fee436f06d3ab32ab9f6e6d1b83 Mon Sep 17 00:00:00 2001 From: agessaman Date: Sat, 21 Feb 2026 22:29:47 -0800 Subject: [PATCH 17/50] feat(companion): add firmware version level handling in login response - Introduced `firmware_ver_level` to the login response parsing in `LoginResponseHandler`, allowing the extraction of firmware version information from the response payload. - Updated `CompanionBase` to include `firmware_ver_level` in the login success response. - Enhanced `CompanionFrameServer` to utilize the new firmware version level during login, providing a fallback mechanism for compatibility. - Adjusted comments and documentation to reflect changes in the response format and handling logic. --- src/pymc_core/companion/companion_base.py | 1 + src/pymc_core/companion/frame_server.py | 9 +++++++-- src/pymc_core/node/handlers/login_response.py | 6 ++++-- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py index a4887b8..746c07d 100644 --- a/src/pymc_core/companion/companion_base.py +++ b/src/pymc_core/companion/companion_base.py @@ -1133,6 +1133,7 @@ def _login_cb(success: bool, data: dict) -> None: "keep_alive_interval": data.get("keep_alive_interval", 0), "tag": data.get("timestamp", 0), "acl_permissions": data.get("reserved", data.get("permissions", 0)), + "firmware_ver_level": data.get("firmware_ver_level"), "reason": "Login successful" if login_result["success"] else "Login failed", } except Exception as e: diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index 6a6d254..ce5205a 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -754,8 +754,8 @@ async def _cmd_app_start(self, data: bytes) -> None: async def _cmd_device_query(self, data: bytes) -> None: # Layout must match MeshCore companion_radio MyMesh.cpp handleCmdFrame() CMD_DEVICE_QEURY: # [0]=RESP_CODE_DEVICE_INFO, [1]=FIRMWARE_VER_CODE, [2]=MAX_CONTACTS/2, - # [3]=MAX_GROUP_CHANNELS, [4..7]=ble_pin, [8..19]=build_date(12), - # [20..59]=manufacturer(40), [60..79]=version(20), [80]=client_repeat. + # [3]=MAX_GROUP_CHANNELS, [4..7]=ble_pin, [8..19]=build_date(12), [20..59]=manufacturer(40), + # [60..79]=version(20), [80]=client_repeat. if len(data) >= 1: self._app_target_ver = data[0] firmware_ver = FIRMWARE_VER_CODE @@ -1141,6 +1141,10 @@ async def _cmd_send_login(self, data: bytes) -> None: self._write_frame(bytes([RESP_CODE_SENT, 1]) + struct.pack("= 2 for owner info self._write_frame( bytes( [ @@ -1151,6 +1155,7 @@ async def _cmd_send_login(self, data: bytes) -> None: + pubkey[:6] + struct.pack("= 13 else None return { "timestamp": timestamp, @@ -185,6 +186,7 @@ async def _decrypt_response( "is_admin": bool(is_admin), "reserved": reserved, "random_blob": random_blob, + "firmware_ver_level": firmware_ver_level, "contact": contact, } From 452de65d19de40e43e9e64193274e772ef57283a Mon Sep 17 00:00:00 2001 From: agessaman Date: Sun, 22 Feb 2026 08:58:00 -0800 Subject: [PATCH 18/50] fix(companion): prevent outgoing messages from being pushed as incoming - Added a check in `_handle_new_channel_message` to ignore messages marked as outgoing, ensuring that only incoming messages are processed and sent to the client. This change improves message handling and prevents unnecessary duplication in the client interface. --- src/pymc_core/companion/companion_base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py index 746c07d..d5faa45 100644 --- a/src/pymc_core/companion/companion_base.py +++ b/src/pymc_core/companion/companion_base.py @@ -1414,6 +1414,10 @@ async def _handle_new_message(self, data: dict) -> None: ) async def _handle_new_channel_message(self, data: dict) -> None: + # Do not push our own (outgoing) channel messages to the client as incoming. + if data.get("is_outgoing"): + return + # Deduplicate by packet hash so we queue one frame per logical message, matching # firmware: Mesh.cpp only calls onChannelMessageRecv when !_tables->hasSeen(pkt). pkt_hash = data.get("packet_hash") From beb45b1815f15d67ff5957d4ca6f54cf29490583 Mon Sep 17 00:00:00 2001 From: agessaman Date: Mon, 23 Feb 2026 19:30:08 -0800 Subject: [PATCH 19/50] fix(companion): fix multi-hop stats, binary request tag matching, and reciprocal PATH MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Send reciprocal PATH after flooded PATH login (mirrors C++ Mesh.cpp:168-169) so remote repeaters learn the return route and respond via DIRECT - Await reciprocal PATH TX before signaling login success to prevent race - Add propagation delay for multi-hop stats/telemetry requests - Fix binary request tag mismatch: use timestamp from create_protocol_request as tag (matching C++ companion pattern) instead of random value the firmware never echoes back — fixes neighbors, ACL, and all CMD_SEND_BINARY_REQ ops - Add sanity checks to stats parser to prevent misidentifying binary responses --- docs/docs/companion.md | 2 +- src/pymc_core/companion/companion_base.py | 94 +++++--- src/pymc_core/companion/companion_bridge.py | 47 +++- src/pymc_core/companion/frame_server.py | 4 + src/pymc_core/node/handlers/group_text.py | 11 +- .../node/handlers/protocol_response.py | 227 ++++++++++++++++-- src/pymc_core/node/handlers/registry.py | 1 + src/pymc_core/protocol/crypto.py | 9 +- src/pymc_core/protocol/packet_builder.py | 13 +- tests/test_handlers.py | 28 +++ 10 files changed, 367 insertions(+), 69 deletions(-) diff --git a/docs/docs/companion.md b/docs/docs/companion.md index 686cbe9..ebdfd2e 100644 --- a/docs/docs/companion.md +++ b/docs/docs/companion.md @@ -562,7 +562,7 @@ The frame server sends unsolicited push frames to the companion app when events | `PUSH_CODE_LOGIN_SUCCESS` | 0x91 | Repeater login succeeded | | `PUSH_CODE_LOGIN_FAIL` | 0x92 | Repeater login failed | | `PUSH_CODE_STATUS_RESPONSE` | 0x93 | Repeater status response | -| `PUSH_CODE_TELEMETRY_RESPONSE` | 0x94 | Telemetry response | +| `PUSH_CODE_TELEMETRY_RESPONSE` | 0x8B | Telemetry response | | `PUSH_CODE_BINARY_RESPONSE` | 0x95 | Binary request response | | `PUSH_CODE_PATH_DISCOVERY_RESPONSE` | 0x96 | Path discovery response | diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py index d5faa45..e6206ba 100644 --- a/src/pymc_core/companion/companion_base.py +++ b/src/pymc_core/companion/companion_base.py @@ -49,6 +49,7 @@ PROTOCOL_CODE_ANON_REQ, PROTOCOL_CODE_BINARY_REQ, PROTOCOL_CODE_RAW_DATA, + PUSH_CODE_TELEMETRY_RESPONSE, STATS_TYPE_CORE, STATS_TYPE_PACKETS, STATS_TYPE_RADIO, @@ -804,37 +805,39 @@ async def send_binary_req( proxy = self.contacts.get_by_name(contact.name) if not proxy: return SentResult(success=False) - tag_int = random.randint(0, 0xFFFFFFFF) - tag_bytes = tag_int.to_bytes(4, "little") - tag_hex = tag_bytes.hex() request_type = data[0] if len(data) >= 1 else 0 - # The firmware's sendRequest(req_data) sends: firmware_tag(4) + req_data on the wire. - # onContactRequest receives data = req_data, so data[0] = req_type. - # create_protocol_request sends: timestamp(4) + protocol_code(1) + extra_data. - # onContactRequest receives data = [protocol_code, extra_data...]. - # So protocol_code must be req_type (data[0]), and extra_data is data[1:] plus - # a random blob (matching the firmware's entropy suffix) for packet uniqueness. + # C++ companion pattern (BaseChatMesh::sendRequest): + # tag = getRTCClock()->getCurrentTimeUnique() + # memcpy(temp, &tag, 4); memcpy(&temp[4], req_data, data_len); + # create_protocol_request packs: timestamp(4) + protocol_code(1) + extra_data. + # The repeater echoes sender_timestamp (bytes 0-3) in the response. + # So the timestamp IS the tag — we capture it from create_protocol_request. protocol_code = request_type - req_payload = data[1:] + tag_bytes # optional payload + random tag for uniqueness + req_payload = data[1:] # request params only; timestamp provides uniqueness self.cleanup_expired_binary_requests() - self.register_binary_request( - tag_hex, - request_type=request_type, - timeout_seconds=timeout_seconds, - pubkey_prefix=pub_key[:6].hex(), - ) try: - pkt, _ = PacketBuilder.create_protocol_request( + pkt, timestamp = PacketBuilder.create_protocol_request( contact=proxy, local_identity=self._identity, protocol_code=protocol_code, data=req_payload, ) + # Use the timestamp as the tag — matches what the repeater echoes back + tag_int = timestamp + tag_bytes = tag_int.to_bytes(4, "little") + tag_hex = tag_bytes.hex() + self.register_binary_request( + tag_hex, + request_type=request_type, + timeout_seconds=timeout_seconds, + pubkey_prefix=pub_key[:6].hex(), + ) self._apply_flood_scope(pkt) success = await self._send_packet(pkt, wait_for_ack=False) except Exception as e: logger.error(f"Binary request send error: {e}") - self._pending_binary_requests.pop(tag_hex, None) + if "tag_hex" in locals(): + self._pending_binary_requests.pop(tag_hex, None) return SentResult(success=False) if not success: self._pending_binary_requests.pop(tag_hex, None) @@ -860,30 +863,32 @@ async def send_anon_req( proxy = self.contacts.get_by_name(contact.name) if not proxy: return SentResult(success=False) - tag_int = random.randint(0, 0xFFFFFFFF) - tag_bytes = tag_int.to_bytes(4, "little") - tag_hex = tag_bytes.hex() request_type = PROTOCOL_CODE_ANON_REQ - req_payload = data + tag_bytes + req_payload = data # no random tag; timestamp provides uniqueness self.cleanup_expired_binary_requests() - self.register_binary_request( - tag_hex, - request_type=request_type, - timeout_seconds=timeout_seconds, - pubkey_prefix=pub_key[:6].hex(), - ) try: - pkt, _ = PacketBuilder.create_protocol_request( + pkt, timestamp = PacketBuilder.create_protocol_request( contact=proxy, local_identity=self._identity, protocol_code=PROTOCOL_CODE_ANON_REQ, data=req_payload, ) + # Use the timestamp as the tag — matches what the repeater echoes back + tag_int = timestamp + tag_bytes = tag_int.to_bytes(4, "little") + tag_hex = tag_bytes.hex() + self.register_binary_request( + tag_hex, + request_type=request_type, + timeout_seconds=timeout_seconds, + pubkey_prefix=pub_key[:6].hex(), + ) self._apply_flood_scope(pkt) success = await self._send_packet(pkt, wait_for_ack=False) except Exception as e: logger.error(f"Anon request send error: {e}") - self._pending_binary_requests.pop(tag_hex, None) + if "tag_hex" in locals(): + self._pending_binary_requests.pop(tag_hex, None) return SentResult(success=False) if not success: self._pending_binary_requests.pop(tag_hex, None) @@ -1158,6 +1163,24 @@ async def send_logout(self, pub_key: bytes) -> bool: logger.error(f"Logout error: {e}") return False + async def _wait_for_path_propagation(self, proxy: Any, request_type: str) -> None: + """Wait for reciprocal PATH to propagate through the mesh for multi-hop contacts. + + After login, pyMC sends a reciprocal PATH so the remote repeater learns + the return route. Each mesh hop adds ~500ms (airtime + processing). + Without this delay, the first REQ may arrive before the reciprocal PATH, + causing the remote to fall back to sendFlood() — which gets dropped by + intermediate repeaters due to transport-code region filtering. + """ + out_path_len = getattr(proxy, "out_path_len", -1) + if out_path_len > 0: + propagation_delay = out_path_len * 0.5 # e.g. 3 hops → 1.5s + logger.debug( + f"Multi-hop {request_type}: waiting {propagation_delay:.1f}s for " + f"reciprocal PATH propagation ({out_path_len} hops)" + ) + await asyncio.sleep(propagation_delay) + async def send_status_request(self, pub_key: bytes, timeout: float = 15.0) -> dict: """Send a protocol request for repeater status/stats.""" contact = self.contacts.get_by_key(pub_key) @@ -1173,6 +1196,7 @@ async def send_status_request(self, pub_key: bytes, timeout: float = 15.0) -> di waiter = ResponseWaiter() proto_handler.set_response_callback(contact_hash, waiter.callback) try: + await self._wait_for_path_propagation(proxy, "stats request") pkt, _ = PacketBuilder.create_protocol_request( contact=proxy, local_identity=self._identity, @@ -1216,6 +1240,7 @@ async def send_telemetry_request( waiter = ResponseWaiter() proto_handler.set_response_callback(contact_hash, waiter.callback) try: + await self._wait_for_path_propagation(proxy, "telemetry request") inv = PacketBuilder._compute_inverse_perm_mask( want_base, want_location, want_environment ) @@ -1227,10 +1252,17 @@ async def send_telemetry_request( ) await self._send_packet(pkt, wait_for_ack=False) result = await waiter.wait(timeout) + telemetry_data = dict(result.get("parsed", {})) + raw_bytes = telemetry_data.get("raw_bytes", b"") + if raw_bytes and len(pub_key) >= 6: + # Companion-style frame: 0x8B + reserved + 6-byte pubkey prefix + LPP + telemetry_data["frame_bytes"] = ( + bytes([PUSH_CODE_TELEMETRY_RESPONSE, 0]) + pub_key[:6] + raw_bytes + ) return { "success": result.get("success", False), "contact": contact.name, - "telemetry_data": result.get("parsed", {}), + "telemetry_data": telemetry_data, "response_text": result.get("text"), "reason": ("Telemetry received" if result.get("success") else "Telemetry failed"), } diff --git a/src/pymc_core/companion/companion_bridge.py b/src/pymc_core/companion/companion_bridge.py index f86a749..b733f34 100644 --- a/src/pymc_core/companion/companion_bridge.py +++ b/src/pymc_core/companion/companion_bridge.py @@ -87,22 +87,42 @@ async def process_path_ack_variants(self, packet: Packet) -> Optional[int]: return None encrypted = bytes(payload[2:]) # Try each contact with matching src_hash until decryption succeeds + contacts_tried = 0 for contact in self._bridge.contacts.contacts: try: pk = contact.public_key pub = bytes.fromhex(pk) if isinstance(pk, str) else bytes(pk) if len(pub) != 32 or pub[0] != src_hash: continue + contacts_tried += 1 peer_id = Identity(pub) shared_secret = peer_id.calc_shared_secret(self._bridge._identity.get_private_key()) aes_key = shared_secret[:16] decrypted = CryptoUtils.mac_then_decrypt(aes_key, shared_secret, encrypted) - except Exception: + except Exception as e: + logger.debug( + "process_path_ack_variants: decrypt failed for src=0x%02x " "contact=%s: %s", + src_hash, + getattr(contact, "name", "?"), + e, + ) continue if len(decrypted) < 2: + logger.debug( + "process_path_ack_variants: decrypted too short (%d) for src=0x%02x", + len(decrypted), + src_hash, + ) continue path_len = min(decrypted[0], MAX_PATH_SIZE) if 1 + path_len > len(decrypted): + logger.debug( + "process_path_ack_variants: path_len=%d exceeds decrypted len=%d " + "for src=0x%02x", + path_len, + len(decrypted), + src_hash, + ) continue path_bytes = bytes(decrypted[1 : 1 + path_len]) # Firmware pattern: onContactPathRecv stores out_path so replies can use sendDirect() @@ -112,12 +132,31 @@ async def process_path_ack_variants(self, packet: Packet) -> Optional[int]: contact_obj.out_path_len = path_len contact_obj.out_path = path_bytes self._bridge.contacts.update(contact_obj) + logger.debug( + "process_path_ack_variants: updated out_path for src=0x%02x " + "contact=%s path_len=%d", + src_hash, + getattr(contact, "name", "?"), + path_len, + ) + else: + logger.debug( + "process_path_ack_variants: get_by_key returned None for src=0x%02x", + src_hash, + ) await self._bridge._fire_callbacks("contact_path_updated", pub, path_len, path_bytes) # If this PATH carries an ACK, return it so send_confirmed can fire extra_start = 1 + path_len if len(decrypted) >= extra_start + 1 + 4 and decrypted[extra_start] == PAYLOAD_TYPE_ACK: return int.from_bytes(decrypted[extra_start + 1 : extra_start + 5], "little") return None + if contacts_tried > 0: + logger.debug( + "process_path_ack_variants: no contact decrypted successfully for src=0x%02x " + "(tried %d)", + src_hash, + contacts_tried, + ) return None async def _notify_ack_received(self, crc: int) -> None: @@ -218,6 +257,7 @@ def _reject_all(*args, **kwargs) -> tuple[bool, int]: self._login_response_handler = core.login_response_handler self._text_handler_ref = core.text_handler core.protocol_response_handler.set_binary_response_callback(self._on_binary_response) + core.protocol_response_handler.set_packet_injector(self._packet_injector) # ------------------------------------------------------------------------- # Handler accessors (used by CompanionBase concrete send methods) @@ -258,6 +298,11 @@ async def process_received_packet(self, packet: Packet) -> None: except Exception as e: logger.error(f"Handler error for type {ptype:02X}: {e}") + # NOTE: PATH packets are already delivered to protocol_response_handler + # via PathHandler.__call__ (path.py), which runs as the handler above. + # No duplicate call here — it would cause double decryption and could + # deliver the result to response waiters twice. + def _update_stores_from_advert(self, packet: Packet, advert_data: dict): """Update ContactStore and PathCache from advert result. Returns the Contact or None.""" try: diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index ce5205a..736082f 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -1199,13 +1199,17 @@ async def _cmd_send_telemetry_req(self, data: bytes) -> None: ) if not result.get("success"): self._write_frame(bytes([PUSH_CODE_TELEMETRY_RESPONSE, 0]) + pubkey[:6]) + await self._drain_writer() return telem_data = result.get("telemetry_data", {}) raw_bytes = telem_data.get("raw_bytes", b"") if not raw_bytes: self._write_frame(bytes([PUSH_CODE_TELEMETRY_RESPONSE, 0]) + pubkey[:6]) + await self._drain_writer() return self._write_frame(bytes([PUSH_CODE_TELEMETRY_RESPONSE, 0]) + pubkey[:6] + raw_bytes) + await self._drain_writer() + logger.info("Telemetry push sent to client: %d bytes LPP", len(raw_bytes)) async def _cmd_send_self_advert(self, data: bytes) -> None: flood = len(data) >= 1 and data[0] == 1 diff --git a/src/pymc_core/node/handlers/group_text.py b/src/pymc_core/node/handlers/group_text.py index 8868829..5030c64 100644 --- a/src/pymc_core/node/handlers/group_text.py +++ b/src/pymc_core/node/handlers/group_text.py @@ -265,7 +265,9 @@ async def __call__(self, packet: Packet) -> None: # Check if this message is from ourselves using sender name (echo detection) is_own = self._is_own_message(packet) if is_own: - self.log(f"Own echo detected (will publish for heard-count): {sender_name}: {message_body}") + self.log( + f"Own echo detected (skip publish to client): {sender_name}: {message_body}" + ) # Log the group message self.log(f"<<< Channel [{channel_name}] {sender_name}: {message_body} >>>") @@ -293,7 +295,12 @@ async def _save_and_broadcast_group_message( ): """Save the group message to database and broadcast via WebSocket.""" try: - message_id = packet.get_packet_hash_hex(16) + message_id = packet.get_packet_hash_hex(16) + + # Do not publish NEW_CHANNEL_MESSAGE for our own messages (inject + echoes). + # The client already has the sent message; publishing per echo would spam the event. + if is_outgoing: + return # Publish channel message event if available if self.event_service: diff --git a/src/pymc_core/node/handlers/protocol_response.py b/src/pymc_core/node/handlers/protocol_response.py index 86a073e..e12cfff 100644 --- a/src/pymc_core/node/handlers/protocol_response.py +++ b/src/pymc_core/node/handlers/protocol_response.py @@ -9,8 +9,14 @@ from typing import Any, Callable, Dict, Optional from ...protocol import CryptoUtils, Identity, Packet -from ...protocol.constants import MAX_PATH_SIZE, PAYLOAD_TYPE_PATH, PAYLOAD_TYPE_RESPONSE +from ...protocol.constants import ( + MAX_PATH_SIZE, + PAYLOAD_TYPE_PATH, + PAYLOAD_TYPE_RESPONSE, + ROUTE_TYPE_DIRECT, +) from ...protocol.crypto import CIPHER_BLOCK_SIZE, CIPHER_MAC_SIZE +from ...protocol.packet_builder import PacketBuilder # --------------------------------------------------------------------------- # Built-in CayenneLPP decoder (no external dependency) @@ -34,7 +40,7 @@ 0x71: ("Accelerometer", 6, 1000, True), # LPP_ACCELEROMETER = 113, 3×int16 0x73: ("Barometer", 2, 10, False), # LPP_BAROMETRIC_PRESSURE = 115 0x74: ("Voltage", 2, 100, False), # LPP_VOLTAGE = 116, 0.01V - 0x75: ("Current", 2, 1000, False), # LPP_CURRENT = 117, 0.001A + 0x75: ("Current", 2, 1000, True), # LPP_CURRENT = 117, 0.001A signed 0x76: ("Frequency", 4, 1, False), # LPP_FREQUENCY = 118, 1Hz 0x78: ("Percentage", 1, 1, False), # LPP_PERCENTAGE = 120, 1-100% 0x79: ("Altitude", 2, 1, True), # LPP_ALTITUDE = 121, 1m signed @@ -46,8 +52,10 @@ 0x85: ("Unix Time", 4, 1, False), # LPP_UNIXTIME = 133 0x86: ("Gyroscope", 6, 100, True), # LPP_GYROMETER = 134, 3×int16 0x87: ("Colour", 3, 1, False), # LPP_COLOUR = 135, RGB - 0x88: ("GPS", 9, 1, True), # LPP_GPS = 136, lat(3)+lon(3)+alt(3) + 0x88: ("GPS", 9, 1, True), # LPP_GPS = 136, lat(3)+lon(3)+alt(3), mult 10000/100 0x8E: ("Switch", 1, 1, False), # LPP_SWITCH = 142, 0/1 + # LPP_POLYLINE 240: variable size; min 8 bytes (size+delta+lon+lat). Skip min to continue. + 0xF0: ("Polyline", 8, 1, False), # LPP_POLYLINE = 240 } @@ -111,6 +119,17 @@ def _decode_cayenne_lpp(data: bytes) -> list: "raw_value": raw.hex(), } ) + elif type_id == 0xF0: + # Polyline: variable size; we only consume minimum 8 bytes (MeshCore skipData). + sensors.append( + { + "channel": channel, + "type": name, + "type_id": type_id, + "value": raw.hex(), + "raw_value": raw.hex(), + } + ) else: val = int.from_bytes(raw, "big", signed=signed) sensors.append( @@ -144,6 +163,26 @@ def __init__(self, log_fn: Callable[[str], None], local_identity, contact_book): # Optional: decrypted payloads with tag+data (and optional path) passed as binary response. # Signature: (tag_bytes, response_data, path_info=None). self._binary_response_callback: Optional[Callable[..., Any]] = None + # Reference to LoginResponseHandler for state-based login detection + self._login_response_handler: Optional[Any] = None + # Packet injector for sending reciprocal PATH packets (mirrors C++ Mesh.cpp:168-169) + self._packet_injector: Optional[Callable] = None + + def set_login_response_handler(self, handler: Any) -> None: + """Set login handler ref for checking active login state.""" + self._login_response_handler = handler + + def set_packet_injector(self, injector: Optional[Callable]) -> None: + """Set packet injector for sending reciprocal PATH packets. + + When the companion receives a flooded PATH from a remote repeater, + the C++ firmware sends a reciprocal PATH back so the remote repeater + learns the route to us (Mesh.cpp:168-169). Without this, the remote + repeater has no out_path for us and must fall back to plain FLOOD for + responses — which intermediate repeaters may drop due to transport-code + region filtering. + """ + self._packet_injector = injector @staticmethod def payload_type() -> int: @@ -174,6 +213,13 @@ async def __call__(self, pkt: Packet) -> None: # Both PATH and RESPONSE packets share the same structure: # dest_hash(1) + src_hash(1) + encrypted_data(N) src_hash = pkt.payload[1] + pkt_type = (pkt.header >> 2) & 0x0F + route_label = "FLOOD" if pkt.is_route_flood() else "DIRECT" + if pkt_type == PAYLOAD_TYPE_RESPONSE: + self._log( + f"[ProtocolResponse] Received RESPONSE (0x01) from 0x{src_hash:02X} " + f"({route_label}, {len(pkt.payload)}B)" + ) # Proceed if we have a callback for this source or the binary (path-discovery) callback if src_hash not in self._response_callbacks and self._binary_response_callback is None: @@ -203,6 +249,11 @@ async def __call__(self, pkt: Packet) -> None: return callback = self._response_callbacks[src_hash] if callback: + if parsed_data.get("type") == "telemetry": + self._log( + f"[ProtocolResponse] Delivering telemetry to waiter " + f"(src=0x{src_hash:02X}, {parsed_data.get('sensor_count', 0)} sensors)" + ) callback(success, decoded_text, parsed_data) return @@ -263,23 +314,119 @@ async def __call__(self, pkt: Packet) -> None: self._log(f"[ProtocolResponse] Error processing protocol response: {e}") def _is_login_response(self, pkt: Packet, raw_decrypted: Optional[bytes]) -> bool: - """True if raw_decrypted is a login response (13 bytes, response_code at [4] in 0x00/0x01). - Used to avoid delivering a retransmitted login to the stats/telemetry waiter. + """True if a login is currently pending for the source contact. + + Mirrors the C++ companion firmware pattern: classify responses by + pending-request state rather than payload content. The previous + content-based check (``inner[4] in (0x00, 0x01)``) falsely matched + CayenneLPP telemetry whose first byte is channel 0x01. """ - if not raw_decrypted or len(raw_decrypted) < 13: + if not self._login_response_handler: return False - pkt_type = (pkt.header >> 2) & 0x0F - if pkt_type == PAYLOAD_TYPE_PATH: - if len(raw_decrypted) < 2: - return False - path_len_byte = raw_decrypted[0] - inner_offset = 1 + path_len_byte + 1 - if len(raw_decrypted) < inner_offset + 13: - return False - inner = raw_decrypted[inner_offset : inner_offset + 13] - else: - inner = raw_decrypted[:13] - return len(inner) == 13 and inner[4] in (0x00, 0x01) + passwords = getattr(self._login_response_handler, "_active_login_passwords", {}) + if not passwords: + return False + if len(pkt.payload) < 2: + return False + src_hash = pkt.payload[1] + return src_hash in passwords + + def _update_contact_path( + self, + contact_pubkey: bytes, + src_hash: int, + path_len_byte: int, + decrypted: bytes, + ) -> None: + """Update contact out_path from decrypted PATH data (firmware onContactPathRecv pattern). + + When a PATH packet is successfully decrypted, store the return path + on the contact so that subsequent requests use sendDirect() instead + of sendFlood(). This mirrors C++ ``BaseChatMesh::onContactPathRecv``. + """ + try: + if path_len_byte > MAX_PATH_SIZE: + return + out_path_bytes = bytes(decrypted[1 : 1 + path_len_byte]) + contact_obj = self._contact_book.get_by_key(contact_pubkey) + if contact_obj is not None: + contact_obj.out_path_len = path_len_byte + contact_obj.out_path = out_path_bytes + self._contact_book.update(contact_obj) + self._log( + f"[ProtocolResponse] Updated out_path for 0x{src_hash:02X}: " + f"path_len={path_len_byte}" + ) + else: + self._log( + f"[ProtocolResponse] Cannot update out_path for 0x{src_hash:02X}: " + f"contact not found by key" + ) + except Exception as e: + self._log(f"[ProtocolResponse] Failed to update out_path: {e}") + + async def _send_reciprocal_path( + self, + src_hash: int, + shared_secret: bytes, + pkt: Packet, + decrypted: bytes, + path_len_byte: int, + ) -> None: + """Send a reciprocal PATH back to the sender so it learns the route to us. + + Mirrors C++ firmware behaviour (Mesh.cpp lines 166-169): + + mesh::Packet* rpath = createPathReturn( + &src_hash, secret, pkt->path, pkt->path_len, 0, NULL, 0); + if (rpath) sendDirect(rpath, path, path_len, 500); + + - ``pkt.path`` is the flood accumulation path on the received PATH + (the inbound route, e.g. [hash_X, hash_B]). This is placed inside + the reciprocal's encrypted payload so the remote repeater stores it + as *its* ``out_path`` — the route from itself back to us. + - The reciprocal is sent **DIRECT** using the inner ``out_path`` + extracted from the decrypted data (e.g. [hash_B, hash_X]), which + routes through the mesh to reach the remote repeater. + """ + if self._packet_injector is None: + return + try: + our_hash = self._local_identity.get_public_key()[0] + # The inbound flood path (pkt.path) tells the remote repeater + # "to reach me, go through these intermediate hops". + in_path = list(pkt.path) if pkt.path else [] + + # Build the reciprocal PATH packet. create_path_return produces a + # FLOOD PATH by default; we convert it to DIRECT below. + reciprocal = PacketBuilder.create_path_return( + dest_hash=src_hash, + src_hash=our_hash, + secret=shared_secret, + path=in_path, + extra_type=0xFF, # no extra payload (dummy, same as C++ NULL/0) + extra=b"", + ) + + # Convert to DIRECT routing using the inner out_path (the route + # from us to the remote repeater). + out_path_bytes = bytes(decrypted[1 : 1 + path_len_byte]) + reciprocal.header = (reciprocal.header & ~0x03) | ROUTE_TYPE_DIRECT + reciprocal.path = bytearray(out_path_bytes) + reciprocal.path_len = len(out_path_bytes) + + # Await injection so the reciprocal PATH is serialized through the + # radio TX pipeline before this method returns. This ensures the + # login callback doesn't fire until the reciprocal PATH is in flight, + # preventing the app's first stats REQ from racing ahead of it. + await self._packet_injector(reciprocal) + + self._log( + f"[ProtocolResponse] Sending reciprocal PATH to 0x{src_hash:02X} " + f"via DIRECT (out_path_len={path_len_byte}, in_path_len={len(in_path)})" + ) + except Exception as e: + self._log(f"[ProtocolResponse] Failed to send reciprocal PATH: {e}") async def _decrypt_protocol_response( self, pkt: Packet, src_hash: int @@ -296,15 +443,17 @@ async def _decrypt_protocol_response( if len(payload) < 2 + 4: # need dest+src + at least MAC(2)+min ciphertext return False, "Payload too short", {}, None encrypted_data = payload[2:] - # MAC(2) + ciphertext; ciphertext must be block-aligned (16 bytes) - # (20 = login response, 68 = stats. Variable: REQ_TYPE_GET_OWNER_INFO can produce 19 bytes.) + # MAC(2) + ciphertext. Ciphertext may be block-aligned or truncated (e.g. long PATH + # packets lose one byte to header size; telemetry PATH 63 bytes). Allow MAC + 15 bytes + # minimum so we can pad to one block and attempt decrypt. enc_len = len(encrypted_data) - if enc_len <= CIPHER_MAC_SIZE or (enc_len - CIPHER_MAC_SIZE) % CIPHER_BLOCK_SIZE != 0: + min_enc = CIPHER_MAC_SIZE + (CIPHER_BLOCK_SIZE - 1) # 17: MAC(2) + 15 ciphertext + if enc_len < min_enc: self._log( - f"[ProtocolResponse] Payload truncated or invalid length for hash " - f"0x{src_hash:02X}: encrypted_data={enc_len}B (need MAC(2)+16k ciphertext)" + f"[ProtocolResponse] Payload too short for hash 0x{src_hash:02X}: " + f"encrypted_data={enc_len}B (need MAC(2)+≥15 bytes ciphertext)" ) - return False, "Payload truncated or invalid length", {}, None + return False, "Payload too short", {}, None pkt_type = (pkt.header >> 2) & 0x0F # Try every contact matching src_hash (same “try all hash matches” as TXT_MSG and PATH ACK). @@ -341,6 +490,24 @@ async def _decrypt_protocol_response( f"not RESPONSE" ) + # Firmware pattern (onContactPathRecv): update contact out_path + # so subsequent requests use sendDirect() instead of sendFlood(). + self._update_contact_path(contact_pubkey, src_hash, path_len_byte, decrypted) + + # Firmware pattern (Mesh.cpp:168-169): send reciprocal PATH back + # to the sender so it learns the route to us. Without this, the + # remote repeater has no out_path for us and must fall back to + # plain FLOOD for responses — which intermediate repeaters may + # drop due to transport-code region filtering. + if pkt.is_route_flood(): + await self._send_reciprocal_path( + src_hash, + shared_secret, + pkt, + decrypted, + path_len_byte, + ) + success, text, parsed = self._parse_protocol_response(response_data) return success, text, parsed, decrypted @@ -353,7 +520,7 @@ async def _decrypt_protocol_response( else: self._log( f"[ProtocolResponse] HMAC failed for hash 0x{src_hash:02X} " - f"(tried {len(contacts_tried)} contact(s), repeater PATH uses same ECDH as login)" + f"(tried {len(contacts_tried)} contact(s). Repeater PATH uses same ECDH as login)" ) return False, "Decryption failed: Invalid HMAC", {}, None @@ -362,7 +529,7 @@ def _parse_protocol_response(self, data: bytes) -> tuple[bool, str, Dict[str, An Parse order: 0. Login response (13 bytes, response_code at [4] 0x00/0x01) → binary, - for LoginResponseHandler. + for LoginResponseHandler. 1. Telemetry (reflected_timestamp + valid CayenneLPP signature byte check) 2. Stats (RepeaterStats struct, ≥52 bytes, only when not telemetry) 3. Text / status (UTF-8 printable after stripping tag + nulls) @@ -497,6 +664,14 @@ def _parse_stats_response(self, data: bytes) -> Optional[Dict[str, Any]]: n_recv_errors, # uint32 offset 52 ) = struct.unpack(" 10000: # > 10V is unreasonable + return None + if last_rssi < -200 or last_rssi > 0: # RSSI always negative, > -200 dBm + return None + raw_stats = { "batt_milli_volts": batt_milli_volts, "curr_tx_queue_len": curr_tx_queue_len, diff --git a/src/pymc_core/node/handlers/registry.py b/src/pymc_core/node/handlers/registry.py index 746188f..6e57613 100644 --- a/src/pymc_core/node/handlers/registry.py +++ b/src/pymc_core/node/handlers/registry.py @@ -65,6 +65,7 @@ def create_core_handlers( login_response_handler = LoginResponseHandler(identity, contacts, log_fn) login_response_handler.set_protocol_response_handler(protocol_response_handler) + protocol_response_handler.set_login_response_handler(login_response_handler) path_handler = PathHandler( log_fn, diff --git a/src/pymc_core/protocol/crypto.py b/src/pymc_core/protocol/crypto.py index e3eeb5f..06d0075 100644 --- a/src/pymc_core/protocol/crypto.py +++ b/src/pymc_core/protocol/crypto.py @@ -36,10 +36,17 @@ def _aes_encrypt(key: bytes, data: bytes) -> bytes: @staticmethod def _aes_decrypt(key: bytes, ciphertext: bytes) -> bytes: cipher = AES.new(key, AES.MODE_ECB) - return b"".join( + orig_len = len(ciphertext) + if orig_len % CIPHER_BLOCK_SIZE != 0: + # Firmware may send non-block-aligned ciphertext (e.g. telemetry 63 bytes). + # Pad to next block for decrypt, then return only original length. + pad_len = CIPHER_BLOCK_SIZE - (orig_len % CIPHER_BLOCK_SIZE) + ciphertext = ciphertext + (b"\x00" * pad_len) + out = b"".join( cipher.decrypt(ciphertext[i : i + CIPHER_BLOCK_SIZE]) for i in range(0, len(ciphertext), CIPHER_BLOCK_SIZE) ) + return out[:orig_len] @staticmethod def encrypt_then_mac(key_aes: bytes, shared_secret: bytes, plaintext: bytes) -> bytes: diff --git a/src/pymc_core/protocol/packet_builder.py b/src/pymc_core/protocol/packet_builder.py index b458125..ab22836 100644 --- a/src/pymc_core/protocol/packet_builder.py +++ b/src/pymc_core/protocol/packet_builder.py @@ -827,7 +827,8 @@ def create_protocol_request( ) out_path_len = getattr(contact, "out_path_len", -1) - if out_path_len < 0: + out_path = getattr(contact, "out_path", b"") or b"" + if out_path_len <= 0 or not out_path: route_type = "flood" else: route_type = "direct" @@ -835,12 +836,10 @@ def create_protocol_request( header = PacketBuilder._create_header(PAYLOAD_TYPE_REQ, route_type) packet = PacketBuilder._create_packet(header, payload) - if route_type == "direct" and out_path_len > 0: - out_path = getattr(contact, "out_path", b"") - if out_path: - path_bytes = out_path[:MAX_PATH_SIZE] - packet.path = bytearray(path_bytes) - packet.path_len = len(packet.path) + if route_type == "direct" and len(out_path) > 0: + path_bytes = out_path[:MAX_PATH_SIZE] + packet.path = bytearray(path_bytes) + packet.path_len = len(packet.path) return packet, timestamp diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 835aaf3..0ef6662 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -316,6 +316,34 @@ def test_protocol_response_handler_initialization(self): assert self.handler._log == self.log_fn assert self.handler._local_identity == self.local_identity + def test_parse_telemetry_response_tag_plus_lpp(self): + """Parse tag(4) + CayenneLPP matches repeater firmware format; raw_bytes is LPP only.""" + # Repeater sends: tag(4) + LPP. Tag is 4-byte reflected_timestamp (little-endian). + # MeshCore first record: addVoltage(TELEM_CHANNEL_SELF=1, v) + # → channel=1, type=0x74 (LPP_VOLTAGE), 2 bytes 0.01V big-endian. 3.7V → 370 → 0x01 0x72 + tag = b"\x01\x00\x00\x00" # LE 1 + lpp = bytes([0x01, 0x74, 0x01, 0x72]) # ch 1, Voltage, 370 (3.70 V) + data = tag + lpp + result = self.handler._parse_telemetry_response(data) + assert result is not None + assert result["type"] == "telemetry" + assert result["reflected_timestamp"] == 1 + assert result["raw_bytes"] == lpp + assert result["sensor_count"] == 1 + sensor = result["sensors"][0] + assert sensor["channel"] == 1 + assert sensor["type"] == "Voltage" + assert sensor["type_id"] == 0x74 + assert abs(sensor["value"] - 3.7) < 0.001 + + def test_parse_telemetry_response_rejects_non_telemetry(self): + """Payload without channel=1, type=0x74 signature is not classified as telemetry.""" + tag = b"\x00\x00\x00\x00" # LE 0 + # Not starting with 0x01 0x74 + data = tag + bytes([0x01, 0x67, 0x00, 0x00]) # ch 1, Temperature, 0°C + result = self.handler._parse_telemetry_response(data) + assert result is None + # Trace Handler Tests class TestTraceHandler: From 98d9e9b34850b6fb9dd85872cff956af8a91f8dd Mon Sep 17 00:00:00 2001 From: agessaman Date: Mon, 23 Feb 2026 20:06:21 -0800 Subject: [PATCH 20/50] Persist contact changes from adverts and PATH updates to SQLite _save_contacts() was only called after explicit add/remove commands, so advert-driven name changes and PATH route updates were lost on restart. --- src/pymc_core/companion/frame_server.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index 736082f..f755101 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -362,10 +362,18 @@ async def on_advert_received(contact): _write_push(full) except Exception as e: logger.exception("advert_received callback error: %s", e) + try: + self._save_contacts() + except Exception as e: + logger.warning("Save contacts after advert failed: %s", e) async def on_contact_path_updated(pub_key, path_len, path): if isinstance(pub_key, bytes) and len(pub_key) >= 32: _write_push(bytes([PUSH_CODE_PATH_UPDATED]) + pub_key[:32]) + try: + self._save_contacts() + except Exception as e: + logger.warning("Save contacts after path update failed: %s", e) async def on_channel_message_received( channel_name, From c6734c1eb7c92c375823a834cdacd9590d66f69b Mon Sep 17 00:00:00 2001 From: agessaman Date: Tue, 24 Feb 2026 14:28:12 -0800 Subject: [PATCH 21/50] feat(companion): implement auto-add contact type handling and contact store management - Added functionality to automatically add contacts based on their type, mirroring C++ logic for selective contact addition. - Introduced methods for handling contact deletion and full contact store scenarios, including callbacks for notifying the application. - Enhanced the contact store to support adding or overwriting contacts when the store is full, ensuring non-favourite contacts are replaced as needed. - Updated the CompanionBridge to utilize the new auto-add logic during advert processing, improving contact management efficiency. --- src/pymc_core/companion/companion_base.py | 39 +++++++++++++++++++++ src/pymc_core/companion/companion_bridge.py | 36 ++++++++++++++++--- src/pymc_core/companion/contact_store.py | 31 +++++++++++++++- src/pymc_core/companion/frame_server.py | 26 ++++++++++++++ 4 files changed, 126 insertions(+), 6 deletions(-) diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py index e6206ba..824b631 100644 --- a/src/pymc_core/companion/companion_base.py +++ b/src/pymc_core/companion/companion_base.py @@ -41,6 +41,11 @@ ADV_TYPE_ROOM, ADV_TYPE_SENSOR, ADVERT_LOC_SHARE, + AUTOADD_CHAT, + AUTOADD_OVERWRITE_OLDEST, + AUTOADD_REPEATER, + AUTOADD_ROOM, + AUTOADD_SENSOR, DEFAULT_MAX_CHANNELS, DEFAULT_MAX_CONTACTS, DEFAULT_OFFLINE_QUEUE_SIZE, @@ -77,6 +82,8 @@ "raw_data_received", "binary_response", "path_discovery_response", + "contact_deleted", + "contacts_full", ] @@ -564,6 +571,30 @@ def set_autoadd_config(self, config: int) -> None: self.prefs.autoadd_config = config self._save_prefs() + # Map ADV_TYPE_* → AUTOADD_* bitmask bits (mirrors C++ shouldAutoAddContactType) + _AUTOADD_TYPE_MAP: dict[int, int] = { + ADV_TYPE_CHAT: AUTOADD_CHAT, # 1 → 0x02 + ADV_TYPE_REPEATER: AUTOADD_REPEATER, # 2 → 0x04 + ADV_TYPE_ROOM: AUTOADD_ROOM, # 3 → 0x08 + ADV_TYPE_SENSOR: AUTOADD_SENSOR, # 4 → 0x10 + } + + def should_auto_add_contact_type(self, contact_type: int) -> bool: + """Check if a contact type should be auto-added based on current preferences. + + Mirrors C++ MyMesh::shouldAutoAddContactType (MyMesh.cpp:281-304). + """ + # manual_add_contacts bit 0 == 0 → auto-add ALL types + if (self.prefs.manual_add_contacts & 1) == 0: + return True + # Selective mode: check the type-specific bit in autoadd_config + type_bit = self._AUTOADD_TYPE_MAP.get(contact_type, 0) + return bool(self.prefs.autoadd_config & type_bit) if type_bit else False + + def should_overwrite_when_full(self) -> bool: + """Check if overwrite-oldest is enabled. Mirrors C++ shouldOverwriteWhenFull.""" + return bool(self.prefs.autoadd_config & AUTOADD_OVERWRITE_OLDEST) + # ------------------------------------------------------------------------- # Push Callbacks # ------------------------------------------------------------------------- @@ -609,6 +640,14 @@ def on_path_discovery_response(self, callback: Callable) -> None: """Register callback for path discovery 0x8D. (tag_bytes, pubkey, out_path, in_path).""" self._push_callbacks["path_discovery_response"].append(callback) + def on_contact_deleted(self, callback: Callable) -> None: + """Register callback for PUSH 0x8F (contact overwritten). Callback(pub_key_bytes).""" + self._push_callbacks["contact_deleted"].append(callback) + + def on_contacts_full(self, callback: Callable) -> None: + """Register callback for PUSH 0x90 (contacts store full). Callback().""" + self._push_callbacks["contacts_full"].append(callback) + def register_binary_request( self, tag_hex: str, diff --git a/src/pymc_core/companion/companion_bridge.py b/src/pymc_core/companion/companion_bridge.py index b733f34..7efb4a4 100644 --- a/src/pymc_core/companion/companion_bridge.py +++ b/src/pymc_core/companion/companion_bridge.py @@ -292,7 +292,7 @@ async def process_received_packet(self, packet: Packet) -> None: try: result = await handler(packet) if ptype == PAYLOAD_TYPE_ADVERT and result: - contact = self._update_stores_from_advert(packet, result) + contact = await self._update_stores_from_advert(packet, result) if contact: await self._fire_callbacks("advert_received", contact) except Exception as e: @@ -303,8 +303,16 @@ async def process_received_packet(self, packet: Packet) -> None: # No duplicate call here — it would cause double decryption and could # deliver the result to response waiters twice. - def _update_stores_from_advert(self, packet: Packet, advert_data: dict): - """Update ContactStore and PathCache from advert result. Returns the Contact or None.""" + async def _update_stores_from_advert(self, packet: Packet, advert_data: dict): + """Update ContactStore and PathCache from advert result. + + Mirrors C++ BaseChatMesh::onAdvertRecv (BaseChatMesh.cpp:106-170): + - Existing contacts are always updated (name, GPS, etc.) + - New contacts are subject to auto-add type filtering + - When store is full, overwrite-oldest replaces the oldest non-favourite + + Returns the Contact or None. + """ try: pub_key = bytes.fromhex(advert_data.get("public_key", "")) if len(pub_key) < 7: @@ -320,10 +328,11 @@ def _update_stores_from_advert(self, packet: Packet, advert_data: dict): last_advert_ts = advert_data.get("advert_timestamp", 0) if last_advert_ts > now: last_advert_ts = now + adv_type = advert_data.get("contact_type_id", 0) contact = Contact( public_key=pub_key, name=name, - adv_type=advert_data.get("contact_type_id", 0), + adv_type=adv_type, gps_lat=advert_data.get("latitude", 0.0), gps_lon=advert_data.get("longitude", 0.0), lastmod=now, @@ -331,7 +340,24 @@ def _update_stores_from_advert(self, packet: Packet, advert_data: dict): out_path_len=-1, out_path=b"", ) - self.contacts.add(contact) + + is_existing = self.contacts.get_by_key(pub_key) is not None + if is_existing: + # Always update existing contacts (C++ BaseChatMesh.cpp:158-167) + self.contacts.update(contact) + elif not self.should_auto_add_contact_type(adv_type): + # Type not allowed — still fire callback so app sees the advert + logger.debug("Auto-add filtered: type %d not allowed", adv_type) + elif self.should_overwrite_when_full() and self.contacts.is_full(): + ok, overwritten = self.contacts.add_or_overwrite(contact) + if ok and overwritten: + await self._fire_callbacks("contact_deleted", overwritten) + elif not ok: + await self._fire_callbacks("contacts_full") + else: + added = self.contacts.add(contact) + if not added and self.contacts.is_full(): + await self._fire_callbacks("contacts_full") self.path_cache.update( AdvertPath( diff --git a/src/pymc_core/companion/contact_store.py b/src/pymc_core/companion/contact_store.py index a44c3a3..2a3c395 100644 --- a/src/pymc_core/companion/contact_store.py +++ b/src/pymc_core/companion/contact_store.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Iterable, Iterator, Optional +from typing import Iterable, Iterator, Optional, Tuple from .constants import DEFAULT_MAX_CONTACTS from .models import Contact @@ -103,6 +103,35 @@ def add(self, contact: Contact) -> bool: self._proxies[contact.public_key] = ContactProxy(contact) return True + def add_or_overwrite(self, contact: Contact) -> Tuple[bool, Optional[bytes]]: + """Add a contact, overwriting the oldest non-favourite if store is full. + + Mirrors C++ BaseChatMesh::allocateContactSlot (BaseChatMesh.cpp:70-90). + + Returns: + (success, overwritten_pubkey_or_None) + """ + if contact.public_key in self._contacts: + return self.update(contact), None + if len(self._contacts) >= self._max_contacts: + # Find oldest non-favourite (flags bit 0 = favourite) + oldest_key: Optional[bytes] = None + oldest_lastmod = 0xFFFFFFFF + for key, c in self._contacts.items(): + if (c.flags & 0x01) == 0 and c.lastmod < oldest_lastmod: + oldest_lastmod = c.lastmod + oldest_key = key + if oldest_key is None: + return False, None # all contacts are favourites + overwritten = oldest_key + self.remove(oldest_key) + self._contacts[contact.public_key] = contact + self._proxies[contact.public_key] = ContactProxy(contact) + return True, overwritten + self._contacts[contact.public_key] = contact + self._proxies[contact.public_key] = ContactProxy(contact) + return True, None + def update(self, contact: Contact) -> bool: """Update an existing contact. Returns False if not found.""" if contact.public_key not in self._contacts: diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index f755101..72afe69 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -54,6 +54,7 @@ CMD_SET_CUSTOM_VAR, CMD_SET_DEVICE_TIME, CMD_SET_FLOOD_SCOPE, + CMD_SET_OTHER_PARAMS, CMD_SET_RADIO_PARAMS, CMD_SET_RADIO_TX_POWER, CMD_SET_TUNING_PARAMS, @@ -72,6 +73,8 @@ PUB_KEY_SIZE, PUSH_CODE_ADVERT, PUSH_CODE_BINARY_RESPONSE, + PUSH_CODE_CONTACT_DELETED, + PUSH_CODE_CONTACTS_FULL, PUSH_CODE_CONTROL_DATA, PUSH_CODE_LOG_RX_DATA, PUSH_CODE_LOGIN_FAIL, @@ -423,6 +426,13 @@ async def on_path_discovery_response(tag_bytes, contact_pubkey, out_path, in_pat ) _write_push(frame) + async def on_contact_deleted(pub_key): + if isinstance(pub_key, bytes) and len(pub_key) >= 32: + _write_push(bytes([PUSH_CODE_CONTACT_DELETED]) + pub_key[:32]) + + async def on_contacts_full(): + _write_push(bytes([PUSH_CODE_CONTACTS_FULL])) + self.bridge.on_message_received(on_message_received) self.bridge.on_channel_message_received(on_channel_message_received) self.bridge.on_send_confirmed(on_send_confirmed) @@ -430,6 +440,8 @@ async def on_path_discovery_response(tag_bytes, contact_pubkey, out_path, in_pat self.bridge.on_contact_path_updated(on_contact_path_updated) self.bridge.on_binary_response(on_binary_response) self.bridge.on_path_discovery_response(on_path_discovery_response) + self.bridge.on_contact_deleted(on_contact_deleted) + self.bridge.on_contacts_full(on_contacts_full) # ------------------------------------------------------------------------- # Public push methods (called directly by host application) @@ -708,6 +720,8 @@ async def _handle_cmd(self, payload: bytes) -> None: await self._cmd_set_autoadd_config(data) elif cmd == CMD_GET_AUTOADD_CONFIG: await self._cmd_get_autoadd_config(data) + elif cmd == CMD_SET_OTHER_PARAMS: + await self._cmd_set_other_params(data) else: logger.warning( "Companion unsupported cmd 0x%02x (%s) len=%s", @@ -1605,3 +1619,15 @@ async def _cmd_set_autoadd_config(self, data: bytes) -> None: async def _cmd_get_autoadd_config(self, data: bytes) -> None: config = self.bridge.get_autoadd_config() self._write_frame(bytes([RESP_CODE_AUTOADD_CONFIG, config & 0xFF])) + + async def _cmd_set_other_params(self, data: bytes) -> None: + """Handle CMD_SET_OTHER_PARAMS (0x26). Mirrors MyMesh.cpp:1290-1305.""" + if len(data) < 1: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + manual_add = data[0] + telemetry_modes = data[1] if len(data) >= 2 else 0 + advert_loc_policy = data[2] if len(data) >= 3 else 0 + multi_acks = data[3] if len(data) >= 4 else 0 + self.bridge.set_other_params(manual_add, telemetry_modes, advert_loc_policy, multi_acks) + self._write_ok() From dd4c681d22457a062bf513d03fa3432d104e93bc Mon Sep 17 00:00:00 2001 From: agessaman Date: Tue, 24 Feb 2026 21:52:21 -0800 Subject: [PATCH 22/50] refactor(companion): enhance contact persistence and connection management - Updated the CompanionBridge to preserve additional fields when updating existing contacts, ensuring important data is retained during advert processing. - Refactored CompanionFrameServer to implement asynchronous contact persistence methods, improving efficiency and reliability in contact management. - Introduced a heartbeat mechanism to maintain TCP connections, enhancing stability during client-server interactions. - Improved error handling for connection issues and added TCP keepalive settings for better network resilience. --- src/pymc_core/companion/companion_bridge.py | 12 ++- src/pymc_core/companion/frame_server.py | 97 ++++++++++++++++--- src/pymc_core/node/handlers/login_response.py | 40 +++++--- 3 files changed, 122 insertions(+), 27 deletions(-) diff --git a/src/pymc_core/companion/companion_bridge.py b/src/pymc_core/companion/companion_bridge.py index 7efb4a4..7ba1c23 100644 --- a/src/pymc_core/companion/companion_bridge.py +++ b/src/pymc_core/companion/companion_bridge.py @@ -341,9 +341,15 @@ async def _update_stores_from_advert(self, packet: Packet, advert_data: dict): out_path=b"", ) - is_existing = self.contacts.get_by_key(pub_key) is not None - if is_existing: - # Always update existing contacts (C++ BaseChatMesh.cpp:158-167) + existing = self.contacts.get_by_key(pub_key) + if existing is not None: + # Always update existing contacts (C++ BaseChatMesh.cpp:158-167). + # Preserve fields that adverts don't carry: the firmware only + # updates name, type, gps, last_advert_timestamp, lastmod. + contact.out_path_len = existing.out_path_len + contact.out_path = existing.out_path + contact.flags = existing.flags + contact.sync_since = existing.sync_since self.contacts.update(contact) elif not self.should_auto_add_contact_type(adv_type): # Type not allowed — still fire callback so app sees the advert diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index 72afe69..adb623e 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -13,7 +13,9 @@ import asyncio import logging +import socket import struct +import sys import time from typing import Any, Callable, Optional @@ -178,6 +180,7 @@ def __init__( local_hash: Optional[int] = None, stats_getter: Optional[Callable] = None, control_handler: Optional[Any] = None, + heartbeat_interval: int = 15, ): self.bridge = bridge self.companion_hash = companion_hash @@ -186,6 +189,7 @@ def __init__( self.local_hash = local_hash self.stats_getter = stats_getter self._control_handler = control_handler + self._heartbeat_interval = heartbeat_interval self._server: Optional[asyncio.Server] = None self._client_writer: Optional[asyncio.StreamWriter] = None self._client_reader: Optional[asyncio.StreamReader] = None @@ -252,11 +256,18 @@ def _sync_next_from_persistence(self) -> Optional[QueuedMessage]: Default returns ``None``.""" return None - def _save_contacts(self) -> None: - """Hook: persist the current contact list. Default is a no-op.""" + async def _persist_contact(self, contact) -> None: + """Hook: persist a single contact. Default is a no-op. - def _save_channels(self) -> None: - """Hook: persist the current channel list. Default is a no-op.""" + Subclasses should override to do a fast single-row upsert rather + than rewriting the entire contact list. + """ + + async def _save_contacts(self) -> None: + """Hook: persist the full contact list. Default is a no-op.""" + + async def _save_channels(self) -> None: + """Hook: persist the full channel list. Default is a no-op.""" def _get_batt_and_storage(self) -> tuple[int, int, int]: """Hook: return (millivolts, used_kb, total_kb). Default: all zeros.""" @@ -276,7 +287,11 @@ def _write_push(data: bytes) -> None: self._client_writer.write(frame) asyncio.create_task(self._drain_writer()) except Exception as e: - logger.debug("Push write error: %s", e) + logger.warning("Push write error (closing connection): %s", e) + try: + self._client_writer.close() + except Exception: + pass async def on_message_received(sender_key, text, timestamp, txt_type, packet_hash=None): msg_dict = { @@ -366,17 +381,19 @@ async def on_advert_received(contact): except Exception as e: logger.exception("advert_received callback error: %s", e) try: - self._save_contacts() + await self._persist_contact(contact) except Exception as e: - logger.warning("Save contacts after advert failed: %s", e) + logger.warning("Persist contact after advert failed: %s", e) async def on_contact_path_updated(pub_key, path_len, path): if isinstance(pub_key, bytes) and len(pub_key) >= 32: _write_push(bytes([PUSH_CODE_PATH_UPDATED]) + pub_key[:32]) try: - self._save_contacts() + c = self.bridge.contacts.get_by_key(pub_key) if isinstance(pub_key, bytes) else None + if c is not None: + await self._persist_contact(c) except Exception as e: - logger.warning("Save contacts after path update failed: %s", e) + logger.warning("Persist contact after path update failed: %s", e) async def on_channel_message_received( channel_name, @@ -560,6 +577,12 @@ async def _drain_writer(self) -> None: if self._client_writer: try: await self._client_writer.drain() + except (ConnectionResetError, BrokenPipeError, OSError) as e: + logger.warning("Drain failed (connection lost): %s", e) + try: + self._client_writer.close() + except Exception: + pass except Exception: pass @@ -579,6 +602,43 @@ def _write_err(self, err_code: int) -> None: # Client handling # ------------------------------------------------------------------------- + async def _heartbeat_loop(self) -> None: + """Send periodic ``RESP_CODE_CURR_TIME`` to keep the TCP connection alive.""" + try: + while self._client_writer and not self._client_writer.is_closing(): + await asyncio.sleep(self._heartbeat_interval) + if self._client_writer and not self._client_writer.is_closing(): + now = self.bridge.get_time() + self._write_frame(bytes([RESP_CODE_CURR_TIME]) + struct.pack(" None: + """Enable TCP keepalive on the underlying socket.""" + sock = writer.get_extra_info("socket") + if sock is None: + return + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + if sys.platform == "linux": + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 15) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 5) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 3) + elif sys.platform == "darwin": + # TCP_KEEPALIVE is the macOS equivalent of TCP_KEEPIDLE + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPALIVE, 15) + try: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 5) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 3) + except (AttributeError, OSError): + pass # older macOS may lack KEEPINTVL/KEEPCNT + except OSError as e: + logger.debug("Could not set TCP keepalive: %s", e) + async def _handle_client( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: @@ -591,9 +651,11 @@ async def _handle_client( self._client_reader = reader self._client_writer = writer + self._enable_tcp_keepalive(writer) self._setup_push_callbacks() logger.info("Companion client connected (port=%s)", self.port) + heartbeat_task = asyncio.create_task(self._heartbeat_loop()) try: while True: prefix = await reader.read(1) @@ -616,6 +678,11 @@ async def _handle_client( except Exception as e: logger.error("Client handler error: %s", e, exc_info=True) finally: + heartbeat_task.cancel() + try: + await heartbeat_task + except asyncio.CancelledError: + pass self._client_writer = None self._client_reader = None logger.info("Companion client disconnected (port=%s)", self.port) @@ -1321,7 +1388,7 @@ async def _cmd_add_update_contact(self, data: bytes) -> None: await self._drain_writer() if ok: try: - self._save_contacts() + await self._save_contacts() except Exception as e: logger.warning("Save contacts after add/update failed: %s", e) @@ -1333,7 +1400,10 @@ async def _cmd_remove_contact(self, data: bytes) -> None: pubkey = data[:32] ok = self.bridge.remove_contact(pubkey) if ok: - self._save_contacts() + try: + await self._save_contacts() + except Exception as e: + logger.warning("Save contacts after remove failed: %s", e) self._write_ok() if ok else self._write_err(ERR_CODE_NOT_FOUND) await self._drain_writer() @@ -1494,7 +1564,10 @@ async def _cmd_set_channel(self, data: bytes) -> None: return ok = self.bridge.set_channel(channel_idx, name, secret) if ok: - self._save_channels() + try: + await self._save_channels() + except Exception as e: + logger.warning("Save channels after set failed: %s", e) self._write_ok() if ok else self._write_err(ERR_CODE_NOT_FOUND) async def _cmd_set_flood_scope(self, data: bytes) -> None: diff --git a/src/pymc_core/node/handlers/login_response.py b/src/pymc_core/node/handlers/login_response.py index e1e507a..58e7f20 100644 --- a/src/pymc_core/node/handlers/login_response.py +++ b/src/pymc_core/node/handlers/login_response.py @@ -92,7 +92,7 @@ async def __call__(self, packet: Packet) -> None: if dest_hash != our_hash and src_hash != our_hash: return - # Find stored password and matching contact + # Find stored password and matching contact(s) if lookup_hash not in self._active_login_passwords: # This might be a telemetry response, not a login response # Forward to protocol response handler if available @@ -111,23 +111,38 @@ async def __call__(self, packet: Packet) -> None: ) return - matched_contact = None - + # Collect all contacts whose public_key first byte matches (hash collision / multiple peers) + candidates = [] for contact in self.contacts.contacts: try: - contact_pubkey = bytes.fromhex(contact.public_key) - if len(contact_pubkey) > 0 and contact_pubkey[0] == lookup_hash: - matched_contact = contact - break + pk = contact.public_key + contact_pubkey = pk if isinstance(pk, bytes) else bytes.fromhex(pk) + if len(contact_pubkey) == 32 and contact_pubkey[0] == lookup_hash: + candidates.append(contact) except Exception: continue - if not matched_contact: + if not candidates: + if self.login_callback: + await self._safe_callback( + False, + { + "error": "No contact found for login response (src_hash=0x%02x)" + % lookup_hash + }, + ) return - # Decrypt and process response - response_data = await self._decrypt_response(packet, matched_contact, encrypted_start) - if response_data: + # Try each candidate until one decrypts successfully (same shared-secret as firmware) + response_data = None + matched_contact = None + for contact in candidates: + response_data = await self._decrypt_response(packet, contact, encrypted_start) + if response_data: + matched_contact = contact + break + + if response_data and matched_contact: await self._process_login_response(response_data, matched_contact) self.clear_login_password(lookup_hash) elif self.login_callback: @@ -142,7 +157,8 @@ async def _decrypt_response( encrypted_data = packet.payload[encrypted_start:] # Calculate X25519 ECDH shared secret - contact_pubkey = bytes.fromhex(contact.public_key) + pk = contact.public_key + contact_pubkey = pk if isinstance(pk, bytes) else bytes.fromhex(pk) contact_identity = Identity(contact_pubkey) shared_secret = contact_identity.calc_shared_secret( self.local_identity.get_private_key() From d702b0c46e68075186ab744991d6b01a5e3f756f Mon Sep 17 00:00:00 2001 From: Adam Gessaman Date: Wed, 25 Feb 2026 15:37:56 -0800 Subject: [PATCH 23/50] refactor(companion): improve asynchronous handling of stats and message synchronization - Updated `_cmd_sync_next_message` to use `asyncio.to_thread` for better performance when retrieving the next message from persistence. - Enhanced stats retrieval logic to support both coroutine and non-coroutine functions, ensuring compatibility and efficiency in fetching statistics. - Improved error handling by ensuring that stats are fetched correctly from the appropriate source, maintaining robustness in the CompanionFrameServer's operations. --- src/pymc_core/companion/frame_server.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index adb623e..721c0a7 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -1162,7 +1162,7 @@ async def _cmd_send_trace_path(self, data: bytes) -> None: async def _cmd_sync_next_message(self, data: bytes) -> None: msg = self.bridge.sync_next_message() if msg is None: - msg = self._sync_next_from_persistence() + msg = await asyncio.to_thread(self._sync_next_from_persistence) if msg is None: self._write_frame(bytes([RESP_CODE_NO_MORE_MESSAGES])) return @@ -1433,9 +1433,14 @@ async def _cmd_get_stats(self, data: bytes) -> None: ): self._write_err(ERR_CODE_ILLEGAL_ARG) return - stats = ( - self.stats_getter(stats_type) if self.stats_getter else None - ) or self.bridge.get_stats(stats_type) + if self.stats_getter: + if asyncio.iscoroutinefunction(self.stats_getter): + stats = await self.stats_getter(stats_type) + else: + stats = await asyncio.to_thread(self.stats_getter, stats_type) + else: + stats = None + stats = stats or self.bridge.get_stats(stats_type) frame = bytes([RESP_CODE_STATS, stats_type]) if stats_type == STATS_TYPE_CORE: battery_mv = int(stats.get("battery_mv", 0)) From 5fb8268a6b5b11a74fe7b5d178d9258eb94f2b9f Mon Sep 17 00:00:00 2001 From: agessaman Date: Wed, 25 Feb 2026 22:02:12 -0800 Subject: [PATCH 24/50] feat(companion): enhance connection management and client timeout handling - Introduced eviction logic for managing client connections, allowing a new connection to replace an existing one. - Added `client_idle_timeout_sec` parameter to disconnect clients after a period of inactivity, improving resource management. - Updated documentation to reflect changes in connection handling and timeout behavior. --- docs/docs/companion.md | 4 +++ src/pymc_core/companion/frame_server.py | 42 +++++++++++++++++++------ 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/docs/docs/companion.md b/docs/docs/companion.md index ebdfd2e..3f020be 100644 --- a/docs/docs/companion.md +++ b/docs/docs/companion.md @@ -496,9 +496,13 @@ CompanionFrameServer( local_hash: int | None = None, stats_getter: Callable | None = None, control_handler: Any | None = None, + heartbeat_interval: int = 15, # seconds between keepalive frames + client_idle_timeout_sec: int = 90, # no data from client → disconnect and free slot ) ``` +**Connection management:** Only one client is allowed at a time. If a new connection arrives while one is already active, the server closes the existing connection and accepts the new one (same as firmware). If the client disappears without closing (e.g. kill, network drop), the slot is freed after no data is received for `client_idle_timeout_sec` seconds (default 90). Operators can tune this timeout to avoid dropping slow but live clients. + ### Supported Commands The frame server handles the following companion radio protocol commands: diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index 721c0a7..9aec4f4 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -163,7 +163,10 @@ def _build_advert_push_frames(data: dict) -> tuple[bytes, Optional[bytes]]: class CompanionFrameServer: """TCP server for the MeshCore companion frame protocol. - One client per companion at a time. Persistence is handled through + One client per companion at a time. If a new connection arrives while + one is already active, the existing connection is closed and the new + one is accepted (eviction). An idle read timeout (client_idle_timeout_sec) + frees the slot when no data is received. Persistence is handled through overridable hook methods; the base class works with in-memory stores only. """ @@ -181,6 +184,7 @@ def __init__( stats_getter: Optional[Callable] = None, control_handler: Optional[Any] = None, heartbeat_interval: int = 15, + client_idle_timeout_sec: int = 90, ): self.bridge = bridge self.companion_hash = companion_hash @@ -190,6 +194,7 @@ def __init__( self.stats_getter = stats_getter self._control_handler = control_handler self._heartbeat_interval = heartbeat_interval + self._client_idle_timeout_sec = client_idle_timeout_sec self._server: Optional[asyncio.Server] = None self._client_writer: Optional[asyncio.StreamWriter] = None self._client_reader: Optional[asyncio.StreamReader] = None @@ -642,12 +647,22 @@ def _enable_tcp_keepalive(writer: asyncio.StreamWriter) -> None: async def _handle_client( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: - """Handle a new client connection. One client at a time.""" + """Handle a new client connection. One client at a time. + If a client is already connected, the existing connection is closed + and the new one is accepted (eviction). An idle read timeout also + frees the slot when no data is received for client_idle_timeout_sec. + """ if self._client_writer: - logger.warning("Companion already has a client; rejecting new connection") - writer.close() - await writer.wait_closed() - return + logger.info( + "Companion already has a client; evicting previous connection (port=%s)", + self.port, + ) + old_writer = self._client_writer + try: + old_writer.close() + await old_writer.wait_closed() + except Exception: + pass self._client_reader = reader self._client_writer = writer @@ -658,7 +673,13 @@ async def _handle_client( heartbeat_task = asyncio.create_task(self._heartbeat_loop()) try: while True: - prefix = await reader.read(1) + try: + prefix = await asyncio.wait_for( + reader.read(1), timeout=self._client_idle_timeout_sec + ) + except asyncio.TimeoutError: + logger.info("Companion client idle timeout (port=%s)", self.port) + break if not prefix: break if prefix[0] != FRAME_INBOUND_PREFIX: @@ -683,9 +704,10 @@ async def _handle_client( await heartbeat_task except asyncio.CancelledError: pass - self._client_writer = None - self._client_reader = None - logger.info("Companion client disconnected (port=%s)", self.port) + if self._client_writer is writer: + self._client_writer = None + self._client_reader = None + logger.info("Companion client disconnected (port=%s)", self.port) # ------------------------------------------------------------------------- # Command dispatch From 75ea5306438c3c5539b60c2e3704f1b2d46552e5 Mon Sep 17 00:00:00 2001 From: agessaman Date: Wed, 25 Feb 2026 22:14:51 -0800 Subject: [PATCH 25/50] fix(companion): extend client idle timeout and drain writer after each command frame - Increased the `client_idle_timeout_sec` from 90 to 120 seconds to allow for longer periods of inactivity before disconnecting clients. - Updated documentation to reflect the new default timeout value, enhancing clarity for operators managing client connections. --- docs/docs/companion.md | 4 ++-- src/pymc_core/companion/frame_server.py | 15 +++------------ 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/docs/docs/companion.md b/docs/docs/companion.md index 3f020be..1bf0781 100644 --- a/docs/docs/companion.md +++ b/docs/docs/companion.md @@ -497,11 +497,11 @@ CompanionFrameServer( stats_getter: Callable | None = None, control_handler: Any | None = None, heartbeat_interval: int = 15, # seconds between keepalive frames - client_idle_timeout_sec: int = 90, # no data from client → disconnect and free slot + client_idle_timeout_sec: int = 120, # no data from client → disconnect and free slot ) ``` -**Connection management:** Only one client is allowed at a time. If a new connection arrives while one is already active, the server closes the existing connection and accepts the new one (same as firmware). If the client disappears without closing (e.g. kill, network drop), the slot is freed after no data is received for `client_idle_timeout_sec` seconds (default 90). Operators can tune this timeout to avoid dropping slow but live clients. +**Connection management:** Only one client is allowed at a time. If a new connection arrives while one is already active, the server closes the existing connection and accepts the new one (same as firmware). If the client disappears without closing (e.g. kill, network drop), the slot is freed after no data is received for `client_idle_timeout_sec` seconds (default 120). Operators can tune this timeout to avoid dropping slow but live clients. ### Supported Commands diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index 9aec4f4..eb94abc 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -184,7 +184,7 @@ def __init__( stats_getter: Optional[Callable] = None, control_handler: Optional[Any] = None, heartbeat_interval: int = 15, - client_idle_timeout_sec: int = 90, + client_idle_timeout_sec: int = 120, ): self.bridge = bridge self.companion_hash = companion_hash @@ -292,11 +292,7 @@ def _write_push(data: bytes) -> None: self._client_writer.write(frame) asyncio.create_task(self._drain_writer()) except Exception as e: - logger.warning("Push write error (closing connection): %s", e) - try: - self._client_writer.close() - except Exception: - pass + logger.warning("Push write error: %s", e) async def on_message_received(sender_key, text, timestamp, txt_type, packet_hash=None): msg_dict = { @@ -584,12 +580,6 @@ async def _drain_writer(self) -> None: await self._client_writer.drain() except (ConnectionResetError, BrokenPipeError, OSError) as e: logger.warning("Drain failed (connection lost): %s", e) - try: - self._client_writer.close() - except Exception: - pass - except Exception: - pass def _write_frame(self, data: bytes) -> None: """Send a frame to the connected client (outbound format).""" @@ -692,6 +682,7 @@ async def _handle_client( break payload = await reader.readexactly(frame_len) await self._handle_cmd(payload) + await self._drain_writer() except asyncio.IncompleteReadError: pass except (ConnectionResetError, BrokenPipeError): From 8c94631545b8d647b9c6b860567961a7961ab51a Mon Sep 17 00:00:00 2001 From: agessaman Date: Thu, 26 Feb 2026 06:12:33 -0800 Subject: [PATCH 26/50] feat(companion): add SNR and RSSI fields to message handling - Extracted SNR and RSSI values from network information in the CompanionBase class, allowing for better monitoring of network conditions. - Updated QueuedMessage model to include SNR and RSSI attributes, enhancing message detail. - Modified CompanionFrameServer to handle SNR and RSSI in message processing, improving data accuracy in communication. --- src/pymc_core/companion/companion_base.py | 10 ++++++++++ src/pymc_core/companion/frame_server.py | 7 ++++++- src/pymc_core/companion/models.py | 2 ++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py index 824b631..64b4bc9 100644 --- a/src/pymc_core/companion/companion_base.py +++ b/src/pymc_core/companion/companion_base.py @@ -1509,6 +1509,11 @@ async def _handle_new_channel_message(self, data: dict) -> None: # sender and message separately. Use full_content (not message_text) so client can split. # Strip trailing nulls so frame matches firmware (exact string length, no padding). display_text = (data.get("full_content", data.get("message_text", "")) or "").rstrip("\x00") + # Extract SNR/RSSI from network info if available + network_info = data.get("network_info", {}) + snr = network_info.get("snr") + rssi = network_info.get("rssi") + msg = QueuedMessage( sender_key=b"", txt_type=0, @@ -1517,8 +1522,11 @@ async def _handle_new_channel_message(self, data: dict) -> None: is_channel=True, channel_idx=channel_idx, path_len=path_len, + snr=snr if snr is not None else 0.0, + rssi=rssi if rssi is not None else 0, ) self.message_queue.push(msg) + await self._fire_callbacks( "channel_message_received", data.get("channel_name", ""), @@ -1528,6 +1536,8 @@ async def _handle_new_channel_message(self, data: dict) -> None: path_len, channel_idx, pkt_hash, + snr, + rssi, ) async def _fire_callbacks(self, event_name: str, *args: Any) -> None: diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index eb94abc..8a2f165 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -404,6 +404,8 @@ async def on_channel_message_received( path_len=0, channel_idx=0, packet_hash=None, + snr=None, + rssi=None, ): msg_dict = { "sender_key": b"", @@ -1184,11 +1186,14 @@ async def _cmd_sync_next_message(self, data: bytes) -> None: txt_type = 0 text_bytes = (msg.text or "").rstrip("\x00").encode("utf-8", errors="replace") if self._app_target_ver >= 3: + snr_byte = max(-128, min(127, int(round(msg.snr * 4)))) + if snr_byte < 0: + snr_byte += 256 frame = ( bytes( [ RESP_CODE_CHANNEL_MSG_RECV_V3, - 0, + snr_byte & 0xFF, 0, 0, msg.channel_idx, diff --git a/src/pymc_core/companion/models.py b/src/pymc_core/companion/models.py index f51a2fe..dd6a8cd 100644 --- a/src/pymc_core/companion/models.py +++ b/src/pymc_core/companion/models.py @@ -98,3 +98,5 @@ class QueuedMessage: is_channel: bool = False channel_idx: int = 0 # only meaningful if is_channel path_len: int = 0 + snr: float = 0.0 + rssi: int = 0 From 090dd4f0f5162bb196db49e3f00ccabe4095acad Mon Sep 17 00:00:00 2001 From: agessaman Date: Fri, 27 Feb 2026 16:51:32 -0800 Subject: [PATCH 27/50] refactor(companion): streamline advert frame building and contact handling - TODO #7 - Refactored `_build_advert_push_frames` to accept a `Contact` object instead of a dictionary, improving type safety and clarity. - Introduced `_contact_from_dict` to construct `Contact` instances from dictionaries, enhancing code readability and maintainability. - Updated the handling of optional fields and data extraction within the frame building process, ensuring consistency and reducing redundancy. --- src/pymc_core/companion/frame_server.py | 125 +++++++++++------------- tests/test_frame_server.py | 99 +++++++++++++++++++ 2 files changed, 154 insertions(+), 70 deletions(-) create mode 100644 tests/test_frame_server.py diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index 8a2f165..c727359 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -119,47 +119,68 @@ logger = logging.getLogger("CompanionFrameServer") -def _build_advert_push_frames(data: dict) -> tuple[bytes, Optional[bytes]]: +def _build_advert_push_frames(contact: Contact) -> tuple[bytes, Optional[bytes]]: """Build PUSH_CODE_ADVERT short frame and optional PUSH_CODE_NEW_ADVERT - full frame from extracted data. Thread-safe for ``asyncio.to_thread``.""" - pubkey_b = data.get("pubkey_b", b"") + full frame from contact. Thread-safe for ``asyncio.to_thread``.""" + pubkey_b = contact.public_key if isinstance(pubkey_b, bytes): pubkey_b = pubkey_b[:32].ljust(32, b"\x00") else: pubkey_b = b"\x00" * 32 short = bytes([PUSH_CODE_ADVERT]) + pubkey_b - if not data.get("include_full"): + if not contact.name: return (short, None) - op = data.get("out_path", b"") - op = (op if isinstance(op, bytes) else bytes(op or []))[:MAX_PATH_SIZE].ljust( - MAX_PATH_SIZE, b"\x00" - ) - nb = data.get("name_b", b"") + op = contact.out_path if isinstance(contact.out_path, bytes) else bytes(contact.out_path or []) + op = op[:MAX_PATH_SIZE].ljust(MAX_PATH_SIZE, b"\x00") nb = ( - nb - if isinstance(nb, bytes) - else (nb.encode("utf-8", errors="replace") if isinstance(nb, str) else b"") + contact.name.encode("utf-8", errors="replace") + if isinstance(contact.name, str) + else (contact.name if isinstance(contact.name, bytes) else b"") )[:32].ljust(32, b"\x00") + opl_byte = 0xFF if contact.out_path_len < 0 else min(contact.out_path_len, 255) full = ( bytes([PUSH_CODE_NEW_ADVERT]) + pubkey_b - + bytes( - [ - data.get("adv_type", 0), - data.get("flags", 0), - data.get("opl_byte", 0xFF), - ] - ) + + bytes([contact.adv_type, contact.flags, opl_byte]) + op + nb - + struct.pack(" Contact: + """Build a Contact from a dict (e.g. legacy callback payload).""" + pub = d.get("public_key", b"") + if isinstance(pub, str): + pub = bytes.fromhex(pub) if pub else b"" + elif not isinstance(pub, bytes): + pub = b"" + pub = pub[:32].ljust(32, b"\x00") + out_path = d.get("out_path", b"") + if isinstance(out_path, str): + out_path = bytes.fromhex(out_path) if out_path else b"" + elif isinstance(out_path, (list, bytearray)): + out_path = bytes(out_path) + else: + out_path = bytes(out_path) if out_path else b"" + return Contact( + public_key=pub, + name=(d.get("name") or ""), + adv_type=d.get("adv_type", 0), + flags=d.get("flags", 0), + out_path_len=d.get("out_path_len", -1), + out_path=out_path, + last_advert_timestamp=d.get("last_advert_timestamp", d.get("last_advert", 0)), + lastmod=d.get("lastmod", 0), + gps_lat=float(d.get("gps_lat", 0)), + gps_lon=float(d.get("gps_lon", 0)), + ) + + class CompanionFrameServer: """TCP server for the MeshCore companion frame protocol. @@ -322,59 +343,23 @@ async def on_advert_received(contact): if isinstance(contact, dict): pubkey = contact.get("public_key", b"") if isinstance(pubkey, str): - pubkey = bytes.fromhex(pubkey) + pubkey = bytes.fromhex(pubkey) if pubkey else b"" + elif not isinstance(pubkey, bytes): + pubkey = b"" + if len(pubkey) < 32: + return + contact = _contact_from_dict(contact) else: pubkey = getattr( contact, "public_key", getattr(contact, "pub_key", b""), ) - if isinstance(pubkey, str): - pubkey = bytes.fromhex(pubkey) - if not isinstance(pubkey, bytes) or len(pubkey) < 32: - return - pubkey_b = pubkey[:32].ljust(32, b"\x00") - include_full = not isinstance(contact, dict) and getattr(contact, "name", None) - data = { - "pubkey_b": pubkey_b, - "include_full": bool(include_full), - "adv_type": 0, - "flags": 0, - "opl_byte": 0xFF, - "out_path": b"\x00" * MAX_PATH_SIZE, - "name_b": b"\x00" * 32, - "last_advert": 0, - "gps_lat_int": 0, - "gps_lon_int": 0, - "lastmod": 0, - } - if include_full: - data["adv_type"] = getattr(contact, "adv_type", 0) - data["flags"] = getattr(contact, "flags", 0) - opl = getattr(contact, "out_path_len", -1) - data["opl_byte"] = 0xFF if opl < 0 else min(opl, 255) - out_path = getattr(contact, "out_path", b"") or b"" - if isinstance(out_path, str): - out_path = bytes.fromhex(out_path) if out_path else b"" - elif isinstance(out_path, (list, bytearray)): - out_path = bytes(out_path) - data["out_path"] = bytes(out_path)[:MAX_PATH_SIZE].ljust(MAX_PATH_SIZE, b"\x00") - name = getattr(contact, "name", "") or "" - if isinstance(name, str): - data["name_b"] = name.encode("utf-8", errors="replace")[:32].ljust( - 32, b"\x00" - ) - elif isinstance(name, bytes): - data["name_b"] = name[:32].ljust(32, b"\x00") - else: - data["name_b"] = b"\x00" * 32 - data["last_advert"] = getattr(contact, "last_advert_timestamp", 0) - data["lastmod"] = getattr(contact, "lastmod", 0) - gps_lat = getattr(contact, "gps_lat", 0.0) - gps_lon = getattr(contact, "gps_lon", 0.0) - data["gps_lat_int"] = int(gps_lat * 1e6) - data["gps_lon_int"] = int(gps_lon * 1e6) - short, full = await asyncio.to_thread(_build_advert_push_frames, data) + if isinstance(pubkey, str): + pubkey = bytes.fromhex(pubkey) + if not isinstance(pubkey, bytes) or len(pubkey) < 32: + return + short, full = await asyncio.to_thread(_build_advert_push_frames, contact) _write_push(short) if full is not None: await asyncio.sleep(0) diff --git a/tests/test_frame_server.py b/tests/test_frame_server.py new file mode 100644 index 0000000..b62ccdc --- /dev/null +++ b/tests/test_frame_server.py @@ -0,0 +1,99 @@ +"""Tests for CompanionFrameServer and advert push frame construction.""" + +import struct + +from pymc_core.companion.constants import ( + MAX_PATH_SIZE, + PUB_KEY_SIZE, + PUSH_CODE_ADVERT, + PUSH_CODE_NEW_ADVERT, +) +from pymc_core.companion.frame_server import _build_advert_push_frames +from pymc_core.companion.models import Contact + + +def test_build_advert_push_frames_short_only_when_no_name(): + """Contact with empty name yields only short frame; full is None.""" + pubkey = bytes(range(32)) + contact = Contact(public_key=pubkey, name="") + short, full = _build_advert_push_frames(contact) + assert full is None + assert len(short) == 1 + PUB_KEY_SIZE + assert short[0] == PUSH_CODE_ADVERT + assert short[1:33] == pubkey + + +def test_build_advert_push_frames_short_and_full_when_has_name(): + """Contact with name yields short frame and full NEW_ADVERT frame.""" + pubkey = bytes(range(32)) + contact = Contact( + public_key=pubkey, + name="Alice", + adv_type=1, + flags=2, + out_path_len=0, + out_path=b"", + last_advert_timestamp=1000, + lastmod=2000, + gps_lat=52.5, + gps_lon=-1.7, + ) + short, full = _build_advert_push_frames(contact) + assert full is not None + # Short frame + assert len(short) == 1 + PUB_KEY_SIZE + assert short[0] == PUSH_CODE_ADVERT + assert short[1:33] == pubkey + # Full frame: code(1) + pubkey(32) + adv_type,flags,opl(3) + path(64) + name(32) + # + last_advert(4) + gps_lat(4) + gps_lon(4) + lastmod(4) + expected_full_len = 1 + 32 + 3 + MAX_PATH_SIZE + 32 + 4 + 4 + 4 + 4 + assert len(full) == expected_full_len + assert full[0] == PUSH_CODE_NEW_ADVERT + assert full[1:33] == pubkey + assert full[33] == 1 # adv_type + assert full[34] == 2 # flags + assert full[35] == 0 # opl_byte (out_path_len 0) + out_path = full[36 : 36 + MAX_PATH_SIZE] + assert out_path == b"\x00" * MAX_PATH_SIZE + name_b = full[36 + MAX_PATH_SIZE : 36 + MAX_PATH_SIZE + 32] + assert name_b.startswith(b"Alice") + assert name_b.rstrip(b"\x00") == b"Alice" + offset = 36 + MAX_PATH_SIZE + 32 + assert struct.unpack(" Date: Fri, 27 Feb 2026 19:49:17 -0800 Subject: [PATCH 28/50] feat(companion): improve CompanionRadio advert pipeline to use Contact natively -- TODO #4 - Added `_apply_advert_to_stores` method to streamline the application of adverts to the contact store and path cache, improving contact handling efficiency. - Updated `CompanionBridge` to utilize the new method for processing adverts during node discovery, ensuring contacts are added or updated correctly. - Refactored `Contact` model to include a `from_dict` class method for constructing contact instances from various data sources, enhancing flexibility and maintainability. - Improved error handling in advert processing to ensure robustness during contact updates and path cache management. --- src/pymc_core/companion/companion_base.py | 54 ++++++++++++ src/pymc_core/companion/companion_bridge.py | 71 ++------------- src/pymc_core/companion/frame_server.py | 31 +------ src/pymc_core/companion/models.py | 74 +++++++++++++++- src/pymc_core/node/handlers/advert.py | 2 + tests/test_companion_bridge.py | 96 +++++++++++++++++++++ 6 files changed, 233 insertions(+), 95 deletions(-) diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py index 64b4bc9..71cd5c7 100644 --- a/src/pymc_core/companion/companion_base.py +++ b/src/pymc_core/companion/companion_base.py @@ -595,6 +595,54 @@ def should_overwrite_when_full(self) -> bool: """Check if overwrite-oldest is enabled. Mirrors C++ shouldOverwriteWhenFull.""" return bool(self.prefs.autoadd_config & AUTOADD_OVERWRITE_OLDEST) + async def _apply_advert_to_stores( + self, contact: Contact, inbound_path: Optional[bytes] = None + ) -> Optional[Contact]: + """Apply advert to ContactStore and PathCache. Shared by Bridge and NODE_DISCOVERED. + + Mirrors C++ BaseChatMesh::onAdvertRecv (existing update, auto-add filter, + overwrite when full). Returns the Contact if added or updated, None otherwise. + Path cache is updated for all valid contacts (pub_key >= 7, name non-empty). + """ + try: + if len(contact.public_key) < 7 or not contact.name: + return None + inbound_path = inbound_path or b"" + self.path_cache.update( + AdvertPath( + public_key_prefix=contact.public_key[:7], + name=contact.name, + path_len=len(inbound_path), + path=inbound_path, + recv_timestamp=int(time.time()), + ) + ) + existing = self.contacts.get_by_key(contact.public_key) + if existing is not None: + contact.out_path_len = existing.out_path_len + contact.out_path = existing.out_path + contact.flags = existing.flags + contact.sync_since = existing.sync_since + self.contacts.update(contact) + return contact + if not self.should_auto_add_contact_type(contact.adv_type): + logger.debug("Auto-add filtered: type %d not allowed", contact.adv_type) + return None + if self.should_overwrite_when_full() and self.contacts.is_full(): + ok, overwritten = self.contacts.add_or_overwrite(contact) + if ok and overwritten: + await self._fire_callbacks("contact_deleted", overwritten) + elif not ok: + await self._fire_callbacks("contacts_full") + return contact if ok else None + added = self.contacts.add(contact) + if not added and self.contacts.is_full(): + await self._fire_callbacks("contacts_full") + return contact if added else None + except Exception as e: + logger.error("Error applying advert to stores: %s", e) + return None + # ------------------------------------------------------------------------- # Push Callbacks # ------------------------------------------------------------------------- @@ -1448,6 +1496,12 @@ async def _handle_mesh_event(self, event_type: str, data: dict) -> None: elif event_type == MeshEvents.CONTACT_UPDATED: pass elif event_type == MeshEvents.NODE_DISCOVERED: + now = int(time.time()) + contact = Contact.from_dict(data, now=now) + if len(contact.public_key) >= 7 and contact.name: + applied = await self._apply_advert_to_stores(contact, None) + if applied is not None: + await self._fire_callbacks("advert_received", applied) await self._fire_callbacks("node_discovered", data) elif event_type == MeshEvents.TELEMETRY_UPDATED: await self._fire_callbacks("telemetry_response", data) diff --git a/src/pymc_core/companion/companion_bridge.py b/src/pymc_core/companion/companion_bridge.py index 7ba1c23..cb479fe 100644 --- a/src/pymc_core/companion/companion_bridge.py +++ b/src/pymc_core/companion/companion_bridge.py @@ -35,7 +35,7 @@ DEFAULT_MAX_CONTACTS, DEFAULT_OFFLINE_QUEUE_SIZE, ) -from .models import AdvertPath, Contact +from .models import Contact logger = logging.getLogger("CompanionBridge") @@ -306,77 +306,20 @@ async def process_received_packet(self, packet: Packet) -> None: async def _update_stores_from_advert(self, packet: Packet, advert_data: dict): """Update ContactStore and PathCache from advert result. - Mirrors C++ BaseChatMesh::onAdvertRecv (BaseChatMesh.cpp:106-170): - - Existing contacts are always updated (name, GPS, etc.) - - New contacts are subject to auto-add type filtering - - When store is full, overwrite-oldest replaces the oldest non-favourite - - Returns the Contact or None. + Builds Contact and inbound path from packet, then delegates to + _apply_advert_to_stores. Returns the Contact if added or updated, None otherwise. """ try: - pub_key = bytes.fromhex(advert_data.get("public_key", "")) - if len(pub_key) < 7: - return None - name = advert_data.get("name", "") - if not name: + contact = Contact.from_dict(advert_data, now=int(time.time())) + if len(contact.public_key) < 7 or not contact.name: return None path_len = getattr(packet, "path_len", 0) or 0 path = getattr(packet, "path", bytearray()) or bytearray() effective_len = path_len if path_len > 0 else len(path) inbound_path = bytes(path[:effective_len]) if effective_len > 0 else b"" - now = int(time.time()) - last_advert_ts = advert_data.get("advert_timestamp", 0) - if last_advert_ts > now: - last_advert_ts = now - adv_type = advert_data.get("contact_type_id", 0) - contact = Contact( - public_key=pub_key, - name=name, - adv_type=adv_type, - gps_lat=advert_data.get("latitude", 0.0), - gps_lon=advert_data.get("longitude", 0.0), - lastmod=now, - last_advert_timestamp=last_advert_ts, - out_path_len=-1, - out_path=b"", - ) - - existing = self.contacts.get_by_key(pub_key) - if existing is not None: - # Always update existing contacts (C++ BaseChatMesh.cpp:158-167). - # Preserve fields that adverts don't carry: the firmware only - # updates name, type, gps, last_advert_timestamp, lastmod. - contact.out_path_len = existing.out_path_len - contact.out_path = existing.out_path - contact.flags = existing.flags - contact.sync_since = existing.sync_since - self.contacts.update(contact) - elif not self.should_auto_add_contact_type(adv_type): - # Type not allowed — still fire callback so app sees the advert - logger.debug("Auto-add filtered: type %d not allowed", adv_type) - elif self.should_overwrite_when_full() and self.contacts.is_full(): - ok, overwritten = self.contacts.add_or_overwrite(contact) - if ok and overwritten: - await self._fire_callbacks("contact_deleted", overwritten) - elif not ok: - await self._fire_callbacks("contacts_full") - else: - added = self.contacts.add(contact) - if not added and self.contacts.is_full(): - await self._fire_callbacks("contacts_full") - - self.path_cache.update( - AdvertPath( - public_key_prefix=pub_key[:7], - name=name, - path_len=len(inbound_path), - path=inbound_path, - recv_timestamp=int(time.time()), - ) - ) - return contact + return await self._apply_advert_to_stores(contact, inbound_path) except Exception as e: - logger.error(f"Error updating stores from advert: {e}") + logger.error("Error updating stores from advert: %s", e) return None # ------------------------------------------------------------------------- diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index c727359..e55f4a4 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -152,35 +152,6 @@ def _build_advert_push_frames(contact: Contact) -> tuple[bytes, Optional[bytes]] return (short, full) -def _contact_from_dict(d: dict) -> Contact: - """Build a Contact from a dict (e.g. legacy callback payload).""" - pub = d.get("public_key", b"") - if isinstance(pub, str): - pub = bytes.fromhex(pub) if pub else b"" - elif not isinstance(pub, bytes): - pub = b"" - pub = pub[:32].ljust(32, b"\x00") - out_path = d.get("out_path", b"") - if isinstance(out_path, str): - out_path = bytes.fromhex(out_path) if out_path else b"" - elif isinstance(out_path, (list, bytearray)): - out_path = bytes(out_path) - else: - out_path = bytes(out_path) if out_path else b"" - return Contact( - public_key=pub, - name=(d.get("name") or ""), - adv_type=d.get("adv_type", 0), - flags=d.get("flags", 0), - out_path_len=d.get("out_path_len", -1), - out_path=out_path, - last_advert_timestamp=d.get("last_advert_timestamp", d.get("last_advert", 0)), - lastmod=d.get("lastmod", 0), - gps_lat=float(d.get("gps_lat", 0)), - gps_lon=float(d.get("gps_lon", 0)), - ) - - class CompanionFrameServer: """TCP server for the MeshCore companion frame protocol. @@ -348,7 +319,7 @@ async def on_advert_received(contact): pubkey = b"" if len(pubkey) < 32: return - contact = _contact_from_dict(contact) + contact = Contact.from_dict(contact) else: pubkey = getattr( contact, diff --git a/src/pymc_core/companion/models.py b/src/pymc_core/companion/models.py index dd6a8cd..9279de8 100644 --- a/src/pymc_core/companion/models.py +++ b/src/pymc_core/companion/models.py @@ -2,8 +2,9 @@ from __future__ import annotations +import time from dataclasses import dataclass -from typing import Optional +from typing import Any, Optional @dataclass @@ -22,6 +23,77 @@ class Contact: gps_lon: float = 0.0 # degrees sync_since: int = 0 # for filtered iteration + @classmethod + def from_dict( + cls, + d: dict[str, Any], + *, + now: Optional[int] = None, + ) -> "Contact": + """Build a Contact from a dict (event_data, advert_data, or serialized Contact). + + event_data uses: public_key, name, contact_type (id), lat, lon, + advert_timestamp, timestamp. + advert_data uses: public_key, name, contact_type_id, latitude, longitude, + flags, advert_timestamp, timestamp. + Serialized Contact dicts (ContactStore.to_dicts) use the same keys as the + dataclass: gps_lat, gps_lon, last_advert_timestamp, lastmod, out_path, + out_path_len, sync_since. + """ + if now is None: + now = int(time.time()) + pub = d.get("public_key", b"") + if isinstance(pub, str): + pub = bytes.fromhex(pub) if pub else b"" + elif not isinstance(pub, bytes): + pub = b"" + pub = pub[:32].ljust(32, b"\x00") + name = (d.get("name") or "") or "" + adv_type_raw = d.get("contact_type_id", d.get("adv_type", d.get("contact_type", 0))) + if isinstance(adv_type_raw, int): + adv_type = adv_type_raw + elif adv_type_raw is None: + adv_type = 0 + else: + try: + adv_type = int(adv_type_raw) + except (TypeError, ValueError): + adv_type = 0 + gps_lat = float(d.get("lat", d.get("latitude", d.get("gps_lat", 0.0)))) + gps_lon = float(d.get("lon", d.get("longitude", d.get("gps_lon", 0.0)))) + last_advert_ts = d.get("advert_timestamp", d.get("last_advert_timestamp", 0)) + last_advert_ts = int(last_advert_ts) if last_advert_ts is not None else 0 + if last_advert_ts > now: + last_advert_ts = now + lastmod_val = d.get("timestamp", d.get("lastmod", now)) + lastmod_val = int(lastmod_val) if lastmod_val is not None else now + flags_val = d.get("flags", 0) + flags_val = int(flags_val) if flags_val is not None else 0 + out_path = d.get("out_path", b"") + if isinstance(out_path, str): + out_path = bytes.fromhex(out_path) if out_path else b"" + elif isinstance(out_path, (list, bytearray)): + out_path = bytes(out_path) + else: + out_path = bytes(out_path) if out_path else b"" + out_path_len_val = d.get("out_path_len", -1) + out_path_len_val = int(out_path_len_val) if out_path_len_val is not None else -1 + sync_since_val = d.get("sync_since", 0) + sync_since_val = int(sync_since_val) if sync_since_val is not None else 0 + return cls( + public_key=pub, + name=name, + adv_type=adv_type, + flags=flags_val, + out_path_len=out_path_len_val, + out_path=out_path, + last_advert_timestamp=last_advert_ts, + lastmod=lastmod_val, + gps_lat=gps_lat, + gps_lon=gps_lon, + sync_since=sync_since_val, + ) + @dataclass class Channel: diff --git a/src/pymc_core/node/handlers/advert.py b/src/pymc_core/node/handlers/advert.py index 8095b6c..4e2174c 100644 --- a/src/pymc_core/node/handlers/advert.py +++ b/src/pymc_core/node/handlers/advert.py @@ -143,6 +143,8 @@ async def __call__(self, packet: Packet) -> Optional[Dict[str, Any]]: "contact_type": contact_type_id, "lat": lat, "lon": lon, + "advert_timestamp": advert_timestamp, + "timestamp": int(time.time()), "snr": advert_data["snr"], "rssi": advert_data["rssi"], } diff --git a/tests/test_companion_bridge.py b/tests/test_companion_bridge.py index f07e53e..6d17d11 100644 --- a/tests/test_companion_bridge.py +++ b/tests/test_companion_bridge.py @@ -3,6 +3,7 @@ import pytest from pymc_core.companion import CompanionBridge +from pymc_core.companion.constants import ADV_TYPE_CHAT, AUTOADD_CHAT from pymc_core.companion.models import Contact from pymc_core.node.events import MeshEvents from pymc_core.protocol import LocalIdentity, Packet @@ -244,6 +245,101 @@ async def test_send_binary_req_with_contact(self): assert len(injector.calls) == 1 +# --------------------------------------------------------------------------- +# NODE_DISCOVERED -> advert pipeline (contact store + advert_received) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionBridgeNodeDiscoveredAdvertPipeline: + async def test_node_discovered_adds_contact_and_fires_advert_received(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + peer = LocalIdentity() + pub_key_hex = peer.get_public_key().hex() + event_data = { + "public_key": pub_key_hex, + "name": "DiscoveredNode", + "contact_type": ADV_TYPE_CHAT, + "lat": 52.0, + "lon": -1.0, + "advert_timestamp": 1000, + "timestamp": 1001, + "snr": 5.0, + "rssi": -80, + } + advert_received_calls = [] + + def on_advert(c): + advert_received_calls.append(c) + + bridge.on_advert_received(on_advert) + await bridge._handle_mesh_event(MeshEvents.NODE_DISCOVERED, event_data) + assert bridge.contacts.get_count() == 1 + assert len(advert_received_calls) == 1 + assert advert_received_calls[0].name == "DiscoveredNode" + assert advert_received_calls[0].public_key == peer.get_public_key() + + async def test_node_discovered_fires_node_discovered_even_when_filtered(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + bridge.prefs.manual_add_contacts = 1 + bridge.prefs.autoadd_config = AUTOADD_CHAT + peer = LocalIdentity() + event_data = { + "public_key": peer.get_public_key().hex(), + "name": "RepeaterNode", + "contact_type": 2, + "lat": 0.0, + "lon": 0.0, + "advert_timestamp": 1000, + "timestamp": 1000, + "snr": 0.0, + "rssi": 0, + } + node_discovered_calls = [] + advert_received_calls = [] + + def on_node(data): + node_discovered_calls.append(data) + + def on_advert(c): + advert_received_calls.append(c) + + bridge.on_node_discovered(on_node) + bridge.on_advert_received(on_advert) + await bridge._handle_mesh_event(MeshEvents.NODE_DISCOVERED, event_data) + assert bridge.contacts.get_count() == 0 + assert len(advert_received_calls) == 0 + assert len(node_discovered_calls) == 1 + assert node_discovered_calls[0]["name"] == "RepeaterNode" + + async def test_update_stores_from_advert_adds_contact_and_returns_it(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + peer = LocalIdentity() + pub_key_hex = peer.get_public_key().hex() + packet = Packet() + packet.path_len = 0 + packet.path = bytearray() + advert_data = { + "public_key": pub_key_hex, + "name": "AdvertNode", + "contact_type_id": ADV_TYPE_CHAT, + "latitude": 0.0, + "longitude": 0.0, + "advert_timestamp": 1000, + } + contact = await bridge._update_stores_from_advert(packet, advert_data) + assert contact is not None + assert contact.name == "AdvertNode" + assert bridge.contacts.get_count() == 1 + contact2 = await bridge._update_stores_from_advert(packet, advert_data) + assert contact2 is not None + assert contact2.name == "AdvertNode" + assert bridge.contacts.get_count() == 1 + + # --------------------------------------------------------------------------- # Deduplication (direct messages by packet_hash) # --------------------------------------------------------------------------- From 0b6e7e38596007338bfdd67dc943498bbdacac01 Mon Sep 17 00:00:00 2001 From: agessaman Date: Fri, 27 Feb 2026 21:36:54 -0800 Subject: [PATCH 29/50] feat(companion): enhance text message sending and advert handling - TODO #3 - Added `wait_for_ack` parameter to `send_text_message`, allowing for non-blocking message sending. - Updated documentation to clarify the behavior of `wait_for_ack` and its impact on message acknowledgment. - Refactored advert handling in `CompanionBridge` to utilize `inbound_path`, improving the processing of node discovery events. - Enhanced test coverage for node discovery and advert received callbacks, ensuring correct handling of single-path events. --- src/pymc_core/companion/companion_base.py | 29 ++++++++-- src/pymc_core/companion/companion_bridge.py | 27 +-------- src/pymc_core/companion/frame_server.py | 2 +- src/pymc_core/node/handlers/advert.py | 5 ++ tests/test_companion_bridge.py | 61 ++++++++++++++++----- 5 files changed, 79 insertions(+), 45 deletions(-) diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py index 71cd5c7..db67698 100644 --- a/src/pymc_core/companion/companion_base.py +++ b/src/pymc_core/companion/companion_base.py @@ -1046,8 +1046,14 @@ async def send_text_message( text: str, txt_type: int = TXT_TYPE_PLAIN, attempt: int = 1, + wait_for_ack: bool = True, ) -> SentResult: - """Send a direct text message to a contact.""" + """Send a direct text message to a contact. + + When wait_for_ack is True (default), blocks until ACK or timeout. + When wait_for_ack is False, returns as soon as the packet is handed off; + ACK (if any) is still tracked and will trigger send_confirmed later. + """ contact = self.contacts.get_by_key(pub_key) if not contact: logger.warning(f"Contact not found for key {pub_key.hex()[:12]}...") @@ -1067,7 +1073,19 @@ async def send_text_message( ) self._apply_flood_scope(pkt) self._track_pending_ack(ack_crc) - success = await self._send_packet(pkt, wait_for_ack=True) + if wait_for_ack: + success = await self._send_packet(pkt, wait_for_ack=True) + if success: + self.stats.record_tx(is_flood=is_flood) + else: + self.stats.record_tx_error() + return SentResult( + success=success, + is_flood=is_flood, + expected_ack=ack_crc, + timeout_ms=None, + ) + success = await self._send_packet(pkt, wait_for_ack=False) if success: self.stats.record_tx(is_flood=is_flood) else: @@ -1076,7 +1094,7 @@ async def send_text_message( success=success, is_flood=is_flood, expected_ack=ack_crc, - timeout_ms=None, + timeout_ms=DEFAULT_RESPONSE_TIMEOUT_MS, ) except Exception as e: logger.error(f"Error sending text message: {e}") @@ -1496,10 +1514,13 @@ async def _handle_mesh_event(self, event_type: str, data: dict) -> None: elif event_type == MeshEvents.CONTACT_UPDATED: pass elif event_type == MeshEvents.NODE_DISCOVERED: + # Advert pipeline (single path): all adverts applied here; one event + # -> one store update and at most one advert_received (Bridge and Radio). now = int(time.time()) contact = Contact.from_dict(data, now=now) if len(contact.public_key) >= 7 and contact.name: - applied = await self._apply_advert_to_stores(contact, None) + inbound_path = data.get("inbound_path") + applied = await self._apply_advert_to_stores(contact, inbound_path) if applied is not None: await self._fire_callbacks("advert_received", applied) await self._fire_callbacks("node_discovered", data) diff --git a/src/pymc_core/companion/companion_bridge.py b/src/pymc_core/companion/companion_bridge.py index cb479fe..6acd48d 100644 --- a/src/pymc_core/companion/companion_bridge.py +++ b/src/pymc_core/companion/companion_bridge.py @@ -10,7 +10,6 @@ import asyncio import logging -import time from typing import Any, Callable, Optional from ..node.handlers import create_core_handlers @@ -35,7 +34,6 @@ DEFAULT_MAX_CONTACTS, DEFAULT_OFFLINE_QUEUE_SIZE, ) -from .models import Contact logger = logging.getLogger("CompanionBridge") @@ -290,11 +288,7 @@ async def process_received_packet(self, packet: Packet) -> None: handler = self._handlers.get(ptype) if handler: try: - result = await handler(packet) - if ptype == PAYLOAD_TYPE_ADVERT and result: - contact = await self._update_stores_from_advert(packet, result) - if contact: - await self._fire_callbacks("advert_received", contact) + await handler(packet) except Exception as e: logger.error(f"Handler error for type {ptype:02X}: {e}") @@ -303,25 +297,6 @@ async def process_received_packet(self, packet: Packet) -> None: # No duplicate call here — it would cause double decryption and could # deliver the result to response waiters twice. - async def _update_stores_from_advert(self, packet: Packet, advert_data: dict): - """Update ContactStore and PathCache from advert result. - - Builds Contact and inbound path from packet, then delegates to - _apply_advert_to_stores. Returns the Contact if added or updated, None otherwise. - """ - try: - contact = Contact.from_dict(advert_data, now=int(time.time())) - if len(contact.public_key) < 7 or not contact.name: - return None - path_len = getattr(packet, "path_len", 0) or 0 - path = getattr(packet, "path", bytearray()) or bytearray() - effective_len = path_len if path_len > 0 else len(path) - inbound_path = bytes(path[:effective_len]) if effective_len > 0 else b"" - return await self._apply_advert_to_stores(contact, inbound_path) - except Exception as e: - logger.error("Error updating stores from advert: %s", e) - return None - # ------------------------------------------------------------------------- # Abstract method implementations # ------------------------------------------------------------------------- diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index e55f4a4..905300e 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -946,7 +946,7 @@ async def _cmd_send_txt_msg(self, data: bytes) -> None: else bytes.fromhex(contact.public_key) ) result = await self.bridge.send_text_message( - pubkey, text, txt_type=txt_type, attempt=attempt + 1 + pubkey, text, txt_type=txt_type, attempt=attempt + 1, wait_for_ack=False ) if result.success: ack = result.expected_ack or 0 diff --git a/src/pymc_core/node/handlers/advert.py b/src/pymc_core/node/handlers/advert.py index 4e2174c..b130ac8 100644 --- a/src/pymc_core/node/handlers/advert.py +++ b/src/pymc_core/node/handlers/advert.py @@ -137,6 +137,10 @@ async def __call__(self, packet: Packet) -> Optional[Dict[str, Any]]: # Publish so companion/app receives node-discovered and advert_received callbacks if self.event_service: try: + path_len = getattr(packet, "path_len", 0) or 0 + path = getattr(packet, "path", bytearray()) or bytearray() + effective_len = path_len if path_len > 0 else len(path) + inbound_path = bytes(path[:effective_len]) if effective_len > 0 else b"" event_data = { "public_key": pubkey_hex, "name": name, @@ -147,6 +151,7 @@ async def __call__(self, packet: Packet) -> Optional[Dict[str, Any]]: "timestamp": int(time.time()), "snr": advert_data["snr"], "rssi": advert_data["rssi"], + "inbound_path": inbound_path, } self.event_service.publish_sync(MeshEvents.NODE_DISCOVERED, event_data) except Exception as e: diff --git a/tests/test_companion_bridge.py b/tests/test_companion_bridge.py index 6d17d11..70727c9 100644 --- a/tests/test_companion_bridge.py +++ b/tests/test_companion_bridge.py @@ -253,6 +253,7 @@ async def test_send_binary_req_with_contact(self): @pytest.mark.asyncio class TestCompanionBridgeNodeDiscoveredAdvertPipeline: async def test_node_discovered_adds_contact_and_fires_advert_received(self): + """Single path: NODE_DISCOVERED event drives store + advert_received (Bridge and Radio).""" injector = MockPacketInjector() bridge = CompanionBridge(LocalIdentity(), injector) peer = LocalIdentity() @@ -280,6 +281,30 @@ def on_advert(c): assert advert_received_calls[0].name == "DiscoveredNode" assert advert_received_calls[0].public_key == peer.get_public_key() + async def test_one_node_discovered_event_produces_exactly_one_advert_received(self): + """Single-path guarantee: one NODE_DISCOVERED event yields exactly one + advert_received callback (no duplicate path, no duplicate push frames). + """ + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + peer = LocalIdentity() + event_data = { + "public_key": peer.get_public_key().hex(), + "name": "SinglePathNode", + "contact_type": ADV_TYPE_CHAT, + "lat": 0.0, + "lon": 0.0, + "advert_timestamp": 1000, + "timestamp": 1000, + "snr": 0.0, + "rssi": 0, + } + advert_received_calls = [] + bridge.on_advert_received(advert_received_calls.append) + await bridge._handle_mesh_event(MeshEvents.NODE_DISCOVERED, event_data) + assert len(advert_received_calls) == 1 + assert advert_received_calls[0].name == "SinglePathNode" + async def test_node_discovered_fires_node_discovered_even_when_filtered(self): injector = MockPacketInjector() bridge = CompanionBridge(LocalIdentity(), injector) @@ -314,30 +339,38 @@ def on_advert(c): assert len(node_discovered_calls) == 1 assert node_discovered_calls[0]["name"] == "RepeaterNode" - async def test_update_stores_from_advert_adds_contact_and_returns_it(self): + async def test_node_discovered_event_path_adds_contact_and_fires_advert_received(self): + """Event path with optional inbound_path: store updated, advert_received fired once.""" injector = MockPacketInjector() bridge = CompanionBridge(LocalIdentity(), injector) peer = LocalIdentity() pub_key_hex = peer.get_public_key().hex() - packet = Packet() - packet.path_len = 0 - packet.path = bytearray() - advert_data = { + event_data = { "public_key": pub_key_hex, "name": "AdvertNode", - "contact_type_id": ADV_TYPE_CHAT, - "latitude": 0.0, - "longitude": 0.0, + "contact_type": ADV_TYPE_CHAT, + "lat": 0.0, + "lon": 0.0, "advert_timestamp": 1000, + "timestamp": 1000, + "snr": 0.0, + "rssi": 0, + "inbound_path": b"\x01\x02\x03", } - contact = await bridge._update_stores_from_advert(packet, advert_data) - assert contact is not None - assert contact.name == "AdvertNode" + advert_received_calls = [] + + def on_advert(c): + advert_received_calls.append(c) + + bridge.on_advert_received(on_advert) + await bridge._handle_mesh_event(MeshEvents.NODE_DISCOVERED, event_data) assert bridge.contacts.get_count() == 1 - contact2 = await bridge._update_stores_from_advert(packet, advert_data) - assert contact2 is not None - assert contact2.name == "AdvertNode" + assert len(advert_received_calls) == 1 + assert advert_received_calls[0].name == "AdvertNode" + # Second event (same contact): update, still one contact, advert_received again + await bridge._handle_mesh_event(MeshEvents.NODE_DISCOVERED, event_data) assert bridge.contacts.get_count() == 1 + assert len(advert_received_calls) == 2 # --------------------------------------------------------------------------- From f043f16af34f8e03ac53901c7d89c536e419531a Mon Sep 17 00:00:00 2001 From: agessaman Date: Fri, 27 Feb 2026 22:32:08 -0800 Subject: [PATCH 30/50] feat(companion): implement pending ACK CRC management for send_confirmed - TODO #1 and #2 - Introduced `_pending_ack_crcs` to track pending ACK CRCs in `CompanionBase`, enhancing the handling of confirmed sends. - Refactored `_track_pending_ack` and `_apply_ack` methods to utilize the new `_try_confirm_send` method for improved clarity and functionality. - Updated `CompanionRadio` to clear pending ACKs upon key import and added an ACK received listener to the dispatcher for better integration. - Enhanced the dispatcher to support asynchronous ACK handling, ensuring efficient processing of received ACKs. --- src/pymc_core/companion/companion_base.py | 15 ++++++++++++- src/pymc_core/companion/companion_bridge.py | 12 +--------- src/pymc_core/companion/companion_radio.py | 6 +++++ src/pymc_core/companion/constants.py | 1 + src/pymc_core/node/dispatcher.py | 25 ++++++++++++++++++++- src/pymc_core/node/handlers/ack.py | 16 ++++++++----- tests/test_dispatcher.py | 2 +- 7 files changed, 58 insertions(+), 19 deletions(-) diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py index db67698..cd4cb30 100644 --- a/src/pymc_core/companion/companion_base.py +++ b/src/pymc_core/companion/companion_base.py @@ -50,6 +50,7 @@ DEFAULT_MAX_CONTACTS, DEFAULT_OFFLINE_QUEUE_SIZE, DEFAULT_RESPONSE_TIMEOUT_MS, + MAX_PENDING_ACK_CRCS, MAX_SIGN_DATA_SIZE, PROTOCOL_CODE_ANON_REQ, PROTOCOL_CODE_BINARY_REQ, @@ -189,6 +190,8 @@ def _init_companion_stores( self._pending_binary_requests: dict[str, dict] = {} # Pending path discovery tags for matching responses self._pending_discovery_tags: set[int] = set() + # Pending ACK CRCs for send_confirmed (Bridge and Radio) + self._pending_ack_crcs: set[int] = set() # GRP_TXT dedup by packet hash: match Mesh.cpp (!_tables->hasSeen(pkt)); # companion queues one frame per logical message like the firmware. @@ -1476,7 +1479,17 @@ def _response_cb(message_text: str, sender_contact: Any) -> None: text_handler.set_command_response_callback(None) def _track_pending_ack(self, ack_crc: int) -> None: - """Hook for subclasses to track pending ACK CRCs. Default is a no-op.""" + """Track pending ACK CRC for send_confirmed (capped).""" + if len(self._pending_ack_crcs) < MAX_PENDING_ACK_CRCS: + self._pending_ack_crcs.add(ack_crc) + + async def _try_confirm_send(self, crc: int) -> bool: + """If CRC is pending, discard it and fire send_confirmed. Returns True if fired.""" + if crc not in self._pending_ack_crcs: + return False + self._pending_ack_crcs.discard(crc) + await self._fire_callbacks("send_confirmed", crc) + return True def sync_next_message(self) -> Optional[QueuedMessage]: """Pop and return the next queued message, or None.""" diff --git a/src/pymc_core/companion/companion_bridge.py b/src/pymc_core/companion/companion_bridge.py index 6acd48d..1a87f9d 100644 --- a/src/pymc_core/companion/companion_bridge.py +++ b/src/pymc_core/companion/companion_bridge.py @@ -42,8 +42,6 @@ # Bridge ACK handler: fires send_confirmed when ACK CRC matches a pending send # --------------------------------------------------------------------------- -MAX_PENDING_ACK_CRCS = 64 - class _BridgeAckHandler: """Handles ACK packets (discrete and PATH-carried). @@ -64,10 +62,7 @@ async def __call__(self, packet: Packet) -> None: async def _apply_ack(self, crc: int) -> None: """If CRC is pending, clear it and fire send_confirmed.""" - if crc not in self._bridge._pending_ack_crcs: - return - self._bridge._pending_ack_crcs.discard(crc) - await self._bridge._fire_callbacks("send_confirmed", crc) + await self._bridge._try_confirm_send(crc) async def process_path_ack_variants(self, packet: Packet) -> Optional[int]: """Decrypt PATH payload; update contact out_path (firmware pattern), return ACK CRC. @@ -211,7 +206,6 @@ async def _delayed_send() -> None: def _log(msg: str) -> None: logger.debug(f"[CompanionBridge] {msg}") - self._pending_ack_crcs: set[int] = set() ack_handler = _BridgeAckHandler(self) # Use shared factory for the core protocol handlers @@ -270,10 +264,6 @@ def _get_login_response_handler(self) -> Any: def _get_text_handler(self) -> Any: return self._text_handler_ref - def _track_pending_ack(self, ack_crc: int) -> None: - if len(self._pending_ack_crcs) < MAX_PENDING_ACK_CRCS: - self._pending_ack_crcs.add(ack_crc) - # ------------------------------------------------------------------------- # RX Entry Point # ------------------------------------------------------------------------- diff --git a/src/pymc_core/companion/companion_radio.py b/src/pymc_core/companion/companion_radio.py index c810cfd..b9ef826 100644 --- a/src/pymc_core/companion/companion_radio.py +++ b/src/pymc_core/companion/companion_radio.py @@ -201,6 +201,7 @@ def set_tx_power(self, power_dbm: int) -> bool: def import_private_key(self, key: bytes) -> bool: try: self._identity = LocalIdentity(seed=key) + self._pending_ack_crcs.clear() self.node = MeshNode( radio=self._radio, local_identity=self._identity, @@ -239,6 +240,7 @@ def _setup_packet_callbacks(self) -> None: dispatcher = self.node.dispatcher dispatcher.set_packet_received_callback(self._on_packet_received) dispatcher.set_packet_sent_callback(self._on_packet_sent) + dispatcher.set_ack_received_listener(self._on_ack_received) if ( hasattr(dispatcher, "protocol_response_handler") and dispatcher.protocol_response_handler @@ -252,5 +254,9 @@ async def _on_packet_received(self, pkt: Any) -> None: is_flood = route_type in (ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD) self.stats.record_rx(is_flood=is_flood) + async def _on_ack_received(self, crc: int) -> None: + """Called by dispatcher when an ACK CRC is received; fire send_confirmed if pending.""" + await self._try_confirm_send(crc) + async def _on_packet_sent(self, pkt: Any) -> None: pass diff --git a/src/pymc_core/companion/constants.py b/src/pymc_core/companion/constants.py index 9653957..8ff8e29 100644 --- a/src/pymc_core/companion/constants.py +++ b/src/pymc_core/companion/constants.py @@ -88,6 +88,7 @@ class BinaryReqType(IntEnum): DEFAULT_MAX_CHANNELS = 40 CONTACT_NAME_SIZE = 32 MAX_SIGN_DATA_SIZE = 8192 # 8KB signing buffer (matches firmware) +MAX_PENDING_ACK_CRCS = 64 # =========================================================================== # Frame Protocol Constants (MeshCore Companion Radio Protocol) diff --git a/src/pymc_core/node/dispatcher.py b/src/pymc_core/node/dispatcher.py index 575382b..4864227 100644 --- a/src/pymc_core/node/dispatcher.py +++ b/src/pymc_core/node/dispatcher.py @@ -66,6 +66,9 @@ def __init__( self.packet_received_callback: Optional[Callable[[Packet], Awaitable[None] | None]] = None self.packet_sent_callback: Optional[Callable[[Packet], Awaitable[None] | None]] = None + # Optional listener for ACK received (e.g. companion send_confirmed) + self._ack_received_listener: Optional[Callable[[int], Awaitable[None] | None]] = None + # Raw packet callbacks: single callback (legacy) and list of subscribers (after parse) self.raw_packet_callback: Optional[Callable[[Packet, bytes], Awaitable[None] | None]] = None self._raw_packet_subscribers: List[Callable[..., Any]] = [] @@ -263,6 +266,13 @@ def set_packet_sent_callback( ) -> None: self.packet_sent_callback = callback + def set_ack_received_listener( + self, + callback: Optional[Callable[[int], Awaitable[None] | None]], + ) -> None: + """Set optional listener for ACK CRCs (e.g. companion send_confirmed).""" + self._ack_received_listener = callback + def set_raw_packet_callback( self, callback: Callable[[Packet, bytes], Awaitable[None] | None] ) -> None: @@ -578,7 +588,7 @@ async def _dispatch(self, pkt: Packet) -> None: # All ACK processing logic is delegated to the AckHandler. # ------------------------------------------------------------------ - def _register_ack_received(self, crc: int) -> None: + async def _register_ack_received(self, crc: int) -> None: """Record that an ACK with the given CRC was received.""" ts = asyncio.get_running_loop().time() self._recent_acks[crc] = ts @@ -588,6 +598,9 @@ def _register_ack_received(self, crc: int) -> None: self._log(f"ACK matched! CRC {crc:08X}") evt.set() + if self._ack_received_listener: + await self._invoke_ack_listener(crc) + async def run_forever(self) -> None: """Run the dispatcher maintenance loop indefinitely (call this in an asyncio task).""" health_check_counter = 0 @@ -630,6 +643,16 @@ async def _invoke_callback(self, cb, pkt: Packet) -> None: else: cb(pkt) + async def _invoke_ack_listener(self, crc: int) -> None: + """Invoke ack-received listener (sync or async).""" + cb = self._ack_received_listener + if cb is None: + return + if asyncio.iscoroutinefunction(cb): + await cb(crc) + else: + cb(crc) + async def _invoke_enhanced_raw_callback( self, callback, pkt: Packet, data: bytes, analysis: dict ) -> None: diff --git a/src/pymc_core/node/handlers/ack.py b/src/pymc_core/node/handlers/ack.py index 7d260cb..c75cdcc 100644 --- a/src/pymc_core/node/handlers/ack.py +++ b/src/pymc_core/node/handlers/ack.py @@ -1,4 +1,5 @@ -from typing import Callable, Optional +import asyncio +from typing import Awaitable, Callable, Optional from ...protocol import Packet from ...protocol.constants import PAYLOAD_TYPE_ACK @@ -20,9 +21,11 @@ def payload_type() -> int: def __init__(self, log_fn, dispatcher=None): self.log = log_fn self.dispatcher = dispatcher - self._ack_received_callback: Optional[Callable[[int], None]] = None + self._ack_received_callback: Optional[Callable[[int], Awaitable[None] | None]] = None - def set_ack_received_callback(self, callback: Callable[[int], None]): + def set_ack_received_callback( + self, callback: Optional[Callable[[int], Awaitable[None] | None]] + ): """Set callback to notify dispatcher when ACK is received.""" self._ack_received_callback = callback @@ -162,5 +165,8 @@ async def _process_bundled_ack_in_path(self, payload: bytes) -> Optional[int]: async def _notify_ack_received(self, crc: int): """Notify the dispatcher that an ACK was received.""" if self._ack_received_callback: - # Call the callback directly since _register_ack_received is synchronous - self._ack_received_callback(crc) + cb = self._ack_received_callback + if asyncio.iscoroutinefunction(cb): + await cb(crc) + else: + cb(crc) diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py index a805815..b4e1e59 100644 --- a/tests/test_dispatcher.py +++ b/tests/test_dispatcher.py @@ -243,7 +243,7 @@ async def test_ack_waiting_and_receipt(self, dispatcher): dispatcher._waiting_acks[crc] = ack_event # Simulate receiving ACK - dispatcher._register_ack_received(crc) + await dispatcher._register_ack_received(crc) # Event should be set assert ack_event.is_set() From 7e376dc1b64ca889668d1f490c6b1d433e90ce2b Mon Sep 17 00:00:00 2001 From: agessaman Date: Sat, 28 Feb 2026 09:54:52 -0800 Subject: [PATCH 31/50] fix(companion): support decimal and hex formats for companion hash - Updated the CompanionFrameServer to accept companion hash values in both decimal and hexadecimal formats, enhancing flexibility in hash handling. - Modified logging to reflect the parsed integer value of the companion hash, ensuring accurate representation in server status messages. --- src/pymc_core/companion/frame_server.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index 905300e..46d7ff0 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -217,11 +217,16 @@ async def start(self) -> None: if self._server.sockets else (self.bind_address, self.port) ) + # Repeater passes hash as hex (first byte of pubkey, e.g. "f5"); accept decimal or hex. + try: + hash_int = int(self.companion_hash) + except ValueError: + hash_int = int(self.companion_hash, 16) logger.info( "Companion frame server listening on %s:%s (hash=0x%02x)", addr[0], addr[1], - int(self.companion_hash), + hash_int, ) async def stop(self) -> None: From 6bab63737cde84de57f4133ff5c0882f95958a90 Mon Sep 17 00:00:00 2001 From: agessaman Date: Sat, 28 Feb 2026 11:32:08 -0800 Subject: [PATCH 32/50] fix(companion): add SNR and RSSI support in message handling - Enhanced the CompanionBase class to extract SNR and RSSI values from network information, improving network condition monitoring. - Updated the QueuedMessage model to include SNR and RSSI attributes for more detailed message data. - Modified the CompanionFrameServer to process and persist SNR and RSSI in received messages, ensuring accurate communication data. --- src/pymc_core/companion/companion_base.py | 8 ++++++++ src/pymc_core/companion/frame_server.py | 8 +++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py index cd4cb30..f5aca82 100644 --- a/src/pymc_core/companion/companion_base.py +++ b/src/pymc_core/companion/companion_base.py @@ -1554,6 +1554,10 @@ async def _handle_new_message(self, data: dict) -> None: sender_key = bytes.fromhex(sender_key_hex) if sender_key_hex else b"" # Handler publishes "message_text"; accept "text" for compatibility message_text = (data.get("message_text") or data.get("text") or "").rstrip("\x00") + # Extract SNR/RSSI from network info if available (same as channel path) + network_info = data.get("network_info", {}) + snr = network_info.get("snr") + rssi = network_info.get("rssi") msg = QueuedMessage( sender_key=sender_key, txt_type=data.get("txt_type", data.get("flags", 0)), @@ -1561,6 +1565,8 @@ async def _handle_new_message(self, data: dict) -> None: text=message_text, is_channel=False, path_len=0, + snr=snr if snr is not None else 0.0, + rssi=rssi if rssi is not None else 0, ) self.message_queue.push(msg) await self._fire_callbacks( @@ -1570,6 +1576,8 @@ async def _handle_new_message(self, data: dict) -> None: msg.timestamp, msg.txt_type, pkt_hash, + snr if snr is not None else 0.0, + rssi if rssi is not None else 0, ) async def _handle_new_channel_message(self, data: dict) -> None: diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index 46d7ff0..41a922e 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -291,7 +291,9 @@ def _write_push(data: bytes) -> None: except Exception as e: logger.warning("Push write error: %s", e) - async def on_message_received(sender_key, text, timestamp, txt_type, packet_hash=None): + async def on_message_received( + sender_key, text, timestamp, txt_type, packet_hash=None, snr=None, rssi=None + ): msg_dict = { "sender_key": sender_key, "text": text, @@ -301,6 +303,8 @@ async def on_message_received(sender_key, text, timestamp, txt_type, packet_hash "channel_idx": 0, "path_len": 0, "packet_hash": packet_hash, + "snr": snr, + "rssi": rssi, } await self._persist_companion_message(msg_dict) _write_push(bytes([PUSH_CODE_MSG_WAITING])) @@ -377,6 +381,8 @@ async def on_channel_message_received( "channel_idx": channel_idx, "path_len": path_len, "packet_hash": packet_hash, + "snr": snr, + "rssi": rssi, } await self._persist_companion_message(msg_dict) _write_push(bytes([PUSH_CODE_MSG_WAITING])) From ccaf639bc98864474d49cfca37c50e5d3fae00ad Mon Sep 17 00:00:00 2001 From: agessaman Date: Sat, 28 Feb 2026 12:22:00 -0800 Subject: [PATCH 33/50] fix(companion): improve client disconnection logging and error handling, add SNR response for channel v3 frame - Added detailed disconnect reasons for client disconnections in the CompanionFrameServer, enhancing debugging and monitoring capabilities. - Updated exception handling to capture specific error types and log them appropriately, improving error traceability. - Ensured that the client disconnection log message includes the reason for disconnection, providing clearer insights into client behavior. --- src/pymc_core/companion/frame_server.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index 41a922e..38ebaa5 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -630,6 +630,7 @@ async def _handle_client( logger.info("Companion client connected (port=%s)", self.port) heartbeat_task = asyncio.create_task(self._heartbeat_loop()) + disconnect_reason: Optional[str] = None try: while True: try: @@ -637,9 +638,10 @@ async def _handle_client( reader.read(1), timeout=self._client_idle_timeout_sec ) except asyncio.TimeoutError: - logger.info("Companion client idle timeout (port=%s)", self.port) + disconnect_reason = "idle_timeout" break if not prefix: + disconnect_reason = "empty_read" break if prefix[0] != FRAME_INBOUND_PREFIX: logger.warning("Invalid frame prefix: 0x%02x", prefix[0]) @@ -648,15 +650,17 @@ async def _handle_client( frame_len = struct.unpack(" MAX_FRAME_SIZE: logger.warning("Frame too long: %s", frame_len) + disconnect_reason = "frame_too_long" break payload = await reader.readexactly(frame_len) await self._handle_cmd(payload) await self._drain_writer() except asyncio.IncompleteReadError: - pass - except (ConnectionResetError, BrokenPipeError): - pass + disconnect_reason = "incomplete_read" + except (ConnectionResetError, BrokenPipeError) as e: + disconnect_reason = type(e).__name__ except Exception as e: + disconnect_reason = f"other: {type(e).__name__}: {e}" logger.error("Client handler error: %s", e, exc_info=True) finally: heartbeat_task.cancel() @@ -667,7 +671,11 @@ async def _handle_client( if self._client_writer is writer: self._client_writer = None self._client_reader = None - logger.info("Companion client disconnected (port=%s)", self.port) + logger.info( + "Companion client disconnected (port=%s): %s", + self.port, + disconnect_reason or "unknown", + ) # ------------------------------------------------------------------------- # Command dispatch @@ -1188,8 +1196,11 @@ async def _cmd_sync_next_message(self, data: bytes) -> None: path_len_byte = msg.path_len if msg.path_len < 256 else 0xFF text_bytes = msg.text.encode("utf-8", errors="replace") if self._app_target_ver >= 3: + snr_byte = max(-128, min(127, int(round(msg.snr * 4)))) + if snr_byte < 0: + snr_byte += 256 frame = ( - bytes([RESP_CODE_CONTACT_MSG_RECV_V3, 0, 0, 0]) + bytes([RESP_CODE_CONTACT_MSG_RECV_V3, snr_byte & 0xFF, 0, 0]) + prefix + bytes([path_len_byte, msg.txt_type]) + struct.pack(" None: return pub_key = data[1 : 1 + PUB_KEY_SIZE] prefix = pub_key[:7] + # Bridge methods used from command handlers must not block the event loop; + # if a subclass adds sync I/O here, run it via asyncio.to_thread(). found = ( self.bridge.get_advert_path(prefix) if getattr(self.bridge, "get_advert_path", None) From 0f531bd1b57d5597b8348a44b6f8c45f500ad729 Mon Sep 17 00:00:00 2001 From: agessaman Date: Sat, 28 Feb 2026 15:47:03 -0800 Subject: [PATCH 34/50] fix(companion): update frame size limits and deduplication handling - Adjusted MAX_FRAME_SIZE to 172 bytes to align with firmware constraints and BLE MTU. - Introduced MAX_PAYLOAD_SIZE to ensure payloads are correctly truncated to prevent overflow. - Enhanced packet injection logic to avoid duplicate messages when multiple bridges are present, ensuring consistent deduplication across local and over-the-air deliveries. - Updated documentation to reflect changes in frame size and deduplication strategy. --- docs/docs/companion.md | 36 +++++++++++++++++++++++-- src/pymc_core/companion/constants.py | 5 +++- src/pymc_core/companion/frame_server.py | 24 ++++++++++++++--- 3 files changed, 59 insertions(+), 6 deletions(-) diff --git a/docs/docs/companion.md b/docs/docs/companion.md index 1bf0781..9c5473d 100644 --- a/docs/docs/companion.md +++ b/docs/docs/companion.md @@ -442,6 +442,37 @@ The bridge registers internal handlers for these payload types: Note that `set_radio_params()` and `set_tx_power()` update in-memory prefs only — there is no physical radio to configure. This is correct: the repeater host owns the radio hardware. +### Avoiding doubled messages + +When you have **multiple bridges** on the same repeater, the other bridge can receive the same logical message twice: once from local fan-out when one bridge injects a packet, and again when the repeater's radio receives that packet over the air. Deduplication uses a packet hash; if the data used for local delivery differs from the bytes sent over the air, the hashes differ and both copies appear. + +**Use one canonical byte representation** for both TX and local delivery: + +1. When your `packet_injector` is called with `pkt`, apply any in-place changes (e.g. flood scope) **before** serializing. +2. Serialize once: `raw = pkt.write_to()`. +3. Send `raw` on the radio. +4. For local delivery to the other bridge, build a new packet from those **same** bytes and pass it: + `pkt2 = Packet(); pkt2.read_from(raw); await other_bridge.process_received_packet(pkt2)`. + +Then both local and OTA deliveries share the same packet hash and companion-side dedup collapses the duplicate. + +**Optional:** Feed the same `raw` bytes into the repeater's dispatcher RX path (the same entry the radio uses) instead of calling `other_bridge.process_received_packet(pkt)` directly. The dispatcher will track the packet hash; when the same bytes arrive over the air they are dropped as duplicates, and you deliver to bridges only from the dispatcher (single path, no double delivery). + +Example injector with serialize-once local delivery to another bridge: + +```python +from pymc_core.protocol import Packet + +async def packet_injector(pkt, wait_for_ack=False): + raw = pkt.write_to() # after any in-place changes (e.g. flood scope) + ok = await repeater.send_raw(raw) # or dispatcher.send_packet with pre-serialized bytes + if ok and other_bridge: + pkt2 = Packet() + pkt2.read_from(raw) + await other_bridge.process_received_packet(pkt2) + return ok +``` + --- ## CompanionFrameServer @@ -457,7 +488,7 @@ All frames use a simple length-prefixed format: | App → Radio | `<` (0x3C) | 2-byte LE | Command byte + payload | | Radio → App | `>` (0x3E) | 2-byte LE | Response/push byte + payload | -Maximum frame size: 512 bytes. +Maximum frame size: 172 bytes (matches firmware; BLE MTU). ### Quick Start @@ -952,7 +983,8 @@ PROTOCOL_CODE_ANON_REQ = 0x07 # Frame format FRAME_OUTBOUND_PREFIX = 0x3E # '>' (radio → app) FRAME_INBOUND_PREFIX = 0x3C # '<' (app → radio) -MAX_FRAME_SIZE = 512 +MAX_FRAME_SIZE = 172 +MAX_PAYLOAD_SIZE = 169 # MAX_FRAME_SIZE - 3 (prefix + 2-byte length) # Error codes (returned by frame server) ERR_CODE_UNSUPPORTED_CMD = 1 diff --git a/src/pymc_core/companion/constants.py b/src/pymc_core/companion/constants.py index 8ff8e29..f2b95e2 100644 --- a/src/pymc_core/companion/constants.py +++ b/src/pymc_core/companion/constants.py @@ -220,7 +220,10 @@ class BinaryReqType(IntEnum): # --------------------------------------------------------------------------- FRAME_OUTBOUND_PREFIX = 0x3E # '>' FRAME_INBOUND_PREFIX = 0x3C # '<' -MAX_FRAME_SIZE = 512 +# Match firmware: writeFrame() refuses to send if len > MAX_FRAME_SIZE; BLE MTU +# is set to this (e.g. BLEDevice::setMTU(MAX_FRAME_SIZE)). Frame = prefix(1) + len(2) + payload. +MAX_FRAME_SIZE = 172 +MAX_PAYLOAD_SIZE = MAX_FRAME_SIZE - 3 # max bytes after prefix + 2-byte length PUB_KEY_SIZE = 32 MAX_PATH_SIZE = 64 diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index 38ebaa5..556c5b5 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -72,6 +72,7 @@ FRAME_OUTBOUND_PREFIX, MAX_FRAME_SIZE, MAX_PATH_SIZE, + MAX_PAYLOAD_SIZE, PUB_KEY_SIZE, PUSH_CODE_ADVERT, PUSH_CODE_BINARY_RESPONSE, @@ -285,6 +286,13 @@ def _setup_push_callbacks(self) -> None: def _write_push(data: bytes) -> None: if self._client_writer and not self._client_writer.is_closing(): try: + if len(data) > MAX_PAYLOAD_SIZE: + logger.warning( + "Push frame payload truncated from %s to %s", + len(data), + MAX_PAYLOAD_SIZE, + ) + data = data[:MAX_PAYLOAD_SIZE] frame = bytes([FRAME_OUTBOUND_PREFIX]) + struct.pack(" None: snr_byte += 256 if rssi_byte < 0: rssi_byte += 256 - payload_len = min(len(raw), MAX_FRAME_SIZE - 3) + payload_len = min(len(raw), MAX_PAYLOAD_SIZE - 3) # 3 = code + snr + rssi data = bytes([PUSH_CODE_LOG_RX_DATA, snr_byte & 0xFF, rssi_byte & 0xFF]) + raw[:payload_len] try: frame = bytes([FRAME_OUTBOUND_PREFIX]) + struct.pack(" None: logger.warning("Drain failed (connection lost): %s", e) def _write_frame(self, data: bytes) -> None: - """Send a frame to the connected client (outbound format).""" + """Send a frame to the connected client (outbound format). + Payload is truncated to MAX_PAYLOAD_SIZE to match firmware MAX_FRAME_SIZE (172). + """ if self._client_writer and not self._client_writer.is_closing(): + if len(data) > MAX_PAYLOAD_SIZE: + logger.warning( + "Outbound frame payload truncated from %s to %s (MAX_FRAME_SIZE=%s)", + len(data), + MAX_PAYLOAD_SIZE, + MAX_FRAME_SIZE, + ) + data = data[:MAX_PAYLOAD_SIZE] frame = bytes([FRAME_OUTBOUND_PREFIX]) + struct.pack(" Date: Sat, 28 Feb 2026 21:14:03 -0800 Subject: [PATCH 35/50] feat(companion): add contact_path_updated callback to ProtocolResponseHandler ProtocolResponseHandler now supports an optional callback when it updates a contact's out_path from a decrypted PATH packet. CompanionRadio and CompanionBridge set this callback to fire contact_path_updated so the companion layer can persist or react. Bridge no longer updates path or fires the callback in process_path_ack_variants; LoginResponseHandler skips forwarding PATH packets to the protocol handler to avoid double invocation. --- docs/docs/companion.md | 14 +++-- src/pymc_core/companion/companion_base.py | 13 ++++ src/pymc_core/companion/companion_bridge.py | 31 +++------- src/pymc_core/companion/companion_radio.py | 3 + src/pymc_core/companion/frame_server.py | 15 +++-- src/pymc_core/node/handlers/login_response.py | 11 ++-- .../node/handlers/protocol_response.py | 32 +++++++++- tests/test_companion_bridge.py | 56 ++++++++++++++++- tests/test_companion_radio.py | 26 ++++++++ tests/test_handlers.py | 62 ++++++++++++++++++- 10 files changed, 215 insertions(+), 48 deletions(-) diff --git a/docs/docs/companion.md b/docs/docs/companion.md index 9c5473d..b7cf888 100644 --- a/docs/docs/companion.md +++ b/docs/docs/companion.md @@ -115,13 +115,13 @@ async def main(): # --- Callbacks --- -def on_msg(sender_key, text, timestamp, txt_type): +def on_msg(sender_key, text, timestamp, txt_type, *args): print(f"DM from {sender_key[:8].hex()}: {text}") def on_advert(contact): print(f"Discovered: {contact.name} (type={contact.adv_type})") -def on_chan_msg(channel_name, sender_name, text, timestamp, path_len, channel_idx): +def on_chan_msg(channel_name, sender_name, text, timestamp, path_len, channel_idx, *args): print(f"[{channel_name}] {sender_name}: {text}") def on_ack(ack_crc): @@ -387,7 +387,7 @@ async def main(): ) bridge.on_message_received( - lambda key, text, ts, tt: print(f"Bridge msg: {text}") + lambda key, text, ts, tt, *args: print(f"Bridge msg: {text}") ) await bridge.start() @@ -796,7 +796,7 @@ companion.set_channel(0, name="Emergency", secret=b"shared_channel_secret_______ companion.set_channel(1, name="General", secret=b"another_shared_secret___________") companion.on_channel_message_received( - lambda ch_name, sender, text, ts, path_len, idx: + lambda ch_name, sender, text, ts, path_len, idx, *args: print(f"[{ch_name}] {sender}: {text}") ) @@ -808,11 +808,13 @@ await companion.send_channel_message(0, "Emergency broadcast") ## Push Callbacks Reference Register callbacks to receive asynchronous events. Both sync and async functions are supported. +Callbacks for `on_message_received` and `on_channel_message_received` receive optional trailing args +`(packet_hash, snr, rssi)` when available; use `*args` to ignore them. | Registration Method | Callback Signature | |---|---| -| `on_message_received` | `(sender_key: bytes, text: str, timestamp: int, txt_type: int)` | -| `on_channel_message_received` | `(channel_name: str, sender_name: str, text: str, timestamp: int, path_len: int, channel_idx: int)` | +| `on_message_received` | `(sender_key: bytes, text: str, timestamp: int, txt_type: int [, packet_hash, snr, rssi])` | +| `on_channel_message_received` | `(channel_name: str, sender_name: str, text: str, timestamp: int, path_len: int, channel_idx: int [, packet_hash, snr, rssi])` | | `on_advert_received` | `(contact: Contact)` | | `on_contact_path_updated` | `(contact: Contact)` | | `on_send_confirmed` | `(ack_crc: int)` | diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py index f5aca82..ab1befc 100644 --- a/src/pymc_core/companion/companion_base.py +++ b/src/pymc_core/companion/companion_base.py @@ -662,6 +662,19 @@ def on_advert_received(self, callback: Callable) -> None: def on_contact_path_updated(self, callback: Callable) -> None: self._push_callbacks["contact_path_updated"].append(callback) + async def _on_contact_path_updated(self, pub: bytes, path_len: int, path_bytes: bytes) -> None: + """Called by ProtocolResponseHandler when contact's out_path is updated from a PATH packet. + Converts (pub, path_len, path_bytes) to a Contact and fires user callbacks with (contact). + """ + contact = self.get_contact_by_key(pub) + if contact is None: + contact = Contact( + public_key=pub, + out_path_len=path_len, + out_path=path_bytes, + ) + await self._fire_callbacks("contact_path_updated", contact) + def on_send_confirmed(self, callback: Callable) -> None: self._push_callbacks["send_confirmed"].append(callback) diff --git a/src/pymc_core/companion/companion_bridge.py b/src/pymc_core/companion/companion_bridge.py index 1a87f9d..ad2fdab 100644 --- a/src/pymc_core/companion/companion_bridge.py +++ b/src/pymc_core/companion/companion_bridge.py @@ -65,8 +65,10 @@ async def _apply_ack(self, crc: int) -> None: await self._bridge._try_confirm_send(crc) async def process_path_ack_variants(self, packet: Packet) -> Optional[int]: - """Decrypt PATH payload; update contact out_path (firmware pattern), return ACK CRC. - Tries every contact matching src_hash (same as TXT_MSG) so we use the correct key. + """Decrypt PATH payload and return ACK CRC if present. + + Path update and contact_path_updated are handled by ProtocolResponseHandler; + this only extracts ACK for send_confirmed. """ from ..protocol import CryptoUtils, Identity @@ -117,27 +119,7 @@ async def process_path_ack_variants(self, packet: Packet) -> Optional[int]: src_hash, ) continue - path_bytes = bytes(decrypted[1 : 1 + path_len]) - # Firmware pattern: onContactPathRecv stores out_path so replies can use sendDirect() - # Update the underlying Contact (store expects Contact with bytes public_key, not proxy) - contact_obj = self._bridge.contacts.get_by_key(pub) - if contact_obj: - contact_obj.out_path_len = path_len - contact_obj.out_path = path_bytes - self._bridge.contacts.update(contact_obj) - logger.debug( - "process_path_ack_variants: updated out_path for src=0x%02x " - "contact=%s path_len=%d", - src_hash, - getattr(contact, "name", "?"), - path_len, - ) - else: - logger.debug( - "process_path_ack_variants: get_by_key returned None for src=0x%02x", - src_hash, - ) - await self._bridge._fire_callbacks("contact_path_updated", pub, path_len, path_bytes) + # Path update and contact_path_updated are handled by ProtocolResponseHandler # If this PATH carries an ACK, return it so send_confirmed can fire extra_start = 1 + path_len if len(decrypted) >= extra_start + 1 + 4 and decrypted[extra_start] == PAYLOAD_TYPE_ACK: @@ -250,6 +232,9 @@ def _reject_all(*args, **kwargs) -> tuple[bool, int]: self._text_handler_ref = core.text_handler core.protocol_response_handler.set_binary_response_callback(self._on_binary_response) core.protocol_response_handler.set_packet_injector(self._packet_injector) + core.protocol_response_handler.set_contact_path_updated_callback( + self._on_contact_path_updated + ) # ------------------------------------------------------------------------- # Handler accessors (used by CompanionBase concrete send methods) diff --git a/src/pymc_core/companion/companion_radio.py b/src/pymc_core/companion/companion_radio.py index b9ef826..cc94b64 100644 --- a/src/pymc_core/companion/companion_radio.py +++ b/src/pymc_core/companion/companion_radio.py @@ -248,6 +248,9 @@ def _setup_packet_callbacks(self) -> None: dispatcher.protocol_response_handler.set_binary_response_callback( self._on_binary_response ) + dispatcher.protocol_response_handler.set_contact_path_updated_callback( + self._on_contact_path_updated + ) async def _on_packet_received(self, pkt: Any) -> None: route_type = pkt.get_route_type() diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index 556c5b5..8e39094 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -359,13 +359,16 @@ async def on_advert_received(contact): except Exception as e: logger.warning("Persist contact after advert failed: %s", e) - async def on_contact_path_updated(pub_key, path_len, path): - if isinstance(pub_key, bytes) and len(pub_key) >= 32: - _write_push(bytes([PUSH_CODE_PATH_UPDATED]) + pub_key[:32]) + async def on_contact_path_updated(contact): + if ( + hasattr(contact, "public_key") + and isinstance(contact.public_key, bytes) + and len(contact.public_key) >= 32 + ): + _write_push(bytes([PUSH_CODE_PATH_UPDATED]) + contact.public_key[:32]) try: - c = self.bridge.contacts.get_by_key(pub_key) if isinstance(pub_key, bytes) else None - if c is not None: - await self._persist_contact(c) + if contact is not None: + await self._persist_contact(contact) except Exception as e: logger.warning("Persist contact after path update failed: %s", e) diff --git a/src/pymc_core/node/handlers/login_response.py b/src/pymc_core/node/handlers/login_response.py index 58e7f20..a02cf9d 100644 --- a/src/pymc_core/node/handlers/login_response.py +++ b/src/pymc_core/node/handlers/login_response.py @@ -95,13 +95,12 @@ async def __call__(self, packet: Packet) -> None: # Find stored password and matching contact(s) if lookup_hash not in self._active_login_passwords: # This might be a telemetry response, not a login response - # Forward to protocol response handler if available + # Forward to protocol response handler if available (only for RESPONSE packets; + # PATH packets are already handled by PathHandler before we are called). if self._protocol_response_handler: - # Create a fake PATH packet format that - # ProtocolResponseHandler expects - # PATH format: dest_hash(1) + src_hash(1) + encrypted_data - # RESPONSE format is already: dest_hash(1) + src_hash(1) + encrypted_data - # So we can directly forward the packet to the protocol response handler + pkt_type = (packet.header >> 2) & 0x0F + if pkt_type == PAYLOAD_TYPE_PATH: + return # PathHandler already invoked protocol_response_handler for this packet try: await self._protocol_response_handler(packet) return diff --git a/src/pymc_core/node/handlers/protocol_response.py b/src/pymc_core/node/handlers/protocol_response.py index e12cfff..b7a44eb 100644 --- a/src/pymc_core/node/handlers/protocol_response.py +++ b/src/pymc_core/node/handlers/protocol_response.py @@ -167,6 +167,18 @@ def __init__(self, log_fn: Callable[[str], None], local_identity, contact_book): self._login_response_handler: Optional[Any] = None # Packet injector for sending reciprocal PATH packets (mirrors C++ Mesh.cpp:168-169) self._packet_injector: Optional[Callable] = None + # Optional: notify when contact out_path is updated from decrypted PATH + # (e.g. companion persist). + self._contact_path_updated_callback: Optional[Callable[..., Any]] = None + + def set_contact_path_updated_callback(self, callback: Optional[Callable[..., Any]]) -> None: + """Set callback when contact out_path is updated from a decrypted PATH packet. + + Signature: (contact_pubkey: bytes, path_len: int, path_bytes: bytes) + -> None | Awaitable[None]. + Called after _update_contact_path when the contact was found and updated. + """ + self._contact_path_updated_callback = callback def set_login_response_handler(self, handler: Any) -> None: """Set login handler ref for checking active login state.""" @@ -337,16 +349,18 @@ def _update_contact_path( src_hash: int, path_len_byte: int, decrypted: bytes, - ) -> None: + ) -> bool: """Update contact out_path from decrypted PATH data (firmware onContactPathRecv pattern). When a PATH packet is successfully decrypted, store the return path on the contact so that subsequent requests use sendDirect() instead of sendFlood(). This mirrors C++ ``BaseChatMesh::onContactPathRecv``. + + Returns True if the contact was found and updated, False otherwise. """ try: if path_len_byte > MAX_PATH_SIZE: - return + return False out_path_bytes = bytes(decrypted[1 : 1 + path_len_byte]) contact_obj = self._contact_book.get_by_key(contact_pubkey) if contact_obj is not None: @@ -357,13 +371,16 @@ def _update_contact_path( f"[ProtocolResponse] Updated out_path for 0x{src_hash:02X}: " f"path_len={path_len_byte}" ) + return True else: self._log( f"[ProtocolResponse] Cannot update out_path for 0x{src_hash:02X}: " f"contact not found by key" ) + return False except Exception as e: self._log(f"[ProtocolResponse] Failed to update out_path: {e}") + return False async def _send_reciprocal_path( self, @@ -492,7 +509,16 @@ async def _decrypt_protocol_response( # Firmware pattern (onContactPathRecv): update contact out_path # so subsequent requests use sendDirect() instead of sendFlood(). - self._update_contact_path(contact_pubkey, src_hash, path_len_byte, decrypted) + out_path_bytes = bytes(decrypted[1 : 1 + path_len_byte]) + if self._update_contact_path( + contact_pubkey, src_hash, path_len_byte, decrypted + ): + if self._contact_path_updated_callback is not None: + cb_result = self._contact_path_updated_callback( + contact_pubkey, path_len_byte, out_path_bytes + ) + if asyncio.iscoroutine(cb_result): + await cb_result # Firmware pattern (Mesh.cpp:168-169): send reciprocal PATH back # to the sender so it learns the route to us. Without this, the diff --git a/tests/test_companion_bridge.py b/tests/test_companion_bridge.py index 70727c9..af0e623 100644 --- a/tests/test_companion_bridge.py +++ b/tests/test_companion_bridge.py @@ -6,8 +6,14 @@ from pymc_core.companion.constants import ADV_TYPE_CHAT, AUTOADD_CHAT from pymc_core.companion.models import Contact from pymc_core.node.events import MeshEvents -from pymc_core.protocol import LocalIdentity, Packet -from pymc_core.protocol.constants import PAYLOAD_TYPE_ADVERT, PAYLOAD_TYPE_TXT_MSG, ROUTE_TYPE_FLOOD +from pymc_core.protocol import CryptoUtils, Identity, LocalIdentity, Packet +from pymc_core.protocol.constants import ( + PAYLOAD_TYPE_ADVERT, + PAYLOAD_TYPE_PATH, + PAYLOAD_TYPE_RESPONSE, + PAYLOAD_TYPE_TXT_MSG, + ROUTE_TYPE_FLOOD, +) def _make_peer_contact(name: str) -> Contact: @@ -305,6 +311,52 @@ async def test_one_node_discovered_event_produces_exactly_one_advert_received(se assert len(advert_received_calls) == 1 assert advert_received_calls[0].name == "SinglePathNode" + async def test_path_packet_updates_contact_path_and_fires_contact_path_updated_once(self): + """PATH packet that decrypts updates contact out_path and fires contact_path_updated.""" + injector = MockPacketInjector() + local_identity = LocalIdentity() + peer_identity = LocalIdentity() + peer_pubkey = peer_identity.get_public_key() + bridge = CompanionBridge(local_identity, injector) + bridge.contacts.add(Contact(public_key=peer_pubkey, name="Peer")) + + path_len_byte = 2 + path_bytes = bytes([0x01, 0x02]) + extra_type = PAYLOAD_TYPE_RESPONSE + extra = bytes([0, 0, 0, 0, 0x00]) + plaintext = bytes([path_len_byte]) + path_bytes + bytes([extra_type]) + extra + peer_id = Identity(peer_pubkey) + shared_secret = peer_id.calc_shared_secret(local_identity.get_private_key()) + aes_key = shared_secret[:16] + encrypted = CryptoUtils.encrypt_then_mac(aes_key, shared_secret, plaintext) + our_hash = local_identity.get_public_key()[0] + src_hash = peer_pubkey[0] + payload = bytes([our_hash, src_hash]) + encrypted + + pkt = Packet() + pkt.header = (ROUTE_TYPE_FLOOD << 0) | (PAYLOAD_TYPE_PATH << 2) + pkt.path_len = 0 + pkt.path = bytearray() + pkt.payload = bytearray(payload) + pkt.payload_len = len(payload) + + path_updated_calls = [] + + async def on_path_updated(contact): + path_updated_calls.append(contact) + + bridge.on_contact_path_updated(on_path_updated) + await bridge.process_received_packet(pkt) + + assert len(path_updated_calls) == 1 + assert path_updated_calls[0].public_key == peer_pubkey + assert path_updated_calls[0].out_path_len == path_len_byte + assert path_updated_calls[0].out_path == path_bytes + contact = bridge.contacts.get_by_key(peer_pubkey) + assert contact is not None + assert contact.out_path_len == path_len_byte + assert contact.out_path == path_bytes + async def test_node_discovered_fires_node_discovered_even_when_filtered(self): injector = MockPacketInjector() bridge = CompanionBridge(LocalIdentity(), injector) diff --git a/tests/test_companion_radio.py b/tests/test_companion_radio.py index 6d24ca3..81b6621 100644 --- a/tests/test_companion_radio.py +++ b/tests/test_companion_radio.py @@ -248,6 +248,32 @@ async def test_send_control_data_raw_payload(self): assert result is True assert len(radio.sent) == 1 + async def test_contact_path_updated_fired_when_handler_callback_invoked(self): + """Radio wires protocol_response_handler contact_path_updated to _fire_callbacks.""" + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + path_updated_calls = [] + + async def on_path_updated(contact): + path_updated_calls.append(contact) + + comp.on_contact_path_updated(on_path_updated) + proto = comp.node.dispatcher.protocol_response_handler + assert proto is not None + assert proto._contact_path_updated_callback is not None + + pub = b"\x22" * 32 + path_len = 2 + path_bytes = bytes([0x01, 0x02]) + cb_result = proto._contact_path_updated_callback(pub, path_len, path_bytes) + if hasattr(cb_result, "__await__"): + await cb_result + + assert len(path_updated_calls) == 1 + assert path_updated_calls[0].public_key == pub + assert path_updated_calls[0].out_path_len == path_len + assert path_updated_calls[0].out_path == path_bytes + # --------------------------------------------------------------------------- # Stats and config diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 0ef6662..4c1365f 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -14,7 +14,7 @@ TextMessageHandler, TraceHandler, ) -from pymc_core.protocol import LocalIdentity, Packet, PacketBuilder +from pymc_core.protocol import CryptoUtils, Identity, LocalIdentity, Packet, PacketBuilder from pymc_core.protocol.constants import ( PAYLOAD_TYPE_ACK, PAYLOAD_TYPE_ADVERT, @@ -344,8 +344,66 @@ def test_parse_telemetry_response_rejects_non_telemetry(self): result = self.handler._parse_telemetry_response(data) assert result is None + def test_set_contact_path_updated_callback(self): + """set_contact_path_updated_callback stores the callback.""" + cb = MagicMock() + self.handler.set_contact_path_updated_callback(cb) + assert self.handler._contact_path_updated_callback is cb + self.handler.set_contact_path_updated_callback(None) + assert self.handler._contact_path_updated_callback is None + + @pytest.mark.asyncio + async def test_contact_path_updated_callback_invoked_on_path_update(self): + """PATH decrypts and updates contact path; contact_path_updated callback is invoked.""" + from pymc_core.companion.contact_store import ContactStore + from pymc_core.companion.models import Contact + + local_identity = LocalIdentity() + peer_identity = LocalIdentity() + peer_pubkey = peer_identity.get_public_key() + contacts = ContactStore(5) + contacts.add(Contact(public_key=peer_pubkey, name="Peer")) + log_fn = MagicMock() + handler = ProtocolResponseHandler(log_fn, local_identity, contacts) + handler.set_binary_response_callback(lambda *a, **k: None) + + path_len_byte = 2 + path_bytes = bytes([0x01, 0x02]) + extra_type = PAYLOAD_TYPE_RESPONSE + extra = bytes([0, 0, 0, 0, 0x00]) # tag(4) + 1 byte (not login response) + plaintext = bytes([path_len_byte]) + path_bytes + bytes([extra_type]) + extra + + peer_id = Identity(peer_pubkey) + shared_secret = peer_id.calc_shared_secret(local_identity.get_private_key()) + aes_key = shared_secret[:16] + encrypted = CryptoUtils.encrypt_then_mac(aes_key, shared_secret, plaintext) + + our_hash = local_identity.get_public_key()[0] + src_hash = peer_pubkey[0] + payload = bytes([our_hash, src_hash]) + encrypted + + pkt = Packet() + pkt.header = (0 << 0) | (PAYLOAD_TYPE_PATH << 2) + pkt.path_len = 0 + pkt.path = bytearray() + pkt.payload = bytearray(payload) + pkt.payload_len = len(payload) + + callback_calls = [] + + async def on_path_updated(pub: bytes, path_len: int, path_bytes_arg: bytes) -> None: + callback_calls.append((pub, path_len, path_bytes_arg)) + + handler.set_contact_path_updated_callback(on_path_updated) + + await handler(pkt) + + assert len(callback_calls) == 1 + assert callback_calls[0][0] == peer_pubkey + assert callback_calls[0][1] == path_len_byte + assert callback_calls[0][2] == path_bytes + -# Trace Handler Tests class TestTraceHandler: def setup_method(self): self.log_fn = MagicMock() From fa86810515ffc14df165a62293f035eeeae7d4db Mon Sep 17 00:00:00 2001 From: agessaman Date: Sat, 28 Feb 2026 22:30:49 -0800 Subject: [PATCH 36/50] feat(companion): Add callbacks to CompanionRadio. - Add channel_updated and channel_removed callbacks. - Add raw_data_received callback with SNR and RSSI for parity with RX_RAW - Add RX_LOG_DATA for parity with firmware LOG_RX_DATA with data, SNR, and RSSI. --- docs/docs/companion.md | 19 +++++- src/pymc_core/companion/companion_base.py | 34 +++++++++- src/pymc_core/companion/companion_bridge.py | 4 +- src/pymc_core/companion/companion_radio.py | 15 ++++- src/pymc_core/companion/models.py | 2 + tests/test_companion_bridge.py | 69 +++++++++++++++++++++ tests/test_companion_radio.py | 42 ++++++++++++- 7 files changed, 180 insertions(+), 5 deletions(-) diff --git a/docs/docs/companion.md b/docs/docs/companion.md index b7cf888..7e3e8b2 100644 --- a/docs/docs/companion.md +++ b/docs/docs/companion.md @@ -143,6 +143,7 @@ CompanionRadio( max_channels: int = 40, offline_queue_size: int = 512, radio_config: dict | None = None, + initial_contacts: iterable of Contact | None = None, # optional bulk load on boot ) ``` @@ -192,6 +193,15 @@ contact = companion.get_contact_by_name("Alice") companion.add_update_contact(Contact(public_key=key, name="Bob")) companion.remove_contact(pub_key_bytes) +# Populate on boot: pass initial_contacts into the constructor (CompanionRadio or CompanionBridge). +# Replaces the need to call the store after construction. +contacts_from_prefs = [Contact(public_key=k, name=name) for ...] # e.g. from _load_prefs or a file +companion = CompanionRadio(radio, identity, node_name="myNode", initial_contacts=contacts_from_prefs) +await companion.start() + +# If your data is dicts (e.g. JSON/DB), load after construction: +companion.contacts.load_from_dicts([{"public_key": key_hex, "name": "Bob"}, ...]) + # Reset routing path (force re-discovery) companion.reset_path(pub_key_bytes) @@ -272,6 +282,10 @@ companion.set_radio_params(915_000_000, 250_000, 10, 5) # freq, bw, SF, CR companion.set_tx_power(22) # dBm companion.set_tuning_params(rx_delay=0.0, airtime_factor=0.0) +# Fetch current radio configuration (frequency, bandwidth, SF, CR, TX power, tuning) +radio_params = companion.get_radio_params() +# {'frequency_hz': 915000000, 'bandwidth_hz': 250000, 'spreading_factor': 10, ...} + # Time management (transient, not persisted) device_time = companion.get_time() # Unix timestamp ok = companion.set_time(1700000000) # returns False if in the past @@ -414,6 +428,7 @@ CompanionBridge( offline_queue_size: int = 512, radio_config: dict | None = None, authenticate_callback: Callable | None = None, # (hash, pw) -> (bool, int) + initial_contacts: iterable of Contact | None = None, # optional bulk load on boot ) ``` @@ -440,7 +455,7 @@ The bridge registers internal handlers for these payload types: `CompanionBridge` exposes the same messaging, contact, channel, path, signing, stats, and configuration APIs as `CompanionRadio` (inherited from `CompanionBase`). The only behavioral difference is that all TX goes through the `packet_injector` instead of an owned radio. -Note that `set_radio_params()` and `set_tx_power()` update in-memory prefs only — there is no physical radio to configure. This is correct: the repeater host owns the radio hardware. +Note that **CompanionBridge does not own the radio**. `set_radio_params()` and `set_tx_power()` update in-memory prefs only; there is no physical radio to configure. `get_radio_params()` and `get_self_info()` return those in-memory prefs, not the repeater's actual hardware configuration. ### Avoiding doubled messages @@ -824,6 +839,7 @@ Callbacks for `on_message_received` and `on_channel_message_received` receive op | `on_telemetry_response` | `(event_data)` | | `on_status_response` | `(status_data)` | | `on_raw_data_received` | `(raw_data)` | +| `on_rx_log_data` | `(snr: float, rssi: int, raw_bytes: bytes)` — **CompanionRadio only**; same data as PUSH 0x88 LOG_RX_DATA | | `on_binary_response` | `(tag: bytes, data: bytes, parsed: dict\|None, request_type: int\|None)` | | `on_path_discovery_response` | `(tag: bytes, contact_pubkey: bytes, out_path: bytes, in_path: bytes)` | @@ -881,6 +897,7 @@ class NodePrefs: autoadd_config: int = 0 rx_delay_base: float = 0.0 airtime_factor: float = 0.0 + client_repeat: int = 0 # reported in CMD_DEVICE_QUERY device info frame (byte 80) ``` ### SentResult diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py index ab1befc..d96ccc6 100644 --- a/src/pymc_core/companion/companion_base.py +++ b/src/pymc_core/companion/companion_base.py @@ -15,7 +15,7 @@ import time from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Any, Callable, Optional +from typing import Any, Callable, Iterable, Optional from ..node.events import EventService, EventSubscriber, MeshEvents from ..protocol import LocalIdentity, Packet, PacketBuilder @@ -81,6 +81,7 @@ "telemetry_response", "status_response", "raw_data_received", + "rx_log_data", # raw RX with SNR/RSSI (CompanionRadio only; matches PUSH 0x88) "binary_response", "path_discovery_response", "contact_deleted", @@ -153,6 +154,7 @@ def _init_companion_stores( max_channels: int = DEFAULT_MAX_CHANNELS, offline_queue_size: int = DEFAULT_OFFLINE_QUEUE_SIZE, radio_config: Optional[dict] = None, + initial_contacts: Optional[Iterable[Contact]] = None, ) -> None: """Initialize shared stores, prefs, event service, and push callbacks.""" self._identity = identity @@ -206,6 +208,10 @@ def _init_companion_stores( # Allow subclasses to restore persisted preferences on startup. self._load_prefs() + # Optional bulk load of contacts (e.g. from persistence on boot). + if initial_contacts is not None: + self.contacts.load_from(initial_contacts) + # ------------------------------------------------------------------------- # Preference Persistence Hooks # ------------------------------------------------------------------------- @@ -365,6 +371,23 @@ def get_tuning_params(self) -> tuple[float, float]: """Return the current (rx_delay, airtime_factor) tuning parameters.""" return (self.prefs.rx_delay_base, self.prefs.airtime_factor) + def get_radio_params(self) -> dict: + """Return current radio configuration (frequency, bandwidth, SF, CR, TX power, tuning). + + Use this to fetch the radio configuration details. Keys match the arguments + to set_radio_params/set_tx_power/set_tuning_params: frequency_hz, bandwidth_hz, + spreading_factor, coding_rate, tx_power_dbm, rx_delay_base, airtime_factor. + """ + return { + "frequency_hz": self.prefs.frequency_hz, + "bandwidth_hz": self.prefs.bandwidth_hz, + "spreading_factor": self.prefs.spreading_factor, + "coding_rate": self.prefs.coding_rate, + "tx_power_dbm": self.prefs.tx_power_dbm, + "rx_delay_base": self.prefs.rx_delay_base, + "airtime_factor": self.prefs.airtime_factor, + } + def get_time(self) -> int: """Return the current device time as a Unix timestamp.""" return int(time.time() + self._time_offset) @@ -696,6 +719,15 @@ def on_status_response(self, callback: Callable) -> None: def on_raw_data_received(self, callback: Callable) -> None: self._push_callbacks["raw_data_received"].append(callback) + def on_rx_log_data(self, callback: Callable) -> None: + """Register callback for raw RX with SNR/RSSI (CompanionRadio only). + + Callback(snr: float, rssi: int, raw_bytes: bytes). Same data as + PUSH_CODE_LOG_RX_DATA (0x88). Only fired when using CompanionRadio; + CompanionBridge does not own the radio. + """ + self._push_callbacks["rx_log_data"].append(callback) + def on_binary_response(self, callback: Callable) -> None: """Register callback for PUSH 0x8C. Callback(tag_bytes, response_data).""" self._push_callbacks["binary_response"].append(callback) diff --git a/src/pymc_core/companion/companion_bridge.py b/src/pymc_core/companion/companion_bridge.py index ad2fdab..4900552 100644 --- a/src/pymc_core/companion/companion_bridge.py +++ b/src/pymc_core/companion/companion_bridge.py @@ -10,7 +10,7 @@ import asyncio import logging -from typing import Any, Callable, Optional +from typing import Any, Callable, Iterable, Optional from ..node.handlers import create_core_handlers from ..node.handlers.login_server import LoginServerHandler @@ -162,6 +162,7 @@ def __init__( offline_queue_size: int = DEFAULT_OFFLINE_QUEUE_SIZE, radio_config: Optional[dict] = None, authenticate_callback: Optional[Callable[..., tuple[bool, int]]] = None, + initial_contacts: Optional[Iterable[Any]] = None, ) -> None: """Initialise the companion bridge.""" self._init_companion_stores( @@ -172,6 +173,7 @@ def __init__( max_channels=max_channels, offline_queue_size=offline_queue_size, radio_config=radio_config, + initial_contacts=initial_contacts, ) self._packet_injector = packet_injector diff --git a/src/pymc_core/companion/companion_radio.py b/src/pymc_core/companion/companion_radio.py index cc94b64..31bb0df 100644 --- a/src/pymc_core/companion/companion_radio.py +++ b/src/pymc_core/companion/companion_radio.py @@ -11,7 +11,7 @@ import asyncio import logging -from typing import Any, Optional +from typing import Any, Iterable, Optional from ..node.node import MeshNode from ..protocol import LocalIdentity, Packet @@ -66,6 +66,7 @@ def __init__( max_channels: int = DEFAULT_MAX_CHANNELS, offline_queue_size: int = DEFAULT_OFFLINE_QUEUE_SIZE, radio_config: Optional[dict] = None, + initial_contacts: Optional[Iterable[Any]] = None, ) -> None: """Initialise the companion radio.""" self._init_companion_stores( @@ -76,6 +77,7 @@ def __init__( max_channels=max_channels, offline_queue_size=offline_queue_size, radio_config=radio_config, + initial_contacts=initial_contacts, ) self._radio = radio self._dispatcher_task: Optional[asyncio.Task] = None @@ -131,6 +133,10 @@ async def start(self) -> None: async def stop(self) -> None: self._running = False + try: + self.node.dispatcher.remove_raw_packet_subscriber(self._on_raw_packet_rx_log) + except Exception: + pass if self._dispatcher_task: self._dispatcher_task.cancel() try: @@ -241,6 +247,7 @@ def _setup_packet_callbacks(self) -> None: dispatcher.set_packet_received_callback(self._on_packet_received) dispatcher.set_packet_sent_callback(self._on_packet_sent) dispatcher.set_ack_received_listener(self._on_ack_received) + dispatcher.add_raw_packet_subscriber(self._on_raw_packet_rx_log) if ( hasattr(dispatcher, "protocol_response_handler") and dispatcher.protocol_response_handler @@ -257,6 +264,12 @@ async def _on_packet_received(self, pkt: Any) -> None: is_flood = route_type in (ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD) self.stats.record_rx(is_flood=is_flood) + async def _on_raw_packet_rx_log(self, pkt: Any, data: bytes, analysis: Any) -> None: + """Dispatcher raw-packet subscriber: fire rx_log_data(snr, rssi, raw_bytes).""" + snr = getattr(pkt, "snr", getattr(pkt, "_snr", 0.0)) + rssi = getattr(pkt, "rssi", getattr(pkt, "_rssi", 0)) + await self._fire_callbacks("rx_log_data", snr, rssi, data) + async def _on_ack_received(self, crc: int) -> None: """Called by dispatcher when an ACK CRC is received; fire send_confirmed if pending.""" await self._try_confirm_send(crc) diff --git a/src/pymc_core/companion/models.py b/src/pymc_core/companion/models.py index 9279de8..7e39504 100644 --- a/src/pymc_core/companion/models.py +++ b/src/pymc_core/companion/models.py @@ -125,6 +125,8 @@ class NodePrefs: autoadd_config: int = 0 rx_delay_base: float = 0.0 airtime_factor: float = 0.0 + # Reported in CMD_DEVICE_QUERY device info frame (byte 80). + client_repeat: int = 0 @dataclass diff --git a/tests/test_companion_bridge.py b/tests/test_companion_bridge.py index af0e623..be82394 100644 --- a/tests/test_companion_bridge.py +++ b/tests/test_companion_bridge.py @@ -1,5 +1,7 @@ """Tests for CompanionBridge (repeater-integrated companion with packet_injector).""" +import asyncio + import pytest from pymc_core.companion import CompanionBridge @@ -8,6 +10,7 @@ from pymc_core.node.events import MeshEvents from pymc_core.protocol import CryptoUtils, Identity, LocalIdentity, Packet from pymc_core.protocol.constants import ( + PAYLOAD_TYPE_ACK, PAYLOAD_TYPE_ADVERT, PAYLOAD_TYPE_PATH, PAYLOAD_TYPE_RESPONSE, @@ -81,6 +84,42 @@ async def test_start_stop(self): assert bridge.is_running is False +# --------------------------------------------------------------------------- +# Channel updated callback +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionBridgeChannelUpdated: + async def test_set_channel_and_remove_channel_fire_channel_updated(self): + """set_channel and remove_channel fire on_channel_updated(idx, channel_or_none).""" + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + events = [] + + def on_channel_updated(idx: int, ch) -> None: + events.append((idx, ch)) + + bridge.on_channel_updated(on_channel_updated) + await bridge.start() + + ok = bridge.set_channel(0, "General", b"secret_________________________") + assert ok is True + await asyncio.sleep(0) + assert len(events) == 1 + assert events[0][0] == 0 + assert events[0][1] is not None + assert events[0][1].name == "General" + + ok = bridge.remove_channel(0) + assert ok is True + await asyncio.sleep(0) + assert len(events) == 2 + assert events[1] == (0, None) + + await bridge.stop() + + # --------------------------------------------------------------------------- # process_received_packet # --------------------------------------------------------------------------- @@ -115,6 +154,36 @@ async def test_process_unknown_type_no_crash(self): await bridge.process_received_packet(pkt) assert True + async def test_process_received_packet_fires_raw_data_received(self): + """CompanionBridge fires on_raw_data_received(raw_bytes, snr, rssi) for each packet.""" + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + raw_calls = [] + + def on_raw(raw: bytes, snr, rssi) -> None: + raw_calls.append((raw, snr, rssi)) + + bridge.on_raw_data_received(on_raw) + await bridge.start() + + pkt = Packet() + pkt.header = (1 << 6) | (PAYLOAD_TYPE_ACK << 2) + pkt.payload = bytearray(b"\x01\x02\x03\x04") + pkt.payload_len = 4 + pkt.path_len = 0 + pkt._snr = 6.0 + pkt._rssi = -75 + + await bridge.process_received_packet(pkt) + await bridge.stop() + + assert len(raw_calls) == 1 + raw_bytes, snr, rssi = raw_calls[0] + expected_raw = pkt.write_to() + assert raw_bytes == expected_raw + assert snr == 6.0 + assert rssi == -75 + # --------------------------------------------------------------------------- # Advertise diff --git a/tests/test_companion_radio.py b/tests/test_companion_radio.py index 81b6621..c8c8035 100644 --- a/tests/test_companion_radio.py +++ b/tests/test_companion_radio.py @@ -5,7 +5,8 @@ from pymc_core.companion import CompanionRadio from pymc_core.companion.constants import ADV_TYPE_CHAT from pymc_core.companion.models import Contact -from pymc_core.protocol import LocalIdentity +from pymc_core.protocol import LocalIdentity, Packet +from pymc_core.protocol.constants import PAYLOAD_TYPE_ACK def _make_peer_contact(name: str) -> Contact: @@ -65,6 +66,16 @@ def test_init_passes_contacts_to_node(self): assert comp.node.contacts is comp.contacts assert comp.node.contacts.get_by_name("Alice") is not None + def test_initial_contacts_populates_store_on_boot(self): + radio = MockRadio() + identity = LocalIdentity() + alice = _make_peer_contact("Alice") + bob = _make_peer_contact("Bob") + comp = CompanionRadio(radio, identity, node_name="TestNode", initial_contacts=[alice, bob]) + assert comp.contacts.get_count() == 2 + assert comp.get_contact_by_name("Alice") is not None + assert comp.get_contact_by_name("Bob") is not None + @pytest.mark.asyncio class TestCompanionRadioLifecycle: @@ -86,6 +97,35 @@ async def test_start_idempotent_warning(self, caplog): await comp.stop() assert "already running" in caplog.text.lower() or True + async def test_rx_log_data_callback_fired_on_raw_packet(self): + """CompanionRadio fires on_rx_log_data(snr, rssi, raw_bytes) for each RX.""" + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + log_calls = [] + + def on_log(snr: float, rssi: int, raw: bytes) -> None: + log_calls.append((snr, rssi, raw)) + + comp.on_rx_log_data(on_log) + await comp.start() + + # Build minimal valid packet (ACK) so dispatcher parses and notifies raw subscribers + pkt = Packet() + pkt.header = (1 << 6) | (PAYLOAD_TYPE_ACK << 2) + pkt.payload = bytearray(b"\x01\x02\x03\x04") + pkt.payload_len = 4 + pkt.path_len = 0 + raw = pkt.write_to() + + await comp.node.dispatcher._process_received_packet(raw, rssi=-75, snr=6.0) + await comp.stop() + + assert len(log_calls) == 1 + snr, rssi, data = log_calls[0] + assert snr == 6.0 + assert rssi == -75 + assert data == raw + # --------------------------------------------------------------------------- # Contact management (base API via radio) From cd3ccf17f788ef62a1abacb793be6d3ea0e025c4 Mon Sep 17 00:00:00 2001 From: agessaman Date: Sun, 1 Mar 2026 11:17:23 -0800 Subject: [PATCH 37/50] feat(companion): implement raw custom data handling in Companion framework - Added CMD_SEND_RAW_DATA command to send raw payloads directly. - Introduced PUSH_CODE_RAW_DATA for receiving raw custom payloads with SNR and RSSI. - Implemented callbacks for raw data reception in CompanionBridge and CompanionRadio. - Enhanced PacketBuilder to create raw custom packets. - Updated tests to cover new raw data functionalities and ensure proper handling. --- docs/docs/companion.md | 12 ++-- src/pymc_core/companion/companion_base.py | 54 +++++++++++++++++- src/pymc_core/companion/companion_bridge.py | 24 ++++++++ src/pymc_core/companion/companion_radio.py | 8 +++ src/pymc_core/companion/frame_server.py | 36 ++++++++++++ src/pymc_core/node/dispatcher.py | 14 +++++ src/pymc_core/protocol/packet_builder.py | 16 ++++++ tests/test_companion_bridge.py | 33 ++++++++--- tests/test_frame_server.py | 62 ++++++++++++++++++++- tests/test_packet_builder.py | 28 +++++++++- 10 files changed, 270 insertions(+), 17 deletions(-) diff --git a/docs/docs/companion.md b/docs/docs/companion.md index 7e3e8b2..21cb97a 100644 --- a/docs/docs/companion.md +++ b/docs/docs/companion.md @@ -573,6 +573,7 @@ The frame server handles the following companion radio protocol commands: | `CMD_SHARE_CONTACT` | 16 | Share a contact to the mesh | | `CMD_EXPORT_CONTACT` | 17 | Export contact as 73-byte blob | | `CMD_IMPORT_CONTACT` | 18 | Import contact from blob | +| `CMD_SEND_RAW_DATA` | 25 | Send raw payload on given direct path | | `CMD_GET_BATT_AND_STORAGE` | 20 | Get battery/storage info | | `CMD_SET_TUNING_PARAMS` | 21 | Set RX delay and airtime factor | | `CMD_DEVICE_QUERY` | 22 | Return device model/version | @@ -602,9 +603,10 @@ The frame server sends unsolicited push frames to the companion app when events | Push Code | Value | Description | |---|---|---| | `PUSH_CODE_ADVERT` | 0x80 | Contact advertisement received | -| `PUSH_CODE_MSG_WAITING` | 0x82 | New message queued | -| `PUSH_CODE_SEND_CONFIRMED` | 0x84 | ACK received for a sent message | -| `PUSH_CODE_PATH_UPDATED` | 0x86 | Contact path updated | +| `PUSH_CODE_MSG_WAITING` | 0x83 | New message queued | +| `PUSH_CODE_SEND_CONFIRMED` | 0x82 | ACK received for a sent message | +| `PUSH_CODE_RAW_DATA` | 0x84 | Raw custom payload received (SNR, RSSI, 0xFF, payload) | +| `PUSH_CODE_PATH_UPDATED` | 0x81 | Contact path updated | | `PUSH_CODE_LOG_RX_DATA` | 0x88 | Raw RX packet (diagnostics) | | `PUSH_CODE_TRACE_DATA` | 0x89 | Trace path response | | `PUSH_CODE_NEW_ADVERT` | 0x8A | New (previously unknown) contact discovered | @@ -838,7 +840,7 @@ Callbacks for `on_message_received` and `on_channel_message_received` receive op | `on_login_result` | `(result_data)` | | `on_telemetry_response` | `(event_data)` | | `on_status_response` | `(status_data)` | -| `on_raw_data_received` | `(raw_data)` | +| `on_raw_data_received` | `(payload: bytes, snr: float, rssi: int)` — raw custom packet received | | `on_rx_log_data` | `(snr: float, rssi: int, raw_bytes: bytes)` — **CompanionRadio only**; same data as PUSH 0x88 LOG_RX_DATA | | `on_binary_response` | `(tag: bytes, data: bytes, parsed: dict\|None, request_type: int\|None)` | | `on_path_discovery_response` | `(tag: bytes, contact_pubkey: bytes, out_path: bytes, in_path: bytes)` | @@ -1026,7 +1028,7 @@ MAX_PATH_SIZE = 64 ## Unimplemented MeshCore Companion Features -The following protocol-level features from the MeshCore companion radio firmware (`examples/companion_radio/`) are **not yet implemented** in pyMC_core: +The following protocol-level features from the MeshCore companion radio firmware (`examples/companion_radio/`) are **not yet implemented** in pyMC_core. CMD_SEND_RAW_DATA (25) and PUSH_CODE_RAW_DATA (0x84) for raw custom packets are implemented. | Feature | Firmware Reference | Description | |---|---|---| diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py index d96ccc6..574e9c8 100644 --- a/src/pymc_core/companion/companion_base.py +++ b/src/pymc_core/companion/companion_base.py @@ -26,6 +26,8 @@ ADVERT_FLAG_IS_REPEATER, ADVERT_FLAG_IS_ROOM_SERVER, ADVERT_FLAG_IS_SENSOR, + MAX_PACKET_PAYLOAD, + MAX_PATH_SIZE, PAYLOAD_TYPE_CONTROL, REQ_TYPE_GET_STATUS, REQ_TYPE_GET_TELEMETRY_DATA, @@ -86,6 +88,7 @@ "path_discovery_response", "contact_deleted", "contacts_full", + "channel_updated", ] @@ -457,7 +460,18 @@ def set_channel(self, idx: int, name: str, secret: bytes) -> bool: secret = secret + b"\x00" * (32 - len(secret)) elif len(secret) > 32: secret = secret[:32] - return self.channels.set(idx, Channel(name=name[:32], secret=secret)) + ok = self.channels.set(idx, Channel(name=name[:32], secret=secret)) + if ok: + ch = self.channels.get(idx) + self._schedule_fire_callbacks("channel_updated", idx, ch) + return ok + + def remove_channel(self, idx: int) -> bool: + """Remove the channel at the given index. Fires on_channel_updated(idx, None).""" + ok = self.channels.remove(idx) + if ok: + self._schedule_fire_callbacks("channel_updated", idx, None) + return ok # ------------------------------------------------------------------------- # Signing Pipeline @@ -744,6 +758,10 @@ def on_contacts_full(self, callback: Callable) -> None: """Register callback for PUSH 0x90 (contacts store full). Callback().""" self._push_callbacks["contacts_full"].append(callback) + def on_channel_updated(self, callback: Callable) -> None: + """Register callback for channel set/remove. Callback(idx: int, channel_or_none).""" + self._push_callbacks["channel_updated"].append(callback) + def register_binary_request( self, tag_hex: str, @@ -1201,6 +1219,32 @@ async def send_raw_data( logger.error(f"Error sending raw data: {e}") return SentResult(success=False) + async def send_raw_data_direct(self, path: bytes, payload: bytes) -> SentResult: + """Send a raw custom packet (PAYLOAD_TYPE_RAW_CUSTOM) on the given direct path. + + No encryption or contact lookup; path and payload are supplied by the caller. + Matches firmware CMD_SEND_RAW_DATA behaviour. + """ + if len(payload) < 4: + return SentResult(success=False) + if len(path) > MAX_PATH_SIZE: + return SentResult(success=False) + if len(payload) > MAX_PACKET_PAYLOAD: + return SentResult(success=False) + try: + pkt = PacketBuilder.create_raw_data(payload) + pkt.path = bytearray(path) + pkt.path_len = len(path) + success = await self._send_packet(pkt, wait_for_ack=False) + if success: + self.stats.record_tx(is_flood=False) + else: + self.stats.record_tx_error() + return SentResult(success=success) + except Exception as e: + logger.error(f"Error sending raw data direct: {e}") + return SentResult(success=False) + async def send_trace_path( self, pub_key: bytes, @@ -1690,3 +1734,11 @@ async def _fire_callbacks(self, event_name: str, *args: Any) -> None: callback(*args) except Exception as e: logger.error(f"Error in {event_name} callback: {e}") + + def _schedule_fire_callbacks(self, event_name: str, *args: Any) -> None: + """Schedule _fire_callbacks from sync code (e.g. set_channel). No-op if no running loop.""" + try: + loop = asyncio.get_running_loop() + loop.create_task(self._fire_callbacks(event_name, *args)) + except RuntimeError: + pass diff --git a/src/pymc_core/companion/companion_bridge.py b/src/pymc_core/companion/companion_bridge.py index 4900552..718d43e 100644 --- a/src/pymc_core/companion/companion_bridge.py +++ b/src/pymc_core/companion/companion_bridge.py @@ -22,6 +22,7 @@ PAYLOAD_TYPE_ANON_REQ, PAYLOAD_TYPE_GRP_TXT, PAYLOAD_TYPE_PATH, + PAYLOAD_TYPE_RAW_CUSTOM, PAYLOAD_TYPE_RESPONSE, PAYLOAD_TYPE_TXT_MSG, ROUTE_TYPE_FLOOD, @@ -139,6 +140,28 @@ async def _notify_ack_received(self, crc: int) -> None: await self._apply_ack(crc) +# --------------------------------------------------------------------------- +# Raw custom payload handler: fires raw_data_received (PUSH 0x84) +# --------------------------------------------------------------------------- + + +class _RawCustomHandler: + """Handles PAYLOAD_TYPE_RAW_CUSTOM packets; fires raw_data_received(payload, snr, rssi).""" + + def __init__(self, bridge: "CompanionBridge") -> None: + self._bridge = bridge + + @staticmethod + def payload_type() -> int: + return PAYLOAD_TYPE_RAW_CUSTOM + + async def __call__(self, packet: Packet) -> None: + payload_bytes = bytes(packet.payload) if packet.payload else b"" + snr = packet.get_snr() if hasattr(packet, "get_snr") else getattr(packet, "_snr", 0) + rssi = packet.rssi if hasattr(packet, "rssi") else getattr(packet, "_rssi", 0) + await self._bridge._fire_callbacks("raw_data_received", payload_bytes, snr, rssi) + + # --------------------------------------------------------------------------- # Main CompanionBridge class # --------------------------------------------------------------------------- @@ -227,6 +250,7 @@ def _reject_all(*args, **kwargs) -> tuple[bool, int]: PAYLOAD_TYPE_ANON_REQ: login_server_handler, PAYLOAD_TYPE_GRP_TXT: core.group_text_handler, PAYLOAD_TYPE_RESPONSE: core.login_response_handler, + PAYLOAD_TYPE_RAW_CUSTOM: _RawCustomHandler(self), } self._protocol_response_handler = core.protocol_response_handler diff --git a/src/pymc_core/companion/companion_radio.py b/src/pymc_core/companion/companion_radio.py index 31bb0df..07deab0 100644 --- a/src/pymc_core/companion/companion_radio.py +++ b/src/pymc_core/companion/companion_radio.py @@ -248,6 +248,7 @@ def _setup_packet_callbacks(self) -> None: dispatcher.set_packet_sent_callback(self._on_packet_sent) dispatcher.set_ack_received_listener(self._on_ack_received) dispatcher.add_raw_packet_subscriber(self._on_raw_packet_rx_log) + dispatcher.raw_data_received_callback = self._on_raw_custom_received if ( hasattr(dispatcher, "protocol_response_handler") and dispatcher.protocol_response_handler @@ -274,5 +275,12 @@ async def _on_ack_received(self, crc: int) -> None: """Called by dispatcher when an ACK CRC is received; fire send_confirmed if pending.""" await self._try_confirm_send(crc) + async def _on_raw_custom_received(self, pkt: Packet) -> None: + """Dispatcher RAW_CUSTOM handler: fire raw_data_received(payload, snr, rssi).""" + payload = bytes(pkt.payload) if pkt.payload else b"" + snr = pkt.get_snr() if hasattr(pkt, "get_snr") else getattr(pkt, "_snr", 0) + rssi = pkt.rssi if hasattr(pkt, "rssi") else getattr(pkt, "_rssi", 0) + await self._fire_callbacks("raw_data_received", payload, snr, rssi) + async def _on_packet_sent(self, pkt: Any) -> None: pass diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index 8e39094..eea3df8 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -44,6 +44,7 @@ CMD_SEND_CONTROL_DATA, CMD_SEND_LOGIN, CMD_SEND_PATH_DISCOVERY_REQ, + CMD_SEND_RAW_DATA, CMD_SEND_SELF_ADVERT, CMD_SEND_STATUS_REQ, CMD_SEND_TELEMETRY_REQ, @@ -86,6 +87,7 @@ PUSH_CODE_NEW_ADVERT, PUSH_CODE_PATH_DISCOVERY_RESPONSE, PUSH_CODE_PATH_UPDATED, + PUSH_CODE_RAW_DATA, PUSH_CODE_SEND_CONFIRMED, PUSH_CODE_STATUS_RESPONSE, PUSH_CODE_TELEMETRY_RESPONSE, @@ -431,6 +433,19 @@ async def on_contact_deleted(pub_key): async def on_contacts_full(): _write_push(bytes([PUSH_CODE_CONTACTS_FULL])) + async def on_raw_data_received(payload_bytes: bytes, snr: float, rssi: int) -> None: + """Push PUSH_CODE_RAW_DATA (0x84): code, SNR byte, RSSI byte, 0xFF, payload.""" + snr_byte = max(-128, min(127, int(round(snr * 4)))) + rssi_byte = max(-128, min(127, int(rssi))) + payload_len = min(len(payload_bytes), MAX_PAYLOAD_SIZE - 4) + data = ( + bytes([PUSH_CODE_RAW_DATA]) + + struct.pack(" None: await self._cmd_get_autoadd_config(data) elif cmd == CMD_SET_OTHER_PARAMS: await self._cmd_set_other_params(data) + elif cmd == CMD_SEND_RAW_DATA: + await self._cmd_send_raw_data(data) else: logger.warning( "Companion unsupported cmd 0x%02x (%s) len=%s", @@ -1728,3 +1746,21 @@ async def _cmd_set_other_params(self, data: bytes) -> None: multi_acks = data[3] if len(data) >= 4 else 0 self.bridge.set_other_params(manual_add, telemetry_modes, advert_loc_policy, multi_acks) self._write_ok() + + async def _cmd_send_raw_data(self, data: bytes) -> None: + """Handle CMD_SEND_RAW_DATA (25). Format: [path_len][path][payload] (min 4-byte payload).""" + if len(data) < 6: + self._write_err(ERR_CODE_UNSUPPORTED_CMD) + return + path_len_byte = data[0] + path_len = path_len_byte - 256 if path_len_byte >= 128 else path_len_byte + if path_len < 0 or path_len > MAX_PATH_SIZE or 1 + path_len + 4 > len(data): + self._write_err(ERR_CODE_UNSUPPORTED_CMD) + return + path = data[1 : 1 + path_len] + payload = data[1 + path_len :] + result = await self.bridge.send_raw_data_direct(path, payload) + if result.success: + self._write_ok() + else: + self._write_err(ERR_CODE_TABLE_FULL) diff --git a/src/pymc_core/node/dispatcher.py b/src/pymc_core/node/dispatcher.py index 4864227..61b494e 100644 --- a/src/pymc_core/node/dispatcher.py +++ b/src/pymc_core/node/dispatcher.py @@ -69,6 +69,9 @@ def __init__( # Optional listener for ACK received (e.g. companion send_confirmed) self._ack_received_listener: Optional[Callable[[int], Awaitable[None] | None]] = None + # Optional callback for PAYLOAD_TYPE_RAW_CUSTOM (companion raw_data_received) + self.raw_data_received_callback: Optional[Callable[[Packet], Awaitable[None]]] = None + # Raw packet callbacks: single callback (legacy) and list of subscribers (after parse) self.raw_packet_callback: Optional[Callable[[Packet, bytes], Awaitable[None] | None]] = None self._raw_packet_subscribers: List[Callable[..., Any]] = [] @@ -217,6 +220,17 @@ def register_default_handlers( self.register_handler(ControlHandler.payload_type(), control_handler) self.control_handler = control_handler + # --- RAW_CUSTOM handler: deliver to companion if direct and callback set --- + from ..protocol.constants import PAYLOAD_TYPE_RAW_CUSTOM + + async def raw_custom_handler(pkt: Packet) -> None: + if not pkt.is_route_direct(): + return + if self.raw_data_received_callback: + await self._invoke_callback(self.raw_data_received_callback, pkt) + + self.register_handler(PAYLOAD_TYPE_RAW_CUSTOM, raw_custom_handler) + self._logger.info("Default handlers registered.") # Set up a fallback handler for unknown packet types diff --git a/src/pymc_core/protocol/packet_builder.py b/src/pymc_core/protocol/packet_builder.py index ab22836..0f0c3e5 100644 --- a/src/pymc_core/protocol/packet_builder.py +++ b/src/pymc_core/protocol/packet_builder.py @@ -23,6 +23,7 @@ PAYLOAD_TYPE_GRP_DATA, PAYLOAD_TYPE_GRP_TXT, PAYLOAD_TYPE_PATH, + PAYLOAD_TYPE_RAW_CUSTOM, PAYLOAD_TYPE_REQ, PAYLOAD_TYPE_RESPONSE, PAYLOAD_TYPE_TRACE, @@ -643,6 +644,21 @@ def create_trace( pkt.payload_len = len(payload) return pkt + @staticmethod + def create_raw_data(data: bytes) -> Packet: + """ + Create a raw custom packet (PAYLOAD_TYPE_RAW_CUSTOM) with no encryption. + + Route type is always DIRECT (consistent with firmware CMD_SEND_RAW_DATA). + Caller must set pkt.path and pkt.path_len for direct routing. + """ + if len(data) > MAX_PACKET_PAYLOAD: + raise ValueError( + f"Raw data length {len(data)} exceeds MAX_PACKET_PAYLOAD ({MAX_PACKET_PAYLOAD})" + ) + header = PacketBuilder._create_header(PAYLOAD_TYPE_RAW_CUSTOM, route_type="direct") + return PacketBuilder._create_packet(header, data) + @staticmethod def create_path_return( dest_hash: int, diff --git a/tests/test_companion_bridge.py b/tests/test_companion_bridge.py index be82394..5263b9b 100644 --- a/tests/test_companion_bridge.py +++ b/tests/test_companion_bridge.py @@ -10,9 +10,9 @@ from pymc_core.node.events import MeshEvents from pymc_core.protocol import CryptoUtils, Identity, LocalIdentity, Packet from pymc_core.protocol.constants import ( - PAYLOAD_TYPE_ACK, PAYLOAD_TYPE_ADVERT, PAYLOAD_TYPE_PATH, + PAYLOAD_TYPE_RAW_CUSTOM, PAYLOAD_TYPE_RESPONSE, PAYLOAD_TYPE_TXT_MSG, ROUTE_TYPE_FLOOD, @@ -155,19 +155,19 @@ async def test_process_unknown_type_no_crash(self): assert True async def test_process_received_packet_fires_raw_data_received(self): - """CompanionBridge fires on_raw_data_received(raw_bytes, snr, rssi) for each packet.""" + """CompanionBridge fires on_raw_data_received(payload, snr, rssi) for RAW_CUSTOM packets.""" injector = MockPacketInjector() bridge = CompanionBridge(LocalIdentity(), injector) raw_calls = [] - def on_raw(raw: bytes, snr, rssi) -> None: - raw_calls.append((raw, snr, rssi)) + def on_raw(payload: bytes, snr, rssi) -> None: + raw_calls.append((payload, snr, rssi)) bridge.on_raw_data_received(on_raw) await bridge.start() pkt = Packet() - pkt.header = (1 << 6) | (PAYLOAD_TYPE_ACK << 2) + pkt.header = (1 << 6) | (PAYLOAD_TYPE_RAW_CUSTOM << 2) pkt.payload = bytearray(b"\x01\x02\x03\x04") pkt.payload_len = 4 pkt.path_len = 0 @@ -178,9 +178,8 @@ def on_raw(raw: bytes, snr, rssi) -> None: await bridge.stop() assert len(raw_calls) == 1 - raw_bytes, snr, rssi = raw_calls[0] - expected_raw = pkt.write_to() - assert raw_bytes == expected_raw + payload_bytes, snr, rssi = raw_calls[0] + assert payload_bytes == b"\x01\x02\x03\x04" assert snr == 6.0 assert rssi == -75 @@ -245,6 +244,24 @@ async def test_share_contact_success(self): assert result is True assert len(injector.calls) == 1 + async def test_send_raw_data_direct_injects_packet(self): + """send_raw_data_direct builds RAW_CUSTOM packet and sends via injector.""" + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + await bridge.start() + path = b"\x42" + payload = b"\x01\x02\x03\x04" + result = await bridge.send_raw_data_direct(path, payload) + await bridge.stop() + assert result.success is True + assert len(injector.calls) == 1 + pkt, wait_for_ack = injector.calls[0] + assert (pkt.header >> 2) & 0x0F == PAYLOAD_TYPE_RAW_CUSTOM + assert pkt.path == bytearray(path) + assert pkt.path_len == len(path) + assert bytes(pkt.payload) == payload + assert wait_for_ack is False + # --------------------------------------------------------------------------- # Path discovery, trace, control data diff --git a/tests/test_frame_server.py b/tests/test_frame_server.py index b62ccdc..8db6573 100644 --- a/tests/test_frame_server.py +++ b/tests/test_frame_server.py @@ -1,15 +1,20 @@ """Tests for CompanionFrameServer and advert push frame construction.""" import struct +from unittest.mock import Mock + +import pytest from pymc_core.companion.constants import ( + ERR_CODE_TABLE_FULL, + ERR_CODE_UNSUPPORTED_CMD, MAX_PATH_SIZE, PUB_KEY_SIZE, PUSH_CODE_ADVERT, PUSH_CODE_NEW_ADVERT, ) -from pymc_core.companion.frame_server import _build_advert_push_frames -from pymc_core.companion.models import Contact +from pymc_core.companion.frame_server import CompanionFrameServer, _build_advert_push_frames +from pymc_core.companion.models import Contact, SentResult def test_build_advert_push_frames_short_only_when_no_name(): @@ -97,3 +102,56 @@ def test_build_advert_push_frames_name_truncated_to_32_bytes(): name_slice = full[36 + MAX_PATH_SIZE : 36 + MAX_PATH_SIZE + 32] assert len(name_slice) == 32 assert name_slice == b"A" * 32 + + +class _MockBridgeSendRawDirect: + """Minimal bridge for CMD_SEND_RAW_DATA tests.""" + + def __init__(self, success: bool = True): + self.calls = [] + self._success = success + + async def send_raw_data_direct(self, path: bytes, payload: bytes): + self.calls.append((path, payload)) + return SentResult(success=self._success) + + +@pytest.mark.asyncio +async def test_cmd_send_raw_data_valid_writes_ok(): + """Valid CMD_SEND_RAW_DATA -> _write_ok.""" + bridge = _MockBridgeSendRawDirect(success=True) + server = CompanionFrameServer(bridge, "hash", port=0) + server._write_ok = Mock() + server._write_err = Mock() + data = bytes([1, 0x42]) + b"\x01\x02\x03\x04" + await server._cmd_send_raw_data(data) + assert bridge.calls == [(b"\x42", b"\x01\x02\x03\x04")] + server._write_ok.assert_called_once() + server._write_err.assert_not_called() + + +@pytest.mark.asyncio +async def test_cmd_send_raw_data_invalid_len_writes_unsupported(): + """Invalid CMD_SEND_RAW_DATA len < 6 -> ERR_CODE_UNSUPPORTED_CMD.""" + bridge = _MockBridgeSendRawDirect() + server = CompanionFrameServer(bridge, "hash", port=0) + server._write_ok = Mock() + server._write_err = Mock() + await server._cmd_send_raw_data(b"\x00\x00\x00") + assert len(bridge.calls) == 0 + server._write_err.assert_called_once_with(ERR_CODE_UNSUPPORTED_CMD) + server._write_ok.assert_not_called() + + +@pytest.mark.asyncio +async def test_cmd_send_raw_data_send_failure_writes_table_full(): + """send_raw_data_direct returns False -> ERR_CODE_TABLE_FULL.""" + bridge = _MockBridgeSendRawDirect(success=False) + server = CompanionFrameServer(bridge, "hash", port=0) + server._write_ok = Mock() + server._write_err = Mock() + data = bytes([1, 0x42]) + b"\x01\x02\x03\x04" + await server._cmd_send_raw_data(data) + assert len(bridge.calls) == 1 + server._write_err.assert_called_once_with(ERR_CODE_TABLE_FULL) + server._write_ok.assert_not_called() diff --git a/tests/test_packet_builder.py b/tests/test_packet_builder.py index 58f453f..fc60a0b 100644 --- a/tests/test_packet_builder.py +++ b/tests/test_packet_builder.py @@ -1,5 +1,10 @@ from pymc_core import LocalIdentity -from pymc_core.protocol.constants import PAYLOAD_TYPE_ACK, PAYLOAD_TYPE_ADVERT +from pymc_core.protocol.constants import ( + MAX_PACKET_PAYLOAD, + PAYLOAD_TYPE_ACK, + PAYLOAD_TYPE_ADVERT, + PAYLOAD_TYPE_RAW_CUSTOM, +) from pymc_core.protocol.packet_builder import PacketBuilder @@ -51,3 +56,24 @@ def test_packet_builder_create_direct_advert(): assert direct_advert is not None assert direct_advert.get_payload_type() == PAYLOAD_TYPE_ADVERT + + +def test_packet_builder_create_raw_data(): + """Test creating raw custom packets (PAYLOAD_TYPE_RAW_CUSTOM).""" + data = b"\x01\x02\x03\x04" + pkt = PacketBuilder.create_raw_data(data) + assert pkt is not None + assert pkt.get_payload_type() == PAYLOAD_TYPE_RAW_CUSTOM + assert pkt.payload == bytearray(data) + assert pkt.payload_len == len(data) + assert pkt.path_len == 0 + assert pkt.path == bytearray() + + +def test_packet_builder_create_raw_data_too_large_raises(): + """Test create_raw_data raises when data exceeds MAX_PACKET_PAYLOAD.""" + import pytest + + data = bytes(MAX_PACKET_PAYLOAD + 1) + with pytest.raises(ValueError, match="exceeds MAX_PACKET_PAYLOAD"): + PacketBuilder.create_raw_data(data) From 14d026e9025af8b0e79102c1dfeeab926fbe0eea Mon Sep 17 00:00:00 2001 From: agessaman Date: Sun, 1 Mar 2026 12:14:13 -0800 Subject: [PATCH 38/50] feat(companion): add private key export/import commands to Companion framework - Introduced CMD_EXPORT_PRIVATE_KEY and CMD_IMPORT_PRIVATE_KEY commands for handling private key operations. - Implemented _cmd_export_private_key to export the private/signing key in 64-byte MeshCore format. - Added a stub for _cmd_import_private_key, which currently does not perform any operations but may support dynamic imports in the future. - Updated CryptoUtils with a method to expand a 32-byte Ed25519 seed to the 64-byte MeshCore format. - Enhanced tests to verify the new private key functionalities and ensure compatibility with MeshCore key generation. --- docs/docs/companion.md | 2 + src/pymc_core/companion/frame_server.py | 28 +++++++++++ src/pymc_core/protocol/crypto.py | 8 +++ tests/test_crypto.py | 66 +++++++++++++++++++++++++ 4 files changed, 104 insertions(+) diff --git a/docs/docs/companion.md b/docs/docs/companion.md index 21cb97a..0c890c2 100644 --- a/docs/docs/companion.md +++ b/docs/docs/companion.md @@ -573,6 +573,8 @@ The frame server handles the following companion radio protocol commands: | `CMD_SHARE_CONTACT` | 16 | Share a contact to the mesh | | `CMD_EXPORT_CONTACT` | 17 | Export contact as 73-byte blob | | `CMD_IMPORT_CONTACT` | 18 | Import contact from blob | +| `CMD_EXPORT_PRIVATE_KEY` | 23 | Export private/signing key (64-byte MeshCore format) | +| `CMD_IMPORT_PRIVATE_KEY` | 24 | Import private key (stub/no-op; key set from config) | | `CMD_SEND_RAW_DATA` | 25 | Send raw payload on given direct path | | `CMD_GET_BATT_AND_STORAGE` | 20 | Get battery/storage info | | `CMD_SET_TUNING_PARAMS` | 21 | Set RX delay and airtime factor | diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index eea3df8..2098f46 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -19,12 +19,14 @@ import time from typing import Any, Callable, Optional +from ..protocol import CryptoUtils from .constants import ( ADV_TYPE_CHAT, CMD_ADD_UPDATE_CONTACT, CMD_APP_START, CMD_DEVICE_QUERY, CMD_EXPORT_CONTACT, + CMD_EXPORT_PRIVATE_KEY, CMD_GET_ADVERT_PATH, CMD_GET_AUTOADD_CONFIG, CMD_GET_BATT_AND_STORAGE, @@ -35,6 +37,7 @@ CMD_GET_DEVICE_TIME, CMD_GET_STATS, CMD_IMPORT_CONTACT, + CMD_IMPORT_PRIVATE_KEY, CMD_LOGOUT, CMD_REMOVE_CONTACT, CMD_RESET_PATH, @@ -110,6 +113,7 @@ RESP_CODE_EXPORT_CONTACT, RESP_CODE_NO_MORE_MESSAGES, RESP_CODE_OK, + RESP_CODE_PRIVATE_KEY, RESP_CODE_SELF_INFO, RESP_CODE_SENT, RESP_CODE_STATS, @@ -802,6 +806,10 @@ async def _handle_cmd(self, payload: bytes) -> None: await self._cmd_share_contact(data) elif cmd == CMD_EXPORT_CONTACT: await self._cmd_export_contact(data) + elif cmd == CMD_EXPORT_PRIVATE_KEY: + await self._cmd_export_private_key(data) + elif cmd == CMD_IMPORT_PRIVATE_KEY: + await self._cmd_import_private_key(data) elif cmd == CMD_SET_TUNING_PARAMS: await self._cmd_set_tuning_params(data) elif cmd == CMD_LOGOUT: @@ -1687,6 +1695,26 @@ async def _cmd_export_contact(self, data: bytes) -> None: return self._write_frame(bytes([RESP_CODE_EXPORT_CONTACT]) + raw) + async def _cmd_export_private_key(self, data: bytes) -> None: + """Export private/signing key as 64-byte MeshCore format (RESP_CODE_PRIVATE_KEY + 64 bytes). + + For PyNaCl 32-byte seeds we expand to MeshCore 64-byte format (SHA-512 + clamp) so + the client's ed25519_derive_pub yields the same public key and signing works. + """ + identity = self.bridge._identity + key_bytes = identity.get_signing_key_bytes() + if len(key_bytes) == 32: + key_bytes = CryptoUtils.ed25519_expand_seed_to_meshcore_64(key_bytes) + elif len(key_bytes) < 64: + key_bytes = key_bytes.ljust(64, b"\x00") + else: + key_bytes = key_bytes[:64] + self._write_frame(bytes([RESP_CODE_PRIVATE_KEY]) + key_bytes) + + async def _cmd_import_private_key(self, data: bytes) -> None: + """Stub/no-op: private key is set from config; dynamic import may be supported later.""" + self._write_ok() + async def _cmd_set_tuning_params(self, data: bytes) -> None: if len(data) < 8: self._write_err(ERR_CODE_ILLEGAL_ARG) diff --git a/src/pymc_core/protocol/crypto.py b/src/pymc_core/protocol/crypto.py index 06d0075..8cd323e 100644 --- a/src/pymc_core/protocol/crypto.py +++ b/src/pymc_core/protocol/crypto.py @@ -87,6 +87,14 @@ def x25519_clamp_scalar(scalar: bytes) -> bytes: s[31] |= 64 return bytes(s) + @staticmethod + def ed25519_expand_seed_to_meshcore_64(seed: bytes) -> bytes: + """Expand a 32-byte Ed25519 seed to 64-byte MeshCore/orlp format.""" + if len(seed) != 32: + raise ValueError("seed must be 32 bytes") + h = hashlib.sha512(seed).digest() + return CryptoUtils.x25519_clamp_scalar(h[:32]) + h[32:64] + @staticmethod def scalarmult(private_key: bytes, public_key: bytes) -> bytes: """ECDH shared secret calculation (X25519).""" diff --git a/tests/test_crypto.py b/tests/test_crypto.py index 7a8d1fc..6a5a701 100644 --- a/tests/test_crypto.py +++ b/tests/test_crypto.py @@ -56,3 +56,69 @@ def test_crypto_utils_key_exchange(): # Both should compute the same shared secret assert alice_shared == bob_shared assert len(alice_shared) == 32 + + +def test_ed25519_expand_seed_to_meshcore_64_same_public_key(): + """ed25519_expand_seed_to_meshcore_64 produces 64-byte key that derives to + same public key as PyNaCl. + + MeshCore firmware (and meshcore-keygen) derives the public key by taking + the first 32 bytes of the 64-byte private key and calling + ed25519_derive_pub / crypto_scalarmult_ed25519_base_noclamp. + So our expanded form must yield the same first-32-bytes scalar so the + client gets the same pub key. + """ + from nacl.bindings import crypto_scalarmult_ed25519_base_noclamp + + from pymc_core import LocalIdentity + + # Fixed seed for reproducible test + seed = bytes.fromhex("70" + "65e18fd9fabb70c1ed90dca19907de" "698c88b709ea146eafd93d9b830c7b60") + assert len(seed) == 32 + + # Expand to 64-byte MeshCore format (SHA-512 + clamp) + expanded = CryptoUtils.ed25519_expand_seed_to_meshcore_64(seed) + assert len(expanded) == 64 + + # Derive public key the way MeshCore firmware/keygen does: first 32 bytes = + # scalar + pub_from_expanded = crypto_scalarmult_ed25519_base_noclamp(expanded[:32]) + + # PyNaCl LocalIdentity(seed=seed) derives pub via the same Ed25519 expansion internally + identity = LocalIdentity(seed=seed) + pub_from_pynacl = identity.get_public_key() + + assert pub_from_expanded == pub_from_pynacl, ( + "Expanded seed must yield same public key as PyNaCl so MeshCore " "clients derive correctly" + ) + + +# Sample keypair generated by meshcore-keygen (MeshCore firmware–compatible +# format). Used to verify our derivation matches: +# pub = crypto_scalarmult_ed25519_base_noclamp(private_64[:32]). +MESHCORE_SAMPLE_PRIVATE_HEX = ( + "a8cdb06cef221ef0ee4e5e9d9d0829499cd304ca1bba47af2ea17d83b316726b" + "e232b1da0388a2eb142d9a16ed66d4994a7ac40339ec1fbc3937f3b81f5dc62b" +) +MESHCORE_SAMPLE_PUBLIC_HEX = "8dfd59f1102e74845758096baf54d2b3a91413813e7c239baf75d1e9a0f04bf4" + + +def test_ed25519_expand_seed_matches_meshcore_keygen_format(): + """Deriving pub from first 32 bytes of a meshcore-keygen 64-byte private + key matches expected pub. + + Uses a sample keypair generated by meshcore-keygen to verify our + derivation (crypto_scalarmult_ed25519_base_noclamp(priv_64[:32])) + matches the firmware/keygen format. + """ + from nacl.bindings import crypto_scalarmult_ed25519_base_noclamp + + private_bytes = bytes.fromhex(MESHCORE_SAMPLE_PRIVATE_HEX) + expected_public = bytes.fromhex(MESHCORE_SAMPLE_PUBLIC_HEX) + assert len(private_bytes) == 64 + assert len(expected_public) == 32 + + derived_public = crypto_scalarmult_ed25519_base_noclamp(private_bytes[:32]) + assert derived_public == expected_public, ( + "MeshCore format: pub must be derived from first 32 bytes of " "64-byte private key" + ) From 2a4c4e5f690107f9d4dc4ba39ead688045462960 Mon Sep 17 00:00:00 2001 From: agessaman Date: Sun, 1 Mar 2026 14:34:43 -0800 Subject: [PATCH 39/50] feat(crypto): add method to derive Ed25519 public key from MeshCore private key - Introduced ed25519_public_from_meshcore_64 method in CryptoUtils to derive an Ed25519 public key from a 64-byte MeshCore private key. - Added validation to ensure the input private key is 64 bytes in length. - Integrated new functionality to enhance key management capabilities within the crypto module. --- src/pymc_core/protocol/crypto.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/pymc_core/protocol/crypto.py b/src/pymc_core/protocol/crypto.py index 8cd323e..c2bb516 100644 --- a/src/pymc_core/protocol/crypto.py +++ b/src/pymc_core/protocol/crypto.py @@ -5,6 +5,7 @@ from nacl.bindings import ( crypto_scalarmult, crypto_scalarmult_base, + crypto_scalarmult_ed25519_base_noclamp, crypto_sign_ed25519_pk_to_curve25519, crypto_sign_ed25519_sk_to_curve25519, ) @@ -95,6 +96,13 @@ def ed25519_expand_seed_to_meshcore_64(seed: bytes) -> bytes: h = hashlib.sha512(seed).digest() return CryptoUtils.x25519_clamp_scalar(h[:32]) + h[32:64] + @staticmethod + def ed25519_public_from_meshcore_64(meshcore_private_64: bytes) -> bytes: + """Derive Ed25519 public key from 64-byte MeshCore private key format.""" + if len(meshcore_private_64) != 64: + raise ValueError("meshcore_private_64 must be 64 bytes") + return crypto_scalarmult_ed25519_base_noclamp(meshcore_private_64[:32]) + @staticmethod def scalarmult(private_key: bytes, public_key: bytes) -> bytes: """ECDH shared secret calculation (X25519).""" From 2e3c2d92ef063109e2aeb2a75c0c53f1a22c9c36 Mon Sep 17 00:00:00 2001 From: agessaman Date: Mon, 2 Mar 2026 06:06:57 -0800 Subject: [PATCH 40/50] feat(companion): refactor push methods for async handling and backpressure - Updated `push_trace_data` and `push_rx_raw` methods to be asynchronous, allowing for proper backpressure handling. - Introduced `push_rx_raw_async` for async usage, ensuring that raw RX packets can be pushed with backpressure. - Modified internal `_write_push` method to await the drain process, preventing concurrent drain tasks and improving reliability. - Enhanced documentation and tests to reflect the new async behavior and ensure correct functionality in various contexts. --- docs/docs/companion.md | 8 ++- src/pymc_core/companion/frame_server.py | 54 +++++++++++------- tests/test_frame_server.py | 75 ++++++++++++++++++++++++- 3 files changed, 114 insertions(+), 23 deletions(-) diff --git a/docs/docs/companion.md b/docs/docs/companion.md index 0c890c2..6cadf40 100644 --- a/docs/docs/companion.md +++ b/docs/docs/companion.md @@ -625,14 +625,16 @@ The frame server sends unsolicited push frames to the companion app when events The frame server exposes methods for the host application to push data to the connected companion app: ```python -# Push trace data from the repeater -server.push_trace_data( +# Push trace data from the repeater (await for backpressure) +await server.push_trace_data( path_len=3, flags=0, tag=42, auth_code=0, path_hashes=b"...", path_snrs=b"...", final_snr_byte=0 ) -# Push raw RX packet for diagnostics logging +# Push raw RX packet for diagnostics logging (sync: schedules send, works without await) server.push_rx_raw(snr=-5.0, rssi=-100, raw=b"...") +# Or from async code with backpressure: +await server.push_rx_raw_async(snr=-5.0, rssi=-100, raw=b"...") # Push control data await server.push_control_data( diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index 2098f46..ff0cde0 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -289,7 +289,7 @@ def _get_batt_and_storage(self) -> tuple[int, int, int]: def _setup_push_callbacks(self) -> None: """Subscribe to bridge events and send PUSH frames to connected client.""" - def _write_push(data: bytes) -> None: + async def _write_push(data: bytes) -> None: if self._client_writer and not self._client_writer.is_closing(): try: if len(data) > MAX_PAYLOAD_SIZE: @@ -301,7 +301,7 @@ def _write_push(data: bytes) -> None: data = data[:MAX_PAYLOAD_SIZE] frame = bytes([FRAME_OUTBOUND_PREFIX]) + struct.pack("= 32 ): - _write_push(bytes([PUSH_CODE_PATH_UPDATED]) + contact.public_key[:32]) + await _write_push(bytes([PUSH_CODE_PATH_UPDATED]) + contact.public_key[:32]) try: if contact is not None: await self._persist_contact(contact) @@ -402,7 +401,7 @@ async def on_channel_message_received( "rssi": rssi, } await self._persist_companion_message(msg_dict) - _write_push(bytes([PUSH_CODE_MSG_WAITING])) + await _write_push(bytes([PUSH_CODE_MSG_WAITING])) async def on_binary_response(tag_bytes, response_data, parsed=None, request_type=None): frame = ( @@ -410,7 +409,7 @@ async def on_binary_response(tag_bytes, response_data, parsed=None, request_type + (tag_bytes if isinstance(tag_bytes, bytes) else struct.pack("= 32: - _write_push(bytes([PUSH_CODE_CONTACT_DELETED]) + pub_key[:32]) + await _write_push(bytes([PUSH_CODE_CONTACT_DELETED]) + pub_key[:32]) async def on_contacts_full(): - _write_push(bytes([PUSH_CODE_CONTACTS_FULL])) + await _write_push(bytes([PUSH_CODE_CONTACTS_FULL])) async def on_raw_data_received(payload_bytes: bytes, snr: float, rssi: int) -> None: """Push PUSH_CODE_RAW_DATA (0x84): code, SNR byte, RSSI byte, 0xFF, payload.""" @@ -448,7 +447,7 @@ async def on_raw_data_received(payload_bytes: bytes, snr: float, rssi: int) -> N + bytes([0xFF]) + payload_bytes[:payload_len] ) - _write_push(data) + await _write_push(data) self.bridge.on_message_received(on_message_received) self.bridge.on_channel_message_received(on_channel_message_received) @@ -465,7 +464,7 @@ async def on_raw_data_received(payload_bytes: bytes, snr: float, rssi: int) -> N # Public push methods (called directly by host application) # ------------------------------------------------------------------------- - def push_trace_data( + async def push_trace_data( self, path_len: int, flags: int, @@ -498,12 +497,25 @@ def push_trace_data( try: frame = bytes([FRAME_OUTBOUND_PREFIX]) + struct.pack(" None: - """Push raw RX packet to client (PUSH_CODE_LOG_RX_DATA 0x88).""" + """Push raw RX packet to client (PUSH_CODE_LOG_RX_DATA 0x88). + + Schedules the push on the event loop so it works when called from sync + context (e.g. repeater RX callback). For backpressure from async code, + use ``await self.push_rx_raw_async(snr, rssi, raw)`` instead. + """ + try: + loop = asyncio.get_running_loop() + loop.create_task(self._push_rx_raw_impl(snr, rssi, raw)) + except RuntimeError: + pass # no running loop + + async def _push_rx_raw_impl(self, snr: float, rssi: int, raw: bytes) -> None: + """Implementation of push_rx_raw; write + drain for backpressure.""" if not self._client_writer or self._client_writer.is_closing(): return snr_byte = max(-128, min(127, int(round(snr * 4)))) @@ -517,10 +529,14 @@ def push_rx_raw(self, snr: float, rssi: int, raw: bytes) -> None: try: frame = bytes([FRAME_OUTBOUND_PREFIX]) + struct.pack(" None: + """Push raw RX packet to client with backpressure (await in async code).""" + await self._push_rx_raw_impl(snr, rssi, raw) + async def push_control_data( self, snr: float, @@ -1186,7 +1202,7 @@ async def _cmd_send_trace_path(self, data: bytes) -> None: snr_len = path_len >> path_sz path_snrs = bytes(snr_len) final_snr_byte = 0 - self.push_trace_data( + await self.push_trace_data( path_len, flags, tag, diff --git a/tests/test_frame_server.py b/tests/test_frame_server.py index 8db6573..5811cb7 100644 --- a/tests/test_frame_server.py +++ b/tests/test_frame_server.py @@ -1,7 +1,8 @@ """Tests for CompanionFrameServer and advert push frame construction.""" +import asyncio import struct -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock import pytest @@ -155,3 +156,75 @@ async def test_cmd_send_raw_data_send_failure_writes_table_full(): assert len(bridge.calls) == 1 server._write_err.assert_called_once_with(ERR_CODE_TABLE_FULL) server._write_ok.assert_not_called() + + +@pytest.mark.asyncio +async def test_push_trace_data_and_push_rx_raw_are_async_and_await_drain(): + """push_trace_data and push_rx_raw_async await drain for backpressure.""" + bridge = _MockBridgeSendRawDirect() + server = CompanionFrameServer(bridge, "hash", port=0) + writer = Mock() + writer.write = Mock() + writer.is_closing = Mock(return_value=False) + writer.drain = AsyncMock(return_value=None) + server._client_writer = writer + + await server.push_trace_data( + path_len=1, + flags=0, + tag=1, + auth_code=0, + path_hashes=b"\x00", + path_snrs=b"\x00", + final_snr_byte=0, + ) + writer.write.assert_called_once() + writer.drain.assert_awaited_once() + + writer.reset_mock() + writer.drain = AsyncMock(return_value=None) + + await server.push_rx_raw_async(snr=-5.0, rssi=-100, raw=b"abc") + writer.write.assert_called_once() + writer.drain.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_push_burst_serialized_by_drain(): + """Multiple pushes in succession each await drain; no concurrent drain tasks.""" + bridge = _MockBridgeSendRawDirect() + server = CompanionFrameServer(bridge, "hash", port=0) + drain_calls = [] + + async def track_drain(): + drain_calls.append(1) + await asyncio.sleep(0) + + writer = Mock() + writer.write = Mock() + writer.is_closing = Mock(return_value=False) + writer.drain = AsyncMock(side_effect=track_drain) + server._client_writer = writer + + for i in range(5): + await server.push_rx_raw_async(snr=0.0, rssi=-80, raw=bytes([i])) + + assert len(drain_calls) == 5 + assert writer.write.call_count == 5 + + +@pytest.mark.asyncio +async def test_push_rx_raw_sync_schedules_push(): + """Sync push_rx_raw() schedules push so callers without await still send.""" + bridge = _MockBridgeSendRawDirect() + server = CompanionFrameServer(bridge, "hash", port=0) + writer = Mock() + writer.write = Mock() + writer.is_closing = Mock(return_value=False) + writer.drain = AsyncMock(return_value=None) + server._client_writer = writer + + server.push_rx_raw(snr=-5.0, rssi=-100, raw=b"abc") # sync call, no await + await asyncio.sleep(0) # let the scheduled task run + writer.write.assert_called_once() + writer.drain.assert_awaited_once() From fda7506ba55c00b8eaef46e00fabdeb580162b25 Mon Sep 17 00:00:00 2001 From: agessaman Date: Mon, 2 Mar 2026 14:32:26 -0800 Subject: [PATCH 41/50] fix: guard PATH updates for non-contacts and harden TCP drain - Skip _on_contact_path_updated for contacts not in the store, matching companion firmware behavior. Also persist the updated path to the in-memory ContactStore before firing callbacks. - Add defense-in-depth check in frame_server's on_contact_path_updated push callback to avoid pushing or persisting phantom contacts. - Make _drain_writer return bool; break the client loop immediately on drain failure instead of waiting for the next read to discover a dead connection. - Drain every 10 frames in _cmd_get_contacts to avoid unbounded write-buffer growth for large contact lists. - Extract _build_message_frame helper shared by base and repeater, eliminating ~60 lines of duplicated message encoding and fixing the repeater's hard-coded SNR=0 for V3 frames. - Remove unnecessary hasattr checks on Contact dataclass fields, redundant fallback loop in _cmd_send_txt_msg, and fix get_full_list evaluation order in _cmd_get_channel. - Simplify on_advert_received dict/object handling; warn on unexpected non-Contact input instead of silently converting. --- src/pymc_core/companion/companion_base.py | 23 +- src/pymc_core/companion/frame_server.py | 297 +++++++++++----------- tests/test_companion_radio.py | 4 + 3 files changed, 163 insertions(+), 161 deletions(-) diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py index 574e9c8..7286d14 100644 --- a/src/pymc_core/companion/companion_base.py +++ b/src/pymc_core/companion/companion_base.py @@ -687,6 +687,15 @@ async def _apply_advert_to_stores( # Push Callbacks # ------------------------------------------------------------------------- + def clear_push_callbacks(self) -> None: + """Remove all registered push callbacks. + + Called by FrameServer between client connections so that stale + closures from a previous connection are not invoked on the next one. + """ + for key in self._push_callbacks: + self._push_callbacks[key].clear() + def on_message_received(self, callback: Callable) -> None: self._push_callbacks["message_received"].append(callback) @@ -701,15 +710,17 @@ def on_contact_path_updated(self, callback: Callable) -> None: async def _on_contact_path_updated(self, pub: bytes, path_len: int, path_bytes: bytes) -> None: """Called by ProtocolResponseHandler when contact's out_path is updated from a PATH packet. - Converts (pub, path_len, path_bytes) to a Contact and fires user callbacks with (contact). + + Matches companion firmware behaviour: PATH updates are only applied + (and pushed to the client) for contacts that already exist in the + store. Unknown public keys are silently ignored. """ contact = self.get_contact_by_key(pub) if contact is None: - contact = Contact( - public_key=pub, - out_path_len=path_len, - out_path=path_bytes, - ) + return # Firmware does not send PATH for non-contacts + contact.out_path_len = path_len + contact.out_path = path_bytes + self.contacts.update(contact) await self._fire_callbacks("contact_path_updated", contact) def on_send_confirmed(self, callback: Callable) -> None: diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index ff0cde0..572c89f 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -197,6 +197,7 @@ def __init__( self._server: Optional[asyncio.Server] = None self._client_writer: Optional[asyncio.StreamWriter] = None self._client_reader: Optional[asyncio.StreamReader] = None + self._write_lock: asyncio.Lock = asyncio.Lock() self._app_target_ver = 0 # Pre-compute padded device info bytes for _cmd_device_query. Version string @@ -288,22 +289,26 @@ def _get_batt_and_storage(self) -> tuple[int, int, int]: def _setup_push_callbacks(self) -> None: """Subscribe to bridge events and send PUSH frames to connected client.""" + # Clear any callbacks registered by a previous connection so they + # don't accumulate across reconnections. + self.bridge.clear_push_callbacks() async def _write_push(data: bytes) -> None: - if self._client_writer and not self._client_writer.is_closing(): - try: - if len(data) > MAX_PAYLOAD_SIZE: - logger.warning( - "Push frame payload truncated from %s to %s", - len(data), - MAX_PAYLOAD_SIZE, - ) - data = data[:MAX_PAYLOAD_SIZE] - frame = bytes([FRAME_OUTBOUND_PREFIX]) + struct.pack(" MAX_PAYLOAD_SIZE: + logger.warning( + "Push frame payload truncated from %s to %s", + len(data), + MAX_PAYLOAD_SIZE, + ) + data = data[:MAX_PAYLOAD_SIZE] + frame = bytes([FRAME_OUTBOUND_PREFIX]) + struct.pack("= 32 ): - await _write_push(bytes([PUSH_CODE_PATH_UPDATED]) + contact.public_key[:32]) + return + if not self.bridge.contacts.get_by_key(contact.public_key): + return + await _write_push(bytes([PUSH_CODE_PATH_UPDATED]) + contact.public_key[:32]) try: - if contact is not None: - await self._persist_contact(contact) + await self._persist_contact(contact) except Exception as e: logger.warning("Persist contact after path update failed: %s", e) @@ -494,12 +493,15 @@ async def push_trace_data( + path_snrs + bytes([final_snr_byte & 0xFF]) ) - try: - frame = bytes([FRAME_OUTBOUND_PREFIX]) + struct.pack(" None: """Push raw RX packet to client (PUSH_CODE_LOG_RX_DATA 0x88). @@ -526,12 +528,15 @@ async def _push_rx_raw_impl(self, snr: float, rssi: int, raw: bytes) -> None: rssi_byte += 256 payload_len = min(len(raw), MAX_PAYLOAD_SIZE - 3) # 3 = code + snr + rssi data = bytes([PUSH_CODE_LOG_RX_DATA, snr_byte & 0xFF, rssi_byte & 0xFF]) + raw[:payload_len] - try: - frame = bytes([FRAME_OUTBOUND_PREFIX]) + struct.pack(" None: """Push raw RX packet to client with backpressure (await in async code).""" @@ -575,27 +580,34 @@ async def push_control_data( ) + payload_slice ) - try: - frame = bytes([FRAME_OUTBOUND_PREFIX]) + struct.pack(" None: + async def _drain_writer(self) -> bool: + """Drain the write buffer. Returns *False* if the connection is lost.""" if self._client_writer: try: await self._client_writer.drain() + return True except (ConnectionResetError, BrokenPipeError, OSError) as e: logger.warning("Drain failed (connection lost): %s", e) + return False + return False def _write_frame(self, data: bytes) -> None: """Send a frame to the connected client (outbound format). @@ -628,21 +640,29 @@ async def _heartbeat_loop(self) -> None: try: while self._client_writer and not self._client_writer.is_closing(): await asyncio.sleep(self._heartbeat_interval) - if self._client_writer and not self._client_writer.is_closing(): - now = self.bridge.get_time() - self._write_frame(bytes([RESP_CODE_CURR_TIME]) + struct.pack(" None: - """Enable TCP keepalive on the underlying socket.""" + def _configure_socket(writer: asyncio.StreamWriter) -> None: + """Configure TCP keepalive and low-latency options on the underlying socket.""" sock = writer.get_extra_info("socket") if sock is None: return + try: + # Disable Nagle's algorithm for real-time frame delivery (important + # over VPN/Tailscale where latency is higher and small-write + # coalescing can compound delays). + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + except OSError as e: + logger.debug("Could not set TCP_NODELAY: %s", e) try: sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) if sys.platform == "linux": @@ -682,7 +702,7 @@ async def _handle_client( self._client_reader = reader self._client_writer = writer - self._enable_tcp_keepalive(writer) + self._configure_socket(writer) self._setup_push_callbacks() logger.info("Companion client connected (port=%s)", self.port) @@ -710,8 +730,11 @@ async def _handle_client( disconnect_reason = "frame_too_long" break payload = await reader.readexactly(frame_len) - await self._handle_cmd(payload) - await self._drain_writer() + async with self._write_lock: + await self._handle_cmd(payload) + if not await self._drain_writer(): + disconnect_reason = "drain_failed" + break except asyncio.IncompleteReadError: disconnect_reason = "incomplete_read" except (ConnectionResetError, BrokenPipeError) as e: @@ -940,8 +963,10 @@ async def _cmd_get_contacts(self, data: bytes) -> None: since = struct.unpack("= 4 else 0 contacts = self.bridge.get_contacts(since=since) self._write_frame(bytes([RESP_CODE_CONTACTS_START]) + struct.pack(" None: name = (c.name.encode("utf-8")[:32] if isinstance(c.name, str) else c.name[:32]).ljust( 32, b"\x00" ) - opl = c.out_path_len if hasattr(c, "out_path_len") else -1 - opl_byte = 0xFF if opl < 0 else min(opl, 255) + opl_byte = 0xFF if c.out_path_len < 0 else min(c.out_path_len, 255) frame = ( - bytes([RESP_CODE_CONTACT]) - + pubkey - + bytes( - [ - c.adv_type if hasattr(c, "adv_type") else 0, - c.flags if hasattr(c, "flags") else 0, - ] - ) - + bytes([opl_byte]) - + (c.out_path[:MAX_PATH_SIZE] if hasattr(c, "out_path") and c.out_path else b"").ljust( - MAX_PATH_SIZE, b"\x00" - ) + bytes([RESP_CODE_CONTACT, *pubkey, c.adv_type, c.flags, opl_byte]) + + (c.out_path[:MAX_PATH_SIZE] if c.out_path else b"").ljust(MAX_PATH_SIZE, b"\x00") + name - + struct.pack( - " None: attempt = data[1] pubkey_prefix = data[6:12] text = data[12:].decode("utf-8", errors="replace").rstrip("\x00") - contact = ( - self.bridge.contacts.get_by_key_prefix(pubkey_prefix) - if hasattr(self.bridge.contacts, "get_by_key_prefix") - else None - ) - if not contact: - for c in self.bridge.get_contacts(): - pk = ( - c.public_key if isinstance(c.public_key, bytes) else bytes.fromhex(c.public_key) - ) - if pk[:6] == pubkey_prefix: - contact = c - break + contact = self.bridge.contacts.get_by_key_prefix(pubkey_prefix) if not contact: self._write_err(ERR_CODE_NOT_FOUND) return @@ -1212,22 +1205,17 @@ async def _cmd_send_trace_path(self, data: bytes) -> None: final_snr_byte, ) - async def _cmd_sync_next_message(self, data: bytes) -> None: - msg = self.bridge.sync_next_message() - if msg is None: - msg = await asyncio.to_thread(self._sync_next_from_persistence) - if msg is None: - self._write_frame(bytes([RESP_CODE_NO_MORE_MESSAGES])) - return + def _build_message_frame(self, msg: "QueuedMessage") -> bytes: + """Encode a QueuedMessage into a response frame (shared by base and subclasses).""" + snr_byte = max(-128, min(127, int(round(getattr(msg, "snr", 0) * 4)))) + if snr_byte < 0: + snr_byte += 256 if msg.is_channel: path_len_byte = msg.path_len if msg.path_len < 256 else 0xFF txt_type = 0 text_bytes = (msg.text or "").rstrip("\x00").encode("utf-8", errors="replace") if self._app_target_ver >= 3: - snr_byte = max(-128, min(127, int(round(msg.snr * 4)))) - if snr_byte < 0: - snr_byte += 256 - frame = ( + return ( bytes( [ RESP_CODE_CHANNEL_MSG_RECV_V3, @@ -1242,41 +1230,40 @@ async def _cmd_sync_next_message(self, data: bytes) -> None: + struct.pack("= 6 else msg.sender_key.ljust(6, b"\x00") + return ( + bytes([RESP_CODE_CHANNEL_MSG_RECV, msg.channel_idx, path_len_byte, txt_type]) + + struct.pack("= 3: - snr_byte = max(-128, min(127, int(round(msg.snr * 4)))) - if snr_byte < 0: - snr_byte += 256 - frame = ( - bytes([RESP_CODE_CONTACT_MSG_RECV_V3, snr_byte & 0xFF, 0, 0]) - + prefix - + bytes([path_len_byte, msg.txt_type]) - + struct.pack("= 6 else msg.sender_key.ljust(6, b"\x00") + ) + path_len_byte = msg.path_len if msg.path_len < 256 else 0xFF + text_bytes = msg.text.encode("utf-8", errors="replace") + if self._app_target_ver >= 3: + return ( + bytes([RESP_CODE_CONTACT_MSG_RECV_V3, snr_byte & 0xFF, 0, 0]) + + prefix + + bytes([path_len_byte, msg.txt_type]) + + struct.pack(" None: + msg = self.bridge.sync_next_message() + if msg is None: + msg = await asyncio.to_thread(self._sync_next_from_persistence) + if msg is None: + self._write_frame(bytes([RESP_CODE_NO_MORE_MESSAGES])) + return + self._write_frame(self._build_message_frame(msg)) async def _cmd_send_login(self, data: bytes) -> None: if len(data) < 32: @@ -1581,8 +1568,8 @@ async def _cmd_import_contact(self, data: bytes) -> None: self._write_ok() if ok else self._write_err(ERR_CODE_ILLEGAL_ARG) async def _cmd_get_channel(self, data: bytes) -> None: - channel_idx = data[0] if len(data) >= 1 else 0 get_full_list = len(data) == 0 + channel_idx = data[0] if not get_full_list else 0 max_channels_val = getattr(getattr(self.bridge, "channels", None), "max_channels", 40) def _channel_info_frame(idx: int, ch) -> bytes: diff --git a/tests/test_companion_radio.py b/tests/test_companion_radio.py index c8c8035..38ef497 100644 --- a/tests/test_companion_radio.py +++ b/tests/test_companion_radio.py @@ -303,6 +303,10 @@ async def on_path_updated(contact): assert proto._contact_path_updated_callback is not None pub = b"\x22" * 32 + # Contact must exist in the store; path updates for unknown contacts + # are silently dropped (matches companion firmware behaviour). + comp.contacts.add(Contact(public_key=pub, name="test")) + path_len = 2 path_bytes = bytes([0x01, 0x02]) cb_result = proto._contact_path_updated_callback(pub, path_len, path_bytes) From 032099074cc3972653ea8e7ac656cb7ccae451f5 Mon Sep 17 00:00:00 2001 From: agessaman Date: Mon, 2 Mar 2026 16:43:00 -0800 Subject: [PATCH 42/50] refactor(companion): transition push methods to enqueue frames for improved backpressure handling - Replaced asynchronous push methods with synchronous versions that enqueue frames into a write queue, enhancing backpressure management. - Introduced a dedicated writer task to process the queue, ensuring reliable frame transmission. - Updated tests to validate the new frame queuing behavior and confirm correct frame formatting for trace and RX raw data pushes. --- src/pymc_core/companion/frame_server.py | 299 ++++++++++++------------ tests/test_frame_server.py | 91 +++++--- 2 files changed, 210 insertions(+), 180 deletions(-) diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index 572c89f..7e04eac 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -197,7 +197,8 @@ def __init__( self._server: Optional[asyncio.Server] = None self._client_writer: Optional[asyncio.StreamWriter] = None self._client_reader: Optional[asyncio.StreamReader] = None - self._write_lock: asyncio.Lock = asyncio.Lock() + self._write_queue: Optional[asyncio.Queue] = None + self._writer_task: Optional[asyncio.Task] = None self._app_target_ver = 0 # Pre-compute padded device info bytes for _cmd_device_query. Version string @@ -239,6 +240,20 @@ async def start(self) -> None: async def stop(self) -> None: """Stop the TCP server and disconnect any client.""" + # Signal writer task to stop and wait for it + if self._write_queue is not None: + try: + self._write_queue.put_nowait(None) # Sentinel + except asyncio.QueueFull: + pass + if self._writer_task is not None: + self._writer_task.cancel() + try: + await self._writer_task + except asyncio.CancelledError: + pass + self._writer_task = None + self._write_queue = None if self._client_writer: try: self._client_writer.close() @@ -293,22 +308,9 @@ def _setup_push_callbacks(self) -> None: # don't accumulate across reconnections. self.bridge.clear_push_callbacks() - async def _write_push(data: bytes) -> None: - async with self._write_lock: - if self._client_writer and not self._client_writer.is_closing(): - try: - if len(data) > MAX_PAYLOAD_SIZE: - logger.warning( - "Push frame payload truncated from %s to %s", - len(data), - MAX_PAYLOAD_SIZE, - ) - data = data[:MAX_PAYLOAD_SIZE] - frame = bytes([FRAME_OUTBOUND_PREFIX]) + struct.pack(" None: + """Enqueue a push frame. Sync, non-blocking.""" + self._enqueue_frame(data) async def on_message_received( sender_key, text, timestamp, txt_type, packet_hash=None, snr=None, rssi=None @@ -326,16 +328,16 @@ async def on_message_received( "rssi": rssi, } await self._persist_companion_message(msg_dict) - await _write_push(bytes([PUSH_CODE_MSG_WAITING])) + _write_push(bytes([PUSH_CODE_MSG_WAITING])) - async def on_send_confirmed(crc): + def on_send_confirmed(crc): data = struct.pack( "= 32: - await _write_push(bytes([PUSH_CODE_CONTACT_DELETED]) + pub_key[:32]) + _write_push(bytes([PUSH_CODE_CONTACT_DELETED]) + pub_key[:32]) - async def on_contacts_full(): - await _write_push(bytes([PUSH_CODE_CONTACTS_FULL])) + def on_contacts_full(): + _write_push(bytes([PUSH_CODE_CONTACTS_FULL])) - async def on_raw_data_received(payload_bytes: bytes, snr: float, rssi: int) -> None: + def on_raw_data_received(payload_bytes: bytes, snr: float, rssi: int) -> None: """Push PUSH_CODE_RAW_DATA (0x84): code, SNR byte, RSSI byte, 0xFF, payload.""" snr_byte = max(-128, min(127, int(round(snr * 4)))) rssi_byte = max(-128, min(127, int(rssi))) @@ -446,7 +448,7 @@ async def on_raw_data_received(payload_bytes: bytes, snr: float, rssi: int) -> N + bytes([0xFF]) + payload_bytes[:payload_len] ) - await _write_push(data) + _write_push(data) self.bridge.on_message_received(on_message_received) self.bridge.on_channel_message_received(on_channel_message_received) @@ -474,8 +476,12 @@ async def push_trace_data( final_snr_byte: int, ) -> None: """Push PUSH_CODE_TRACE_DATA (0x89) to client. Matches firmware - ``onTraceRecv()`` frame format.""" - if not self._client_writer or self._client_writer.is_closing(): + ``onTraceRecv()`` frame format. + + Kept as ``async def`` for backward-compatible call sites that + ``await`` it, but the body is synchronous (just enqueues). + """ + if self._write_queue is None: return path_sz = flags & 0x03 expected_snr_len = path_len >> path_sz @@ -493,32 +499,14 @@ async def push_trace_data( + path_snrs + bytes([final_snr_byte & 0xFF]) ) - async with self._write_lock: - if not self._client_writer or self._client_writer.is_closing(): - return - try: - frame = bytes([FRAME_OUTBOUND_PREFIX]) + struct.pack(" None: """Push raw RX packet to client (PUSH_CODE_LOG_RX_DATA 0x88). - Schedules the push on the event loop so it works when called from sync - context (e.g. repeater RX callback). For backpressure from async code, - use ``await self.push_rx_raw_async(snr, rssi, raw)`` instead. + Sync, non-blocking. Safe to call from any context (async or sync). """ - try: - loop = asyncio.get_running_loop() - loop.create_task(self._push_rx_raw_impl(snr, rssi, raw)) - except RuntimeError: - pass # no running loop - - async def _push_rx_raw_impl(self, snr: float, rssi: int, raw: bytes) -> None: - """Implementation of push_rx_raw; write + drain for backpressure.""" - if not self._client_writer or self._client_writer.is_closing(): + if self._write_queue is None: return snr_byte = max(-128, min(127, int(round(snr * 4)))) rssi_byte = max(-128, min(127, int(rssi))) @@ -528,19 +516,11 @@ async def _push_rx_raw_impl(self, snr: float, rssi: int, raw: bytes) -> None: rssi_byte += 256 payload_len = min(len(raw), MAX_PAYLOAD_SIZE - 3) # 3 = code + snr + rssi data = bytes([PUSH_CODE_LOG_RX_DATA, snr_byte & 0xFF, rssi_byte & 0xFF]) + raw[:payload_len] - async with self._write_lock: - if not self._client_writer or self._client_writer.is_closing(): - return - try: - frame = bytes([FRAME_OUTBOUND_PREFIX]) + struct.pack(" None: - """Push raw RX packet to client with backpressure (await in async code).""" - await self._push_rx_raw_impl(snr, rssi, raw) + """Push raw RX packet to client. Async wrapper for backward compatibility.""" + self.push_rx_raw(snr, rssi, raw) async def push_control_data( self, @@ -550,8 +530,12 @@ async def push_control_data( path_bytes: bytes, payload: bytes, ) -> None: - """Push CONTROL packet to client (PUSH_CODE_CONTROL_DATA 0x8E).""" - if not self._client_writer or self._client_writer.is_closing(): + """Push CONTROL packet to client (PUSH_CODE_CONTROL_DATA 0x8E). + + Kept as ``async def`` for backward-compatible call sites that + ``await`` it, but the body is synchronous (just enqueues). + """ + if self._write_queue is None: logger.warning("Push control data skipped: no client connection") return # Discovery response (0x90): clear the no-op callback @@ -580,50 +564,34 @@ async def push_control_data( ) + payload_slice ) - async with self._write_lock: - if not self._client_writer or self._client_writer.is_closing(): - return - try: - frame = bytes([FRAME_OUTBOUND_PREFIX]) + struct.pack(" None: + """Build an outbound frame and enqueue it for the writer task. - async def _drain_writer(self) -> bool: - """Drain the write buffer. Returns *False* if the connection is lost.""" - if self._client_writer: - try: - await self._client_writer.drain() - return True - except (ConnectionResetError, BrokenPipeError, OSError) as e: - logger.warning("Drain failed (connection lost): %s", e) - return False - return False + Sync, non-blocking. On ``QueueFull`` the frame is dropped with a + warning — this provides natural backpressure shedding. + """ + if self._write_queue is None: + return + if len(data) > MAX_PAYLOAD_SIZE: + logger.warning( + "Outbound frame payload truncated from %s to %s (MAX_FRAME_SIZE=%s)", + len(data), + MAX_PAYLOAD_SIZE, + MAX_FRAME_SIZE, + ) + data = data[:MAX_PAYLOAD_SIZE] + frame = bytes([FRAME_OUTBOUND_PREFIX]) + struct.pack(" None: - """Send a frame to the connected client (outbound format). - Payload is truncated to MAX_PAYLOAD_SIZE to match firmware MAX_FRAME_SIZE (172). - """ - if self._client_writer and not self._client_writer.is_closing(): - if len(data) > MAX_PAYLOAD_SIZE: - logger.warning( - "Outbound frame payload truncated from %s to %s (MAX_FRAME_SIZE=%s)", - len(data), - MAX_PAYLOAD_SIZE, - MAX_FRAME_SIZE, - ) - data = data[:MAX_PAYLOAD_SIZE] - frame = bytes([FRAME_OUTBOUND_PREFIX]) + struct.pack(" None: self._write_frame(bytes([RESP_CODE_OK])) @@ -632,23 +600,70 @@ def _write_err(self, err_code: int) -> None: self._write_frame(bytes([RESP_CODE_ERR, err_code])) # ------------------------------------------------------------------------- - # Client handling + # Writer task # ------------------------------------------------------------------------- - async def _heartbeat_loop(self) -> None: - """Send periodic ``RESP_CODE_CURR_TIME`` to keep the TCP connection alive.""" + # Must exceed DEFAULT_MAX_CONTACTS (+2 for START/END) so that + # _cmd_get_contacts can enqueue the full contact dump without drops. + _WRITE_QUEUE_MAXSIZE = 2048 + _DRAIN_BATCH = 10 + + async def _writer_loop(self, writer: asyncio.StreamWriter) -> None: + """Single writer task: pull frames from the queue, write to the + ``StreamWriter``, and drain periodically. + + Integrates heartbeat via timeout on :pymethod:`asyncio.Queue.get` — + when no frames arrive within ``_heartbeat_interval`` seconds a + ``RESP_CODE_CURR_TIME`` heartbeat frame is generated automatically, + eliminating the need for a separate heartbeat task. + + On any write/drain error the writer is closed, which causes the read + loop in :pymethod:`_handle_client` to receive EOF → clean disconnect. + """ + frames_since_drain = 0 try: - while self._client_writer and not self._client_writer.is_closing(): - await asyncio.sleep(self._heartbeat_interval) - async with self._write_lock: - if self._client_writer and not self._client_writer.is_closing(): - now = self.bridge.get_time() - self._write_frame(bytes([RESP_CODE_CURR_TIME]) + struct.pack("= self._DRAIN_BATCH: + await writer.drain() + frames_since_drain = 0 except asyncio.CancelledError: pass + except (ConnectionResetError, BrokenPipeError, OSError) as e: + logger.warning("Writer loop connection lost: %s", e) except Exception as e: - logger.debug("Heartbeat loop ended: %s", e) + logger.error("Writer loop error: %s", e, exc_info=True) + finally: + try: + if not writer.is_closing(): + writer.close() + except Exception: + pass + + # ------------------------------------------------------------------------- + # Client handling + # ------------------------------------------------------------------------- @staticmethod def _configure_socket(writer: asyncio.StreamWriter) -> None: @@ -703,10 +718,11 @@ async def _handle_client( self._client_reader = reader self._client_writer = writer self._configure_socket(writer) + self._write_queue = asyncio.Queue(maxsize=self._WRITE_QUEUE_MAXSIZE) self._setup_push_callbacks() logger.info("Companion client connected (port=%s)", self.port) - heartbeat_task = asyncio.create_task(self._heartbeat_loop()) + self._writer_task = asyncio.create_task(self._writer_loop(writer)) disconnect_reason: Optional[str] = None try: while True: @@ -730,11 +746,10 @@ async def _handle_client( disconnect_reason = "frame_too_long" break payload = await reader.readexactly(frame_len) - async with self._write_lock: - await self._handle_cmd(payload) - if not await self._drain_writer(): - disconnect_reason = "drain_failed" - break + await self._handle_cmd(payload) + if self._writer_task.done(): + disconnect_reason = "writer_failed" + break except asyncio.IncompleteReadError: disconnect_reason = "incomplete_read" except (ConnectionResetError, BrokenPipeError) as e: @@ -743,11 +758,19 @@ async def _handle_client( disconnect_reason = f"other: {type(e).__name__}: {e}" logger.error("Client handler error: %s", e, exc_info=True) finally: - heartbeat_task.cancel() - try: - await heartbeat_task - except asyncio.CancelledError: - pass + if self._write_queue is not None: + try: + self._write_queue.put_nowait(None) # Sentinel + except asyncio.QueueFull: + pass + if self._writer_task is not None: + self._writer_task.cancel() + try: + await self._writer_task + except asyncio.CancelledError: + pass + self._writer_task = None + self._write_queue = None if self._client_writer is writer: self._client_writer = None self._client_reader = None @@ -965,8 +988,6 @@ async def _cmd_get_contacts(self, data: bytes) -> None: self._write_frame(bytes([RESP_CODE_CONTACTS_START]) + struct.pack(" None: return if ok: self._write_ok() - await self._drain_writer() else: self._write_err(ERR_CODE_TABLE_FULL) @@ -1301,7 +1321,6 @@ async def _cmd_send_status_req(self, data: bytes) -> None: return pubkey = data[0:32] self._write_frame(bytes([RESP_CODE_SENT, 0]) + struct.pack(" None: want_location = bool(flags & 0x02) want_environment = bool(flags & 0x04) self._write_frame(bytes([RESP_CODE_SENT, 0]) + struct.pack(" None: ) if not result.get("success"): self._write_frame(bytes([PUSH_CODE_TELEMETRY_RESPONSE, 0]) + pubkey[:6]) - await self._drain_writer() return telem_data = result.get("telemetry_data", {}) raw_bytes = telem_data.get("raw_bytes", b"") if not raw_bytes: self._write_frame(bytes([PUSH_CODE_TELEMETRY_RESPONSE, 0]) + pubkey[:6]) - await self._drain_writer() return self._write_frame(bytes([PUSH_CODE_TELEMETRY_RESPONSE, 0]) + pubkey[:6] + raw_bytes) - await self._drain_writer() logger.info("Telemetry push sent to client: %d bytes LPP", len(raw_bytes)) async def _cmd_send_self_advert(self, data: bytes) -> None: @@ -1367,7 +1382,6 @@ async def _cmd_set_advert_latlon(self, data: bytes) -> None: async def _cmd_add_update_contact(self, data: bytes) -> None: if len(data) < 36: self._write_err(ERR_CODE_ILLEGAL_ARG) - await self._drain_writer() return pubkey = data[0:32] adv_type = data[32] @@ -1431,7 +1445,6 @@ async def _cmd_add_update_contact(self, data: bytes) -> None: + struct.pack(" None: async def _cmd_remove_contact(self, data: bytes) -> None: if len(data) < 32: self._write_err(ERR_CODE_ILLEGAL_ARG) - await self._drain_writer() return pubkey = data[:32] ok = self.bridge.remove_contact(pubkey) @@ -1451,7 +1463,6 @@ async def _cmd_remove_contact(self, data: bytes) -> None: except Exception as e: logger.warning("Save contacts after remove failed: %s", e) self._write_ok() if ok else self._write_err(ERR_CODE_NOT_FOUND) - await self._drain_writer() async def _cmd_reset_path(self, data: bytes) -> None: if len(data) < 32: diff --git a/tests/test_frame_server.py b/tests/test_frame_server.py index 5811cb7..d39d5a8 100644 --- a/tests/test_frame_server.py +++ b/tests/test_frame_server.py @@ -159,15 +159,11 @@ async def test_cmd_send_raw_data_send_failure_writes_table_full(): @pytest.mark.asyncio -async def test_push_trace_data_and_push_rx_raw_are_async_and_await_drain(): - """push_trace_data and push_rx_raw_async await drain for backpressure.""" +async def test_push_trace_data_enqueues_frame(): + """push_trace_data enqueues a correctly formatted trace frame.""" bridge = _MockBridgeSendRawDirect() server = CompanionFrameServer(bridge, "hash", port=0) - writer = Mock() - writer.write = Mock() - writer.is_closing = Mock(return_value=False) - writer.drain = AsyncMock(return_value=None) - server._client_writer = writer + server._write_queue = asyncio.Queue(maxsize=256) await server.push_trace_data( path_len=1, @@ -178,53 +174,76 @@ async def test_push_trace_data_and_push_rx_raw_are_async_and_await_drain(): path_snrs=b"\x00", final_snr_byte=0, ) - writer.write.assert_called_once() - writer.drain.assert_awaited_once() + assert not server._write_queue.empty() + frame = server._write_queue.get_nowait() + # Frame format: FRAME_OUTBOUND_PREFIX + 2-byte LE length + payload + assert frame[0] == 0x3E # FRAME_OUTBOUND_PREFIX + _ = struct.unpack(" Date: Mon, 2 Mar 2026 18:56:01 -0800 Subject: [PATCH 43/50] feat(companion): synchronize node name across handlers for echo detection - Added a method to synchronize the current node name to group text handlers in CompanionBridge and CompanionRadio. - Implemented the `_sync_our_node_name_to_handlers` method in both classes to update the node name for echo detection. - Enhanced the GroupTextHandler to include a method for updating the stored node name. - Updated tests to verify the correct functionality of the new synchronization feature. --- src/pymc_core/companion/companion_base.py | 6 +++ src/pymc_core/companion/companion_bridge.py | 6 +++ src/pymc_core/companion/companion_radio.py | 6 +++ src/pymc_core/node/dispatcher.py | 1 + src/pymc_core/node/handlers/group_text.py | 4 ++ tests/test_handlers.py | 47 +++++++++++++++++++-- 6 files changed, 67 insertions(+), 3 deletions(-) diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py index 7286d14..c2c7a9b 100644 --- a/src/pymc_core/companion/companion_base.py +++ b/src/pymc_core/companion/companion_base.py @@ -334,6 +334,12 @@ def set_advert_name(self, name: str) -> None: """Set the node's advertised name (max 31 chars).""" self.prefs.node_name = name[:31] self._save_prefs() + self._sync_our_node_name_to_handlers() + + def _sync_our_node_name_to_handlers(self) -> None: + """Sync node name to group text handler for echo detection. + No-op in base; override in Bridge/Radio.""" + pass def set_advert_latlon(self, lat: float, lon: float) -> None: """Set the GPS coordinates included in advertisements.""" diff --git a/src/pymc_core/companion/companion_bridge.py b/src/pymc_core/companion/companion_bridge.py index 718d43e..0001627 100644 --- a/src/pymc_core/companion/companion_bridge.py +++ b/src/pymc_core/companion/companion_bridge.py @@ -275,6 +275,12 @@ def _get_login_response_handler(self) -> Any: def _get_text_handler(self) -> Any: return self._text_handler_ref + def _sync_our_node_name_to_handlers(self) -> None: + """Sync current node name to group text handler for echo detection.""" + handler = self._handlers.get(PAYLOAD_TYPE_GRP_TXT) + if handler is not None: + handler.set_our_node_name(self.prefs.node_name) + # ------------------------------------------------------------------------- # RX Entry Point # ------------------------------------------------------------------------- diff --git a/src/pymc_core/companion/companion_radio.py b/src/pymc_core/companion/companion_radio.py index 07deab0..ee03bd6 100644 --- a/src/pymc_core/companion/companion_radio.py +++ b/src/pymc_core/companion/companion_radio.py @@ -173,6 +173,12 @@ def set_advert_name(self, name: str) -> None: super().set_advert_name(name) self.node.node_name = self.prefs.node_name + def _sync_our_node_name_to_handlers(self) -> None: + """Sync current node name to group text handler for echo detection.""" + handler = getattr(self.node.dispatcher, "group_text_handler", None) + if handler is not None: + handler.set_our_node_name(self.prefs.node_name) + def set_radio_params(self, freq_hz: int, bw_hz: int, sf: int, cr: int) -> bool: super().set_radio_params(freq_hz, bw_hz, sf, cr) if hasattr(self._radio, "configure_radio"): diff --git a/src/pymc_core/node/dispatcher.py b/src/pymc_core/node/dispatcher.py index 61b494e..8bb5dba 100644 --- a/src/pymc_core/node/dispatcher.py +++ b/src/pymc_core/node/dispatcher.py @@ -190,6 +190,7 @@ def register_default_handlers( self.text_message_handler = core.text_handler self.protocol_response_handler = core.protocol_response_handler self.login_response_handler = core.login_response_handler + self.group_text_handler = core.group_text_handler # Backward compat alias self.telemetry_response_handler = core.protocol_response_handler diff --git a/src/pymc_core/node/handlers/group_text.py b/src/pymc_core/node/handlers/group_text.py index 5030c64..35dafcc 100644 --- a/src/pymc_core/node/handlers/group_text.py +++ b/src/pymc_core/node/handlers/group_text.py @@ -29,6 +29,10 @@ def __init__( self.event_service = event_service self.our_node_name = our_node_name # Store our node name for echo detection + def set_our_node_name(self, name: str | None) -> None: + """Update the node name used for echo detection (e.g. after set_advert_name).""" + self.our_node_name = name + def _get_channel_by_hash(self, channel_hash: int) -> Optional[dict]: """Find a channel by its hash (first byte of SHA256) from database. diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 4c1365f..77941b6 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -259,9 +259,50 @@ def setup_method(self): self.send_packet_fn = AsyncMock() self.event_service = MockEventService() self.handler = GroupTextHandler( - self.local_identity, self.contacts, self.log_fn, self.send_packet_fn + self.local_identity, + self.contacts, + self.log_fn, + self.send_packet_fn, + channel_db=None, + event_service=self.event_service, + our_node_name="InitialName", ) - # GroupTextHandler doesn't take event_service in constructor + + def test_set_our_node_name_updates_stored_name(self): + """set_our_node_name updates the name used for echo detection.""" + assert self.handler.our_node_name == "InitialName" + self.handler.set_our_node_name("NewName") + assert self.handler.our_node_name == "NewName" + self.handler.set_our_node_name(None) + assert self.handler.our_node_name is None + + def test_is_own_message_uses_current_name_after_set_our_node_name(self): + """_is_own_message uses the current our_node_name after it is updated.""" + self.handler.set_our_node_name("Howl 🏝️") + packet = Packet() + packet.decrypted = {"group_text_data": {"sender_name": "Howl 🏝️"}} + assert self.handler._is_own_message(packet) is True + packet.decrypted = {"group_text_data": {"sender_name": "Howl 🧱"}} + assert self.handler._is_own_message(packet) is False + # After updating name, old name no longer matches + self.handler.set_our_node_name("Howl 🧱") + assert self.handler._is_own_message(packet) is True + + def test_is_own_message_false_when_sender_name_missing(self): + """_is_own_message returns False when packet has no sender_name in group_text_data.""" + self.handler.set_our_node_name("Me") + packet = Packet() + packet.decrypted = {} + assert self.handler._is_own_message(packet) is False + packet.decrypted = {"group_text_data": {}} + assert self.handler._is_own_message(packet) is False + + def test_is_own_message_false_when_no_match(self): + """_is_own_message returns False when sender name differs from our_node_name.""" + self.handler.set_our_node_name("Me") + packet = Packet() + packet.decrypted = {"group_text_data": {"sender_name": "Other"}} + assert self.handler._is_own_message(packet) is False def test_payload_type(self): """Test group text handler payload type.""" @@ -273,7 +314,7 @@ def test_group_text_handler_initialization(self): assert self.handler.contacts == self.contacts assert self.handler.log == self.log_fn assert self.handler.send_packet == self.send_packet_fn - # GroupTextHandler doesn't store event_service + assert self.handler.our_node_name == "InitialName" # Login Response Handler Tests From df073cbc1739517fe646200ea27f06761aa4a22d Mon Sep 17 00:00:00 2001 From: agessaman Date: Mon, 2 Mar 2026 19:28:21 -0800 Subject: [PATCH 44/50] feat: add radio config instance attributes to KissModemWrapper Matches SX1262Wrapper convention so engine.py getattr() reads actual configured values instead of falling back to defaults. Adds set_tx_power() and optional kwargs to configure_radio(). --- src/pymc_core/companion/companion_bridge.py | 25 +++++++ tests/test_companion_bridge.py | 83 +++++++++++++++++++++ 2 files changed, 108 insertions(+) diff --git a/src/pymc_core/companion/companion_bridge.py b/src/pymc_core/companion/companion_bridge.py index 0001627..b476730 100644 --- a/src/pymc_core/companion/companion_bridge.py +++ b/src/pymc_core/companion/companion_bridge.py @@ -198,6 +198,17 @@ def __init__( radio_config=radio_config, initial_contacts=initial_contacts, ) + + # Radio settings are authoritative from the host — re-apply after + # _load_prefs() so persisted values never overwrite the host config. + rc = self._radio_config + if rc: + self.prefs.frequency_hz = rc.get("frequency", self.prefs.frequency_hz) + self.prefs.bandwidth_hz = rc.get("bandwidth", self.prefs.bandwidth_hz) + self.prefs.spreading_factor = rc.get("spreading_factor", self.prefs.spreading_factor) + self.prefs.coding_rate = rc.get("coding_rate", self.prefs.coding_rate) + self.prefs.tx_power_dbm = rc.get("power", rc.get("tx_power", self.prefs.tx_power_dbm)) + self._packet_injector = packet_injector async def _handler_send_packet(pkt: Packet, wait_for_ack: bool = False) -> bool: @@ -331,6 +342,20 @@ async def stop(self) -> None: def is_running(self) -> bool: return self._running + # ------------------------------------------------------------------------- + # Radio settings — read-only in bridge mode + # ------------------------------------------------------------------------- + + def set_radio_params(self, freq_hz: int, bw_hz: int, sf: int, cr: int) -> bool: + """No-op: bridge radio belongs to the host.""" + logger.debug("set_radio_params ignored in bridge mode") + return True + + def set_tx_power(self, power_dbm: int) -> bool: + """No-op: bridge radio belongs to the host.""" + logger.debug("set_tx_power ignored in bridge mode") + return True + # ------------------------------------------------------------------------- # Key Management # ------------------------------------------------------------------------- diff --git a/tests/test_companion_bridge.py b/tests/test_companion_bridge.py index 5263b9b..3fd0032 100644 --- a/tests/test_companion_bridge.py +++ b/tests/test_companion_bridge.py @@ -538,3 +538,86 @@ async def test_direct_message_deduplicated_by_packet_hash(self): assert msg is not None assert msg.text == "Hello" assert bridge.sync_next_message() is None + + +# --------------------------------------------------------------------------- +# Radio settings — read-only in bridge mode +# --------------------------------------------------------------------------- + + +class TestBridgeRadioReadOnly: + """Bridge radio settings come from the host and cannot be changed by clients.""" + + RADIO_CONFIG = { + "frequency": 869618000, + "bandwidth": 62500, + "spreading_factor": 12, + "coding_rate": 7, + "tx_power": 22, + } + + def _make_bridge(self, radio_config=None): + return CompanionBridge( + LocalIdentity(), + MockPacketInjector(), + radio_config=radio_config or self.RADIO_CONFIG, + ) + + def test_bridge_set_radio_params_is_noop(self): + bridge = self._make_bridge() + # Record original values + orig_freq = bridge.prefs.frequency_hz + orig_bw = bridge.prefs.bandwidth_hz + orig_sf = bridge.prefs.spreading_factor + orig_cr = bridge.prefs.coding_rate + + # Attempt to change — should be a no-op + result = bridge.set_radio_params(915000000, 250000, 10, 5) + assert result is True + assert bridge.prefs.frequency_hz == orig_freq + assert bridge.prefs.bandwidth_hz == orig_bw + assert bridge.prefs.spreading_factor == orig_sf + assert bridge.prefs.coding_rate == orig_cr + + def test_bridge_set_tx_power_is_noop(self): + bridge = self._make_bridge() + orig_power = bridge.prefs.tx_power_dbm + + result = bridge.set_tx_power(5) + assert result is True + assert bridge.prefs.tx_power_dbm == orig_power + + def test_bridge_radio_config_applied(self): + """Radio prefs reflect the radio_config dict passed at construction.""" + bridge = self._make_bridge() + assert bridge.prefs.frequency_hz == 869618000 + assert bridge.prefs.bandwidth_hz == 62500 + assert bridge.prefs.spreading_factor == 12 + assert bridge.prefs.coding_rate == 7 + assert bridge.prefs.tx_power_dbm == 22 + + def test_bridge_radio_config_survives_load_prefs(self): + """Host radio config takes precedence over values restored by _load_prefs.""" + + class _StalePrefsLoader(CompanionBridge): + """Simulates a persistence layer that restores stale radio prefs.""" + + def _load_prefs(self): + self.prefs.frequency_hz = 111111 + self.prefs.bandwidth_hz = 222222 + self.prefs.spreading_factor = 6 + self.prefs.coding_rate = 5 + self.prefs.tx_power_dbm = -9 + + bridge = _StalePrefsLoader( + LocalIdentity(), + MockPacketInjector(), + radio_config=self.RADIO_CONFIG, + ) + + # Host config must override the stale values from _load_prefs + assert bridge.prefs.frequency_hz == 869618000 + assert bridge.prefs.bandwidth_hz == 62500 + assert bridge.prefs.spreading_factor == 12 + assert bridge.prefs.coding_rate == 7 + assert bridge.prefs.tx_power_dbm == 22 From 97c760802ac9eca62046fffff059b606535f9a35 Mon Sep 17 00:00:00 2001 From: agessaman Date: Mon, 2 Mar 2026 19:29:52 -0800 Subject: [PATCH 45/50] Revert "feat: add radio config instance attributes to KissModemWrapper" This reverts commit df073cbc1739517fe646200ea27f06761aa4a22d. --- src/pymc_core/companion/companion_bridge.py | 25 ------- tests/test_companion_bridge.py | 83 --------------------- 2 files changed, 108 deletions(-) diff --git a/src/pymc_core/companion/companion_bridge.py b/src/pymc_core/companion/companion_bridge.py index b476730..0001627 100644 --- a/src/pymc_core/companion/companion_bridge.py +++ b/src/pymc_core/companion/companion_bridge.py @@ -198,17 +198,6 @@ def __init__( radio_config=radio_config, initial_contacts=initial_contacts, ) - - # Radio settings are authoritative from the host — re-apply after - # _load_prefs() so persisted values never overwrite the host config. - rc = self._radio_config - if rc: - self.prefs.frequency_hz = rc.get("frequency", self.prefs.frequency_hz) - self.prefs.bandwidth_hz = rc.get("bandwidth", self.prefs.bandwidth_hz) - self.prefs.spreading_factor = rc.get("spreading_factor", self.prefs.spreading_factor) - self.prefs.coding_rate = rc.get("coding_rate", self.prefs.coding_rate) - self.prefs.tx_power_dbm = rc.get("power", rc.get("tx_power", self.prefs.tx_power_dbm)) - self._packet_injector = packet_injector async def _handler_send_packet(pkt: Packet, wait_for_ack: bool = False) -> bool: @@ -342,20 +331,6 @@ async def stop(self) -> None: def is_running(self) -> bool: return self._running - # ------------------------------------------------------------------------- - # Radio settings — read-only in bridge mode - # ------------------------------------------------------------------------- - - def set_radio_params(self, freq_hz: int, bw_hz: int, sf: int, cr: int) -> bool: - """No-op: bridge radio belongs to the host.""" - logger.debug("set_radio_params ignored in bridge mode") - return True - - def set_tx_power(self, power_dbm: int) -> bool: - """No-op: bridge radio belongs to the host.""" - logger.debug("set_tx_power ignored in bridge mode") - return True - # ------------------------------------------------------------------------- # Key Management # ------------------------------------------------------------------------- diff --git a/tests/test_companion_bridge.py b/tests/test_companion_bridge.py index 3fd0032..5263b9b 100644 --- a/tests/test_companion_bridge.py +++ b/tests/test_companion_bridge.py @@ -538,86 +538,3 @@ async def test_direct_message_deduplicated_by_packet_hash(self): assert msg is not None assert msg.text == "Hello" assert bridge.sync_next_message() is None - - -# --------------------------------------------------------------------------- -# Radio settings — read-only in bridge mode -# --------------------------------------------------------------------------- - - -class TestBridgeRadioReadOnly: - """Bridge radio settings come from the host and cannot be changed by clients.""" - - RADIO_CONFIG = { - "frequency": 869618000, - "bandwidth": 62500, - "spreading_factor": 12, - "coding_rate": 7, - "tx_power": 22, - } - - def _make_bridge(self, radio_config=None): - return CompanionBridge( - LocalIdentity(), - MockPacketInjector(), - radio_config=radio_config or self.RADIO_CONFIG, - ) - - def test_bridge_set_radio_params_is_noop(self): - bridge = self._make_bridge() - # Record original values - orig_freq = bridge.prefs.frequency_hz - orig_bw = bridge.prefs.bandwidth_hz - orig_sf = bridge.prefs.spreading_factor - orig_cr = bridge.prefs.coding_rate - - # Attempt to change — should be a no-op - result = bridge.set_radio_params(915000000, 250000, 10, 5) - assert result is True - assert bridge.prefs.frequency_hz == orig_freq - assert bridge.prefs.bandwidth_hz == orig_bw - assert bridge.prefs.spreading_factor == orig_sf - assert bridge.prefs.coding_rate == orig_cr - - def test_bridge_set_tx_power_is_noop(self): - bridge = self._make_bridge() - orig_power = bridge.prefs.tx_power_dbm - - result = bridge.set_tx_power(5) - assert result is True - assert bridge.prefs.tx_power_dbm == orig_power - - def test_bridge_radio_config_applied(self): - """Radio prefs reflect the radio_config dict passed at construction.""" - bridge = self._make_bridge() - assert bridge.prefs.frequency_hz == 869618000 - assert bridge.prefs.bandwidth_hz == 62500 - assert bridge.prefs.spreading_factor == 12 - assert bridge.prefs.coding_rate == 7 - assert bridge.prefs.tx_power_dbm == 22 - - def test_bridge_radio_config_survives_load_prefs(self): - """Host radio config takes precedence over values restored by _load_prefs.""" - - class _StalePrefsLoader(CompanionBridge): - """Simulates a persistence layer that restores stale radio prefs.""" - - def _load_prefs(self): - self.prefs.frequency_hz = 111111 - self.prefs.bandwidth_hz = 222222 - self.prefs.spreading_factor = 6 - self.prefs.coding_rate = 5 - self.prefs.tx_power_dbm = -9 - - bridge = _StalePrefsLoader( - LocalIdentity(), - MockPacketInjector(), - radio_config=self.RADIO_CONFIG, - ) - - # Host config must override the stale values from _load_prefs - assert bridge.prefs.frequency_hz == 869618000 - assert bridge.prefs.bandwidth_hz == 62500 - assert bridge.prefs.spreading_factor == 12 - assert bridge.prefs.coding_rate == 7 - assert bridge.prefs.tx_power_dbm == 22 From d12a68d4b1c6bff8fbe3f5c33c4074c21a66b55e Mon Sep 17 00:00:00 2001 From: agessaman Date: Mon, 2 Mar 2026 19:32:16 -0800 Subject: [PATCH 46/50] feat: add radio config instance attributes to KissModemWrapper Matches SX1262Wrapper convention so engine.py getattr() reads actual configured values instead of falling back to defaults. Adds set_tx_power() and optional kwargs to configure_radio(). --- src/pymc_core/hardware/kiss_modem_wrapper.py | 75 +++++++++++++++++--- 1 file changed, 66 insertions(+), 9 deletions(-) diff --git a/src/pymc_core/hardware/kiss_modem_wrapper.py b/src/pymc_core/hardware/kiss_modem_wrapper.py index c9ba3d6..17180c1 100644 --- a/src/pymc_core/hardware/kiss_modem_wrapper.py +++ b/src/pymc_core/hardware/kiss_modem_wrapper.py @@ -263,6 +263,15 @@ def __init__( self.radio_config = radio_config or {} self.is_configured = False + # Radio configuration — instance attributes matching SX1262Wrapper + # convention. Seeded from the config dict; updated by configure_radio(). + self.frequency = self.radio_config.get("frequency", int(869.618 * 1000000)) + self.tx_power = self.radio_config.get("power", self.radio_config.get("tx_power", 22)) + self.spreading_factor = self.radio_config.get("spreading_factor", 8) + self.bandwidth = self.radio_config.get("bandwidth", int(62500)) + self.coding_rate = self.radio_config.get("coding_rate", 8) + self.preamble_length = self.radio_config.get("preamble_length", 17) + self.serial_conn: Optional[serial.Serial] = None self.is_connected = False @@ -574,9 +583,18 @@ def _query_modem_info(self): except Exception as e: logger.warning(f"Failed to query modem info: {e}") - def configure_radio(self) -> bool: - """ - Configure radio parameters + def configure_radio( + self, + frequency: Optional[int] = None, + bandwidth: Optional[int] = None, + spreading_factor: Optional[int] = None, + coding_rate: Optional[int] = None, + ) -> bool: + """Configure radio parameters. + + When called with keyword arguments (e.g. from CompanionRadio), those + values take precedence. When called with no arguments the values are + read from ``self.radio_config`` (populated from config.yaml at init). Returns: True if configuration successful, False otherwise @@ -586,12 +604,23 @@ def configure_radio(self) -> bool: return False try: - # Extract configuration parameters with defaults - # Support both "power" and "tx_power" for compatibility with different config styles - frequency_hz = self.radio_config.get("frequency", int(869.618 * 1000000)) - bandwidth_hz = self.radio_config.get("bandwidth", int(62500)) - sf = self.radio_config.get("spreading_factor", 8) - cr = self.radio_config.get("coding_rate", 8) + # Explicit kwargs take precedence, then radio_config dict, then defaults + frequency_hz = ( + frequency + if frequency is not None + else self.radio_config.get("frequency", int(869.618 * 1000000)) + ) + bandwidth_hz = ( + bandwidth + if bandwidth is not None + else self.radio_config.get("bandwidth", int(62500)) + ) + sf = ( + spreading_factor + if spreading_factor is not None + else self.radio_config.get("spreading_factor", 8) + ) + cr = coding_rate if coding_rate is not None else self.radio_config.get("coding_rate", 8) power = self.radio_config.get("power", self.radio_config.get("tx_power", 22)) # Set radio parameters (frequency, bandwidth, SF, CR) @@ -610,6 +639,13 @@ def configure_radio(self) -> bool: # Note: Sync word is configured at firmware build time, not at runtime + # Sync instance attributes to match what was applied to hardware + self.frequency = frequency_hz + self.bandwidth = bandwidth_hz + self.spreading_factor = sf + self.coding_rate = cr + self.tx_power = power + self.is_configured = True logger.info( f"Radio configured: {frequency_hz / 1000000:.3f} MHz, " @@ -735,6 +771,27 @@ def get_radio_config(self) -> Optional[Dict[str, Any]]: } return None + def set_tx_power(self, power: int) -> bool: + """Set TX power in dBm. + + Sends the command to the modem and updates the instance attribute + on success, matching the SX1262Wrapper interface. + """ + if not self.is_connected: + logger.error("Cannot set TX power: not connected") + return False + try: + resp = self._send_command(CMD_SET_TX_POWER, bytes([power])) + if not resp or resp[0] == RESP_ERROR: + logger.error("Failed to set TX power") + return False + self.tx_power = power + logger.info(f"TX power set to {power} dBm") + return True + except Exception as e: + logger.error(f"Error setting TX power: {e}") + return False + def get_tx_power(self) -> Optional[int]: """Get current TX power in dBm""" resp = self._send_command(CMD_GET_TX_POWER) From 4028071b1c5ccef1c0494495f5b6c70b90560d79 Mon Sep 17 00:00:00 2001 From: agessaman Date: Mon, 2 Mar 2026 19:32:16 -0800 Subject: [PATCH 47/50] fix: make bridge radio settings read-only from host Override set_radio_params/set_tx_power as no-ops in CompanionBridge and re-apply host radio_config after _load_prefs() to prevent persisted values from overwriting the host's actual settings. Made-with: Cursor --- src/pymc_core/hardware/kiss_modem_wrapper.py | 75 +++++++++++++++++--- 1 file changed, 66 insertions(+), 9 deletions(-) diff --git a/src/pymc_core/hardware/kiss_modem_wrapper.py b/src/pymc_core/hardware/kiss_modem_wrapper.py index c9ba3d6..17180c1 100644 --- a/src/pymc_core/hardware/kiss_modem_wrapper.py +++ b/src/pymc_core/hardware/kiss_modem_wrapper.py @@ -263,6 +263,15 @@ def __init__( self.radio_config = radio_config or {} self.is_configured = False + # Radio configuration — instance attributes matching SX1262Wrapper + # convention. Seeded from the config dict; updated by configure_radio(). + self.frequency = self.radio_config.get("frequency", int(869.618 * 1000000)) + self.tx_power = self.radio_config.get("power", self.radio_config.get("tx_power", 22)) + self.spreading_factor = self.radio_config.get("spreading_factor", 8) + self.bandwidth = self.radio_config.get("bandwidth", int(62500)) + self.coding_rate = self.radio_config.get("coding_rate", 8) + self.preamble_length = self.radio_config.get("preamble_length", 17) + self.serial_conn: Optional[serial.Serial] = None self.is_connected = False @@ -574,9 +583,18 @@ def _query_modem_info(self): except Exception as e: logger.warning(f"Failed to query modem info: {e}") - def configure_radio(self) -> bool: - """ - Configure radio parameters + def configure_radio( + self, + frequency: Optional[int] = None, + bandwidth: Optional[int] = None, + spreading_factor: Optional[int] = None, + coding_rate: Optional[int] = None, + ) -> bool: + """Configure radio parameters. + + When called with keyword arguments (e.g. from CompanionRadio), those + values take precedence. When called with no arguments the values are + read from ``self.radio_config`` (populated from config.yaml at init). Returns: True if configuration successful, False otherwise @@ -586,12 +604,23 @@ def configure_radio(self) -> bool: return False try: - # Extract configuration parameters with defaults - # Support both "power" and "tx_power" for compatibility with different config styles - frequency_hz = self.radio_config.get("frequency", int(869.618 * 1000000)) - bandwidth_hz = self.radio_config.get("bandwidth", int(62500)) - sf = self.radio_config.get("spreading_factor", 8) - cr = self.radio_config.get("coding_rate", 8) + # Explicit kwargs take precedence, then radio_config dict, then defaults + frequency_hz = ( + frequency + if frequency is not None + else self.radio_config.get("frequency", int(869.618 * 1000000)) + ) + bandwidth_hz = ( + bandwidth + if bandwidth is not None + else self.radio_config.get("bandwidth", int(62500)) + ) + sf = ( + spreading_factor + if spreading_factor is not None + else self.radio_config.get("spreading_factor", 8) + ) + cr = coding_rate if coding_rate is not None else self.radio_config.get("coding_rate", 8) power = self.radio_config.get("power", self.radio_config.get("tx_power", 22)) # Set radio parameters (frequency, bandwidth, SF, CR) @@ -610,6 +639,13 @@ def configure_radio(self) -> bool: # Note: Sync word is configured at firmware build time, not at runtime + # Sync instance attributes to match what was applied to hardware + self.frequency = frequency_hz + self.bandwidth = bandwidth_hz + self.spreading_factor = sf + self.coding_rate = cr + self.tx_power = power + self.is_configured = True logger.info( f"Radio configured: {frequency_hz / 1000000:.3f} MHz, " @@ -735,6 +771,27 @@ def get_radio_config(self) -> Optional[Dict[str, Any]]: } return None + def set_tx_power(self, power: int) -> bool: + """Set TX power in dBm. + + Sends the command to the modem and updates the instance attribute + on success, matching the SX1262Wrapper interface. + """ + if not self.is_connected: + logger.error("Cannot set TX power: not connected") + return False + try: + resp = self._send_command(CMD_SET_TX_POWER, bytes([power])) + if not resp or resp[0] == RESP_ERROR: + logger.error("Failed to set TX power") + return False + self.tx_power = power + logger.info(f"TX power set to {power} dBm") + return True + except Exception as e: + logger.error(f"Error setting TX power: {e}") + return False + def get_tx_power(self) -> Optional[int]: """Get current TX power in dBm""" resp = self._send_command(CMD_GET_TX_POWER) From 201cb8e3339f22233173c54b41a294f3f6c1a9b9 Mon Sep 17 00:00:00 2001 From: agessaman Date: Tue, 3 Mar 2026 18:34:12 -0800 Subject: [PATCH 48/50] fix: update frequency validation and conversion in CompanionFrameServer - Changed frequency validation to accept values in kHz instead of Hz, aligning with firmware specifications. - Updated the frequency parameter in set_radio_params to convert kHz back to Hz for internal processing. - Adjusted the power validation to allow a maximum of 30 instead of 20. --- src/pymc_core/companion/frame_server.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index 7e04eac..3ab8519 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -1664,11 +1664,12 @@ async def _cmd_set_radio_params(self, data: bytes) -> None: if len(data) < 10: self._write_err(ERR_CODE_ILLEGAL_ARG) return - freq = struct.unpack_from(" None: if not (5 <= sf <= 12) or not (5 <= cr <= 8): self._write_err(ERR_CODE_ILLEGAL_ARG) return - self.bridge.set_radio_params(freq, bw, sf, cr) + self.bridge.set_radio_params(freq_khz * 1000, bw, sf, cr) self._write_ok() async def _cmd_set_tx_power(self, data: bytes) -> None: @@ -1685,7 +1686,7 @@ async def _cmd_set_tx_power(self, data: bytes) -> None: self._write_err(ERR_CODE_ILLEGAL_ARG) return power = struct.unpack_from(" 20: + if power < -9 or power >= 30: self._write_err(ERR_CODE_ILLEGAL_ARG) return self.bridge.set_tx_power(power) From a87948c0932a04be1ef436326355598b4f8e6211 Mon Sep 17 00:00:00 2001 From: agessaman Date: Thu, 5 Mar 2026 11:26:57 -0800 Subject: [PATCH 49/50] feat(companion): implement multi-byte path hash encoding and management - Introduced support for multi-byte path lengths in CompanionBase, allowing for 1-byte, 2-byte, and 3-byte hash sizes. - Added methods to set and apply path hash mode, ensuring proper encoding in originated packets. - Updated constants and packet handling to accommodate new path length encoding, enhancing compatibility with firmware. - Enhanced tests to validate the new path hash functionality and ensure correct behavior across various scenarios. --- src/pymc_core/companion/companion_base.py | 78 ++++- src/pymc_core/companion/constants.py | 4 +- src/pymc_core/companion/frame_server.py | 51 ++- src/pymc_core/companion/models.py | 1 + src/pymc_core/node/dispatcher.py | 5 + src/pymc_core/node/handlers/ack.py | 8 +- src/pymc_core/node/handlers/advert.py | 6 +- src/pymc_core/node/handlers/login_response.py | 15 +- src/pymc_core/node/handlers/login_server.py | 4 +- src/pymc_core/node/handlers/path.py | 2 +- .../node/handlers/protocol_request.py | 11 +- .../node/handlers/protocol_response.py | 43 +-- src/pymc_core/protocol/__init__.py | 2 + src/pymc_core/protocol/constants.py | 4 +- src/pymc_core/protocol/packet.py | 101 +++++- src/pymc_core/protocol/packet_builder.py | 41 ++- src/pymc_core/protocol/packet_utils.py | 83 ++++- tests/test_companion_base.py | 100 ++++++ tests/test_dispatcher.py | 53 ++++ tests/test_frame_server.py | 165 +++++++++- tests/test_handlers.py | 53 ++++ tests/test_packet.py | 290 ++++++++++++++++++ tests/test_packet_utils.py | 153 +++++++++ 23 files changed, 1181 insertions(+), 92 deletions(-) diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py index c2c7a9b..eda2cfb 100644 --- a/src/pymc_core/companion/companion_base.py +++ b/src/pymc_core/companion/companion_base.py @@ -29,12 +29,14 @@ MAX_PACKET_PAYLOAD, MAX_PATH_SIZE, PAYLOAD_TYPE_CONTROL, + PAYLOAD_TYPE_TRACE, REQ_TYPE_GET_STATUS, REQ_TYPE_GET_TELEMETRY_DATA, ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD, TELEM_PERM_BASE, ) +from ..protocol.packet_utils import PathUtils from ..protocol.transport_keys import calc_transport_code, get_auto_key_for from .channel_store import ChannelStore from .constants import ( @@ -425,6 +427,11 @@ def set_other_params( self.prefs.multi_acks = multi_acks self._save_prefs() + def set_path_hash_mode(self, mode: int) -> None: + """Set path hash encoding mode (0=1-byte, 1=2-byte, 2=3-byte hashes).""" + self.prefs.path_hash_mode = mode + self._save_prefs() + def get_self_info(self) -> NodePrefs: """Return a copy of the current node preferences.""" return copy.copy(self.prefs) @@ -564,6 +571,24 @@ def _apply_flood_scope(self, pkt: Packet) -> None: # Switch route type from FLOOD -> TRANSPORT_FLOOD pkt.header = (pkt.header & ~0x03) | ROUTE_TYPE_TRANSPORT_FLOOD + def _apply_path_hash_mode(self, pkt: Packet) -> None: + """Encode the device's path_hash_mode in originated packets. + + When a packet has 0 hops (freshly originated), sets bits 6-7 of + ``path_len`` to encode the hash size from ``prefs.path_hash_mode``. + Packets with existing hops (stored contact paths) are untouched. + Trace packets are excluded because the repeater's trace handler uses + ``path``/``path_len`` to store SNR values, not routing hashes. + + Mirrors firmware ``sendFlood(pkt, delay, _prefs.path_hash_mode + 1)`` + which calls ``pkt->setPathHashSizeAndCount(hash_size, 0)``. + """ + if pkt.get_payload_type() == PAYLOAD_TYPE_TRACE: + return + if pkt.get_path_hash_count() == 0: + hash_size = self.prefs.path_hash_mode + 1 + pkt.path_len = PathUtils.encode_path_len(hash_size, 0) + # ------------------------------------------------------------------------- # Statistics (subclasses may override _get_radio_stats for STATS_TYPE_RADIO) # ------------------------------------------------------------------------- @@ -642,23 +667,34 @@ def should_overwrite_when_full(self) -> bool: return bool(self.prefs.autoadd_config & AUTOADD_OVERWRITE_OLDEST) async def _apply_advert_to_stores( - self, contact: Contact, inbound_path: Optional[bytes] = None + self, + contact: Contact, + inbound_path: Optional[bytes] = None, + *, + path_len_encoded: Optional[int] = None, ) -> Optional[Contact]: """Apply advert to ContactStore and PathCache. Shared by Bridge and NODE_DISCOVERED. Mirrors C++ BaseChatMesh::onAdvertRecv (existing update, auto-add filter, overwrite when full). Returns the Contact if added or updated, None otherwise. Path cache is updated for all valid contacts (pub_key >= 7, name non-empty). + + Args: + path_len_encoded: Encoded path_len byte from the packet. If None, + falls back to len(inbound_path) (assumes 1-byte hashes). """ try: if len(contact.public_key) < 7 or not contact.name: return None inbound_path = inbound_path or b"" + advert_path_len = ( + path_len_encoded if path_len_encoded is not None else len(inbound_path) + ) self.path_cache.update( AdvertPath( public_key_prefix=contact.public_key[:7], name=contact.name, - path_len=len(inbound_path), + path_len=advert_path_len, path=inbound_path, recv_timestamp=int(time.time()), ) @@ -919,6 +955,7 @@ async def advertise(self, flood: bool = True) -> bool: route_type=route, ) self._apply_flood_scope(pkt) + self._apply_path_hash_mode(pkt) success = await self._send_packet(pkt, wait_for_ack=False) if success: self.stats.record_tx(is_flood=flood) @@ -939,6 +976,7 @@ async def share_contact(self, pub_key: bytes) -> bool: route_type="direct", ) self._apply_flood_scope(pkt) + self._apply_path_hash_mode(pkt) return await self._send_packet(pkt, wait_for_ack=False) except Exception as e: logger.error(f"Error sharing contact: {e}") @@ -956,6 +994,7 @@ async def send_trace_path_raw( path_list = list(path_bytes) pkt = PacketBuilder.create_trace(tag, auth_code, flags, path=path_list) self._apply_flood_scope(pkt) + self._apply_path_hash_mode(pkt) return await self._send_packet(pkt, wait_for_ack=False) except Exception as e: logger.error(f"Error sending trace (raw path): {e}") @@ -1003,6 +1042,7 @@ async def send_binary_req( pubkey_prefix=pub_key[:6].hex(), ) self._apply_flood_scope(pkt) + self._apply_path_hash_mode(pkt) success = await self._send_packet(pkt, wait_for_ack=False) except Exception as e: logger.error(f"Binary request send error: {e}") @@ -1054,6 +1094,7 @@ async def send_anon_req( pubkey_prefix=pub_key[:6].hex(), ) self._apply_flood_scope(pkt) + self._apply_path_hash_mode(pkt) success = await self._send_packet(pkt, wait_for_ack=False) except Exception as e: logger.error(f"Anon request send error: {e}") @@ -1104,6 +1145,7 @@ async def send_path_discovery_req(self, pub_key: bytes) -> SentResult: data=req_payload, ) self._apply_flood_scope(pkt) + self._apply_path_hash_mode(pkt) success = await self._send_packet(pkt, wait_for_ack=False) if success: self._pending_discovery_tags.add(tag_int) @@ -1155,6 +1197,7 @@ async def send_text_message( message_type=msg_type, ) self._apply_flood_scope(pkt) + self._apply_path_hash_mode(pkt) self._track_pending_ack(ack_crc) if wait_for_ack: success = await self._send_packet(pkt, wait_for_ack=True) @@ -1199,6 +1242,7 @@ async def send_channel_message(self, channel_idx: int, text: str) -> bool: channels_config=self.channels.get_channels(), ) self._apply_flood_scope(pkt) + self._apply_path_hash_mode(pkt) success = await self._send_packet(pkt, wait_for_ack=False) if success: self.stats.record_tx(is_flood=True) @@ -1230,17 +1274,23 @@ async def send_raw_data( protocol_code=PROTOCOL_CODE_RAW_DATA, data=data, ) + self._apply_path_hash_mode(pkt) success = await self._send_packet(pkt, wait_for_ack=False) return SentResult(success=success) except Exception as e: logger.error(f"Error sending raw data: {e}") return SentResult(success=False) - async def send_raw_data_direct(self, path: bytes, payload: bytes) -> SentResult: + async def send_raw_data_direct( + self, path: bytes, payload: bytes, *, path_len_encoded: int = None + ) -> SentResult: """Send a raw custom packet (PAYLOAD_TYPE_RAW_CUSTOM) on the given direct path. No encryption or contact lookup; path and payload are supplied by the caller. Matches firmware CMD_SEND_RAW_DATA behaviour. + + Args: + path_len_encoded: Encoded path_len byte. If None, assumes 1-byte hashes. """ if len(payload) < 4: return SentResult(success=False) @@ -1250,8 +1300,7 @@ async def send_raw_data_direct(self, path: bytes, payload: bytes) -> SentResult: return SentResult(success=False) try: pkt = PacketBuilder.create_raw_data(payload) - pkt.path = bytearray(path) - pkt.path_len = len(path) + pkt.set_path(path, path_len_encoded) success = await self._send_packet(pkt, wait_for_ack=False) if success: self.stats.record_tx(is_flood=False) @@ -1278,6 +1327,7 @@ async def send_trace_path( path = [contact.public_key[0]] try: pkt = PacketBuilder.create_trace(tag, auth_code, flags, path=path) + self._apply_path_hash_mode(pkt) return await self._send_packet(pkt, wait_for_ack=False) except Exception as e: logger.error(f"Error sending trace: {e}") @@ -1301,6 +1351,7 @@ async def send_control_data(self, data: Any = None) -> bool: pkt.path = bytearray() pkt.payload = bytearray(data) pkt.payload_len = len(data) + self._apply_path_hash_mode(pkt) return await self._send_packet(pkt, wait_for_ack=False) elif data is not None: # data was provided but invalid @@ -1308,6 +1359,7 @@ async def send_control_data(self, data: Any = None) -> bool: # No data: send default discovery request tag = random.randint(0, 0xFFFFFFFF) pkt = PacketBuilder.create_discovery_request(tag, filter_mask=0x04) + self._apply_path_hash_mode(pkt) return await self._send_packet(pkt, wait_for_ack=False) except Exception as e: logger.error(f"Error sending control data: {e}") @@ -1339,6 +1391,7 @@ def _login_cb(success: bool, data: dict) -> None: pkt = PacketBuilder.create_login_packet( contact=proxy, local_identity=self._identity, password=password ) + self._apply_path_hash_mode(pkt) await self._send_packet(pkt, wait_for_ack=False) try: await asyncio.wait_for(login_event.wait(), timeout=10.0) @@ -1371,6 +1424,7 @@ async def send_logout(self, pub_key: bytes) -> bool: pkt, _ = PacketBuilder.create_logout_packet( contact=contact, local_identity=self._identity ) + self._apply_path_hash_mode(pkt) await self._send_packet(pkt, wait_for_ack=False) return True except Exception as e: @@ -1388,10 +1442,11 @@ async def _wait_for_path_propagation(self, proxy: Any, request_type: str) -> Non """ out_path_len = getattr(proxy, "out_path_len", -1) if out_path_len > 0: - propagation_delay = out_path_len * 0.5 # e.g. 3 hops → 1.5s + hop_count = PathUtils.get_path_hash_count(out_path_len) + propagation_delay = hop_count * 0.5 # e.g. 3 hops → 1.5s logger.debug( f"Multi-hop {request_type}: waiting {propagation_delay:.1f}s for " - f"reciprocal PATH propagation ({out_path_len} hops)" + f"reciprocal PATH propagation ({hop_count} hops)" ) await asyncio.sleep(propagation_delay) @@ -1417,6 +1472,7 @@ async def send_status_request(self, pub_key: bytes, timeout: float = 15.0) -> di protocol_code=REQ_TYPE_GET_STATUS, data=b"", ) + self._apply_path_hash_mode(pkt) await self._send_packet(pkt, wait_for_ack=False) result = await waiter.wait(timeout) return { @@ -1464,6 +1520,7 @@ async def send_telemetry_request( protocol_code=REQ_TYPE_GET_TELEMETRY_DATA, data=bytes([inv]), ) + self._apply_path_hash_mode(pkt) await self._send_packet(pkt, wait_for_ack=False) result = await waiter.wait(timeout) telemetry_data = dict(result.get("parsed", {})) @@ -1518,6 +1575,7 @@ async def _send_protocol_request(self, pub_key: bytes, protocol_code: int, data: protocol_code=protocol_code, data=data, ) + self._apply_path_hash_mode(pkt) await self._send_packet(pkt, wait_for_ack=False) result = await waiter.wait(10.0) return { @@ -1566,6 +1624,7 @@ def _response_cb(message_text: str, sender_contact: Any) -> None: attempt=1, message_type=msg_type, ) + self._apply_path_hash_mode(pkt) await self._send_packet(pkt, wait_for_ack=True) try: await asyncio.wait_for(response_event.wait(), timeout=15.0) @@ -1639,7 +1698,10 @@ async def _handle_mesh_event(self, event_type: str, data: dict) -> None: contact = Contact.from_dict(data, now=now) if len(contact.public_key) >= 7 and contact.name: inbound_path = data.get("inbound_path") - applied = await self._apply_advert_to_stores(contact, inbound_path) + path_len_encoded = data.get("path_len_encoded") + applied = await self._apply_advert_to_stores( + contact, inbound_path, path_len_encoded=path_len_encoded + ) if applied is not None: await self._fire_callbacks("advert_received", applied) await self._fire_callbacks("node_discovered", data) diff --git a/src/pymc_core/companion/constants.py b/src/pymc_core/companion/constants.py index f2b95e2..f914f47 100644 --- a/src/pymc_core/companion/constants.py +++ b/src/pymc_core/companion/constants.py @@ -96,7 +96,8 @@ class BinaryReqType(IntEnum): # Protocol version reported in RESP_CODE_DEVICE_INFO; phone uses 9+ to infer # CMD_SEND_ANON_REQ (owner requests, etc.) is supported. -FIRMWARE_VER_CODE = 9 +# 10+ provides support for multi-byte path lengths. +FIRMWARE_VER_CODE = 10 # --------------------------------------------------------------------------- # Commands (app -> radio) @@ -153,6 +154,7 @@ class BinaryReqType(IntEnum): CMD_SEND_ANON_REQ = 57 CMD_SET_AUTOADD_CONFIG = 58 CMD_GET_AUTOADD_CONFIG = 59 +CMD_SET_PATH_HASH_MODE = 61 # --------------------------------------------------------------------------- # Response codes (radio -> app) diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index 3ab8519..86309bb 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -20,6 +20,7 @@ from typing import Any, Callable, Optional from ..protocol import CryptoUtils +from ..protocol.packet_utils import PathUtils from .constants import ( ADV_TYPE_CHAT, CMD_ADD_UPDATE_CONTACT, @@ -61,6 +62,7 @@ CMD_SET_DEVICE_TIME, CMD_SET_FLOOD_SCOPE, CMD_SET_OTHER_PARAMS, + CMD_SET_PATH_HASH_MODE, CMD_SET_RADIO_PARAMS, CMD_SET_RADIO_TX_POWER, CMD_SET_TUNING_PARAMS, @@ -888,6 +890,8 @@ async def _handle_cmd(self, payload: bytes) -> None: await self._cmd_set_other_params(data) elif cmd == CMD_SEND_RAW_DATA: await self._cmd_send_raw_data(data) + elif cmd == CMD_SET_PATH_HASH_MODE: + await self._cmd_set_path_hash_mode(data) else: logger.warning( "Companion unsupported cmd 0x%02x (%s) len=%s", @@ -943,7 +947,7 @@ async def _cmd_device_query(self, data: bytes) -> None: # Layout must match MeshCore companion_radio MyMesh.cpp handleCmdFrame() CMD_DEVICE_QEURY: # [0]=RESP_CODE_DEVICE_INFO, [1]=FIRMWARE_VER_CODE, [2]=MAX_CONTACTS/2, # [3]=MAX_GROUP_CHANNELS, [4..7]=ble_pin, [8..19]=build_date(12), [20..59]=manufacturer(40), - # [60..79]=version(20), [80]=client_repeat. + # [60..79]=version(20), [80]=client_repeat, [81]=path_hash_mode (v10+). if len(data) >= 1: self._app_target_ver = data[0] firmware_ver = FIRMWARE_VER_CODE @@ -955,8 +959,10 @@ async def _cmd_device_query(self, data: bytes) -> None: try: prefs = self.bridge.get_self_info() client_repeat = getattr(prefs, "client_repeat", 0) & 0xFF + path_hash_mode = getattr(prefs, "path_hash_mode", 0) & 0xFF except Exception: client_repeat = 0 + path_hash_mode = 0 frame = ( bytes( [ @@ -970,7 +976,7 @@ async def _cmd_device_query(self, data: bytes) -> None: + self._build_date_bytes + self._model_bytes + self._version_bytes - + bytes([client_repeat & 0xFF]) + + bytes([client_repeat & 0xFF, path_hash_mode & 0xFF]) ) version_str = self._version_bytes.split(b"\x00")[0].decode("utf-8", errors="replace") logger.info( @@ -1564,13 +1570,14 @@ async def _cmd_get_advert_path(self, data: bytes) -> None: path_bytes = getattr(found, "path", None) or b"" if not isinstance(path_bytes, bytes): path_bytes = bytes(path_bytes) - path_len = min(len(path_bytes), MAX_PATH_SIZE) + path_len_encoded = getattr(found, "path_len", 0) or 0 + path_byte_len = PathUtils.get_path_byte_len(path_len_encoded) recv_ts = getattr(found, "recv_timestamp", 0) frame = ( bytes([RESP_CODE_ADVERT_PATH]) + struct.pack(" None: self._write_ok() async def _cmd_send_raw_data(self, data: bytes) -> None: - """Handle CMD_SEND_RAW_DATA (25). Format: [path_len][path][payload] (min 4-byte payload).""" + """Handle CMD_SEND_RAW_DATA (25). + Format: [path_len_encoded][path][payload] (min 4-byte payload).""" if len(data) < 6: self._write_err(ERR_CODE_UNSUPPORTED_CMD) return path_len_byte = data[0] - path_len = path_len_byte - 256 if path_len_byte >= 128 else path_len_byte - if path_len < 0 or path_len > MAX_PATH_SIZE or 1 + path_len + 4 > len(data): + if not PathUtils.is_valid_path_len(path_len_byte): self._write_err(ERR_CODE_UNSUPPORTED_CMD) return - path = data[1 : 1 + path_len] - payload = data[1 + path_len :] - result = await self.bridge.send_raw_data_direct(path, payload) + path_byte_len = PathUtils.get_path_byte_len(path_len_byte) + if 1 + path_byte_len + 4 > len(data): + self._write_err(ERR_CODE_UNSUPPORTED_CMD) + return + path = data[1 : 1 + path_byte_len] + payload = data[1 + path_byte_len :] + result = await self.bridge.send_raw_data_direct( + path, payload, path_len_encoded=path_len_byte + ) if result.success: self._write_ok() else: self._write_err(ERR_CODE_TABLE_FULL) + + async def _cmd_set_path_hash_mode(self, data: bytes) -> None: + """Handle CMD_SET_PATH_HASH_MODE (61). Format: [subtype(0), mode(0-2)]. + + Mirrors MyMesh.cpp:1320-1327. Subtype byte must be 0; mode values + 0, 1, 2 select 1-byte, 2-byte, 3-byte path hashes respectively. + """ + if len(data) < 2 or data[0] != 0: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + mode = data[1] + if mode >= 3: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + self.bridge.set_path_hash_mode(mode) + self._write_ok() diff --git a/src/pymc_core/companion/models.py b/src/pymc_core/companion/models.py index 7e39504..13b54a0 100644 --- a/src/pymc_core/companion/models.py +++ b/src/pymc_core/companion/models.py @@ -127,6 +127,7 @@ class NodePrefs: airtime_factor: float = 0.0 # Reported in CMD_DEVICE_QUERY device info frame (byte 80). client_repeat: int = 0 + path_hash_mode: int = 0 # 0=1-byte, 1=2-byte, 2=3-byte hashes @dataclass diff --git a/src/pymc_core/node/dispatcher.py b/src/pymc_core/node/dispatcher.py index 8bb5dba..5d6474c 100644 --- a/src/pymc_core/node/dispatcher.py +++ b/src/pymc_core/node/dispatcher.py @@ -13,6 +13,7 @@ ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD, ) +from ..protocol.packet_utils import PathUtils from ..protocol.transport_keys import calc_transport_code from ..protocol.utils import PAYLOAD_TYPES, ROUTE_TYPES, format_packet_info @@ -395,6 +396,10 @@ async def _process_received_packet( self._log(f"Blacklisted malformed packet (hash: {packet_hash})") return + # Packets at max hops for their path encoding must not be retransmitted + if PathUtils.is_path_at_max_hops(pkt.path_len): + pkt.mark_do_not_retransmit() + ptype = pkt.header >> PH_TYPE_SHIFT self._log(f"[RX DEBUG] Packet type: {ptype:02X}") diff --git a/src/pymc_core/node/handlers/ack.py b/src/pymc_core/node/handlers/ack.py index c75cdcc..e47070d 100644 --- a/src/pymc_core/node/handlers/ack.py +++ b/src/pymc_core/node/handlers/ack.py @@ -3,6 +3,7 @@ from ...protocol import Packet from ...protocol.constants import PAYLOAD_TYPE_ACK +from ...protocol.packet_utils import PathUtils from .base import BaseHandler @@ -141,14 +142,15 @@ async def _process_bundled_ack_in_path(self, payload: bytes) -> Optional[int]: return None path_length = payload[0] + path_byte_len = PathUtils.get_path_byte_len(path_length) - # Check if we have enough data for: path_length + path + extra_type + extra - min_required = 1 + path_length + 1 + 4 # +4 for ACK CRC + # Check if we have enough data for: path_byte_len + path + extra_type + extra + min_required = 1 + path_byte_len + 1 + 4 # +4 for ACK CRC if len(payload) < min_required: return None # Extract extra section - extra_start = 1 + path_length + extra_start = 1 + path_byte_len extra_type = payload[extra_start] extra_payload = payload[extra_start + 1 :] diff --git a/src/pymc_core/node/handlers/advert.py b/src/pymc_core/node/handlers/advert.py index b130ac8..44b4b50 100644 --- a/src/pymc_core/node/handlers/advert.py +++ b/src/pymc_core/node/handlers/advert.py @@ -10,6 +10,7 @@ SIGNATURE_SIZE, describe_advert_flags, ) +from ...protocol.packet_utils import PathUtils from ...protocol.utils import determine_contact_type_from_flags, get_contact_type_name from ..events import MeshEvents from .base import BaseHandler @@ -139,8 +140,8 @@ async def __call__(self, packet: Packet) -> Optional[Dict[str, Any]]: try: path_len = getattr(packet, "path_len", 0) or 0 path = getattr(packet, "path", bytearray()) or bytearray() - effective_len = path_len if path_len > 0 else len(path) - inbound_path = bytes(path[:effective_len]) if effective_len > 0 else b"" + path_byte_len = PathUtils.get_path_byte_len(path_len) + inbound_path = bytes(path[:path_byte_len]) if path_byte_len > 0 else b"" event_data = { "public_key": pubkey_hex, "name": name, @@ -152,6 +153,7 @@ async def __call__(self, packet: Packet) -> Optional[Dict[str, Any]]: "snr": advert_data["snr"], "rssi": advert_data["rssi"], "inbound_path": inbound_path, + "path_len_encoded": path_len, } self.event_service.publish_sync(MeshEvents.NODE_DISCOVERED, event_data) except Exception as e: diff --git a/src/pymc_core/node/handlers/login_response.py b/src/pymc_core/node/handlers/login_response.py index a02cf9d..f242d39 100644 --- a/src/pymc_core/node/handlers/login_response.py +++ b/src/pymc_core/node/handlers/login_response.py @@ -3,12 +3,8 @@ from typing import Callable, Optional from ...protocol import CryptoUtils, Identity, Packet -from ...protocol.constants import ( - MAX_PATH_SIZE, - PAYLOAD_TYPE_ANON_REQ, - PAYLOAD_TYPE_PATH, - PAYLOAD_TYPE_RESPONSE, -) +from ...protocol.constants import PAYLOAD_TYPE_ANON_REQ, PAYLOAD_TYPE_PATH, PAYLOAD_TYPE_RESPONSE +from ...protocol.packet_utils import PathUtils from .base import BaseHandler # Response codes from C++ server @@ -176,9 +172,10 @@ async def _decrypt_response( pkt_type = (packet.header >> 2) & 0x0F if pkt_type == PAYLOAD_TYPE_PATH and len(plaintext) >= 2: path_len_byte = plaintext[0] - inner_offset = 1 + path_len_byte + 1 # skip path_len + path + extra_type - if path_len_byte <= MAX_PATH_SIZE and len(plaintext) >= inner_offset: - extra_type = plaintext[1 + path_len_byte] & 0x0F + path_byte_len = PathUtils.get_path_byte_len(path_len_byte) + inner_offset = 1 + path_byte_len + 1 # skip path_len + path + extra_type + if PathUtils.is_valid_path_len(path_len_byte) and len(plaintext) >= inner_offset: + extra_type = plaintext[1 + path_byte_len] & 0x0F if extra_type == PAYLOAD_TYPE_RESPONSE and len(plaintext) > inner_offset: plaintext = plaintext[inner_offset:] diff --git a/src/pymc_core/node/handlers/login_server.py b/src/pymc_core/node/handlers/login_server.py index 92ed40a..ff551c0 100644 --- a/src/pymc_core/node/handlers/login_server.py +++ b/src/pymc_core/node/handlers/login_server.py @@ -89,7 +89,7 @@ async def __call__(self, packet: Packet) -> None: """Handle ANON_REQ login packet from client.""" try: # Debug: Log packet routing info - path_data = list(packet.path[: packet.path_len]) if packet.path_len > 0 else [] + path_data = packet.get_path_hashes_hex() if packet.path_len > 0 else [] self.log( f"[LoginServer] Packet route flood: {packet.is_route_flood()}, " f"path_len: {packet.path_len}, path: {path_data}" @@ -253,7 +253,7 @@ async def _send_login_response( client_hash = client_identity.get_public_key()[0] server_hash = self.local_identity.get_public_key()[0] path_list = ( - list(original_packet.path[: original_packet.path_len]) + list(original_packet.path[: original_packet.get_path_byte_len()]) if original_packet and original_packet.path_len > 0 else [] ) diff --git a/src/pymc_core/node/handlers/path.py b/src/pymc_core/node/handlers/path.py index f0b13ee..4348e48 100644 --- a/src/pymc_core/node/handlers/path.py +++ b/src/pymc_core/node/handlers/path.py @@ -77,7 +77,7 @@ async def __call__(self, pkt: Packet) -> None: # Single summary line for PATH packet try: payload = pkt.get_payload() - hop_count = pkt.path_len + hop_count = pkt.get_path_hash_count() if len(payload) >= 2: dest_hash = payload[0] src_hash = payload[1] diff --git a/src/pymc_core/node/handlers/protocol_request.py b/src/pymc_core/node/handlers/protocol_request.py index 5b33f05..a5b42f2 100644 --- a/src/pymc_core/node/handlers/protocol_request.py +++ b/src/pymc_core/node/handlers/protocol_request.py @@ -8,8 +8,9 @@ from typing import Callable, Optional from pymc_core.protocol import PacketBuilder -from pymc_core.protocol.constants import PAYLOAD_TYPE_REQ, PAYLOAD_TYPE_RESPONSE +from pymc_core.protocol.constants import MAX_PATH_SIZE, PAYLOAD_TYPE_REQ, PAYLOAD_TYPE_RESPONSE from pymc_core.protocol.crypto import CryptoUtils +from pymc_core.protocol.packet_utils import PathUtils # Request type codes (matching C++ implementation) REQ_TYPE_GET_STATUS = 0x01 @@ -245,8 +246,12 @@ def _build_response(self, original_packet, client, response_data: bytes, shared_ # Add path for direct routing if available if hasattr(client, "out_path_len") and hasattr(client, "out_path"): if client.out_path_len >= 0 and len(client.out_path) > 0: - reply_packet.path = bytearray(client.out_path[: client.out_path_len]) - reply_packet.path_len = client.out_path_len + reply_packet.set_path( + client.out_path[:MAX_PATH_SIZE], + client.out_path_len + if PathUtils.is_valid_path_len(client.out_path_len) + else None, + ) self.log( f"RESPONSE built for 0x{client_identity.get_public_key()[0]:02X} " diff --git a/src/pymc_core/node/handlers/protocol_response.py b/src/pymc_core/node/handlers/protocol_response.py index b7a44eb..79d09b6 100644 --- a/src/pymc_core/node/handlers/protocol_response.py +++ b/src/pymc_core/node/handlers/protocol_response.py @@ -9,14 +9,10 @@ from typing import Any, Callable, Dict, Optional from ...protocol import CryptoUtils, Identity, Packet -from ...protocol.constants import ( - MAX_PATH_SIZE, - PAYLOAD_TYPE_PATH, - PAYLOAD_TYPE_RESPONSE, - ROUTE_TYPE_DIRECT, -) +from ...protocol.constants import PAYLOAD_TYPE_PATH, PAYLOAD_TYPE_RESPONSE, ROUTE_TYPE_DIRECT from ...protocol.crypto import CIPHER_BLOCK_SIZE, CIPHER_MAC_SIZE from ...protocol.packet_builder import PacketBuilder +from ...protocol.packet_utils import PathUtils # --------------------------------------------------------------------------- # Built-in CayenneLPP decoder (no external dependency) @@ -283,10 +279,14 @@ async def __call__(self, pkt: Packet) -> None: # PATH packet: decrypted is path_len(1)+path(N)+extra_type(1)+extra # Extract inner response from path-return structure path_len_byte = raw_decrypted[0] - inner_offset = 1 + path_len_byte + 1 - if path_len_byte <= MAX_PATH_SIZE and len(raw_decrypted) >= inner_offset + 4: - out_path = bytes(raw_decrypted[1 : 1 + path_len_byte]) - extra_type = raw_decrypted[1 + path_len_byte] & 0x0F + path_byte_len = PathUtils.get_path_byte_len(path_len_byte) + inner_offset = 1 + path_byte_len + 1 + if ( + PathUtils.is_valid_path_len(path_len_byte) + and len(raw_decrypted) >= inner_offset + 4 + ): + out_path = bytes(raw_decrypted[1 : 1 + path_byte_len]) + extra_type = raw_decrypted[1 + path_byte_len] & 0x0F extra = raw_decrypted[inner_offset:] if extra_type == PAYLOAD_TYPE_RESPONSE and len(extra) >= 4: tag_bytes = extra[:4] @@ -359,9 +359,10 @@ def _update_contact_path( Returns True if the contact was found and updated, False otherwise. """ try: - if path_len_byte > MAX_PATH_SIZE: + if not PathUtils.is_valid_path_len(path_len_byte): return False - out_path_bytes = bytes(decrypted[1 : 1 + path_len_byte]) + path_byte_len = PathUtils.get_path_byte_len(path_len_byte) + out_path_bytes = bytes(decrypted[1 : 1 + path_byte_len]) contact_obj = self._contact_book.get_by_key(contact_pubkey) if contact_obj is not None: contact_obj.out_path_len = path_len_byte @@ -427,10 +428,10 @@ async def _send_reciprocal_path( # Convert to DIRECT routing using the inner out_path (the route # from us to the remote repeater). - out_path_bytes = bytes(decrypted[1 : 1 + path_len_byte]) + path_byte_len = PathUtils.get_path_byte_len(path_len_byte) + out_path_bytes = bytes(decrypted[1 : 1 + path_byte_len]) reciprocal.header = (reciprocal.header & ~0x03) | ROUTE_TYPE_DIRECT - reciprocal.path = bytearray(out_path_bytes) - reciprocal.path_len = len(out_path_bytes) + reciprocal.set_path(out_path_bytes, path_len_byte) # Await injection so the reciprocal PATH is serialized through the # radio TX pipeline before this method returns. This ensures the @@ -496,9 +497,13 @@ async def _decrypt_protocol_response( if pkt_type == PAYLOAD_TYPE_PATH: if len(decrypted) >= 2: path_len_byte = decrypted[0] - inner_offset = 1 + path_len_byte + 1 - if path_len_byte <= MAX_PATH_SIZE and len(decrypted) >= inner_offset: - extra_type = decrypted[1 + path_len_byte] & 0x0F + path_byte_len = PathUtils.get_path_byte_len(path_len_byte) + inner_offset = 1 + path_byte_len + 1 + if ( + PathUtils.is_valid_path_len(path_len_byte) + and len(decrypted) >= inner_offset + ): + extra_type = decrypted[1 + path_byte_len] & 0x0F if extra_type == PAYLOAD_TYPE_RESPONSE and len(decrypted) > inner_offset: response_data = decrypted[inner_offset:] elif extra_type != PAYLOAD_TYPE_RESPONSE: @@ -509,7 +514,7 @@ async def _decrypt_protocol_response( # Firmware pattern (onContactPathRecv): update contact out_path # so subsequent requests use sendDirect() instead of sendFlood(). - out_path_bytes = bytes(decrypted[1 : 1 + path_len_byte]) + out_path_bytes = bytes(decrypted[1 : 1 + path_byte_len]) if self._update_contact_path( contact_pubkey, src_hash, path_len_byte, decrypted ): diff --git a/src/pymc_core/protocol/__init__.py b/src/pymc_core/protocol/__init__.py index d85c710..816d4fc 100644 --- a/src/pymc_core/protocol/__init__.py +++ b/src/pymc_core/protocol/__init__.py @@ -72,6 +72,7 @@ PacketHeaderUtils, PacketTimingUtils, PacketValidationUtils, + PathUtils, RouteTypeUtils, ) from .transport_keys import calc_transport_code, get_auto_key_for @@ -98,6 +99,7 @@ "PacketHeaderUtils", "PacketHashingUtils", "RouteTypeUtils", + "PathUtils", "PacketTimingUtils", # Header constants "PH_ROUTE_MASK", diff --git a/src/pymc_core/protocol/constants.py b/src/pymc_core/protocol/constants.py index 443d94b..a3421ed 100644 --- a/src/pymc_core/protocol/constants.py +++ b/src/pymc_core/protocol/constants.py @@ -48,7 +48,9 @@ MAX_ADVERT_DATA_SIZE = 96 PUB_KEY_SIZE = 32 SIGNATURE_SIZE = 64 -PATH_HASH_SIZE = 1 +PATH_HASH_SIZE = 1 # Legacy default; see PathUtils for multi-byte path support +PATH_HASH_COUNT_MASK = 0x3F # bits 0-5 of encoded path_len (max encodable hop count) +PATH_HASH_SIZE_SHIFT = 6 # bits 6-7 of encoded path_len CIPHER_MAC_SIZE = 32 # SHA‑256 HMAC CIPHER_BLOCK_SIZE = 16 MAX_PACKET_PAYLOAD = 256 # firmware's default diff --git a/src/pymc_core/protocol/packet.py b/src/pymc_core/protocol/packet.py index 1fdd620..1b205f3 100644 --- a/src/pymc_core/protocol/packet.py +++ b/src/pymc_core/protocol/packet.py @@ -1,7 +1,6 @@ from typing import ByteString, Optional from .constants import ( - MAX_PATH_SIZE, MAX_SUPPORTED_PAYLOAD_VERSION, PH_ROUTE_MASK, PH_TYPE_MASK, @@ -16,7 +15,7 @@ SIGNATURE_SIZE, TIMESTAMP_SIZE, ) -from .packet_utils import PacketDataUtils, PacketHashingUtils, PacketValidationUtils +from .packet_utils import PacketDataUtils, PacketHashingUtils, PacketValidationUtils, PathUtils """ ╔═══════════════════════════════════════════════════════════════════════════╗ @@ -30,9 +29,10 @@ ║ Transport Codes ║ Two 16-bit codes (4 bytes total). Only present for ║ ║ (0 or 4 bytes) ║ TRANSPORT_FLOOD and TRANSPORT_DIRECT route types. ║ ╠════════════════════╬══════════════════════════════════════════════════════╣ -║ Path Length (1 B) ║ Number of path hops (0–15). ║ +║ Path Length (1 B) ║ Encoded: bits 0-5 = hash count (hops), bits 6-7 = ║ +║ ║ (hash_size - 1). Actual bytes = count × hash_size. ║ ╠════════════════════╬══════════════════════════════════════════════════════╣ -║ Path (N bytes) ║ List of node hashes (1 byte each), length = path_len ║ +║ Path (N bytes) ║ Node hashes (1-3 bytes each), N = count × hash_size ║ ╠════════════════════╬══════════════════════════════════════════════════════╣ ║ Payload (N bytes) ║ Actual encrypted or plain payload. Max: 254 bytes ║ ╠════════════════════╬══════════════════════════════════════════════════════╣ @@ -71,8 +71,8 @@ class Packet: Attributes: header (int): Single byte header containing packet type and flags. transport_codes (list): Two 16-bit transport codes for TRANSPORT route types. - path_len (int): Length of the path component in bytes. - path (bytearray): Variable-length path data for routing. + path_len (int): Encoded path length byte (bits 0-5 = hash count, bits 6-7 = hash size - 1). + path (bytearray): Variable-length path data for routing (hash_count × hash_size bytes). payload (bytearray): Variable-length payload data. payload_len (int): Actual length of payload data. _rssi (int): Raw RSSI signal strength value from firmware. @@ -82,8 +82,7 @@ class Packet: ```python packet = Packet() packet.header = 0x01 # Flood routing - packet.path = b"node1->node2" - packet.path_len = len(packet.path) + packet.set_path(b"\\xAA\\xBB\\xCC") # 3 hops, 1-byte hashes packet.payload = b"Hello World" packet.payload_len = len(packet.payload) data = packet.write_to() @@ -207,6 +206,73 @@ def is_route_direct(self) -> bool: route_type = self.get_route_type() return route_type == ROUTE_TYPE_TRANSPORT_DIRECT or route_type == ROUTE_TYPE_DIRECT + def get_path_hash_size(self) -> int: + """Extract per-hop hash size (1, 2, or 3) from the encoded path_len byte.""" + return PathUtils.get_path_hash_size(self.path_len) + + def get_path_hash_count(self) -> int: + """Extract hop count (0-63) from the encoded path_len byte.""" + return PathUtils.get_path_hash_count(self.path_len) + + def get_path_byte_len(self) -> int: + """Calculate actual path byte length from the encoded path_len byte.""" + return PathUtils.get_path_byte_len(self.path_len) + + def get_path_hashes(self) -> list: + """Return path as a list of per-hop hash entries (1, 2, or 3 bytes each). + + Groups the raw ``self.path`` bytearray using the hash size encoded in + ``self.path_len``. Each entry in the returned list is a ``bytes`` + object whose length equals ``get_path_hash_size()``. + + Returns: + list[bytes]: One entry per hop. Empty list when hop count is 0. + """ + hash_size = self.get_path_hash_size() + count = self.get_path_hash_count() + result = [] + for i in range(count): + start = i * hash_size + end = start + hash_size + if end <= len(self.path): + result.append(bytes(self.path[start:end])) + return result + + def get_path_hashes_hex(self) -> list: + """Return path as a list of uppercase hex strings, one per hop. + + Examples:: + + 1-byte hashes: ["B5", "A3", "F2"] + 2-byte hashes: ["B5A3", "F2C1"] + 3-byte hashes: ["B5A3F2", "C1D4E7"] + """ + return [entry.hex().upper() for entry in self.get_path_hashes()] + + def set_path( + self, + path_bytes: bytes, + path_len_encoded: int = None, + ) -> None: + """Set the routing path with optional encoded path_len. + + Args: + path_bytes: Raw path bytes to set. + path_len_encoded: Pre-encoded path_len byte. If None, assumes + 1-byte hashes and encodes len(path_bytes) as the hop count. + """ + self.path = bytearray(path_bytes) + if path_len_encoded is not None: + self.path_len = path_len_encoded + else: + hop_count = len(path_bytes) + if hop_count > 63: + raise ValueError( + f"path length {hop_count} exceeds maximum encodable hop count 63 " + "for 1-byte hashes; pass path_len_encoded explicitly or use a shorter path" + ) + self.path_len = PathUtils.encode_path_len(1, hop_count) + def get_payload(self) -> bytes: """ Get the packet payload as immutable bytes, truncated to declared length. @@ -250,7 +316,7 @@ def _validate_lengths(self) -> None: ValueError: If any declared length doesn't match the actual buffer length. """ PacketValidationUtils.validate_buffer_lengths( - self.path_len, len(self.path), self.payload_len, len(self.payload) + self.get_path_byte_len(), len(self.path), self.payload_len, len(self.payload) ) def _check_bounds(self, idx: int, required: int, data_len: int, error_msg: str) -> None: @@ -295,7 +361,7 @@ def write_to(self) -> bytes: out.extend(self.transport_codes[1].to_bytes(2, "little")) out.append(self.path_len) - out += self.path + out += self.path[: self.get_path_byte_len()] out += self.payload[: self.payload_len] return bytes(out) @@ -338,12 +404,13 @@ def read_from(self, data: ByteString) -> bool: self._check_bounds(idx, 1, data_len, "missing path_len") self.path_len = data[idx] idx += 1 - if self.path_len > MAX_PATH_SIZE: - raise ValueError("path_len too large") + if not PathUtils.is_valid_path_len(self.path_len): + raise ValueError(f"invalid path_len encoding: 0x{self.path_len:02X}") - self._check_bounds(idx, self.path_len, data_len, "truncated path") - self.path = bytearray(data[idx : idx + self.path_len]) - idx += self.path_len + path_byte_len = self.get_path_byte_len() + self._check_bounds(idx, path_byte_len, data_len, "truncated path") + self.path = bytearray(data[idx : idx + path_byte_len]) + idx += path_byte_len self.payload = bytearray(data[idx:]) self.payload_len = len(self.payload) @@ -425,7 +492,9 @@ def get_raw_length(self) -> int: Note: This matches the wire format used by write_to() and expected by read_from(). """ - base_length = 2 + self.path_len + self.payload_len # header + path_len + path + payload + base_length = ( + 2 + self.get_path_byte_len() + self.payload_len + ) # header + path_len_byte + path + payload return base_length + (4 if self.has_transport_codes() else 0) def get_snr(self) -> float: diff --git a/src/pymc_core/protocol/packet_builder.py b/src/pymc_core/protocol/packet_builder.py index 0f0c3e5..7d2d8f2 100644 --- a/src/pymc_core/protocol/packet_builder.py +++ b/src/pymc_core/protocol/packet_builder.py @@ -35,7 +35,13 @@ TELEM_PERM_LOCATION, ) from .identity import Identity, LocalIdentity -from .packet_utils import PacketDataUtils, PacketHeaderUtils, PacketValidationUtils, RouteTypeUtils +from .packet_utils import ( + PacketDataUtils, + PacketHeaderUtils, + PacketValidationUtils, + PathUtils, + RouteTypeUtils, +) logger = logging.getLogger(__name__) @@ -483,9 +489,10 @@ def create_login_packet(contact: Any, local_identity: LocalIdentity, password: s if route_type == "direct" and out_path_len > 0: out_path = getattr(contact, "out_path", b"") if out_path: - path_bytes = out_path[:MAX_PATH_SIZE] - pkt.path = bytearray(path_bytes) - pkt.path_len = len(pkt.path) + pkt.set_path( + out_path[:MAX_PATH_SIZE], + out_path_len if PathUtils.is_valid_path_len(out_path_len) else None, + ) return pkt @@ -773,8 +780,23 @@ def create_text_message( f"Path length {len(routing_path)} exceeds maximum {MAX_PATH_SIZE}, truncating" ) routing_path = routing_path[:MAX_PATH_SIZE] - pkt.path = bytearray(routing_path) - pkt.path_len = len(pkt.path) + # Preserve encoded path_len from contact when using its stored path + contact_path_len = getattr(contact, "out_path_len", -1) if contact else -1 + if ( + out_path is None + and contact_path_len >= 0 + and PathUtils.is_valid_path_len(contact_path_len) + ): + pkt.set_path(bytearray(routing_path), contact_path_len) + else: + # path_len encodes hop count in 6 bits (0-63); 64 would encode as 0 + if len(routing_path) == 64: + logger.warning( + "Path length 64 exceeds encodable hop count 63 (1-byte hashes), " + "truncating to 63 bytes" + ) + routing_path = routing_path[:63] + pkt.set_path(bytearray(routing_path)) else: pkt.path_len, pkt.path = 0, bytearray() @@ -853,9 +875,10 @@ def create_protocol_request( packet = PacketBuilder._create_packet(header, payload) if route_type == "direct" and len(out_path) > 0: - path_bytes = out_path[:MAX_PATH_SIZE] - packet.path = bytearray(path_bytes) - packet.path_len = len(packet.path) + packet.set_path( + out_path[:MAX_PATH_SIZE], + out_path_len if PathUtils.is_valid_path_len(out_path_len) else None, + ) return packet, timestamp diff --git a/src/pymc_core/protocol/packet_utils.py b/src/pymc_core/protocol/packet_utils.py index 7164192..fc0fe25 100644 --- a/src/pymc_core/protocol/packet_utils.py +++ b/src/pymc_core/protocol/packet_utils.py @@ -11,6 +11,8 @@ MAX_HASH_SIZE, MAX_PACKET_PAYLOAD, MAX_PATH_SIZE, + PATH_HASH_COUNT_MASK, + PATH_HASH_SIZE_SHIFT, PAYLOAD_TYPE_TRACE, PAYLOAD_VER_1, PH_ROUTE_MASK, @@ -101,6 +103,77 @@ def validate_payload_size(payload_len: int) -> None: raise ValueError(f"payload too large: {payload_len} > {MAX_PACKET_PAYLOAD}") +class PathUtils: + """Multi-byte path encoding/decoding matching firmware Packet.h. + + The encoded path_len byte packs both hash size and hop count: + - Bits 0-5: hash_count (number of hops, 0-63) + - Bits 6-7: (hash_size - 1), where hash_size is 1, 2, or 3 bytes + + Total path bytes on the wire = hash_count * hash_size. + For 1-byte hashes the encoded byte equals the raw hop count, + preserving backward compatibility with legacy packets. + """ + + @staticmethod + def get_path_hash_size(path_len_byte: int) -> int: + """Extract per-hop hash size (1, 2, or 3) from the encoded path_len byte.""" + return (path_len_byte >> PATH_HASH_SIZE_SHIFT) + 1 + + @staticmethod + def get_path_hash_count(path_len_byte: int) -> int: + """Extract hop count (0-63) from the encoded path_len byte.""" + return path_len_byte & PATH_HASH_COUNT_MASK + + @staticmethod + def get_path_byte_len(path_len_byte: int) -> int: + """Calculate actual path byte length from the encoded path_len byte.""" + return (path_len_byte & PATH_HASH_COUNT_MASK) * ( + (path_len_byte >> PATH_HASH_SIZE_SHIFT) + 1 + ) + + @staticmethod + def encode_path_len(hash_size: int, hash_count: int) -> int: + """Encode hash size and hop count into a single path_len byte. + + Hop count is stored in 6 bits (0-63). Values above 63 are invalid and raise. + """ + if not 0 <= hash_count <= 63: + raise ValueError(f"hop count must be 0-63 for path_len encoding, got {hash_count}") + return ((hash_size - 1) << PATH_HASH_SIZE_SHIFT) | (hash_count & PATH_HASH_COUNT_MASK) + + @staticmethod + def is_valid_path_len(path_len_byte: int) -> bool: + """Validate an encoded path_len byte. + + Returns False for hash_size == 4 (reserved) or if the total + path bytes would exceed MAX_PATH_SIZE. + """ + hash_size = (path_len_byte >> PATH_HASH_SIZE_SHIFT) + 1 + if hash_size > 3: + return False + hash_count = path_len_byte & PATH_HASH_COUNT_MASK + return hash_count * hash_size <= MAX_PATH_SIZE + + @staticmethod + def is_path_at_max_hops(path_len_byte: int) -> bool: + """True if path has reached maximum hops for its hash size (do not retransmit). + + Measures hops, not raw bytes. Max hops depend on hash size and MAX_PATH_SIZE: + - 1-byte hashes: 63 hops (63 bytes) + - 2-byte hashes: 32 hops (64 bytes) + - 3-byte hashes: 21 hops (63 bytes) + """ + if path_len_byte == 0: + return False + hash_size = PathUtils.get_path_hash_size(path_len_byte) + if hash_size > 3: + return False + hash_count = path_len_byte & PATH_HASH_COUNT_MASK + max_hops = min(PATH_HASH_COUNT_MASK, MAX_PATH_SIZE // hash_size) + return hash_count >= max_hops + + class PacketDataUtils: """Centralized data packing and unpacking utilities.""" @@ -339,15 +412,17 @@ def calc_flood_timeout_ms(packet_airtime_ms: float) -> float: return SEND_TIMEOUT_BASE_MILLIS + (FLOOD_SEND_TIMEOUT_FACTOR * packet_airtime_ms) @staticmethod - def calc_direct_timeout_ms(packet_airtime_ms: float, path_len: int) -> float: + def calc_direct_timeout_ms(packet_airtime_ms: float, path_hash_count: int) -> float: """ Calculate timeout for direct packets. - Formula: 500ms + ((airtime × 6 + 250ms) × (path_len + 1)) + Formula: 500ms + ((airtime × 6 + 250ms) × (path_hash_count + 1)) Args: packet_airtime_ms: Estimated packet airtime in milliseconds - path_len: Number of hops in the path (0 for direct) + path_hash_count: Number of hops in the path (0 for direct). + Use ``PathUtils.get_path_hash_count(path_len)`` to extract + from an encoded path_len byte. Returns: Timeout in milliseconds @@ -357,5 +432,5 @@ def calc_direct_timeout_ms(packet_airtime_ms: float, path_len: int) -> float: DIRECT_SEND_PERHOP_EXTRA_MILLIS = 250 return SEND_TIMEOUT_BASE_MILLIS + ( (packet_airtime_ms * DIRECT_SEND_PERHOP_FACTOR + DIRECT_SEND_PERHOP_EXTRA_MILLIS) - * (path_len + 1) + * (path_hash_count + 1) ) diff --git a/tests/test_companion_base.py b/tests/test_companion_base.py index 2346317..050b44d 100644 --- a/tests/test_companion_base.py +++ b/tests/test_companion_base.py @@ -2,6 +2,7 @@ import pytest +from pymc_core.companion import CompanionBridge from pymc_core.companion.companion_base import ResponseWaiter, adv_type_to_flags from pymc_core.companion.constants import ( ADV_TYPE_CHAT, @@ -9,11 +10,14 @@ ADV_TYPE_ROOM, ADV_TYPE_SENSOR, ) +from pymc_core.protocol import LocalIdentity, Packet from pymc_core.protocol.constants import ( ADVERT_FLAG_IS_CHAT_NODE, ADVERT_FLAG_IS_REPEATER, ADVERT_FLAG_IS_ROOM_SERVER, ADVERT_FLAG_IS_SENSOR, + PAYLOAD_TYPE_TRACE, + ROUTE_TYPE_DIRECT, ) # --------------------------------------------------------------------------- @@ -75,3 +79,99 @@ def test_sensor(self): def test_unknown_defaults_to_chat(self): assert adv_type_to_flags(99) == ADVERT_FLAG_IS_CHAT_NODE assert adv_type_to_flags(0) == ADVERT_FLAG_IS_CHAT_NODE + + +# --------------------------------------------------------------------------- +# _apply_path_hash_mode +# --------------------------------------------------------------------------- + + +def _make_bridge(path_hash_mode: int = 0) -> CompanionBridge: + """Create a minimal CompanionBridge for testing _apply_path_hash_mode.""" + + async def _noop_injector(pkt, wait_for_ack=False): + return True + + bridge = CompanionBridge(LocalIdentity(), _noop_injector, node_name="Test") + bridge.prefs.path_hash_mode = path_hash_mode + return bridge + + +class TestApplyPathHashMode: + def test_encodes_on_zero_hops(self): + """path_hash_mode=1 on a fresh packet (0 hops) → path_len=0x40.""" + bridge = _make_bridge(path_hash_mode=1) + pkt = Packet() + pkt.header = 0x06 + pkt.path_len = 0 + pkt.path = bytearray() + pkt.payload = bytearray(b"test") + pkt.payload_len = 4 + + bridge._apply_path_hash_mode(pkt) + + assert pkt.path_len == 0x40 # (1 << 6) | 0 = 0x40 + assert pkt.get_path_hash_size() == 2 + assert pkt.get_path_hash_count() == 0 + + def test_skips_nonzero_hops(self): + """Packets with existing hops (stored contact path) are untouched.""" + bridge = _make_bridge(path_hash_mode=2) + pkt = Packet() + pkt.header = 0x06 + # 3 hops with 1-byte hashes + pkt.set_path(b"\xAA\xBB\xCC") + pkt.payload = bytearray(b"test") + pkt.payload_len = 4 + + original_path_len = pkt.path_len + bridge._apply_path_hash_mode(pkt) + + # path_len unchanged — the contact path is preserved + assert pkt.path_len == original_path_len + assert pkt.get_path_hash_count() == 3 + + def test_all_modes(self): + """Verify mode 0→0x00, mode 1→0x40, mode 2→0x80 on fresh packets.""" + expected = { + 0: (0x00, 1), # (path_len, hash_size) + 1: (0x40, 2), + 2: (0x80, 3), + } + for mode, (expected_path_len, expected_hash_size) in expected.items(): + bridge = _make_bridge(path_hash_mode=mode) + pkt = Packet() + pkt.header = 0x06 + pkt.path_len = 0 + pkt.path = bytearray() + pkt.payload = bytearray(b"x") + pkt.payload_len = 1 + + bridge._apply_path_hash_mode(pkt) + + assert pkt.path_len == expected_path_len, ( + f"mode={mode}: expected path_len=0x{expected_path_len:02X}, " + f"got 0x{pkt.path_len:02X}" + ) + assert pkt.get_path_hash_size() == expected_hash_size, ( + f"mode={mode}: expected hash_size={expected_hash_size}, " + f"got {pkt.get_path_hash_size()}" + ) + + def test_skips_trace_packets(self): + """Trace packets use path for SNR values, not routing hashes.""" + bridge = _make_bridge(path_hash_mode=1) + pkt = Packet() + # Trace packet: payload_type=PAYLOAD_TYPE_TRACE, route_type=ROUTE_TYPE_DIRECT + pkt.header = (PAYLOAD_TYPE_TRACE << 2) | ROUTE_TYPE_DIRECT + pkt.path_len = 0 + pkt.path = bytearray() + pkt.payload = bytearray(b"trace_data") + pkt.payload_len = 10 + + bridge._apply_path_hash_mode(pkt) + + # path_len must stay 0 — NOT 0x40 + assert pkt.path_len == 0 + assert pkt.get_path_hash_size() == 1 + assert pkt.get_path_hash_count() == 0 diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py index b4e1e59..5b1f9bd 100644 --- a/tests/test_dispatcher.py +++ b/tests/test_dispatcher.py @@ -7,6 +7,7 @@ from pymc_core.protocol import Packet from pymc_core.protocol.constants import PAYLOAD_TYPE_ACK, PAYLOAD_TYPE_ADVERT, PAYLOAD_TYPE_TXT_MSG from pymc_core.protocol.packet_filter import PacketFilter +from pymc_core.protocol.packet_utils import PathUtils def create_test_packet(payload_type: int, payload: bytes) -> bytes: @@ -447,6 +448,58 @@ def test_callback(packet, data, analysis): assert received_packet is not None assert received_data == packet_data + @pytest.mark.asyncio + async def test_full_hop_count_packet_marked_do_not_retransmit(self, dispatcher): + """Packet with path at max hops for its encoding is marked do not retransmit.""" + received_packet = None + + def capture(packet, data, analysis): + nonlocal received_packet + received_packet = packet + + dispatcher.set_raw_packet_callback(capture) + + # 1-byte hashes: max 63 hops + pkt = Packet() + pkt.header = (1 << 6) | (PAYLOAD_TYPE_TXT_MSG << 2) + pkt.path_len = PathUtils.encode_path_len(1, 63) + pkt.path = bytearray(bytes(63)) + pkt.payload = bytearray(b"x") + pkt.payload_len = 1 + packet_data = pkt.write_to() + + await dispatcher._process_received_packet(packet_data) + + assert received_packet is not None + assert received_packet.is_marked_do_not_retransmit() is True + assert received_packet.get_path_hash_count() == 63 + + @pytest.mark.asyncio + async def test_2byte_path_at_max_hops_marked_do_not_retransmit(self, dispatcher): + """Packet with 2-byte path at 32 hops (64 bytes) is marked do not retransmit.""" + received_packet = None + + def capture(packet, data, analysis): + nonlocal received_packet + received_packet = packet + + dispatcher.set_raw_packet_callback(capture) + + pkt = Packet() + pkt.header = (1 << 6) | (PAYLOAD_TYPE_TXT_MSG << 2) + pkt.path_len = PathUtils.encode_path_len(2, 32) + pkt.path = bytearray(64) # 32 * 2 + pkt.payload = bytearray(b"x") + pkt.payload_len = 1 + packet_data = pkt.write_to() + + await dispatcher._process_received_packet(packet_data) + + assert received_packet is not None + assert received_packet.is_marked_do_not_retransmit() is True + assert received_packet.get_path_hash_count() == 32 + assert received_packet.get_path_byte_len() == 64 + @pytest.mark.asyncio async def test_async_callback(self, dispatcher): """Test async callback.""" diff --git a/tests/test_frame_server.py b/tests/test_frame_server.py index d39d5a8..d2284fd 100644 --- a/tests/test_frame_server.py +++ b/tests/test_frame_server.py @@ -112,8 +112,10 @@ def __init__(self, success: bool = True): self.calls = [] self._success = success - async def send_raw_data_direct(self, path: bytes, payload: bytes): - self.calls.append((path, payload)) + async def send_raw_data_direct( + self, path: bytes, payload: bytes, *, path_len_encoded: int = None + ): + self.calls.append((path, payload, path_len_encoded)) return SentResult(success=self._success) @@ -126,7 +128,11 @@ async def test_cmd_send_raw_data_valid_writes_ok(): server._write_err = Mock() data = bytes([1, 0x42]) + b"\x01\x02\x03\x04" await server._cmd_send_raw_data(data) - assert bridge.calls == [(b"\x42", b"\x01\x02\x03\x04")] + assert len(bridge.calls) == 1 + path, payload, path_len_enc = bridge.calls[0] + assert path == b"\x42" + assert payload == b"\x01\x02\x03\x04" + assert path_len_enc == 1 # 1-byte hash, 1 hop server._write_ok.assert_called_once() server._write_err.assert_not_called() @@ -158,6 +164,61 @@ async def test_cmd_send_raw_data_send_failure_writes_table_full(): server._write_ok.assert_not_called() +@pytest.mark.asyncio +async def test_cmd_send_raw_data_2byte_hashes(): + """CMD_SEND_RAW_DATA with 2-byte hash path encoding.""" + from pymc_core.protocol.packet_utils import PathUtils + + bridge = _MockBridgeSendRawDirect(success=True) + server = CompanionFrameServer(bridge, "hash", port=0) + server._write_ok = Mock() + server._write_err = Mock() + # path_len_encoded=0x42 → 2-byte hashes, 2 hops → 4 bytes of path + path_len_byte = PathUtils.encode_path_len(2, 2) # 0x42 + path_data = b"\x01\x02\x03\x04" + payload_data = b"\xAA\xBB\xCC\xDD" + data = bytes([path_len_byte]) + path_data + payload_data + await server._cmd_send_raw_data(data) + assert len(bridge.calls) == 1 + path, payload, path_len_enc = bridge.calls[0] + assert path == path_data + assert payload == payload_data + assert path_len_enc == path_len_byte + server._write_ok.assert_called_once() + + +@pytest.mark.asyncio +async def test_cmd_send_raw_data_invalid_path_encoding(): + """CMD_SEND_RAW_DATA with reserved hash_size=4 encoding → error.""" + bridge = _MockBridgeSendRawDirect() + server = CompanionFrameServer(bridge, "hash", port=0) + server._write_ok = Mock() + server._write_err = Mock() + # 0xC1 = hash_size 4 (reserved), should fail validation + data = bytes([0xC1]) + b"\x00" * 10 + await server._cmd_send_raw_data(data) + assert len(bridge.calls) == 0 + server._write_err.assert_called_once_with(ERR_CODE_UNSUPPORTED_CMD) + + +@pytest.mark.asyncio +async def test_cmd_send_raw_data_truncated_multibyte_path(): + """CMD_SEND_RAW_DATA with not enough path bytes for 2-byte encoding → error.""" + from pymc_core.protocol.packet_utils import PathUtils + + bridge = _MockBridgeSendRawDirect() + server = CompanionFrameServer(bridge, "hash", port=0) + server._write_ok = Mock() + server._write_err = Mock() + # 0x43 = 2-byte hashes, 3 hops → needs 6 path bytes + 4 payload = 11 total + # But only provide 8 bytes after path_len (not enough) + path_len_byte = PathUtils.encode_path_len(2, 3) # 0x43 + data = bytes([path_len_byte]) + b"\x00" * 8 # only 8 bytes, need 6+4=10 + await server._cmd_send_raw_data(data) + assert len(bridge.calls) == 0 + server._write_err.assert_called_once_with(ERR_CODE_UNSUPPORTED_CMD) + + @pytest.mark.asyncio async def test_push_trace_data_enqueues_frame(): """push_trace_data enqueues a correctly formatted trace frame.""" @@ -247,3 +308,101 @@ async def _send_sentinel(): writer.write.assert_called_once() writer.drain.assert_awaited_once() + + +# --------------------------------------------------------------------------- +# CMD_SET_PATH_HASH_MODE tests +# --------------------------------------------------------------------------- + + +class _MockBridgePathHashMode: + """Minimal bridge for CMD_SET_PATH_HASH_MODE tests.""" + + def __init__(self): + self.calls = [] + + def set_path_hash_mode(self, mode: int) -> None: + self.calls.append(mode) + + +@pytest.mark.asyncio +async def test_cmd_set_path_hash_mode_valid(): + """Valid CMD_SET_PATH_HASH_MODE for each mode (0, 1, 2) → _write_ok.""" + for mode in (0, 1, 2): + bridge = _MockBridgePathHashMode() + server = CompanionFrameServer(bridge, "hash", port=0) + server._write_ok = Mock() + server._write_err = Mock() + await server._cmd_set_path_hash_mode(bytes([0, mode])) + assert bridge.calls == [mode] + server._write_ok.assert_called_once() + server._write_err.assert_not_called() + + +@pytest.mark.asyncio +async def test_cmd_set_path_hash_mode_invalid_mode(): + """CMD_SET_PATH_HASH_MODE with mode >= 3 → ERR_CODE_ILLEGAL_ARG.""" + from pymc_core.companion.constants import ERR_CODE_ILLEGAL_ARG + + bridge = _MockBridgePathHashMode() + server = CompanionFrameServer(bridge, "hash", port=0) + server._write_ok = Mock() + server._write_err = Mock() + await server._cmd_set_path_hash_mode(bytes([0, 3])) + assert len(bridge.calls) == 0 + server._write_err.assert_called_once_with(ERR_CODE_ILLEGAL_ARG) + + +@pytest.mark.asyncio +async def test_cmd_set_path_hash_mode_wrong_subtype(): + """CMD_SET_PATH_HASH_MODE with subtype != 0 → ERR_CODE_ILLEGAL_ARG.""" + from pymc_core.companion.constants import ERR_CODE_ILLEGAL_ARG + + bridge = _MockBridgePathHashMode() + server = CompanionFrameServer(bridge, "hash", port=0) + server._write_ok = Mock() + server._write_err = Mock() + await server._cmd_set_path_hash_mode(bytes([1, 0])) + assert len(bridge.calls) == 0 + server._write_err.assert_called_once_with(ERR_CODE_ILLEGAL_ARG) + + +@pytest.mark.asyncio +async def test_cmd_set_path_hash_mode_too_short(): + """CMD_SET_PATH_HASH_MODE with only 1 byte → ERR_CODE_ILLEGAL_ARG.""" + from pymc_core.companion.constants import ERR_CODE_ILLEGAL_ARG + + bridge = _MockBridgePathHashMode() + server = CompanionFrameServer(bridge, "hash", port=0) + server._write_ok = Mock() + server._write_err = Mock() + await server._cmd_set_path_hash_mode(bytes([0])) + assert len(bridge.calls) == 0 + server._write_err.assert_called_once_with(ERR_CODE_ILLEGAL_ARG) + + +@pytest.mark.asyncio +async def test_device_info_includes_path_hash_mode(): + """RESP_CODE_DEVICE_INFO frame includes path_hash_mode at byte [81].""" + from pymc_core.companion.constants import RESP_CODE_DEVICE_INFO + from pymc_core.companion.models import NodePrefs + + prefs = NodePrefs() + prefs.path_hash_mode = 2 # 3-byte hashes + + bridge = Mock() + bridge.get_self_info = Mock(return_value=prefs) + bridge.contacts = Mock(max_contacts=100) + bridge.channels = Mock(max_channels=8) + + server = CompanionFrameServer(bridge, "hash", port=0) + frames = [] + server._write_frame = lambda f: frames.append(f) + + await server._cmd_device_query(bytes([10])) # app_ver = 10 + + assert len(frames) == 1 + frame = frames[0] + assert frame[0] == RESP_CODE_DEVICE_INFO + assert len(frame) == 82 # 81 bytes (old) + 1 byte path_hash_mode + assert frame[81] == 2 # path_hash_mode at last byte diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 77941b6..2f82414 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -444,6 +444,59 @@ async def on_path_updated(pub: bytes, path_len: int, path_bytes_arg: bytes) -> N assert callback_calls[0][1] == path_len_byte assert callback_calls[0][2] == path_bytes + @pytest.mark.asyncio + async def test_contact_path_updated_with_2byte_hashes(self): + """PATH with 2-byte hashes decrypts and updates contact path correctly.""" + from pymc_core.companion.contact_store import ContactStore + from pymc_core.companion.models import Contact + from pymc_core.protocol.packet_utils import PathUtils + + local_identity = LocalIdentity() + peer_identity = LocalIdentity() + peer_pubkey = peer_identity.get_public_key() + contacts = ContactStore(5) + contacts.add(Contact(public_key=peer_pubkey, name="Peer")) + log_fn = MagicMock() + handler = ProtocolResponseHandler(log_fn, local_identity, contacts) + handler.set_binary_response_callback(lambda *a, **k: None) + + # 2 hops × 2-byte hashes = 4 bytes of path data + path_len_byte = PathUtils.encode_path_len(2, 2) # 0x42 + path_bytes = bytes([0x01, 0x02, 0x03, 0x04]) + extra_type = PAYLOAD_TYPE_RESPONSE + extra = bytes([0, 0, 0, 0, 0x00]) + plaintext = bytes([path_len_byte]) + path_bytes + bytes([extra_type]) + extra + + peer_id = Identity(peer_pubkey) + shared_secret = peer_id.calc_shared_secret(local_identity.get_private_key()) + aes_key = shared_secret[:16] + encrypted = CryptoUtils.encrypt_then_mac(aes_key, shared_secret, plaintext) + + our_hash = local_identity.get_public_key()[0] + src_hash = peer_pubkey[0] + payload = bytes([our_hash, src_hash]) + encrypted + + pkt = Packet() + pkt.header = (0 << 0) | (PAYLOAD_TYPE_PATH << 2) + pkt.path_len = 0 + pkt.path = bytearray() + pkt.payload = bytearray(payload) + pkt.payload_len = len(payload) + + callback_calls = [] + + async def on_path_updated(pub: bytes, path_len: int, path_bytes_arg: bytes) -> None: + callback_calls.append((pub, path_len, path_bytes_arg)) + + handler.set_contact_path_updated_callback(on_path_updated) + + await handler(pkt) + + assert len(callback_calls) == 1 + assert callback_calls[0][0] == peer_pubkey + assert callback_calls[0][1] == path_len_byte # encoded byte, not raw count + assert callback_calls[0][2] == path_bytes # all 4 bytes of path data + class TestTraceHandler: def setup_method(self): diff --git a/tests/test_packet.py b/tests/test_packet.py index 02e1328..c0d849d 100644 --- a/tests/test_packet.py +++ b/tests/test_packet.py @@ -1,4 +1,7 @@ +import pytest + from pymc_core.protocol import Packet +from pymc_core.protocol.packet_utils import PathUtils # Packet tests @@ -28,3 +31,290 @@ def test_packet_validation(): # Should validate successfully packet._validate_lengths() + + +# --- Multi-byte path support --- + + +class TestPacketSetPath: + """Tests for Packet.set_path() and path accessor methods.""" + + def test_set_path_default_1byte_hashes(self): + """set_path without encoded path_len assumes 1-byte hashes.""" + pkt = Packet() + pkt.header = 0x02 # ROUTE_TYPE_DIRECT + pkt.payload = bytearray(b"data") + pkt.payload_len = 4 + + pkt.set_path(b"\xAA\xBB\xCC") + + assert pkt.path == bytearray(b"\xAA\xBB\xCC") + assert pkt.path_len == 3 # encode_path_len(1, 3) == 3 + assert pkt.get_path_hash_size() == 1 + assert pkt.get_path_hash_count() == 3 + assert pkt.get_path_byte_len() == 3 + + def test_set_path_with_2byte_encoded(self): + """set_path with 2-byte hash encoded path_len.""" + pkt = Packet() + pkt.header = 0x02 + pkt.payload = bytearray(b"data") + pkt.payload_len = 4 + + # 2 hops × 2-byte hashes = 4 bytes of path data + encoded = PathUtils.encode_path_len(2, 2) # 0x42 + pkt.set_path(b"\x01\x02\x03\x04", path_len_encoded=encoded) + + assert pkt.path_len == 0x42 + assert pkt.get_path_hash_size() == 2 + assert pkt.get_path_hash_count() == 2 + assert pkt.get_path_byte_len() == 4 + assert pkt.path == bytearray(b"\x01\x02\x03\x04") + + def test_set_path_with_3byte_encoded(self): + """set_path with 3-byte hash encoded path_len.""" + pkt = Packet() + pkt.header = 0x02 + pkt.payload = bytearray(b"data") + pkt.payload_len = 4 + + # 3 hops × 3-byte hashes = 9 bytes of path data + encoded = PathUtils.encode_path_len(3, 3) # 0x83 + path_data = bytes(range(9)) + pkt.set_path(path_data, path_len_encoded=encoded) + + assert pkt.path_len == 0x83 + assert pkt.get_path_hash_size() == 3 + assert pkt.get_path_hash_count() == 3 + assert pkt.get_path_byte_len() == 9 + assert pkt.path == bytearray(path_data) + + def test_set_path_empty(self): + """set_path with empty path.""" + pkt = Packet() + pkt.header = 0x02 + pkt.payload = bytearray(b"data") + pkt.payload_len = 4 + + pkt.set_path(b"") + + assert pkt.path_len == 0 + assert pkt.get_path_hash_size() == 1 + assert pkt.get_path_hash_count() == 0 + assert pkt.get_path_byte_len() == 0 + + def test_set_path_64_bytes_raises_without_encoded(self): + """64-byte path without path_len_encoded would encode as 0 hops (64 & 0x3F).""" + pkt = Packet() + pkt.header = 0x02 + pkt.payload = bytearray(b"data") + pkt.payload_len = 4 + + with pytest.raises(ValueError, match="path length 64 exceeds maximum encodable"): + pkt.set_path(bytes(64)) + + # With explicit path_len_encoded (63 hops) path must be 63 bytes + encoded_63 = PathUtils.encode_path_len(1, 63) + pkt.set_path(bytes(63), path_len_encoded=encoded_63) + assert pkt.path_len == 0x3F + assert pkt.get_path_byte_len() == 63 + + +class TestGetPathHashes: + """Tests for Packet.get_path_hashes() and get_path_hashes_hex().""" + + def test_1byte_hashes(self): + pkt = Packet() + pkt.set_path(b"\xAA\xBB\xCC") + assert pkt.get_path_hashes() == [b"\xAA", b"\xBB", b"\xCC"] + + def test_2byte_hashes(self): + pkt = Packet() + encoded = PathUtils.encode_path_len(2, 2) + pkt.set_path(b"\xAA\xBB\xCC\xDD", path_len_encoded=encoded) + assert pkt.get_path_hashes() == [b"\xAA\xBB", b"\xCC\xDD"] + + def test_3byte_hashes(self): + pkt = Packet() + encoded = PathUtils.encode_path_len(3, 2) + pkt.set_path(b"\xAA\xBB\xCC\xDD\xEE\xFF", path_len_encoded=encoded) + assert pkt.get_path_hashes() == [b"\xAA\xBB\xCC", b"\xDD\xEE\xFF"] + + def test_hex_2byte(self): + pkt = Packet() + encoded = PathUtils.encode_path_len(2, 2) + pkt.set_path(b"\xAA\xBB\xCC\xDD", path_len_encoded=encoded) + assert pkt.get_path_hashes_hex() == ["AABB", "CCDD"] + + def test_hex_1byte_backward_compat(self): + pkt = Packet() + pkt.set_path(b"\xB5\xA3\xF2") + assert pkt.get_path_hashes_hex() == ["B5", "A3", "F2"] + + def test_empty_path(self): + pkt = Packet() + pkt.path_len = 0 + pkt.path = bytearray() + assert pkt.get_path_hashes() == [] + assert pkt.get_path_hashes_hex() == [] + + def test_zero_hops_with_hash_mode(self): + """0 hops but hash_size=2 (originated with path_hash_mode=1).""" + pkt = Packet() + pkt.path_len = PathUtils.encode_path_len(2, 0) # 0x40 + pkt.path = bytearray() + assert pkt.get_path_hashes() == [] + assert pkt.get_path_hashes_hex() == [] + + +class TestPacketRoundTrip: + """Tests for write_to → read_from round-trip with multi-byte paths.""" + + def _make_packet(self, header, path_data, path_len_encoded, payload): + """Helper to build a packet with given path encoding.""" + pkt = Packet() + pkt.header = header + pkt.set_path(path_data, path_len_encoded) + pkt.payload = bytearray(payload) + pkt.payload_len = len(payload) + return pkt + + def test_roundtrip_no_path(self): + """Round-trip with empty path.""" + pkt = self._make_packet(0x05, b"", 0, b"hello") + raw = pkt.write_to() + pkt2 = Packet() + pkt2.read_from(raw) + + assert pkt2.header == 0x05 + assert pkt2.path_len == 0 + assert pkt2.get_path_byte_len() == 0 + assert pkt2.path == bytearray() + assert pkt2.get_payload() == b"hello" + + def test_roundtrip_1byte_hashes(self): + """Round-trip with 1-byte hashes (backward compatible).""" + path = b"\xAA\xBB\xCC" + pkt = self._make_packet(0x06, path, PathUtils.encode_path_len(1, 3), b"payload") + raw = pkt.write_to() + pkt2 = Packet() + pkt2.read_from(raw) + + assert pkt2.get_path_hash_size() == 1 + assert pkt2.get_path_hash_count() == 3 + assert pkt2.get_path_byte_len() == 3 + assert pkt2.path == bytearray(path) + assert pkt2.get_payload() == b"payload" + + def test_roundtrip_2byte_hashes(self): + """Round-trip with 2-byte hashes.""" + # 4 hops × 2 bytes = 8 bytes path + path = bytes(range(8)) + encoded = PathUtils.encode_path_len(2, 4) + pkt = self._make_packet(0x06, path, encoded, b"test") + raw = pkt.write_to() + pkt2 = Packet() + pkt2.read_from(raw) + + assert pkt2.path_len == encoded + assert pkt2.get_path_hash_size() == 2 + assert pkt2.get_path_hash_count() == 4 + assert pkt2.get_path_byte_len() == 8 + assert pkt2.path == bytearray(path) + assert pkt2.get_payload() == b"test" + + def test_roundtrip_3byte_hashes(self): + """Round-trip with 3-byte hashes.""" + # 3 hops × 3 bytes = 9 bytes path + path = bytes([0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99]) + encoded = PathUtils.encode_path_len(3, 3) + pkt = self._make_packet(0x06, path, encoded, b"msg") + raw = pkt.write_to() + pkt2 = Packet() + pkt2.read_from(raw) + + assert pkt2.path_len == encoded + assert pkt2.get_path_hash_size() == 3 + assert pkt2.get_path_hash_count() == 3 + assert pkt2.get_path_byte_len() == 9 + assert pkt2.path == bytearray(path) + assert pkt2.get_payload() == b"msg" + + def test_roundtrip_max_2byte_path(self): + """Round-trip with maximum valid 2-byte hash path (32 hops × 2 = 64 bytes).""" + path = bytes(range(64)) + encoded = PathUtils.encode_path_len(2, 32) + pkt = self._make_packet(0x06, path, encoded, b"x") + raw = pkt.write_to() + pkt2 = Packet() + pkt2.read_from(raw) + + assert pkt2.get_path_hash_size() == 2 + assert pkt2.get_path_hash_count() == 32 + assert pkt2.get_path_byte_len() == 64 + assert pkt2.path == bytearray(path) + + def test_roundtrip_max_3byte_path(self): + """Round-trip with maximum valid 3-byte hash path (21 hops × 3 = 63 bytes).""" + path = bytes(range(63)) + encoded = PathUtils.encode_path_len(3, 21) + pkt = self._make_packet(0x06, path, encoded, b"x") + raw = pkt.write_to() + pkt2 = Packet() + pkt2.read_from(raw) + + assert pkt2.get_path_hash_size() == 3 + assert pkt2.get_path_hash_count() == 21 + assert pkt2.get_path_byte_len() == 63 + assert pkt2.path == bytearray(path) + + def test_roundtrip_1byte_backward_compat(self): + """Backward compatibility: 1-byte hash packets identical to legacy format.""" + # Build the packet the OLD way (direct assignment) + pkt_old = Packet() + pkt_old.header = 0x06 + pkt_old.path = bytearray(b"\x01\x02\x03") + pkt_old.path_len = 3 + pkt_old.payload = bytearray(b"test") + pkt_old.payload_len = 4 + raw_old = pkt_old.write_to() + + # Build the packet the NEW way (set_path) + pkt_new = Packet() + pkt_new.header = 0x06 + pkt_new.set_path(b"\x01\x02\x03") + pkt_new.payload = bytearray(b"test") + pkt_new.payload_len = 4 + raw_new = pkt_new.write_to() + + # Wire formats must be identical + assert raw_old == raw_new + + def test_read_from_invalid_path_len_raises(self): + """read_from rejects reserved hash_size=4 encoding.""" + # Hand-craft a raw packet with path_len 0xC1 (hash_size=4, invalid) + raw = bytes([0x06, 0xC1]) # header + invalid path_len + pkt = Packet() + with pytest.raises(ValueError, match="invalid path_len encoding"): + pkt.read_from(raw) + + def test_read_from_truncated_path_raises(self): + """read_from rejects packet with insufficient path bytes.""" + # 2 hops × 2-byte hashes = 4 bytes expected, but only provide 2 + encoded = PathUtils.encode_path_len(2, 2) # needs 4 bytes of path + raw = bytes([0x06, encoded, 0xAA, 0xBB]) # header + path_len + only 2 path bytes + pkt = Packet() + with pytest.raises(ValueError, match="truncated path"): + pkt.read_from(raw) + + def test_get_raw_length_multibyte(self): + """get_raw_length accounts for multi-byte path encoding.""" + pkt = Packet() + pkt.header = 0x06 + # 5 hops × 2-byte hashes = 10 bytes path + pkt.set_path(bytes(10), PathUtils.encode_path_len(2, 5)) + pkt.payload = bytearray(b"test") + pkt.payload_len = 4 + + # raw_length = header(1) + path_len_byte(1) + path(10) + payload(4) = 16 + assert pkt.get_raw_length() == 16 diff --git a/tests/test_packet_utils.py b/tests/test_packet_utils.py index bd6a0e6..2eff143 100644 --- a/tests/test_packet_utils.py +++ b/tests/test_packet_utils.py @@ -7,6 +7,7 @@ PacketDataUtils, PacketHashingUtils, PacketValidationUtils, + PathUtils, ) @@ -184,3 +185,155 @@ def test_hash_string_truncates_to_requested_length(self): assert truncated == expected_hex[:16] assert len(truncated) == 16 assert truncated.isupper() + + +class TestPathUtils: + """Tests for multi-byte path encoding/decoding utilities.""" + + # --- get_path_hash_size --- + + def test_hash_size_1byte(self): + """Bits 6-7 == 0b00 → hash_size = 1.""" + assert PathUtils.get_path_hash_size(0x00) == 1 + assert PathUtils.get_path_hash_size(0x05) == 1 + assert PathUtils.get_path_hash_size(0x3F) == 1 # max hop count, 1-byte hashes + + def test_hash_size_2byte(self): + """Bits 6-7 == 0b01 → hash_size = 2.""" + assert PathUtils.get_path_hash_size(0x40) == 2 + assert PathUtils.get_path_hash_size(0x45) == 2 + assert PathUtils.get_path_hash_size(0x7F) == 2 + + def test_hash_size_3byte(self): + """Bits 6-7 == 0b10 → hash_size = 3.""" + assert PathUtils.get_path_hash_size(0x80) == 3 + assert PathUtils.get_path_hash_size(0x8A) == 3 + assert PathUtils.get_path_hash_size(0xBF) == 3 + + def test_hash_size_reserved(self): + """Bits 6-7 == 0b11 → hash_size = 4 (reserved, invalid).""" + assert PathUtils.get_path_hash_size(0xC0) == 4 + + # --- get_path_hash_count --- + + def test_hash_count_extracts_lower_6_bits(self): + """Hash count is the lower 6 bits (0-63).""" + assert PathUtils.get_path_hash_count(0x00) == 0 + assert PathUtils.get_path_hash_count(0x05) == 5 + assert PathUtils.get_path_hash_count(0x3F) == 63 + # Upper bits should be masked off + assert PathUtils.get_path_hash_count(0x45) == 5 # 0b01_000101 + assert PathUtils.get_path_hash_count(0x8A) == 10 # 0b10_001010 + assert PathUtils.get_path_hash_count(0xC0) == 0 # 0b11_000000 + + # --- get_path_byte_len --- + + def test_byte_len_1byte_hashes(self): + """1-byte hashes: byte_len = hop_count * 1.""" + assert PathUtils.get_path_byte_len(0x00) == 0 + assert PathUtils.get_path_byte_len(0x01) == 1 + assert PathUtils.get_path_byte_len(0x05) == 5 + assert PathUtils.get_path_byte_len(0x3F) == 63 + + def test_byte_len_2byte_hashes(self): + """2-byte hashes: byte_len = hop_count * 2.""" + assert PathUtils.get_path_byte_len(0x40) == 0 # 0 hops × 2 + assert PathUtils.get_path_byte_len(0x41) == 2 # 1 hop × 2 + assert PathUtils.get_path_byte_len(0x45) == 10 # 5 hops × 2 + assert PathUtils.get_path_byte_len(0x60) == 64 # 32 hops × 2 = 64 + + def test_byte_len_3byte_hashes(self): + """3-byte hashes: byte_len = hop_count * 3.""" + assert PathUtils.get_path_byte_len(0x80) == 0 # 0 hops × 3 + assert PathUtils.get_path_byte_len(0x81) == 3 # 1 hop × 3 + assert PathUtils.get_path_byte_len(0x8A) == 30 # 10 hops × 3 + assert PathUtils.get_path_byte_len(0x95) == 63 # 21 hops × 3 = 63 + + # --- encode_path_len --- + + def test_encode_1byte_hashes(self): + """1-byte hashes: encoded = (0 << 6) | count = count.""" + assert PathUtils.encode_path_len(1, 0) == 0x00 + assert PathUtils.encode_path_len(1, 1) == 0x01 + assert PathUtils.encode_path_len(1, 5) == 0x05 + assert PathUtils.encode_path_len(1, 63) == 0x3F + + def test_encode_2byte_hashes(self): + """2-byte hashes: encoded = (1 << 6) | count.""" + assert PathUtils.encode_path_len(2, 0) == 0x40 + assert PathUtils.encode_path_len(2, 1) == 0x41 + assert PathUtils.encode_path_len(2, 5) == 0x45 + assert PathUtils.encode_path_len(2, 32) == 0x60 + + def test_encode_3byte_hashes(self): + """3-byte hashes: encoded = (2 << 6) | count.""" + assert PathUtils.encode_path_len(3, 0) == 0x80 + assert PathUtils.encode_path_len(3, 1) == 0x81 + assert PathUtils.encode_path_len(3, 10) == 0x8A + assert PathUtils.encode_path_len(3, 21) == 0x95 + + def test_encode_decode_roundtrip(self): + """encode → get_path_hash_size + get_path_hash_count roundtrip.""" + for hash_size in (1, 2, 3): + for count in (0, 1, 10, 63): + encoded = PathUtils.encode_path_len(hash_size, count) + assert PathUtils.get_path_hash_size(encoded) == hash_size + assert PathUtils.get_path_hash_count(encoded) == count + assert PathUtils.get_path_byte_len(encoded) == hash_size * count + + # --- is_valid_path_len --- + + def test_valid_path_len_1byte(self): + """1-byte hashes: valid up to MAX_PATH_SIZE hops.""" + assert PathUtils.is_valid_path_len(0x00) is True # 0 hops + assert PathUtils.is_valid_path_len(0x01) is True # 1 hop + assert PathUtils.is_valid_path_len(0x3F) is True # 63 hops, 63 bytes ≤ MAX_PATH_SIZE(64) + + def test_valid_path_len_2byte(self): + """2-byte hashes: valid when hop_count * 2 ≤ MAX_PATH_SIZE.""" + assert PathUtils.is_valid_path_len(0x40) is True # 0 hops + assert PathUtils.is_valid_path_len(0x41) is True # 1 hop × 2 = 2 + assert PathUtils.is_valid_path_len(0x60) is True # 32 hops × 2 = 64 ≤ MAX_PATH_SIZE + assert PathUtils.is_valid_path_len(0x61) is False # 33 hops × 2 = 66 > MAX_PATH_SIZE + + def test_valid_path_len_3byte(self): + """3-byte hashes: valid when hop_count * 3 ≤ MAX_PATH_SIZE.""" + assert PathUtils.is_valid_path_len(0x80) is True # 0 hops + assert PathUtils.is_valid_path_len(0x95) is True # 21 hops × 3 = 63 ≤ MAX_PATH_SIZE + assert PathUtils.is_valid_path_len(0x96) is False # 22 hops × 3 = 66 > MAX_PATH_SIZE + + def test_valid_path_len_reserved_hash_size(self): + """Hash size 4 (bits 6-7 == 0b11) is reserved and always invalid.""" + assert PathUtils.is_valid_path_len(0xC0) is False # hash_size=4, 0 hops + assert PathUtils.is_valid_path_len(0xC1) is False # hash_size=4, 1 hop + assert PathUtils.is_valid_path_len(0xFF) is False # hash_size=4, 63 hops + + # --- Backward compatibility --- + + def test_1byte_backward_compatible(self): + """For 1-byte hashes, encoded path_len == raw hop count == byte count.""" + for n in range(0, 64): + encoded = PathUtils.encode_path_len(1, n) + assert encoded == n + assert PathUtils.get_path_byte_len(encoded) == n + + def test_encode_path_len_rejects_hop_count_64(self): + """Hop count is 6 bits (0-63); 64 would mask to 0 and produce invalid packet.""" + with pytest.raises(ValueError, match="hop count must be 0-63"): + PathUtils.encode_path_len(1, 64) + with pytest.raises(ValueError, match="hop count must be 0-63"): + PathUtils.encode_path_len(1, 100) + + def test_is_path_at_max_hops(self): + """is_path_at_max_hops is True when path bytes/hops are at limit for hash size.""" + # No path + assert PathUtils.is_path_at_max_hops(0) is False + # 1-byte hashes: max 63 hops + assert PathUtils.is_path_at_max_hops(PathUtils.encode_path_len(1, 62)) is False + assert PathUtils.is_path_at_max_hops(PathUtils.encode_path_len(1, 63)) is True + # 2-byte hashes: max 32 hops (64 bytes) + assert PathUtils.is_path_at_max_hops(PathUtils.encode_path_len(2, 31)) is False + assert PathUtils.is_path_at_max_hops(PathUtils.encode_path_len(2, 32)) is True + # 3-byte hashes: max 21 hops (63 bytes) + assert PathUtils.is_path_at_max_hops(PathUtils.encode_path_len(3, 20)) is False + assert PathUtils.is_path_at_max_hops(PathUtils.encode_path_len(3, 21)) is True From 1dd0fbb4f677e1aa1d29ce754f3be05394702f23 Mon Sep 17 00:00:00 2001 From: agessaman Date: Thu, 5 Mar 2026 20:59:18 -0800 Subject: [PATCH 50/50] feat(companion): enhance binary parsing and node name synchronization - Added logging for optional LPP parse failures in binary parsing functions to improve debugging. - Introduced a method to retrieve the group text handler for node name synchronization in CompanionBase, CompanionBridge, and CompanionRadio classes. - Updated path length validation in Packet class to provide clearer error messages for invalid encodings. - Enhanced tests to ensure proper handling of path length constraints and error scenarios. --- src/pymc_core/companion/binary_parsing.py | 8 +- src/pymc_core/companion/companion_base.py | 11 +- src/pymc_core/companion/companion_bridge.py | 8 +- src/pymc_core/companion/companion_radio.py | 10 +- src/pymc_core/companion/frame_server.py | 144 +++++++------------- src/pymc_core/protocol/packet.py | 5 +- src/pymc_core/protocol/packet_utils.py | 6 +- tests/test_packet_path_and_hash.py | 58 +++++--- 8 files changed, 120 insertions(+), 130 deletions(-) diff --git a/src/pymc_core/companion/binary_parsing.py b/src/pymc_core/companion/binary_parsing.py index 74d3791..fc150da 100644 --- a/src/pymc_core/companion/binary_parsing.py +++ b/src/pymc_core/companion/binary_parsing.py @@ -2,10 +2,13 @@ from __future__ import annotations +import logging from typing import Optional from .constants import BinaryReqType +logger = logging.getLogger(__name__) + def parse_binary_response( request_type: int, @@ -74,7 +77,7 @@ def _parse_telemetry(data: bytes) -> dict: {"channel": d.channel, "type": d.type_id, "value": d.data} for d in frame.data ] except Exception: - pass + logger.debug("Optional LPP parse failed for telemetry", exc_info=True) return out @@ -87,7 +90,7 @@ def _parse_mma(data: bytes) -> dict: frame = LppFrame.from_bytes(data) out["mma"] = [{"channel": d.channel, "type": d.type_id, "data": d.data} for d in frame.data] except Exception: - pass + logger.debug("Optional LPP parse failed for MMA", exc_info=True) return out @@ -104,6 +107,7 @@ def _parse_owner_info(data: bytes) -> dict: "raw_text": text, } except Exception: + logger.debug("Owner info parse failed, returning fallback", exc_info=True) return {"raw_hex": data.hex(), "request_type": BinaryReqType.OWNER_INFO} diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py index eda2cfb..208542b 100644 --- a/src/pymc_core/companion/companion_base.py +++ b/src/pymc_core/companion/companion_base.py @@ -338,10 +338,15 @@ def set_advert_name(self, name: str) -> None: self._save_prefs() self._sync_our_node_name_to_handlers() + def _get_group_text_handler(self) -> Optional[Any]: + """Return the group text handler for name sync, or None. Override in Radio/Bridge.""" + return None + def _sync_our_node_name_to_handlers(self) -> None: - """Sync node name to group text handler for echo detection. - No-op in base; override in Bridge/Radio.""" - pass + """Sync node name to group text handler for echo detection.""" + handler = self._get_group_text_handler() + if handler is not None: + handler.set_our_node_name(self.prefs.node_name) def set_advert_latlon(self, lat: float, lon: float) -> None: """Set the GPS coordinates included in advertisements.""" diff --git a/src/pymc_core/companion/companion_bridge.py b/src/pymc_core/companion/companion_bridge.py index 0001627..963d5fa 100644 --- a/src/pymc_core/companion/companion_bridge.py +++ b/src/pymc_core/companion/companion_bridge.py @@ -275,11 +275,9 @@ def _get_login_response_handler(self) -> Any: def _get_text_handler(self) -> Any: return self._text_handler_ref - def _sync_our_node_name_to_handlers(self) -> None: - """Sync current node name to group text handler for echo detection.""" - handler = self._handlers.get(PAYLOAD_TYPE_GRP_TXT) - if handler is not None: - handler.set_our_node_name(self.prefs.node_name) + def _get_group_text_handler(self): + """Return the group text handler for name sync.""" + return self._handlers.get(PAYLOAD_TYPE_GRP_TXT) # ------------------------------------------------------------------------- # RX Entry Point diff --git a/src/pymc_core/companion/companion_radio.py b/src/pymc_core/companion/companion_radio.py index ee03bd6..75c8dea 100644 --- a/src/pymc_core/companion/companion_radio.py +++ b/src/pymc_core/companion/companion_radio.py @@ -136,7 +136,7 @@ async def stop(self) -> None: try: self.node.dispatcher.remove_raw_packet_subscriber(self._on_raw_packet_rx_log) except Exception: - pass + logger.debug("Remove raw packet subscriber during stop failed", exc_info=True) if self._dispatcher_task: self._dispatcher_task.cancel() try: @@ -173,11 +173,9 @@ def set_advert_name(self, name: str) -> None: super().set_advert_name(name) self.node.node_name = self.prefs.node_name - def _sync_our_node_name_to_handlers(self) -> None: - """Sync current node name to group text handler for echo detection.""" - handler = getattr(self.node.dispatcher, "group_text_handler", None) - if handler is not None: - handler.set_our_node_name(self.prefs.node_name) + def _get_group_text_handler(self): + """Return the group text handler for name sync.""" + return getattr(self.node.dispatcher, "group_text_handler", None) def set_radio_params(self, freq_hz: int, bw_hz: int, sf: int, cr: int) -> bool: super().set_radio_params(freq_hz, bw_hz, sf, cr) diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index 86309bb..29cdb1b 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -212,6 +212,55 @@ def __init__( self._model_bytes = (device_model.encode("utf-8") + b"\x00")[:40].ljust(40, b"\x00") self._version_bytes = (device_version.encode("utf-8") + b"\x00")[:20].ljust(20, b"\x00") + # Command dispatch registry: cmd byte -> async handler(data) + self._cmd_handlers = { + CMD_APP_START: self._cmd_app_start, + CMD_DEVICE_QUERY: self._cmd_device_query, + CMD_GET_CONTACTS: self._cmd_get_contacts, + CMD_GET_CONTACT_BY_KEY: self._cmd_get_contact_by_key, + CMD_SEND_TXT_MSG: self._cmd_send_txt_msg, + CMD_SEND_CHANNEL_TXT_MSG: self._cmd_send_channel_txt_msg, + CMD_SYNC_NEXT_MESSAGE: self._cmd_sync_next_message, + CMD_SEND_LOGIN: self._cmd_send_login, + CMD_SEND_STATUS_REQ: self._cmd_send_status_req, + CMD_SEND_TELEMETRY_REQ: self._cmd_send_telemetry_req, + CMD_SEND_SELF_ADVERT: self._cmd_send_self_advert, + CMD_SET_ADVERT_NAME: self._cmd_set_advert_name, + CMD_SET_ADVERT_LATLON: self._cmd_set_advert_latlon, + CMD_ADD_UPDATE_CONTACT: self._cmd_add_update_contact, + CMD_REMOVE_CONTACT: self._cmd_remove_contact, + CMD_RESET_PATH: self._cmd_reset_path, + CMD_GET_BATT_AND_STORAGE: self._cmd_get_batt_and_storage, + CMD_GET_STATS: self._cmd_get_stats, + CMD_GET_ADVERT_PATH: self._cmd_get_advert_path, + CMD_IMPORT_CONTACT: self._cmd_import_contact, + CMD_GET_CHANNEL: self._cmd_get_channel, + CMD_SET_CHANNEL: self._cmd_set_channel, + CMD_SEND_BINARY_REQ: self._cmd_send_binary_req, + CMD_SEND_ANON_REQ: self._cmd_send_anon_req, + CMD_SEND_PATH_DISCOVERY_REQ: self._cmd_send_path_discovery_req, + CMD_SEND_CONTROL_DATA: self._cmd_send_control_data, + CMD_SEND_TRACE_PATH: self._cmd_send_trace_path, + CMD_SET_FLOOD_SCOPE: self._cmd_set_flood_scope, + CMD_GET_DEVICE_TIME: self._cmd_get_device_time, + CMD_SET_DEVICE_TIME: self._cmd_set_device_time, + CMD_SET_RADIO_PARAMS: self._cmd_set_radio_params, + CMD_SET_RADIO_TX_POWER: self._cmd_set_tx_power, + CMD_SHARE_CONTACT: self._cmd_share_contact, + CMD_EXPORT_CONTACT: self._cmd_export_contact, + CMD_EXPORT_PRIVATE_KEY: self._cmd_export_private_key, + CMD_IMPORT_PRIVATE_KEY: self._cmd_import_private_key, + CMD_SET_TUNING_PARAMS: self._cmd_set_tuning_params, + CMD_LOGOUT: self._cmd_logout, + CMD_GET_CUSTOM_VARS: self._cmd_get_custom_vars, + CMD_SET_CUSTOM_VAR: self._cmd_set_custom_var, + CMD_SET_AUTOADD_CONFIG: self._cmd_set_autoadd_config, + CMD_GET_AUTOADD_CONFIG: self._cmd_get_autoadd_config, + CMD_SET_OTHER_PARAMS: self._cmd_set_other_params, + CMD_SEND_RAW_DATA: self._cmd_send_raw_data, + CMD_SET_PATH_HASH_MODE: self._cmd_set_path_hash_mode, + } + # ------------------------------------------------------------------------- # Lifecycle # ------------------------------------------------------------------------- @@ -802,96 +851,9 @@ async def _handle_cmd(self, payload: bytes) -> None: ) try: - if cmd == CMD_APP_START: - await self._cmd_app_start(data) - elif cmd == CMD_DEVICE_QUERY: - await self._cmd_device_query(data) - elif cmd == CMD_GET_CONTACTS: - await self._cmd_get_contacts(data) - elif cmd == CMD_GET_CONTACT_BY_KEY: - await self._cmd_get_contact_by_key(data) - elif cmd == CMD_SEND_TXT_MSG: - await self._cmd_send_txt_msg(data) - elif cmd == CMD_SEND_CHANNEL_TXT_MSG: - await self._cmd_send_channel_txt_msg(data) - elif cmd == CMD_SYNC_NEXT_MESSAGE: - await self._cmd_sync_next_message(data) - elif cmd == CMD_SEND_LOGIN: - await self._cmd_send_login(data) - elif cmd == CMD_SEND_STATUS_REQ: - await self._cmd_send_status_req(data) - elif cmd == CMD_SEND_TELEMETRY_REQ: - await self._cmd_send_telemetry_req(data) - elif cmd == CMD_SEND_SELF_ADVERT: - await self._cmd_send_self_advert(data) - elif cmd == CMD_SET_ADVERT_NAME: - await self._cmd_set_advert_name(data) - elif cmd == CMD_SET_ADVERT_LATLON: - await self._cmd_set_advert_latlon(data) - elif cmd == CMD_ADD_UPDATE_CONTACT: - await self._cmd_add_update_contact(data) - elif cmd == CMD_REMOVE_CONTACT: - await self._cmd_remove_contact(data) - elif cmd == CMD_RESET_PATH: - await self._cmd_reset_path(data) - elif cmd == CMD_GET_BATT_AND_STORAGE: - await self._cmd_get_batt_and_storage(data) - elif cmd == CMD_GET_STATS: - await self._cmd_get_stats(data) - elif cmd == CMD_GET_ADVERT_PATH: - await self._cmd_get_advert_path(data) - elif cmd == CMD_IMPORT_CONTACT: - await self._cmd_import_contact(data) - elif cmd == CMD_GET_CHANNEL: - await self._cmd_get_channel(data) - elif cmd == CMD_SET_CHANNEL: - await self._cmd_set_channel(data) - elif cmd == CMD_SEND_BINARY_REQ: - await self._cmd_send_binary_req(data) - elif cmd == CMD_SEND_ANON_REQ: - await self._cmd_send_anon_req(data) - elif cmd == CMD_SEND_PATH_DISCOVERY_REQ: - await self._cmd_send_path_discovery_req(data) - elif cmd == CMD_SEND_CONTROL_DATA: - await self._cmd_send_control_data(data) - elif cmd == CMD_SEND_TRACE_PATH: - await self._cmd_send_trace_path(data) - elif cmd == CMD_SET_FLOOD_SCOPE: - await self._cmd_set_flood_scope(data) - elif cmd == CMD_GET_DEVICE_TIME: - await self._cmd_get_device_time(data) - elif cmd == CMD_SET_DEVICE_TIME: - await self._cmd_set_device_time(data) - elif cmd == CMD_SET_RADIO_PARAMS: - await self._cmd_set_radio_params(data) - elif cmd == CMD_SET_RADIO_TX_POWER: - await self._cmd_set_tx_power(data) - elif cmd == CMD_SHARE_CONTACT: - await self._cmd_share_contact(data) - elif cmd == CMD_EXPORT_CONTACT: - await self._cmd_export_contact(data) - elif cmd == CMD_EXPORT_PRIVATE_KEY: - await self._cmd_export_private_key(data) - elif cmd == CMD_IMPORT_PRIVATE_KEY: - await self._cmd_import_private_key(data) - elif cmd == CMD_SET_TUNING_PARAMS: - await self._cmd_set_tuning_params(data) - elif cmd == CMD_LOGOUT: - await self._cmd_logout(data) - elif cmd == CMD_GET_CUSTOM_VARS: - await self._cmd_get_custom_vars(data) - elif cmd == CMD_SET_CUSTOM_VAR: - await self._cmd_set_custom_var(data) - elif cmd == CMD_SET_AUTOADD_CONFIG: - await self._cmd_set_autoadd_config(data) - elif cmd == CMD_GET_AUTOADD_CONFIG: - await self._cmd_get_autoadd_config(data) - elif cmd == CMD_SET_OTHER_PARAMS: - await self._cmd_set_other_params(data) - elif cmd == CMD_SEND_RAW_DATA: - await self._cmd_send_raw_data(data) - elif cmd == CMD_SET_PATH_HASH_MODE: - await self._cmd_set_path_hash_mode(data) + handler = self._cmd_handlers.get(cmd) + if handler is not None: + await handler(data) else: logger.warning( "Companion unsupported cmd 0x%02x (%s) len=%s", @@ -944,7 +906,7 @@ async def _cmd_app_start(self, data: bytes) -> None: self._write_frame(frame) async def _cmd_device_query(self, data: bytes) -> None: - # Layout must match MeshCore companion_radio MyMesh.cpp handleCmdFrame() CMD_DEVICE_QEURY: + # Layout must match MeshCore companion_radio MyMesh.cpp handleCmdFrame() CMD_DEVICE_QUEURY: # [0]=RESP_CODE_DEVICE_INFO, [1]=FIRMWARE_VER_CODE, [2]=MAX_CONTACTS/2, # [3]=MAX_GROUP_CHANNELS, [4..7]=ble_pin, [8..19]=build_date(12), [20..59]=manufacturer(40), # [60..79]=version(20), [80]=client_repeat, [81]=path_hash_mode (v10+). diff --git a/src/pymc_core/protocol/packet.py b/src/pymc_core/protocol/packet.py index 6bc3747..c5235b9 100644 --- a/src/pymc_core/protocol/packet.py +++ b/src/pymc_core/protocol/packet.py @@ -405,7 +405,10 @@ def read_from(self, data: ByteString) -> bool: self.path_len = data[idx] idx += 1 if not PathUtils.is_valid_path_len(self.path_len): - raise ValueError(f"invalid path_len encoding: 0x{self.path_len:02X}") + hash_size = PathUtils.get_path_hash_size(self.path_len) + if hash_size > 3: + raise ValueError(f"invalid path_len encoding: 0x{self.path_len:02X}") + raise ValueError("path_len too large") path_byte_len = self.get_path_byte_len() self._check_bounds(idx, path_byte_len, data_len, "truncated path") diff --git a/src/pymc_core/protocol/packet_utils.py b/src/pymc_core/protocol/packet_utils.py index d112fda..b9490c9 100644 --- a/src/pymc_core/protocol/packet_utils.py +++ b/src/pymc_core/protocol/packet_utils.py @@ -146,8 +146,10 @@ def encode_path_len(hash_size: int, hash_count: int) -> int: def is_valid_path_len(path_len_byte: int) -> bool: """Validate an encoded path_len byte. - Returns False for hash_size == 4 (reserved) or if the total - path bytes would exceed MAX_PATH_SIZE. + Path length in bytes must never exceed MAX_PATH_SIZE (64): at most 64 + one-byte hops, 32 two-byte hops, or 21 three-byte hops. Returns False + for hash_size == 4 (reserved) or if the total path bytes would exceed + MAX_PATH_SIZE. """ hash_size = (path_len_byte >> PATH_HASH_SIZE_SHIFT) + 1 if hash_size > 3: diff --git a/tests/test_packet_path_and_hash.py b/tests/test_packet_path_and_hash.py index b5894cd..c17907a 100644 --- a/tests/test_packet_path_and_hash.py +++ b/tests/test_packet_path_and_hash.py @@ -21,12 +21,10 @@ MAX_PACKET_PAYLOAD, MAX_PATH_SIZE, MAX_SUPPORTED_PAYLOAD_VERSION, - PATH_HASH_SIZE, PAYLOAD_TYPE_ACK, PAYLOAD_TYPE_ADVERT, - PAYLOAD_TYPE_TXT_MSG, PAYLOAD_TYPE_TRACE, - PH_ROUTE_MASK, + PAYLOAD_TYPE_TXT_MSG, PH_TYPE_SHIFT, PH_VER_SHIFT, ROUTE_TYPE_DIRECT, @@ -34,17 +32,13 @@ ROUTE_TYPE_TRANSPORT_DIRECT, ROUTE_TYPE_TRANSPORT_FLOOD, ) -from pymc_core.protocol.packet_utils import ( - PacketHashingUtils, - PacketHeaderUtils, - PacketValidationUtils, -) - +from pymc_core.protocol.packet_utils import PacketHashingUtils, PacketValidationUtils # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- + def _make_header(payload_type: int, route_type: int, version: int = 0) -> int: """Build a header byte from components (mirrors PacketHeaderUtils.make_header).""" return route_type | (payload_type << PH_TYPE_SHIFT) | (version << 6) @@ -94,6 +88,7 @@ def _build_packet( # 1. Basic hash correctness — matches reference C++ implementation # =================================================================== + class TestPacketHashBasic: """Verify hash output matches the C++ reference for various packet types.""" @@ -177,6 +172,7 @@ def test_trace_path_len_uint16_not_uint8(self): # 2. Hash excludes header, path, route type, transport codes # =================================================================== + class TestPacketHashExclusions: """Verify that fields NOT in the C++ hash don't affect the Python hash.""" @@ -193,7 +189,9 @@ def test_hash_independent_of_path_content(self): payload = b"path_test_data" pkt_no_path = _build_packet(PAYLOAD_TYPE_TXT_MSG, ROUTE_TYPE_FLOOD, payload) - pkt_with_path = _build_packet(PAYLOAD_TYPE_TXT_MSG, ROUTE_TYPE_FLOOD, payload, path=b"\xAA\xBB\xCC") + pkt_with_path = _build_packet( + PAYLOAD_TYPE_TXT_MSG, ROUTE_TYPE_FLOOD, payload, path=b"\xAA\xBB\xCC" + ) assert pkt_no_path.calculate_packet_hash() == pkt_with_path.calculate_packet_hash() @@ -239,6 +237,7 @@ def test_different_payload_type_gives_different_hash(self): # 3. payload_len truncation — hash must use only payload[:payload_len] # =================================================================== + class TestPayloadLenTruncation: """Verify hash uses payload[:payload_len], not the full bytearray buffer.""" @@ -307,6 +306,7 @@ def test_trace_hash_uses_payload_len_truncation_too(self): # 4. Flood forwarding: path append must not change duplicate detection # =================================================================== + class TestFloodPathAppend: """Simulate flood forwarding and verify hash stability.""" @@ -363,6 +363,7 @@ def test_max_path_flood(self): # 5. Direct forwarding: path consume must not change duplicate detection # =================================================================== + class TestDirectPathConsume: """Simulate direct forwarding path consumption and verify hash stability.""" @@ -422,6 +423,7 @@ def test_trace_hash_changes_after_path_consume(self): # 6. Serialization round-trip preserves hash # =================================================================== + class TestSerializationHashPreservation: """Verify that write_to → read_from round-trip preserves the packet hash.""" @@ -505,6 +507,7 @@ def test_roundtrip_after_flood_append_preserves_hash(self): # 7. PacketHashingUtils standalone tests # =================================================================== + class TestPacketHashingUtilsStandalone: """Test the static utility directly, confirming C++ compatibility.""" @@ -559,11 +562,12 @@ def test_hash_string_truncation(self): # 8. Edge cases and regression guards # =================================================================== -class TestEdgeCases: +class TestEdgeCases: def test_max_payload_hashes_correctly(self): """MAX_PACKET_PAYLOAD-sized payload should hash without error.""" from pymc_core.protocol.constants import MAX_PACKET_PAYLOAD + payload = bytes(range(256)) * (MAX_PACKET_PAYLOAD // 256 + 1) payload = payload[:MAX_PACKET_PAYLOAD] @@ -619,6 +623,7 @@ def test_header_0xff_sentinel_changes_payload_type(self): # 9. Bad / malformed packets — must be rejected by validation # =================================================================== + class TestBadPacketDeserialization: """Verify that read_from rejects malformed wire data.""" @@ -639,27 +644,34 @@ def test_unsupported_version_rejected(self): """Version > MAX_SUPPORTED_PAYLOAD_VERSION should be rejected.""" bad_version = MAX_SUPPORTED_PAYLOAD_VERSION + 1 # Craft header with unsupported version in bits 6-7 - header = ROUTE_TYPE_FLOOD | (PAYLOAD_TYPE_TXT_MSG << PH_TYPE_SHIFT) | (bad_version << PH_VER_SHIFT) + header = ( + ROUTE_TYPE_FLOOD + | (PAYLOAD_TYPE_TXT_MSG << PH_TYPE_SHIFT) + | (bad_version << PH_VER_SHIFT) + ) wire = bytes([header, 0]) # header + path_len=0, no payload pkt = Packet() with pytest.raises(ValueError, match="Unsupported packet version"): pkt.read_from(wire) def test_path_len_exceeds_max(self): - """path_len > MAX_PATH_SIZE (64) must be rejected.""" + """Encoded path_len that decodes to > MAX_PATH_SIZE (64) bytes must be rejected.""" + from pymc_core.protocol.packet_utils import PathUtils + header = _make_header(PAYLOAD_TYPE_TXT_MSG, ROUTE_TYPE_FLOOD) - bad_path_len = MAX_PATH_SIZE + 1 # 65 - wire = bytes([header, bad_path_len]) + bytes(bad_path_len) + b"payload" + # 33 hops × 2 bytes = 66 path bytes (exceeds MAX_PATH_SIZE) + bad_path_len = PathUtils.encode_path_len(hash_size=2, hash_count=33) + wire = bytes([header, bad_path_len]) + bytes(66) + b"payload" pkt = Packet() with pytest.raises(ValueError, match="path_len too large"): pkt.read_from(wire) def test_path_len_255_rejected(self): - """path_len=255 (max uint8) must be rejected.""" + """path_len=255 (reserved hash_size 4) must be rejected.""" header = _make_header(PAYLOAD_TYPE_TXT_MSG, ROUTE_TYPE_FLOOD) wire = bytes([header, 0xFF]) + bytes(255) pkt = Packet() - with pytest.raises(ValueError, match="path_len too large"): + with pytest.raises(ValueError, match="invalid path_len encoding"): pkt.read_from(wire) def test_truncated_path_rejected(self): @@ -696,13 +708,18 @@ def test_oversized_payload_rejected(self): pkt.read_from(wire) def test_path_len_exact_max_accepted(self): - """path_len == MAX_PATH_SIZE should be accepted (boundary test).""" + """Encoded path_len that decodes to exactly MAX_PATH_SIZE (64) bytes should be accepted.""" + from pymc_core.protocol.packet_utils import PathUtils + header = _make_header(PAYLOAD_TYPE_TXT_MSG, ROUTE_TYPE_FLOOD) + # 32 hops × 2 bytes = 64 path bytes (exactly MAX_PATH_SIZE) + path_len_byte = PathUtils.encode_path_len(hash_size=2, hash_count=32) path_data = bytes(MAX_PATH_SIZE) - wire = bytes([header, MAX_PATH_SIZE]) + path_data + b"ok" + wire = bytes([header, path_len_byte]) + path_data + b"ok" pkt = Packet() pkt.read_from(wire) - assert pkt.path_len == MAX_PATH_SIZE + assert pkt.path_len == path_len_byte + assert len(pkt.path) == MAX_PATH_SIZE assert pkt.payload == bytearray(b"ok") @@ -772,6 +789,7 @@ def test_random_garbage_rejected_or_benign(self): or raise ValueError — never crash with an unhandled exception. """ import random + rng = random.Random(42) # deterministic seed for _ in range(100): length = rng.randint(0, 300)