diff --git a/docs/docs/companion.md b/docs/docs/companion.md new file mode 100644 index 0000000..6cadf40 --- /dev/null +++ b/docs/docs/companion.md @@ -0,0 +1,1042 @@ +# Companion Module + +The companion module provides a high-level Python interface to the MeshCore companion radio protocol. It manages contacts, messaging, channels, advertisements, path routing, telemetry, cryptographic signing, and device configuration on top of pyMC_core's `MeshNode`. + +Three main classes are provided: + +| Class | Owns Radio | Use Case | +|---|---|---| +| `CompanionRadio` | Yes | Standalone companion — wraps a hardware radio and `MeshNode` | +| `CompanionBridge` | No | Repeater-integrated companion — shares an existing dispatcher via a packet injector callback | +| `CompanionFrameServer` | No | TCP server implementing the MeshCore companion frame protocol (the binary wire format used by companion apps) | + +`CompanionRadio` and `CompanionBridge` both inherit from `CompanionBase` (an abstract base class), which holds all shared stores, event handling, device configuration logic, and unified TX methods. Subclasses implement transport via the abstract `_send_packet` method. + +`CompanionFrameServer` wraps a `CompanionBridge` (or any `CompanionBase` subclass) and exposes it over TCP using the same binary frame protocol as the MeshCore firmware companion radio. + +--- + +## Architecture + +``` +CompanionBase (ABC) +├── ContactStore (in-memory contacts, max 1000) +├── ChannelStore (group channels, max 40) +├── MessageQueue (offline FIFO, max 512) +├── PathCache (recent advert paths, max 16) +├── StatsCollector (TX/RX counters, uptime) +├── NodePrefs (radio params, name, location) +│ +│ Persistence hooks (no-op by default, override for SQLite/JSON): +│ _save_prefs, _load_prefs +│ +│ Unified methods (use abstract _send_packet): +│ advertise, share_contact, send_text_message, +│ send_channel_message, send_binary_req, +│ send_path_discovery_req, send_trace_path, +│ send_login, send_logout, sync_next_message +│ +├─► CompanionRadio (owns MeshNode + hardware radio) +└─► CompanionBridge (packet_injector callback, no radio) + │ + └─► CompanionFrameServer (TCP binary frame protocol server) + │ + │ Persistence hooks (no-op by default): + │ _persist_companion_message, _sync_next_from_persistence, + │ _save_contacts, _save_channels, _get_batt_and_storage + │ + └─► (your subclass, e.g. SQLite-backed repeater) +``` + +--- + +## Installation + +```bash +pip install pymc_core # core only +pip install pymc_core[hardware] # SX1262 direct radio support +pip install pymc_core[all] # everything +``` + +--- + +## CompanionRadio + +`CompanionRadio` is a standalone companion that owns a hardware radio and a `MeshNode`. It is the typical entry point for building a chat application, sensor gateway, or automation tool that participates in a MeshCore network. + +### Quick Start + +```python +import asyncio +from pymc_core import LocalIdentity +from pymc_core.companion import ( + CompanionRadio, + ADV_TYPE_CHAT, + STATS_TYPE_PACKETS, +) + +async def main(): + # --- Setup --- + from pymc_core.hardware import KissModemWrapper + + radio = KissModemWrapper("/dev/ttyUSB0") + radio.connect() + + identity = LocalIdentity() # generates a new Ed25519 keypair + companion = CompanionRadio( + radio=radio, + identity=identity, + node_name="myNode", + adv_type=ADV_TYPE_CHAT, + ) + + # --- Register callbacks before starting --- + companion.on_message_received(on_msg) + companion.on_advert_received(on_advert) + companion.on_channel_message_received(on_chan_msg) + companion.on_send_confirmed(on_ack) + + await companion.start() + + # --- Advertise presence --- + await companion.advertise(flood=True) + + # --- Send a direct message --- + dest = bytes.fromhex("ab" * 32) # 32-byte public key + result = await companion.send_text_message(dest, "Hello mesh!") + print(f"Sent: success={result.success}, flood={result.is_flood}") + + # --- Keep running --- + try: + while True: + await asyncio.sleep(1) + finally: + await companion.stop() + + +# --- Callbacks --- +def on_msg(sender_key, text, timestamp, txt_type, *args): + print(f"DM from {sender_key[:8].hex()}: {text}") + +def on_advert(contact): + print(f"Discovered: {contact.name} (type={contact.adv_type})") + +def on_chan_msg(channel_name, sender_name, text, timestamp, path_len, channel_idx, *args): + print(f"[{channel_name}] {sender_name}: {text}") + +def on_ack(ack_crc): + print(f"ACK confirmed: {ack_crc:#x}") + + +asyncio.run(main()) +``` + +### Constructor + +```python +CompanionRadio( + radio, # hardware radio wrapper + identity: LocalIdentity, # Ed25519 identity + node_name: str = "pyMC", + adv_type: int = ADV_TYPE_CHAT, # 1=chat, 2=repeater, 3=room, 4=sensor + max_contacts: int = 1000, + max_channels: int = 40, + offline_queue_size: int = 512, + radio_config: dict | None = None, + initial_contacts: iterable of Contact | None = None, # optional bulk load on boot +) +``` + +### Lifecycle + +```python +await companion.start() # start dispatcher task +await companion.stop() # cancel dispatcher +companion.is_running # bool property +``` + +### Messaging + +```python +# Direct text message +result = await companion.send_text_message(pub_key, "hello", txt_type=0, attempt=1) +# result: SentResult(success, is_flood, expected_ack, timeout_ms) + +# Group channel message +ok = await companion.send_channel_message(channel_idx=0, text="hello group") + +# Pop oldest queued offline message +msg = companion.sync_next_message() # -> QueuedMessage | None + +# Raw binary data (direct path only) +result = await companion.send_raw_data(dest_key, data=b"\x01\x02", path=None) +``` + +### Advertisements + +```python +await companion.advertise(flood=True) # broadcast presence +await companion.share_contact(pub_key) # share a contact via direct advert +``` + +### Contact Management + +```python +from pymc_core.companion import Contact + +# List / lookup +contacts = companion.get_contacts(since=0) +contact = companion.get_contact_by_key(pub_key_bytes) +contact = companion.get_contact_by_name("Alice") + +# Add / update / remove +companion.add_update_contact(Contact(public_key=key, name="Bob")) +companion.remove_contact(pub_key_bytes) + +# Populate on boot: pass initial_contacts into the constructor (CompanionRadio or CompanionBridge). +# Replaces the need to call the store after construction. +contacts_from_prefs = [Contact(public_key=k, name=name) for ...] # e.g. from _load_prefs or a file +companion = CompanionRadio(radio, identity, node_name="myNode", initial_contacts=contacts_from_prefs) +await companion.start() + +# If your data is dicts (e.g. JSON/DB), load after construction: +companion.contacts.load_from_dicts([{"public_key": key_hex, "name": "Bob"}, ...]) + +# Reset routing path (force re-discovery) +companion.reset_path(pub_key_bytes) + +# Serialise for sharing +blob = companion.export_contact(pub_key) # bytes (73-byte binary packet) +ok = companion.import_contact(blob) # bool +``` + +### Channel Management + +```python +from pymc_core.companion import Channel + +companion.set_channel(0, name="General", secret=b"shared_secret_key_here__________") +ch = companion.get_channel(0) # -> Channel | None +``` + +### Path Discovery & Tracing + +```python +# Path discovery (returns SentResult, fires on_path_discovery_response callback) +result = await companion.send_path_discovery_req(pub_key) + +# Trace path through the network +ok = await companion.send_trace_path(pub_key, tag=42, auth_code=0, flags=0) + +# Get recently heard advert path +advert_path = companion.get_advert_path(pub_key_prefix_7bytes) +``` + +### Repeater Interaction + +```python +# Login to a repeater +resp = await companion.send_login(repeater_key, password="secret") + +# Logout from a repeater +ok = await companion.send_logout(repeater_key) + +# Request repeater status +resp = await companion.send_status_request(repeater_key) + +# Request telemetry +resp = await companion.send_telemetry_request( + repeater_key, + want_base=True, + want_location=True, + want_environment=False, + timeout=10.0, +) + +# Send a text command to a repeater +resp = await companion.send_repeater_command(repeater_key, command="status") +``` + +### Binary Requests + +The generic binary request/response mechanism uses random 4-byte tags for matching. + +```python +# Send and wait for response +result = await companion.send_binary_req(pub_key, data=b"\x01", timeout_seconds=10) + +# Register a callback for responses +companion.on_binary_response( + lambda tag, data, parsed, req_type: print(f"Response: {parsed}") +) +``` + +### Device Configuration + +All preference-mutating methods automatically call `_save_prefs()` (a no-op by default; override in subclasses for persistence). + +```python +companion.set_advert_name("NewName") # max 31 chars +companion.set_advert_latlon(37.7749, -122.4194) # GPS coordinates +companion.set_radio_params(915_000_000, 250_000, 10, 5) # freq, bw, SF, CR +companion.set_tx_power(22) # dBm +companion.set_tuning_params(rx_delay=0.0, airtime_factor=0.0) + +# Fetch current radio configuration (frequency, bandwidth, SF, CR, TX power, tuning) +radio_params = companion.get_radio_params() +# {'frequency_hz': 915000000, 'bandwidth_hz': 250000, 'spreading_factor': 10, ...} + +# Time management (transient, not persisted) +device_time = companion.get_time() # Unix timestamp +ok = companion.set_time(1700000000) # returns False if in the past + +# Custom variables (key:value string pairs, max 140 chars total) +custom_vars = companion.get_custom_vars() # -> dict[str, str] +ok = companion.set_custom_var("key", "value") # -> bool + +# Auto-add configuration +config = companion.get_autoadd_config() # -> int bitmask +companion.set_autoadd_config(AUTOADD_CHAT | AUTOADD_REPEATER) + +# Location sharing in adverts +from pymc_core.companion import ADVERT_LOC_SHARE +companion.set_other_params( + manual_add=0, + telemetry_modes=0, + advert_loc_policy=ADVERT_LOC_SHARE, + multi_acks=0, +) + +prefs = companion.get_self_info() # -> NodePrefs (copy) +``` + +### Flood Scope (Regions) + +Constrain flood packets to a specific region using transport key scoping. +Nodes outside the region will ignore scoped flood packets. + +```python +from pymc_core.protocol.transport_keys import get_auto_key_for + +# Set region by name (auto-derives transport key via SHA-256) +companion.set_flood_region("usa") # '#' prefix added automatically +companion.set_flood_region("#europe") # explicit '#' also works + +# Or set directly with a raw 16-byte transport key +key = get_auto_key_for("#usa") +companion.set_flood_scope(key) + +# Clear scope (flood to all nodes) +companion.set_flood_region(None) +``` + +When a flood scope is active, all flood packets are tagged with a 16-bit transport code +(HMAC-SHA256 derived) and sent as `ROUTE_TYPE_TRANSPORT_FLOOD`. Direct-routed packets +are unaffected. + +### Cryptographic Signing + +```python +buf_size = companion.sign_start() # returns max buffer size (8192) +companion.sign_data(b"data to sign...") +signature = companion.sign_finish() # -> 64-byte Ed25519 signature +``` + +### Statistics + +```python +from pymc_core.companion import STATS_TYPE_CORE, STATS_TYPE_RADIO, STATS_TYPE_PACKETS + +stats = companion.get_stats(STATS_TYPE_CORE) +# {'uptime': 3600, 'queue_len': 2, 'contacts_count': 15, 'channels_count': 3} + +stats = companion.get_stats(STATS_TYPE_PACKETS) +# {'flood_tx': 42, 'flood_rx': 108, 'direct_tx': 5, 'direct_rx': 12, ...} +``` + +### CompanionRadio-Specific Overrides + +`CompanionRadio` overrides several `CompanionBase` methods to also configure the physical radio hardware: + +| Method | Base Behavior | Radio Override | +|---|---|---| +| `set_radio_params()` | Updates `prefs` fields | Also calls `radio.configure_radio()` | +| `set_tx_power()` | Updates `prefs.tx_power_dbm` | Also calls `radio.set_tx_power()` | +| `set_advert_name()` | Updates `prefs.node_name` | Also syncs `node.node_name` | +| `set_flood_scope()` | Stores transport key | Also syncs to `node.dispatcher` | +| `set_flood_region()` | Derives key from name | Also syncs to `node.dispatcher` | + +--- + +## CompanionBridge + +`CompanionBridge` is designed for repeater integration. It does not own a radio or `MeshNode` — instead, the repeater host feeds received packets in via `process_received_packet()`, and all outbound packets go through a `packet_injector` callback you provide. This lets a companion identity coexist alongside a repeater on the same radio. + +### Quick Start + +```python +import asyncio +from pymc_core import LocalIdentity +from pymc_core.companion import CompanionBridge, ADV_TYPE_CHAT + +async def main(): + identity = LocalIdentity() + + async def packet_injector(pkt, wait_for_ack=False): + """Forward packet to the repeater's radio.""" + return await my_repeater.send_packet(pkt) + + def authenticate(user_hash, password): + """Validate login attempts. Return (success, acl_bits).""" + if user_hash == expected_hash: + return (True, 0x01) + return (False, 0) + + bridge = CompanionBridge( + identity=identity, + packet_injector=packet_injector, + node_name="myBridge", + adv_type=ADV_TYPE_CHAT, + authenticate_callback=authenticate, + ) + + bridge.on_message_received( + lambda key, text, ts, tt, *args: print(f"Bridge msg: {text}") + ) + + await bridge.start() + + # Feed packets from the repeater's dispatcher + async def on_repeater_rx(packet): + await bridge.process_received_packet(packet) + + # ... register on_repeater_rx with your repeater ... + +asyncio.run(main()) +``` + +### Constructor + +```python +CompanionBridge( + identity: LocalIdentity, + packet_injector: Callable, # async (pkt, wait_for_ack=False) -> bool + node_name: str = "pyMC", + adv_type: int = ADV_TYPE_CHAT, + max_contacts: int = 1000, + max_channels: int = 40, + offline_queue_size: int = 512, + radio_config: dict | None = None, + authenticate_callback: Callable | None = None, # (hash, pw) -> (bool, int) + initial_contacts: iterable of Contact | None = None, # optional bulk load on boot +) +``` + +### RX Entry Point + +```python +# Called by the repeater host for every received packet +await bridge.process_received_packet(packet) +``` + +The bridge registers internal handlers for these payload types: + +| Payload Type | Handler | +|---|---| +| ACK | Bridge ACK handler (matches pending CRCs) | +| TXT_MSG | TextMessageHandler | +| ADVERT | AdvertHandler | +| PATH | PathHandler | +| ANON_REQ | LoginServerHandler | +| GRP_TXT | GroupTextHandler | +| RESPONSE | LoginResponseHandler | + +### All Other APIs + +`CompanionBridge` exposes the same messaging, contact, channel, path, signing, stats, and configuration APIs as `CompanionRadio` (inherited from `CompanionBase`). The only behavioral difference is that all TX goes through the `packet_injector` instead of an owned radio. + +Note that **CompanionBridge does not own the radio**. `set_radio_params()` and `set_tx_power()` update in-memory prefs only; there is no physical radio to configure. `get_radio_params()` and `get_self_info()` return those in-memory prefs, not the repeater's actual hardware configuration. + +### Avoiding doubled messages + +When you have **multiple bridges** on the same repeater, the other bridge can receive the same logical message twice: once from local fan-out when one bridge injects a packet, and again when the repeater's radio receives that packet over the air. Deduplication uses a packet hash; if the data used for local delivery differs from the bytes sent over the air, the hashes differ and both copies appear. + +**Use one canonical byte representation** for both TX and local delivery: + +1. When your `packet_injector` is called with `pkt`, apply any in-place changes (e.g. flood scope) **before** serializing. +2. Serialize once: `raw = pkt.write_to()`. +3. Send `raw` on the radio. +4. For local delivery to the other bridge, build a new packet from those **same** bytes and pass it: + `pkt2 = Packet(); pkt2.read_from(raw); await other_bridge.process_received_packet(pkt2)`. + +Then both local and OTA deliveries share the same packet hash and companion-side dedup collapses the duplicate. + +**Optional:** Feed the same `raw` bytes into the repeater's dispatcher RX path (the same entry the radio uses) instead of calling `other_bridge.process_received_packet(pkt)` directly. The dispatcher will track the packet hash; when the same bytes arrive over the air they are dropped as duplicates, and you deliver to bridges only from the dispatcher (single path, no double delivery). + +Example injector with serialize-once local delivery to another bridge: + +```python +from pymc_core.protocol import Packet + +async def packet_injector(pkt, wait_for_ack=False): + raw = pkt.write_to() # after any in-place changes (e.g. flood scope) + ok = await repeater.send_raw(raw) # or dispatcher.send_packet with pre-serialized bytes + if ok and other_bridge: + pkt2 = Packet() + pkt2.read_from(raw) + await other_bridge.process_received_packet(pkt2) + return ok +``` + +--- + +## CompanionFrameServer + +`CompanionFrameServer` implements the MeshCore companion radio TCP frame protocol — the same binary wire format used by the C++ firmware (`examples/companion_radio/`). It wraps a `CompanionBase` subclass (typically a `CompanionBridge`) and exposes it to companion apps (e.g. MeshCore Android/iOS) over a TCP socket. + +### Frame Format + +All frames use a simple length-prefixed format: + +| Direction | Prefix | Length | Data | +|---|---|---|---| +| App → Radio | `<` (0x3C) | 2-byte LE | Command byte + payload | +| Radio → App | `>` (0x3E) | 2-byte LE | Response/push byte + payload | + +Maximum frame size: 172 bytes (matches firmware; BLE MTU). + +### Quick Start + +```python +from pymc_core import LocalIdentity +from pymc_core.companion import CompanionBridge, CompanionFrameServer + +identity = LocalIdentity() +bridge = CompanionBridge(identity=identity, packet_injector=my_injector) + +server = CompanionFrameServer( + bridge=bridge, + companion_hash="abcd1234", # identifier for this companion + port=5000, + device_model="pyMC-Companion", + device_version="1.0.0", +) + +await server.start() # starts listening on TCP port +# ... companion app connects and sends commands ... +await server.stop() +``` + +### Constructor + +```python +CompanionFrameServer( + bridge: CompanionBase, # the companion to wrap + companion_hash: str, # unique identifier + port: int = 5000, + bind_address: str = "0.0.0.0", + *, + device_model: str = "pyMC-Companion", + device_version: str = "1.0.0", + build_date: str = "", + local_hash: int | None = None, + stats_getter: Callable | None = None, + control_handler: Any | None = None, + heartbeat_interval: int = 15, # seconds between keepalive frames + client_idle_timeout_sec: int = 120, # no data from client → disconnect and free slot +) +``` + +**Connection management:** Only one client is allowed at a time. If a new connection arrives while one is already active, the server closes the existing connection and accepts the new one (same as firmware). If the client disappears without closing (e.g. kill, network drop), the slot is freed after no data is received for `client_idle_timeout_sec` seconds (default 120). Operators can tune this timeout to avoid dropping slow but live clients. + +### Supported Commands + +The frame server handles the following companion radio protocol commands: + +| CMD | Code | Description | +|---|---|---| +| `CMD_APP_START` | 1 | Initialize connection, return device info | +| `CMD_SEND_TXT_MSG` | 2 | Send a direct text message | +| `CMD_SEND_CHANNEL_TXT_MSG` | 3 | Send a channel message | +| `CMD_GET_CONTACTS` | 4 | Retrieve contact list (paginated) | +| `CMD_GET_DEVICE_TIME` | 5 | Get current device time | +| `CMD_SET_DEVICE_TIME` | 6 | Set device time | +| `CMD_SEND_SELF_ADVERT` | 7 | Broadcast self advertisement | +| `CMD_SET_ADVERT_NAME` | 8 | Set advertised node name | +| `CMD_ADD_UPDATE_CONTACT` | 9 | Add or update a contact | +| `CMD_SYNC_NEXT_MESSAGE` | 10 | Pop next queued message | +| `CMD_SET_RADIO_PARAMS` | 11 | Set frequency, bandwidth, SF, CR | +| `CMD_SET_RADIO_TX_POWER` | 12 | Set transmit power | +| `CMD_RESET_PATH` | 13 | Reset routing path for a contact | +| `CMD_SET_ADVERT_LATLON` | 14 | Set GPS coordinates | +| `CMD_REMOVE_CONTACT` | 15 | Remove a contact | +| `CMD_SHARE_CONTACT` | 16 | Share a contact to the mesh | +| `CMD_EXPORT_CONTACT` | 17 | Export contact as 73-byte blob | +| `CMD_IMPORT_CONTACT` | 18 | Import contact from blob | +| `CMD_EXPORT_PRIVATE_KEY` | 23 | Export private/signing key (64-byte MeshCore format) | +| `CMD_IMPORT_PRIVATE_KEY` | 24 | Import private key (stub/no-op; key set from config) | +| `CMD_SEND_RAW_DATA` | 25 | Send raw payload on given direct path | +| `CMD_GET_BATT_AND_STORAGE` | 20 | Get battery/storage info | +| `CMD_SET_TUNING_PARAMS` | 21 | Set RX delay and airtime factor | +| `CMD_DEVICE_QUERY` | 22 | Return device model/version | +| `CMD_SEND_LOGIN` | 26 | Login to a repeater | +| `CMD_SEND_STATUS_REQ` | 27 | Request repeater status | +| `CMD_LOGOUT` | 29 | Logout from a repeater | +| `CMD_GET_CONTACT_BY_KEY` | 30 | Look up contact by public key | +| `CMD_GET_CHANNEL` | 31 | Get a channel by index | +| `CMD_SET_CHANNEL` | 32 | Set a channel | +| `CMD_SEND_TRACE_PATH` | 36 | Send trace path request | +| `CMD_SEND_TELEMETRY_REQ` | 39 | Request telemetry data | +| `CMD_GET_CUSTOM_VARS` | 40 | Get custom variables | +| `CMD_SET_CUSTOM_VAR` | 41 | Set a custom variable | +| `CMD_GET_ADVERT_PATH` | 42 | Get cached advert path | +| `CMD_SEND_BINARY_REQ` | 50 | Send binary request | +| `CMD_SEND_PATH_DISCOVERY_REQ` | 52 | Send path discovery | +| `CMD_SET_FLOOD_SCOPE` | 54 | Set flood scope transport key | +| `CMD_SEND_CONTROL_DATA` | 55 | Send control data | +| `CMD_GET_STATS` | 56 | Get statistics | +| `CMD_SET_AUTOADD_CONFIG` | 58 | Set auto-add configuration | +| `CMD_GET_AUTOADD_CONFIG` | 59 | Get auto-add configuration | + +### Push Notifications + +The frame server sends unsolicited push frames to the companion app when events occur: + +| Push Code | Value | Description | +|---|---|---| +| `PUSH_CODE_ADVERT` | 0x80 | Contact advertisement received | +| `PUSH_CODE_MSG_WAITING` | 0x83 | New message queued | +| `PUSH_CODE_SEND_CONFIRMED` | 0x82 | ACK received for a sent message | +| `PUSH_CODE_RAW_DATA` | 0x84 | Raw custom payload received (SNR, RSSI, 0xFF, payload) | +| `PUSH_CODE_PATH_UPDATED` | 0x81 | Contact path updated | +| `PUSH_CODE_LOG_RX_DATA` | 0x88 | Raw RX packet (diagnostics) | +| `PUSH_CODE_TRACE_DATA` | 0x89 | Trace path response | +| `PUSH_CODE_NEW_ADVERT` | 0x8A | New (previously unknown) contact discovered | +| `PUSH_CODE_CONTROL_DATA` | 0x8E | Control data received | +| `PUSH_CODE_LOGIN_SUCCESS` | 0x91 | Repeater login succeeded | +| `PUSH_CODE_LOGIN_FAIL` | 0x92 | Repeater login failed | +| `PUSH_CODE_STATUS_RESPONSE` | 0x93 | Repeater status response | +| `PUSH_CODE_TELEMETRY_RESPONSE` | 0x8B | Telemetry response | +| `PUSH_CODE_BINARY_RESPONSE` | 0x95 | Binary request response | +| `PUSH_CODE_PATH_DISCOVERY_RESPONSE` | 0x96 | Path discovery response | + +### Host-Callable Push Methods + +The frame server exposes methods for the host application to push data to the connected companion app: + +```python +# Push trace data from the repeater (await for backpressure) +await server.push_trace_data( + path_len=3, flags=0, tag=42, auth_code=0, + path_hashes=b"...", path_snrs=b"...", final_snr_byte=0 +) + +# Push raw RX packet for diagnostics logging (sync: schedules send, works without await) +server.push_rx_raw(snr=-5.0, rssi=-100, raw=b"...") +# Or from async code with backpressure: +await server.push_rx_raw_async(snr=-5.0, rssi=-100, raw=b"...") + +# Push control data +await server.push_control_data( + snr=-5.0, rssi=-100, path_len=2, + path_bytes=b"...", payload=b"..." +) +``` + +### Persistence Hooks + +`CompanionFrameServer` provides no-op hooks that subclasses override for persistent storage (e.g. SQLite): + +```python +class MyFrameServer(CompanionFrameServer): + async def _persist_companion_message(self, msg_dict: dict) -> None: + """Called when a message is received. Save to database.""" + await self.db.save_message(msg_dict) + + def _sync_next_from_persistence(self) -> QueuedMessage | None: + """Called when the in-memory queue is empty. Pop from database.""" + return self.db.pop_oldest_message() + + def _save_contacts(self) -> None: + """Called after contact list changes. Sync to database.""" + self.db.save_contacts(self.bridge.contacts.to_dicts()) + + def _save_channels(self) -> None: + """Called after channel changes. Sync to database.""" + self.db.save_channels(...) + + def _get_batt_and_storage(self) -> tuple[int, int, int]: + """Return (millivolts, used_kb, total_kb) for CMD_GET_BATT_AND_STORAGE.""" + return (4200, 128, 1024) +``` + +--- + +## Persistence Hooks + +The companion module uses a "no-op hook" pattern for persistence: base classes define empty methods that subclasses override to save/load state from their storage backend. + +### CompanionBase Hooks (Preferences) + +```python +class CompanionBase: + def _save_prefs(self) -> None: + """Persist self.prefs. Called after any pref-mutating method.""" + + def _load_prefs(self) -> None: + """Restore self.prefs on startup. Called at end of _init_companion_stores().""" +``` + +`_save_prefs()` is called automatically by: `set_radio_params`, `set_tx_power`, `set_tuning_params`, `set_autoadd_config`, `set_other_params`, `set_advert_name`, `set_advert_latlon`. + +Note: `set_time()` does **not** call `_save_prefs()` — the time offset is a transient runtime correction, not a persistent preference. + +### CompanionFrameServer Hooks (Messages, Contacts, Channels) + +```python +class CompanionFrameServer: + async def _persist_companion_message(self, msg_dict: dict) -> None: ... + def _sync_next_from_persistence(self) -> QueuedMessage | None: ... + def _save_contacts(self) -> None: ... + def _save_channels(self) -> None: ... + def _get_batt_and_storage(self) -> tuple[int, int, int]: ... +``` + +--- + +## Use Cases + +### 1. Chat Application + +Build a terminal or GUI chat client that discovers peers and exchanges messages. + +```python +companion = CompanionRadio(radio, identity, node_name="ChatApp") + +companion.on_message_received(display_message) +companion.on_advert_received(add_to_contact_list) + +await companion.start() +await companion.advertise() + +# User picks a contact and sends +contact = companion.get_contact_by_name("Alice") +await companion.send_text_message(contact.public_key, user_input) +``` + +### 2. Sensor Gateway + +Collect telemetry from sensor nodes in the mesh, forward to a database or MQTT. + +```python +companion = CompanionRadio(radio, identity, node_name="Gateway", adv_type=ADV_TYPE_CHAT) + +async def on_telemetry(event_data): + # event_data contains parsed CayenneLPP sensor readings + publish_to_mqtt(event_data) + +companion.on_telemetry_response(on_telemetry) + +await companion.start() + +# Periodically poll known sensors +for sensor in companion.get_contacts(): + if sensor.adv_type == ADV_TYPE_SENSOR: + await companion.send_telemetry_request( + sensor.public_key, want_base=True, + want_location=True, want_environment=True, timeout=15, + ) +``` + +### 3. Repeater Companion (Bridge Mode) + +Add a companion identity to an existing repeater without a second radio. + +```python +bridge = CompanionBridge( + identity=identity, + packet_injector=repeater.inject_packet, + node_name="RepeaterBot", + authenticate_callback=auth_check, +) + +bridge.on_message_received(handle_bot_command) +await bridge.start() + +# In the repeater's RX loop: +async def repeater_on_rx(pkt): + await bridge.process_received_packet(pkt) + # ... also handle repeater logic ... +``` + +### 4. Companion Frame Server (TCP Protocol) + +Expose a companion over TCP so standard companion apps can connect. + +```python +bridge = CompanionBridge( + identity=identity, + packet_injector=repeater.inject_packet, + node_name="pyMC-Server", +) + +server = CompanionFrameServer( + bridge=bridge, + companion_hash="abcd1234", + port=5000, + device_model="pyMC-Companion", + device_version="1.0.0", +) + +await bridge.start() +await server.start() +# Companion apps (Android/iOS) can now connect on port 5000 +``` + +### 5. Network Diagnostics Tool + +Trace paths and discover topology. + +```python +companion.on_trace_received(lambda data: print(f"Trace: {data}")) +companion.on_path_discovery_response( + lambda tag, key, out_path, in_path: print(f"Path to {key.hex()[:8]}: out={out_path}, in={in_path}") +) + +# Trace route to a node +await companion.send_trace_path(target_key, tag=1, auth_code=0) + +# Discover paths +await companion.send_path_discovery_req(target_key) +``` + +### 6. Group Chat / Channels + +```python +companion.set_channel(0, name="Emergency", secret=b"shared_channel_secret___________") +companion.set_channel(1, name="General", secret=b"another_shared_secret___________") + +companion.on_channel_message_received( + lambda ch_name, sender, text, ts, path_len, idx, *args: + print(f"[{ch_name}] {sender}: {text}") +) + +await companion.send_channel_message(0, "Emergency broadcast") +``` + +--- + +## Push Callbacks Reference + +Register callbacks to receive asynchronous events. Both sync and async functions are supported. +Callbacks for `on_message_received` and `on_channel_message_received` receive optional trailing args +`(packet_hash, snr, rssi)` when available; use `*args` to ignore them. + +| Registration Method | Callback Signature | +|---|---| +| `on_message_received` | `(sender_key: bytes, text: str, timestamp: int, txt_type: int [, packet_hash, snr, rssi])` | +| `on_channel_message_received` | `(channel_name: str, sender_name: str, text: str, timestamp: int, path_len: int, channel_idx: int [, packet_hash, snr, rssi])` | +| `on_advert_received` | `(contact: Contact)` | +| `on_contact_path_updated` | `(contact: Contact)` | +| `on_send_confirmed` | `(ack_crc: int)` | +| `on_trace_received` | `(trace_data)` | +| `on_node_discovered` | `(event_data)` | +| `on_login_result` | `(result_data)` | +| `on_telemetry_response` | `(event_data)` | +| `on_status_response` | `(status_data)` | +| `on_raw_data_received` | `(payload: bytes, snr: float, rssi: int)` — raw custom packet received | +| `on_rx_log_data` | `(snr: float, rssi: int, raw_bytes: bytes)` — **CompanionRadio only**; same data as PUSH 0x88 LOG_RX_DATA | +| `on_binary_response` | `(tag: bytes, data: bytes, parsed: dict\|None, request_type: int\|None)` | +| `on_path_discovery_response` | `(tag: bytes, contact_pubkey: bytes, out_path: bytes, in_path: bytes)` | + +--- + +## Models Reference + +### Contact + +```python +@dataclass +class Contact: + public_key: bytes # 32-byte Ed25519 public key + name: str = "" # up to 32 characters + adv_type: int = 0 # ADV_TYPE_CHAT / REPEATER / ROOM / SENSOR + flags: int = 0 + out_path_len: int = -1 # -1=unknown, 0=direct, >0=multi-hop + out_path: bytes = b"" + last_advert_timestamp: int = 0 + lastmod: int = 0 + gps_lat: float = 0.0 + gps_lon: float = 0.0 + sync_since: int = 0 +``` + +### Channel + +```python +@dataclass +class Channel: + name: str # up to 32 characters + secret: bytes # 16-byte pre-shared key +``` + +### NodePrefs + +```python +@dataclass +class NodePrefs: + node_name: str = "pyMC" + adv_type: int = 1 # ADV_TYPE_CHAT + tx_power_dbm: int = 20 + frequency_hz: int = 915000000 + bandwidth_hz: int = 250000 + spreading_factor: int = 10 + coding_rate: int = 5 + latitude: float = 0.0 + longitude: float = 0.0 + advert_loc_policy: int = 0 # ADVERT_LOC_NONE + multi_acks: int = 0 + telemetry_mode_base: int = 0 # TELEM_MODE_DENY + telemetry_mode_location: int = 0 + telemetry_mode_environment: int = 0 + manual_add_contacts: int = 0 + autoadd_config: int = 0 + rx_delay_base: float = 0.0 + airtime_factor: float = 0.0 + client_repeat: int = 0 # reported in CMD_DEVICE_QUERY device info frame (byte 80) +``` + +### SentResult + +```python +@dataclass +class SentResult: + success: bool + is_flood: bool = False + expected_ack: int | None = None + timeout_ms: int | None = None +``` + +### QueuedMessage + +```python +@dataclass +class QueuedMessage: + sender_key: bytes + txt_type: int = 0 # TXT_TYPE_PLAIN / CLI_DATA / SIGNED_PLAIN + timestamp: int = 0 + text: str = "" + is_channel: bool = False + channel_idx: int = 0 + path_len: int = 0 +``` + +### AdvertPath + +```python +@dataclass +class AdvertPath: + public_key_prefix: bytes # 7-byte prefix + name: str = "" + path_len: int = 0 + path: bytes = b"" + recv_timestamp: int = 0 +``` + +### PacketStats + +```python +@dataclass +class PacketStats: + flood_tx: int = 0 + flood_rx: int = 0 + direct_tx: int = 0 + direct_rx: int = 0 + tx_errors: int = 0 +``` + +--- + +## Constants + +```python +# Advertisement types +ADV_TYPE_CHAT = 1 +ADV_TYPE_REPEATER = 2 +ADV_TYPE_ROOM = 3 +ADV_TYPE_SENSOR = 4 + +# Text message types +TXT_TYPE_PLAIN = 0 +TXT_TYPE_CLI_DATA = 1 +TXT_TYPE_SIGNED_PLAIN = 2 + +# Telemetry modes +TELEM_MODE_DENY = 0 +TELEM_MODE_ALLOW_FLAGS = 1 +TELEM_MODE_ALLOW_ALL = 2 + +# Location policy +ADVERT_LOC_NONE = 0 +ADVERT_LOC_SHARE = 1 + +# Auto-add config bitmask +AUTOADD_OVERWRITE_OLDEST = 0x01 +AUTOADD_CHAT = 0x02 +AUTOADD_REPEATER = 0x04 +AUTOADD_ROOM = 0x08 +AUTOADD_SENSOR = 0x10 + +# Stats types +STATS_TYPE_CORE = 0 +STATS_TYPE_RADIO = 1 +STATS_TYPE_PACKETS = 2 + +# Binary request types (IntEnum) +BinaryReqType.STATUS # 0x01 +BinaryReqType.KEEP_ALIVE # 0x02 +BinaryReqType.TELEMETRY # 0x03 +BinaryReqType.MMA # 0x04 +BinaryReqType.ACL # 0x05 +BinaryReqType.NEIGHBOURS # 0x06 + +# Protocol codes +PROTOCOL_CODE_RAW_DATA = 0x00 +PROTOCOL_CODE_BINARY_REQ = 0x02 +PROTOCOL_CODE_ANON_REQ = 0x07 + +# Frame format +FRAME_OUTBOUND_PREFIX = 0x3E # '>' (radio → app) +FRAME_INBOUND_PREFIX = 0x3C # '<' (app → radio) +MAX_FRAME_SIZE = 172 +MAX_PAYLOAD_SIZE = 169 # MAX_FRAME_SIZE - 3 (prefix + 2-byte length) + +# Error codes (returned by frame server) +ERR_CODE_UNSUPPORTED_CMD = 1 +ERR_CODE_NOT_FOUND = 2 +ERR_CODE_TABLE_FULL = 3 +ERR_CODE_BAD_STATE = 4 +ERR_CODE_FILE_IO_ERROR = 5 +ERR_CODE_ILLEGAL_ARG = 6 + +# Defaults +DEFAULT_RESPONSE_TIMEOUT_MS = 10000 +DEFAULT_MAX_CONTACTS = 1000 +DEFAULT_MAX_CHANNELS = 40 +DEFAULT_OFFLINE_QUEUE_SIZE = 512 +PUB_KEY_SIZE = 32 +MAX_PATH_SIZE = 64 +``` + +--- + +## Unimplemented MeshCore Companion Features + +The following protocol-level features from the MeshCore companion radio firmware (`examples/companion_radio/`) are **not yet implemented** in pyMC_core. CMD_SEND_RAW_DATA (25) and PUSH_CODE_RAW_DATA (0x84) for raw custom packets are implemented. + +| Feature | Firmware Reference | Description | +|---|---|---| +| Has connection | `CMD_HAS_CONNECTION` (0x1C) | Check if active connection exists to a contact | +| Push: contact deleted | `PUSH_CODE_CONTACT_DELETED` (0x8F) | Notification when a contact is overwritten by auto-add | +| Push: contacts full | `PUSH_CODE_CONTACTS_FULL` (0x90) | Notification when contact storage is full | +| Keep-alive mechanism | Server-driven keep-alive | Periodic keep-alive packets for active server connections | diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index bf450d3..c688b1b 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -96,6 +96,7 @@ markdown_extensions: nav: - Home: index.md - Node Usage Guide: node.md + - Companion Guide: companion.md - API Reference: - Core: api/core.md - Protocol: api/protocol.md diff --git a/examples/discover_nodes.py b/examples/discover_nodes.py index 5e69845..196b583 100644 --- a/examples/discover_nodes.py +++ b/examples/discover_nodes.py @@ -177,9 +177,7 @@ def main(): if args.radio_type == "kiss-tnc": print(f"Serial port: {args.serial_port}") - asyncio.run( - discover_nodes(args.radio_type, args.serial_port, args.timeout, args.filter) - ) + asyncio.run(discover_nodes(args.radio_type, args.serial_port, args.timeout, args.filter)) if __name__ == "__main__": diff --git a/scripts/test_modem_crypto.py b/scripts/test_modem_crypto.py index cb4bbf4..d8f18f8 100644 --- a/scripts/test_modem_crypto.py +++ b/scripts/test_modem_crypto.py @@ -9,14 +9,14 @@ - Encryption/decryption """ -import sys import os +import sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) -from pymc_core.hardware.kiss_modem_wrapper import KissModemWrapper -from pymc_core.protocol.identity import Identity, LocalIdentity -from pymc_core.protocol.crypto import CryptoUtils +from pymc_core.hardware.kiss_modem_wrapper import KissModemWrapper # noqa: E402 +from pymc_core.protocol.crypto import CryptoUtils # noqa: E402 +from pymc_core.protocol.identity import Identity, LocalIdentity # noqa: E402 def test_modem_crypto(port: str = "/dev/cu.usbmodem1101"): @@ -149,7 +149,7 @@ def test_modem_crypto(port: str = "/dev/cu.usbmodem1101"): modem_decrypted = modem.decrypt_data(key, mac, ciphertext) if modem_decrypted: # Trim padding (modem pads to block size) - modem_decrypted = modem_decrypted[:len(plaintext)] + modem_decrypted = modem_decrypted[: len(plaintext)] print(f"Modem decrypted: {modem_decrypted}") if modem_decrypted == plaintext: @@ -178,7 +178,7 @@ def test_modem_crypto(port: str = "/dev/cu.usbmodem1101"): # Decrypt with modem modem_decrypted2 = modem.decrypt_data(key, python_mac, python_ciphertext) if modem_decrypted2: - modem_decrypted2 = modem_decrypted2[:len(plaintext)] + modem_decrypted2 = modem_decrypted2[: len(plaintext)] print(f"Modem decrypted Python ciphertext: {modem_decrypted2}") if modem_decrypted2 == plaintext: diff --git a/src/pymc_core/__init__.py b/src/pymc_core/__init__.py index 0d1e9ab..04d3f34 100644 --- a/src/pymc_core/__init__.py +++ b/src/pymc_core/__init__.py @@ -21,5 +21,17 @@ "__version__", ] +# Conditional import for CompanionRadio +try: + from .companion.companion_radio import CompanionRadio + + _COMPANION_AVAILABLE = True +except ImportError: + _COMPANION_AVAILABLE = False + CompanionRadio = None + +if _COMPANION_AVAILABLE: + __all__.append("CompanionRadio") + # End of mesh package exports diff --git a/src/pymc_core/companion/__init__.py b/src/pymc_core/companion/__init__.py new file mode 100644 index 0000000..91fbf3d --- /dev/null +++ b/src/pymc_core/companion/__init__.py @@ -0,0 +1,103 @@ +""" +MeshCore Companion Radio - Python-native implementation. + +Provides contact management, messaging with offline queue, advertisement +broadcasting, channel management, path tracking, signing, telemetry, +statistics, and device configuration on top of MeshNode. +""" + +from .channel_store import ChannelStore +from .companion_bridge import CompanionBridge +from .companion_radio import CompanionRadio +from .constants import ( + ADV_TYPE_CHAT, + ADV_TYPE_REPEATER, + ADV_TYPE_ROOM, + ADV_TYPE_SENSOR, + ADVERT_LOC_NONE, + ADVERT_LOC_SHARE, + AUTOADD_CHAT, + AUTOADD_OVERWRITE_OLDEST, + AUTOADD_REPEATER, + AUTOADD_ROOM, + AUTOADD_SENSOR, + DEFAULT_MAX_CHANNELS, + DEFAULT_MAX_CONTACTS, + DEFAULT_OFFLINE_QUEUE_SIZE, + MSG_SEND_FAILED, + MSG_SEND_SENT_DIRECT, + MSG_SEND_SENT_FLOOD, + STATS_TYPE_CORE, + STATS_TYPE_PACKETS, + STATS_TYPE_RADIO, + TELEM_MODE_ALLOW_ALL, + TELEM_MODE_ALLOW_FLAGS, + TELEM_MODE_DENY, + TXT_TYPE_CLI_DATA, + TXT_TYPE_PLAIN, + TXT_TYPE_SIGNED_PLAIN, + BinaryReqType, +) +from .contact_store import ContactStore +from .frame_server import CompanionFrameServer +from .message_queue import MessageQueue +from .models import AdvertPath, Channel, Contact, NodePrefs, PacketStats, QueuedMessage, SentResult +from .path_cache import PathCache +from .stats_collector import StatsCollector + +__all__ = [ + # Main classes + "CompanionRadio", + "CompanionBridge", + "CompanionFrameServer", + # Stores + "ContactStore", + "ChannelStore", + "MessageQueue", + "PathCache", + "StatsCollector", + # Models + "Contact", + "Channel", + "NodePrefs", + "SentResult", + "PacketStats", + "AdvertPath", + "QueuedMessage", + # ADV Types + "ADV_TYPE_CHAT", + "ADV_TYPE_REPEATER", + "ADV_TYPE_ROOM", + "ADV_TYPE_SENSOR", + # Text Types + "TXT_TYPE_PLAIN", + "TXT_TYPE_CLI_DATA", + "TXT_TYPE_SIGNED_PLAIN", + # Telemetry Modes + "TELEM_MODE_DENY", + "TELEM_MODE_ALLOW_FLAGS", + "TELEM_MODE_ALLOW_ALL", + # Location Policy + "ADVERT_LOC_NONE", + "ADVERT_LOC_SHARE", + # Auto-Add Config + "AUTOADD_OVERWRITE_OLDEST", + "AUTOADD_CHAT", + "AUTOADD_REPEATER", + "AUTOADD_ROOM", + "AUTOADD_SENSOR", + # Message Send Result + "MSG_SEND_FAILED", + "MSG_SEND_SENT_FLOOD", + "MSG_SEND_SENT_DIRECT", + # Binary request types + "BinaryReqType", + # Stats Types + "STATS_TYPE_CORE", + "STATS_TYPE_RADIO", + "STATS_TYPE_PACKETS", + # Defaults + "DEFAULT_MAX_CONTACTS", + "DEFAULT_MAX_CHANNELS", + "DEFAULT_OFFLINE_QUEUE_SIZE", +] diff --git a/src/pymc_core/companion/binary_parsing.py b/src/pymc_core/companion/binary_parsing.py new file mode 100644 index 0000000..fc150da --- /dev/null +++ b/src/pymc_core/companion/binary_parsing.py @@ -0,0 +1,150 @@ +"""Parse binary response payloads by request type (BinaryReqType).""" + +from __future__ import annotations + +import logging +from typing import Optional + +from .constants import BinaryReqType + +logger = logging.getLogger(__name__) + + +def parse_binary_response( + request_type: int, + data: bytes, + pubkey_prefix: str = "", + context: Optional[dict] = None, +) -> Optional[dict]: + """Parse response_data by request_type. Returns dict or None.""" + if request_type == BinaryReqType.STATUS and len(data) >= 52: + return _parse_status(data, pubkey_prefix=pubkey_prefix or None) + if request_type == BinaryReqType.TELEMETRY and len(data) >= 0: + return _parse_telemetry(data) + if request_type == BinaryReqType.MMA and len(data) >= 4: + return _parse_mma(data[4:]) # skip 4-byte header + if request_type == BinaryReqType.ACL: + return _parse_acl(data) + if request_type == BinaryReqType.NEIGHBOURS: + return _parse_neighbours(data, context or {}) + if request_type == BinaryReqType.OWNER_INFO and len(data) >= 4: + return _parse_owner_info(data) + return {"raw_hex": data.hex(), "request_type": request_type} + + +def _parse_status(data: bytes, pubkey_prefix: Optional[str] = None, offset: int = 0) -> dict: + """Parse status response (52 bytes).""" + res = {} + if pubkey_prefix is None and len(data) >= 8: + res["pubkey_pre"] = data[2:8].hex() + offset = 8 + else: + res["pubkey_pre"] = pubkey_prefix or "" + res["bat"] = int.from_bytes(data[offset : offset + 2], byteorder="little") + res["tx_queue_len"] = int.from_bytes(data[offset + 2 : offset + 4], byteorder="little") + res["noise_floor"] = int.from_bytes( + data[offset + 4 : offset + 6], byteorder="little", signed=True + ) + res["last_rssi"] = int.from_bytes( + data[offset + 6 : offset + 8], byteorder="little", signed=True + ) + res["nb_recv"] = int.from_bytes(data[offset + 8 : offset + 12], byteorder="little") + res["nb_sent"] = int.from_bytes(data[offset + 12 : offset + 16], byteorder="little") + res["airtime"] = int.from_bytes(data[offset + 16 : offset + 20], byteorder="little") + res["uptime"] = int.from_bytes(data[offset + 20 : offset + 24], byteorder="little") + res["sent_flood"] = int.from_bytes(data[offset + 24 : offset + 28], byteorder="little") + res["sent_direct"] = int.from_bytes(data[offset + 28 : offset + 32], byteorder="little") + res["recv_flood"] = int.from_bytes(data[offset + 32 : offset + 36], byteorder="little") + res["recv_direct"] = int.from_bytes(data[offset + 36 : offset + 40], byteorder="little") + res["full_evts"] = int.from_bytes(data[offset + 40 : offset + 42], byteorder="little") + res["last_snr"] = ( + int.from_bytes(data[offset + 42 : offset + 44], byteorder="little", signed=True) / 4 + ) + res["direct_dups"] = int.from_bytes(data[offset + 44 : offset + 46], byteorder="little") + res["flood_dups"] = int.from_bytes(data[offset + 46 : offset + 48], byteorder="little") + res["rx_airtime"] = int.from_bytes(data[offset + 48 : offset + 52], byteorder="little") + return res + + +def _parse_telemetry(data: bytes) -> dict: + """Telemetry: Cayenne LPP or raw. Dict has raw_hex; optional LPP if cayennelpp available.""" + out: dict = {"raw_hex": data.hex()} + try: + from cayennelpp import LppFrame + + frame = LppFrame.from_bytes(data) + out["lpp"] = [ + {"channel": d.channel, "type": d.type_id, "value": d.data} for d in frame.data + ] + except Exception: + logger.debug("Optional LPP parse failed for telemetry", exc_info=True) + return out + + +def _parse_mma(data: bytes) -> dict: + """MMA: LPP min/max/avg or raw.""" + out: dict = {"raw_hex": data.hex()} + try: + from cayennelpp import LppFrame + + frame = LppFrame.from_bytes(data) + out["mma"] = [{"channel": d.channel, "type": d.type_id, "data": d.data} for d in frame.data] + except Exception: + logger.debug("Optional LPP parse failed for MMA", exc_info=True) + return out + + +def _parse_owner_info(data: bytes) -> dict: + """Parse GET_OWNER_INFO response: tag(4) + 'version\\nname\\nowner' (variable).""" + try: + text = data[4:].decode("utf-8", errors="replace").strip() + parts = text.split("\n", 2) + return { + "tag": int.from_bytes(data[:4], "little"), + "version": parts[0] if len(parts) > 0 else "", + "node_name": parts[1] if len(parts) > 1 else "", + "owner_info": parts[2] if len(parts) > 2 else "", + "raw_text": text, + } + except Exception: + logger.debug("Owner info parse failed, returning fallback", exc_info=True) + return {"raw_hex": data.hex(), "request_type": BinaryReqType.OWNER_INFO} + + +def _parse_acl(buf: bytes) -> dict: + """ACL: 7-byte entries (key 6 + perm 1).""" + res = [] + i = 0 + while i + 7 <= len(buf): + key = buf[i : i + 6].hex() + perm = buf[i + 6] + if key != "000000000000": + res.append({"key": key, "perm": perm}) + i += 7 + return {"acl": res} + + +def _parse_neighbours(data: bytes, context: dict) -> dict: + """Neighbours: count(2) + results_count(2) + entries (pubkey_prefix + secs_ago(4) + snr(1)).""" + if len(data) < 4: + return {"raw_hex": data.hex()} + pk_plen = context.get("pubkey_prefix_length", 6) + neighbours_count = int.from_bytes(data[0:2], "little", signed=True) + results_count = int.from_bytes(data[2:4], "little", signed=True) + neighbours_list = [] + i = 4 + for _ in range(results_count): + if i + pk_plen + 4 + 1 > len(data): + break + pubkey = data[i : i + pk_plen].hex() + i += pk_plen + secs_ago = int.from_bytes(data[i : i + 4], "little", signed=True) + i += 4 + snr = int.from_bytes(data[i : i + 1], "little", signed=True) / 4 + i += 1 + neighbours_list.append({"pubkey": pubkey, "secs_ago": secs_ago, "snr": snr}) + return { + "neighbours_count": neighbours_count, + "results_count": results_count, + "neighbours": neighbours_list, + } diff --git a/src/pymc_core/companion/channel_store.py b/src/pymc_core/companion/channel_store.py new file mode 100644 index 0000000..81fd03a --- /dev/null +++ b/src/pymc_core/companion/channel_store.py @@ -0,0 +1,85 @@ +"""In-memory channel storage compatible with MeshNode's channel_db interface.""" + +from __future__ import annotations + +from typing import Optional + +from .constants import DEFAULT_MAX_CHANNELS +from .models import Channel + + +class ChannelStore: + """In-memory channel storage compatible with MeshNode's channel_db interface. + + Provides both the interface expected by GroupTextHandler (get_channels returning + list of dicts) and companion radio operations (get/set/remove by index). + """ + + def __init__(self, max_channels: int = DEFAULT_MAX_CHANNELS): + self._channels: list[Optional[Channel]] = [None] * max_channels + self._max_channels = max_channels + + @property + def max_channels(self) -> int: + """Maximum number of channels (read-only). Used by companion protocol device info.""" + return self._max_channels + + # ------------------------------------------------------------------ + # Interface expected by GroupTextHandler / PacketBuilder + # ------------------------------------------------------------------ + + def get_channels(self) -> list[dict]: + """Return channels as list of dicts with 'name' and 'secret' keys. + + The secret is returned as a hex string, which is what the existing + GroupTextHandler and PacketBuilder expect. + """ + result = [] + for ch in self._channels: + if ch is not None: + result.append( + { + "name": ch.name, + "secret": ch.secret.hex(), + } + ) + return result + + # ------------------------------------------------------------------ + # Companion radio methods + # ------------------------------------------------------------------ + + def get(self, idx: int) -> Optional[Channel]: + """Get a channel by index. Returns None if index invalid or empty.""" + if 0 <= idx < self._max_channels: + return self._channels[idx] + return None + + def set(self, idx: int, channel: Channel) -> bool: + """Set a channel at the given index. Returns False if index out of range.""" + if 0 <= idx < self._max_channels: + self._channels[idx] = channel + return True + return False + + def remove(self, idx: int) -> bool: + """Remove a channel at the given index. Returns False if index invalid or already empty.""" + if 0 <= idx < self._max_channels and self._channels[idx] is not None: + self._channels[idx] = None + return True + return False + + def find_by_name(self, name: str) -> Optional[int]: + """Find a channel index by name. Returns None if not found.""" + for idx, ch in enumerate(self._channels): + if ch is not None and ch.name == name: + return idx + return None + + def get_count(self) -> int: + """Return the number of configured channels.""" + return sum(1 for ch in self._channels if ch is not None) + + def clear(self) -> None: + """Remove all channels.""" + self._channels = [None] * self._max_channels diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py new file mode 100644 index 0000000..208542b --- /dev/null +++ b/src/pymc_core/companion/companion_base.py @@ -0,0 +1,1828 @@ +""" +CompanionBase - Shared logic for CompanionRadio and CompanionBridge. + +Provides stores, event handling, contact management, device configuration, +and push callbacks. Subclasses implement TX via MeshNode or packet_injector. +""" + +from __future__ import annotations + +import asyncio +import copy +import logging +import random +import struct +import time +from abc import ABC, abstractmethod +from collections import OrderedDict +from typing import Any, Callable, Iterable, Optional + +from ..node.events import EventService, EventSubscriber, MeshEvents +from ..protocol import LocalIdentity, Packet, PacketBuilder +from ..protocol.constants import ( + ADVERT_FLAG_HAS_LOCATION, + ADVERT_FLAG_HAS_NAME, + ADVERT_FLAG_IS_CHAT_NODE, + ADVERT_FLAG_IS_REPEATER, + ADVERT_FLAG_IS_ROOM_SERVER, + ADVERT_FLAG_IS_SENSOR, + MAX_PACKET_PAYLOAD, + MAX_PATH_SIZE, + PAYLOAD_TYPE_CONTROL, + PAYLOAD_TYPE_TRACE, + REQ_TYPE_GET_STATUS, + REQ_TYPE_GET_TELEMETRY_DATA, + ROUTE_TYPE_FLOOD, + ROUTE_TYPE_TRANSPORT_FLOOD, + TELEM_PERM_BASE, +) +from ..protocol.packet_utils import PathUtils +from ..protocol.transport_keys import calc_transport_code, get_auto_key_for +from .channel_store import ChannelStore +from .constants import ( + ADV_TYPE_CHAT, + ADV_TYPE_REPEATER, + ADV_TYPE_ROOM, + ADV_TYPE_SENSOR, + ADVERT_LOC_SHARE, + AUTOADD_CHAT, + AUTOADD_OVERWRITE_OLDEST, + AUTOADD_REPEATER, + AUTOADD_ROOM, + AUTOADD_SENSOR, + DEFAULT_MAX_CHANNELS, + DEFAULT_MAX_CONTACTS, + DEFAULT_OFFLINE_QUEUE_SIZE, + DEFAULT_RESPONSE_TIMEOUT_MS, + MAX_PENDING_ACK_CRCS, + MAX_SIGN_DATA_SIZE, + PROTOCOL_CODE_ANON_REQ, + PROTOCOL_CODE_BINARY_REQ, + PROTOCOL_CODE_RAW_DATA, + PUSH_CODE_TELEMETRY_RESPONSE, + STATS_TYPE_CORE, + STATS_TYPE_PACKETS, + STATS_TYPE_RADIO, + TXT_TYPE_PLAIN, +) +from .contact_store import ContactStore +from .message_queue import MessageQueue +from .models import AdvertPath, Channel, Contact, NodePrefs, QueuedMessage, SentResult +from .path_cache import PathCache +from .stats_collector import StatsCollector + +logger = logging.getLogger("CompanionBase") + +PUSH_CALLBACK_KEYS = [ + "message_received", + "channel_message_received", + "advert_received", + "contact_path_updated", + "send_confirmed", + "trace_received", + "node_discovered", + "login_result", + "telemetry_response", + "status_response", + "raw_data_received", + "rx_log_data", # raw RX with SNR/RSSI (CompanionRadio only; matches PUSH 0x88) + "binary_response", + "path_discovery_response", + "contact_deleted", + "contacts_full", + "channel_updated", +] + + +class ResponseWaiter: + """Helper for awaiting async protocol/login responses.""" + + def __init__(self) -> None: + self.event = asyncio.Event() + self.data: dict = {"success": False, "text": None, "parsed": {}} + + def callback( + self, + success: bool, + text: str, + parsed_data: Optional[dict] = None, + ) -> None: + self.data["success"] = success + self.data["text"] = text + self.data["parsed"] = parsed_data or {} + self.event.set() + + async def wait(self, timeout: float = 10.0) -> dict: + try: + await asyncio.wait_for(self.event.wait(), timeout=timeout) + return self.data + except asyncio.TimeoutError: + return {**self.data, "timeout": True} + + +class _CompanionEventSubscriber(EventSubscriber): + """Bridges event service to companion push callbacks.""" + + def __init__(self, companion: CompanionBase) -> None: + self._companion = companion + + async def handle_event(self, event_type: str, data: dict) -> None: + await self._companion._handle_mesh_event(event_type, data) + + +def adv_type_to_flags(adv_type: int) -> int: + """Convert ADV_TYPE_* constant to advertisement flags byte.""" + if adv_type == ADV_TYPE_CHAT: + return ADVERT_FLAG_IS_CHAT_NODE + elif adv_type == ADV_TYPE_REPEATER: + return ADVERT_FLAG_IS_REPEATER + elif adv_type == ADV_TYPE_ROOM: + return ADVERT_FLAG_IS_ROOM_SERVER + elif adv_type == ADV_TYPE_SENSOR: + return ADVERT_FLAG_IS_SENSOR + return ADVERT_FLAG_IS_CHAT_NODE + + +class CompanionBase(ABC): + """Abstract base class for companion implementations. + + Provides shared stores, event handling, contact management, device config, + and push callbacks. Subclasses implement TX (via node or packet_injector). + """ + + def _init_companion_stores( + self, + identity: LocalIdentity, + node_name: str = "pyMC", + adv_type: int = ADV_TYPE_CHAT, + max_contacts: int = DEFAULT_MAX_CONTACTS, + max_channels: int = DEFAULT_MAX_CHANNELS, + offline_queue_size: int = DEFAULT_OFFLINE_QUEUE_SIZE, + radio_config: Optional[dict] = None, + initial_contacts: Optional[Iterable[Contact]] = None, + ) -> None: + """Initialize shared stores, prefs, event service, and push callbacks.""" + self._identity = identity + self._radio_config = radio_config or {} + self._running = False + + self.contacts = ContactStore(max_contacts) + self.channels = ChannelStore(max_channels) + self.message_queue = MessageQueue(offline_queue_size) + self.path_cache = PathCache() + self.stats = StatsCollector() + + self.prefs = NodePrefs( + node_name=node_name, + adv_type=adv_type, + tx_power_dbm=self._radio_config.get("power", self._radio_config.get("tx_power", 20)), + frequency_hz=self._radio_config.get("frequency", 915000000), + bandwidth_hz=self._radio_config.get("bandwidth", 250000), + spreading_factor=self._radio_config.get("spreading_factor", 10), + coding_rate=self._radio_config.get("coding_rate", 5), + ) + + self._custom_vars: dict[str, str] = {} + self._sign_buffer: Optional[bytearray] = None + self._flood_transport_key: Optional[bytes] = None + self._time_offset: float = 0.0 + + self._event_service = EventService() + self._event_subscriber = _CompanionEventSubscriber(self) + self._event_service.subscribe_all(self._event_subscriber) + + self._push_callbacks: dict[str, list[Callable]] = {k: [] for k in PUSH_CALLBACK_KEYS} + + # Pending binary requests by tag (hex) for matching responses + self._pending_binary_requests: dict[str, dict] = {} + # Pending path discovery tags for matching responses + self._pending_discovery_tags: set[int] = set() + # Pending ACK CRCs for send_confirmed (Bridge and Radio) + self._pending_ack_crcs: set[int] = set() + + # GRP_TXT dedup by packet hash: match Mesh.cpp (!_tables->hasSeen(pkt)); + # companion queues one frame per logical message like the firmware. + self._seen_grp_txt: OrderedDict[str, float] = OrderedDict() + self._seen_grp_txt_ttl = 300 + self._seen_grp_txt_max = 1000 + # TXT_MSG (direct) dedup by packet hash so reconnects don't re-queue same packet. + self._seen_txt: OrderedDict[str, float] = OrderedDict() + self._seen_txt_ttl = 300 + self._seen_txt_max = 1000 + + # Allow subclasses to restore persisted preferences on startup. + self._load_prefs() + + # Optional bulk load of contacts (e.g. from persistence on boot). + if initial_contacts is not None: + self.contacts.load_from(initial_contacts) + + # ------------------------------------------------------------------------- + # Preference Persistence Hooks + # ------------------------------------------------------------------------- + + def _save_prefs(self) -> None: + """Hook: persist the current :attr:`prefs` to stable storage. + + The default implementation is a no-op — preferences live only in + memory. Subclasses that need persistence (e.g. backed by SQLite or + a JSON file) should override this method. + + Called automatically after any preference-mutating method + (``set_radio_params``, ``set_tx_power``, ``set_tuning_params``, + ``set_autoadd_config``, ``set_other_params``, + ``set_advert_name``, ``set_advert_latlon``). + """ + + def _load_prefs(self) -> None: + """Hook: restore :attr:`prefs` from stable storage on startup. + + The default implementation is a no-op. Subclasses should override + to populate :attr:`self.prefs` fields from their persistence layer. + + Called once at the end of :meth:`_init_companion_stores`. + """ + + # ------------------------------------------------------------------------- + # Contact Management + # ------------------------------------------------------------------------- + + def get_contacts(self, since: int = 0) -> list[Contact]: + """Return all contacts, optionally filtered by modification time.""" + return self.contacts.get_all(since=since) + + def get_contact_by_key(self, pub_key: bytes) -> Optional[Contact]: + """Look up a contact by its full 32-byte public key.""" + return self.contacts.get_by_key(pub_key) + + def get_contact_by_name(self, name: str) -> Optional[Contact]: + """Look up a contact by name, returning the full Contact or None.""" + proxy = self.contacts.get_by_name(name) + if proxy: + return self.contacts.get_by_key(bytes.fromhex(proxy.public_key)) + return None + + def add_update_contact(self, contact: Contact) -> bool: + """Add or update a contact, setting lastmod if unset.""" + if contact.lastmod == 0: + contact.lastmod = int(time.time()) + return self.contacts.add(contact) + + def remove_contact(self, pub_key: bytes) -> bool: + """Remove a contact by public key.""" + return self.contacts.remove(pub_key) + + def export_contact(self, pub_key: Optional[bytes] = None) -> Optional[bytes]: + """Export a contact (or self) as a 73-byte binary packet.""" + if pub_key is None: + key = self._identity.get_public_key() + name = self.prefs.node_name.encode("utf-8")[:32] + name = name + b"\x00" * (32 - len(name)) + lat = int(self.prefs.latitude * 1e6) + lon = int(self.prefs.longitude * 1e6) + return struct.pack( + "<32sB32sii", + key, + self.prefs.adv_type, + name, + lat, + lon, + ) + contact = self.contacts.get_by_key(pub_key) + if not contact: + return None + name = contact.name.encode("utf-8")[:32] + name = name + b"\x00" * (32 - len(name)) + lat = int(contact.gps_lat * 1e6) + lon = int(contact.gps_lon * 1e6) + return struct.pack( + "<32sB32sii", + contact.public_key, + contact.adv_type, + name, + lat, + lon, + ) + + def import_contact(self, packet_data: bytes) -> bool: + """Import a contact from a 73-byte binary packet.""" + if len(packet_data) < 73: + logger.warning(f"Import data too short: {len(packet_data)} bytes") + return False + try: + pub_key = packet_data[:32] + adv_type = packet_data[32] + name_raw = packet_data[33:65] + lat, lon = struct.unpack_from(" None: + """Set the node's advertised name (max 31 chars).""" + self.prefs.node_name = name[:31] + self._save_prefs() + self._sync_our_node_name_to_handlers() + + def _get_group_text_handler(self) -> Optional[Any]: + """Return the group text handler for name sync, or None. Override in Radio/Bridge.""" + return None + + def _sync_our_node_name_to_handlers(self) -> None: + """Sync node name to group text handler for echo detection.""" + handler = self._get_group_text_handler() + if handler is not None: + handler.set_our_node_name(self.prefs.node_name) + + def set_advert_latlon(self, lat: float, lon: float) -> None: + """Set the GPS coordinates included in advertisements.""" + if not (-90.0 <= lat <= 90.0): + raise ValueError(f"Latitude out of range: {lat}") + if not (-180.0 <= lon <= 180.0): + raise ValueError(f"Longitude out of range: {lon}") + self.prefs.latitude = lat + self.prefs.longitude = lon + self._save_prefs() + + def set_radio_params(self, freq_hz: int, bw_hz: int, sf: int, cr: int) -> bool: + """Set radio parameters (frequency, bandwidth, SF, CR).""" + if not (5 <= sf <= 12): + raise ValueError(f"Spreading factor out of range: {sf}") + if not (5 <= cr <= 8): + raise ValueError(f"Coding rate out of range: {cr}") + self.prefs.frequency_hz = freq_hz + self.prefs.bandwidth_hz = bw_hz + self.prefs.spreading_factor = sf + self.prefs.coding_rate = cr + self._save_prefs() + return True + + def set_tx_power(self, power_dbm: int) -> bool: + """Set the transmit power in dBm.""" + self.prefs.tx_power_dbm = power_dbm + self._save_prefs() + return True + + def set_tuning_params(self, rx_delay: float, airtime_factor: float) -> None: + """Set RX delay and airtime factor tuning parameters.""" + self.prefs.rx_delay_base = rx_delay + self.prefs.airtime_factor = airtime_factor + self._save_prefs() + + def get_tuning_params(self) -> tuple[float, float]: + """Return the current (rx_delay, airtime_factor) tuning parameters.""" + return (self.prefs.rx_delay_base, self.prefs.airtime_factor) + + def get_radio_params(self) -> dict: + """Return current radio configuration (frequency, bandwidth, SF, CR, TX power, tuning). + + Use this to fetch the radio configuration details. Keys match the arguments + to set_radio_params/set_tx_power/set_tuning_params: frequency_hz, bandwidth_hz, + spreading_factor, coding_rate, tx_power_dbm, rx_delay_base, airtime_factor. + """ + return { + "frequency_hz": self.prefs.frequency_hz, + "bandwidth_hz": self.prefs.bandwidth_hz, + "spreading_factor": self.prefs.spreading_factor, + "coding_rate": self.prefs.coding_rate, + "tx_power_dbm": self.prefs.tx_power_dbm, + "rx_delay_base": self.prefs.rx_delay_base, + "airtime_factor": self.prefs.airtime_factor, + } + + def get_time(self) -> int: + """Return the current device time as a Unix timestamp.""" + return int(time.time() + self._time_offset) + + def set_time(self, secs: int) -> bool: + """Set the device time. Returns False if *secs* is in the past.""" + current = self.get_time() + if secs < current: + return False + self._time_offset = secs - time.time() + return True + + def set_other_params( + self, + manual_add: int, + telemetry_modes: int, + advert_loc_policy: int, + multi_acks: int, + ) -> None: + """Set additional node parameters (manual add, telemetry, location, multi-acks).""" + self.prefs.manual_add_contacts = manual_add + self.prefs.telemetry_mode_base = telemetry_modes & 0x03 + self.prefs.telemetry_mode_location = (telemetry_modes >> 2) & 0x03 + self.prefs.telemetry_mode_environment = (telemetry_modes >> 4) & 0x03 + self.prefs.advert_loc_policy = advert_loc_policy + self.prefs.multi_acks = multi_acks + self._save_prefs() + + def set_path_hash_mode(self, mode: int) -> None: + """Set path hash encoding mode (0=1-byte, 1=2-byte, 2=3-byte hashes).""" + self.prefs.path_hash_mode = mode + self._save_prefs() + + def get_self_info(self) -> NodePrefs: + """Return a copy of the current node preferences.""" + return copy.copy(self.prefs) + + def get_public_key(self) -> bytes: + """Return this node's 32-byte Ed25519 public key.""" + return self._identity.get_public_key() + + # ------------------------------------------------------------------------- + # Path & Routing + # ------------------------------------------------------------------------- + + def reset_path(self, pub_key: bytes) -> bool: + """Reset the outbound routing path for a contact.""" + contact = self.contacts.get_by_key(pub_key) + if not contact: + return False + contact.out_path_len = -1 + contact.out_path = b"" + self.contacts.update(contact) + return True + + def get_advert_path(self, pub_key_prefix: bytes) -> Optional[AdvertPath]: + """Look up a cached advert path by public key prefix.""" + return self.path_cache.get_by_prefix(pub_key_prefix) + + # ------------------------------------------------------------------------- + # Channel Management + # ------------------------------------------------------------------------- + + def get_channel(self, idx: int) -> Optional[Channel]: + """Return the channel at the given index, or None.""" + return self.channels.get(idx) + + def set_channel(self, idx: int, name: str, secret: bytes) -> bool: + """Set a channel at the given index with name and 32-byte secret.""" + # MeshCore DataStore uses 32-byte secret; GroupTextHandler uses up to 32 for HMAC + if len(secret) < 32: + secret = secret + b"\x00" * (32 - len(secret)) + elif len(secret) > 32: + secret = secret[:32] + ok = self.channels.set(idx, Channel(name=name[:32], secret=secret)) + if ok: + ch = self.channels.get(idx) + self._schedule_fire_callbacks("channel_updated", idx, ch) + return ok + + def remove_channel(self, idx: int) -> bool: + """Remove the channel at the given index. Fires on_channel_updated(idx, None).""" + ok = self.channels.remove(idx) + if ok: + self._schedule_fire_callbacks("channel_updated", idx, None) + return ok + + # ------------------------------------------------------------------------- + # Signing Pipeline + # ------------------------------------------------------------------------- + + def sign_start(self) -> int: + """Begin a signing session; returns the maximum sign buffer size.""" + self._sign_buffer = bytearray() + return MAX_SIGN_DATA_SIZE + + def sign_data(self, data: bytes) -> bool: + """Append data to the signing buffer.""" + if self._sign_buffer is None: + logger.warning("sign_data called without sign_start") + return False + if len(self._sign_buffer) + len(data) > MAX_SIGN_DATA_SIZE: + logger.warning("Sign data would overflow buffer") + return False + self._sign_buffer.extend(data) + return True + + def sign_finish(self) -> Optional[bytes]: + if self._sign_buffer is None: + logger.warning("sign_finish called without sign_start") + return None + try: + return self._identity.sign(bytes(self._sign_buffer)) + except Exception as e: + logger.error(f"Signing error: {e}") + return None + finally: + self._sign_buffer = None + + # ------------------------------------------------------------------------- + # Key Management + # ------------------------------------------------------------------------- + + def export_private_key(self) -> bytes: + """Return the raw signing key bytes for backup/export.""" + return self._identity.get_signing_key_bytes() + + # ------------------------------------------------------------------------- + # Flood Scope + # ------------------------------------------------------------------------- + + def set_flood_scope(self, transport_key: Optional[bytes] = None) -> None: + """Set or clear the flood transport key for scoped flooding.""" + if transport_key and len(transport_key) >= 16: + self._flood_transport_key = transport_key[:16] + else: + self._flood_transport_key = None + + def set_flood_region(self, region_name: Optional[str] = None) -> None: + """Set flood scope from a region name (e.g., ``'#usa'``) or clear it. + + Derives the 16-byte transport key automatically via SHA-256 of the + region name. A leading ``#`` is added if not already present. + Pass ``None`` to clear the scope (flood to all). + """ + if region_name: + if not region_name.startswith("#"): + region_name = f"#{region_name}" + self._flood_transport_key = get_auto_key_for(region_name) + else: + self._flood_transport_key = None + + def _apply_flood_scope(self, pkt: Packet) -> None: + """Apply flood scope transport codes to a packet in-place. + + If ``_flood_transport_key`` is set and the packet uses flood routing, + calculates the transport code, attaches it to the packet, and changes + the route type to ``ROUTE_TYPE_TRANSPORT_FLOOD``. + + Matches firmware ``sendFloodScoped()`` in ``BaseChatMesh.cpp``. + """ + if self._flood_transport_key is None: + return + route_type = pkt.get_route_type() + if route_type != ROUTE_TYPE_FLOOD: + return # only scope flood packets, not direct + code = calc_transport_code(self._flood_transport_key, pkt) + pkt.transport_codes[0] = code + pkt.transport_codes[1] = 0 # reserved for home region (firmware TODO) + # Switch route type from FLOOD -> TRANSPORT_FLOOD + pkt.header = (pkt.header & ~0x03) | ROUTE_TYPE_TRANSPORT_FLOOD + + def _apply_path_hash_mode(self, pkt: Packet) -> None: + """Encode the device's path_hash_mode in originated packets. + + When a packet has 0 hops (freshly originated), sets bits 6-7 of + ``path_len`` to encode the hash size from ``prefs.path_hash_mode``. + Packets with existing hops (stored contact paths) are untouched. + Trace packets are excluded because the repeater's trace handler uses + ``path``/``path_len`` to store SNR values, not routing hashes. + + Mirrors firmware ``sendFlood(pkt, delay, _prefs.path_hash_mode + 1)`` + which calls ``pkt->setPathHashSizeAndCount(hash_size, 0)``. + """ + if pkt.get_payload_type() == PAYLOAD_TYPE_TRACE: + return + if pkt.get_path_hash_count() == 0: + hash_size = self.prefs.path_hash_mode + 1 + pkt.path_len = PathUtils.encode_path_len(hash_size, 0) + + # ------------------------------------------------------------------------- + # Statistics (subclasses may override _get_radio_stats for STATS_TYPE_RADIO) + # ------------------------------------------------------------------------- + + def get_stats(self, stats_type: int = STATS_TYPE_PACKETS) -> dict: + """Return statistics of the requested type (core, radio, or packets).""" + if stats_type == STATS_TYPE_CORE: + return { + "uptime_secs": self.stats.get_uptime_secs(), + "queue_len": self.message_queue.count, + "contacts_count": self.contacts.get_count(), + "channels_count": self.channels.get_count(), + } + elif stats_type == STATS_TYPE_RADIO: + return self._get_radio_stats() + return self.stats.get_totals() + + def _get_radio_stats(self) -> dict: + """Override in CompanionRadio for hardware RSSI/SNR. Default: prefs only.""" + return { + "frequency_hz": self.prefs.frequency_hz, + "bandwidth_hz": self.prefs.bandwidth_hz, + "spreading_factor": self.prefs.spreading_factor, + "coding_rate": self.prefs.coding_rate, + "tx_power_dbm": self.prefs.tx_power_dbm, + } + + # ------------------------------------------------------------------------- + # Custom Variables + # ------------------------------------------------------------------------- + + def get_custom_vars(self) -> dict[str, str]: + """Return a copy of all custom variables.""" + return dict(self._custom_vars) + + def set_custom_var(self, name: str, value: str) -> bool: + """Set a custom variable by name.""" + self._custom_vars[name] = value + return True + + # ------------------------------------------------------------------------- + # Auto-Add Configuration + # ------------------------------------------------------------------------- + + def get_autoadd_config(self) -> int: + """Return the current auto-add configuration bitmask.""" + return self.prefs.autoadd_config + + def set_autoadd_config(self, config: int) -> None: + """Set the auto-add configuration bitmask.""" + self.prefs.autoadd_config = config + self._save_prefs() + + # Map ADV_TYPE_* → AUTOADD_* bitmask bits (mirrors C++ shouldAutoAddContactType) + _AUTOADD_TYPE_MAP: dict[int, int] = { + ADV_TYPE_CHAT: AUTOADD_CHAT, # 1 → 0x02 + ADV_TYPE_REPEATER: AUTOADD_REPEATER, # 2 → 0x04 + ADV_TYPE_ROOM: AUTOADD_ROOM, # 3 → 0x08 + ADV_TYPE_SENSOR: AUTOADD_SENSOR, # 4 → 0x10 + } + + def should_auto_add_contact_type(self, contact_type: int) -> bool: + """Check if a contact type should be auto-added based on current preferences. + + Mirrors C++ MyMesh::shouldAutoAddContactType (MyMesh.cpp:281-304). + """ + # manual_add_contacts bit 0 == 0 → auto-add ALL types + if (self.prefs.manual_add_contacts & 1) == 0: + return True + # Selective mode: check the type-specific bit in autoadd_config + type_bit = self._AUTOADD_TYPE_MAP.get(contact_type, 0) + return bool(self.prefs.autoadd_config & type_bit) if type_bit else False + + def should_overwrite_when_full(self) -> bool: + """Check if overwrite-oldest is enabled. Mirrors C++ shouldOverwriteWhenFull.""" + return bool(self.prefs.autoadd_config & AUTOADD_OVERWRITE_OLDEST) + + async def _apply_advert_to_stores( + self, + contact: Contact, + inbound_path: Optional[bytes] = None, + *, + path_len_encoded: Optional[int] = None, + ) -> Optional[Contact]: + """Apply advert to ContactStore and PathCache. Shared by Bridge and NODE_DISCOVERED. + + Mirrors C++ BaseChatMesh::onAdvertRecv (existing update, auto-add filter, + overwrite when full). Returns the Contact if added or updated, None otherwise. + Path cache is updated for all valid contacts (pub_key >= 7, name non-empty). + + Args: + path_len_encoded: Encoded path_len byte from the packet. If None, + falls back to len(inbound_path) (assumes 1-byte hashes). + """ + try: + if len(contact.public_key) < 7 or not contact.name: + return None + inbound_path = inbound_path or b"" + advert_path_len = ( + path_len_encoded if path_len_encoded is not None else len(inbound_path) + ) + self.path_cache.update( + AdvertPath( + public_key_prefix=contact.public_key[:7], + name=contact.name, + path_len=advert_path_len, + path=inbound_path, + recv_timestamp=int(time.time()), + ) + ) + existing = self.contacts.get_by_key(contact.public_key) + if existing is not None: + contact.out_path_len = existing.out_path_len + contact.out_path = existing.out_path + contact.flags = existing.flags + contact.sync_since = existing.sync_since + self.contacts.update(contact) + return contact + if not self.should_auto_add_contact_type(contact.adv_type): + logger.debug("Auto-add filtered: type %d not allowed", contact.adv_type) + return None + if self.should_overwrite_when_full() and self.contacts.is_full(): + ok, overwritten = self.contacts.add_or_overwrite(contact) + if ok and overwritten: + await self._fire_callbacks("contact_deleted", overwritten) + elif not ok: + await self._fire_callbacks("contacts_full") + return contact if ok else None + added = self.contacts.add(contact) + if not added and self.contacts.is_full(): + await self._fire_callbacks("contacts_full") + return contact if added else None + except Exception as e: + logger.error("Error applying advert to stores: %s", e) + return None + + # ------------------------------------------------------------------------- + # Push Callbacks + # ------------------------------------------------------------------------- + + def clear_push_callbacks(self) -> None: + """Remove all registered push callbacks. + + Called by FrameServer between client connections so that stale + closures from a previous connection are not invoked on the next one. + """ + for key in self._push_callbacks: + self._push_callbacks[key].clear() + + def on_message_received(self, callback: Callable) -> None: + self._push_callbacks["message_received"].append(callback) + + def on_channel_message_received(self, callback: Callable) -> None: + self._push_callbacks["channel_message_received"].append(callback) + + def on_advert_received(self, callback: Callable) -> None: + self._push_callbacks["advert_received"].append(callback) + + def on_contact_path_updated(self, callback: Callable) -> None: + self._push_callbacks["contact_path_updated"].append(callback) + + async def _on_contact_path_updated(self, pub: bytes, path_len: int, path_bytes: bytes) -> None: + """Called by ProtocolResponseHandler when contact's out_path is updated from a PATH packet. + + Matches companion firmware behaviour: PATH updates are only applied + (and pushed to the client) for contacts that already exist in the + store. Unknown public keys are silently ignored. + """ + contact = self.get_contact_by_key(pub) + if contact is None: + return # Firmware does not send PATH for non-contacts + contact.out_path_len = path_len + contact.out_path = path_bytes + self.contacts.update(contact) + await self._fire_callbacks("contact_path_updated", contact) + + def on_send_confirmed(self, callback: Callable) -> None: + self._push_callbacks["send_confirmed"].append(callback) + + def on_trace_received(self, callback: Callable) -> None: + self._push_callbacks["trace_received"].append(callback) + + def on_node_discovered(self, callback: Callable) -> None: + self._push_callbacks["node_discovered"].append(callback) + + def on_login_result(self, callback: Callable) -> None: + self._push_callbacks["login_result"].append(callback) + + def on_telemetry_response(self, callback: Callable) -> None: + self._push_callbacks["telemetry_response"].append(callback) + + def on_status_response(self, callback: Callable) -> None: + self._push_callbacks["status_response"].append(callback) + + def on_raw_data_received(self, callback: Callable) -> None: + self._push_callbacks["raw_data_received"].append(callback) + + def on_rx_log_data(self, callback: Callable) -> None: + """Register callback for raw RX with SNR/RSSI (CompanionRadio only). + + Callback(snr: float, rssi: int, raw_bytes: bytes). Same data as + PUSH_CODE_LOG_RX_DATA (0x88). Only fired when using CompanionRadio; + CompanionBridge does not own the radio. + """ + self._push_callbacks["rx_log_data"].append(callback) + + def on_binary_response(self, callback: Callable) -> None: + """Register callback for PUSH 0x8C. Callback(tag_bytes, response_data).""" + self._push_callbacks["binary_response"].append(callback) + + def on_path_discovery_response(self, callback: Callable) -> None: + """Register callback for path discovery 0x8D. (tag_bytes, pubkey, out_path, in_path).""" + self._push_callbacks["path_discovery_response"].append(callback) + + def on_contact_deleted(self, callback: Callable) -> None: + """Register callback for PUSH 0x8F (contact overwritten). Callback(pub_key_bytes).""" + self._push_callbacks["contact_deleted"].append(callback) + + def on_contacts_full(self, callback: Callable) -> None: + """Register callback for PUSH 0x90 (contacts store full). Callback().""" + self._push_callbacks["contacts_full"].append(callback) + + def on_channel_updated(self, callback: Callable) -> None: + """Register callback for channel set/remove. Callback(idx: int, channel_or_none).""" + self._push_callbacks["channel_updated"].append(callback) + + def register_binary_request( + self, + tag_hex: str, + request_type: int, + timeout_seconds: float, + pubkey_prefix: str = "", + context: Optional[dict] = None, + ) -> None: + """Register a pending binary request. Call cleanup_expired_requests first.""" + self._pending_binary_requests[tag_hex] = { + "request_type": request_type, + "pubkey_prefix": pubkey_prefix, + "expires_at": time.time() + timeout_seconds, + "context": context or {}, + } + + def cleanup_expired_binary_requests(self) -> None: + """Remove expired entries from _pending_binary_requests.""" + now = time.time() + expired = [ + tag for tag, info in self._pending_binary_requests.items() if now > info["expires_at"] + ] + for tag in expired: + del self._pending_binary_requests[tag] + + async def _on_binary_response( + self, + tag_bytes: bytes, + response_data: bytes, + path_info: Optional[tuple] = None, + ) -> None: + """Called when binary response (tag + data, optional path) received.""" + if path_info is not None: + if await self._try_handle_path_discovery(tag_bytes, path_info): + return + self.cleanup_expired_binary_requests() + tag_hex = tag_bytes.hex() + info = self._pending_binary_requests.pop(tag_hex, None) + if not info: + # Skip log for small payloads (e.g. login response handled elsewhere) + if len(response_data) >= 20: + logger.debug(f"Binary response for unknown tag {tag_hex}") + await self._fire_callbacks("binary_response", tag_bytes, response_data) + return + request_type = info["request_type"] + pubkey_prefix = info.get("pubkey_prefix", "") + context = info.get("context", {}) + parsed = None + try: + from . import binary_parsing + + parsed = binary_parsing.parse_binary_response( + request_type, response_data, pubkey_prefix=pubkey_prefix, context=context + ) + except Exception as e: + logger.debug(f"Binary response parse for type {request_type}: {e}") + await self._fire_callbacks( + "binary_response", tag_bytes, response_data, parsed, request_type + ) + + async def _try_handle_path_discovery(self, tag_bytes: bytes, path_info: tuple) -> bool: + """If tag is pending path discovery, fire path_discovery_response and return True.""" + out_path, in_path, contact_pubkey = path_info + tag_int = int.from_bytes(tag_bytes, "little") + if tag_int not in self._pending_discovery_tags: + return False + self._pending_discovery_tags.discard(tag_int) + await self._fire_callbacks( + "path_discovery_response", + tag_bytes, + contact_pubkey, + out_path, + in_path, + ) + return True + + # ------------------------------------------------------------------------- + # Abstract methods (subclasses must implement) + # ------------------------------------------------------------------------- + + @abstractmethod + async def _send_packet(self, pkt: Packet, wait_for_ack: bool = False) -> bool: + """Send a packet via the subclass transport (radio or packet_injector).""" + + @abstractmethod + async def start(self) -> None: + """Start the companion.""" + + @abstractmethod + async def stop(self) -> None: + """Stop the companion.""" + + @property + @abstractmethod + def is_running(self) -> bool: + """Return whether the companion is currently running.""" + + @abstractmethod + def import_private_key(self, key: bytes) -> bool: + """Import a private key and rebuild the identity.""" + + def _get_protocol_response_handler(self) -> Any: + """Return the protocol response handler, or ``None``. + + Subclasses that support request/response methods (telemetry, status, + binary request, etc.) must override this to return their handler. + """ + return None + + def _get_login_response_handler(self) -> Any: + """Return the login response handler, or ``None``.""" + return None + + def _get_text_handler(self) -> Any: + """Return the text message handler, or ``None``.""" + return None + + # ------------------------------------------------------------------------- + # Unified TX methods (shared between Radio and Bridge) + # ------------------------------------------------------------------------- + + async def advertise(self, flood: bool = True) -> bool: + """Broadcast an advertisement packet.""" + flags = adv_type_to_flags(self.prefs.adv_type) + flags |= ADVERT_FLAG_HAS_NAME + lat, lon = 0.0, 0.0 + if self.prefs.advert_loc_policy == ADVERT_LOC_SHARE: + lat, lon = self.prefs.latitude, self.prefs.longitude + if lat != 0.0 or lon != 0.0: + flags |= ADVERT_FLAG_HAS_LOCATION + route = "flood" if flood else "direct" + pkt = PacketBuilder.create_advert( + local_identity=self._identity, + name=self.prefs.node_name, + lat=lat, + lon=lon, + flags=flags, + route_type=route, + ) + self._apply_flood_scope(pkt) + self._apply_path_hash_mode(pkt) + success = await self._send_packet(pkt, wait_for_ack=False) + if success: + self.stats.record_tx(is_flood=flood) + else: + self.stats.record_tx_error() + return success + + async def share_contact(self, pub_key: bytes) -> bool: + """Share a contact's advert to the mesh.""" + contact = self.contacts.get_by_key(pub_key) + if not contact: + return False + try: + pkt = PacketBuilder.create_advert( + local_identity=self._identity, + name=contact.name, + flags=adv_type_to_flags(contact.adv_type) | ADVERT_FLAG_HAS_NAME, + route_type="direct", + ) + self._apply_flood_scope(pkt) + self._apply_path_hash_mode(pkt) + return await self._send_packet(pkt, wait_for_ack=False) + except Exception as e: + logger.error(f"Error sharing contact: {e}") + return False + + async def send_trace_path_raw( + self, + tag: int, + auth_code: int, + flags: int, + path_bytes: bytes, + ) -> bool: + """Send a trace packet with an explicit path.""" + try: + path_list = list(path_bytes) + pkt = PacketBuilder.create_trace(tag, auth_code, flags, path=path_list) + self._apply_flood_scope(pkt) + self._apply_path_hash_mode(pkt) + return await self._send_packet(pkt, wait_for_ack=False) + except Exception as e: + logger.error(f"Error sending trace (raw path): {e}") + return False + + async def send_binary_req( + self, pub_key: bytes, data: bytes, timeout_seconds: float = 15.0 + ) -> SentResult: + """Send binary request (CMD_SEND_BINARY_REQ). + + data = request_type(1) + optional payload. + Returns SentResult with expected_ack (4-byte tag as int) and timeout_ms. + """ + contact = self.contacts.get_by_key(pub_key) + if not contact: + return SentResult(success=False) + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return SentResult(success=False) + request_type = data[0] if len(data) >= 1 else 0 + # C++ companion pattern (BaseChatMesh::sendRequest): + # tag = getRTCClock()->getCurrentTimeUnique() + # memcpy(temp, &tag, 4); memcpy(&temp[4], req_data, data_len); + # create_protocol_request packs: timestamp(4) + protocol_code(1) + extra_data. + # The repeater echoes sender_timestamp (bytes 0-3) in the response. + # So the timestamp IS the tag — we capture it from create_protocol_request. + protocol_code = request_type + req_payload = data[1:] # request params only; timestamp provides uniqueness + self.cleanup_expired_binary_requests() + try: + pkt, timestamp = PacketBuilder.create_protocol_request( + contact=proxy, + local_identity=self._identity, + protocol_code=protocol_code, + data=req_payload, + ) + # Use the timestamp as the tag — matches what the repeater echoes back + tag_int = timestamp + tag_bytes = tag_int.to_bytes(4, "little") + tag_hex = tag_bytes.hex() + self.register_binary_request( + tag_hex, + request_type=request_type, + timeout_seconds=timeout_seconds, + pubkey_prefix=pub_key[:6].hex(), + ) + self._apply_flood_scope(pkt) + self._apply_path_hash_mode(pkt) + success = await self._send_packet(pkt, wait_for_ack=False) + except Exception as e: + logger.error(f"Binary request send error: {e}") + if "tag_hex" in locals(): + self._pending_binary_requests.pop(tag_hex, None) + return SentResult(success=False) + if not success: + self._pending_binary_requests.pop(tag_hex, None) + return SentResult(success=False) + return SentResult( + success=True, + is_flood=contact.out_path_len <= 0, + expected_ack=tag_int, + timeout_ms=DEFAULT_RESPONSE_TIMEOUT_MS, + ) + + async def send_anon_req( + self, pub_key: bytes, data: bytes, timeout_seconds: float = 15.0 + ) -> SentResult: + """Send anonymous request (CMD_SEND_ANON_REQ), e.g. owner info. + + data = request payload (e.g. [0x07] for GET_OWNER_INFO). Response is + delivered via on_binary_response (PUSH_CODE_BINARY_RESPONSE) like binary req. + """ + contact = self.contacts.get_by_key(pub_key) + if not contact: + return SentResult(success=False) + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return SentResult(success=False) + request_type = PROTOCOL_CODE_ANON_REQ + req_payload = data # no random tag; timestamp provides uniqueness + self.cleanup_expired_binary_requests() + try: + pkt, timestamp = PacketBuilder.create_protocol_request( + contact=proxy, + local_identity=self._identity, + protocol_code=PROTOCOL_CODE_ANON_REQ, + data=req_payload, + ) + # Use the timestamp as the tag — matches what the repeater echoes back + tag_int = timestamp + tag_bytes = tag_int.to_bytes(4, "little") + tag_hex = tag_bytes.hex() + self.register_binary_request( + tag_hex, + request_type=request_type, + timeout_seconds=timeout_seconds, + pubkey_prefix=pub_key[:6].hex(), + ) + self._apply_flood_scope(pkt) + self._apply_path_hash_mode(pkt) + success = await self._send_packet(pkt, wait_for_ack=False) + except Exception as e: + logger.error(f"Anon request send error: {e}") + if "tag_hex" in locals(): + self._pending_binary_requests.pop(tag_hex, None) + return SentResult(success=False) + if not success: + self._pending_binary_requests.pop(tag_hex, None) + return SentResult(success=False) + return SentResult( + success=True, + is_flood=contact.out_path_len <= 0, + expected_ack=tag_int, + timeout_ms=DEFAULT_RESPONSE_TIMEOUT_MS, + ) + + async def send_path_discovery(self, pub_key: bytes) -> bool: + """Legacy: send path discovery without returning tag. Prefer send_path_discovery_req.""" + result = await self.send_path_discovery_req(pub_key) + return result.success + + async def send_path_discovery_req(self, pub_key: bytes) -> SentResult: + """Send path discovery (flood telemetry request with tag). + + Returns SentResult for RESP_CODE_SENT. When path return arrives with + matching tag, path_discovery_response is fired (PUSH 0x8D). + """ + contact = self.contacts.get_by_key(pub_key) + if not contact: + return SentResult(success=False) + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return SentResult(success=False) + tag_int = random.randint(0, 0xFFFFFFFF) + tag_bytes = tag_int.to_bytes(4, "little") + inv_perm = 0xFF & ~TELEM_PERM_BASE + req_payload = tag_bytes + bytes([REQ_TYPE_GET_TELEMETRY_DATA, inv_perm, 0, 0, 0]) + old_path_len = contact.out_path_len + old_path = contact.out_path + contact.out_path_len = -1 + contact.out_path = b"" + self.contacts.update(contact) + try: + pkt, _ = PacketBuilder.create_protocol_request( + contact=proxy, + local_identity=self._identity, + protocol_code=REQ_TYPE_GET_TELEMETRY_DATA, + data=req_payload, + ) + self._apply_flood_scope(pkt) + self._apply_path_hash_mode(pkt) + success = await self._send_packet(pkt, wait_for_ack=False) + if success: + self._pending_discovery_tags.add(tag_int) + return SentResult( + success=success, + is_flood=True, + expected_ack=tag_int, + timeout_ms=DEFAULT_RESPONSE_TIMEOUT_MS, + ) + except Exception as e: + logger.error(f"Error in path discovery: {e}") + return SentResult(success=False) + finally: + current = self.contacts.get_by_key(pub_key) + if current and current.out_path_len == -1: + current.out_path_len = old_path_len + current.out_path = old_path + self.contacts.update(current) + + async def send_text_message( + self, + pub_key: bytes, + text: str, + txt_type: int = TXT_TYPE_PLAIN, + attempt: int = 1, + wait_for_ack: bool = True, + ) -> SentResult: + """Send a direct text message to a contact. + + When wait_for_ack is True (default), blocks until ACK or timeout. + When wait_for_ack is False, returns as soon as the packet is handed off; + ACK (if any) is still tracked and will trigger send_confirmed later. + """ + contact = self.contacts.get_by_key(pub_key) + if not contact: + logger.warning(f"Contact not found for key {pub_key.hex()[:12]}...") + return SentResult(success=False) + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return SentResult(success=False) + try: + is_flood = proxy.out_path_len < 0 + msg_type = "flood" if is_flood else "direct" + pkt, ack_crc = PacketBuilder.create_text_message( + contact=proxy, + local_identity=self._identity, + message=text, + attempt=attempt, + message_type=msg_type, + ) + self._apply_flood_scope(pkt) + self._apply_path_hash_mode(pkt) + self._track_pending_ack(ack_crc) + if wait_for_ack: + success = await self._send_packet(pkt, wait_for_ack=True) + if success: + self.stats.record_tx(is_flood=is_flood) + else: + self.stats.record_tx_error() + return SentResult( + success=success, + is_flood=is_flood, + expected_ack=ack_crc, + timeout_ms=None, + ) + success = await self._send_packet(pkt, wait_for_ack=False) + if success: + self.stats.record_tx(is_flood=is_flood) + else: + self.stats.record_tx_error() + return SentResult( + success=success, + is_flood=is_flood, + expected_ack=ack_crc, + timeout_ms=DEFAULT_RESPONSE_TIMEOUT_MS, + ) + except Exception as e: + logger.error(f"Error sending text message: {e}") + self.stats.record_tx_error() + return SentResult(success=False) + + async def send_channel_message(self, channel_idx: int, text: str) -> bool: + """Send a message to a channel.""" + channel = self.channels.get(channel_idx) + if not channel: + logger.warning(f"Channel {channel_idx} not found") + return False + try: + pkt = PacketBuilder.create_group_datagram( + group_name=channel.name, + local_identity=self._identity, + message=text, + sender_name=self.prefs.node_name, + channels_config=self.channels.get_channels(), + ) + self._apply_flood_scope(pkt) + self._apply_path_hash_mode(pkt) + success = await self._send_packet(pkt, wait_for_ack=False) + if success: + self.stats.record_tx(is_flood=True) + else: + self.stats.record_tx_error() + return success + except Exception as e: + logger.error(f"Error sending channel message: {e}") + self.stats.record_tx_error() + return False + + async def send_raw_data( + self, + dest_key: bytes, + data: bytes, + path: Optional[bytes] = None, + ) -> SentResult: + """Send raw data to a contact via a protocol request.""" + contact = self.contacts.get_by_key(dest_key) + if not contact: + return SentResult(success=False) + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return SentResult(success=False) + try: + pkt, _ = PacketBuilder.create_protocol_request( + contact=proxy, + local_identity=self._identity, + protocol_code=PROTOCOL_CODE_RAW_DATA, + data=data, + ) + self._apply_path_hash_mode(pkt) + success = await self._send_packet(pkt, wait_for_ack=False) + return SentResult(success=success) + except Exception as e: + logger.error(f"Error sending raw data: {e}") + return SentResult(success=False) + + async def send_raw_data_direct( + self, path: bytes, payload: bytes, *, path_len_encoded: int = None + ) -> SentResult: + """Send a raw custom packet (PAYLOAD_TYPE_RAW_CUSTOM) on the given direct path. + + No encryption or contact lookup; path and payload are supplied by the caller. + Matches firmware CMD_SEND_RAW_DATA behaviour. + + Args: + path_len_encoded: Encoded path_len byte. If None, assumes 1-byte hashes. + """ + if len(payload) < 4: + return SentResult(success=False) + if len(path) > MAX_PATH_SIZE: + return SentResult(success=False) + if len(payload) > MAX_PACKET_PAYLOAD: + return SentResult(success=False) + try: + pkt = PacketBuilder.create_raw_data(payload) + pkt.set_path(path, path_len_encoded) + success = await self._send_packet(pkt, wait_for_ack=False) + if success: + self.stats.record_tx(is_flood=False) + else: + self.stats.record_tx_error() + return SentResult(success=success) + except Exception as e: + logger.error(f"Error sending raw data direct: {e}") + return SentResult(success=False) + + async def send_trace_path( + self, + pub_key: bytes, + tag: int, + auth_code: int, + flags: int = 0, + ) -> bool: + """Send a trace path request to a contact.""" + contact = self.contacts.get_by_key(pub_key) + if not contact: + return False + path = list(contact.out_path) if contact.out_path else [] + if not path: + path = [contact.public_key[0]] + try: + pkt = PacketBuilder.create_trace(tag, auth_code, flags, path=path) + self._apply_path_hash_mode(pkt) + return await self._send_packet(pkt, wait_for_ack=False) + except Exception as e: + logger.error(f"Error sending trace: {e}") + return False + + async def send_control_data(self, data: Any = None) -> bool: + """Send a CONTROL packet (e.g. discovery request). + + If *data* is provided it must be 1-254 bytes with the first byte having + the 0x80 bit set (e.g. ``DISCOVER_REQ``). Returns ``False`` for + invalid payloads. + + When called with no *data* (or ``None``), a default discovery request + is sent for backward compatibility. + """ + try: + if data and len(data) <= 254 and (data[0] & 0x80) != 0: + pkt = Packet() + pkt.header = PacketBuilder._create_header(PAYLOAD_TYPE_CONTROL, route_type="direct") + pkt.path_len = 0 + pkt.path = bytearray() + pkt.payload = bytearray(data) + pkt.payload_len = len(data) + self._apply_path_hash_mode(pkt) + return await self._send_packet(pkt, wait_for_ack=False) + elif data is not None: + # data was provided but invalid + return False + # No data: send default discovery request + tag = random.randint(0, 0xFFFFFFFF) + pkt = PacketBuilder.create_discovery_request(tag, filter_mask=0x04) + self._apply_path_hash_mode(pkt) + return await self._send_packet(pkt, wait_for_ack=False) + except Exception as e: + logger.error(f"Error sending control data: {e}") + return False + + async def send_login(self, pub_key: bytes, password: str) -> dict: + """Send a login request to a repeater and wait for the response.""" + contact = self.contacts.get_by_key(pub_key) + if not contact: + return {"success": False, "reason": "Contact not found"} + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return {"success": False, "reason": "Contact not found"} + login_handler = self._get_login_response_handler() + if not login_handler: + return {"success": False, "reason": "Login handler not available"} + dest_hash = bytes.fromhex(proxy.public_key)[0] + login_handler.store_login_password(dest_hash, password) + login_result: dict = {"success": False, "data": {}} + login_event = asyncio.Event() + + def _login_cb(success: bool, data: dict) -> None: + login_result["success"] = success + login_result["data"] = data + login_event.set() + + login_handler.set_login_callback(_login_cb) + try: + pkt = PacketBuilder.create_login_packet( + contact=proxy, local_identity=self._identity, password=password + ) + self._apply_path_hash_mode(pkt) + await self._send_packet(pkt, wait_for_ack=False) + try: + await asyncio.wait_for(login_event.wait(), timeout=10.0) + except asyncio.TimeoutError: + return {"success": False, "reason": "Login response timeout"} + data = login_result["data"] + return { + "success": login_result["success"], + "repeater": contact.name, + "is_admin": data.get("is_admin", False), + "keep_alive_interval": data.get("keep_alive_interval", 0), + "tag": data.get("timestamp", 0), + "acl_permissions": data.get("reserved", data.get("permissions", 0)), + "firmware_ver_level": data.get("firmware_ver_level"), + "reason": "Login successful" if login_result["success"] else "Login failed", + } + except Exception as e: + logger.error(f"Login error: {e}") + return {"success": False, "reason": str(e)} + finally: + login_handler.set_login_callback(None) + login_handler.clear_login_password(dest_hash) + + async def send_logout(self, pub_key: bytes) -> bool: + """Send a logout / disconnect to a repeater contact.""" + contact = self.contacts.get_by_key(pub_key) + if not contact: + return False + try: + pkt, _ = PacketBuilder.create_logout_packet( + contact=contact, local_identity=self._identity + ) + self._apply_path_hash_mode(pkt) + await self._send_packet(pkt, wait_for_ack=False) + return True + except Exception as e: + logger.error(f"Logout error: {e}") + return False + + async def _wait_for_path_propagation(self, proxy: Any, request_type: str) -> None: + """Wait for reciprocal PATH to propagate through the mesh for multi-hop contacts. + + After login, pyMC sends a reciprocal PATH so the remote repeater learns + the return route. Each mesh hop adds ~500ms (airtime + processing). + Without this delay, the first REQ may arrive before the reciprocal PATH, + causing the remote to fall back to sendFlood() — which gets dropped by + intermediate repeaters due to transport-code region filtering. + """ + out_path_len = getattr(proxy, "out_path_len", -1) + if out_path_len > 0: + hop_count = PathUtils.get_path_hash_count(out_path_len) + propagation_delay = hop_count * 0.5 # e.g. 3 hops → 1.5s + logger.debug( + f"Multi-hop {request_type}: waiting {propagation_delay:.1f}s for " + f"reciprocal PATH propagation ({hop_count} hops)" + ) + await asyncio.sleep(propagation_delay) + + async def send_status_request(self, pub_key: bytes, timeout: float = 15.0) -> dict: + """Send a protocol request for repeater status/stats.""" + contact = self.contacts.get_by_key(pub_key) + if not contact: + return {"success": False, "reason": "Contact not found"} + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return {"success": False, "reason": "Contact not found"} + proto_handler = self._get_protocol_response_handler() + if not proto_handler: + return {"success": False, "reason": "Protocol handler not available"} + contact_hash = bytes.fromhex(proxy.public_key)[0] + waiter = ResponseWaiter() + proto_handler.set_response_callback(contact_hash, waiter.callback) + try: + await self._wait_for_path_propagation(proxy, "stats request") + pkt, _ = PacketBuilder.create_protocol_request( + contact=proxy, + local_identity=self._identity, + protocol_code=REQ_TYPE_GET_STATUS, + data=b"", + ) + self._apply_path_hash_mode(pkt) + await self._send_packet(pkt, wait_for_ack=False) + result = await waiter.wait(timeout) + return { + "success": result.get("success", False), + "repeater": contact.name, + "stats": result.get("parsed", {}), + "response_text": result.get("text"), + "reason": "Stats received" if result.get("success") else "Stats request failed", + } + except Exception as e: + logger.error(f"Status request error: {e}") + return {"success": False, "reason": str(e)} + finally: + proto_handler.clear_response_callback(contact_hash) + + async def send_telemetry_request( + self, + pub_key: bytes, + want_base: bool = True, + want_location: bool = True, + want_environment: bool = True, + timeout: float = 10.0, + ) -> dict: + """Send a telemetry request to a contact and wait for the response.""" + contact = self.contacts.get_by_key(pub_key) + if not contact: + return {"success": False, "reason": "Contact not found"} + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return {"success": False, "reason": "Contact not found"} + proto_handler = self._get_protocol_response_handler() + if not proto_handler: + return {"success": False, "reason": "Protocol handler not available"} + contact_hash = bytes.fromhex(proxy.public_key)[0] + waiter = ResponseWaiter() + proto_handler.set_response_callback(contact_hash, waiter.callback) + try: + await self._wait_for_path_propagation(proxy, "telemetry request") + inv = PacketBuilder._compute_inverse_perm_mask( + want_base, want_location, want_environment + ) + pkt, _ = PacketBuilder.create_protocol_request( + contact=proxy, + local_identity=self._identity, + protocol_code=REQ_TYPE_GET_TELEMETRY_DATA, + data=bytes([inv]), + ) + self._apply_path_hash_mode(pkt) + await self._send_packet(pkt, wait_for_ack=False) + result = await waiter.wait(timeout) + telemetry_data = dict(result.get("parsed", {})) + raw_bytes = telemetry_data.get("raw_bytes", b"") + if raw_bytes and len(pub_key) >= 6: + # Companion-style frame: 0x8B + reserved + 6-byte pubkey prefix + LPP + telemetry_data["frame_bytes"] = ( + bytes([PUSH_CODE_TELEMETRY_RESPONSE, 0]) + pub_key[:6] + raw_bytes + ) + return { + "success": result.get("success", False), + "contact": contact.name, + "telemetry_data": telemetry_data, + "response_text": result.get("text"), + "reason": ("Telemetry received" if result.get("success") else "Telemetry failed"), + } + except Exception as e: + logger.error(f"Telemetry error: {e}") + return {"success": False, "reason": str(e)} + finally: + proto_handler.clear_response_callback(contact_hash) + + async def send_binary_request(self, pub_key: bytes, data: bytes) -> dict: + """Legacy: send binary request and wait. + + Prefer ``send_binary_req`` + ``on_binary_response``. + """ + return await self._send_protocol_request(pub_key, PROTOCOL_CODE_BINARY_REQ, data) + + async def send_anon_request(self, pub_key: bytes, data: bytes) -> dict: + """Send an anonymous request to a contact and wait for the response.""" + return await self._send_protocol_request(pub_key, PROTOCOL_CODE_ANON_REQ, data) + + async def _send_protocol_request(self, pub_key: bytes, protocol_code: int, data: bytes) -> dict: + """Build and send a protocol request, waiting for the response.""" + contact = self.contacts.get_by_key(pub_key) + if not contact: + return {"success": False, "reason": "Contact not found"} + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return {"success": False, "reason": "Contact not found"} + proto_handler = self._get_protocol_response_handler() + if not proto_handler: + return {"success": False, "reason": "Protocol handler not available"} + contact_hash = bytes.fromhex(proxy.public_key)[0] + waiter = ResponseWaiter() + proto_handler.set_response_callback(contact_hash, waiter.callback) + try: + pkt, _ = PacketBuilder.create_protocol_request( + contact=proxy, + local_identity=self._identity, + protocol_code=protocol_code, + data=data, + ) + self._apply_path_hash_mode(pkt) + await self._send_packet(pkt, wait_for_ack=False) + result = await waiter.wait(10.0) + return { + "success": result.get("success", False), + "response": result.get("text"), + "parsed_data": result.get("parsed", {}), + "reason": "Success" if result.get("success") else "Failed", + } + except Exception as e: + logger.error(f"Protocol request error: {e}") + return {"success": False, "reason": str(e)} + finally: + proto_handler.clear_response_callback(contact_hash) + + async def send_repeater_command( + self, pub_key: bytes, command: str, parameters: Optional[str] = None + ) -> dict: + """Send a text-based command to a repeater and wait for the response.""" + contact = self.contacts.get_by_key(pub_key) + if not contact: + return {"success": False, "reason": "Contact not found"} + proxy = self.contacts.get_by_name(contact.name) + if not proxy: + return {"success": False, "reason": "Contact not found"} + text_handler = self._get_text_handler() + if not text_handler: + return {"success": False, "reason": "Text handler not available"} + full_command = command + if parameters: + full_command += f" {parameters}" + response_data: dict = {"text": None, "success": False} + response_event = asyncio.Event() + + def _response_cb(message_text: str, sender_contact: Any) -> None: + response_data["text"] = message_text + response_data["success"] = True + response_event.set() + + text_handler.set_command_response_callback(_response_cb) + try: + msg_type = "flood" if proxy.out_path_len < 0 else "direct" + pkt, ack_crc = PacketBuilder.create_text_message( + contact=proxy, + local_identity=self._identity, + message=full_command, + attempt=1, + message_type=msg_type, + ) + self._apply_path_hash_mode(pkt) + await self._send_packet(pkt, wait_for_ack=True) + try: + await asyncio.wait_for(response_event.wait(), timeout=15.0) + except asyncio.TimeoutError: + pass + return { + "success": response_data["success"], + "repeater": contact.name, + "command": command, + "response": response_data["text"], + "reason": ("Command successful" if response_data["success"] else "No response"), + } + except Exception as e: + logger.error(f"Repeater command error: {e}") + return {"success": False, "reason": str(e)} + finally: + text_handler.set_command_response_callback(None) + + def _track_pending_ack(self, ack_crc: int) -> None: + """Track pending ACK CRC for send_confirmed (capped).""" + if len(self._pending_ack_crcs) < MAX_PENDING_ACK_CRCS: + self._pending_ack_crcs.add(ack_crc) + + async def _try_confirm_send(self, crc: int) -> bool: + """If CRC is pending, discard it and fire send_confirmed. Returns True if fired.""" + if crc not in self._pending_ack_crcs: + return False + self._pending_ack_crcs.discard(crc) + await self._fire_callbacks("send_confirmed", crc) + return True + + def sync_next_message(self) -> Optional[QueuedMessage]: + """Pop and return the next queued message, or None.""" + return self.message_queue.pop() + + # ------------------------------------------------------------------------- + # Dedup Helper + # ------------------------------------------------------------------------- + + def _check_dedup(self, cache: OrderedDict, key: str, ttl: float, max_size: int) -> bool: + """Return True if *key* is a duplicate. Evicts expired entries.""" + now = time.time() + if key in cache: + return True + expired = [k for k, ts in cache.items() if now - ts > ttl] + for k in expired: + del cache[k] + cache[key] = now + if len(cache) > max_size: + cache.popitem(last=False) + return False + + # ------------------------------------------------------------------------- + # Event Handling (shared) + # ------------------------------------------------------------------------- + + async def _handle_mesh_event(self, event_type: str, data: dict) -> None: + try: + if event_type == MeshEvents.NEW_MESSAGE: + await self._handle_new_message(data) + elif event_type == MeshEvents.NEW_CHANNEL_MESSAGE: + await self._handle_new_channel_message(data) + elif event_type == MeshEvents.NEW_CONTACT: + await self._fire_callbacks("node_discovered", data) + elif event_type == MeshEvents.CONTACT_UPDATED: + pass + elif event_type == MeshEvents.NODE_DISCOVERED: + # Advert pipeline (single path): all adverts applied here; one event + # -> one store update and at most one advert_received (Bridge and Radio). + now = int(time.time()) + contact = Contact.from_dict(data, now=now) + if len(contact.public_key) >= 7 and contact.name: + inbound_path = data.get("inbound_path") + path_len_encoded = data.get("path_len_encoded") + applied = await self._apply_advert_to_stores( + contact, inbound_path, path_len_encoded=path_len_encoded + ) + if applied is not None: + await self._fire_callbacks("advert_received", applied) + await self._fire_callbacks("node_discovered", data) + elif event_type == MeshEvents.TELEMETRY_UPDATED: + await self._fire_callbacks("telemetry_response", data) + except Exception as e: + logger.error(f"Error handling mesh event {event_type}: {e}") + + async def _handle_new_message(self, data: dict) -> None: + # Deduplicate by packet hash so reconnects don't queue the same packet multiple times. + pkt_hash = data.get("packet_hash") + if pkt_hash and self._check_dedup( + self._seen_txt, pkt_hash, self._seen_txt_ttl, self._seen_txt_max + ): + return + + sender_key_hex = data.get("contact_pubkey", "") + sender_key = bytes.fromhex(sender_key_hex) if sender_key_hex else b"" + # Handler publishes "message_text"; accept "text" for compatibility + message_text = (data.get("message_text") or data.get("text") or "").rstrip("\x00") + # Extract SNR/RSSI from network info if available (same as channel path) + network_info = data.get("network_info", {}) + snr = network_info.get("snr") + rssi = network_info.get("rssi") + msg = QueuedMessage( + sender_key=sender_key, + txt_type=data.get("txt_type", data.get("flags", 0)), + timestamp=data.get("timestamp", int(time.time())), + text=message_text, + is_channel=False, + path_len=0, + snr=snr if snr is not None else 0.0, + rssi=rssi if rssi is not None else 0, + ) + self.message_queue.push(msg) + await self._fire_callbacks( + "message_received", + sender_key, + message_text, + msg.timestamp, + msg.txt_type, + pkt_hash, + snr if snr is not None else 0.0, + rssi if rssi is not None else 0, + ) + + async def _handle_new_channel_message(self, data: dict) -> None: + # Do not push our own (outgoing) channel messages to the client as incoming. + if data.get("is_outgoing"): + return + + # Deduplicate by packet hash so we queue one frame per logical message, matching + # firmware: Mesh.cpp only calls onChannelMessageRecv when !_tables->hasSeen(pkt). + pkt_hash = data.get("packet_hash") + if pkt_hash and self._check_dedup( + self._seen_grp_txt, pkt_hash, self._seen_grp_txt_ttl, self._seen_grp_txt_max + ): + return + + path_len = data.get("path_len", 0) + channel_name = data.get("channel_name", "") + # Resolve channel index so sync_next_message returns correct channel_idx in the frame + channel_idx = 0 + if getattr(self, "channels", None) and hasattr(self.channels, "find_by_name"): + idx = self.channels.find_by_name(channel_name) + if idx is not None: + channel_idx = idx + # MeshCore client expects "SenderName: Message" format in text field; it parses to show + # sender and message separately. Use full_content (not message_text) so client can split. + # Strip trailing nulls so frame matches firmware (exact string length, no padding). + display_text = (data.get("full_content", data.get("message_text", "")) or "").rstrip("\x00") + # Extract SNR/RSSI from network info if available + network_info = data.get("network_info", {}) + snr = network_info.get("snr") + rssi = network_info.get("rssi") + + msg = QueuedMessage( + sender_key=b"", + txt_type=0, + timestamp=data.get("timestamp", int(time.time())), + text=display_text, + is_channel=True, + channel_idx=channel_idx, + path_len=path_len, + snr=snr if snr is not None else 0.0, + rssi=rssi if rssi is not None else 0, + ) + self.message_queue.push(msg) + + await self._fire_callbacks( + "channel_message_received", + data.get("channel_name", ""), + data.get("sender_name", ""), + display_text, + msg.timestamp, + path_len, + channel_idx, + pkt_hash, + snr, + rssi, + ) + + async def _fire_callbacks(self, event_name: str, *args: Any) -> None: + for callback in self._push_callbacks.get(event_name, []): + try: + if asyncio.iscoroutinefunction(callback): + await callback(*args) + else: + callback(*args) + except Exception as e: + logger.error(f"Error in {event_name} callback: {e}") + + def _schedule_fire_callbacks(self, event_name: str, *args: Any) -> None: + """Schedule _fire_callbacks from sync code (e.g. set_channel). No-op if no running loop.""" + try: + loop = asyncio.get_running_loop() + loop.create_task(self._fire_callbacks(event_name, *args)) + except RuntimeError: + pass diff --git a/src/pymc_core/companion/companion_bridge.py b/src/pymc_core/companion/companion_bridge.py new file mode 100644 index 0000000..963d5fa --- /dev/null +++ b/src/pymc_core/companion/companion_bridge.py @@ -0,0 +1,343 @@ +""" +CompanionBridge - Repeater-integrated companion mode. + +Provides the same API as CompanionRadio but uses a shared dispatcher via +packet_injector. No radio ownership; host (repeater) injects packets via +process_received_packet and TX goes through packet_injector. +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any, Callable, Iterable, Optional + +from ..node.handlers import create_core_handlers +from ..node.handlers.login_server import LoginServerHandler +from ..protocol import LocalIdentity, Packet +from ..protocol.constants import ( + MAX_PATH_SIZE, + PAYLOAD_TYPE_ACK, + PAYLOAD_TYPE_ADVERT, + PAYLOAD_TYPE_ANON_REQ, + PAYLOAD_TYPE_GRP_TXT, + PAYLOAD_TYPE_PATH, + PAYLOAD_TYPE_RAW_CUSTOM, + PAYLOAD_TYPE_RESPONSE, + PAYLOAD_TYPE_TXT_MSG, + ROUTE_TYPE_FLOOD, + ROUTE_TYPE_TRANSPORT_FLOOD, +) +from .companion_base import CompanionBase +from .constants import ( + ADV_TYPE_CHAT, + DEFAULT_MAX_CHANNELS, + DEFAULT_MAX_CONTACTS, + DEFAULT_OFFLINE_QUEUE_SIZE, +) + +logger = logging.getLogger("CompanionBridge") + + +# --------------------------------------------------------------------------- +# Bridge ACK handler: fires send_confirmed when ACK CRC matches a pending send +# --------------------------------------------------------------------------- + + +class _BridgeAckHandler: + """Handles ACK packets (discrete and PATH-carried). + Fires send_confirmed when ACK CRC matches.""" + + def __init__(self, bridge: "CompanionBridge") -> None: + self._bridge = bridge + + @staticmethod + def payload_type() -> int: + return PAYLOAD_TYPE_ACK + + async def __call__(self, packet: Packet) -> None: + if not packet.payload or len(packet.payload) != 4: + return + crc = int.from_bytes(packet.payload, "little") + await self._apply_ack(crc) + + async def _apply_ack(self, crc: int) -> None: + """If CRC is pending, clear it and fire send_confirmed.""" + await self._bridge._try_confirm_send(crc) + + async def process_path_ack_variants(self, packet: Packet) -> Optional[int]: + """Decrypt PATH payload and return ACK CRC if present. + + Path update and contact_path_updated are handled by ProtocolResponseHandler; + this only extracts ACK for send_confirmed. + """ + from ..protocol import CryptoUtils, Identity + + payload = packet.payload + if not payload or len(payload) < 2 + 6: + return None + dest_hash = payload[0] + src_hash = payload[1] + our_hash = self._bridge._identity.get_public_key()[0] + if dest_hash != our_hash: + return None + encrypted = bytes(payload[2:]) + # Try each contact with matching src_hash until decryption succeeds + contacts_tried = 0 + for contact in self._bridge.contacts.contacts: + try: + pk = contact.public_key + pub = bytes.fromhex(pk) if isinstance(pk, str) else bytes(pk) + if len(pub) != 32 or pub[0] != src_hash: + continue + contacts_tried += 1 + peer_id = Identity(pub) + shared_secret = peer_id.calc_shared_secret(self._bridge._identity.get_private_key()) + aes_key = shared_secret[:16] + decrypted = CryptoUtils.mac_then_decrypt(aes_key, shared_secret, encrypted) + except Exception as e: + logger.debug( + "process_path_ack_variants: decrypt failed for src=0x%02x " "contact=%s: %s", + src_hash, + getattr(contact, "name", "?"), + e, + ) + continue + if len(decrypted) < 2: + logger.debug( + "process_path_ack_variants: decrypted too short (%d) for src=0x%02x", + len(decrypted), + src_hash, + ) + continue + path_len = min(decrypted[0], MAX_PATH_SIZE) + if 1 + path_len > len(decrypted): + logger.debug( + "process_path_ack_variants: path_len=%d exceeds decrypted len=%d " + "for src=0x%02x", + path_len, + len(decrypted), + src_hash, + ) + continue + # Path update and contact_path_updated are handled by ProtocolResponseHandler + # If this PATH carries an ACK, return it so send_confirmed can fire + extra_start = 1 + path_len + if len(decrypted) >= extra_start + 1 + 4 and decrypted[extra_start] == PAYLOAD_TYPE_ACK: + return int.from_bytes(decrypted[extra_start + 1 : extra_start + 5], "little") + return None + if contacts_tried > 0: + logger.debug( + "process_path_ack_variants: no contact decrypted successfully for src=0x%02x " + "(tried %d)", + src_hash, + contacts_tried, + ) + return None + + async def _notify_ack_received(self, crc: int) -> None: + """Called by path handler when PATH packet contained an ACK.""" + await self._apply_ack(crc) + + +# --------------------------------------------------------------------------- +# Raw custom payload handler: fires raw_data_received (PUSH 0x84) +# --------------------------------------------------------------------------- + + +class _RawCustomHandler: + """Handles PAYLOAD_TYPE_RAW_CUSTOM packets; fires raw_data_received(payload, snr, rssi).""" + + def __init__(self, bridge: "CompanionBridge") -> None: + self._bridge = bridge + + @staticmethod + def payload_type() -> int: + return PAYLOAD_TYPE_RAW_CUSTOM + + async def __call__(self, packet: Packet) -> None: + payload_bytes = bytes(packet.payload) if packet.payload else b"" + snr = packet.get_snr() if hasattr(packet, "get_snr") else getattr(packet, "_snr", 0) + rssi = packet.rssi if hasattr(packet, "rssi") else getattr(packet, "_rssi", 0) + await self._bridge._fire_callbacks("raw_data_received", payload_bytes, snr, rssi) + + +# --------------------------------------------------------------------------- +# Main CompanionBridge class +# --------------------------------------------------------------------------- + + +class CompanionBridge(CompanionBase): + """Repeater-integrated companion: shared dispatcher, packet_injector for TX. + + No MeshNode, no radio. Host calls process_received_packet when packets + destined for this companion arrive. All TX goes through packet_injector. + """ + + def __init__( + self, + identity: LocalIdentity, + packet_injector: Callable[..., Any], + node_name: str = "pyMC", + adv_type: int = ADV_TYPE_CHAT, + max_contacts: int = DEFAULT_MAX_CONTACTS, + max_channels: int = DEFAULT_MAX_CHANNELS, + offline_queue_size: int = DEFAULT_OFFLINE_QUEUE_SIZE, + radio_config: Optional[dict] = None, + authenticate_callback: Optional[Callable[..., tuple[bool, int]]] = None, + initial_contacts: Optional[Iterable[Any]] = None, + ) -> None: + """Initialise the companion bridge.""" + self._init_companion_stores( + identity=identity, + node_name=node_name, + adv_type=adv_type, + max_contacts=max_contacts, + max_channels=max_channels, + offline_queue_size=offline_queue_size, + radio_config=radio_config, + initial_contacts=initial_contacts, + ) + self._packet_injector = packet_injector + + async def _handler_send_packet(pkt: Packet, wait_for_ack: bool = False) -> bool: + return await self._packet_injector(pkt, wait_for_ack=wait_for_ack) + + def _login_send_callback(pkt: Packet, delay_ms: int) -> None: + async def _delayed_send() -> None: + await asyncio.sleep(delay_ms / 1000.0) + await self._packet_injector(pkt, wait_for_ack=False) + + asyncio.create_task(_delayed_send()) + + def _log(msg: str) -> None: + logger.debug(f"[CompanionBridge] {msg}") + + ack_handler = _BridgeAckHandler(self) + + # Use shared factory for the core protocol handlers + core = create_core_handlers( + identity=identity, + contacts=self.contacts, + channels=self.channels, + event_service=self._event_service, + send_packet_fn=_handler_send_packet, + log_fn=_log, + node_name=node_name, + radio_config=self._radio_config, + ack_handler=ack_handler, + ) + + # Bridge-specific: LoginServerHandler for incoming login requests + auth_cb = authenticate_callback + if auth_cb is None: + + def _reject_all(*args, **kwargs) -> tuple[bool, int]: + return (False, 0) + + auth_cb = _reject_all + + login_server_handler = LoginServerHandler( + identity, _log, authenticate_callback=auth_cb, is_room_server=False + ) + login_server_handler.set_send_packet_callback(_login_send_callback) + + self._handlers: dict[int, Any] = { + PAYLOAD_TYPE_ACK: ack_handler, + PAYLOAD_TYPE_TXT_MSG: core.text_handler, + PAYLOAD_TYPE_ADVERT: core.advert_handler, + PAYLOAD_TYPE_PATH: core.path_handler, + PAYLOAD_TYPE_ANON_REQ: login_server_handler, + PAYLOAD_TYPE_GRP_TXT: core.group_text_handler, + PAYLOAD_TYPE_RESPONSE: core.login_response_handler, + PAYLOAD_TYPE_RAW_CUSTOM: _RawCustomHandler(self), + } + + self._protocol_response_handler = core.protocol_response_handler + self._login_response_handler = core.login_response_handler + self._text_handler_ref = core.text_handler + core.protocol_response_handler.set_binary_response_callback(self._on_binary_response) + core.protocol_response_handler.set_packet_injector(self._packet_injector) + core.protocol_response_handler.set_contact_path_updated_callback( + self._on_contact_path_updated + ) + + # ------------------------------------------------------------------------- + # Handler accessors (used by CompanionBase concrete send methods) + # ------------------------------------------------------------------------- + + def _get_protocol_response_handler(self) -> Any: + return self._protocol_response_handler + + def _get_login_response_handler(self) -> Any: + return self._login_response_handler + + def _get_text_handler(self) -> Any: + return self._text_handler_ref + + def _get_group_text_handler(self): + """Return the group text handler for name sync.""" + return self._handlers.get(PAYLOAD_TYPE_GRP_TXT) + + # ------------------------------------------------------------------------- + # RX Entry Point + # ------------------------------------------------------------------------- + + async def process_received_packet(self, packet: Packet) -> None: + """Process a packet destined for this companion.""" + ptype = packet.header >> 2 & 0x0F + route_type = packet.header & 0x03 + is_flood = route_type in (ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD) + self.stats.record_rx(is_flood=is_flood) + + handler = self._handlers.get(ptype) + if handler: + try: + await handler(packet) + except Exception as e: + logger.error(f"Handler error for type {ptype:02X}: {e}") + + # NOTE: PATH packets are already delivered to protocol_response_handler + # via PathHandler.__call__ (path.py), which runs as the handler above. + # No duplicate call here — it would cause double decryption and could + # deliver the result to response waiters twice. + + # ------------------------------------------------------------------------- + # Abstract method implementations + # ------------------------------------------------------------------------- + + async def _send_packet(self, pkt: Packet, wait_for_ack: bool = False) -> bool: + """Send a packet via the packet_injector.""" + return await self._packet_injector(pkt, wait_for_ack=wait_for_ack) + + # ------------------------------------------------------------------------- + # Lifecycle + # ------------------------------------------------------------------------- + + async def start(self) -> None: + self._running = True + logger.info( + f"CompanionBridge started: name={self.prefs.node_name}, " + f"key={self._identity.get_public_key().hex()[:16]}..." + ) + + async def stop(self) -> None: + self._running = False + logger.info("CompanionBridge stopped") + + @property + def is_running(self) -> bool: + return self._running + + # ------------------------------------------------------------------------- + # Key Management + # ------------------------------------------------------------------------- + + def import_private_key(self, key: bytes) -> bool: + try: + self._identity = LocalIdentity(seed=key) + logger.info(f"Imported new identity: {self._identity.get_public_key().hex()[:16]}...") + return True + except Exception as e: + logger.error(f"Error importing private key: {e}") + return False diff --git a/src/pymc_core/companion/companion_radio.py b/src/pymc_core/companion/companion_radio.py new file mode 100644 index 0000000..75c8dea --- /dev/null +++ b/src/pymc_core/companion/companion_radio.py @@ -0,0 +1,290 @@ +""" +MeshCore Companion Radio - Python-native implementation. + +Provides the same feature set as the MeshCore companion radio firmware +(meshcore-dev/MeshCore/examples/companion_radio), implemented as a +high-level wrapper around MeshNode with in-memory contact, channel, +message queue, path cache, and statistics management. +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any, Iterable, Optional + +from ..node.node import MeshNode +from ..protocol import LocalIdentity, Packet +from ..protocol.constants import ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD +from .companion_base import CompanionBase +from .constants import ( + ADV_TYPE_CHAT, + DEFAULT_MAX_CHANNELS, + DEFAULT_MAX_CONTACTS, + DEFAULT_OFFLINE_QUEUE_SIZE, +) + +logger = logging.getLogger("CompanionRadio") + + +class CompanionRadio(CompanionBase): + """Python-native MeshCore companion radio. + + Wraps MeshNode and augments it with application-layer state and services + that the C++ companion radio firmware provides: contact management, + messaging with offline queue, advertisement broadcasting, channel + management, path tracking, signing, telemetry, statistics, and device + configuration. + + Example: + ```python + from pymc_core import CompanionRadio, LocalIdentity + from pymc_core.hardware import KissModemWrapper + + radio = KissModemWrapper("/dev/ttyUSB0") + radio.connect() + identity = LocalIdentity() + companion = CompanionRadio(radio, identity, node_name="myNode") + + async def main(): + await companion.start() + print(f"Key: {companion.get_public_key().hex()}") + await companion.advertise() + await companion.stop() + + asyncio.run(main()) + ``` + """ + + def __init__( + self, + radio: Any, + identity: LocalIdentity, + node_name: str = "pyMC", + adv_type: int = ADV_TYPE_CHAT, + max_contacts: int = DEFAULT_MAX_CONTACTS, + max_channels: int = DEFAULT_MAX_CHANNELS, + offline_queue_size: int = DEFAULT_OFFLINE_QUEUE_SIZE, + radio_config: Optional[dict] = None, + initial_contacts: Optional[Iterable[Any]] = None, + ) -> None: + """Initialise the companion radio.""" + self._init_companion_stores( + identity=identity, + node_name=node_name, + adv_type=adv_type, + max_contacts=max_contacts, + max_channels=max_channels, + offline_queue_size=offline_queue_size, + radio_config=radio_config, + initial_contacts=initial_contacts, + ) + self._radio = radio + self._dispatcher_task: Optional[asyncio.Task] = None + + self.node = MeshNode( + radio=radio, + local_identity=identity, + config={ + "node": {"name": node_name}, + "radio": self._radio_config, + }, + contacts=self.contacts, + channel_db=self.channels, + event_service=self._event_service, + ) + self._setup_packet_callbacks() + + # ------------------------------------------------------------------------- + # Abstract method implementations + # ------------------------------------------------------------------------- + + async def _send_packet(self, pkt: Packet, wait_for_ack: bool = False) -> bool: + """Send a packet via the MeshNode dispatcher.""" + return await self.node.dispatcher.send_packet(pkt, wait_for_ack=wait_for_ack) + + # ------------------------------------------------------------------------- + # Handler accessors (used by CompanionBase concrete send methods) + # ------------------------------------------------------------------------- + + def _get_protocol_response_handler(self) -> Any: + return getattr(self.node.dispatcher, "protocol_response_handler", None) + + def _get_login_response_handler(self) -> Any: + return getattr(self.node.dispatcher, "login_response_handler", None) + + def _get_text_handler(self) -> Any: + return getattr(self.node.dispatcher, "text_message_handler", None) + + # ------------------------------------------------------------------------- + # Lifecycle + # ------------------------------------------------------------------------- + + async def start(self) -> None: + if self._running: + logger.warning("CompanionRadio already running") + return + self._running = True + self._dispatcher_task = asyncio.create_task(self.node.start()) + logger.info( + f"CompanionRadio started: name={self.prefs.node_name}, " + f"key={self._identity.get_public_key().hex()[:16]}..." + ) + + async def stop(self) -> None: + self._running = False + try: + self.node.dispatcher.remove_raw_packet_subscriber(self._on_raw_packet_rx_log) + except Exception: + logger.debug("Remove raw packet subscriber during stop failed", exc_info=True) + if self._dispatcher_task: + self._dispatcher_task.cancel() + try: + await self._dispatcher_task + except asyncio.CancelledError: + pass + self._dispatcher_task = None + self.node.stop() + logger.info("CompanionRadio stopped") + + @property + def is_running(self) -> bool: + return self._running + + # ------------------------------------------------------------------------- + # Flood Scope (sync to dispatcher) + # ------------------------------------------------------------------------- + + def set_flood_scope(self, transport_key: Optional[bytes] = None) -> None: + """Set or clear flood scope and propagate to the dispatcher.""" + super().set_flood_scope(transport_key) + self.node.dispatcher.flood_transport_key = self._flood_transport_key + + def set_flood_region(self, region_name: Optional[str] = None) -> None: + """Set flood region and propagate to the dispatcher.""" + super().set_flood_region(region_name) + self.node.dispatcher.flood_transport_key = self._flood_transport_key + + # ------------------------------------------------------------------------- + # Device Configuration (overrides for radio hardware) + # ------------------------------------------------------------------------- + + def set_advert_name(self, name: str) -> None: + super().set_advert_name(name) + self.node.node_name = self.prefs.node_name + + def _get_group_text_handler(self): + """Return the group text handler for name sync.""" + return getattr(self.node.dispatcher, "group_text_handler", None) + + def set_radio_params(self, freq_hz: int, bw_hz: int, sf: int, cr: int) -> bool: + super().set_radio_params(freq_hz, bw_hz, sf, cr) + if hasattr(self._radio, "configure_radio"): + try: + self._radio.configure_radio( + frequency=freq_hz, + bandwidth=bw_hz, + spreading_factor=sf, + coding_rate=cr, + ) + return True + except Exception as e: + logger.error(f"Error configuring radio: {e}") + return False + return True + + def set_tx_power(self, power_dbm: int) -> bool: + super().set_tx_power(power_dbm) + if hasattr(self._radio, "set_tx_power"): + try: + self._radio.set_tx_power(power_dbm) + return True + except Exception as e: + logger.error(f"Error setting TX power: {e}") + return False + return True + + # ------------------------------------------------------------------------- + # Key Management + # ------------------------------------------------------------------------- + + def import_private_key(self, key: bytes) -> bool: + try: + self._identity = LocalIdentity(seed=key) + self._pending_ack_crcs.clear() + self.node = MeshNode( + radio=self._radio, + local_identity=self._identity, + config={ + "node": {"name": self.prefs.node_name}, + "radio": self._radio_config, + }, + contacts=self.contacts, + channel_db=self.channels, + event_service=self._event_service, + ) + self._setup_packet_callbacks() + logger.info(f"Imported new identity: {self._identity.get_public_key().hex()[:16]}...") + return True + except Exception as e: + logger.error(f"Error importing private key: {e}") + return False + + # ------------------------------------------------------------------------- + # Statistics (override for radio hardware) + # ------------------------------------------------------------------------- + + def _get_radio_stats(self) -> dict: + radio_stats = super()._get_radio_stats() + if hasattr(self._radio, "get_last_rssi"): + radio_stats["last_rssi"] = self._radio.get_last_rssi() + if hasattr(self._radio, "get_last_snr"): + radio_stats["last_snr"] = self._radio.get_last_snr() + return radio_stats + + # ------------------------------------------------------------------------- + # Internal + # ------------------------------------------------------------------------- + + def _setup_packet_callbacks(self) -> None: + dispatcher = self.node.dispatcher + dispatcher.set_packet_received_callback(self._on_packet_received) + dispatcher.set_packet_sent_callback(self._on_packet_sent) + dispatcher.set_ack_received_listener(self._on_ack_received) + dispatcher.add_raw_packet_subscriber(self._on_raw_packet_rx_log) + dispatcher.raw_data_received_callback = self._on_raw_custom_received + if ( + hasattr(dispatcher, "protocol_response_handler") + and dispatcher.protocol_response_handler + ): + dispatcher.protocol_response_handler.set_binary_response_callback( + self._on_binary_response + ) + dispatcher.protocol_response_handler.set_contact_path_updated_callback( + self._on_contact_path_updated + ) + + async def _on_packet_received(self, pkt: Any) -> None: + route_type = pkt.get_route_type() + is_flood = route_type in (ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD) + self.stats.record_rx(is_flood=is_flood) + + async def _on_raw_packet_rx_log(self, pkt: Any, data: bytes, analysis: Any) -> None: + """Dispatcher raw-packet subscriber: fire rx_log_data(snr, rssi, raw_bytes).""" + snr = getattr(pkt, "snr", getattr(pkt, "_snr", 0.0)) + rssi = getattr(pkt, "rssi", getattr(pkt, "_rssi", 0)) + await self._fire_callbacks("rx_log_data", snr, rssi, data) + + async def _on_ack_received(self, crc: int) -> None: + """Called by dispatcher when an ACK CRC is received; fire send_confirmed if pending.""" + await self._try_confirm_send(crc) + + async def _on_raw_custom_received(self, pkt: Packet) -> None: + """Dispatcher RAW_CUSTOM handler: fire raw_data_received(payload, snr, rssi).""" + payload = bytes(pkt.payload) if pkt.payload else b"" + snr = pkt.get_snr() if hasattr(pkt, "get_snr") else getattr(pkt, "_snr", 0) + rssi = pkt.rssi if hasattr(pkt, "rssi") else getattr(pkt, "_rssi", 0) + await self._fire_callbacks("raw_data_received", payload, snr, rssi) + + async def _on_packet_sent(self, pkt: Any) -> None: + pass diff --git a/src/pymc_core/companion/constants.py b/src/pymc_core/companion/constants.py new file mode 100644 index 0000000..f914f47 --- /dev/null +++ b/src/pymc_core/companion/constants.py @@ -0,0 +1,236 @@ +"""Companion radio constants for application-layer mesh networking features.""" + +from __future__ import annotations + +import base64 +from enum import IntEnum + +# --------------------------------------------------------------------------- +# ADV Types (contact/node classification) +# --------------------------------------------------------------------------- +ADV_TYPE_CHAT = 1 +ADV_TYPE_REPEATER = 2 +ADV_TYPE_ROOM = 3 +ADV_TYPE_SENSOR = 4 + +# --------------------------------------------------------------------------- +# Text Types +# --------------------------------------------------------------------------- +TXT_TYPE_PLAIN = 0 +TXT_TYPE_CLI_DATA = 1 +TXT_TYPE_SIGNED_PLAIN = 2 + +# --------------------------------------------------------------------------- +# Telemetry Modes +# --------------------------------------------------------------------------- +TELEM_MODE_DENY = 0 +TELEM_MODE_ALLOW_FLAGS = 1 +TELEM_MODE_ALLOW_ALL = 2 + +# --------------------------------------------------------------------------- +# Advert Location Policy +# --------------------------------------------------------------------------- +ADVERT_LOC_NONE = 0 +ADVERT_LOC_SHARE = 1 + +# --------------------------------------------------------------------------- +# Auto-Add Config Bitmask +# --------------------------------------------------------------------------- +AUTOADD_OVERWRITE_OLDEST = 0x01 +AUTOADD_CHAT = 0x02 +AUTOADD_REPEATER = 0x04 +AUTOADD_ROOM = 0x08 +AUTOADD_SENSOR = 0x10 + +# --------------------------------------------------------------------------- +# Message Send Result +# --------------------------------------------------------------------------- +MSG_SEND_FAILED = 0 +MSG_SEND_SENT_FLOOD = 1 +MSG_SEND_SENT_DIRECT = 2 + +# --------------------------------------------------------------------------- +# Stats Types +# --------------------------------------------------------------------------- +STATS_TYPE_CORE = 0 +STATS_TYPE_RADIO = 1 +STATS_TYPE_PACKETS = 2 + + +# --------------------------------------------------------------------------- +# Binary request types (CMD_SEND_BINARY_REQ / PUSH_CODE_BINARY_RESPONSE) +# --------------------------------------------------------------------------- +class BinaryReqType(IntEnum): + """Binary request type codes (companion frame protocol).""" + + STATUS = 0x01 + KEEP_ALIVE = 0x02 + TELEMETRY = 0x03 + MMA = 0x04 + ACL = 0x05 + NEIGHBOURS = 0x06 + OWNER_INFO = 0x07 # REQ_TYPE_GET_OWNER_INFO: variable "version\nname\nowner" + + +# --------------------------------------------------------------------------- +# Protocol Codes (used in create_protocol_request / send_protocol_request) +# --------------------------------------------------------------------------- +PROTOCOL_CODE_RAW_DATA = 0x00 +PROTOCOL_CODE_BINARY_REQ = 0x02 +PROTOCOL_CODE_ANON_REQ = 0x07 + +# --------------------------------------------------------------------------- +# Default configuration +# --------------------------------------------------------------------------- +DEFAULT_RESPONSE_TIMEOUT_MS = 10000 +DEFAULT_MAX_CONTACTS = 1000 +DEFAULT_OFFLINE_QUEUE_SIZE = 512 +DEFAULT_MAX_CHANNELS = 40 +CONTACT_NAME_SIZE = 32 +MAX_SIGN_DATA_SIZE = 8192 # 8KB signing buffer (matches firmware) +MAX_PENDING_ACK_CRCS = 64 + +# =========================================================================== +# Frame Protocol Constants (MeshCore Companion Radio Protocol) +# =========================================================================== + +# Protocol version reported in RESP_CODE_DEVICE_INFO; phone uses 9+ to infer +# CMD_SEND_ANON_REQ (owner requests, etc.) is supported. +# 10+ provides support for multi-byte path lengths. +FIRMWARE_VER_CODE = 10 + +# --------------------------------------------------------------------------- +# Commands (app -> radio) +# --------------------------------------------------------------------------- +CMD_APP_START = 1 +CMD_SEND_TXT_MSG = 2 +CMD_SEND_CHANNEL_TXT_MSG = 3 +CMD_GET_CONTACTS = 4 +CMD_GET_DEVICE_TIME = 5 +CMD_SET_DEVICE_TIME = 6 +CMD_SEND_SELF_ADVERT = 7 +CMD_SET_ADVERT_NAME = 8 +CMD_ADD_UPDATE_CONTACT = 9 +CMD_SYNC_NEXT_MESSAGE = 10 +CMD_SET_RADIO_PARAMS = 11 +CMD_SET_RADIO_TX_POWER = 12 +CMD_RESET_PATH = 13 +CMD_SET_ADVERT_LATLON = 14 +CMD_REMOVE_CONTACT = 15 +CMD_SHARE_CONTACT = 16 +CMD_EXPORT_CONTACT = 17 +CMD_IMPORT_CONTACT = 18 +CMD_REBOOT = 19 +CMD_GET_BATT_AND_STORAGE = 20 +CMD_SET_TUNING_PARAMS = 21 +CMD_DEVICE_QUERY = 22 +CMD_EXPORT_PRIVATE_KEY = 23 +CMD_IMPORT_PRIVATE_KEY = 24 +CMD_SEND_RAW_DATA = 25 +CMD_SEND_LOGIN = 26 +CMD_SEND_STATUS_REQ = 27 +CMD_HAS_CONNECTION = 28 +CMD_LOGOUT = 29 +CMD_GET_CONTACT_BY_KEY = 30 +CMD_GET_CHANNEL = 31 +CMD_SET_CHANNEL = 32 +CMD_SIGN_START = 33 +CMD_SIGN_DATA = 34 +CMD_SIGN_FINISH = 35 +CMD_SEND_TRACE_PATH = 36 +CMD_SET_DEVICE_PIN = 37 +CMD_SET_OTHER_PARAMS = 38 +CMD_SEND_TELEMETRY_REQ = 39 +CMD_GET_CUSTOM_VARS = 40 +CMD_SET_CUSTOM_VAR = 41 +CMD_GET_ADVERT_PATH = 42 +CMD_GET_TUNING_PARAMS = 43 +CMD_SEND_BINARY_REQ = 50 +CMD_FACTORY_RESET = 51 +CMD_SEND_PATH_DISCOVERY_REQ = 52 +CMD_SET_FLOOD_SCOPE = 54 +CMD_SEND_CONTROL_DATA = 55 +CMD_GET_STATS = 56 +CMD_SEND_ANON_REQ = 57 +CMD_SET_AUTOADD_CONFIG = 58 +CMD_GET_AUTOADD_CONFIG = 59 +CMD_SET_PATH_HASH_MODE = 61 + +# --------------------------------------------------------------------------- +# Response codes (radio -> app) +# --------------------------------------------------------------------------- +RESP_CODE_OK = 0 +RESP_CODE_ERR = 1 +RESP_CODE_CONTACTS_START = 2 +RESP_CODE_CONTACT = 3 +RESP_CODE_END_OF_CONTACTS = 4 +RESP_CODE_SELF_INFO = 5 +RESP_CODE_SENT = 6 +RESP_CODE_CONTACT_MSG_RECV = 7 +RESP_CODE_CHANNEL_MSG_RECV = 8 +RESP_CODE_CURR_TIME = 9 +RESP_CODE_NO_MORE_MESSAGES = 10 +RESP_CODE_EXPORT_CONTACT = 11 +RESP_CODE_BATT_AND_STORAGE = 12 +RESP_CODE_DEVICE_INFO = 13 +RESP_CODE_PRIVATE_KEY = 14 +RESP_CODE_DISABLED = 15 +RESP_CODE_CONTACT_MSG_RECV_V3 = 16 +RESP_CODE_CHANNEL_MSG_RECV_V3 = 17 +RESP_CODE_CHANNEL_INFO = 18 +RESP_CODE_SIGN_START = 19 +RESP_CODE_SIGNATURE = 20 +RESP_CODE_CUSTOM_VARS = 21 +RESP_CODE_ADVERT_PATH = 22 +RESP_CODE_TUNING_PARAMS = 23 +RESP_CODE_STATS = 24 +RESP_CODE_AUTOADD_CONFIG = 25 + +# --------------------------------------------------------------------------- +# Push codes (radio -> app, unsolicited) +# --------------------------------------------------------------------------- +PUSH_CODE_ADVERT = 0x80 +PUSH_CODE_PATH_UPDATED = 0x81 +PUSH_CODE_SEND_CONFIRMED = 0x82 +PUSH_CODE_MSG_WAITING = 0x83 +PUSH_CODE_RAW_DATA = 0x84 +PUSH_CODE_LOGIN_SUCCESS = 0x85 +PUSH_CODE_LOGIN_FAIL = 0x86 +PUSH_CODE_STATUS_RESPONSE = 0x87 +PUSH_CODE_LOG_RX_DATA = 0x88 +PUSH_CODE_TRACE_DATA = 0x89 +PUSH_CODE_NEW_ADVERT = 0x8A +PUSH_CODE_TELEMETRY_RESPONSE = 0x8B +PUSH_CODE_BINARY_RESPONSE = 0x8C +PUSH_CODE_PATH_DISCOVERY_RESPONSE = 0x8D +PUSH_CODE_CONTROL_DATA = 0x8E +PUSH_CODE_CONTACT_DELETED = 0x8F +PUSH_CODE_CONTACTS_FULL = 0x90 + +# --------------------------------------------------------------------------- +# Error codes +# --------------------------------------------------------------------------- +ERR_CODE_UNSUPPORTED_CMD = 1 +ERR_CODE_NOT_FOUND = 2 +ERR_CODE_TABLE_FULL = 3 +ERR_CODE_BAD_STATE = 4 +ERR_CODE_FILE_IO_ERROR = 5 +ERR_CODE_ILLEGAL_ARG = 6 + +# --------------------------------------------------------------------------- +# Frame delimiters (USB/TCP: > = outbound, < = inbound) +# --------------------------------------------------------------------------- +FRAME_OUTBOUND_PREFIX = 0x3E # '>' +FRAME_INBOUND_PREFIX = 0x3C # '<' +# Match firmware: writeFrame() refuses to send if len > MAX_FRAME_SIZE; BLE MTU +# is set to this (e.g. BLEDevice::setMTU(MAX_FRAME_SIZE)). Frame = prefix(1) + len(2) + payload. +MAX_FRAME_SIZE = 172 +MAX_PAYLOAD_SIZE = MAX_FRAME_SIZE - 3 # max bytes after prefix + 2-byte length +PUB_KEY_SIZE = 32 +MAX_PATH_SIZE = 64 + +# --------------------------------------------------------------------------- +# Default public channel PSK (from firmware MeshCore companion_radio example) +# --------------------------------------------------------------------------- +PUBLIC_GROUP_PSK = b"izOH6cXN6mrJ5e26oRXNcg==" +DEFAULT_PUBLIC_CHANNEL_SECRET = base64.b64decode(PUBLIC_GROUP_PSK) diff --git a/src/pymc_core/companion/contact_store.py b/src/pymc_core/companion/contact_store.py new file mode 100644 index 0000000..2a3c395 --- /dev/null +++ b/src/pymc_core/companion/contact_store.py @@ -0,0 +1,268 @@ +"""In-memory contact storage compatible with MeshNode's contacts interface.""" + +from __future__ import annotations + +from typing import Iterable, Iterator, Optional, Tuple + +from .constants import DEFAULT_MAX_CONTACTS +from .models import Contact + + +class ContactProxy: + """Wraps a Contact to provide the interface expected by MeshNode handlers. + + The existing handlers expect contacts with: + - public_key as a hex string (not bytes) + - name as a string + - out_path as a list + - type as an int + """ + + def __init__(self, contact: Contact): + self._contact = contact + self.public_key = contact.public_key.hex() + self.name = contact.name + self.type = contact.adv_type + self.flags = contact.flags + self.out_path = list(contact.out_path) if contact.out_path else [] + self.out_path_len = contact.out_path_len + self.sync_since = contact.sync_since + self.last_advert_timestamp = contact.last_advert_timestamp + self.lastmod = contact.lastmod + self.gps_lat = contact.gps_lat + self.gps_lon = contact.gps_lon + + def _sync_from_contact(self) -> None: + """Update proxy fields from the underlying Contact.""" + c = self._contact + self.public_key = c.public_key.hex() + self.name = c.name + self.type = c.adv_type + self.flags = c.flags + self.out_path = list(c.out_path) if c.out_path else [] + self.out_path_len = c.out_path_len + self.sync_since = c.sync_since + self.last_advert_timestamp = c.last_advert_timestamp + self.lastmod = c.lastmod + self.gps_lat = c.gps_lat + self.gps_lon = c.gps_lon + + +class ContactStore: + """In-memory contact storage compatible with MeshNode's contacts interface. + + Provides both the interface expected by MeshNode/Dispatcher (contacts property, + get_by_name, list_contacts) and companion radio CRUD operations (add, update, + remove, get_by_key, etc.). + + The store can be populated from external sources using load_from() or + load_from_dicts() for easy integration with databases and configuration files. + """ + + def __init__(self, max_contacts: int = DEFAULT_MAX_CONTACTS): + self._contacts: dict[bytes, Contact] = {} # keyed by public_key bytes + self._proxies: dict[bytes, ContactProxy] = {} # cached proxies + self._max_contacts = max_contacts + + @property + def max_contacts(self) -> int: + """Maximum number of contacts (read-only). Used by companion protocol device info.""" + return self._max_contacts + + # ------------------------------------------------------------------ + # Interface expected by MeshNode/Dispatcher/Handlers + # ------------------------------------------------------------------ + + @property + def contacts(self) -> list: + """Return contacts as list of proxy objects with hex public_key attribute.""" + return list(self._proxies.values()) + + def list_contacts(self) -> list: + """Return contacts list (used by ProtocolResponseHandler).""" + return self.contacts + + def get_by_name(self, name: str) -> Optional[ContactProxy]: + """Lookup by name (required by MeshNode._get_contact_or_raise).""" + for proxy in self._proxies.values(): + if proxy.name == name: + return proxy + return None + + # ------------------------------------------------------------------ + # Companion radio CRUD operations + # ------------------------------------------------------------------ + + def add(self, contact: Contact) -> bool: + """Add a new contact. Returns False if store is full or key already exists.""" + if contact.public_key in self._contacts: + return self.update(contact) + if len(self._contacts) >= self._max_contacts: + return False + self._contacts[contact.public_key] = contact + self._proxies[contact.public_key] = ContactProxy(contact) + return True + + def add_or_overwrite(self, contact: Contact) -> Tuple[bool, Optional[bytes]]: + """Add a contact, overwriting the oldest non-favourite if store is full. + + Mirrors C++ BaseChatMesh::allocateContactSlot (BaseChatMesh.cpp:70-90). + + Returns: + (success, overwritten_pubkey_or_None) + """ + if contact.public_key in self._contacts: + return self.update(contact), None + if len(self._contacts) >= self._max_contacts: + # Find oldest non-favourite (flags bit 0 = favourite) + oldest_key: Optional[bytes] = None + oldest_lastmod = 0xFFFFFFFF + for key, c in self._contacts.items(): + if (c.flags & 0x01) == 0 and c.lastmod < oldest_lastmod: + oldest_lastmod = c.lastmod + oldest_key = key + if oldest_key is None: + return False, None # all contacts are favourites + overwritten = oldest_key + self.remove(oldest_key) + self._contacts[contact.public_key] = contact + self._proxies[contact.public_key] = ContactProxy(contact) + return True, overwritten + self._contacts[contact.public_key] = contact + self._proxies[contact.public_key] = ContactProxy(contact) + return True, None + + def update(self, contact: Contact) -> bool: + """Update an existing contact. Returns False if not found.""" + if contact.public_key not in self._contacts: + return self.add(contact) + self._contacts[contact.public_key] = contact + self._proxies[contact.public_key] = ContactProxy(contact) + return True + + def remove(self, public_key: bytes) -> bool: + """Remove a contact by public key. Returns False if not found.""" + if public_key not in self._contacts: + return False + del self._contacts[public_key] + del self._proxies[public_key] + return True + + def get_by_key(self, public_key: bytes) -> Optional[Contact]: + """Lookup a contact by full 32-byte public key.""" + return self._contacts.get(public_key) + + def get_by_key_prefix(self, prefix: bytes) -> Optional[Contact]: + """Lookup a contact by public key prefix (1-32 bytes).""" + for key, contact in self._contacts.items(): + if key[: len(prefix)] == prefix: + return contact + return None + + def get_all(self, since: int = 0) -> list[Contact]: + """Get all contacts, optionally filtered by lastmod >= since.""" + if since == 0: + return list(self._contacts.values()) + return [c for c in self._contacts.values() if c.lastmod >= since] + + def get_count(self) -> int: + """Return the number of stored contacts.""" + return len(self._contacts) + + def is_full(self) -> bool: + """Check if the contact store is at capacity.""" + return len(self._contacts) >= self._max_contacts + + def clear(self) -> None: + """Remove all contacts.""" + self._contacts.clear() + self._proxies.clear() + + # ------------------------------------------------------------------ + # Bulk loading from external sources + # ------------------------------------------------------------------ + + def load_from(self, contacts: Iterable[Contact]) -> None: + """Bulk-load contacts from any iterable of Contact objects. + + Replaces all existing contacts. + """ + self.clear() + for contact in contacts: + if len(self._contacts) >= self._max_contacts: + break + self._contacts[contact.public_key] = contact + self._proxies[contact.public_key] = ContactProxy(contact) + + def load_from_dicts(self, records: Iterable[dict]) -> None: + """Bulk-load contacts from dicts. + + Each dict must have 'public_key' (hex string or bytes) and 'name' keys. + Optional keys: 'adv_type', 'flags', 'out_path', 'out_path_len', + 'last_advert_timestamp', 'lastmod', 'gps_lat', 'gps_lon', 'sync_since'. + + Replaces all existing contacts. + """ + self.clear() + for rec in records: + if len(self._contacts) >= self._max_contacts: + break + + pub_key = rec["public_key"] + if isinstance(pub_key, str): + pub_key = bytes.fromhex(pub_key) + + out_path = rec.get("out_path", b"") + if isinstance(out_path, str): + out_path = bytes.fromhex(out_path) + elif isinstance(out_path, list): + out_path = bytes(out_path) + + contact = Contact( + public_key=pub_key, + name=rec.get("name", ""), + adv_type=rec.get("adv_type", 0), + flags=rec.get("flags", 0), + out_path_len=-1 + if rec.get("out_path_len", -1) in (-1, 255) + else rec.get("out_path_len", -1), + out_path=out_path, + last_advert_timestamp=rec.get("last_advert_timestamp", 0), + lastmod=rec.get("lastmod", 0), + gps_lat=rec.get("gps_lat", 0.0), + gps_lon=rec.get("gps_lon", 0.0), + sync_since=rec.get("sync_since", 0), + ) + self._contacts[pub_key] = contact + self._proxies[pub_key] = ContactProxy(contact) + + def to_dicts(self) -> list[dict]: + """Export all contacts as a list of plain dicts for serialization.""" + result = [] + for c in self._contacts.values(): + result.append( + { + "public_key": c.public_key.hex(), + "name": c.name, + "adv_type": c.adv_type, + "flags": c.flags, + "out_path_len": c.out_path_len, + "out_path": c.out_path.hex() if c.out_path else "", + "last_advert_timestamp": c.last_advert_timestamp, + "lastmod": c.lastmod, + "gps_lat": c.gps_lat, + "gps_lon": c.gps_lon, + "sync_since": c.sync_since, + } + ) + return result + + # ------------------------------------------------------------------ + # Iterator (matches firmware's iterator pattern) + # ------------------------------------------------------------------ + + def iterate(self, since: int = 0) -> Iterator[Contact]: + """Iterate over contacts, optionally filtered by lastmod >= since.""" + for contact in self._contacts.values(): + if since == 0 or contact.lastmod >= since: + yield contact diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py new file mode 100644 index 0000000..29cdb1b --- /dev/null +++ b/src/pymc_core/companion/frame_server.py @@ -0,0 +1,1800 @@ +""" +CompanionFrameServer - Standard MeshCore Companion Radio Protocol over TCP. + +Implements the full companion frame protocol: command dispatch, push callbacks, +and contact/message/channel management. Persistence is handled through +overridable hook methods so the base class works standalone (in-memory only) +while subclasses can add SQLite or other storage backends. + +Frame format: + Outbound (radio → app): ``>`` (0x3E) + 2-byte LE length + data + Inbound (app → radio): ``<`` (0x3C) + 2-byte LE length + data +""" + +import asyncio +import logging +import socket +import struct +import sys +import time +from typing import Any, Callable, Optional + +from ..protocol import CryptoUtils +from ..protocol.packet_utils import PathUtils +from .constants import ( + ADV_TYPE_CHAT, + CMD_ADD_UPDATE_CONTACT, + CMD_APP_START, + CMD_DEVICE_QUERY, + CMD_EXPORT_CONTACT, + CMD_EXPORT_PRIVATE_KEY, + CMD_GET_ADVERT_PATH, + CMD_GET_AUTOADD_CONFIG, + CMD_GET_BATT_AND_STORAGE, + CMD_GET_CHANNEL, + CMD_GET_CONTACT_BY_KEY, + CMD_GET_CONTACTS, + CMD_GET_CUSTOM_VARS, + CMD_GET_DEVICE_TIME, + CMD_GET_STATS, + CMD_IMPORT_CONTACT, + CMD_IMPORT_PRIVATE_KEY, + CMD_LOGOUT, + CMD_REMOVE_CONTACT, + CMD_RESET_PATH, + CMD_SEND_ANON_REQ, + CMD_SEND_BINARY_REQ, + CMD_SEND_CHANNEL_TXT_MSG, + CMD_SEND_CONTROL_DATA, + CMD_SEND_LOGIN, + CMD_SEND_PATH_DISCOVERY_REQ, + CMD_SEND_RAW_DATA, + CMD_SEND_SELF_ADVERT, + CMD_SEND_STATUS_REQ, + CMD_SEND_TELEMETRY_REQ, + CMD_SEND_TRACE_PATH, + CMD_SEND_TXT_MSG, + CMD_SET_ADVERT_LATLON, + CMD_SET_ADVERT_NAME, + CMD_SET_AUTOADD_CONFIG, + CMD_SET_CHANNEL, + CMD_SET_CUSTOM_VAR, + CMD_SET_DEVICE_TIME, + CMD_SET_FLOOD_SCOPE, + CMD_SET_OTHER_PARAMS, + CMD_SET_PATH_HASH_MODE, + CMD_SET_RADIO_PARAMS, + CMD_SET_RADIO_TX_POWER, + CMD_SET_TUNING_PARAMS, + CMD_SHARE_CONTACT, + CMD_SYNC_NEXT_MESSAGE, + ERR_CODE_BAD_STATE, + ERR_CODE_ILLEGAL_ARG, + ERR_CODE_NOT_FOUND, + ERR_CODE_TABLE_FULL, + ERR_CODE_UNSUPPORTED_CMD, + FIRMWARE_VER_CODE, + FRAME_INBOUND_PREFIX, + FRAME_OUTBOUND_PREFIX, + MAX_FRAME_SIZE, + MAX_PATH_SIZE, + MAX_PAYLOAD_SIZE, + PUB_KEY_SIZE, + PUSH_CODE_ADVERT, + PUSH_CODE_BINARY_RESPONSE, + PUSH_CODE_CONTACT_DELETED, + PUSH_CODE_CONTACTS_FULL, + PUSH_CODE_CONTROL_DATA, + PUSH_CODE_LOG_RX_DATA, + PUSH_CODE_LOGIN_FAIL, + PUSH_CODE_LOGIN_SUCCESS, + PUSH_CODE_MSG_WAITING, + PUSH_CODE_NEW_ADVERT, + PUSH_CODE_PATH_DISCOVERY_RESPONSE, + PUSH_CODE_PATH_UPDATED, + PUSH_CODE_RAW_DATA, + PUSH_CODE_SEND_CONFIRMED, + PUSH_CODE_STATUS_RESPONSE, + PUSH_CODE_TELEMETRY_RESPONSE, + PUSH_CODE_TRACE_DATA, + RESP_CODE_ADVERT_PATH, + RESP_CODE_AUTOADD_CONFIG, + RESP_CODE_BATT_AND_STORAGE, + RESP_CODE_CHANNEL_INFO, + RESP_CODE_CHANNEL_MSG_RECV, + RESP_CODE_CHANNEL_MSG_RECV_V3, + RESP_CODE_CONTACT, + RESP_CODE_CONTACT_MSG_RECV, + RESP_CODE_CONTACT_MSG_RECV_V3, + RESP_CODE_CONTACTS_START, + RESP_CODE_CURR_TIME, + RESP_CODE_CUSTOM_VARS, + RESP_CODE_DEVICE_INFO, + RESP_CODE_END_OF_CONTACTS, + RESP_CODE_ERR, + RESP_CODE_EXPORT_CONTACT, + RESP_CODE_NO_MORE_MESSAGES, + RESP_CODE_OK, + RESP_CODE_PRIVATE_KEY, + RESP_CODE_SELF_INFO, + RESP_CODE_SENT, + RESP_CODE_STATS, + STATS_TYPE_CORE, + STATS_TYPE_PACKETS, + STATS_TYPE_RADIO, +) +from .models import Contact, QueuedMessage + +logger = logging.getLogger("CompanionFrameServer") + + +def _build_advert_push_frames(contact: Contact) -> tuple[bytes, Optional[bytes]]: + """Build PUSH_CODE_ADVERT short frame and optional PUSH_CODE_NEW_ADVERT + full frame from contact. Thread-safe for ``asyncio.to_thread``.""" + pubkey_b = contact.public_key + if isinstance(pubkey_b, bytes): + pubkey_b = pubkey_b[:32].ljust(32, b"\x00") + else: + pubkey_b = b"\x00" * 32 + short = bytes([PUSH_CODE_ADVERT]) + pubkey_b + if not contact.name: + return (short, None) + op = contact.out_path if isinstance(contact.out_path, bytes) else bytes(contact.out_path or []) + op = op[:MAX_PATH_SIZE].ljust(MAX_PATH_SIZE, b"\x00") + nb = ( + contact.name.encode("utf-8", errors="replace") + if isinstance(contact.name, str) + else (contact.name if isinstance(contact.name, bytes) else b"") + )[:32].ljust(32, b"\x00") + opl_byte = 0xFF if contact.out_path_len < 0 else min(contact.out_path_len, 255) + full = ( + bytes([PUSH_CODE_NEW_ADVERT]) + + pubkey_b + + bytes([contact.adv_type, contact.flags, opl_byte]) + + op + + nb + + struct.pack(" async handler(data) + self._cmd_handlers = { + CMD_APP_START: self._cmd_app_start, + CMD_DEVICE_QUERY: self._cmd_device_query, + CMD_GET_CONTACTS: self._cmd_get_contacts, + CMD_GET_CONTACT_BY_KEY: self._cmd_get_contact_by_key, + CMD_SEND_TXT_MSG: self._cmd_send_txt_msg, + CMD_SEND_CHANNEL_TXT_MSG: self._cmd_send_channel_txt_msg, + CMD_SYNC_NEXT_MESSAGE: self._cmd_sync_next_message, + CMD_SEND_LOGIN: self._cmd_send_login, + CMD_SEND_STATUS_REQ: self._cmd_send_status_req, + CMD_SEND_TELEMETRY_REQ: self._cmd_send_telemetry_req, + CMD_SEND_SELF_ADVERT: self._cmd_send_self_advert, + CMD_SET_ADVERT_NAME: self._cmd_set_advert_name, + CMD_SET_ADVERT_LATLON: self._cmd_set_advert_latlon, + CMD_ADD_UPDATE_CONTACT: self._cmd_add_update_contact, + CMD_REMOVE_CONTACT: self._cmd_remove_contact, + CMD_RESET_PATH: self._cmd_reset_path, + CMD_GET_BATT_AND_STORAGE: self._cmd_get_batt_and_storage, + CMD_GET_STATS: self._cmd_get_stats, + CMD_GET_ADVERT_PATH: self._cmd_get_advert_path, + CMD_IMPORT_CONTACT: self._cmd_import_contact, + CMD_GET_CHANNEL: self._cmd_get_channel, + CMD_SET_CHANNEL: self._cmd_set_channel, + CMD_SEND_BINARY_REQ: self._cmd_send_binary_req, + CMD_SEND_ANON_REQ: self._cmd_send_anon_req, + CMD_SEND_PATH_DISCOVERY_REQ: self._cmd_send_path_discovery_req, + CMD_SEND_CONTROL_DATA: self._cmd_send_control_data, + CMD_SEND_TRACE_PATH: self._cmd_send_trace_path, + CMD_SET_FLOOD_SCOPE: self._cmd_set_flood_scope, + CMD_GET_DEVICE_TIME: self._cmd_get_device_time, + CMD_SET_DEVICE_TIME: self._cmd_set_device_time, + CMD_SET_RADIO_PARAMS: self._cmd_set_radio_params, + CMD_SET_RADIO_TX_POWER: self._cmd_set_tx_power, + CMD_SHARE_CONTACT: self._cmd_share_contact, + CMD_EXPORT_CONTACT: self._cmd_export_contact, + CMD_EXPORT_PRIVATE_KEY: self._cmd_export_private_key, + CMD_IMPORT_PRIVATE_KEY: self._cmd_import_private_key, + CMD_SET_TUNING_PARAMS: self._cmd_set_tuning_params, + CMD_LOGOUT: self._cmd_logout, + CMD_GET_CUSTOM_VARS: self._cmd_get_custom_vars, + CMD_SET_CUSTOM_VAR: self._cmd_set_custom_var, + CMD_SET_AUTOADD_CONFIG: self._cmd_set_autoadd_config, + CMD_GET_AUTOADD_CONFIG: self._cmd_get_autoadd_config, + CMD_SET_OTHER_PARAMS: self._cmd_set_other_params, + CMD_SEND_RAW_DATA: self._cmd_send_raw_data, + CMD_SET_PATH_HASH_MODE: self._cmd_set_path_hash_mode, + } + + # ------------------------------------------------------------------------- + # Lifecycle + # ------------------------------------------------------------------------- + + async def start(self) -> None: + """Start the TCP server.""" + self._server = await asyncio.start_server( + self._handle_client, + self.bind_address, + self.port, + ) + addr = ( + self._server.sockets[0].getsockname() + if self._server.sockets + else (self.bind_address, self.port) + ) + # Repeater passes hash as hex (first byte of pubkey, e.g. "f5"); accept decimal or hex. + try: + hash_int = int(self.companion_hash) + except ValueError: + hash_int = int(self.companion_hash, 16) + logger.info( + "Companion frame server listening on %s:%s (hash=0x%02x)", + addr[0], + addr[1], + hash_int, + ) + + async def stop(self) -> None: + """Stop the TCP server and disconnect any client.""" + # Signal writer task to stop and wait for it + if self._write_queue is not None: + try: + self._write_queue.put_nowait(None) # Sentinel + except asyncio.QueueFull: + pass + if self._writer_task is not None: + self._writer_task.cancel() + try: + await self._writer_task + except asyncio.CancelledError: + pass + self._writer_task = None + self._write_queue = None + if self._client_writer: + try: + self._client_writer.close() + await self._client_writer.wait_closed() + except Exception: + pass + self._client_writer = None + self._client_reader = None + if self._server: + self._server.close() + await self._server.wait_closed() + self._server = None + logger.info("Companion frame server stopped (port=%s)", self.port) + + # ------------------------------------------------------------------------- + # Persistence hooks (override in subclasses for SQLite, etc.) + # ------------------------------------------------------------------------- + + async def _persist_companion_message(self, msg_dict: dict) -> None: + """Hook: persist a received message. Default is a no-op — the message + stays in the bridge's in-memory queue for ``sync_next_message``.""" + + def _sync_next_from_persistence(self) -> Optional[QueuedMessage]: + """Hook: pop a persisted message when the bridge queue is empty. + Default returns ``None``.""" + return None + + async def _persist_contact(self, contact) -> None: + """Hook: persist a single contact. Default is a no-op. + + Subclasses should override to do a fast single-row upsert rather + than rewriting the entire contact list. + """ + + async def _save_contacts(self) -> None: + """Hook: persist the full contact list. Default is a no-op.""" + + async def _save_channels(self) -> None: + """Hook: persist the full channel list. Default is a no-op.""" + + def _get_batt_and_storage(self) -> tuple[int, int, int]: + """Hook: return (millivolts, used_kb, total_kb). Default: all zeros.""" + return (0, 0, 0) + + # ------------------------------------------------------------------------- + # Push callbacks + # ------------------------------------------------------------------------- + + def _setup_push_callbacks(self) -> None: + """Subscribe to bridge events and send PUSH frames to connected client.""" + # Clear any callbacks registered by a previous connection so they + # don't accumulate across reconnections. + self.bridge.clear_push_callbacks() + + def _write_push(data: bytes) -> None: + """Enqueue a push frame. Sync, non-blocking.""" + self._enqueue_frame(data) + + async def on_message_received( + sender_key, text, timestamp, txt_type, packet_hash=None, snr=None, rssi=None + ): + msg_dict = { + "sender_key": sender_key, + "text": text, + "timestamp": timestamp, + "txt_type": txt_type, + "is_channel": False, + "channel_idx": 0, + "path_len": 0, + "packet_hash": packet_hash, + "snr": snr, + "rssi": rssi, + } + await self._persist_companion_message(msg_dict) + _write_push(bytes([PUSH_CODE_MSG_WAITING])) + + def on_send_confirmed(crc): + data = struct.pack( + "= 32 + ): + return + if not self.bridge.contacts.get_by_key(contact.public_key): + return + _write_push(bytes([PUSH_CODE_PATH_UPDATED]) + contact.public_key[:32]) + try: + await self._persist_contact(contact) + except Exception as e: + logger.warning("Persist contact after path update failed: %s", e) + + async def on_channel_message_received( + channel_name, + sender_name, + message_text, + timestamp, + path_len=0, + channel_idx=0, + packet_hash=None, + snr=None, + rssi=None, + ): + msg_dict = { + "sender_key": b"", + "text": message_text, + "timestamp": timestamp, + "txt_type": 0, + "is_channel": True, + "channel_idx": channel_idx, + "path_len": path_len, + "packet_hash": packet_hash, + "snr": snr, + "rssi": rssi, + } + await self._persist_companion_message(msg_dict) + _write_push(bytes([PUSH_CODE_MSG_WAITING])) + + def on_binary_response(tag_bytes, response_data, parsed=None, request_type=None): + frame = ( + bytes([PUSH_CODE_BINARY_RESPONSE, 0]) + + (tag_bytes if isinstance(tag_bytes, bytes) else struct.pack("= 32: + _write_push(bytes([PUSH_CODE_CONTACT_DELETED]) + pub_key[:32]) + + def on_contacts_full(): + _write_push(bytes([PUSH_CODE_CONTACTS_FULL])) + + def on_raw_data_received(payload_bytes: bytes, snr: float, rssi: int) -> None: + """Push PUSH_CODE_RAW_DATA (0x84): code, SNR byte, RSSI byte, 0xFF, payload.""" + snr_byte = max(-128, min(127, int(round(snr * 4)))) + rssi_byte = max(-128, min(127, int(rssi))) + payload_len = min(len(payload_bytes), MAX_PAYLOAD_SIZE - 4) + data = ( + bytes([PUSH_CODE_RAW_DATA]) + + struct.pack(" None: + """Push PUSH_CODE_TRACE_DATA (0x89) to client. Matches firmware + ``onTraceRecv()`` frame format. + + Kept as ``async def`` for backward-compatible call sites that + ``await`` it, but the body is synchronous (just enqueues). + """ + if self._write_queue is None: + return + path_sz = flags & 0x03 + expected_snr_len = path_len >> path_sz + if len(path_snrs) != expected_snr_len: + logger.debug( + "push_trace_data: path_snrs len %s != expected %s", + len(path_snrs), + expected_snr_len, + ) + return + data = ( + bytes([PUSH_CODE_TRACE_DATA, 0, path_len, flags]) + + struct.pack(" None: + """Push raw RX packet to client (PUSH_CODE_LOG_RX_DATA 0x88). + + Sync, non-blocking. Safe to call from any context (async or sync). + """ + if self._write_queue is None: + return + snr_byte = max(-128, min(127, int(round(snr * 4)))) + rssi_byte = max(-128, min(127, int(rssi))) + if snr_byte < 0: + snr_byte += 256 + if rssi_byte < 0: + rssi_byte += 256 + payload_len = min(len(raw), MAX_PAYLOAD_SIZE - 3) # 3 = code + snr + rssi + data = bytes([PUSH_CODE_LOG_RX_DATA, snr_byte & 0xFF, rssi_byte & 0xFF]) + raw[:payload_len] + self._enqueue_frame(data) + + async def push_rx_raw_async(self, snr: float, rssi: int, raw: bytes) -> None: + """Push raw RX packet to client. Async wrapper for backward compatibility.""" + self.push_rx_raw(snr, rssi, raw) + + async def push_control_data( + self, + snr: float, + rssi: int, + path_len: int, + path_bytes: bytes, + payload: bytes, + ) -> None: + """Push CONTROL packet to client (PUSH_CODE_CONTROL_DATA 0x8E). + + Kept as ``async def`` for backward-compatible call sites that + ``await`` it, but the body is synchronous (just enqueues). + """ + if self._write_queue is None: + logger.warning("Push control data skipped: no client connection") + return + # Discovery response (0x90): clear the no-op callback + if self._control_handler and len(payload) >= 6 and (payload[0] & 0xF0) == 0x90: + tag = struct.unpack(" None: + """Build an outbound frame and enqueue it for the writer task. + + Sync, non-blocking. On ``QueueFull`` the frame is dropped with a + warning — this provides natural backpressure shedding. + """ + if self._write_queue is None: + return + if len(data) > MAX_PAYLOAD_SIZE: + logger.warning( + "Outbound frame payload truncated from %s to %s (MAX_FRAME_SIZE=%s)", + len(data), + MAX_PAYLOAD_SIZE, + MAX_FRAME_SIZE, + ) + data = data[:MAX_PAYLOAD_SIZE] + frame = bytes([FRAME_OUTBOUND_PREFIX]) + struct.pack(" None: + """Alias for ``_enqueue_frame``; retained for subclass compatibility.""" + self._enqueue_frame(data) + + def _write_ok(self) -> None: + self._write_frame(bytes([RESP_CODE_OK])) + + def _write_err(self, err_code: int) -> None: + self._write_frame(bytes([RESP_CODE_ERR, err_code])) + + # ------------------------------------------------------------------------- + # Writer task + # ------------------------------------------------------------------------- + + # Must exceed DEFAULT_MAX_CONTACTS (+2 for START/END) so that + # _cmd_get_contacts can enqueue the full contact dump without drops. + _WRITE_QUEUE_MAXSIZE = 2048 + _DRAIN_BATCH = 10 + + async def _writer_loop(self, writer: asyncio.StreamWriter) -> None: + """Single writer task: pull frames from the queue, write to the + ``StreamWriter``, and drain periodically. + + Integrates heartbeat via timeout on :pymethod:`asyncio.Queue.get` — + when no frames arrive within ``_heartbeat_interval`` seconds a + ``RESP_CODE_CURR_TIME`` heartbeat frame is generated automatically, + eliminating the need for a separate heartbeat task. + + On any write/drain error the writer is closed, which causes the read + loop in :pymethod:`_handle_client` to receive EOF → clean disconnect. + """ + frames_since_drain = 0 + try: + while True: + # Wait for a frame, or timeout for heartbeat --------- + try: + frame = await asyncio.wait_for( + self._write_queue.get(), + timeout=self._heartbeat_interval, + ) + except asyncio.TimeoutError: + # Heartbeat: send RESP_CODE_CURR_TIME + now = self.bridge.get_time() + hb_data = bytes([RESP_CODE_CURR_TIME]) + struct.pack("= self._DRAIN_BATCH: + await writer.drain() + frames_since_drain = 0 + except asyncio.CancelledError: + pass + except (ConnectionResetError, BrokenPipeError, OSError) as e: + logger.warning("Writer loop connection lost: %s", e) + except Exception as e: + logger.error("Writer loop error: %s", e, exc_info=True) + finally: + try: + if not writer.is_closing(): + writer.close() + except Exception: + pass + + # ------------------------------------------------------------------------- + # Client handling + # ------------------------------------------------------------------------- + + @staticmethod + def _configure_socket(writer: asyncio.StreamWriter) -> None: + """Configure TCP keepalive and low-latency options on the underlying socket.""" + sock = writer.get_extra_info("socket") + if sock is None: + return + try: + # Disable Nagle's algorithm for real-time frame delivery (important + # over VPN/Tailscale where latency is higher and small-write + # coalescing can compound delays). + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + except OSError as e: + logger.debug("Could not set TCP_NODELAY: %s", e) + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + if sys.platform == "linux": + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 15) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 5) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 3) + elif sys.platform == "darwin": + # TCP_KEEPALIVE is the macOS equivalent of TCP_KEEPIDLE + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPALIVE, 15) + try: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 5) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 3) + except (AttributeError, OSError): + pass # older macOS may lack KEEPINTVL/KEEPCNT + except OSError as e: + logger.debug("Could not set TCP keepalive: %s", e) + + async def _handle_client( + self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ) -> None: + """Handle a new client connection. One client at a time. + If a client is already connected, the existing connection is closed + and the new one is accepted (eviction). An idle read timeout also + frees the slot when no data is received for client_idle_timeout_sec. + """ + if self._client_writer: + logger.info( + "Companion already has a client; evicting previous connection (port=%s)", + self.port, + ) + old_writer = self._client_writer + try: + old_writer.close() + await old_writer.wait_closed() + except Exception: + pass + + self._client_reader = reader + self._client_writer = writer + self._configure_socket(writer) + self._write_queue = asyncio.Queue(maxsize=self._WRITE_QUEUE_MAXSIZE) + self._setup_push_callbacks() + logger.info("Companion client connected (port=%s)", self.port) + + self._writer_task = asyncio.create_task(self._writer_loop(writer)) + disconnect_reason: Optional[str] = None + try: + while True: + try: + prefix = await asyncio.wait_for( + reader.read(1), timeout=self._client_idle_timeout_sec + ) + except asyncio.TimeoutError: + disconnect_reason = "idle_timeout" + break + if not prefix: + disconnect_reason = "empty_read" + break + if prefix[0] != FRAME_INBOUND_PREFIX: + logger.warning("Invalid frame prefix: 0x%02x", prefix[0]) + continue + len_bytes = await reader.readexactly(2) + frame_len = struct.unpack(" MAX_FRAME_SIZE: + logger.warning("Frame too long: %s", frame_len) + disconnect_reason = "frame_too_long" + break + payload = await reader.readexactly(frame_len) + await self._handle_cmd(payload) + if self._writer_task.done(): + disconnect_reason = "writer_failed" + break + except asyncio.IncompleteReadError: + disconnect_reason = "incomplete_read" + except (ConnectionResetError, BrokenPipeError) as e: + disconnect_reason = type(e).__name__ + except Exception as e: + disconnect_reason = f"other: {type(e).__name__}: {e}" + logger.error("Client handler error: %s", e, exc_info=True) + finally: + if self._write_queue is not None: + try: + self._write_queue.put_nowait(None) # Sentinel + except asyncio.QueueFull: + pass + if self._writer_task is not None: + self._writer_task.cancel() + try: + await self._writer_task + except asyncio.CancelledError: + pass + self._writer_task = None + self._write_queue = None + if self._client_writer is writer: + self._client_writer = None + self._client_reader = None + logger.info( + "Companion client disconnected (port=%s): %s", + self.port, + disconnect_reason or "unknown", + ) + + # ------------------------------------------------------------------------- + # Command dispatch + # ------------------------------------------------------------------------- + + async def _handle_cmd(self, payload: bytes) -> None: + """Dispatch command to handler.""" + if not payload: + return + cmd = payload[0] + data = payload[1:] + logger.info("Companion cmd 0x%02x (%s) len=%s", cmd, cmd, len(payload)) + if cmd in (CMD_GET_CHANNEL, CMD_SET_CHANNEL): + logger.debug( + "Companion cmd 0x%02x (%s), payload_len=%s", + cmd, + "GET_CHANNEL" if cmd == CMD_GET_CHANNEL else "SET_CHANNEL", + len(payload), + ) + + try: + handler = self._cmd_handlers.get(cmd) + if handler is not None: + await handler(data) + else: + logger.warning( + "Companion unsupported cmd 0x%02x (%s) len=%s", + cmd, + cmd, + len(payload), + ) + self._write_err(ERR_CODE_UNSUPPORTED_CMD) + except Exception as e: + logger.error("Cmd 0x%02x error: %s", cmd, e, exc_info=True) + self._write_err(ERR_CODE_ILLEGAL_ARG) + + # ------------------------------------------------------------------------- + # Command handlers + # ------------------------------------------------------------------------- + + async def _cmd_app_start(self, data: bytes) -> None: + if len(data) >= 1: + self._app_target_ver = data[0] + prefs = self.bridge.get_self_info() + pubkey = self.bridge.get_public_key() + name = prefs.node_name.encode("utf-8", errors="replace") + lat = int(getattr(prefs, "latitude", 0) * 1e6) + lon = int(getattr(prefs, "longitude", 0) * 1e6) + frame = ( + bytes([RESP_CODE_SELF_INFO, ADV_TYPE_CHAT, prefs.tx_power_dbm, 22]) + + pubkey + + struct.pack(" None: + # Layout must match MeshCore companion_radio MyMesh.cpp handleCmdFrame() CMD_DEVICE_QUEURY: + # [0]=RESP_CODE_DEVICE_INFO, [1]=FIRMWARE_VER_CODE, [2]=MAX_CONTACTS/2, + # [3]=MAX_GROUP_CHANNELS, [4..7]=ble_pin, [8..19]=build_date(12), [20..59]=manufacturer(40), + # [60..79]=version(20), [80]=client_repeat, [81]=path_hash_mode (v10+). + if len(data) >= 1: + self._app_target_ver = data[0] + firmware_ver = FIRMWARE_VER_CODE + max_contacts = getattr(getattr(self.bridge, "contacts", None), "max_contacts", 1000) + max_channels_val = getattr(getattr(self.bridge, "channels", None), "max_channels", 40) + max_contacts_div_2 = min(max_contacts // 2, 255) + max_channels = min(max_channels_val, 255) + ble_pin = 0 + try: + prefs = self.bridge.get_self_info() + client_repeat = getattr(prefs, "client_repeat", 0) & 0xFF + path_hash_mode = getattr(prefs, "path_hash_mode", 0) & 0xFF + except Exception: + client_repeat = 0 + path_hash_mode = 0 + frame = ( + bytes( + [ + RESP_CODE_DEVICE_INFO, + firmware_ver, + max_contacts_div_2, + max_channels, + ] + ) + + struct.pack(" None: + since = struct.unpack("= 4 else 0 + contacts = self.bridge.get_contacts(since=since) + self._write_frame(bytes([RESP_CODE_CONTACTS_START]) + struct.pack(" None: + """Encode and write a single RESP_CODE_CONTACT frame.""" + pubkey = c.public_key if isinstance(c.public_key, bytes) else bytes.fromhex(c.public_key) + name = (c.name.encode("utf-8")[:32] if isinstance(c.name, str) else c.name[:32]).ljust( + 32, b"\x00" + ) + opl_byte = 0xFF if c.out_path_len < 0 else min(c.out_path_len, 255) + frame = ( + bytes([RESP_CODE_CONTACT, *pubkey, c.adv_type, c.flags, opl_byte]) + + (c.out_path[:MAX_PATH_SIZE] if c.out_path else b"").ljust(MAX_PATH_SIZE, b"\x00") + + name + + struct.pack(" None: + if len(data) < PUB_KEY_SIZE: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + pubkey = data[:PUB_KEY_SIZE] + contact = ( + self.bridge.contacts.get_by_key(pubkey) + if hasattr(self.bridge.contacts, "get_by_key") + else None + ) + if not contact: + self._write_err(ERR_CODE_NOT_FOUND) + return + self._write_contact_frame(contact) + + async def _cmd_send_txt_msg(self, data: bytes) -> None: + if len(data) < 12: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + txt_type = data[0] + attempt = data[1] + pubkey_prefix = data[6:12] + text = data[12:].decode("utf-8", errors="replace").rstrip("\x00") + contact = self.bridge.contacts.get_by_key_prefix(pubkey_prefix) + if not contact: + self._write_err(ERR_CODE_NOT_FOUND) + return + pubkey = ( + contact.public_key + if isinstance(contact.public_key, bytes) + else bytes.fromhex(contact.public_key) + ) + result = await self.bridge.send_text_message( + pubkey, text, txt_type=txt_type, attempt=attempt + 1, wait_for_ack=False + ) + if result.success: + ack = result.expected_ack or 0 + timeout = result.timeout_ms or 5000 + frame = bytes([RESP_CODE_SENT, 1 if result.is_flood else 0]) + struct.pack( + " None: + if len(data) < 6: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + txt_type = data[0] + channel_idx = data[1] + text = data[6:].decode("utf-8", errors="replace").rstrip("\x00") + if txt_type != 0: + self._write_err(ERR_CODE_UNSUPPORTED_CMD) + return + if self.bridge.get_channel(channel_idx) is None: + self._write_err(ERR_CODE_NOT_FOUND) + return + ok = await self.bridge.send_channel_message(channel_idx, text) + if ok: + self._write_ok() + else: + self._write_err(ERR_CODE_BAD_STATE) + + async def _cmd_send_binary_req(self, data: bytes) -> None: + if len(data) < 33: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + pubkey = data[:32] + req_data = data[32:] + send_binary_req = getattr(self.bridge, "send_binary_req", None) + if not send_binary_req: + self._write_err(ERR_CODE_UNSUPPORTED_CMD) + return + try: + result = await send_binary_req(pubkey, req_data) + except Exception as e: + logger.error("send_binary_req error: %s", e, exc_info=True) + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + if not result.success: + self._write_err(ERR_CODE_NOT_FOUND) + return + tag = result.expected_ack if result.expected_ack is not None else 0 + timeout_ms = result.timeout_ms if result.timeout_ms is not None else 10000 + frame = bytes([RESP_CODE_SENT, 1 if result.is_flood else 0]) + struct.pack( + " None: + if len(data) < 33: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + pubkey = data[:32] + req_data = data[32:] + send_anon_req = getattr(self.bridge, "send_anon_req", None) + if not send_anon_req: + self._write_err(ERR_CODE_UNSUPPORTED_CMD) + return + try: + result = await send_anon_req(pubkey, req_data) + except Exception as e: + logger.error("send_anon_req error: %s", e, exc_info=True) + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + if not result.success: + self._write_err(ERR_CODE_NOT_FOUND) + return + tag = result.expected_ack if result.expected_ack is not None else 0 + timeout_ms = result.timeout_ms if result.timeout_ms is not None else 10000 + frame = bytes([RESP_CODE_SENT, 1 if result.is_flood else 0]) + struct.pack( + " None: + if len(data) < 2: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + if (data[0] & 0x80) == 0: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + # Discovery request: register a no-op response callback + if self._control_handler and len(data) >= 6 and (data[0] & 0xF0) == 0x80: + tag = struct.unpack(" None: + logger.info( + "Path discovery request received (cmd 52), data_len=%s", + len(data), + ) + if len(data) < 33: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + pub_key = data[1:33] + send_req = getattr(self.bridge, "send_path_discovery_req", None) + if not send_req: + self._write_err(ERR_CODE_UNSUPPORTED_CMD) + return + try: + result = await send_req(pub_key) + except Exception as e: + logger.error("send_path_discovery_req error: %s", e, exc_info=True) + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + if not result.success: + self._write_err(ERR_CODE_NOT_FOUND) + return + tag = result.expected_ack if result.expected_ack is not None else 0 + timeout_ms = result.timeout_ms if result.timeout_ms is not None else 10000 + frame = bytes([RESP_CODE_SENT, 1 if result.is_flood else 0]) + struct.pack( + " None: + if len(data) < 10: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + tag = struct.unpack_from("> path_sz) > MAX_PATH_SIZE or (path_len % (1 << path_sz)) != 0: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + send_raw = getattr(self.bridge, "send_trace_path_raw", None) + if not send_raw: + self._write_err(ERR_CODE_UNSUPPORTED_CMD) + return + try: + ok = await send_raw(tag, auth_code, flags, path_bytes) + except Exception as e: + logger.error("send_trace_path error: %s", e, exc_info=True) + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + if not ok: + self._write_err(ERR_CODE_TABLE_FULL) + return + est_timeout_ms = 5000 + (path_len * 200) + frame = bytes([RESP_CODE_SENT, 0]) + struct.pack("> path_sz + path_snrs = bytes(snr_len) + final_snr_byte = 0 + await self.push_trace_data( + path_len, + flags, + tag, + auth_code, + path_bytes, + path_snrs, + final_snr_byte, + ) + + def _build_message_frame(self, msg: "QueuedMessage") -> bytes: + """Encode a QueuedMessage into a response frame (shared by base and subclasses).""" + snr_byte = max(-128, min(127, int(round(getattr(msg, "snr", 0) * 4)))) + if snr_byte < 0: + snr_byte += 256 + if msg.is_channel: + path_len_byte = msg.path_len if msg.path_len < 256 else 0xFF + txt_type = 0 + text_bytes = (msg.text or "").rstrip("\x00").encode("utf-8", errors="replace") + if self._app_target_ver >= 3: + return ( + bytes( + [ + RESP_CODE_CHANNEL_MSG_RECV_V3, + snr_byte & 0xFF, + 0, + 0, + msg.channel_idx, + path_len_byte, + txt_type, + ] + ) + + struct.pack("= 6 else msg.sender_key.ljust(6, b"\x00") + ) + path_len_byte = msg.path_len if msg.path_len < 256 else 0xFF + text_bytes = msg.text.encode("utf-8", errors="replace") + if self._app_target_ver >= 3: + return ( + bytes([RESP_CODE_CONTACT_MSG_RECV_V3, snr_byte & 0xFF, 0, 0]) + + prefix + + bytes([path_len_byte, msg.txt_type]) + + struct.pack(" None: + msg = self.bridge.sync_next_message() + if msg is None: + msg = await asyncio.to_thread(self._sync_next_from_persistence) + if msg is None: + self._write_frame(bytes([RESP_CODE_NO_MORE_MESSAGES])) + return + self._write_frame(self._build_message_frame(msg)) + + async def _cmd_send_login(self, data: bytes) -> None: + if len(data) < 32: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + pubkey = data[:32] + password = ( + data[32:].decode("utf-8", errors="replace").rstrip("\x00") if len(data) > 32 else "" + ) + self._write_frame(bytes([RESP_CODE_SENT, 1]) + struct.pack("= 2 for owner info + self._write_frame( + bytes( + [ + PUSH_CODE_LOGIN_SUCCESS, + 1 if result.get("is_admin") else 0, + ] + ) + + pubkey[:6] + + struct.pack(" None: + if len(data) < 32: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + pubkey = data[0:32] + self._write_frame(bytes([RESP_CODE_SENT, 0]) + struct.pack(" None: + # Protocol: CMD_SEND_TELEMETRY_REQ has reserved bytes(3) then pub_key bytes(32). + # See MeshCore Companion-Radio-Protocol: CMD_SEND_TELEMETRY_REQ frame format. + if len(data) < 35: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + pubkey = data[3:35] + flags = 0x07 # request all: base + location + environment + want_base = bool(flags & 0x01) + want_location = bool(flags & 0x02) + want_environment = bool(flags & 0x04) + self._write_frame(bytes([RESP_CODE_SENT, 0]) + struct.pack(" None: + flood = len(data) >= 1 and data[0] == 1 + ok = await self.bridge.advertise(flood=flood) + self._write_ok() if ok else self._write_err(ERR_CODE_BAD_STATE) + + async def _cmd_set_advert_name(self, data: bytes) -> None: + name = data.decode("utf-8", errors="replace").rstrip("\x00") + self.bridge.set_advert_name(name) + self._write_ok() + + async def _cmd_set_advert_latlon(self, data: bytes) -> None: + if len(data) < 8: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + lat, lon = struct.unpack_from(" None: + if len(data) < 36: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + pubkey = data[0:32] + adv_type = data[32] + flags = data[33] + out_path_len = struct.unpack_from("= out_path_end: + out_path = data[35:out_path_end].rstrip(b"\x00") + else: + out_path = data[35 : len(data)].rstrip(b"\x00") if len(data) > 35 else b"" + name_start = 35 + MAX_PATH_SIZE + name_end = name_start + 32 + if len(data) >= name_end: + name_raw = data[name_start:name_end] + elif len(data) > name_start: + name_raw = data[name_start : len(data)].ljust(32, b"\x00") + else: + name_raw = b"\x00" * 32 + name = name_raw.split(b"\x00")[0].decode("utf-8", errors="replace") + last_advert = 0 + if len(data) >= name_end + 4: + last_advert = struct.unpack_from("= name_end + 4 + 8: + gps_lat = struct.unpack_from("= name_end + 4 + 12: + lastmod = struct.unpack_from(" 255 else out_path_len + out_path_padded = (out_path[:MAX_PATH_SIZE] if out_path else b"").ljust( + MAX_PATH_SIZE, b"\x00" + ) + name_padded = (name.encode("utf-8")[:32] if isinstance(name, str) else name[:32]).ljust( + 32, b"\x00" + ) + contact_frame = ( + bytes([RESP_CODE_CONTACT]) + + pubkey + + bytes([adv_type, flags, opl_byte]) + + out_path_padded + + name_padded + + struct.pack(" None: + if len(data) < 32: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + pubkey = data[:32] + ok = self.bridge.remove_contact(pubkey) + if ok: + try: + await self._save_contacts() + except Exception as e: + logger.warning("Save contacts after remove failed: %s", e) + self._write_ok() if ok else self._write_err(ERR_CODE_NOT_FOUND) + + async def _cmd_reset_path(self, data: bytes) -> None: + if len(data) < 32: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + pubkey = data[:32] + ok = self.bridge.reset_path(pubkey) + self._write_ok() if ok else self._write_err(ERR_CODE_NOT_FOUND) + + async def _cmd_get_batt_and_storage(self, data: bytes) -> None: + millivolts, used_kb, total_kb = self._get_batt_and_storage() + frame = ( + bytes([RESP_CODE_BATT_AND_STORAGE]) + + struct.pack(" None: + stats_type = data[0] if len(data) >= 1 else STATS_TYPE_PACKETS + if stats_type not in ( + STATS_TYPE_CORE, + STATS_TYPE_RADIO, + STATS_TYPE_PACKETS, + ): + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + if self.stats_getter: + if asyncio.iscoroutinefunction(self.stats_getter): + stats = await self.stats_getter(stats_type) + else: + stats = await asyncio.to_thread(self.stats_getter, stats_type) + else: + stats = None + stats = stats or self.bridge.get_stats(stats_type) + frame = bytes([RESP_CODE_STATS, stats_type]) + if stats_type == STATS_TYPE_CORE: + battery_mv = int(stats.get("battery_mv", 0)) + uptime_secs = int(stats.get("uptime_secs", 0)) + errors = int(stats.get("errors", 0)) + queue_len = min(255, max(0, int(stats.get("queue_len", 0)))) + frame += struct.pack(" None: + if len(data) < 1 + PUB_KEY_SIZE: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + pub_key = data[1 : 1 + PUB_KEY_SIZE] + prefix = pub_key[:7] + # Bridge methods used from command handlers must not block the event loop; + # if a subclass adds sync I/O here, run it via asyncio.to_thread(). + found = ( + self.bridge.get_advert_path(prefix) + if getattr(self.bridge, "get_advert_path", None) + else None + ) + if not found: + self._write_err(ERR_CODE_NOT_FOUND) + return + path_bytes = getattr(found, "path", None) or b"" + if not isinstance(path_bytes, bytes): + path_bytes = bytes(path_bytes) + path_len_encoded = getattr(found, "path_len", 0) or 0 + path_byte_len = PathUtils.get_path_byte_len(path_len_encoded) + recv_ts = getattr(found, "recv_timestamp", 0) + frame = ( + bytes([RESP_CODE_ADVERT_PATH]) + + struct.pack(" None: + ok = self.bridge.import_contact(data) + self._write_ok() if ok else self._write_err(ERR_CODE_ILLEGAL_ARG) + + async def _cmd_get_channel(self, data: bytes) -> None: + get_full_list = len(data) == 0 + channel_idx = data[0] if not get_full_list else 0 + max_channels_val = getattr(getattr(self.bridge, "channels", None), "max_channels", 40) + + def _channel_info_frame(idx: int, ch) -> bytes: + if ch is None: + name = b"\x00" * 32 + secret = b"\x00" * 16 + else: + name = ch.name.encode("utf-8", errors="replace")[:32].ljust(32, b"\x00") + secret = (ch.secret[:16] if ch.secret else b"\x00" * 16).ljust(16, b"\x00") + return bytes([RESP_CODE_CHANNEL_INFO, idx]) + name + secret + + if get_full_list: + for idx in range(max_channels_val): + ch = self.bridge.get_channel(idx) + frame = _channel_info_frame(idx, ch) + self._write_frame(frame) + return + + if channel_idx < 0 or channel_idx >= max_channels_val: + self._write_err(ERR_CODE_NOT_FOUND) + return + ch = self.bridge.get_channel(channel_idx) + frame = _channel_info_frame(channel_idx, ch) + self._write_frame(frame) + + async def _cmd_set_channel(self, data: bytes) -> None: + if len(data) < 34: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + channel_idx = data[0] + name_raw = data[1:33] + name = name_raw.split(b"\x00")[0].decode("utf-8", errors="replace").strip() + if len(data) >= 97: + try: + secret = bytes.fromhex(data[33:97].decode("ascii")) + except (ValueError, UnicodeDecodeError): + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + elif len(data) >= 65: + secret = data[33:65] + elif len(data) >= 49: + secret = data[33:49] + else: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + ok = self.bridge.set_channel(channel_idx, name, secret) + if ok: + try: + await self._save_channels() + except Exception as e: + logger.warning("Save channels after set failed: %s", e) + self._write_ok() if ok else self._write_err(ERR_CODE_NOT_FOUND) + + async def _cmd_set_flood_scope(self, data: bytes) -> None: + """Delegate flood scope to the bridge.""" + if len(data) >= 16: + self.bridge.set_flood_scope(data[:16]) + else: + self.bridge.set_flood_scope(None) + self._write_ok() + + # ------------------------------------------------------------------------- + # Time, radio, tuning, share/export, logout, custom vars, autoadd + # ------------------------------------------------------------------------- + + async def _cmd_get_device_time(self, data: bytes) -> None: + now = self.bridge.get_time() + self._write_frame(bytes([RESP_CODE_CURR_TIME]) + struct.pack(" None: + if len(data) < 4: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + secs = struct.unpack(" None: + if len(data) < 10: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + # Frequency in kHz (match firmware self-info; client sends same encoding) + freq_khz = struct.unpack_from(" None: + if len(data) < 1: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + power = struct.unpack_from("= 30: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + self.bridge.set_tx_power(power) + self._write_ok() + + async def _cmd_share_contact(self, data: bytes) -> None: + if len(data) < PUB_KEY_SIZE: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + pubkey = data[:PUB_KEY_SIZE] + ok = await self.bridge.share_contact(pubkey) + self._write_ok() if ok else self._write_err(ERR_CODE_NOT_FOUND) + + async def _cmd_export_contact(self, data: bytes) -> None: + if len(data) < PUB_KEY_SIZE: + raw = self.bridge.export_contact(None) + else: + raw = self.bridge.export_contact(data[:PUB_KEY_SIZE]) + if raw is None: + self._write_err(ERR_CODE_NOT_FOUND) + return + self._write_frame(bytes([RESP_CODE_EXPORT_CONTACT]) + raw) + + async def _cmd_export_private_key(self, data: bytes) -> None: + """Export private/signing key as 64-byte MeshCore format (RESP_CODE_PRIVATE_KEY + 64 bytes). + + For PyNaCl 32-byte seeds we expand to MeshCore 64-byte format (SHA-512 + clamp) so + the client's ed25519_derive_pub yields the same public key and signing works. + """ + identity = self.bridge._identity + key_bytes = identity.get_signing_key_bytes() + if len(key_bytes) == 32: + key_bytes = CryptoUtils.ed25519_expand_seed_to_meshcore_64(key_bytes) + elif len(key_bytes) < 64: + key_bytes = key_bytes.ljust(64, b"\x00") + else: + key_bytes = key_bytes[:64] + self._write_frame(bytes([RESP_CODE_PRIVATE_KEY]) + key_bytes) + + async def _cmd_import_private_key(self, data: bytes) -> None: + """Stub/no-op: private key is set from config; dynamic import may be supported later.""" + self._write_ok() + + async def _cmd_set_tuning_params(self, data: bytes) -> None: + if len(data) < 8: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + rx_ms = struct.unpack_from(" None: + if len(data) < PUB_KEY_SIZE: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + pubkey = data[:PUB_KEY_SIZE] + await self.bridge.send_logout(pubkey) + self._write_ok() + + async def _cmd_get_custom_vars(self, data: bytes) -> None: + custom_vars = self.bridge.get_custom_vars() + parts = [f"{k}:{v}" for k, v in custom_vars.items()] + csv = ",".join(parts)[:140] + self._write_frame(bytes([RESP_CODE_CUSTOM_VARS]) + csv.encode("utf-8", errors="replace")) + + async def _cmd_set_custom_var(self, data: bytes) -> None: + if len(data) < 3: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + text = data.split(b"\x00")[0].decode("utf-8", errors="replace") + sep = text.find(":") + if sep < 1: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + name = text[:sep] + value = text[sep + 1 :] + ok = self.bridge.set_custom_var(name, value) + self._write_ok() if ok else self._write_err(ERR_CODE_ILLEGAL_ARG) + + async def _cmd_set_autoadd_config(self, data: bytes) -> None: + if len(data) < 1: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + self.bridge.set_autoadd_config(data[0]) + self._write_ok() + + async def _cmd_get_autoadd_config(self, data: bytes) -> None: + config = self.bridge.get_autoadd_config() + self._write_frame(bytes([RESP_CODE_AUTOADD_CONFIG, config & 0xFF])) + + async def _cmd_set_other_params(self, data: bytes) -> None: + """Handle CMD_SET_OTHER_PARAMS (0x26). Mirrors MyMesh.cpp:1290-1305.""" + if len(data) < 1: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + manual_add = data[0] + telemetry_modes = data[1] if len(data) >= 2 else 0 + advert_loc_policy = data[2] if len(data) >= 3 else 0 + multi_acks = data[3] if len(data) >= 4 else 0 + self.bridge.set_other_params(manual_add, telemetry_modes, advert_loc_policy, multi_acks) + self._write_ok() + + async def _cmd_send_raw_data(self, data: bytes) -> None: + """Handle CMD_SEND_RAW_DATA (25). + Format: [path_len_encoded][path][payload] (min 4-byte payload).""" + if len(data) < 6: + self._write_err(ERR_CODE_UNSUPPORTED_CMD) + return + path_len_byte = data[0] + if not PathUtils.is_valid_path_len(path_len_byte): + self._write_err(ERR_CODE_UNSUPPORTED_CMD) + return + path_byte_len = PathUtils.get_path_byte_len(path_len_byte) + if 1 + path_byte_len + 4 > len(data): + self._write_err(ERR_CODE_UNSUPPORTED_CMD) + return + path = data[1 : 1 + path_byte_len] + payload = data[1 + path_byte_len :] + result = await self.bridge.send_raw_data_direct( + path, payload, path_len_encoded=path_len_byte + ) + if result.success: + self._write_ok() + else: + self._write_err(ERR_CODE_TABLE_FULL) + + async def _cmd_set_path_hash_mode(self, data: bytes) -> None: + """Handle CMD_SET_PATH_HASH_MODE (61). Format: [subtype(0), mode(0-2)]. + + Mirrors MyMesh.cpp:1320-1327. Subtype byte must be 0; mode values + 0, 1, 2 select 1-byte, 2-byte, 3-byte path hashes respectively. + """ + if len(data) < 2 or data[0] != 0: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + mode = data[1] + if mode >= 3: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + self.bridge.set_path_hash_mode(mode) + self._write_ok() diff --git a/src/pymc_core/companion/message_queue.py b/src/pymc_core/companion/message_queue.py new file mode 100644 index 0000000..283c32c --- /dev/null +++ b/src/pymc_core/companion/message_queue.py @@ -0,0 +1,66 @@ +"""Fixed-size offline message queue for companion radio.""" + +from __future__ import annotations + +from collections import deque +from typing import Optional + +from .constants import DEFAULT_OFFLINE_QUEUE_SIZE +from .models import QueuedMessage + + +class MessageQueue: + """Fixed-size offline message queue (FIFO). + + Stores incoming messages that arrive when no consumer is actively + reading. Matches the firmware's offline_queue behaviour with a + configurable maximum size. When full, the oldest messages are + silently dropped (deque maxlen behaviour). + """ + + def __init__(self, max_size: int = DEFAULT_OFFLINE_QUEUE_SIZE): + self._queue: deque[QueuedMessage] = deque(maxlen=max_size) + self._max_size = max_size + + def push(self, msg: QueuedMessage) -> bool: + """Add a message to the queue. Returns True on success. + + If the queue is at capacity the oldest message is silently dropped. + """ + self._queue.append(msg) + return True + + def pop(self) -> Optional[QueuedMessage]: + """Remove and return the oldest message, or None if empty.""" + if self._queue: + return self._queue.popleft() + return None + + def pop_last(self) -> Optional[QueuedMessage]: + """Remove and return the most recently pushed message, or None if empty.""" + if self._queue: + return self._queue.pop() + return None + + def peek(self) -> Optional[QueuedMessage]: + """Return the oldest message without removing it, or None if empty.""" + if self._queue: + return self._queue[0] + return None + + def is_empty(self) -> bool: + """Check if the queue has no messages.""" + return len(self._queue) == 0 + + def is_full(self) -> bool: + """Check if the queue is at capacity.""" + return len(self._queue) >= self._max_size + + @property + def count(self) -> int: + """Return the number of messages in the queue.""" + return len(self._queue) + + def clear(self) -> None: + """Remove all messages from the queue.""" + self._queue.clear() diff --git a/src/pymc_core/companion/models.py b/src/pymc_core/companion/models.py new file mode 100644 index 0000000..13b54a0 --- /dev/null +++ b/src/pymc_core/companion/models.py @@ -0,0 +1,177 @@ +"""Data models for companion radio state objects.""" + +from __future__ import annotations + +import time +from dataclasses import dataclass +from typing import Any, Optional + + +@dataclass +class Contact: + """Represents a mesh network contact.""" + + public_key: bytes # 32 bytes (Ed25519) + name: str = "" # up to 32 chars + adv_type: int = 0 # ADV_TYPE_CHAT/REPEATER/ROOM/SENSOR + flags: int = 0 # bitfield + out_path_len: int = -1 # -1 = unknown, 0 = direct, >0 = multi-hop + out_path: bytes = b"" # routing path bytes + last_advert_timestamp: int = 0 # remote timestamp + lastmod: int = 0 # local modification timestamp + gps_lat: float = 0.0 # degrees + gps_lon: float = 0.0 # degrees + sync_since: int = 0 # for filtered iteration + + @classmethod + def from_dict( + cls, + d: dict[str, Any], + *, + now: Optional[int] = None, + ) -> "Contact": + """Build a Contact from a dict (event_data, advert_data, or serialized Contact). + + event_data uses: public_key, name, contact_type (id), lat, lon, + advert_timestamp, timestamp. + advert_data uses: public_key, name, contact_type_id, latitude, longitude, + flags, advert_timestamp, timestamp. + Serialized Contact dicts (ContactStore.to_dicts) use the same keys as the + dataclass: gps_lat, gps_lon, last_advert_timestamp, lastmod, out_path, + out_path_len, sync_since. + """ + if now is None: + now = int(time.time()) + pub = d.get("public_key", b"") + if isinstance(pub, str): + pub = bytes.fromhex(pub) if pub else b"" + elif not isinstance(pub, bytes): + pub = b"" + pub = pub[:32].ljust(32, b"\x00") + name = (d.get("name") or "") or "" + adv_type_raw = d.get("contact_type_id", d.get("adv_type", d.get("contact_type", 0))) + if isinstance(adv_type_raw, int): + adv_type = adv_type_raw + elif adv_type_raw is None: + adv_type = 0 + else: + try: + adv_type = int(adv_type_raw) + except (TypeError, ValueError): + adv_type = 0 + gps_lat = float(d.get("lat", d.get("latitude", d.get("gps_lat", 0.0)))) + gps_lon = float(d.get("lon", d.get("longitude", d.get("gps_lon", 0.0)))) + last_advert_ts = d.get("advert_timestamp", d.get("last_advert_timestamp", 0)) + last_advert_ts = int(last_advert_ts) if last_advert_ts is not None else 0 + if last_advert_ts > now: + last_advert_ts = now + lastmod_val = d.get("timestamp", d.get("lastmod", now)) + lastmod_val = int(lastmod_val) if lastmod_val is not None else now + flags_val = d.get("flags", 0) + flags_val = int(flags_val) if flags_val is not None else 0 + out_path = d.get("out_path", b"") + if isinstance(out_path, str): + out_path = bytes.fromhex(out_path) if out_path else b"" + elif isinstance(out_path, (list, bytearray)): + out_path = bytes(out_path) + else: + out_path = bytes(out_path) if out_path else b"" + out_path_len_val = d.get("out_path_len", -1) + out_path_len_val = int(out_path_len_val) if out_path_len_val is not None else -1 + sync_since_val = d.get("sync_since", 0) + sync_since_val = int(sync_since_val) if sync_since_val is not None else 0 + return cls( + public_key=pub, + name=name, + adv_type=adv_type, + flags=flags_val, + out_path_len=out_path_len_val, + out_path=out_path, + last_advert_timestamp=last_advert_ts, + lastmod=lastmod_val, + gps_lat=gps_lat, + gps_lon=gps_lon, + sync_since=sync_since_val, + ) + + +@dataclass +class Channel: + """Represents a group communication channel.""" + + name: str # up to 32 chars + secret: bytes # 16-byte PSK + + +@dataclass +class NodePrefs: + """Node configuration preferences (equivalent to firmware NodePrefs).""" + + node_name: str = "pyMC" + adv_type: int = 1 # ADV_TYPE_CHAT + tx_power_dbm: int = 20 + frequency_hz: int = 915000000 + bandwidth_hz: int = 250000 + spreading_factor: int = 10 + coding_rate: int = 5 + latitude: float = 0.0 + longitude: float = 0.0 + advert_loc_policy: int = 0 # ADVERT_LOC_NONE + multi_acks: int = 0 + telemetry_mode_base: int = 0 # TELEM_MODE_DENY + telemetry_mode_location: int = 0 + telemetry_mode_environment: int = 0 + manual_add_contacts: int = 0 + autoadd_config: int = 0 + rx_delay_base: float = 0.0 + airtime_factor: float = 0.0 + # Reported in CMD_DEVICE_QUERY device info frame (byte 80). + client_repeat: int = 0 + path_hash_mode: int = 0 # 0=1-byte, 1=2-byte, 2=3-byte hashes + + +@dataclass +class SentResult: + """Result of a message send operation.""" + + success: bool + is_flood: bool = False + expected_ack: Optional[int] = None + timeout_ms: Optional[int] = None + + +@dataclass +class PacketStats: + """Packet transmission/reception statistics.""" + + flood_tx: int = 0 + flood_rx: int = 0 + direct_tx: int = 0 + direct_rx: int = 0 + tx_errors: int = 0 + + +@dataclass +class AdvertPath: + """Recently heard advertiser path information.""" + + public_key_prefix: bytes # 7 bytes + name: str = "" + path_len: int = 0 + path: bytes = b"" + recv_timestamp: int = 0 + + +@dataclass +class QueuedMessage: + """A message stored in the offline queue.""" + + sender_key: bytes # 32 bytes + txt_type: int = 0 + timestamp: int = 0 + text: str = "" + is_channel: bool = False + channel_idx: int = 0 # only meaningful if is_channel + path_len: int = 0 + snr: float = 0.0 + rssi: int = 0 diff --git a/src/pymc_core/companion/path_cache.py b/src/pymc_core/companion/path_cache.py new file mode 100644 index 0000000..dbaf6c9 --- /dev/null +++ b/src/pymc_core/companion/path_cache.py @@ -0,0 +1,59 @@ +"""Path cache for tracking recently heard advertiser paths.""" + +from __future__ import annotations + +from collections import deque +from typing import Optional + +from .models import AdvertPath + + +class PathCache: + """Tracks recently heard advertiser paths. + + Stores path information received from advertisements and path updates, + matching the firmware's advert_paths table. Paths are keyed by public + key prefix and updated on each new advertisement. + """ + + def __init__(self, max_entries: int = 16): + self._paths: deque[AdvertPath] = deque() + self._max = max_entries + + def update(self, advert_path: AdvertPath) -> None: + """Add or update a path entry. + + If a path with the same public key prefix already exists, it is + removed and the new entry is appended (LRU refresh). If the cache + is full, the oldest entry is evicted. + """ + # Remove existing entry with same prefix (LRU refresh to tail) + for existing in self._paths: + if existing.public_key_prefix == advert_path.public_key_prefix: + self._paths.remove(existing) + break + + # Evict oldest if full + if len(self._paths) >= self._max: + self._paths.popleft() + self._paths.append(advert_path) + + def get_by_prefix(self, prefix: bytes) -> Optional[AdvertPath]: + """Lookup a path by public key prefix. + + Args: + prefix: Public key prefix to search for (matches the start + of stored public_key_prefix fields). + """ + for path in self._paths: + if path.public_key_prefix[: len(prefix)] == prefix: + return path + return None + + def get_all(self) -> list[AdvertPath]: + """Return all cached paths.""" + return list(self._paths) + + def clear(self) -> None: + """Remove all cached paths.""" + self._paths.clear() diff --git a/src/pymc_core/companion/stats_collector.py b/src/pymc_core/companion/stats_collector.py new file mode 100644 index 0000000..a4f6bc7 --- /dev/null +++ b/src/pymc_core/companion/stats_collector.py @@ -0,0 +1,59 @@ +"""Packet and radio statistics collector for companion radio.""" + +from __future__ import annotations + +import time + +from .models import PacketStats + + +class StatsCollector: + """Collects packet transmission/reception statistics. + + Tracks flood vs direct packet counts, errors, and uptime. + Matches the firmware's statistics reporting via CMD_GET_STATS. + """ + + def __init__(self) -> None: + self.packets = PacketStats() + self._start_time = time.time() + + def record_tx(self, is_flood: bool) -> None: + """Record a successful transmission.""" + if is_flood: + self.packets.flood_tx += 1 + else: + self.packets.direct_tx += 1 + + def record_rx(self, is_flood: bool) -> None: + """Record a successful reception.""" + if is_flood: + self.packets.flood_rx += 1 + else: + self.packets.direct_rx += 1 + + def record_tx_error(self) -> None: + """Record a transmission error.""" + self.packets.tx_errors += 1 + + def get_uptime_secs(self) -> int: + """Return the number of seconds since the collector was created.""" + return int(time.time() - self._start_time) + + def get_totals(self) -> dict: + """Return a summary of all statistics.""" + return { + "flood_tx": self.packets.flood_tx, + "flood_rx": self.packets.flood_rx, + "direct_tx": self.packets.direct_tx, + "direct_rx": self.packets.direct_rx, + "tx_errors": self.packets.tx_errors, + "total_tx": self.packets.flood_tx + self.packets.direct_tx, + "total_rx": self.packets.flood_rx + self.packets.direct_rx, + "uptime_secs": self.get_uptime_secs(), + } + + def reset(self) -> None: + """Reset all counters and restart uptime.""" + self.packets = PacketStats() + self._start_time = time.time() diff --git a/src/pymc_core/hardware/kiss_modem_wrapper.py b/src/pymc_core/hardware/kiss_modem_wrapper.py index d93e187..17180c1 100644 --- a/src/pymc_core/hardware/kiss_modem_wrapper.py +++ b/src/pymc_core/hardware/kiss_modem_wrapper.py @@ -18,6 +18,11 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, Callable, Dict, Optional, Union +import serial + +from ..protocol.packet_utils import PacketTimingUtils +from .base import LoRaRadio + # RX callback: (data) for backward compat, or (data, rssi, snr) for per-packet metrics RxCallback = Union[ Callable[[bytes], None], @@ -42,9 +47,6 @@ def _invoke_rx_callback( else: callback(data) -import serial - -from .base import LoRaRadio # KISS Protocol Constants (shared with standard KISS) KISS_FEND = 0xC0 # Frame End @@ -261,6 +263,15 @@ def __init__( self.radio_config = radio_config or {} self.is_configured = False + # Radio configuration — instance attributes matching SX1262Wrapper + # convention. Seeded from the config dict; updated by configure_radio(). + self.frequency = self.radio_config.get("frequency", int(869.618 * 1000000)) + self.tx_power = self.radio_config.get("power", self.radio_config.get("tx_power", 22)) + self.spreading_factor = self.radio_config.get("spreading_factor", 8) + self.bandwidth = self.radio_config.get("bandwidth", int(62500)) + self.coding_rate = self.radio_config.get("coding_rate", 8) + self.preamble_length = self.radio_config.get("preamble_length", 17) + self.serial_conn: Optional[serial.Serial] = None self.is_connected = False @@ -572,9 +583,18 @@ def _query_modem_info(self): except Exception as e: logger.warning(f"Failed to query modem info: {e}") - def configure_radio(self) -> bool: - """ - Configure radio parameters + def configure_radio( + self, + frequency: Optional[int] = None, + bandwidth: Optional[int] = None, + spreading_factor: Optional[int] = None, + coding_rate: Optional[int] = None, + ) -> bool: + """Configure radio parameters. + + When called with keyword arguments (e.g. from CompanionRadio), those + values take precedence. When called with no arguments the values are + read from ``self.radio_config`` (populated from config.yaml at init). Returns: True if configuration successful, False otherwise @@ -584,12 +604,23 @@ def configure_radio(self) -> bool: return False try: - # Extract configuration parameters with defaults - # Support both "power" and "tx_power" for compatibility with different config styles - frequency_hz = self.radio_config.get("frequency", int(869.618 * 1000000)) - bandwidth_hz = self.radio_config.get("bandwidth", int(62500)) - sf = self.radio_config.get("spreading_factor", 8) - cr = self.radio_config.get("coding_rate", 8) + # Explicit kwargs take precedence, then radio_config dict, then defaults + frequency_hz = ( + frequency + if frequency is not None + else self.radio_config.get("frequency", int(869.618 * 1000000)) + ) + bandwidth_hz = ( + bandwidth + if bandwidth is not None + else self.radio_config.get("bandwidth", int(62500)) + ) + sf = ( + spreading_factor + if spreading_factor is not None + else self.radio_config.get("spreading_factor", 8) + ) + cr = coding_rate if coding_rate is not None else self.radio_config.get("coding_rate", 8) power = self.radio_config.get("power", self.radio_config.get("tx_power", 22)) # Set radio parameters (frequency, bandwidth, SF, CR) @@ -608,6 +639,13 @@ def configure_radio(self) -> bool: # Note: Sync word is configured at firmware build time, not at runtime + # Sync instance attributes to match what was applied to hardware + self.frequency = frequency_hz + self.bandwidth = bandwidth_hz + self.spreading_factor = sf + self.coding_rate = cr + self.tx_power = power + self.is_configured = True logger.info( f"Radio configured: {frequency_hz / 1000000:.3f} MHz, " @@ -701,9 +739,7 @@ def _send_command( self._pending_response = None # SetHardware frame: type 0x06, payload = sub_cmd (1 byte) + data - kiss_frame = self._encode_kiss_frame( - KISS_CMD_SETHARDWARE, bytes([sub_cmd]) + data - ) + kiss_frame = self._encode_kiss_frame(KISS_CMD_SETHARDWARE, bytes([sub_cmd]) + data) if not self._write_frame(kiss_frame): logger.warning("SetHardware frame write failed") @@ -735,6 +771,27 @@ def get_radio_config(self) -> Optional[Dict[str, Any]]: } return None + def set_tx_power(self, power: int) -> bool: + """Set TX power in dBm. + + Sends the command to the modem and updates the instance attribute + on success, matching the SX1262Wrapper interface. + """ + if not self.is_connected: + logger.error("Cannot set TX power: not connected") + return False + try: + resp = self._send_command(CMD_SET_TX_POWER, bytes([power])) + if not resp or resp[0] == RESP_ERROR: + logger.error("Failed to set TX power") + return False + self.tx_power = power + logger.info(f"TX power set to {power} dBm") + return True + except Exception as e: + logger.error(f"Error setting TX power: {e}") + return False + def get_tx_power(self) -> Optional[int]: """Get current TX power in dBm""" resp = self._send_command(CMD_GET_TX_POWER) @@ -760,17 +817,21 @@ def is_channel_busy(self) -> bool: return resp[1][0] == 0x01 return False - def get_airtime(self, packet_length: int) -> Optional[int]: + def get_airtime(self, packet_length: int, timeout: Optional[float] = None) -> Optional[int]: """ - Get estimated airtime for a packet + Get estimated airtime for a packet from the modem. Args: packet_length: Length of packet in bytes + timeout: Response timeout in seconds (default: RESPONSE_TIMEOUT). + Use a shorter value (e.g. 1.0) in the TX path to avoid + blocking when the modem is busy or unresponsive. Returns: - Airtime in milliseconds or None on error + Airtime in milliseconds or None on error/timeout """ - resp = self._send_command(CMD_GET_AIRTIME, bytes([packet_length])) + t = timeout if timeout is not None else RESPONSE_TIMEOUT + resp = self._send_command(CMD_GET_AIRTIME, bytes([packet_length]), timeout=t) if resp and resp[0] == RESP_AIRTIME and len(resp[1]) >= 4: return struct.unpack(" Optional[Dict[str, Any]]: if not success: raise Exception("Failed to send frame via KISS modem") - airtime = self.get_airtime(len(data)) + # Use short timeout for GET_AIRTIME so TX path is not blocked if modem + # is busy or unresponsive (avoids 5s stall and subsequent bad state). + airtime = self.get_airtime(len(data), timeout=1.0) + if airtime is None: + airtime = int(PacketTimingUtils.estimate_airtime_ms(len(data), self.radio_config)) return { - "airtime_ms": airtime if airtime is not None else 0, + "airtime_ms": airtime, "lbt_attempts": len(lbt_backoff_delays), "lbt_backoff_delays_ms": lbt_backoff_delays, "lbt_channel_busy": len(lbt_backoff_delays) > 0, @@ -1168,8 +1233,12 @@ def get_status(self) -> Dict[str, Any]: status: Dict[str, Any] = { "initialized": self.is_connected, "frequency": cfg["frequency"] if cfg else self.radio_config.get("frequency", 0), - "tx_power": tx_power if tx_power is not None else self.radio_config.get("tx_power", self.radio_config.get("power", 0)), - "spreading_factor": cfg["spreading_factor"] if cfg else self.radio_config.get("spreading_factor", 0), + "tx_power": tx_power + if tx_power is not None + else self.radio_config.get("tx_power", self.radio_config.get("power", 0)), + "spreading_factor": cfg["spreading_factor"] + if cfg + else self.radio_config.get("spreading_factor", 0), "bandwidth": cfg["bandwidth"] if cfg else self.radio_config.get("bandwidth", 0), "coding_rate": cfg["coding_rate"] if cfg else self.radio_config.get("coding_rate", 0), "last_rssi": self.stats.get("last_rssi", -999), @@ -1247,9 +1316,7 @@ def _decode_kiss_byte(self, byte: int): if len(self.rx_frame_buffer) >= MAX_FRAME_SIZE: # Frame too long (e.g. lost FEND); reset and resync at next FEND self.stats["frame_errors"] += 1 - logger.warning( - "KISS frame exceeded max size (%d), resyncing", MAX_FRAME_SIZE - ) + logger.warning("KISS frame exceeded max size (%d), resyncing", MAX_FRAME_SIZE) self.rx_frame_buffer.clear() self.in_frame = False else: diff --git a/src/pymc_core/node/dispatcher.py b/src/pymc_core/node/dispatcher.py index 9cd5a3e..5d6474c 100644 --- a/src/pymc_core/node/dispatcher.py +++ b/src/pymc_core/node/dispatcher.py @@ -3,28 +3,27 @@ import asyncio import enum import logging -from typing import Any, Awaitable, Callable, Optional +from typing import Any, Awaitable, Callable, List, Optional from ..protocol import Packet from ..protocol.constants import ( # Payload types PAYLOAD_TYPE_ACK, PAYLOAD_TYPE_ADVERT, PH_TYPE_SHIFT, + ROUTE_TYPE_FLOOD, + ROUTE_TYPE_TRANSPORT_FLOOD, ) +from ..protocol.packet_utils import PathUtils +from ..protocol.transport_keys import calc_transport_code from ..protocol.utils import PAYLOAD_TYPES, ROUTE_TYPES, format_packet_info # Import handler classes from .handlers import ( AckHandler, - AdvertHandler, AnonReqResponseHandler, ControlHandler, - GroupTextHandler, - LoginResponseHandler, - PathHandler, - ProtocolResponseHandler, - TextMessageHandler, TraceHandler, + create_core_handlers, ) ACK_TIMEOUT = 5.0 # seconds to wait for an ACK @@ -68,8 +67,17 @@ def __init__( self.packet_received_callback: Optional[Callable[[Packet], Awaitable[None] | None]] = None self.packet_sent_callback: Optional[Callable[[Packet], Awaitable[None] | None]] = None - # Add raw packet callback for detailed logging + # Optional listener for ACK received (e.g. companion send_confirmed) + self._ack_received_listener: Optional[Callable[[int], Awaitable[None] | None]] = None + + # Optional callback for PAYLOAD_TYPE_RAW_CUSTOM (companion raw_data_received) + self.raw_data_received_callback: Optional[Callable[[Packet], Awaitable[None]]] = None + + # Raw packet callbacks: single callback (legacy) and list of subscribers (after parse) self.raw_packet_callback: Optional[Callable[[Packet, bytes], Awaitable[None] | None]] = None + self._raw_packet_subscribers: List[Callable[..., Any]] = [] + # Raw RX subscribers: notified for every reception (data, rssi, snr) before duplicate/parse + self._raw_rx_subscribers: List[Callable[..., Any]] = [] self._handlers: dict[int, Any] = {} # Keep track of packet handlers self._handler_instances: dict[ @@ -82,6 +90,11 @@ def __init__( # Contact book for decrypting messages (set by the node later) self.contact_book = None + # Flood scope: 16-byte transport key for region-scoped flooding. + # When set, flood packets are tagged with a transport code and sent + # as ROUTE_TYPE_TRANSPORT_FLOOD. Set via companion set_flood_scope(). + self.flood_transport_key: Optional[bytes] = None + self._logger = logging.getLogger("Dispatcher") self._current_expected_crc: Optional[int] = None self._recent_acks: dict[int, float] = {} # {crc: timestamp} @@ -156,96 +169,70 @@ def register_default_handlers( # Keep our identity handy for detecting our own packets self.local_identity = local_identity - # Set up ACK handler with callback to us + # --- ACK handler (dispatcher-specific wiring) --- ack_handler = AckHandler(self._log, self) ack_handler.set_ack_received_callback(self._register_ack_received) - - # Register all the standard handlers - self.register_handler( - AdvertHandler.payload_type(), - AdvertHandler(self._log), - ) self.register_handler(AckHandler.payload_type(), ack_handler) - # Text message handler - needs to send ACKs back through us - text_message_handler = TextMessageHandler( - local_identity, - contacts, - self._log, - self.send_packet, - event_service, - radio_config, - ) - # Keep a reference so the node can use it - self.text_message_handler = text_message_handler - self.register_handler( - TextMessageHandler.payload_type(), - text_message_handler, - ) - # Group text handler with channel database - self.register_handler( - GroupTextHandler.payload_type(), - GroupTextHandler( - local_identity, - contacts, - self._log, - self.send_packet, - channel_db, - event_service, - node_name, - ), - ) - # Protocol response handler for encrypted responses (including telemetry) - protocol_response_handler = ProtocolResponseHandler(self._log, local_identity, contacts) - # Keep a reference for the node - self.protocol_response_handler = protocol_response_handler - - # Login response handler for PAYLOAD_TYPE_RESPONSE packets - login_response_handler = LoginResponseHandler(local_identity, contacts, self._log) - # Connect protocol response handler for forwarding telemetry - login_response_handler.set_protocol_response_handler(protocol_response_handler) - # Keep references for backward compatibility - # Note: telemetry now uses protocol_response_handler, login uses PAYLOAD_TYPE_RESPONSE - self.login_response_handler = login_response_handler - # For backward compatibility, point telemetry handler to protocol response handler - self.telemetry_response_handler = protocol_response_handler - - # PATH handler - for route discovery packets, with ACK and protocol response processing - path_handler = PathHandler( - self._log, ack_handler, protocol_response_handler, login_response_handler - ) - self.register_handler(PathHandler.payload_type(), path_handler) - - # Login response handler for PAYLOAD_TYPE_RESPONSE packets - self.register_handler( - LoginResponseHandler.payload_type(), - login_response_handler, + # --- Core handlers via shared factory --- + core = create_core_handlers( + identity=local_identity, + contacts=contacts, + channels=channel_db, + event_service=event_service, + send_packet_fn=self.send_packet, + log_fn=self._log, + node_name=node_name, + radio_config=radio_config, + ack_handler=ack_handler, ) - # Anonymous request response handler for login responses that come as ANON_REQ + # Keep references for companion layer access + self.text_message_handler = core.text_handler + self.protocol_response_handler = core.protocol_response_handler + self.login_response_handler = core.login_response_handler + self.group_text_handler = core.group_text_handler + # Backward compat alias + self.telemetry_response_handler = core.protocol_response_handler + + # Register core handlers by payload type + from .handlers import AdvertHandler as _Adv + from .handlers import GroupTextHandler as _Grp + from .handlers import LoginResponseHandler as _Login + from .handlers import PathHandler as _Path + from .handlers import TextMessageHandler as _Txt + + self.register_handler(_Adv.payload_type(), core.advert_handler) + self.register_handler(_Txt.payload_type(), core.text_handler) + self.register_handler(_Grp.payload_type(), core.group_text_handler) + self.register_handler(_Path.payload_type(), core.path_handler) + self.register_handler(_Login.payload_type(), core.login_response_handler) + + # --- Dispatcher-only handlers --- self.register_handler( AnonReqResponseHandler.payload_type(), AnonReqResponseHandler(local_identity, contacts, self._log), ) - # TRACE handler for diagnostics and routing analysis - trace_handler = TraceHandler(self._log, protocol_response_handler) - self.register_handler( - TraceHandler.payload_type(), - trace_handler, - ) - # Keep a reference for the node + trace_handler = TraceHandler(self._log, core.protocol_response_handler) + self.register_handler(TraceHandler.payload_type(), trace_handler) self.trace_handler = trace_handler - # CONTROL handler for node discovery control_handler = ControlHandler(self._log) - self.register_handler( - ControlHandler.payload_type(), - control_handler, - ) - # Keep a reference for the node + self.register_handler(ControlHandler.payload_type(), control_handler) self.control_handler = control_handler + # --- RAW_CUSTOM handler: deliver to companion if direct and callback set --- + from ..protocol.constants import PAYLOAD_TYPE_RAW_CUSTOM + + async def raw_custom_handler(pkt: Packet) -> None: + if not pkt.is_route_direct(): + return + if self.raw_data_received_callback: + await self._invoke_callback(self.raw_data_received_callback, pkt) + + self.register_handler(PAYLOAD_TYPE_RAW_CUSTOM, raw_custom_handler) + self._logger.info("Default handlers registered.") # Set up a fallback handler for unknown packet types @@ -295,12 +282,49 @@ def set_packet_sent_callback( ) -> None: self.packet_sent_callback = callback + def set_ack_received_listener( + self, + callback: Optional[Callable[[int], Awaitable[None] | None]], + ) -> None: + """Set optional listener for ACK CRCs (e.g. companion send_confirmed).""" + self._ack_received_listener = callback + def set_raw_packet_callback( self, callback: Callable[[Packet, bytes], Awaitable[None] | None] ) -> None: """Set callback for raw packet data (includes both parsed packet and raw bytes).""" self.raw_packet_callback = callback + def add_raw_packet_subscriber(self, callback: Callable[..., Any]) -> None: + """Subscribe to every raw packet. Callback (pkt, data) or (pkt, data, analysis). + Forward raw RX to clients to track repeats by packet hash. + """ + if callback not in self._raw_packet_subscribers: + self._raw_packet_subscribers.append(callback) + + def remove_raw_packet_subscriber(self, callback: Callable[..., Any]) -> None: + """Unsubscribe from raw packet notifications (after parse).""" + try: + self._raw_packet_subscribers.remove(callback) + except ValueError: + pass + + def add_raw_rx_subscriber( + self, callback: Callable[[bytes, int, float], Awaitable[None] | None] + ) -> None: + """Subscribe to every incoming raw RX. Callback receives (data, rssi, snr). + Called before duplicate/blacklist so clients get every repeat. + """ + if callback not in self._raw_rx_subscribers: + self._raw_rx_subscribers.append(callback) + + def remove_raw_rx_subscriber(self, callback: Callable[..., Any]) -> None: + """Unsubscribe from raw RX notifications.""" + try: + self._raw_rx_subscribers.remove(callback) + except ValueError: + pass + def _on_packet_received( self, data: bytes, @@ -320,9 +344,31 @@ async def _process_received_packet( rssi: Optional[int] = None, snr: Optional[float] = None, ) -> None: - """Process a received packet from the radio callback. rssi/snr are per-packet when provided.""" + """Process received packet. rssi/snr are per-packet when provided.""" self._log(f"[RX DEBUG] Processing packet: {len(data)} bytes, data: {data.hex()[:32]}...") + # Notify raw RX subscribers so clients can track repeats + if rssi is not None: + rssi_val = rssi + elif hasattr(self.radio, "get_last_rssi"): + rssi_val = self.radio.get_last_rssi() + else: + rssi_val = 0 + if snr is not None: + snr_val = snr + elif hasattr(self.radio, "get_last_snr"): + snr_val = self.radio.get_last_snr() + else: + snr_val = 0.0 + for cb in self._raw_rx_subscribers: + try: + if asyncio.iscoroutinefunction(cb): + await cb(data, rssi_val, snr_val) + else: + cb(data, rssi_val, snr_val) + except Exception as e: + self._log(f"Raw RX subscriber error: {e}") + # Generate packet hash for deduplication and blacklist checking packet_hash = self.packet_filter.generate_hash(data) @@ -350,6 +396,10 @@ async def _process_received_packet( self._log(f"Blacklisted malformed packet (hash: {packet_hash})") return + # Packets at max hops for their path encoding must not be retransmitted + if PathUtils.is_path_at_max_hops(pkt.path_len): + pkt.mark_do_not_retransmit() + ptype = pkt.header >> PH_TYPE_SHIFT self._log(f"[RX DEBUG] Packet type: {ptype:02X}") @@ -361,8 +411,6 @@ async def _process_received_packet( # Let the node know about this packet for analysis (statistics, caching, etc.) if self.packet_analysis_callback: try: - import asyncio - if asyncio.iscoroutinefunction(self.packet_analysis_callback): await self.packet_analysis_callback(pkt, data) else: @@ -371,9 +419,13 @@ async def _process_received_packet( except Exception as e: self._log(f"Error in packet analysis callback: {e}") - # Always call raw packet callback first for logging (regardless of source) + # Notify raw packet subscribers (e.g. companion clients for PUSH_CODE_LOG_RX_DATA) + analysis = {} + for callback in self._raw_packet_subscribers: + await self._invoke_enhanced_raw_callback(callback, pkt, data, analysis) if self.raw_packet_callback: await self._invoke_enhanced_raw_callback(self.raw_packet_callback, pkt, data, {}) + if self._raw_packet_subscribers or self.raw_packet_callback: self._log("[RX DEBUG] Raw packet callback completed") # Check if this is our own packet before processing handlers @@ -395,6 +447,23 @@ async def _process_received_packet( # Public interface - sending and receiving packets # ------------------------------------------------------------------ + def _apply_flood_scope(self, pkt: Packet) -> None: + """Apply flood scope transport codes to a packet in-place. + + If ``flood_transport_key`` is set and the packet uses flood routing, + calculates the transport code, attaches it to the packet, and + switches the route type to ``ROUTE_TYPE_TRANSPORT_FLOOD``. + """ + if self.flood_transport_key is None: + return + route_type = pkt.get_route_type() + if route_type != ROUTE_TYPE_FLOOD: + return + code = calc_transport_code(self.flood_transport_key, pkt) + pkt.transport_codes[0] = code + pkt.transport_codes[1] = 0 # reserved for home region + pkt.header = (pkt.header & ~0x03) | ROUTE_TYPE_TRANSPORT_FLOOD + async def send_packet( self, packet: Packet, @@ -411,6 +480,7 @@ async def send_packet( expected_crc: The expected CRC for ACK matching. If None, will be calculated from packet. """ + self._apply_flood_scope(packet) async with self._tx_lock: # Wait our turn return await self._send_packet_immediate(packet, wait_for_ack, expected_crc) @@ -538,7 +608,7 @@ async def _dispatch(self, pkt: Packet) -> None: # All ACK processing logic is delegated to the AckHandler. # ------------------------------------------------------------------ - def _register_ack_received(self, crc: int) -> None: + async def _register_ack_received(self, crc: int) -> None: """Record that an ACK with the given CRC was received.""" ts = asyncio.get_running_loop().time() self._recent_acks[crc] = ts @@ -548,6 +618,9 @@ def _register_ack_received(self, crc: int) -> None: self._log(f"ACK matched! CRC {crc:08X}") evt.set() + if self._ack_received_listener: + await self._invoke_ack_listener(crc) + async def run_forever(self) -> None: """Run the dispatcher maintenance loop indefinitely (call this in an asyncio task).""" health_check_counter = 0 @@ -590,6 +663,16 @@ async def _invoke_callback(self, cb, pkt: Packet) -> None: else: cb(pkt) + async def _invoke_ack_listener(self, crc: int) -> None: + """Invoke ack-received listener (sync or async).""" + cb = self._ack_received_listener + if cb is None: + return + if asyncio.iscoroutinefunction(cb): + await cb(crc) + else: + cb(crc) + async def _invoke_enhanced_raw_callback( self, callback, pkt: Packet, data: bytes, analysis: dict ) -> None: diff --git a/src/pymc_core/node/handlers/__init__.py b/src/pymc_core/node/handlers/__init__.py index 27fc5c5..6bad086 100644 --- a/src/pymc_core/node/handlers/__init__.py +++ b/src/pymc_core/node/handlers/__init__.py @@ -10,6 +10,7 @@ from .login_response import AnonReqResponseHandler, LoginResponseHandler from .path import PathHandler from .protocol_response import ProtocolResponseHandler +from .registry import CoreHandlers, create_core_handlers from .text import TextMessageHandler from .trace import TraceHandler @@ -25,4 +26,6 @@ "AnonReqResponseHandler", "TraceHandler", "ControlHandler", + "CoreHandlers", + "create_core_handlers", ] diff --git a/src/pymc_core/node/handlers/ack.py b/src/pymc_core/node/handlers/ack.py index 7d260cb..e47070d 100644 --- a/src/pymc_core/node/handlers/ack.py +++ b/src/pymc_core/node/handlers/ack.py @@ -1,7 +1,9 @@ -from typing import Callable, Optional +import asyncio +from typing import Awaitable, Callable, Optional from ...protocol import Packet from ...protocol.constants import PAYLOAD_TYPE_ACK +from ...protocol.packet_utils import PathUtils from .base import BaseHandler @@ -20,9 +22,11 @@ def payload_type() -> int: def __init__(self, log_fn, dispatcher=None): self.log = log_fn self.dispatcher = dispatcher - self._ack_received_callback: Optional[Callable[[int], None]] = None + self._ack_received_callback: Optional[Callable[[int], Awaitable[None] | None]] = None - def set_ack_received_callback(self, callback: Callable[[int], None]): + def set_ack_received_callback( + self, callback: Optional[Callable[[int], Awaitable[None] | None]] + ): """Set callback to notify dispatcher when ACK is received.""" self._ack_received_callback = callback @@ -138,14 +142,15 @@ async def _process_bundled_ack_in_path(self, payload: bytes) -> Optional[int]: return None path_length = payload[0] + path_byte_len = PathUtils.get_path_byte_len(path_length) - # Check if we have enough data for: path_length + path + extra_type + extra - min_required = 1 + path_length + 1 + 4 # +4 for ACK CRC + # Check if we have enough data for: path_byte_len + path + extra_type + extra + min_required = 1 + path_byte_len + 1 + 4 # +4 for ACK CRC if len(payload) < min_required: return None # Extract extra section - extra_start = 1 + path_length + extra_start = 1 + path_byte_len extra_type = payload[extra_start] extra_payload = payload[extra_start + 1 :] @@ -162,5 +167,8 @@ async def _process_bundled_ack_in_path(self, payload: bytes) -> Optional[int]: async def _notify_ack_received(self, crc: int): """Notify the dispatcher that an ACK was received.""" if self._ack_received_callback: - # Call the callback directly since _register_ack_received is synchronous - self._ack_received_callback(crc) + cb = self._ack_received_callback + if asyncio.iscoroutinefunction(cb): + await cb(crc) + else: + cb(crc) diff --git a/src/pymc_core/node/handlers/advert.py b/src/pymc_core/node/handlers/advert.py index bda80c1..44b4b50 100644 --- a/src/pymc_core/node/handlers/advert.py +++ b/src/pymc_core/node/handlers/advert.py @@ -1,16 +1,18 @@ +import struct import time from typing import Any, Dict, Optional -from ...protocol import Identity, Packet, decode_appdata +from ...protocol import Identity, Packet, decode_appdata, parse_advert_payload from ...protocol.constants import ( MAX_ADVERT_DATA_SIZE, PAYLOAD_TYPE_ADVERT, PUB_KEY_SIZE, SIGNATURE_SIZE, - TIMESTAMP_SIZE, describe_advert_flags, ) +from ...protocol.packet_utils import PathUtils from ...protocol.utils import determine_contact_type_from_flags, get_contact_type_name +from ..events import MeshEvents from .base import BaseHandler @@ -19,33 +21,9 @@ class AdvertHandler(BaseHandler): def payload_type() -> int: return PAYLOAD_TYPE_ADVERT - def __init__(self, log_fn): + def __init__(self, log_fn, event_service=None): self.log = log_fn - - def _extract_advert_components(self, packet: Packet): - """Extract and validate advert packet components.""" - payload = packet.get_payload() - header_len = PUB_KEY_SIZE + TIMESTAMP_SIZE + SIGNATURE_SIZE - if len(payload) < header_len: - self.log( - f"Advert payload too short ({len(payload)} bytes, expected at least {header_len})" - ) - return None - - sig_offset = PUB_KEY_SIZE + TIMESTAMP_SIZE - pubkey = payload[:PUB_KEY_SIZE] - timestamp = payload[PUB_KEY_SIZE:sig_offset] - signature = payload[sig_offset : sig_offset + SIGNATURE_SIZE] - appdata = payload[sig_offset + SIGNATURE_SIZE :] - - if len(appdata) > MAX_ADVERT_DATA_SIZE: - self.log( - f"Advert appdata too large ({len(appdata)} bytes). " - f"Truncating to {MAX_ADVERT_DATA_SIZE}" - ) - appdata = appdata[:MAX_ADVERT_DATA_SIZE] - - return pubkey, timestamp, signature, appdata + self.event_service = event_service def _verify_advert_signature( self, pubkey: bytes, timestamp: bytes, appdata: bytes, signature: bytes @@ -85,13 +63,27 @@ def _verify_advert_signature( async def __call__(self, packet: Packet) -> Optional[Dict[str, Any]]: """Process advert packet and return parsed data with signature verification.""" try: - # Extract and validate packet components - components = self._extract_advert_components(packet) - if not components: + payload = packet.get_payload() + if not payload: + return None + try: + parsed = parse_advert_payload(payload) + except ValueError as e: + self.log(f"Advert payload parse error: {e}") return None - pubkey_bytes, timestamp_bytes, signature_bytes, appdata = components - pubkey_hex = pubkey_bytes.hex() + pubkey_bytes = bytes.fromhex(parsed["pubkey"]) + pubkey_hex = parsed["pubkey"] + advert_timestamp = parsed["timestamp"] + timestamp_bytes = struct.pack(" MAX_ADVERT_DATA_SIZE: + self.log( + f"Advert appdata too large ({len(appdata)} bytes), " + f"truncating to {MAX_ADVERT_DATA_SIZE}" + ) + appdata = appdata[:MAX_ADVERT_DATA_SIZE] # Verify cryptographic signature if not self._verify_advert_signature( @@ -102,7 +94,7 @@ async def __call__(self, packet: Packet) -> Optional[Dict[str, Any]]: self.log(f"Processing advert for pubkey: {pubkey_hex[:16]}...") - # Decode application data + # Decode application data (protocol.utils.decode_appdata) decoded = decode_appdata(appdata) # Extract name from decoded data @@ -119,6 +111,11 @@ async def __call__(self, packet: Packet) -> Optional[Dict[str, Any]]: contact_type_id = determine_contact_type_from_flags(flags_int) contact_type = get_contact_type_name(contact_type_id) + # Clamp to current time if remote clock is ahead (avoid "future" last-advert in UI) + now = int(time.time()) + if advert_timestamp > now: + advert_timestamp = now + # Build parsed advert data advert_data = { "public_key": pubkey_hex, @@ -129,6 +126,7 @@ async def __call__(self, packet: Packet) -> Optional[Dict[str, Any]]: "flags_description": flags_description, "contact_type_id": contact_type_id, "contact_type": contact_type, + "advert_timestamp": advert_timestamp, "timestamp": int(time.time()), "snr": packet._snr if hasattr(packet, "_snr") else 0.0, "rssi": packet._rssi if hasattr(packet, "_rssi") else 0, @@ -136,6 +134,31 @@ async def __call__(self, packet: Packet) -> Optional[Dict[str, Any]]: } self.log(f"Parsed advert: {name} ({contact_type})") + + # Publish so companion/app receives node-discovered and advert_received callbacks + if self.event_service: + try: + path_len = getattr(packet, "path_len", 0) or 0 + path = getattr(packet, "path", bytearray()) or bytearray() + path_byte_len = PathUtils.get_path_byte_len(path_len) + inbound_path = bytes(path[:path_byte_len]) if path_byte_len > 0 else b"" + event_data = { + "public_key": pubkey_hex, + "name": name, + "contact_type": contact_type_id, + "lat": lat, + "lon": lon, + "advert_timestamp": advert_timestamp, + "timestamp": int(time.time()), + "snr": advert_data["snr"], + "rssi": advert_data["rssi"], + "inbound_path": inbound_path, + "path_len_encoded": path_len, + } + self.event_service.publish_sync(MeshEvents.NODE_DISCOVERED, event_data) + except Exception as e: + self.log(f"Failed to publish NODE_DISCOVERED event: {e}") + return advert_data except Exception as e: diff --git a/src/pymc_core/node/handlers/control.py b/src/pymc_core/node/handlers/control.py index 93148b4..0f7dd25 100644 --- a/src/pymc_core/node/handlers/control.py +++ b/src/pymc_core/node/handlers/control.py @@ -24,13 +24,22 @@ class ControlHandler: This handler processes incoming discovery requests and responses. """ - def __init__(self, log_fn: Callable[[str], None]): + def __init__( + self, + log_fn: Callable[[str], None], + debug_log_fn: Optional[Callable[[str], None]] = None, + ): """Initialize control handler. - + Args: - log_fn: Logging function + log_fn: Logging function for normal messages. + debug_log_fn: Optional logging function for verbose messages (e.g. callback + presence). If set, "No callback waiting" and "Called response callback" + use this instead of log_fn, so callers can use logger.debug to avoid noise + when forwarding discovery to companions. """ self._log = log_fn + self._debug_log = debug_log_fn if debug_log_fn is not None else log_fn # Callbacks for discovery responses self._response_callbacks: Dict[int, Callable[[Dict[str, Any]], None]] = {} @@ -40,9 +49,7 @@ def __init__(self, log_fn: Callable[[str], None]): def payload_type() -> int: return PAYLOAD_TYPE_CONTROL - def set_response_callback( - self, tag: int, callback: Callable[[Dict[str, Any]], None] - ) -> None: + def set_response_callback(self, tag: int, callback: Callable[[Dict[str, Any]], None]) -> None: """Set callback for discovery responses with a specific tag.""" self._response_callbacks[tag] = callback @@ -50,9 +57,7 @@ def clear_response_callback(self, tag: int) -> None: """Clear callback for discovery responses with a specific tag.""" self._response_callbacks.pop(tag, None) - def set_request_callback( - self, callback: Callable[[Dict[str, Any]], None] - ) -> None: + def set_request_callback(self, callback: Callable[[Dict[str, Any]], None]) -> None: """Set callback for discovery requests (for logging/monitoring).""" self._request_callbacks[0] = callback @@ -69,9 +74,7 @@ async def __call__(self, pkt: Packet) -> Optional[Dict[str, Any]]: # Check if this is a zero-hop packet (path_len must be 0) if pkt.path_len != 0: - self._log( - f"[ControlHandler] Non-zero path length ({pkt.path_len}), ignoring" - ) + self._log(f"[ControlHandler] Non-zero path length ({pkt.path_len}), ignoring") return None # Extract control type (upper 4 bits of first byte) @@ -82,9 +85,7 @@ async def __call__(self, pkt: Packet) -> Optional[Dict[str, Any]]: elif control_type == CTL_TYPE_NODE_DISCOVER_RESP: return await self._handle_discovery_response(pkt) else: - self._log( - f"[ControlHandler] Unknown control type: 0x{control_type:02X}" - ) + self._log(f"[ControlHandler] Unknown control type: 0x{control_type:02X}") return None except Exception as e: @@ -93,7 +94,7 @@ async def __call__(self, pkt: Packet) -> Optional[Dict[str, Any]]: async def _handle_discovery_request(self, pkt: Packet) -> Optional[Dict[str, Any]]: """Handle node discovery request packet and return parsed data. - + Expected format: - byte 0: type (0x80) + flags (bit 0: prefix_only) - byte 1: filter (bitfield of node types to respond) @@ -147,7 +148,7 @@ async def _handle_discovery_request(self, pkt: Packet) -> Optional[Dict[str, Any async def _handle_discovery_response(self, pkt: Packet) -> Optional[Dict[str, Any]]: """Handle node discovery response packet and return parsed data. - + Response format: - byte 0: type (0x90) + node_type (lower 4 bits) - byte 1: SNR of our request (int8_t, multiplied by 4) @@ -179,8 +180,8 @@ async def _handle_discovery_response(self, pkt: Packet) -> Optional[Dict[str, An "tag": tag, "node_type": node_type, "inbound_snr": inbound_snr, # SNR of our request at their end - "response_snr": pkt._snr, # SNR of their response at our end - "rssi": pkt._rssi, # RSSI of their response at our end + "response_snr": pkt._snr, # SNR of their response at our end + "rssi": pkt._rssi, # RSSI of their response at our end "pub_key": pub_key.hex(), "pub_key_bytes": bytes(pub_key), "timestamp": time.time(), @@ -192,19 +193,14 @@ async def _handle_discovery_response(self, pkt: Packet) -> Optional[Dict[str, An callback = self._response_callbacks[tag] if callback: callback(response_data) - self._log( + self._debug_log( f"[ControlHandler] Called response callback for tag 0x{tag:08X}" ) else: - self._log( - f"[ControlHandler] No callback waiting for tag 0x{tag:08X}" - ) + self._debug_log(f"[ControlHandler] No callback waiting for tag 0x{tag:08X}") return response_data except Exception as e: self._log(f"[ControlHandler] Error handling discovery response: {e}") return None - - except Exception as e: - self._log(f"[ControlHandler] Error handling discovery response: {e}") diff --git a/src/pymc_core/node/handlers/group_text.py b/src/pymc_core/node/handlers/group_text.py index 404c9b1..35dafcc 100644 --- a/src/pymc_core/node/handlers/group_text.py +++ b/src/pymc_core/node/handlers/group_text.py @@ -1,7 +1,7 @@ from typing import Optional from ...protocol import Packet -from ...protocol.constants import PAYLOAD_TYPE_GRP_TXT +from ...protocol.constants import PAYLOAD_TYPE_GRP_TXT, ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD from ...protocol.crypto import CryptoUtils from .base import BaseHandler @@ -29,53 +29,71 @@ def __init__( self.event_service = event_service self.our_node_name = our_node_name # Store our node name for echo detection + def set_our_node_name(self, name: str | None) -> None: + """Update the node name used for echo detection (e.g. after set_advert_name).""" + self.our_node_name = name + def _get_channel_by_hash(self, channel_hash: int) -> Optional[dict]: - """Find a channel by its hash (first byte of public key) from database.""" + """Find a channel by its hash (first byte of SHA256) from database. + + Returns the first matching channel. See also + :meth:`_get_channels_by_hash` which returns *all* matches (needed + because the hash is only 1 byte and collisions are expected). + """ + matches = self._get_channels_by_hash(channel_hash) + return matches[0] if matches else None + + def _get_channels_by_hash(self, channel_hash: int) -> list[dict]: + """Return **all** channels whose derived hash matches *channel_hash*. + + The channel hash is only 1 byte, so collisions between channels + with different PSKs are expected (~0.4 % per foreign channel). + The firmware handles this by trying each match until HMAC validates; + we do the same. + """ if not self.channel_db: self.log("No channel database available") - return None + return [] try: - # Get all channels from database (live query) channels = self.channel_db.get_channels() + matches = [] for channel in channels: if "secret" in channel: - # Use consistent channel hash derivation calculated_hash = self._derive_channel_hash(channel["secret"]) if calculated_hash == channel_hash: - return channel - return None + matches.append(channel) + return matches except Exception as e: self.log(f"Error querying channel database: {e}") - return None - - def _derive_channel_hash(self, channel_secret: str) -> int: - """Derive a consistent channel hash from the secret.""" - import hashlib + return [] - # Convert hex secret to bytes, then derive key + def _secret_bytes_for_hash(self, channel_secret: str) -> bytes: + """Normalize secret to bytes used for channel hash (match MeshCore firmware). + Firmware hashes only first 16 bytes when second 16 are zero (128-bit key).""" try: secret_bytes = bytes.fromhex(channel_secret) except ValueError: - # If not hex, treat as UTF-8 string secret_bytes = channel_secret.encode("utf-8") + if len(secret_bytes) >= 32 and secret_bytes[16:32] == b"\x00" * 16: + return secret_bytes[:16] + if len(secret_bytes) > 32: + return secret_bytes[:32] + return secret_bytes - # Simple SHA256 derivation (no salt) to match official spec - channel_key = hashlib.sha256(secret_bytes).digest() + def _derive_channel_hash(self, channel_secret: str) -> int: + """Derive channel hash (first byte of SHA256) to match MeshCore firmware.""" + import hashlib - # Return first byte as the channel hash + secret_bytes = self._secret_bytes_for_hash(channel_secret) + channel_key = hashlib.sha256(secret_bytes).digest() return channel_key[0] def _derive_channel_keys(self, channel_secret: str) -> tuple: """Derive all necessary keys from channel secret.""" import hashlib - try: - secret_bytes = bytes.fromhex(channel_secret) - except ValueError: - secret_bytes = channel_secret.encode("utf-8") - - # Simple SHA256 derivation to match official spec + secret_bytes = self._secret_bytes_for_hash(channel_secret) master_key = hashlib.sha256(secret_bytes).digest() # Split into different keys @@ -88,7 +106,12 @@ def _derive_channel_keys(self, channel_secret: str) -> tuple: def _decrypt_channel_message( self, channel_secret: str, mac: bytes, ciphertext: bytes ) -> Optional[bytes]: - """Decrypt a channel message using the channel secret.""" + """Attempt to decrypt a channel message using *channel_secret*. + + Returns the plaintext on success, or ``None`` if the HMAC does not + validate (which is expected during candidate iteration when multiple + channels share the same 1-byte hash). + """ try: # Convert hex secret to bytes try: @@ -98,22 +121,19 @@ def _decrypt_channel_message( # Ensure we have PUB_KEY_SIZE (32 bytes) for the secret if len(secret_bytes) < 32: - # Pad with zeros if needed secret_bytes = secret_bytes + b"\x00" * (32 - len(secret_bytes)) elif len(secret_bytes) > 32: - # Truncate if too long secret_bytes = secret_bytes[:32] expected_mac = CryptoUtils._hmac_sha256(secret_bytes, ciphertext)[:2] if mac != expected_mac: - raise ValueError("Invalid HMAC") + return None # HMAC mismatch — normal during candidate iteration - plaintext = CryptoUtils._aes_decrypt(secret_bytes[:16], ciphertext) - return plaintext + return CryptoUtils._aes_decrypt(secret_bytes[:16], ciphertext) except Exception as e: - self.log(f"Channel message decryption failed: {e}") + self.log(f"Channel message decryption error: {e}") return None def _parse_plaintext_message(self, plaintext: bytes) -> Optional[dict]: @@ -124,7 +144,9 @@ def _parse_plaintext_message(self, plaintext: bytes) -> Optional[dict]: try: timestamp = int.from_bytes(plaintext[:4], "little") flags = plaintext[4] - message_content = plaintext[5:].decode("utf-8", errors="replace") + # Decode and strip trailing null (AES decrypt is block-aligned) + raw = plaintext[5:].decode("utf-8", errors="replace") + message_content = raw.rstrip("\x00") # Parse message flags according to spec message_type = "unknown" @@ -137,7 +159,8 @@ def _parse_plaintext_message(self, plaintext: bytes) -> Optional[dict]: # For signed messages, first two bytes are sender prefix if len(plaintext) >= 7: # sender_prefix = plaintext[5:7] # Unused for now - message_content = plaintext[7:].decode("utf-8", errors="replace") + raw = plaintext[7:].decode("utf-8", errors="replace") + message_content = raw.rstrip("\x00") return { "timestamp": timestamp, @@ -191,21 +214,36 @@ async def __call__(self, packet: Packet) -> None: cipher_mac = payload[1:3] ciphertext = payload[3:] - # Find the channel configuration - channel = self._get_channel_by_hash(channel_hash) - if not channel: + # Find all channels whose 1-byte hash matches (collisions are + # expected; the firmware tries up to 4 candidates). + candidates = self._get_channels_by_hash(channel_hash) + if not candidates: self.log(f"Unknown channel hash: {channel_hash:02X}") return + # Try each candidate until HMAC validates (matches firmware behaviour). + channel = None + plaintext = None + for candidate in candidates: + result = self._decrypt_channel_message(candidate["secret"], cipher_mac, ciphertext) + if result is not None: + channel = candidate + plaintext = result + break + + if channel is None or plaintext is None: + # No candidate validated — the packet is for a channel we + # don't have the key for (hash collision with 1-byte hash). + self.log( + f"GRP_TXT hash {channel_hash:02X} matched " + f"{len(candidates)} local channel(s) but HMAC failed " + f"for all — unknown channel" + ) + return + channel_name = channel.get("name", f"Channel-{channel_hash:02X}") self.log(f"Received group message for channel: {channel_name}") - # Decrypt the message - plaintext = self._decrypt_channel_message(channel["secret"], cipher_mac, ciphertext) - if not plaintext: - self.log("Failed to decrypt channel message") - return - # Parse the decrypted message parsed_message = self._parse_plaintext_message(plaintext) if not parsed_message: @@ -231,7 +269,9 @@ async def __call__(self, packet: Packet) -> None: # Check if this message is from ourselves using sender name (echo detection) is_own = self._is_own_message(packet) if is_own: - self.log(f"Own echo detected (will publish for heard-count): {sender_name}: {message_body}") + self.log( + f"Own echo detected (skip publish to client): {sender_name}: {message_body}" + ) # Log the group message self.log(f"<<< Channel [{channel_name}] {sender_name}: {message_body} >>>") @@ -259,7 +299,12 @@ async def _save_and_broadcast_group_message( ): """Save the group message to database and broadcast via WebSocket.""" try: - message_id = packet.get_packet_hash_hex(16) + message_id = packet.get_packet_hash_hex(16) + + # Do not publish NEW_CHANNEL_MESSAGE for our own messages (inject + echoes). + # The client already has the sent message; publishing per echo would spam the event. + if is_outgoing: + return # Publish channel message event if available if self.event_service: @@ -270,6 +315,12 @@ async def _save_and_broadcast_group_message( # Extract path from packet (list of node hashes) path = list(packet.path) if hasattr(packet, "path") and packet.path else None + # path_len: flood packets use actual path length; direct uses 0xFF + route_type = packet.header & 0x03 + if route_type in (ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD): + path_len = getattr(packet, "path_len", 0) or len(packet.path or []) + else: + path_len = 0xFF # Use a custom message type for single channel message addition message_data = { @@ -281,6 +332,8 @@ async def _save_and_broadcast_group_message( "timestamp": timestamp, "message_type": "group_text", "flags": 0, + "path_len": path_len, + "packet_hash": packet.calculate_packet_hash().hex().upper(), "full_content": packet.decrypted.get("group_text_data", {}).get( "full_content" ), @@ -295,8 +348,8 @@ async def _save_and_broadcast_group_message( }, } - # Publish channel message event - self.event_service.publish_sync(MeshEvents.NEW_CHANNEL_MESSAGE, message_data) + # Publish channel message event (await so queued and MSG_WAITING sent) + await self.event_service.publish(MeshEvents.NEW_CHANNEL_MESSAGE, message_data) self.log("Published group message event") except Exception as publish_error: self.log(f"Failed to publish group message event: {publish_error}") diff --git a/src/pymc_core/node/handlers/login_response.py b/src/pymc_core/node/handlers/login_response.py index a695a97..f242d39 100644 --- a/src/pymc_core/node/handlers/login_response.py +++ b/src/pymc_core/node/handlers/login_response.py @@ -3,7 +3,8 @@ from typing import Callable, Optional from ...protocol import CryptoUtils, Identity, Packet -from ...protocol.constants import PAYLOAD_TYPE_ANON_REQ, PAYLOAD_TYPE_RESPONSE +from ...protocol.constants import PAYLOAD_TYPE_ANON_REQ, PAYLOAD_TYPE_PATH, PAYLOAD_TYPE_RESPONSE +from ...protocol.packet_utils import PathUtils from .base import BaseHandler # Response codes from C++ server @@ -87,16 +88,15 @@ async def __call__(self, packet: Packet) -> None: if dest_hash != our_hash and src_hash != our_hash: return - # Find stored password and matching contact + # Find stored password and matching contact(s) if lookup_hash not in self._active_login_passwords: # This might be a telemetry response, not a login response - # Forward to protocol response handler if available + # Forward to protocol response handler if available (only for RESPONSE packets; + # PATH packets are already handled by PathHandler before we are called). if self._protocol_response_handler: - # Create a fake PATH packet format that - # ProtocolResponseHandler expects - # PATH format: dest_hash(1) + src_hash(1) + encrypted_data - # RESPONSE format is already: dest_hash(1) + src_hash(1) + encrypted_data - # So we can directly forward the packet to the protocol response handler + pkt_type = (packet.header >> 2) & 0x0F + if pkt_type == PAYLOAD_TYPE_PATH: + return # PathHandler already invoked protocol_response_handler for this packet try: await self._protocol_response_handler(packet) return @@ -106,23 +106,38 @@ async def __call__(self, packet: Packet) -> None: ) return - matched_contact = None - + # Collect all contacts whose public_key first byte matches (hash collision / multiple peers) + candidates = [] for contact in self.contacts.contacts: try: - contact_pubkey = bytes.fromhex(contact.public_key) - if len(contact_pubkey) > 0 and contact_pubkey[0] == lookup_hash: - matched_contact = contact - break + pk = contact.public_key + contact_pubkey = pk if isinstance(pk, bytes) else bytes.fromhex(pk) + if len(contact_pubkey) == 32 and contact_pubkey[0] == lookup_hash: + candidates.append(contact) except Exception: continue - if not matched_contact: + if not candidates: + if self.login_callback: + await self._safe_callback( + False, + { + "error": "No contact found for login response (src_hash=0x%02x)" + % lookup_hash + }, + ) return - # Decrypt and process response - response_data = await self._decrypt_response(packet, matched_contact, encrypted_start) - if response_data: + # Try each candidate until one decrypts successfully (same shared-secret as firmware) + response_data = None + matched_contact = None + for contact in candidates: + response_data = await self._decrypt_response(packet, contact, encrypted_start) + if response_data: + matched_contact = contact + break + + if response_data and matched_contact: await self._process_login_response(response_data, matched_contact) self.clear_login_password(lookup_hash) elif self.login_callback: @@ -137,7 +152,8 @@ async def _decrypt_response( encrypted_data = packet.payload[encrypted_start:] # Calculate X25519 ECDH shared secret - contact_pubkey = bytes.fromhex(contact.public_key) + pk = contact.public_key + contact_pubkey = pk if isinstance(pk, bytes) else bytes.fromhex(pk) contact_identity = Identity(contact_pubkey) shared_secret = contact_identity.calc_shared_secret( self.local_identity.get_private_key() @@ -147,16 +163,33 @@ async def _decrypt_response( aes_key = shared_secret[:16] plaintext = CryptoUtils.mac_then_decrypt(aes_key, shared_secret, encrypted_data) - if not plaintext or len(plaintext) < 12: + if not plaintext: + return None + + # If this is a PATH packet, unwrap the path-return envelope to get + # the inner response. PATH format after decryption: + # path_len(1) + path(N) + extra_type(1) + extra_data(M) + pkt_type = (packet.header >> 2) & 0x0F + if pkt_type == PAYLOAD_TYPE_PATH and len(plaintext) >= 2: + path_len_byte = plaintext[0] + path_byte_len = PathUtils.get_path_byte_len(path_len_byte) + inner_offset = 1 + path_byte_len + 1 # skip path_len + path + extra_type + if PathUtils.is_valid_path_len(path_len_byte) and len(plaintext) >= inner_offset: + extra_type = plaintext[1 + path_byte_len] & 0x0F + if extra_type == PAYLOAD_TYPE_RESPONSE and len(plaintext) > inner_offset: + plaintext = plaintext[inner_offset:] + + if len(plaintext) < 12: return None - # Parse the C++ response format: + # Parse the C++ response format (handleLoginReq reply_data): # timestamp(4) + response_code(1) + keep_alive(1) + is_admin(1) + - # reserved(1) + random(4) + # permissions(1) + random(4) + [firmware_ver_level(1) at index 12] timestamp, response_code, keep_alive, is_admin, reserved = struct.unpack( "= 13 else None return { "timestamp": timestamp, @@ -165,6 +198,7 @@ async def _decrypt_response( "is_admin": bool(is_admin), "reserved": reserved, "random_blob": random_blob, + "firmware_ver_level": firmware_ver_level, "contact": contact, } diff --git a/src/pymc_core/node/handlers/login_server.py b/src/pymc_core/node/handlers/login_server.py index 4841a5d..ff551c0 100644 --- a/src/pymc_core/node/handlers/login_server.py +++ b/src/pymc_core/node/handlers/login_server.py @@ -89,7 +89,7 @@ async def __call__(self, packet: Packet) -> None: """Handle ANON_REQ login packet from client.""" try: # Debug: Log packet routing info - path_data = list(packet.path[: packet.path_len]) if packet.path_len > 0 else [] + path_data = packet.get_path_hashes_hex() if packet.path_len > 0 else [] self.log( f"[LoginServer] Packet route flood: {packet.is_route_flood()}, " f"path_len: {packet.path_len}, path: {path_data}" @@ -144,22 +144,28 @@ async def __call__(self, packet: Packet) -> None: self.log("[LoginServer] Room server packet too short for sync_since field") return sync_since = struct.unpack(" 0 else '(empty)'}") + self.log( + f"[LoginServer] Room server: sync_since={sync_since}, " + f"password from byte 8 to {null_idx}" + ) + self.log( + f"[LoginServer] Password hex: " + f"{password_bytes.hex() if password_bytes else '(empty)'}" + ) else: # Repeater format: password only # Find null terminator after timestamp (starting from byte 4) null_idx = plaintext.find(b"\x00", 4) if null_idx == -1: null_idx = len(plaintext) - + password_bytes = plaintext[4:null_idx] self.log(f"[LoginServer] Repeater format: password from byte 4 to {null_idx}") @@ -247,7 +253,7 @@ async def _send_login_response( client_hash = client_identity.get_public_key()[0] server_hash = self.local_identity.get_public_key()[0] path_list = ( - list(original_packet.path[: original_packet.path_len]) + list(original_packet.path[: original_packet.get_path_byte_len()]) if original_packet and original_packet.path_len > 0 else [] ) diff --git a/src/pymc_core/node/handlers/path.py b/src/pymc_core/node/handlers/path.py index 225190c..4348e48 100644 --- a/src/pymc_core/node/handlers/path.py +++ b/src/pymc_core/node/handlers/path.py @@ -65,52 +65,28 @@ async def __call__(self, pkt: Packet) -> None: # Optional PATH packet analysis if analyzer is available try: - # Try to use any available packet analyzer through callback - if hasattr(self, "_dispatcher") and hasattr( - self._dispatcher, "packet_analysis_callback" + if ( + hasattr(self, "_dispatcher") + and hasattr(self._dispatcher, "packet_analysis_callback") + and self._dispatcher.packet_analysis_callback ): - if self._dispatcher.packet_analysis_callback: - self._dispatcher.packet_analysis_callback(pkt) - self._log("PATH packet analysis delegated to app") - else: - self._log("PATH packet received - hop analysis requires app-level analyzer") - + self._dispatcher.packet_analysis_callback(pkt) except Exception as e: self._log(f"PATH packet analysis failed: {e}") - # Extract and log key PATH information directly from packet + # Single summary line for PATH packet try: payload = pkt.get_payload() - hop_count = pkt.path_len + hop_count = pkt.get_path_hash_count() if len(payload) >= 2: dest_hash = payload[0] src_hash = payload[1] self._log( f"PATH packet: hop_count={hop_count}, " - f"dest=0x{dest_hash:02X}, src=0x{src_hash:02X}, " - f"payload_len={len(payload)}" + f"dest=0x{dest_hash:02X}, src=0x{src_hash:02X}, payload_len={len(payload)}" ) - if hop_count > 0: - self._log(f"Path contains {hop_count} hops") - else: - self._log("Direct PATH (no intermediate hops)") else: - self._log("PATH packet received with minimal payload") - - # Log basic routing behavior based on header - try: - # These constants are already imported at the top - # from ...protocol.constants import ( - # ROUTE_TYPE_DIRECT, - # ROUTE_TYPE_FLOOD, - # ) - - # Extract route type from packet header if possible - # This is a simplified version without full analysis - self._log("PATH packet routing analysis requires app-level analyzer") - except ImportError: - pass - + self._log("PATH packet: minimal payload") except Exception as e: self._log(f"Error extracting PATH information: {e}") diff --git a/src/pymc_core/node/handlers/protocol_request.py b/src/pymc_core/node/handlers/protocol_request.py index 6541635..a5b42f2 100644 --- a/src/pymc_core/node/handlers/protocol_request.py +++ b/src/pymc_core/node/handlers/protocol_request.py @@ -5,11 +5,12 @@ """ import struct -from typing import Optional, Callable, Any +from typing import Callable, Optional -from pymc_core.protocol.constants import PAYLOAD_TYPE_REQ, PAYLOAD_TYPE_RESPONSE -from pymc_core.protocol.crypto import CryptoUtils from pymc_core.protocol import PacketBuilder +from pymc_core.protocol.constants import MAX_PATH_SIZE, PAYLOAD_TYPE_REQ, PAYLOAD_TYPE_RESPONSE +from pymc_core.protocol.crypto import CryptoUtils +from pymc_core.protocol.packet_utils import PathUtils # Request type codes (matching C++ implementation) REQ_TYPE_GET_STATUS = 0x01 @@ -17,6 +18,7 @@ REQ_TYPE_GET_TELEMETRY_DATA = 0x03 REQ_TYPE_GET_ACCESS_LIST = 0x05 REQ_TYPE_GET_NEIGHBOURS = 0x06 +REQ_TYPE_GET_OWNER_INFO = 0x07 # Variable-length: tag(4) + "version\nname\nowner" # Response delay (matching C++ SERVER_RESPONSE_DELAY) SERVER_RESPONSE_DELAY_MS = 500 @@ -25,17 +27,17 @@ class ProtocolRequestHandler: """ Handler for protocol request packets (PAYLOAD_TYPE_REQ). - + Processes encrypted request packets from authenticated clients and sends appropriate RESPONSE packets. Request handling is delegated to callbacks for application-specific logic. """ - + @staticmethod def payload_type(): """Return the payload type this handler processes.""" return PAYLOAD_TYPE_REQ - + def __init__( self, local_identity, @@ -46,7 +48,7 @@ def __init__( ): """ Initialize protocol request handler. - + Args: local_identity: LocalIdentity for this handler contacts: Contact manager or wrapper providing client lookup @@ -59,120 +61,131 @@ def __init__( self.get_client_fn = get_client_fn self.request_handlers = request_handlers or {} self.log = log_fn if log_fn else lambda msg: None - + async def __call__(self, packet): """ Process a protocol request packet. - + Args: packet: Packet instance with REQ payload - + Returns: Packet: RESPONSE packet to send, or None """ try: if len(packet.payload) < 2: return None - + dest_hash = packet.payload[0] src_hash = packet.payload[1] - + # Verify this packet is for us our_hash = self.local_identity.get_public_key()[0] if dest_hash != our_hash: return None - + self.log(f"Processing REQ from 0x{src_hash:02X}") - + # Get client info client = self._get_client(src_hash) if not client: self.log(f"REQ from unknown client 0x{src_hash:02X}") return None - + # Get shared secret shared_secret = self._get_shared_secret(client) if not shared_secret: self.log(f"No shared secret for client 0x{src_hash:02X}") return None - + # Decrypt request encrypted_data = packet.payload[2:] aes_key = shared_secret[:16] - + try: - plaintext = CryptoUtils.mac_then_decrypt(aes_key, shared_secret, bytes(encrypted_data)) + plaintext = CryptoUtils.mac_then_decrypt( + aes_key, shared_secret, bytes(encrypted_data) + ) except Exception as e: self.log(f"Failed to decrypt REQ: {e}") return None - + # Parse request if len(plaintext) < 5: self.log("REQ packet too short") return None - - timestamp = struct.unpack(' 5 else b'' - + req_data = plaintext[5:] if len(plaintext) > 5 else b"" + self.log(f"REQ type=0x{req_type:02X}, timestamp={timestamp}") - + # Handle request response_data = await self._handle_request(client, timestamp, req_type, req_data) - + if response_data: return self._build_response(packet, client, response_data, shared_secret) - + return None - + except Exception as e: self.log(f"Error processing REQ: {e}") return None - + def _get_client(self, src_hash: int): """Get client info by source hash.""" if self.get_client_fn: return self.get_client_fn(src_hash) - + # Fallback: search in contacts - if hasattr(self.contacts, 'contacts'): + if hasattr(self.contacts, "contacts"): for contact in self.contacts.contacts: - if hasattr(contact, 'public_key'): - pk = bytes.fromhex(contact.public_key) if isinstance(contact.public_key, str) else contact.public_key + if hasattr(contact, "public_key"): + pk = ( + bytes.fromhex(contact.public_key) + if isinstance(contact.public_key, str) + else contact.public_key + ) if pk[0] == src_hash: return contact - + return None - + def _get_shared_secret(self, client): """Get shared secret for client.""" - if hasattr(client, 'shared_secret'): + if hasattr(client, "shared_secret"): return client.shared_secret - - if hasattr(client, 'public_key'): - pk = bytes.fromhex(client.public_key) if isinstance(client.public_key, str) else client.public_key + + if hasattr(client, "public_key"): + pk = ( + bytes.fromhex(client.public_key) + if isinstance(client.public_key, str) + else client.public_key + ) from pymc_core.protocol.identity import Identity + identity = Identity(pk) return identity.calc_shared_secret(self.local_identity.get_private_key()) - + return None - + async def _handle_request(self, client, timestamp: int, req_type: int, req_data: bytes): """ Handle request and generate response. - + Args: client: Client info object timestamp: Request timestamp req_type: Request type code req_data: Request payload - + Returns: bytes: Response data (timestamp + payload) or None """ # Build response with reflected timestamp - response = bytearray(struct.pack('= 0 and len(client.out_path) > 0: - reply_packet.path = bytearray(client.out_path[:client.out_path_len]) - reply_packet.path_len = client.out_path_len - - self.log(f"RESPONSE built for 0x{client_identity.get_public_key()[0]:02X} via {route_type.upper()}") - + reply_packet.set_path( + client.out_path[:MAX_PATH_SIZE], + client.out_path_len + if PathUtils.is_valid_path_len(client.out_path_len) + else None, + ) + + self.log( + f"RESPONSE built for 0x{client_identity.get_public_key()[0]:02X} " + f"via {route_type.upper()}" + ) + return reply_packet - + except Exception as e: self.log(f"Error building RESPONSE: {e}") return None diff --git a/src/pymc_core/node/handlers/protocol_response.py b/src/pymc_core/node/handlers/protocol_response.py index c82ca4f..79d09b6 100644 --- a/src/pymc_core/node/handlers/protocol_response.py +++ b/src/pymc_core/node/handlers/protocol_response.py @@ -4,12 +4,140 @@ back as PATH packets with encrypted payloads. """ +import asyncio import struct from typing import Any, Callable, Dict, Optional -from ...hardware.signal_utils import snr_register_to_db from ...protocol import CryptoUtils, Identity, Packet -from ...protocol.constants import PAYLOAD_TYPE_PATH +from ...protocol.constants import PAYLOAD_TYPE_PATH, PAYLOAD_TYPE_RESPONSE, ROUTE_TYPE_DIRECT +from ...protocol.crypto import CIPHER_BLOCK_SIZE, CIPHER_MAC_SIZE +from ...protocol.packet_builder import PacketBuilder +from ...protocol.packet_utils import PathUtils + +# --------------------------------------------------------------------------- +# Built-in CayenneLPP decoder (no external dependency) +# Spec: https://docs.mydevices.com/docs/lorawan/cayenne-lpp +# Each record: channel(1) + type_id(1) + value(N) +# --------------------------------------------------------------------------- + +_LPP_TYPES: Dict[int, tuple] = { + # type_id: (name, value_size_bytes, divisor, signed) + # --- Original LPPv1 types --- + 0x00: ("Digital Input", 1, 1, False), + 0x01: ("Digital Output", 1, 1, False), + 0x02: ("Analog Input", 2, 100, True), + 0x03: ("Analog Output", 2, 100, True), + # --- Extended types (from CayenneLPP.h) --- + 0x64: ("Generic Sensor", 4, 1, False), # LPP_GENERIC_SENSOR = 100 + 0x65: ("Illuminance", 2, 1, False), # LPP_LUMINOSITY = 101 + 0x66: ("Presence", 1, 1, False), # LPP_PRESENCE = 102 + 0x67: ("Temperature", 2, 10, True), # LPP_TEMPERATURE = 103 + 0x68: ("Humidity", 1, 2, False), # LPP_RELATIVE_HUMIDITY = 104 + 0x71: ("Accelerometer", 6, 1000, True), # LPP_ACCELEROMETER = 113, 3×int16 + 0x73: ("Barometer", 2, 10, False), # LPP_BAROMETRIC_PRESSURE = 115 + 0x74: ("Voltage", 2, 100, False), # LPP_VOLTAGE = 116, 0.01V + 0x75: ("Current", 2, 1000, True), # LPP_CURRENT = 117, 0.001A signed + 0x76: ("Frequency", 4, 1, False), # LPP_FREQUENCY = 118, 1Hz + 0x78: ("Percentage", 1, 1, False), # LPP_PERCENTAGE = 120, 1-100% + 0x79: ("Altitude", 2, 1, True), # LPP_ALTITUDE = 121, 1m signed + 0x7D: ("Concentration", 2, 1, False), # LPP_CONCENTRATION = 125, 1ppm + 0x80: ("Power", 2, 1, False), # LPP_POWER = 128, 1W + 0x82: ("Distance", 4, 1000, False), # LPP_DISTANCE = 130, 0.001m + 0x83: ("Energy", 4, 1000, False), # LPP_ENERGY = 131, 0.001kWh + 0x84: ("Direction", 2, 1, False), # LPP_DIRECTION = 132, 1deg + 0x85: ("Unix Time", 4, 1, False), # LPP_UNIXTIME = 133 + 0x86: ("Gyroscope", 6, 100, True), # LPP_GYROMETER = 134, 3×int16 + 0x87: ("Colour", 3, 1, False), # LPP_COLOUR = 135, RGB + 0x88: ("GPS", 9, 1, True), # LPP_GPS = 136, lat(3)+lon(3)+alt(3), mult 10000/100 + 0x8E: ("Switch", 1, 1, False), # LPP_SWITCH = 142, 0/1 + # LPP_POLYLINE 240: variable size; min 8 bytes (size+delta+lon+lat). Skip min to continue. + 0xF0: ("Polyline", 8, 1, False), # LPP_POLYLINE = 240 +} + + +def _decode_cayenne_lpp(data: bytes) -> list: + """Decode CayenneLPP binary payload into a list of sensor dicts.""" + sensors: list = [] + idx = 0 + while idx + 2 <= len(data): + channel = data[idx] + type_id = data[idx + 1] + # Channel 0 is never used by MeshCore firmware (channels start at + # TELEM_CHANNEL_SELF=1). A channel=0 byte is AES zero-padding — stop. + if channel == 0: + break + idx += 2 + spec = _LPP_TYPES.get(type_id) + if spec is None: + break # unknown type → stop (remaining bytes may be padding) + name, size, divisor, signed = spec + if idx + size > len(data): + break + raw = data[idx : idx + size] + idx += size + + if type_id == 0x88: + # GPS: lat(3, signed, /10000) + lon(3, signed, /10000) + alt(3, signed, /100) + lat = int.from_bytes(raw[0:3], "big", signed=True) / 10000 + lon = int.from_bytes(raw[3:6], "big", signed=True) / 10000 + alt = int.from_bytes(raw[6:9], "big", signed=True) / 100 + sensors.append( + { + "channel": channel, + "type": name, + "type_id": type_id, + "value": {"latitude": lat, "longitude": lon, "altitude": alt}, + "raw_value": raw.hex(), + } + ) + elif size == 6 and type_id in (0x71, 0x86): + # 3-axis: x(2) + y(2) + z(2), all signed + x = int.from_bytes(raw[0:2], "big", signed=True) / divisor + y = int.from_bytes(raw[2:4], "big", signed=True) / divisor + z = int.from_bytes(raw[4:6], "big", signed=True) / divisor + sensors.append( + { + "channel": channel, + "type": name, + "type_id": type_id, + "value": {"x": x, "y": y, "z": z}, + "raw_value": raw.hex(), + } + ) + elif type_id == 0x87: + # Colour: R(1) + G(1) + B(1) + sensors.append( + { + "channel": channel, + "type": name, + "type_id": type_id, + "value": {"r": raw[0], "g": raw[1], "b": raw[2]}, + "raw_value": raw.hex(), + } + ) + elif type_id == 0xF0: + # Polyline: variable size; we only consume minimum 8 bytes (MeshCore skipData). + sensors.append( + { + "channel": channel, + "type": name, + "type_id": type_id, + "value": raw.hex(), + "raw_value": raw.hex(), + } + ) + else: + val = int.from_bytes(raw, "big", signed=signed) + sensors.append( + { + "channel": channel, + "type": name, + "type_id": type_id, + "value": val / divisor if divisor != 1 else val, + "raw_value": raw.hex(), + } + ) + return sensors class ProtocolResponseHandler: @@ -28,6 +156,41 @@ def __init__(self, log_fn: Callable[[str], None], local_identity, contact_book): # Callbacks for protocol responses self._response_callbacks: Dict[int, Callable[[bool, str, Dict[str, Any]], None]] = {} + # Optional: decrypted payloads with tag+data (and optional path) passed as binary response. + # Signature: (tag_bytes, response_data, path_info=None). + self._binary_response_callback: Optional[Callable[..., Any]] = None + # Reference to LoginResponseHandler for state-based login detection + self._login_response_handler: Optional[Any] = None + # Packet injector for sending reciprocal PATH packets (mirrors C++ Mesh.cpp:168-169) + self._packet_injector: Optional[Callable] = None + # Optional: notify when contact out_path is updated from decrypted PATH + # (e.g. companion persist). + self._contact_path_updated_callback: Optional[Callable[..., Any]] = None + + def set_contact_path_updated_callback(self, callback: Optional[Callable[..., Any]]) -> None: + """Set callback when contact out_path is updated from a decrypted PATH packet. + + Signature: (contact_pubkey: bytes, path_len: int, path_bytes: bytes) + -> None | Awaitable[None]. + Called after _update_contact_path when the contact was found and updated. + """ + self._contact_path_updated_callback = callback + + def set_login_response_handler(self, handler: Any) -> None: + """Set login handler ref for checking active login state.""" + self._login_response_handler = handler + + def set_packet_injector(self, injector: Optional[Callable]) -> None: + """Set packet injector for sending reciprocal PATH packets. + + When the companion receives a flooded PATH from a remote repeater, + the C++ firmware sends a reciprocal PATH back so the remote repeater + learns the route to us (Mesh.cpp:168-169). Without this, the remote + repeater has no out_path for us and must fall back to plain FLOOD for + responses — which intermediate repeaters may drop due to transport-code + region filtering. + """ + self._packet_injector = injector @staticmethod def payload_type() -> int: @@ -43,99 +206,425 @@ def clear_response_callback(self, contact_hash: int) -> None: """Clear callback for protocol responses from a specific contact.""" self._response_callbacks.pop(contact_hash, None) + def set_binary_response_callback(self, callback: Callable[..., Any]) -> None: + """Set callback for binary responses. (tag_bytes, response_data, path_info=None). + path_info = (out_path, in_path, contact_pubkey) for path-return format.""" + self._binary_response_callback = callback + async def __call__(self, pkt: Packet) -> None: - """Handle incoming PATH packet that might be a protocol response.""" + """Handle incoming PATH or RESPONSE packet that might be a protocol response.""" try: # Check if this looks like an encrypted protocol response if len(pkt.payload) < 4: return # Too short for protocol response - # PATH packet structure: + # Both PATH and RESPONSE packets share the same structure: # dest_hash(1) + src_hash(1) + encrypted_data(N) src_hash = pkt.payload[1] + pkt_type = (pkt.header >> 2) & 0x0F + route_label = "FLOOD" if pkt.is_route_flood() else "DIRECT" + if pkt_type == PAYLOAD_TYPE_RESPONSE: + self._log( + f"[ProtocolResponse] Received RESPONSE (0x01) from 0x{src_hash:02X} " + f"({route_label}, {len(pkt.payload)}B)" + ) - # Check if we have a callback waiting for this source - if src_hash not in self._response_callbacks: - return # Not waiting for response from this source - - self._log( - "[ProtocolResponse] Processing potential protocol response " - f"from 0x{src_hash:02X}" - ) + # Proceed if we have a callback for this source or the binary (path-discovery) callback + if src_hash not in self._response_callbacks and self._binary_response_callback is None: + return # Try to decrypt the response - success, decoded_text, parsed_data = await self._decrypt_protocol_response( - pkt, src_hash - ) + ( + success, + decoded_text, + parsed_data, + raw_decrypted, + ) = await self._decrypt_protocol_response(pkt, src_hash) + + # If an explicit response callback is waiting for this source (e.g. telemetry, + # stats, repeater command), deliver there first. The binary/path-discovery + # callback is a generic fallback for unsolicited binary responses. + # + # Guard: only skip when this is a login response (13 bytes, response_code at [4] + # 0x00/0x01). A broad "decrypted_len < 20" would drop valid PATH-wrapped stats + # or other short responses and delay stats load after login. + if src_hash in self._response_callbacks: + if not success: + return + if self._is_login_response(pkt, raw_decrypted): + # Login responses are handled by LoginResponseHandler; do not deliver to + # stats/telemetry waiter. + return + callback = self._response_callbacks[src_hash] + if callback: + if parsed_data.get("type") == "telemetry": + self._log( + f"[ProtocolResponse] Delivering telemetry to waiter " + f"(src=0x{src_hash:02X}, {parsed_data.get('sensor_count', 0)} sensors)" + ) + callback(success, decoded_text, parsed_data) + return + + # If binary response callback set, parse and invoke (tag+data or path-return) + if ( + success + and self._binary_response_callback is not None + and raw_decrypted is not None + and len(raw_decrypted) >= 4 + ): + path_info = None + pkt_type = (pkt.header >> 2) & 0x0F + + if pkt_type == PAYLOAD_TYPE_PATH: + # PATH packet: decrypted is path_len(1)+path(N)+extra_type(1)+extra + # Extract inner response from path-return structure + path_len_byte = raw_decrypted[0] + path_byte_len = PathUtils.get_path_byte_len(path_len_byte) + inner_offset = 1 + path_byte_len + 1 + if ( + PathUtils.is_valid_path_len(path_len_byte) + and len(raw_decrypted) >= inner_offset + 4 + ): + out_path = bytes(raw_decrypted[1 : 1 + path_byte_len]) + extra_type = raw_decrypted[1 + path_byte_len] & 0x0F + extra = raw_decrypted[inner_offset:] + if extra_type == PAYLOAD_TYPE_RESPONSE and len(extra) >= 4: + tag_bytes = extra[:4] + response_data = extra[4:] + in_path = bytes(pkt.path) if pkt.path else b"" + contact = self._find_contact_by_hash(src_hash) + if contact: + contact_pubkey = bytes.fromhex(contact.public_key) + path_info = (out_path, in_path, contact_pubkey) + else: + tag_bytes = raw_decrypted[:4] + response_data = raw_decrypted[4:] + else: + tag_bytes = raw_decrypted[:4] + response_data = raw_decrypted[4:] + else: + # RESPONSE packet: decrypted is tag(4)+data directly + tag_bytes = raw_decrypted[:4] + response_data = raw_decrypted[4:] - # Call the waiting callback - callback = self._response_callbacks[src_hash] - if callback: - callback(success, decoded_text, parsed_data) + # Do not deliver login responses to the binary callback; they are + # handled by LoginResponseHandler. Login response format is + # tag(4) + response_code(1) + keep_alive(1) + is_admin(1) + ... + # = 13 bytes total, with response_code 0x00 or 0x01. + if len(response_data) == 9 and response_data[0] in (0x00, 0x01): + return + + try: + cb_result = self._binary_response_callback(tag_bytes, response_data, path_info) + if asyncio.iscoroutine(cb_result): + await cb_result + except Exception as e: + self._log(f"[ProtocolResponse] Binary response callback error: {e}") + return except Exception as e: self._log(f"[ProtocolResponse] Error processing protocol response: {e}") - async def _decrypt_protocol_response( - self, pkt: Packet, src_hash: int - ) -> tuple[bool, str, Dict[str, Any]]: - """Decrypt and parse a protocol response packet.""" + def _is_login_response(self, pkt: Packet, raw_decrypted: Optional[bytes]) -> bool: + """True if a login is currently pending for the source contact. + + Mirrors the C++ companion firmware pattern: classify responses by + pending-request state rather than payload content. The previous + content-based check (``inner[4] in (0x00, 0x01)``) falsely matched + CayenneLPP telemetry whose first byte is channel 0x01. + """ + if not self._login_response_handler: + return False + passwords = getattr(self._login_response_handler, "_active_login_passwords", {}) + if not passwords: + return False + if len(pkt.payload) < 2: + return False + src_hash = pkt.payload[1] + return src_hash in passwords + + def _update_contact_path( + self, + contact_pubkey: bytes, + src_hash: int, + path_len_byte: int, + decrypted: bytes, + ) -> bool: + """Update contact out_path from decrypted PATH data (firmware onContactPathRecv pattern). + + When a PATH packet is successfully decrypted, store the return path + on the contact so that subsequent requests use sendDirect() instead + of sendFlood(). This mirrors C++ ``BaseChatMesh::onContactPathRecv``. + + Returns True if the contact was found and updated, False otherwise. + """ try: - # Find the contact by hash - contact = self._find_contact_by_hash(src_hash) - if not contact: - return False, f"Unknown contact for hash 0x{src_hash:02X}", {} + if not PathUtils.is_valid_path_len(path_len_byte): + return False + path_byte_len = PathUtils.get_path_byte_len(path_len_byte) + out_path_bytes = bytes(decrypted[1 : 1 + path_byte_len]) + contact_obj = self._contact_book.get_by_key(contact_pubkey) + if contact_obj is not None: + contact_obj.out_path_len = path_len_byte + contact_obj.out_path = out_path_bytes + self._contact_book.update(contact_obj) + self._log( + f"[ProtocolResponse] Updated out_path for 0x{src_hash:02X}: " + f"path_len={path_len_byte}" + ) + return True + else: + self._log( + f"[ProtocolResponse] Cannot update out_path for 0x{src_hash:02X}: " + f"contact not found by key" + ) + return False + except Exception as e: + self._log(f"[ProtocolResponse] Failed to update out_path: {e}") + return False + + async def _send_reciprocal_path( + self, + src_hash: int, + shared_secret: bytes, + pkt: Packet, + decrypted: bytes, + path_len_byte: int, + ) -> None: + """Send a reciprocal PATH back to the sender so it learns the route to us. - # Get encryption keys - contact_pubkey = bytes.fromhex(contact.public_key) - peer_id = Identity(contact_pubkey) - shared_secret = peer_id.calc_shared_secret(self._local_identity.get_private_key()) - aes_key = shared_secret[:16] + Mirrors C++ firmware behaviour (Mesh.cpp lines 166-169): - # Extract encrypted data (skip dest_hash(1) + src_hash(1)) - encrypted_data = pkt.payload[2:] + mesh::Packet* rpath = createPathReturn( + &src_hash, secret, pkt->path, pkt->path_len, 0, NULL, 0); + if (rpath) sendDirect(rpath, path, path_len, 500); - # Decrypt the payload - decrypted = CryptoUtils.mac_then_decrypt(aes_key, shared_secret, encrypted_data) + - ``pkt.path`` is the flood accumulation path on the received PATH + (the inbound route, e.g. [hash_X, hash_B]). This is placed inside + the reciprocal's encrypted payload so the remote repeater stores it + as *its* ``out_path`` — the route from itself back to us. + - The reciprocal is sent **DIRECT** using the inner ``out_path`` + extracted from the decrypted data (e.g. [hash_B, hash_X]), which + routes through the mesh to reach the remote repeater. + """ + if self._packet_injector is None: + return + try: + our_hash = self._local_identity.get_public_key()[0] + # The inbound flood path (pkt.path) tells the remote repeater + # "to reach me, go through these intermediate hops". + in_path = list(pkt.path) if pkt.path else [] + + # Build the reciprocal PATH packet. create_path_return produces a + # FLOOD PATH by default; we convert it to DIRECT below. + reciprocal = PacketBuilder.create_path_return( + dest_hash=src_hash, + src_hash=our_hash, + secret=shared_secret, + path=in_path, + extra_type=0xFF, # no extra payload (dummy, same as C++ NULL/0) + extra=b"", + ) - self._log(f"[ProtocolResponse] Successfully decrypted {len(decrypted)} bytes") + # Convert to DIRECT routing using the inner out_path (the route + # from us to the remote repeater). + path_byte_len = PathUtils.get_path_byte_len(path_len_byte) + out_path_bytes = bytes(decrypted[1 : 1 + path_byte_len]) + reciprocal.header = (reciprocal.header & ~0x03) | ROUTE_TYPE_DIRECT + reciprocal.set_path(out_path_bytes, path_len_byte) - # Parse based on content type - return self._parse_protocol_response(decrypted) + # Await injection so the reciprocal PATH is serialized through the + # radio TX pipeline before this method returns. This ensures the + # login callback doesn't fire until the reciprocal PATH is in flight, + # preventing the app's first stats REQ from racing ahead of it. + await self._packet_injector(reciprocal) + self._log( + f"[ProtocolResponse] Sending reciprocal PATH to 0x{src_hash:02X} " + f"via DIRECT (out_path_len={path_len_byte}, in_path_len={len(in_path)})" + ) except Exception as e: - self._log(f"[ProtocolResponse] Decryption failed: {e}") - return False, f"Decryption failed: {e}", {} + self._log(f"[ProtocolResponse] Failed to send reciprocal PATH: {e}") + + async def _decrypt_protocol_response( + self, pkt: Packet, src_hash: int + ) -> tuple[bool, str, Dict[str, Any], Optional[bytes]]: + """Decrypt and parse protocol response. Returns (success, text, parsed_data, raw_decrypted). + + Handles both packet types: + - RESPONSE (0x01): direct → tag(4)+data + - PATH (0x08): path_len+path(N)+extra_type+extra + + Both use same wire payload layout: dest_hash(1) + src_hash(1) + MAC(2) + ciphertext. + """ + payload = pkt.get_payload() + if len(payload) < 2 + 4: # need dest+src + at least MAC(2)+min ciphertext + return False, "Payload too short", {}, None + encrypted_data = payload[2:] + # MAC(2) + ciphertext. Ciphertext may be block-aligned or truncated (e.g. long PATH + # packets lose one byte to header size; telemetry PATH 63 bytes). Allow MAC + 15 bytes + # minimum so we can pad to one block and attempt decrypt. + enc_len = len(encrypted_data) + min_enc = CIPHER_MAC_SIZE + (CIPHER_BLOCK_SIZE - 1) # 17: MAC(2) + 15 ciphertext + if enc_len < min_enc: + self._log( + f"[ProtocolResponse] Payload too short for hash 0x{src_hash:02X}: " + f"encrypted_data={enc_len}B (need MAC(2)+≥15 bytes ciphertext)" + ) + return False, "Payload too short", {}, None + pkt_type = (pkt.header >> 2) & 0x0F + + # Try every contact matching src_hash (same “try all hash matches” as TXT_MSG and PATH ACK). + # Repeaters use the same ECDH shared secret as login (createPathReturn(..., secret, ...)). + # Firmware: ed25519_key_exchange uses first 32B of priv (clamped) and (y+1)/(1-y) for peer + # pub; we match via libsodium ed25519_pk_to_curve25519 + scalarmult. + contacts_tried = list(self._contacts_by_hash(src_hash)) + for contact in contacts_tried: + try: + pk = contact.public_key + contact_pubkey = pk if isinstance(pk, bytes) else bytes.fromhex(pk) + if len(contact_pubkey) != 32: + continue + peer_id = Identity(contact_pubkey) + shared_secret = peer_id.calc_shared_secret(self._local_identity.get_private_key()) + aes_key = shared_secret[:16] + decrypted = CryptoUtils.mac_then_decrypt(aes_key, shared_secret, encrypted_data) + except Exception: + continue + + # Determine the actual response data based on packet type. + response_data = decrypted + if pkt_type == PAYLOAD_TYPE_PATH: + if len(decrypted) >= 2: + path_len_byte = decrypted[0] + path_byte_len = PathUtils.get_path_byte_len(path_len_byte) + inner_offset = 1 + path_byte_len + 1 + if ( + PathUtils.is_valid_path_len(path_len_byte) + and len(decrypted) >= inner_offset + ): + extra_type = decrypted[1 + path_byte_len] & 0x0F + if extra_type == PAYLOAD_TYPE_RESPONSE and len(decrypted) > inner_offset: + response_data = decrypted[inner_offset:] + elif extra_type != PAYLOAD_TYPE_RESPONSE: + self._log( + f"[ProtocolResponse] PATH format: extra_type=0x{extra_type:02X}, " + f"not RESPONSE" + ) + + # Firmware pattern (onContactPathRecv): update contact out_path + # so subsequent requests use sendDirect() instead of sendFlood(). + out_path_bytes = bytes(decrypted[1 : 1 + path_byte_len]) + if self._update_contact_path( + contact_pubkey, src_hash, path_len_byte, decrypted + ): + if self._contact_path_updated_callback is not None: + cb_result = self._contact_path_updated_callback( + contact_pubkey, path_len_byte, out_path_bytes + ) + if asyncio.iscoroutine(cb_result): + await cb_result + + # Firmware pattern (Mesh.cpp:168-169): send reciprocal PATH back + # to the sender so it learns the route to us. Without this, the + # remote repeater has no out_path for us and must fall back to + # plain FLOOD for responses — which intermediate repeaters may + # drop due to transport-code region filtering. + if pkt.is_route_flood(): + await self._send_reciprocal_path( + src_hash, + shared_secret, + pkt, + decrypted, + path_len_byte, + ) + + success, text, parsed = self._parse_protocol_response(response_data) + return success, text, parsed, decrypted + + # Log once per packet: no contact or HMAC failed for every matching contact + if not contacts_tried: + self._log( + f"[ProtocolResponse] No contact for hash 0x{src_hash:02X}, " + "cannot decrypt PATH/RESPONSE" + ) + else: + self._log( + f"[ProtocolResponse] HMAC failed for hash 0x{src_hash:02X} " + f"(tried {len(contacts_tried)} contact(s). Repeater PATH uses same ECDH as login)" + ) + return False, "Decryption failed: Invalid HMAC", {}, None def _parse_protocol_response(self, data: bytes) -> tuple[bool, str, Dict[str, Any]]: - """Parse decrypted protocol response data.""" + """Parse decrypted protocol response data. + + Parse order: + 0. Login response (13 bytes, response_code at [4] 0x00/0x01) → binary, + for LoginResponseHandler. + 1. Telemetry (reflected_timestamp + valid CayenneLPP signature byte check) + 2. Stats (RepeaterStats struct, ≥52 bytes, only when not telemetry) + 3. Text / status (UTF-8 printable after stripping tag + nulls) + 4. Binary fallback + + Telemetry is checked first because CayenneLPP data can be ≥56 bytes for + sensors with many readings, which would otherwise be misidentified as stats. + The telemetry signature check (channel=1, type=0x74) is cheap and reliable. + """ try: - # Check if this looks like a stats response (protocol 0x01) - if len(data) >= 48: - # Try parsing as RepeaterStats struct - stats_result = self._parse_stats_response(data) - if stats_result: - return True, stats_result["formatted"], stats_result["raw"] + # 0. Login responses are 13 bytes (tag(4) + response_code(1) + keep_alive(1) + ...). + # Do not parse as telemetry/stats; LoginResponseHandler will handle them. + if len(data) == 13 and data[4] in (0x00, 0x01): + return ( + True, + "Binary response: " + data.hex(), + {"type": "binary", "hex": data.hex()}, + ) - # Check if this looks like a telemetry response (protocol 0x03) - if len(data) >= 4: # At minimum need some telemetry data + # 1. Check if this looks like a telemetry response (protocol 0x03). + # MeshCore always starts telemetry with addVoltage(TELEM_CHANNEL_SELF=1, ...) + # which produces LPP channel=0x01, type=0x74 (LPP_VOLTAGE) as first record. + # This signature reliably distinguishes telemetry from stats/text responses. + if len(data) >= 8: # tag(4) + at least one LPP record (ch+type+val = 3+) telemetry_result = self._parse_telemetry_response(data) - if telemetry_result: + if telemetry_result and telemetry_result.get("sensor_count", 0) > 0: return True, telemetry_result["formatted"], telemetry_result - # Try parsing as text response - try: - text_response = data.rstrip(b"\x00").decode("utf-8") - if text_response.strip(): - return ( - True, - text_response, - {"type": "text", "content": text_response}, + # 2. Check if this looks like a stats response (protocol 0x01). + # RepeaterStats is 48-56 bytes + 4-byte tag. Older firmware + # omits n_recv_errors (52 B struct → 56 total); PATH-wrapped + # responses may also lose trailing bytes to AES block alignment. + # Only reached if telemetry signature check above failed. + if len(data) >= 56: + stats_result = self._parse_stats_response(data) + if stats_result: + # Include raw_bytes in the parsed dict so callers can + # forward the binary RepeaterStats to companion apps. + result_dict = stats_result["raw"] + result_dict["type"] = "stats" + result_dict["raw_bytes"] = stats_result["raw_bytes"] + self._log( + f"[ProtocolResponse] STATS: batt={result_dict['batt_milli_volts']}mV, " + f"rssi={result_dict['last_rssi']}, snr={result_dict['last_snr']}, " + f"raw={len(result_dict['raw_bytes'])}B" ) - except UnicodeDecodeError: - pass + return True, stats_result["formatted"], result_dict - # Fall back to hex representation + # 3. Try parsing as text/status response. + # Status responses are tag(4) + UTF-8 text. Strip the 4-byte + # tag that prefixes every response, then check for printable text. + if len(data) > 4: + try: + text_candidate = data[4:].rstrip(b"\x00").decode("utf-8") + if text_candidate.strip() and text_candidate.strip().isprintable(): + return ( + True, + text_candidate.strip(), + {"type": "text", "content": text_candidate.strip()}, + ) + except UnicodeDecodeError: + pass + + # 4. Fall back to hex representation hex_response = data.hex() return ( True, @@ -147,37 +636,108 @@ def _parse_protocol_response(self, data: bytes) -> tuple[bool, str, Dict[str, An return False, f"Parse error: {e}", {} def _parse_stats_response(self, data: bytes) -> Optional[Dict[str, Any]]: - """Parse RepeaterStats struct response (protocol 0x01).""" + """Parse RepeaterStats struct response (protocol 0x01). + + RepeaterStats layout (from simple_repeater/MyMesh.h): + uint16_t batt_milli_volts; // offset 0 + uint16_t curr_tx_queue_len; // offset 2 + int16_t noise_floor; // offset 4 + int16_t last_rssi; // offset 6 + uint32_t n_packets_recv; // offset 8 + uint32_t n_packets_sent; // offset 12 + uint32_t total_air_time_secs; // offset 16 + uint32_t total_up_time_secs; // offset 20 + uint32_t n_sent_flood; // offset 24 + uint32_t n_sent_direct; // offset 28 + uint32_t n_recv_flood; // offset 32 + uint32_t n_recv_direct; // offset 36 + uint16_t err_events; // offset 40 + int16_t last_snr; // ×4 // offset 42 + uint16_t n_direct_dups; // offset 44 + uint16_t n_flood_dups; // offset 46 + uint32_t total_rx_air_time_secs; // offset 48 + uint32_t n_recv_errors; // offset 52 + Total: 56 bytes + """ try: - # Skip 4-byte header as per C++ code: memcpy(&reply_data[4], &stats, sizeof(stats)) - if len(data) < 52: # 4 header + 48 struct = 52 minimum + # Skip 4-byte reflected timestamp/tag + # memcpy(&reply_data[4], &stats, sizeof(stats)) + if len(data) < 56: # 4 tag + 52 struct minimum (without n_recv_errors) return None - stats_data = data[4:] # Skip the 4-byte header - - # Parse as all 16-bit values - this gives correct results - parsed = struct.unpack("<24H", stats_data[:48]) + stats_data = data[4:] # Skip the 4-byte tag + + # Pad to 56 bytes so struct.unpack always succeeds. Older firmware + # or PATH-wrapped responses with AES block alignment may yield fewer + # than 56 bytes; missing trailing fields default to zero. + if len(stats_data) < 56: + stats_data = stats_data + b"\x00" * (56 - len(stats_data)) + + # Parse with correct field types matching C++ struct + ( + batt_milli_volts, # uint16 offset 0 + curr_tx_queue_len, # uint16 offset 2 + noise_floor, # int16 offset 4 + last_rssi, # int16 offset 6 + n_packets_recv, # uint32 offset 8 + n_packets_sent, # uint32 offset 12 + total_air_time_secs, # uint32 offset 16 + total_up_time_secs, # uint32 offset 20 + n_sent_flood, # uint32 offset 24 + n_sent_direct, # uint32 offset 28 + n_recv_flood, # uint32 offset 32 + n_recv_direct, # uint32 offset 36 + err_events, # uint16 offset 40 + last_snr_raw, # int16 offset 42 + n_direct_dups, # uint16 offset 44 + n_flood_dups, # uint16 offset 46 + total_rx_air_time_secs, # uint32 offset 48 + n_recv_errors, # uint32 offset 52 + ) = struct.unpack(" 10000: # > 10V is unreasonable + return None + if last_rssi < -200 or last_rssi > 0: # RSSI always negative, > -200 dBm + return None - # Map to meaningful field names based on observed values raw_stats = { - "batt_milli_volts": parsed[1], # Battery voltage in mV - "curr_tx_queue_len": parsed[2], # Current TX queue length - "last_rssi": self._convert_signed_16bit(parsed[3]), # Last RSSI in dBm - "n_packets_recv": parsed[5], # Total packets received - "n_packets_sent": parsed[7], # Total packets sent - "n_recv_flood": parsed[9], # Flood packets received - "total_up_time_secs": parsed[11], # Uptime in seconds - "total_air_time_secs": parsed[13], # Air time in seconds - "err_events": parsed[17], # Error events count - "last_snr": snr_register_to_db(parsed[19], bits=16), - "n_flood_dups": parsed[22], # Flood duplicate packets - "n_direct_dups": parsed[23], # Direct duplicate packets + "batt_milli_volts": batt_milli_volts, + "curr_tx_queue_len": curr_tx_queue_len, + "noise_floor": noise_floor, + "last_rssi": last_rssi, + "n_packets_recv": n_packets_recv, + "n_packets_sent": n_packets_sent, + "total_air_time_secs": total_air_time_secs, + "total_up_time_secs": total_up_time_secs, + "n_sent_flood": n_sent_flood, + "n_sent_direct": n_sent_direct, + "n_recv_flood": n_recv_flood, + "n_recv_direct": n_recv_direct, + "err_events": err_events, + "last_snr": last_snr_raw / 4.0, # firmware stores SNR × 4 + "n_direct_dups": n_direct_dups, + "n_flood_dups": n_flood_dups, + "total_rx_air_time_secs": total_rx_air_time_secs, + "n_recv_errors": n_recv_errors, } # Format as human-readable string formatted = self._format_stats(raw_stats) - return {"raw": raw_stats, "formatted": formatted, "type": "stats"} + # Include raw bytes after the 4-byte tag so callers can forward + # the binary RepeaterStats struct to companion apps verbatim. + # Pad to 56 bytes if shorter (companion app expects full struct). + raw_bytes_after_tag = bytes(stats_data[:56]) + + return { + "raw": raw_stats, + "formatted": formatted, + "type": "stats", + "raw_bytes": raw_bytes_after_tag, + } except Exception as e: self._log(f"[ProtocolResponse] Stats parsing failed: {e}") @@ -189,123 +749,53 @@ def _parse_telemetry_response(self, data: bytes) -> Optional[Dict[str, Any]]: Expected format: - reflected_timestamp (4 bytes, little-endian) - CayenneLPP data (remaining bytes) + + Returns None if no valid CayenneLPP sensors can be decoded, allowing + the caller to fall back to other response types. """ try: - if len(data) < 4: - self._log( - "[ProtocolResponse] Telemetry data too short: " - f"{len(data)} bytes (need at least 4 for timestamp)" - ) + if len(data) < 8: + # Need at least tag(4) + one minimal LPP record (ch+type+val = 3) return None - self._log( - "[ProtocolResponse] Parsing " f"{len(data)} bytes telemetry data: {data.hex()}" - ) - - # Parse according to MeshCore TelemetryResponseData structure - # First 4 bytes: reflected timestamp (little-endian) + # First 4 bytes: reflected timestamp / tag (little-endian) reflected_timestamp = struct.unpack("= 0 and lpp_data[last_nonzero] == 0: - last_nonzero -= 1 - - if last_nonzero < len(lpp_data) - 1: - lpp_data = lpp_data[: last_nonzero + 1] - - # Try using the cayenne_lpp_helpers function (now without trailing zeros) if available - try: - try: - from utils.cayenne_lpp_helpers import decode_cayenne_lpp_payload - - helper_result = decode_cayenne_lpp_payload(lpp_data.hex()) - except ImportError: - # Utils not available in lightweight mode - helper_result = {"error": "cayenne_lpp_helpers not available"} - - if "error" not in helper_result and helper_result.get("sensor_count", 0) > 0: - self._log( - "[ProtocolResponse] CayenneLPP parsing succeeded: " - f"{helper_result['sensor_count']} sensors" - ) + if len(lpp_data) < 3: + # Not enough for even one LPP record (channel + type + 1-byte value) + return None - # Convert to our expected format - converted_sensors = [] - for sensor in helper_result["sensors"]: - converted_sensor = { - "channel": sensor["channel"], - "type": sensor["type"], - "type_id": sensor["type_id"], - "value": sensor["value"], - "raw_value": sensor["raw_value"], - } - converted_sensors.append(converted_sensor) - - return { - "type": "telemetry", - "formatted": ( - f"Telemetry ({len(converted_sensors)} sensors, " - f"ts:{reflected_timestamp})" - ), - "reflected_timestamp": reflected_timestamp, - "sensor_count": len(converted_sensors), - "sensors": converted_sensors, - } - else: - self._log( - "[ProtocolResponse] CayenneLPP parsing failed: " - f"{helper_result.get('error', 'no sensors found')}" - ) + # Sanity check: MeshCore telemetry always starts with + # addVoltage(TELEM_CHANNEL_SELF=1, battery_volts) which produces + # channel=1, type=0x74 (LPP_VOLTAGE). Require this signature to + # distinguish telemetry from other response types that happen to + # decrypt to >= 8 bytes. + if lpp_data[0] != 0x01 or lpp_data[1] != 0x74: + return None - except Exception as e: - self._log(f"[ProtocolResponse] CayenneLPP parsing exception: {e}") + sensors = _decode_cayenne_lpp(lpp_data) + if not sensors: + return None - # All parsing methods failed - self._log("[ProtocolResponse] CayenneLPP parsing failed") + self._log( + f"[ProtocolResponse] CayenneLPP decoded {len(sensors)} sensor(s) " + f"from {len(lpp_data)} bytes: {lpp_data.hex()}" + ) return { "type": "telemetry", - "formatted": ( - f"Unknown telemetry LPP data ({len(lpp_data)} bytes, " - f"ts:{reflected_timestamp})" - ), + "formatted": (f"Telemetry ({len(sensors)} sensors, " f"ts:{reflected_timestamp})"), "reflected_timestamp": reflected_timestamp, - "sensor_count": 0, - "sensors": [], + "sensor_count": len(sensors), + "sensors": sensors, + "raw_bytes": bytes(data[4:]), # LPP data after tag for verbatim forwarding } except Exception as e: self._log(f"[ProtocolResponse] Telemetry parsing failed: {e}") - return { - "type": "telemetry", - "length": len(data), - "hex": data.hex(), - "format": "error", - "formatted": f"Telemetry parsing error: {e}", - "error": str(e), - } - - def _convert_signed_16bit(self, value: int) -> int: - """Convert unsigned 16-bit to signed if needed.""" - return value - 65536 if value > 32767 else value + return None def _format_stats(self, stats: Dict[str, Any]) -> str: """Format stats as human-readable string.""" @@ -321,11 +811,17 @@ def _format_stats(self, stats: Dict[str, Any]) -> str: # Signal quality result.append(f"RSSI: {stats['last_rssi']}dBm") result.append(f"SNR: {stats['last_snr']:.1f}dB") + result.append(f"NF: {stats['noise_floor']}dB") # Packet counts - result.append(f"RX: {stats['n_packets_recv']}") - result.append(f"TX: {stats['n_packets_sent']}") - result.append(f"Flood RX: {stats['n_recv_flood']}") + result.append( + f"TX: {stats['n_packets_sent']} " + f"(F:{stats['n_sent_flood']}/D:{stats['n_sent_direct']})" + ) + result.append( + f"RX: {stats['n_packets_recv']} " + f"(F:{stats['n_recv_flood']}/D:{stats['n_recv_direct']})" + ) # Uptime formatting uptime = stats["total_up_time_secs"] @@ -341,30 +837,38 @@ def _format_stats(self, stats: Dict[str, Any]) -> str: result.append(f"Up: {days}d{hours}h") # Air time - result.append(f"Air: {stats['total_air_time_secs']}s") + result.append(f"TxAir: {stats['total_air_time_secs']}s") + if stats.get("total_rx_air_time_secs"): + result.append(f"RxAir: {stats['total_rx_air_time_secs']}s") # Error events (only if > 0) if stats["err_events"] > 0: result.append(f"Err: {stats['err_events']}") + # RX errors (only if > 0) + if stats.get("n_recv_errors", 0) > 0: + result.append(f"RxErr: {stats['n_recv_errors']}") + # Duplicates (only if > 0) if stats["n_direct_dups"] > 0 or stats["n_flood_dups"] > 0: - result.append(f"Dups: {stats['n_direct_dups']}/{stats['n_flood_dups']}") + result.append(f"Dups: D:{stats['n_direct_dups']}/F:{stats['n_flood_dups']}") return " | ".join(result) def _find_contact_by_hash(self, contact_hash: int): - """Find contact by hash value.""" - if not self._contact_book: - return None + """Find first contact by hash value.""" + for contact in self._contacts_by_hash(contact_hash): + return contact + return None - # Search through contacts to find one with matching hash + def _contacts_by_hash(self, contact_hash: int): + """Yield all contacts whose public_key first byte matches contact_hash.""" + if not self._contact_book: + return for contact in self._contact_book.list_contacts(): try: contact_pubkey = bytes.fromhex(contact.public_key) if contact_pubkey[0] == contact_hash: - return contact + yield contact except (ValueError, IndexError): continue - - return None diff --git a/src/pymc_core/node/handlers/registry.py b/src/pymc_core/node/handlers/registry.py new file mode 100644 index 0000000..6e57613 --- /dev/null +++ b/src/pymc_core/node/handlers/registry.py @@ -0,0 +1,105 @@ +"""Handler registry for creating and wiring standard MeshCore protocol handlers. + +Both the :class:`Dispatcher` and :class:`CompanionBridge` need the same core +set of handlers — this module provides a shared factory so handler creation +and cross-wiring only lives in one place. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Optional + +from .advert import AdvertHandler +from .group_text import GroupTextHandler +from .login_response import LoginResponseHandler +from .path import PathHandler +from .protocol_response import ProtocolResponseHandler +from .text import TextMessageHandler + + +@dataclass +class CoreHandlers: + """Bundle of the core protocol handlers shared by Dispatcher and Bridge.""" + + text_handler: TextMessageHandler + advert_handler: AdvertHandler + group_text_handler: GroupTextHandler + protocol_response_handler: ProtocolResponseHandler + login_response_handler: LoginResponseHandler + path_handler: PathHandler + + +def create_core_handlers( + *, + identity: Any, + contacts: Any, + channels: Any, + event_service: Any, + send_packet_fn: Callable, + log_fn: Callable, + node_name: str, + radio_config: Optional[dict] = None, + ack_handler: Any = None, +) -> CoreHandlers: + """Create and wire the standard set of MeshCore protocol handlers. + + This is the single source of truth for handler construction. Both + :meth:`Dispatcher.register_default_handlers` and + :class:`CompanionBridge.__init__` delegate here. + + Args: + identity: The local identity for encryption/signing. + contacts: Contact storage. + channels: Channel database. + event_service: Event service for broadcasting mesh events. + send_packet_fn: Async callable to send a packet (the transport). + log_fn: Logging callable (``str -> None``). + node_name: Human-readable node name. + radio_config: Optional radio configuration dict. + ack_handler: ACK handler instance (varies between Dispatcher and + Bridge). If ``None``, the :class:`PathHandler` is constructed + without ACK forwarding. + """ + protocol_response_handler = ProtocolResponseHandler(log_fn, identity, contacts) + + login_response_handler = LoginResponseHandler(identity, contacts, log_fn) + login_response_handler.set_protocol_response_handler(protocol_response_handler) + protocol_response_handler.set_login_response_handler(login_response_handler) + + path_handler = PathHandler( + log_fn, + ack_handler, + protocol_response_handler, + login_response_handler, + ) + + text_handler = TextMessageHandler( + identity, + contacts, + log_fn, + send_packet_fn, + event_service, + radio_config, + ) + + advert_handler = AdvertHandler(log_fn, event_service=event_service) + + group_text_handler = GroupTextHandler( + identity, + contacts, + log_fn, + send_packet_fn, + channels, + event_service, + node_name, + ) + + return CoreHandlers( + text_handler=text_handler, + advert_handler=advert_handler, + group_text_handler=group_text_handler, + protocol_response_handler=protocol_response_handler, + login_response_handler=login_response_handler, + path_handler=path_handler, + ) diff --git a/src/pymc_core/node/handlers/text.py b/src/pymc_core/node/handlers/text.py index 174b1df..43dc908 100644 --- a/src/pymc_core/node/handlers/text.py +++ b/src/pymc_core/node/handlers/text.py @@ -31,34 +31,56 @@ def set_command_response_callback(self, callback): """Set callback function for command responses.""" self.command_response_callback = callback + def _contact_pubkey_bytes(self, contact) -> bytes: + """Return contact's public key as 32 bytes (handles hex str or bytes).""" + pk = contact.public_key + return bytes.fromhex(pk) if isinstance(pk, str) else bytes(pk) + async def __call__(self, packet: Packet) -> None: if len(packet.payload) < 4: self.log("TXT_MSG payload too short to decrypt") return src_hash = packet.payload[1] - matched_contact = None + # Collect all contacts whose public key first byte matches src_hash (hash collision + # possible) + candidates = [] for contact in self.contacts.contacts: try: - if bytes.fromhex(contact.public_key)[0] == src_hash: - matched_contact = contact - break + pk = self._contact_pubkey_bytes(contact) + if len(pk) >= 1 and pk[0] == src_hash: + candidates.append(contact) except Exception as err: self.log(f"Error reading contact key: {err}") - if not matched_contact: + if not candidates: self.log(f"No contact found for src hash: {src_hash:02X}") return - peer_id = Identity(bytes.fromhex(matched_contact.public_key)) - shared_secret = peer_id.calc_shared_secret(self.local_identity.get_private_key()) - aes_key = shared_secret[:16] payload = packet.payload[2:] # Skip dest_hash and src_hash - - try: - decrypted = CryptoUtils.mac_then_decrypt(aes_key, shared_secret, payload) - except Exception as err: - self.log(f"Decryption failed: {err}") + matched_contact = None + decrypted = None + shared_secret = None + for contact in candidates: + try: + pubkey_bytes = self._contact_pubkey_bytes(contact) + if len(pubkey_bytes) != 32: + continue + peer_id = Identity(pubkey_bytes) + ss = peer_id.calc_shared_secret(self.local_identity.get_private_key()) + aes_key = ss[:16] + decrypted = CryptoUtils.mac_then_decrypt(aes_key, ss, payload) + matched_contact = contact + shared_secret = ss + break + except Exception: + continue + + if matched_contact is None or decrypted is None: + self.log( + f"Decryption failed: Invalid HMAC for all {len(candidates)} contact(s) " + f"with src hash {src_hash:02X}" + ) return if len(decrypted) < 5: # timestamp(4) + flags(1) minimum @@ -72,7 +94,7 @@ async def __call__(self, packet: Packet) -> None: txt_type = (flags >> 2) & 0x3F # Upper 6 bits are txt_type message_body = decrypted[5:] # Rest is the message content - pubkey = bytes.fromhex(matched_contact.public_key) + pubkey = self._contact_pubkey_bytes(matched_contact) timestamp_int = int.from_bytes(timestamp, "little") # Determine message routing type from packet header @@ -87,8 +109,8 @@ async def __call__(self, packet: Packet) -> None: # Skip ACK for TXT_TYPE_CLI_DATA (0x01) - CLI commands don't need ACKs # Following C++ pattern: only TXT_TYPE_PLAIN (0x00) gets ACKs TXT_TYPE_PLAIN = 0x00 - TXT_TYPE_CLI_DATA = 0x01 - send_ack = (txt_type == TXT_TYPE_PLAIN) + TXT_TYPE_CLI_DATA = 0x01 # noqa: F841 + send_ack = txt_type == TXT_TYPE_PLAIN if send_ack: # Create appropriate ACK response @@ -191,6 +213,7 @@ async def send_delayed_ack(): "contact_name": matched_contact.name, "contact_pubkey": matched_contact.public_key, "message_text": decoded_msg, + "txt_type": txt_type, "is_outgoing": False, "timestamp": message_timestamp, "delivery_status": "received", @@ -201,6 +224,7 @@ async def send_delayed_ack(): }, "sender_name": matched_contact.name, "is_read": False, + "packet_hash": packet.calculate_packet_hash().hex().upper(), } # Publish new message event for app to handle database storage diff --git a/src/pymc_core/node/node.py b/src/pymc_core/node/node.py index 266be07..9d124d6 100644 --- a/src/pymc_core/node/node.py +++ b/src/pymc_core/node/node.py @@ -17,14 +17,18 @@ class MeshNode: - """Represents a node in a mesh network for radio communication. + """Thin transport layer for mesh radio communication. - Manages radio communication, message routing, and protocol handling - within a mesh network. Provides high-level APIs for sending messages, - telemetry requests, and commands to other nodes and repeaters. + Owns a radio interface and a :class:`Dispatcher` that handles raw packet + I/O (TX lock, ACK management, handler dispatch). Application-layer + concerns — contact lookup, message building, response waiting — belong in + the companion layer (:class:`CompanionBase` and its subclasses). - The node integrates with various components like contact storage, - channel databases, and event services for comprehensive mesh functionality. + Typical usage:: + + node = MeshNode(radio, identity, config={...}) + await node.start() # blocks in dispatcher.run_forever() + node.stop() """ def __init__( @@ -40,9 +44,6 @@ def __init__( ) -> None: """Initialise a mesh network node instance. - Sets up the node's core components including radio interface, - identity management, and communication handlers. - Args: radio: Radio hardware interface for transmission/reception. local_identity: Node's cryptographic identity for secure communication. @@ -80,73 +81,41 @@ def __init__( node_name=self.node_name, radio_config=self.radio_config, ) - # Store reference to text handler for command response callbacks - self._text_handler = None - - # Helper Methods - def _find_and_call_handler_method(self, method_name: str, *args, **kwargs) -> bool: - """Find and call a method on any handler that has it. Returns True if called.""" - found = False - - if hasattr(self.dispatcher, "_handler_instances"): - for handler in self.dispatcher._handler_instances.values(): - if hasattr(handler, method_name): - getattr(handler, method_name)(*args, **kwargs) - found = True - else: - for attr_name in dir(self.dispatcher): - if attr_name.endswith("_handler"): - handler = getattr(self.dispatcher, attr_name, None) - if handler and hasattr(handler, method_name): - getattr(handler, method_name)(*args, **kwargs) - found = True - - return found - def _get_contact_or_raise(self, contact_name: str): - """Get contact by name or raise RuntimeError if not found.""" - contact = self.contacts.get_by_name(contact_name) if self.contacts else None - if not contact: - raise RuntimeError(f"No contact '{contact_name}'") - return contact + # ------------------------------------------------------------------------- + # Lifecycle + # ------------------------------------------------------------------------- - class _ResponseWaiter: - """Helper class for managing asynchronous response callbacks. + async def start(self) -> None: + """Start the mesh node and begin processing radio communications. - Provides a synchronisation mechanism for waiting on responses - from remote nodes with timeout support. + Enters the dispatcher's main event loop for handling incoming/outgoing + messages. This method blocks until the node is stopped. """ + await self.dispatcher.run_forever() - def __init__(self): - self.event = asyncio.Event() - self.data = {"success": False, "text": None, "parsed": {}} - - def callback(self, success: bool, text: str, parsed_data: Optional[dict] = None): - """Standard callback for response handlers.""" - self.data["success"] = success - self.data["text"] = text - self.data["parsed"] = parsed_data or {} - self.event.set() + def stop(self): + """Stop the mesh node and clean up associated services.""" + try: + self.logger.info("Node stopped") + except Exception as e: + self.logger.error(f"Error stopping node: {e}") - async def wait(self, timeout: float = 10.0) -> dict: - """Wait for response with timeout. Returns the response data.""" - try: - await asyncio.wait_for(self.event.wait(), timeout=timeout) - return self.data - except asyncio.TimeoutError: - return {"success": False, "text": None, "parsed": {}, "timeout": True} + # ------------------------------------------------------------------------- + # Transport + # ------------------------------------------------------------------------- - def _time_operation(self): - """Context manager for timing operations and calculating RTT.""" - import time - from contextlib import contextmanager + async def send_packet(self, pkt: Any, *, wait_for_ack: bool = False, **kwargs) -> bool: + """Send a raw packet via the dispatcher. - @contextmanager - def timer(): - start_time = time.time() - yield lambda: (time.time() - start_time) * 1000 # RTT in milliseconds + This is the single transport entry point. All message-building and + response-waiting logic lives in the companion layer. + """ + return await self.dispatcher.send_packet(pkt, wait_for_ack=wait_for_ack, **kwargs) - return timer() + # ------------------------------------------------------------------------- + # Event service propagation + # ------------------------------------------------------------------------- def set_event_service(self, event_service): """Set the event service for broadcasting mesh events.""" @@ -165,837 +134,33 @@ def set_event_service(self, event_service): if handler and hasattr(handler, "event_service"): handler.event_service = event_service - async def start(self) -> None: - """Start the mesh node and begin processing radio communications. - - Initialises the radio interface and dispatcher, then enters the main - event loop for handling incoming/outgoing messages. This method blocks - until the node is stopped. - - Note: - This is an asynchronous operation that runs indefinitely until - cancelled or the node is stopped. - """ - await self.dispatcher.run_forever() - - async def send_text( - self, - contact_name: str, - message: str, - attempt: int = 1, - message_type: str = "direct", - out_path: Optional[list] = None, - ) -> dict: - """Send a text message to a specified contact. - - Transmits a text message to another node in the mesh network, - with optional routing and retry configuration. - - Args: - contact_name: Name of the target contact in the contact book. - message: Text content to send. - attempt: Message attempt number for retry logic (default: 1). - message_type: Routing type - "direct" or other supported types. - out_path: Optional list of intermediate nodes for routing. - - Returns: - Dictionary containing transmission results including success status, - signal strength metrics (SNR/RSSI), and routing information. - - Raises: - RuntimeError: If the specified contact is not found. - - Example: - ```python - result = await node.send_text("alice", "Hello, world!") - print(result["success"]) # True if message sent successfully - ``` - """ - from ..protocol import PacketBuilder - - contact = self._get_contact_or_raise(contact_name) - - # Create the text message packet using PacketBuilder - pkt, ack_crc = PacketBuilder.create_text_message( - contact=contact, - local_identity=self.identity, - message=message, - attempt=attempt, - message_type=message_type, - out_path=out_path, - ) - - # Log packet details with routing info - routing_info = f"type={message_type}" - if out_path: - routing_info += f", path={' -> '.join(str(hop) for hop in out_path)}" - - self.logger.debug( - f"[send_text] -> {contact_name} msg='{message}' CRC={ack_crc:08X} ({routing_info})" - ) - - # Send packet with the expected ACK CRC - success = await self.dispatcher.send_packet(pkt, wait_for_ack=True, expected_crc=ack_crc) - if not success: - self.logger.warning(f"No ACK received for CRC {ack_crc:08X}") - - # Extract signal strength information from radio - snr = getattr(pkt, "snr", None) if "pkt" in locals() else None - rssi = None - - # Get current signal strength from radio for outgoing messages - if hasattr(self.radio, "get_last_rssi"): - rssi = self.radio.get_last_rssi() - if hasattr(self.radio, "get_last_snr") and snr is None: - snr = self.radio.get_last_snr() - - return { - "success": success, - "attempt": attempt, - "message_type": message_type, - "out_path": out_path, - "snr": snr, - "rssi": rssi, - "crc": ack_crc, - } - - async def send_telemetry_request( - self, - contact_name: str, - want_base: bool = True, - want_location: bool = True, - want_environment: bool = True, - timeout: float = 10.0, - ) -> dict: - """Request telemetry data from a contact node. - - Sends a telemetry request and waits for the target node to respond - with requested sensor data including base metrics, location, and - environmental readings. - - Args: - contact_name: Name of the contact to query. - want_base: Include basic telemetry metrics in request. - want_location: Include GPS/location data in request. - want_environment: Include environmental sensors in request. - timeout: Maximum time to wait for response in seconds. - - Returns: - Dictionary with request results, telemetry data, and performance - metrics including round-trip time. - - Raises: - RuntimeError: If contact not found or protocol handler unavailable. - - Example: - ```python - result = await node.send_telemetry_request("sensor_node") - if result["success"]: - print(f"Temperature: {result['telemetry_data'].get('temp')}") - ``` - """ - from ..protocol import PacketBuilder - from ..protocol.constants import REQ_TYPE_GET_TELEMETRY_DATA - - contact = self._get_contact_or_raise(contact_name) - - with self._time_operation() as get_rtt: - contact_hash = bytes.fromhex(contact.public_key)[0] - - # Set up response waiting - waiter = self._ResponseWaiter() + # ------------------------------------------------------------------------- + # Backwards-compatible utilities (deprecated — prefer companion layer) + # ------------------------------------------------------------------------- - # Register callback with protocol response handler - if not hasattr(self.dispatcher, "protocol_response_handler"): - raise RuntimeError("Protocol response handler not available") - - self.dispatcher.protocol_response_handler.set_response_callback( - contact_hash, waiter.callback - ) - - try: - # Build and send telemetry request - inv = PacketBuilder._compute_inverse_perm_mask( - want_base, want_location, want_environment - ) - - pkt, _ = PacketBuilder.create_protocol_request( - contact=contact, - local_identity=self.identity, - protocol_code=REQ_TYPE_GET_TELEMETRY_DATA, - data=bytes([inv]), - ) - - self.logger.debug( - f"[send_telemetry_request] -> {contact_name} " - f"base={want_base}, location={want_location}, " - f"environment={want_environment}" - ) - - await self.dispatcher.send_packet(pkt, wait_for_ack=False) - - # Wait for response - result = await waiter.wait(timeout) - rtt = get_rtt() - - if result.get("timeout"): - self.logger.warning( - f"Timeout waiting for telemetry response from {contact_name}" - ) - return { - "success": False, - "contact": contact_name, - "requested": { - "base": want_base, - "location": want_location, - "environment": want_environment, - }, - "telemetry_data": None, - "rtt_ms": round(rtt, 2), - "reason": f"Telemetry response timeout after {timeout}s", - } - - self.logger.info( - f"[send_telemetry_request] Response from {contact_name}: '{result['text']}'" - ) - - return { - "success": result.get("success", False), - "contact": contact_name, - "requested": { - "base": want_base, - "location": want_location, - "environment": want_environment, - }, - "telemetry_data": result["parsed"], - "response_text": result["text"], - "rtt_ms": round(rtt, 2), - "reason": ( - "Telemetry response received" - if result.get("success") - else "Telemetry request failed" - ), - } - - finally: - self.dispatcher.protocol_response_handler.clear_response_callback(contact_hash) - - def stop(self): - """Stop the mesh node and clean up associated services. - - Terminates radio communications and shuts down all active handlers. - This method is synchronous and should be called to gracefully - shut down the node. - """ - try: - self.logger.info("Node stopped") - except Exception as e: - self.logger.error(f"Error stopping node: {e}") - - async def send_group_text(self, group_name: str, message: str) -> dict: - """Broadcast a text message to all members of a group. - - Sends a group datagram that will be received by all nodes configured - for the specified group. Group messages are fire-and-forget with no - acknowledgements expected. - - Args: - group_name: Name of the group to broadcast to. - message: Text content to broadcast. - - Returns: - Dictionary with transmission results and signal metrics. - Note: Group messages don't wait for acknowledgements. - - Example: - ```python - result = await node.send_group_text("team_alpha", "Meeting at 15:00") - print(f"Broadcast to {result['group']}: {result['success']}") - ``` - """ - from ..protocol import PacketBuilder - - # Get channels from database (live query) - try: - channels_config = self.channel_db.get_channels() if self.channel_db else [] - except Exception as e: - self.logger.error(f"Failed to get channels from database: {e}") - channels_config = [] - - # Create the group text message packet using PacketBuilder - pkt = PacketBuilder.create_group_datagram( - group_name=group_name, - local_identity=self.identity, - message=message, - sender_name=self.node_name, - channels_config=channels_config, - ) - - # Log packet details (no CRC for group messages - they don't use ACKs) - self.logger.debug(f"[send_group_text] -> {group_name} msg='{message}'") - - # Send packet without waiting for ACK (group messages are unverified) - success = await self.dispatcher.send_packet(pkt, wait_for_ack=False) - - # Extract signal strength information from radio - snr = getattr(pkt, "snr", None) if "pkt" in locals() else None - rssi = None - - # Get current signal strength from radio for outgoing messages - if hasattr(self.radio, "get_last_rssi"): - rssi = self.radio.get_last_rssi() - if hasattr(self.radio, "get_last_snr") and snr is None: - snr = self.radio.get_last_snr() - - # Note: Unlike text messages, we don't publish events here - # Let the app level handle outgoing message events if needed - # This prevents duplicate events when the message is received back - - return { - "success": success, - "snr": snr, - "rssi": rssi, - "group": group_name, - } - - async def send_login(self, repeater_name: str, password: str) -> dict: - """Authenticate with a repeater node. - - Sends login credentials to a repeater and waits for authentication - response. Successful login may grant administrative privileges. - - Args: - repeater_name: Name of the repeater to authenticate with. - password: Authentication password for the repeater. - - Returns: - Dictionary with login results including success status, - admin privileges, and keep-alive intervals. - - Raises: - RuntimeError: If repeater contact not found. - - Example: - ```python - result = await node.send_login("repeater_01", "secret123") - if result["success"] and result["is_admin"]: - print("Admin access granted") - ``` - """ - from ..protocol import PacketBuilder - - contact = self._get_contact_or_raise(repeater_name) - - with self._time_operation() as get_rtt: - contact_pubkey = bytes.fromhex(contact.public_key) - dest_hash = contact_pubkey[0] if len(contact_pubkey) > 0 else 0 - - # Store password in login handlers - self._find_and_call_handler_method("store_login_password", dest_hash, password) - - # Create and send login packet - pkt = PacketBuilder.create_login_packet( - contact=contact, local_identity=self.identity, password=password - ) - - self.logger.debug(f"[send_login] -> {repeater_name}") - - # Set up login response waiting - login_result = {"success": False, "data": {}} - login_event = asyncio.Event() - - def login_response_callback(success: bool, response_data: dict): - login_result["success"] = success - login_result["data"] = response_data - login_event.set() - - # Set callback on login response handlers - self._find_and_call_handler_method("set_login_callback", login_response_callback) - - try: - await self.dispatcher.send_packet(pkt, wait_for_ack=False) - - # Wait for login response - try: - await asyncio.wait_for(login_event.wait(), timeout=10.0) - except asyncio.TimeoutError: - self.logger.warning(f"Login timeout for repeater '{repeater_name}'") - return { - "success": False, - "repeater": repeater_name, - "command": "login", - "rtt_ms": round(get_rtt(), 2), - "reason": "Login response timeout", - } - - rtt = get_rtt() - success = login_result["success"] - response_data = login_result["data"] - - if success: - self.logger.info( - f"Login successful to '{repeater_name}' " - f"(admin: {response_data.get('is_admin', False)})" - ) - reason = ( - f"Login successful - Admin: " - f"{'Yes' if response_data.get('is_admin') else 'No'}" - ) - else: - error_msg = response_data.get("error", "Login failed") - self.logger.warning(f"Login failed to '{repeater_name}': {error_msg}") - reason = f"Login failed: {error_msg}" - - return { - "success": success, - "repeater": repeater_name, - "command": "login", - "rtt_ms": round(rtt, 2), - "is_admin": response_data.get("is_admin", False), - "keep_alive_interval": response_data.get("keep_alive_interval", 0), - "reason": reason, - } - - finally: - # Clear callbacks - self._find_and_call_handler_method("set_login_callback", None) - - async def send_logout(self, repeater_name: str) -> dict: - """Terminate authentication session with a repeater. - - Sends a logout command to end the current session with a repeater. - This should be called when finished with repeater operations. - - Args: - repeater_name: Name of the repeater to logout from. - - Returns: - Dictionary with logout results and performance metrics. - - Raises: - RuntimeError: If repeater contact not found. - - Example: - ```python - result = await node.send_logout("repeater_01") - print(f"Logout {'successful' if result['success'] else 'failed'}") - ``` - """ - from ..protocol import PacketBuilder - - contact = self._get_contact_or_raise(repeater_name) - - with self._time_operation() as get_rtt: - # Create the logout packet using PacketBuilder - pkt, ack_crc = PacketBuilder.create_logout_packet( - contact=contact, local_identity=self.identity - ) - - self.logger.debug(f"[send_logout] -> {repeater_name} with CRC={ack_crc:08X}") - - # Send packet and wait for ACK - success = await self.dispatcher.send_packet( - pkt, wait_for_ack=True, expected_crc=ack_crc - ) - rtt = get_rtt() - - if not success: - self.logger.warning(f"No ACK received for logout CRC {ack_crc:08X}") - - return { - "success": success, - "repeater": repeater_name, - "command": "logout", - "rtt_ms": round(rtt, 2), - "crc": ack_crc, - "reason": "Logout successful" if success else "No ACK received", - } - - async def send_status_request(self, repeater_name: str) -> dict: - """Request status information from a repeater. - - Queries a repeater for its current operational status and configuration. - This is a convenience method that uses the text command interface. - - Args: - repeater_name: Name of the repeater to query. - - Returns: - Dictionary with status information and response metrics. - - Raises: - RuntimeError: If repeater contact not found. - - Example: - ```python - status = await node.send_status_request("repeater_01") - if status["success"]: - print(f"Status: {status['response']}") - ``` - """ - # Use the simple text command approach instead of protocol packets - return await self.send_repeater_command(repeater_name, "status") - - async def send_protocol_request( - self, repeater_name: str, protocol_code: int, data: bytes = b"" - ) -> dict: - """Send a protocol-specific request to a repeater. - - Transmits a custom protocol request with optional data payload - and waits for the repeater's response. - - Args: - repeater_name: Name of the repeater to send request to. - protocol_code: Protocol operation code (0-255). - data: Optional binary data payload for the request. - - Returns: - Dictionary with protocol response, parsed data, and timing metrics. - - Raises: - RuntimeError: If repeater contact or protocol handler not found. - - Example: - ```python - result = await node.send_protocol_request("repeater_01", 0x10, b"config") - if result["success"]: - print(f"Response: {result['response']}") - ``` - """ - from ..protocol import PacketBuilder - - contact = self._get_contact_or_raise(repeater_name) - - with self._time_operation() as get_rtt: - contact_hash = bytes.fromhex(contact.public_key)[0] - - # Set up response waiting - waiter = self._ResponseWaiter() - - if not hasattr(self.dispatcher, "protocol_response_handler"): - raise RuntimeError("Protocol response handler not available") - - self.dispatcher.protocol_response_handler.set_response_callback( - contact_hash, waiter.callback - ) - - try: - pkt, _ = PacketBuilder.create_protocol_request( - contact=contact, - local_identity=self.identity, - protocol_code=protocol_code, - data=data, - ) - - self.logger.debug( - f"[send_protocol_request] -> {repeater_name}: protocol 0x{protocol_code:02X}" - ) - - await self.dispatcher.send_packet(pkt, wait_for_ack=False) - self.logger.debug("[send_protocol_request] Packet sent, waiting for response...") - - result = await waiter.wait(10.0) - rtt = get_rtt() - - if result.get("timeout"): - self.logger.warning( - f"Timeout waiting for protocol response from {repeater_name}" - ) - return { - "success": False, - "repeater": repeater_name, - "command": f"protocol_0x{protocol_code:02X}", - "protocol_code": protocol_code, - "response": None, - "parsed_data": {}, - "rtt_ms": round(rtt, 2), - "ack_received": False, - "reason": f"Protocol 0x{protocol_code:02X} timeout", - } - - self.logger.info( - f"[send_protocol_request] Response from {repeater_name}: '{result['text']}'" - ) - - return { - "success": result["success"], - "repeater": repeater_name, - "command": f"protocol_0x{protocol_code:02X}", - "protocol_code": protocol_code, - "response": result["text"], - "parsed_data": result["parsed"], - "rtt_ms": round(rtt, 2), - "ack_received": False, - "reason": ( - f"Protocol 0x{protocol_code:02X} " - f"{'successful' if result['success'] else 'failed'}" - ), - } - - finally: - self.dispatcher.protocol_response_handler.clear_response_callback(contact_hash) - - async def send_trace_packet( - self, - contact_name: str, - tag: int, - auth_code: int, - flags: int = 0, - path: Optional[list] = None, - timeout: float = 5.0, - ) -> dict: - """Send a diagnostic trace packet for network analysis. - - Transmits a trace packet to analyse routing paths and network - performance. Always expects a response with trace data, signal - metrics, and routing information. - - Args: - contact_name: Name of the target contact for tracing. - tag: Unique identifier for this trace operation. - auth_code: Authentication code for the trace request. - flags: Optional flags to modify trace behaviour. - path: Optional custom routing path for the trace. - timeout: Maximum time to wait for trace response. - - Returns: - Dictionary with trace results, routing data, and signal metrics. - - Raises: - RuntimeError: If contact not found or trace handler unavailable. - - Example: - ```python - trace = await node.send_trace_packet("target_node", 0x12345678, 0xABCD) - if trace["success"]: - print(f"RTT: {trace['rtt_ms']}ms") - ``` - """ - from ..protocol import PacketBuilder - - contact = self._get_contact_or_raise(contact_name) - path = path or [] - - # Get target node ID from contact's public key - target_node_id = bytes.fromhex(contact.public_key)[0] - - # Use provided path or create simple direct path - trace_path = path if path else [target_node_id] - - with self._time_operation() as get_rtt: - # Create trace packet with path included - pkt = PacketBuilder.create_trace(tag, auth_code, flags, path=trace_path) - - self.logger.debug( - f"[send_trace_packet] -> {contact_name} tag=0x{tag:08X} path={trace_path}" - ) - - # Send trace packet and wait for response - try: - handler = self.dispatcher.trace_handler - contact_hash = bytes.fromhex(contact.public_key)[0] - waiter = self._ResponseWaiter() - - if handler: - handler.set_response_callback(contact_hash, waiter.callback) - - try: - await self.dispatcher.send_packet(pkt, wait_for_ack=False) - result = await waiter.wait(timeout) - rtt = get_rtt() - - if result.get("timeout"): - self.logger.warning(f"No trace response from {contact_name}") - return { - "success": False, - "contact": contact_name, - "trace_data": { - "tag": tag, - "auth_code": auth_code, - "flags": flags, - "path": trace_path, - }, - "response": None, - "rtt_ms": round(rtt, 2), - "reason": f"Timeout after {timeout}s", - } - - self.logger.info( - f"Trace response from {contact_name}: '{result.get('text', '')}'" - ) - return { - "success": result.get("success", True), - "contact": contact_name, - "trace_data": { - "tag": tag, - "auth_code": auth_code, - "flags": flags, - "path": trace_path, - }, - "response": result.get("text"), - "parsed_data": result.get("parsed", {}), - "rtt_ms": round(rtt, 2), - "reason": "Response received", - } - - finally: - handler.clear_response_callback(contact_hash) - else: - self.logger.error(f"No trace handler for {contact_name}") - return { - "success": False, - "contact": contact_name, - "reason": "No trace handler available", - } - - except Exception as e: - rtt = get_rtt() - self.logger.error(f"Trace error: {e}") - return { - "success": False, - "contact": contact_name, - "trace_data": { - "tag": tag, - "auth_code": auth_code, - "flags": flags, - "path": trace_path, - }, - "response": None, - "rtt_ms": round(rtt, 2), - "reason": f"Error: {str(e)}", - } - - async def send_repeater_command( - self, repeater_name: str, command: str, parameters: Optional[str] = None - ) -> dict: - """Send a text-based command to a repeater and await response. - - Transmits a command string to a repeater using the text message - protocol and waits for a response. Useful for administrative - operations and status queries. + class _ResponseWaiter: + """Synchronisation helper for async response callbacks. - Args: - repeater_name: Name of the repeater to send command to. - command: Command string to execute on the repeater. - parameters: Optional parameters for the command. - - Returns: - Dictionary with command results, response text, and timing data. - - Raises: - RuntimeError: If repeater contact not found. - - Example: - ```python - result = await node.send_repeater_command("repeater_01", "status") - if result["success"]: - print(f"Response: {result['response']}") - ``` + .. deprecated:: + Use :class:`~pymc_core.companion.models.ResponseWaiter` from the + companion layer instead. """ - from ..protocol import PacketBuilder - - contact = self._get_contact_or_raise(repeater_name) - - with self._time_operation() as get_rtt: - # Build full command string - full_command = command - if parameters: - full_command += f" {parameters}" - - # Set up response capture - response_event = asyncio.Event() - response_data = {"text": None, "success": False} - def response_callback(message_text: str, sender_contact): - response_data["text"] = message_text - response_data["success"] = True - response_event.set() + def __init__(self): + self.event = asyncio.Event() + self.data = {"success": False, "text": None, "parsed": {}} - # Set response callback - self.dispatcher.text_message_handler.set_command_response_callback(response_callback) + def callback(self, success: bool, text: str, parsed_data: Optional[dict] = None): + """Standard callback for response handlers.""" + self.data["success"] = success + self.data["text"] = text + self.data["parsed"] = parsed_data or {} + self.event.set() + async def wait(self, timeout: float = 10.0) -> dict: + """Wait for response with timeout. Returns the response data.""" try: - # Create and send packet - pkt, ack_crc = PacketBuilder.create_text_message( - contact=contact, - local_identity=self.identity, - message=full_command, - attempt=1, - message_type="command", - ) - - # Send packet and get ACK result - ack_success = await self.dispatcher.send_packet( - pkt, wait_for_ack=True, expected_crc=ack_crc - ) - - # Wait for response (regardless of ACK result) - try: - await asyncio.wait_for(response_event.wait(), timeout=15.0) - response_received = True - except asyncio.TimeoutError: - response_received = False - - # Calculate RTT - rtt = get_rtt() - response_text = response_data["text"] - - # Return result based on what we got - if response_received: - return { - "success": True, - "repeater": repeater_name, - "command": command, - "parameters": parameters, - "full_command": full_command, - "response": response_text, - "rtt_ms": round(rtt, 2), - "crc": ack_crc, - "ack_received": ack_success, - "reason": f"Command '{command}' successful with response" - + ("" if ack_success else " (no ACK)"), - } - elif ack_success: - return { - "success": True, - "repeater": repeater_name, - "command": command, - "parameters": parameters, - "full_command": full_command, - "response": "Command sent successfully (no response received)", - "rtt_ms": round(rtt, 2), - "crc": ack_crc, - "ack_received": True, - "reason": f"Command '{command}' sent but no response received", - } - else: - return { - "success": False, - "repeater": repeater_name, - "command": command, - "parameters": parameters, - "full_command": full_command, - "response": None, - "rtt_ms": round(rtt, 2), - "crc": ack_crc, - "ack_received": False, - "reason": f"No ACK or response received for command '{command}'", - } - - except Exception as e: - rtt = get_rtt() - return { - "success": False, - "repeater": repeater_name, - "command": command, - "parameters": parameters, - "full_command": full_command, - "response": None, - "rtt_ms": round(rtt, 2), - "crc": None, - "ack_received": False, - "reason": f"Error sending command: {e}", - } - finally: - # Always clear the callback - self.dispatcher.text_message_handler.set_command_response_callback(None) + await asyncio.wait_for(self.event.wait(), timeout=timeout) + return self.data + except asyncio.TimeoutError: + return {"success": False, "text": None, "parsed": {}, "timeout": True} diff --git a/src/pymc_core/protocol/__init__.py b/src/pymc_core/protocol/__init__.py index 75cdddf..816d4fc 100644 --- a/src/pymc_core/protocol/__init__.py +++ b/src/pymc_core/protocol/__init__.py @@ -42,6 +42,8 @@ PH_VER_MASK, PH_VER_SHIFT, PUB_KEY_SIZE, + REQ_TYPE_GET_OWNER_INFO, + REQ_TYPE_GET_STATUS, REQ_TYPE_GET_TELEMETRY_DATA, ROUTE_TYPE_DIRECT, ROUTE_TYPE_FLOOD, @@ -70,6 +72,7 @@ PacketHeaderUtils, PacketTimingUtils, PacketValidationUtils, + PathUtils, RouteTypeUtils, ) from .transport_keys import calc_transport_code, get_auto_key_for @@ -96,6 +99,7 @@ "PacketHeaderUtils", "PacketHashingUtils", "RouteTypeUtils", + "PathUtils", "PacketTimingUtils", # Header constants "PH_ROUTE_MASK", @@ -148,8 +152,10 @@ "CONTACT_TYPE_REPEATER", "CONTACT_TYPE_ROOM_SERVER", "CONTACT_TYPE_HYBRID", - # Telemetry + # Protocol request types + "REQ_TYPE_GET_STATUS", "REQ_TYPE_GET_TELEMETRY_DATA", + "REQ_TYPE_GET_OWNER_INFO", "TELEM_PERM_BASE", "TELEM_PERM_LOCATION", "TELEM_PERM_ENVIRONMENT", diff --git a/src/pymc_core/protocol/constants.py b/src/pymc_core/protocol/constants.py index d5b23d9..a3421ed 100644 --- a/src/pymc_core/protocol/constants.py +++ b/src/pymc_core/protocol/constants.py @@ -48,7 +48,9 @@ MAX_ADVERT_DATA_SIZE = 96 PUB_KEY_SIZE = 32 SIGNATURE_SIZE = 64 -PATH_HASH_SIZE = 1 +PATH_HASH_SIZE = 1 # Legacy default; see PathUtils for multi-byte path support +PATH_HASH_COUNT_MASK = 0x3F # bits 0-5 of encoded path_len (max encodable hop count) +PATH_HASH_SIZE_SHIFT = 6 # bits 6-7 of encoded path_len CIPHER_MAC_SIZE = 32 # SHA‑256 HMAC CIPHER_BLOCK_SIZE = 16 MAX_PACKET_PAYLOAD = 256 # firmware's default @@ -66,6 +68,7 @@ ADVERT_FLAG_IS_CHAT_NODE = 0x01 ADVERT_FLAG_IS_REPEATER = 0x02 ADVERT_FLAG_IS_ROOM_SERVER = 0x03 +ADVERT_FLAG_IS_SENSOR = 0x04 ADVERT_FLAG_HAS_LOCATION = 0x10 ADVERT_FLAG_HAS_FEATURE1 = 0x20 ADVERT_FLAG_HAS_FEATURE2 = 0x40 @@ -107,9 +110,10 @@ def describe_advert_flags(flags: int) -> str: CONTACT_TYPE_HYBRID = 4 -# Telemetry Permissions - -REQ_TYPE_GET_TELEMETRY_DATA = 0x03 +# Protocol Request Types +REQ_TYPE_GET_STATUS = 0x01 # Get repeater stats (RepeaterStats struct) +REQ_TYPE_GET_TELEMETRY_DATA = 0x03 # Get telemetry data (CayenneLPP) +REQ_TYPE_GET_OWNER_INFO = 0x07 # Variable-length: tag(4) + "version\nname\nowner" (simple_repeater) TELEM_PERM_BASE = 0x01 TELEM_PERM_LOCATION = 0x02 TELEM_PERM_ENVIRONMENT = 0x04 diff --git a/src/pymc_core/protocol/crypto.py b/src/pymc_core/protocol/crypto.py index 29a1aa4..c2bb516 100644 --- a/src/pymc_core/protocol/crypto.py +++ b/src/pymc_core/protocol/crypto.py @@ -5,6 +5,7 @@ from nacl.bindings import ( crypto_scalarmult, crypto_scalarmult_base, + crypto_scalarmult_ed25519_base_noclamp, crypto_sign_ed25519_pk_to_curve25519, crypto_sign_ed25519_sk_to_curve25519, ) @@ -36,10 +37,17 @@ def _aes_encrypt(key: bytes, data: bytes) -> bytes: @staticmethod def _aes_decrypt(key: bytes, ciphertext: bytes) -> bytes: cipher = AES.new(key, AES.MODE_ECB) - return b"".join( + orig_len = len(ciphertext) + if orig_len % CIPHER_BLOCK_SIZE != 0: + # Firmware may send non-block-aligned ciphertext (e.g. telemetry 63 bytes). + # Pad to next block for decrypt, then return only original length. + pad_len = CIPHER_BLOCK_SIZE - (orig_len % CIPHER_BLOCK_SIZE) + ciphertext = ciphertext + (b"\x00" * pad_len) + out = b"".join( cipher.decrypt(ciphertext[i : i + CIPHER_BLOCK_SIZE]) for i in range(0, len(ciphertext), CIPHER_BLOCK_SIZE) ) + return out[:orig_len] @staticmethod def encrypt_then_mac(key_aes: bytes, shared_secret: bytes, plaintext: bytes) -> bytes: @@ -69,6 +77,32 @@ def mac_then_decrypt(aes_key: bytes, shared_secret: bytes, data: bytes) -> bytes return decrypted + @staticmethod + def x25519_clamp_scalar(scalar: bytes) -> bytes: + """Clamp a 32-byte scalar for X25519 (matches MeshCore key_exchange.c).""" + if len(scalar) != 32: + raise ValueError("scalar must be 32 bytes") + s = bytearray(scalar) + s[0] &= 248 + s[31] &= 63 + s[31] |= 64 + return bytes(s) + + @staticmethod + def ed25519_expand_seed_to_meshcore_64(seed: bytes) -> bytes: + """Expand a 32-byte Ed25519 seed to 64-byte MeshCore/orlp format.""" + if len(seed) != 32: + raise ValueError("seed must be 32 bytes") + h = hashlib.sha512(seed).digest() + return CryptoUtils.x25519_clamp_scalar(h[:32]) + h[32:64] + + @staticmethod + def ed25519_public_from_meshcore_64(meshcore_private_64: bytes) -> bytes: + """Derive Ed25519 public key from 64-byte MeshCore private key format.""" + if len(meshcore_private_64) != 64: + raise ValueError("meshcore_private_64 must be 64 bytes") + return crypto_scalarmult_ed25519_base_noclamp(meshcore_private_64[:32]) + @staticmethod def scalarmult(private_key: bytes, public_key: bytes) -> bytes: """ECDH shared secret calculation (X25519).""" diff --git a/src/pymc_core/protocol/identity.py b/src/pymc_core/protocol/identity.py index a5c8101..bd6a23f 100644 --- a/src/pymc_core/protocol/identity.py +++ b/src/pymc_core/protocol/identity.py @@ -100,17 +100,21 @@ def __init__(self, seed: Optional[bytes] = None): if seed and len(seed) == 64: from nacl.bindings import crypto_scalarmult_ed25519_base_noclamp - # MeshCore format: [32-byte clamped scalar][32-byte nonce] + # MeshCore format: [32-byte scalar][32-byte nonce]. Identity.cpp readFrom(64) calls + # ed25519_derive_pub(pub_key, prv_key) which uses first 32 bytes as-is (no clamp). + # key_exchange.c uses first 32 bytes clamped for ECDH. We must match both. self._firmware_key = seed self.signing_key = None - # Derive public key from scalar - scalar = seed[:32] - ed25519_pub = crypto_scalarmult_ed25519_base_noclamp(scalar) + scalar_first32 = seed[:32] + # Ed25519 public: derive without clamping so get_public_key() matches firmware + ed25519_pub = crypto_scalarmult_ed25519_base_noclamp(scalar_first32) self.verify_key = VerifyKey(ed25519_pub) - # Build ed25519_sk for X25519 conversion (use reconstructed format) - ed25519_sk = scalar + ed25519_pub + # ECDH: use clamped scalar so shared secret matches firmware's ed25519_key_exchange() + clamped = CryptoUtils.x25519_clamp_scalar(scalar_first32) + self._x25519_private = clamped + self._x25519_public = CryptoUtils.scalarmult_base(clamped) else: # Standard 32-byte seed or None self._firmware_key = None @@ -121,9 +125,9 @@ def __init__(self, seed: Optional[bytes] = None): ed25519_pub = self.verify_key.encode() ed25519_sk = self.signing_key.encode() + ed25519_pub - # X25519 keypair for ECDH - self._x25519_private = CryptoUtils.ed25519_sk_to_x25519(ed25519_sk) - self._x25519_public = CryptoUtils.scalarmult_base(self._x25519_private) + # X25519 keypair for ECDH (libsodium conversion) + self._x25519_private = CryptoUtils.ed25519_sk_to_x25519(ed25519_sk) + self._x25519_public = CryptoUtils.scalarmult_base(self._x25519_private) # Initialise base class with Ed25519 pubkey super().__init__(ed25519_pub) @@ -159,10 +163,10 @@ def get_shared_public_key(self) -> bytes: def get_signing_key_bytes(self) -> bytes: """ Get the signing key bytes for this identity. - + For standard keys, returns the 32-byte Ed25519 seed. For firmware keys, returns the 64-byte expanded key format [scalar||nonce]. - + Returns: The signing key bytes (32 or 64 bytes depending on key type). """ diff --git a/src/pymc_core/protocol/modem_identity.py b/src/pymc_core/protocol/modem_identity.py index e8551ee..ab54ec2 100644 --- a/src/pymc_core/protocol/modem_identity.py +++ b/src/pymc_core/protocol/modem_identity.py @@ -188,8 +188,7 @@ def get_signing_key_bytes(self) -> bytes: RuntimeError: Always, as signing key is not accessible """ raise RuntimeError( - "ModemIdentity does not expose signing keys. " - "Use sign() for signing operations." + "ModemIdentity does not expose signing keys. " "Use sign() for signing operations." ) # Additional modem-specific methods diff --git a/src/pymc_core/protocol/packet.py b/src/pymc_core/protocol/packet.py index 4f5a781..c5235b9 100644 --- a/src/pymc_core/protocol/packet.py +++ b/src/pymc_core/protocol/packet.py @@ -1,7 +1,6 @@ from typing import ByteString, Optional from .constants import ( - MAX_PATH_SIZE, MAX_SUPPORTED_PAYLOAD_VERSION, PH_ROUTE_MASK, PH_TYPE_MASK, @@ -16,7 +15,7 @@ SIGNATURE_SIZE, TIMESTAMP_SIZE, ) -from .packet_utils import PacketDataUtils, PacketHashingUtils, PacketValidationUtils +from .packet_utils import PacketDataUtils, PacketHashingUtils, PacketValidationUtils, PathUtils """ ╔═══════════════════════════════════════════════════════════════════════════╗ @@ -30,9 +29,10 @@ ║ Transport Codes ║ Two 16-bit codes (4 bytes total). Only present for ║ ║ (0 or 4 bytes) ║ TRANSPORT_FLOOD and TRANSPORT_DIRECT route types. ║ ╠════════════════════╬══════════════════════════════════════════════════════╣ -║ Path Length (1 B) ║ Number of path hops (0–15). ║ +║ Path Length (1 B) ║ Encoded: bits 0-5 = hash count (hops), bits 6-7 = ║ +║ ║ (hash_size - 1). Actual bytes = count × hash_size. ║ ╠════════════════════╬══════════════════════════════════════════════════════╣ -║ Path (N bytes) ║ List of node hashes (1 byte each), length = path_len ║ +║ Path (N bytes) ║ Node hashes (1-3 bytes each), N = count × hash_size ║ ╠════════════════════╬══════════════════════════════════════════════════════╣ ║ Payload (N bytes) ║ Actual encrypted or plain payload. Max: 254 bytes ║ ╠════════════════════╬══════════════════════════════════════════════════════╣ @@ -71,8 +71,8 @@ class Packet: Attributes: header (int): Single byte header containing packet type and flags. transport_codes (list): Two 16-bit transport codes for TRANSPORT route types. - path_len (int): Length of the path component in bytes. - path (bytearray): Variable-length path data for routing. + path_len (int): Encoded path length byte (bits 0-5 = hash count, bits 6-7 = hash size - 1). + path (bytearray): Variable-length path data for routing (hash_count × hash_size bytes). payload (bytearray): Variable-length payload data. payload_len (int): Actual length of payload data. _rssi (int): Raw RSSI signal strength value from firmware. @@ -82,8 +82,7 @@ class Packet: ```python packet = Packet() packet.header = 0x01 # Flood routing - packet.path = b"node1->node2" - packet.path_len = len(packet.path) + packet.set_path(b"\\xAA\\xBB\\xCC") # 3 hops, 1-byte hashes packet.payload = b"Hello World" packet.payload_len = len(packet.payload) data = packet.write_to() @@ -207,6 +206,73 @@ def is_route_direct(self) -> bool: route_type = self.get_route_type() return route_type == ROUTE_TYPE_TRANSPORT_DIRECT or route_type == ROUTE_TYPE_DIRECT + def get_path_hash_size(self) -> int: + """Extract per-hop hash size (1, 2, or 3) from the encoded path_len byte.""" + return PathUtils.get_path_hash_size(self.path_len) + + def get_path_hash_count(self) -> int: + """Extract hop count (0-63) from the encoded path_len byte.""" + return PathUtils.get_path_hash_count(self.path_len) + + def get_path_byte_len(self) -> int: + """Calculate actual path byte length from the encoded path_len byte.""" + return PathUtils.get_path_byte_len(self.path_len) + + def get_path_hashes(self) -> list: + """Return path as a list of per-hop hash entries (1, 2, or 3 bytes each). + + Groups the raw ``self.path`` bytearray using the hash size encoded in + ``self.path_len``. Each entry in the returned list is a ``bytes`` + object whose length equals ``get_path_hash_size()``. + + Returns: + list[bytes]: One entry per hop. Empty list when hop count is 0. + """ + hash_size = self.get_path_hash_size() + count = self.get_path_hash_count() + result = [] + for i in range(count): + start = i * hash_size + end = start + hash_size + if end <= len(self.path): + result.append(bytes(self.path[start:end])) + return result + + def get_path_hashes_hex(self) -> list: + """Return path as a list of uppercase hex strings, one per hop. + + Examples:: + + 1-byte hashes: ["B5", "A3", "F2"] + 2-byte hashes: ["B5A3", "F2C1"] + 3-byte hashes: ["B5A3F2", "C1D4E7"] + """ + return [entry.hex().upper() for entry in self.get_path_hashes()] + + def set_path( + self, + path_bytes: bytes, + path_len_encoded: int = None, + ) -> None: + """Set the routing path with optional encoded path_len. + + Args: + path_bytes: Raw path bytes to set. + path_len_encoded: Pre-encoded path_len byte. If None, assumes + 1-byte hashes and encodes len(path_bytes) as the hop count. + """ + self.path = bytearray(path_bytes) + if path_len_encoded is not None: + self.path_len = path_len_encoded + else: + hop_count = len(path_bytes) + if hop_count > 63: + raise ValueError( + f"path length {hop_count} exceeds maximum encodable hop count 63 " + "for 1-byte hashes; pass path_len_encoded explicitly or use a shorter path" + ) + self.path_len = PathUtils.encode_path_len(1, hop_count) + def get_payload(self) -> bytes: """ Get the packet payload as immutable bytes, truncated to declared length. @@ -250,7 +316,7 @@ def _validate_lengths(self) -> None: ValueError: If any declared length doesn't match the actual buffer length. """ PacketValidationUtils.validate_buffer_lengths( - self.path_len, len(self.path), self.payload_len, len(self.payload) + self.get_path_byte_len(), len(self.path), self.payload_len, len(self.payload) ) def _check_bounds(self, idx: int, required: int, data_len: int, error_msg: str) -> None: @@ -295,7 +361,7 @@ def write_to(self) -> bytes: out.extend(self.transport_codes[1].to_bytes(2, "little")) out.append(self.path_len) - out += self.path + out += self.path[: self.get_path_byte_len()] out += self.payload[: self.payload_len] return bytes(out) @@ -338,12 +404,16 @@ def read_from(self, data: ByteString) -> bool: self._check_bounds(idx, 1, data_len, "missing path_len") self.path_len = data[idx] idx += 1 - if self.path_len > MAX_PATH_SIZE: + if not PathUtils.is_valid_path_len(self.path_len): + hash_size = PathUtils.get_path_hash_size(self.path_len) + if hash_size > 3: + raise ValueError(f"invalid path_len encoding: 0x{self.path_len:02X}") raise ValueError("path_len too large") - self._check_bounds(idx, self.path_len, data_len, "truncated path") - self.path = bytearray(data[idx : idx + self.path_len]) - idx += self.path_len + path_byte_len = self.get_path_byte_len() + self._check_bounds(idx, path_byte_len, data_len, "truncated path") + self.path = bytearray(data[idx : idx + path_byte_len]) + idx += path_byte_len self.payload = bytearray(data[idx:]) self.payload_len = len(self.payload) @@ -425,7 +495,9 @@ def get_raw_length(self) -> int: Note: This matches the wire format used by write_to() and expected by read_from(). """ - base_length = 2 + self.path_len + self.payload_len # header + path_len + path + payload + base_length = ( + 2 + self.get_path_byte_len() + self.payload_len + ) # header + path_len_byte + path + payload return base_length + (4 if self.has_transport_codes() else 0) def get_snr(self) -> float: diff --git a/src/pymc_core/protocol/packet_builder.py b/src/pymc_core/protocol/packet_builder.py index 64fc012..7d2d8f2 100644 --- a/src/pymc_core/protocol/packet_builder.py +++ b/src/pymc_core/protocol/packet_builder.py @@ -23,6 +23,7 @@ PAYLOAD_TYPE_GRP_DATA, PAYLOAD_TYPE_GRP_TXT, PAYLOAD_TYPE_PATH, + PAYLOAD_TYPE_RAW_CUSTOM, PAYLOAD_TYPE_REQ, PAYLOAD_TYPE_RESPONSE, PAYLOAD_TYPE_TRACE, @@ -34,7 +35,13 @@ TELEM_PERM_LOCATION, ) from .identity import Identity, LocalIdentity -from .packet_utils import PacketDataUtils, PacketHeaderUtils, PacketValidationUtils, RouteTypeUtils +from .packet_utils import ( + PacketDataUtils, + PacketHeaderUtils, + PacketValidationUtils, + PathUtils, + RouteTypeUtils, +) logger = logging.getLogger(__name__) @@ -469,10 +476,26 @@ def create_login_packet(contact: Any, local_identity: LocalIdentity, password: s contact_identity = Identity(contact_pubkey) shared_secret = contact_identity.calc_shared_secret(local_identity.get_private_key()) - return PacketBuilder.create_anon_req( - contact_identity, local_identity, shared_secret, plaintext, "direct" + out_path_len = getattr(contact, "out_path_len", -1) + if out_path_len < 0: + route_type = "flood" + else: + route_type = "direct" + + pkt = PacketBuilder.create_anon_req( + contact_identity, local_identity, shared_secret, plaintext, route_type ) + if route_type == "direct" and out_path_len > 0: + out_path = getattr(contact, "out_path", b"") + if out_path: + pkt.set_path( + out_path[:MAX_PATH_SIZE], + out_path_len if PathUtils.is_valid_path_len(out_path_len) else None, + ) + + return pkt + @staticmethod def create_group_datagram( group_name: str, @@ -523,9 +546,19 @@ def create_group_datagram( secret_bytes = ( bytes.fromhex(channel["secret"]) if isinstance(channel["secret"], str) - else channel["secret"].encode("utf-8") + else ( + channel["secret"] + if isinstance(channel["secret"], bytes) + else channel["secret"].encode("utf-8") + ) + ) + # Same channel hash as GroupTextHandler (hash first 16 when key has second 16 zero) + hash_input = ( + secret_bytes[:16] + if len(secret_bytes) >= 32 and secret_bytes[16:32] == b"\x00" * 16 + else (secret_bytes[:32] if len(secret_bytes) > 32 else secret_bytes) ) - channel_hash = hashlib.sha256(secret_bytes).digest()[0] + channel_hash = hashlib.sha256(hash_input).digest()[0] secret_bytes = (secret_bytes + b"\x00" * 32)[:32] timestamp, flags = PacketBuilder._get_timestamp(), 0x00 @@ -618,6 +651,21 @@ def create_trace( pkt.payload_len = len(payload) return pkt + @staticmethod + def create_raw_data(data: bytes) -> Packet: + """ + Create a raw custom packet (PAYLOAD_TYPE_RAW_CUSTOM) with no encryption. + + Route type is always DIRECT (consistent with firmware CMD_SEND_RAW_DATA). + Caller must set pkt.path and pkt.path_len for direct routing. + """ + if len(data) > MAX_PACKET_PAYLOAD: + raise ValueError( + f"Raw data length {len(data)} exceeds MAX_PACKET_PAYLOAD ({MAX_PACKET_PAYLOAD})" + ) + header = PacketBuilder._create_header(PAYLOAD_TYPE_RAW_CUSTOM, route_type="direct") + return PacketBuilder._create_packet(header, data) + @staticmethod def create_path_return( dest_hash: int, @@ -732,8 +780,23 @@ def create_text_message( f"Path length {len(routing_path)} exceeds maximum {MAX_PATH_SIZE}, truncating" ) routing_path = routing_path[:MAX_PATH_SIZE] - pkt.path = bytearray(routing_path) - pkt.path_len = len(pkt.path) + # Preserve encoded path_len from contact when using its stored path + contact_path_len = getattr(contact, "out_path_len", -1) if contact else -1 + if ( + out_path is None + and contact_path_len >= 0 + and PathUtils.is_valid_path_len(contact_path_len) + ): + pkt.set_path(bytearray(routing_path), contact_path_len) + else: + # path_len encodes hop count in 6 bits (0-63); 64 would encode as 0 + if len(routing_path) == 64: + logger.warning( + "Path length 64 exceeds encodable hop count 63 (1-byte hashes), " + "truncating to 63 bytes" + ) + routing_path = routing_path[:63] + pkt.set_path(bytearray(routing_path)) else: pkt.path_len, pkt.path = 0, bytearray() @@ -801,8 +864,22 @@ def create_protocol_request( contact, local_identity, plaintext ) - header = PacketBuilder._create_header(PAYLOAD_TYPE_REQ) + out_path_len = getattr(contact, "out_path_len", -1) + out_path = getattr(contact, "out_path", b"") or b"" + if out_path_len <= 0 or not out_path: + route_type = "flood" + else: + route_type = "direct" + + header = PacketBuilder._create_header(PAYLOAD_TYPE_REQ, route_type) packet = PacketBuilder._create_packet(header, payload) + + if route_type == "direct" and len(out_path) > 0: + packet.set_path( + out_path[:MAX_PATH_SIZE], + out_path_len if PathUtils.is_valid_path_len(out_path_len) else None, + ) + return packet, timestamp @staticmethod diff --git a/src/pymc_core/protocol/packet_utils.py b/src/pymc_core/protocol/packet_utils.py index fd30505..b9490c9 100644 --- a/src/pymc_core/protocol/packet_utils.py +++ b/src/pymc_core/protocol/packet_utils.py @@ -11,6 +11,8 @@ MAX_HASH_SIZE, MAX_PACKET_PAYLOAD, MAX_PATH_SIZE, + PATH_HASH_COUNT_MASK, + PATH_HASH_SIZE_SHIFT, PAYLOAD_TYPE_TRACE, PAYLOAD_VER_1, PH_ROUTE_MASK, @@ -101,6 +103,79 @@ def validate_payload_size(payload_len: int) -> None: raise ValueError(f"payload too large: {payload_len} > {MAX_PACKET_PAYLOAD}") +class PathUtils: + """Multi-byte path encoding/decoding matching firmware Packet.h. + + The encoded path_len byte packs both hash size and hop count: + - Bits 0-5: hash_count (number of hops, 0-63) + - Bits 6-7: (hash_size - 1), where hash_size is 1, 2, or 3 bytes + + Total path bytes on the wire = hash_count * hash_size. + For 1-byte hashes the encoded byte equals the raw hop count, + preserving backward compatibility with legacy packets. + """ + + @staticmethod + def get_path_hash_size(path_len_byte: int) -> int: + """Extract per-hop hash size (1, 2, or 3) from the encoded path_len byte.""" + return (path_len_byte >> PATH_HASH_SIZE_SHIFT) + 1 + + @staticmethod + def get_path_hash_count(path_len_byte: int) -> int: + """Extract hop count (0-63) from the encoded path_len byte.""" + return path_len_byte & PATH_HASH_COUNT_MASK + + @staticmethod + def get_path_byte_len(path_len_byte: int) -> int: + """Calculate actual path byte length from the encoded path_len byte.""" + return (path_len_byte & PATH_HASH_COUNT_MASK) * ( + (path_len_byte >> PATH_HASH_SIZE_SHIFT) + 1 + ) + + @staticmethod + def encode_path_len(hash_size: int, hash_count: int) -> int: + """Encode hash size and hop count into a single path_len byte. + + Hop count is stored in 6 bits (0-63). Values above 63 are invalid and raise. + """ + if not 0 <= hash_count <= 63: + raise ValueError(f"hop count must be 0-63 for path_len encoding, got {hash_count}") + return ((hash_size - 1) << PATH_HASH_SIZE_SHIFT) | (hash_count & PATH_HASH_COUNT_MASK) + + @staticmethod + def is_valid_path_len(path_len_byte: int) -> bool: + """Validate an encoded path_len byte. + + Path length in bytes must never exceed MAX_PATH_SIZE (64): at most 64 + one-byte hops, 32 two-byte hops, or 21 three-byte hops. Returns False + for hash_size == 4 (reserved) or if the total path bytes would exceed + MAX_PATH_SIZE. + """ + hash_size = (path_len_byte >> PATH_HASH_SIZE_SHIFT) + 1 + if hash_size > 3: + return False + hash_count = path_len_byte & PATH_HASH_COUNT_MASK + return hash_count * hash_size <= MAX_PATH_SIZE + + @staticmethod + def is_path_at_max_hops(path_len_byte: int) -> bool: + """True if path has reached maximum hops for its hash size (do not retransmit). + + Measures hops, not raw bytes. Max hops depend on hash size and MAX_PATH_SIZE: + - 1-byte hashes: 63 hops (63 bytes) + - 2-byte hashes: 32 hops (64 bytes) + - 3-byte hashes: 21 hops (63 bytes) + """ + if path_len_byte == 0: + return False + hash_size = PathUtils.get_path_hash_size(path_len_byte) + if hash_size > 3: + return False + hash_count = path_len_byte & PATH_HASH_COUNT_MASK + max_hops = min(PATH_HASH_COUNT_MASK, MAX_PATH_SIZE // hash_size) + return hash_count >= max_hops + + class PacketDataUtils: """Centralized data packing and unpacking utilities.""" @@ -341,15 +416,17 @@ def calc_flood_timeout_ms(packet_airtime_ms: float) -> float: return SEND_TIMEOUT_BASE_MILLIS + (FLOOD_SEND_TIMEOUT_FACTOR * packet_airtime_ms) @staticmethod - def calc_direct_timeout_ms(packet_airtime_ms: float, path_len: int) -> float: + def calc_direct_timeout_ms(packet_airtime_ms: float, path_hash_count: int) -> float: """ Calculate timeout for direct packets. - Formula: 500ms + ((airtime × 6 + 250ms) × (path_len + 1)) + Formula: 500ms + ((airtime × 6 + 250ms) × (path_hash_count + 1)) Args: packet_airtime_ms: Estimated packet airtime in milliseconds - path_len: Number of hops in the path (0 for direct) + path_hash_count: Number of hops in the path (0 for direct). + Use ``PathUtils.get_path_hash_count(path_len)`` to extract + from an encoded path_len byte. Returns: Timeout in milliseconds @@ -359,5 +436,5 @@ def calc_direct_timeout_ms(packet_airtime_ms: float, path_len: int) -> float: DIRECT_SEND_PERHOP_EXTRA_MILLIS = 250 return SEND_TIMEOUT_BASE_MILLIS + ( (packet_airtime_ms * DIRECT_SEND_PERHOP_FACTOR + DIRECT_SEND_PERHOP_EXTRA_MILLIS) - * (path_len + 1) + * (path_hash_count + 1) ) diff --git a/src/pymc_core/protocol/transport_keys.py b/src/pymc_core/protocol/transport_keys.py index 9518f3e..d174196 100644 --- a/src/pymc_core/protocol/transport_keys.py +++ b/src/pymc_core/protocol/transport_keys.py @@ -7,43 +7,44 @@ """ import struct + from .crypto import CryptoUtils def get_auto_key_for(name: str) -> bytes: """ Generate 128-bit transport key from region name. - + Matches C++ implementation: void TransportKeyStore::getAutoKeyFor(uint16_t id, const char* name, TransportKey& dest) - + Args: name: Region name including '#' (e.g., "#usa") - + Returns: bytes: 16-byte transport key """ if not name: raise ValueError("Region name cannot be empty") - if not name.startswith('#'): + if not name.startswith("#"): raise ValueError("Region name must start with '#'") if len(name) > 64: raise ValueError("Region name is too long (max 64 characters)") - key_hash = CryptoUtils.sha256(name.encode('ascii')) + key_hash = CryptoUtils.sha256(name.encode("ascii")) return key_hash[:16] # First 16 bytes (128 bits) def calc_transport_code(key: bytes, packet) -> int: """ Calculate transport code for a packet. - + Matches C++ implementation: uint16_t TransportKey::calcTransportCode(const mesh::Packet* packet) const - + Args: key: 16-byte transport key packet: Packet with payload_type and payload - + Returns: int: 16-bit transport code """ @@ -51,20 +52,20 @@ def calc_transport_code(key: bytes, packet) -> int: raise ValueError(f"Transport key must be 16 bytes, got {len(key)}") payload_type = packet.get_payload_type() payload_data = packet.get_payload() - + # HMAC input: payload_type (1 byte) + payload hmac_data = bytes([payload_type]) + payload_data - + # Calculate HMAC-SHA256 hmac_digest = CryptoUtils._hmac_sha256(key, hmac_data) - + # Extract first 2 bytes as little-endian uint16 (matches Arduino platform endianness) - code = struct.unpack(' CompanionBridge: + """Create a minimal CompanionBridge for testing _apply_path_hash_mode.""" + + async def _noop_injector(pkt, wait_for_ack=False): + return True + + bridge = CompanionBridge(LocalIdentity(), _noop_injector, node_name="Test") + bridge.prefs.path_hash_mode = path_hash_mode + return bridge + + +class TestApplyPathHashMode: + def test_encodes_on_zero_hops(self): + """path_hash_mode=1 on a fresh packet (0 hops) → path_len=0x40.""" + bridge = _make_bridge(path_hash_mode=1) + pkt = Packet() + pkt.header = 0x06 + pkt.path_len = 0 + pkt.path = bytearray() + pkt.payload = bytearray(b"test") + pkt.payload_len = 4 + + bridge._apply_path_hash_mode(pkt) + + assert pkt.path_len == 0x40 # (1 << 6) | 0 = 0x40 + assert pkt.get_path_hash_size() == 2 + assert pkt.get_path_hash_count() == 0 + + def test_skips_nonzero_hops(self): + """Packets with existing hops (stored contact path) are untouched.""" + bridge = _make_bridge(path_hash_mode=2) + pkt = Packet() + pkt.header = 0x06 + # 3 hops with 1-byte hashes + pkt.set_path(b"\xAA\xBB\xCC") + pkt.payload = bytearray(b"test") + pkt.payload_len = 4 + + original_path_len = pkt.path_len + bridge._apply_path_hash_mode(pkt) + + # path_len unchanged — the contact path is preserved + assert pkt.path_len == original_path_len + assert pkt.get_path_hash_count() == 3 + + def test_all_modes(self): + """Verify mode 0→0x00, mode 1→0x40, mode 2→0x80 on fresh packets.""" + expected = { + 0: (0x00, 1), # (path_len, hash_size) + 1: (0x40, 2), + 2: (0x80, 3), + } + for mode, (expected_path_len, expected_hash_size) in expected.items(): + bridge = _make_bridge(path_hash_mode=mode) + pkt = Packet() + pkt.header = 0x06 + pkt.path_len = 0 + pkt.path = bytearray() + pkt.payload = bytearray(b"x") + pkt.payload_len = 1 + + bridge._apply_path_hash_mode(pkt) + + assert pkt.path_len == expected_path_len, ( + f"mode={mode}: expected path_len=0x{expected_path_len:02X}, " + f"got 0x{pkt.path_len:02X}" + ) + assert pkt.get_path_hash_size() == expected_hash_size, ( + f"mode={mode}: expected hash_size={expected_hash_size}, " + f"got {pkt.get_path_hash_size()}" + ) + + def test_skips_trace_packets(self): + """Trace packets use path for SNR values, not routing hashes.""" + bridge = _make_bridge(path_hash_mode=1) + pkt = Packet() + # Trace packet: payload_type=PAYLOAD_TYPE_TRACE, route_type=ROUTE_TYPE_DIRECT + pkt.header = (PAYLOAD_TYPE_TRACE << 2) | ROUTE_TYPE_DIRECT + pkt.path_len = 0 + pkt.path = bytearray() + pkt.payload = bytearray(b"trace_data") + pkt.payload_len = 10 + + bridge._apply_path_hash_mode(pkt) + + # path_len must stay 0 — NOT 0x40 + assert pkt.path_len == 0 + assert pkt.get_path_hash_size() == 1 + assert pkt.get_path_hash_count() == 0 diff --git a/tests/test_companion_bridge.py b/tests/test_companion_bridge.py new file mode 100644 index 0000000..5263b9b --- /dev/null +++ b/tests/test_companion_bridge.py @@ -0,0 +1,540 @@ +"""Tests for CompanionBridge (repeater-integrated companion with packet_injector).""" + +import asyncio + +import pytest + +from pymc_core.companion import CompanionBridge +from pymc_core.companion.constants import ADV_TYPE_CHAT, AUTOADD_CHAT +from pymc_core.companion.models import Contact +from pymc_core.node.events import MeshEvents +from pymc_core.protocol import CryptoUtils, Identity, LocalIdentity, Packet +from pymc_core.protocol.constants import ( + PAYLOAD_TYPE_ADVERT, + PAYLOAD_TYPE_PATH, + PAYLOAD_TYPE_RAW_CUSTOM, + PAYLOAD_TYPE_RESPONSE, + PAYLOAD_TYPE_TXT_MSG, + ROUTE_TYPE_FLOOD, +) + + +def _make_peer_contact(name: str) -> Contact: + """Return a contact with a valid Ed25519 public key (required for packet encryption).""" + peer = LocalIdentity() + return Contact(public_key=peer.get_public_key(), name=name) + + +class MockPacketInjector: + """Records injected packets and returns True by default.""" + + def __init__(self): + self.calls: list[tuple] = [] + + async def __call__(self, pkt: Packet, wait_for_ack: bool = False) -> bool: + self.calls.append((pkt, wait_for_ack)) + return True + + +# --------------------------------------------------------------------------- +# Init +# --------------------------------------------------------------------------- + + +class TestCompanionBridgeInit: + def test_init_creates_stores(self): + injector = MockPacketInjector() + identity = LocalIdentity() + bridge = CompanionBridge(identity, injector, node_name="BridgeNode") + assert bridge.contacts is not None + assert bridge.contacts.get_count() == 0 + assert bridge.channels is not None + assert bridge.stats is not None + assert bridge.prefs.node_name == "BridgeNode" + assert bridge.get_public_key() == identity.get_public_key() + assert injector.calls == [] + + def test_init_with_authenticate_callback(self): + def auth_cb(*args, **kwargs): + return (True, 0) + + injector = MockPacketInjector() + bridge = CompanionBridge( + LocalIdentity(), + injector, + authenticate_callback=auth_cb, + ) + assert bridge._handlers is not None + + +# --------------------------------------------------------------------------- +# Lifecycle +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionBridgeLifecycle: + async def test_start_stop(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + assert bridge.is_running is False + await bridge.start() + assert bridge.is_running is True + await bridge.stop() + assert bridge.is_running is False + + +# --------------------------------------------------------------------------- +# Channel updated callback +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionBridgeChannelUpdated: + async def test_set_channel_and_remove_channel_fire_channel_updated(self): + """set_channel and remove_channel fire on_channel_updated(idx, channel_or_none).""" + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + events = [] + + def on_channel_updated(idx: int, ch) -> None: + events.append((idx, ch)) + + bridge.on_channel_updated(on_channel_updated) + await bridge.start() + + ok = bridge.set_channel(0, "General", b"secret_________________________") + assert ok is True + await asyncio.sleep(0) + assert len(events) == 1 + assert events[0][0] == 0 + assert events[0][1] is not None + assert events[0][1].name == "General" + + ok = bridge.remove_channel(0) + assert ok is True + await asyncio.sleep(0) + assert len(events) == 2 + assert events[1] == (0, None) + + await bridge.stop() + + +# --------------------------------------------------------------------------- +# process_received_packet +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionBridgeProcessReceivedPacket: + async def test_process_packet_records_rx_stats(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + await bridge.start() + pkt = Packet() + pkt.header = (ROUTE_TYPE_FLOOD << 0) | (PAYLOAD_TYPE_ADVERT << 2) + pkt.path_len = 0 + pkt.path = bytearray() + pkt.payload = bytearray() + pkt.payload_len = 0 + await bridge.process_received_packet(pkt) + tot = bridge.stats.get_totals() + assert tot["flood_rx"] == 1 + await bridge.stop() + + async def test_process_unknown_type_no_crash(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + pkt = Packet() + pkt.header = (ROUTE_TYPE_FLOOD << 0) | (15 << 2) + pkt.path_len = 0 + pkt.path = bytearray() + pkt.payload = bytearray() + pkt.payload_len = 0 + await bridge.process_received_packet(pkt) + assert True + + async def test_process_received_packet_fires_raw_data_received(self): + """CompanionBridge fires on_raw_data_received(payload, snr, rssi) for RAW_CUSTOM packets.""" + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + raw_calls = [] + + def on_raw(payload: bytes, snr, rssi) -> None: + raw_calls.append((payload, snr, rssi)) + + bridge.on_raw_data_received(on_raw) + await bridge.start() + + pkt = Packet() + pkt.header = (1 << 6) | (PAYLOAD_TYPE_RAW_CUSTOM << 2) + pkt.payload = bytearray(b"\x01\x02\x03\x04") + pkt.payload_len = 4 + pkt.path_len = 0 + pkt._snr = 6.0 + pkt._rssi = -75 + + await bridge.process_received_packet(pkt) + await bridge.stop() + + assert len(raw_calls) == 1 + payload_bytes, snr, rssi = raw_calls[0] + assert payload_bytes == b"\x01\x02\x03\x04" + assert snr == 6.0 + assert rssi == -75 + + +# --------------------------------------------------------------------------- +# Advertise +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionBridgeAdvertise: + async def test_advertise_injects_packet(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + result = await bridge.advertise(flood=True) + assert result is True + assert len(injector.calls) == 1 + pkt, wait_for_ack = injector.calls[0] + assert pkt is not None + assert (pkt.header >> 2) & 0x0F == PAYLOAD_TYPE_ADVERT + assert wait_for_ack is False + assert bridge.stats.get_totals()["flood_tx"] == 1 + + +# --------------------------------------------------------------------------- +# Send text, share contact +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionBridgeSendAndShare: + async def test_send_text_message_no_contact(self, caplog): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + result = await bridge.send_text_message(b"\x00" * 32, "Hi") + assert result.success is False + assert len(injector.calls) == 0 + + async def test_send_text_message_with_contact_injects_packet(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + contact = _make_peer_contact("Alice") + bridge.contacts.add(contact) + await bridge.send_text_message(contact.public_key, "Hello") + assert len(injector.calls) >= 1 + pkt, _ = injector.calls[0] + assert (pkt.header >> 2) & 0x0F == PAYLOAD_TYPE_TXT_MSG + + async def test_share_contact_not_found(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + result = await bridge.share_contact(b"\x00" * 32) + assert result is False + assert len(injector.calls) == 0 + + async def test_share_contact_success(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + key = b"\x22" * 32 + bridge.contacts.add(Contact(public_key=key, name="Bob")) + result = await bridge.share_contact(key) + assert result is True + assert len(injector.calls) == 1 + + async def test_send_raw_data_direct_injects_packet(self): + """send_raw_data_direct builds RAW_CUSTOM packet and sends via injector.""" + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + await bridge.start() + path = b"\x42" + payload = b"\x01\x02\x03\x04" + result = await bridge.send_raw_data_direct(path, payload) + await bridge.stop() + assert result.success is True + assert len(injector.calls) == 1 + pkt, wait_for_ack = injector.calls[0] + assert (pkt.header >> 2) & 0x0F == PAYLOAD_TYPE_RAW_CUSTOM + assert pkt.path == bytearray(path) + assert pkt.path_len == len(path) + assert bytes(pkt.payload) == payload + assert wait_for_ack is False + + +# --------------------------------------------------------------------------- +# Path discovery, trace, control data +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionBridgePathAndControl: + async def test_send_path_discovery_req_no_contact(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + result = await bridge.send_path_discovery_req(b"\x00" * 32) + assert result.success is False + + async def test_send_path_discovery_req_success(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + contact = _make_peer_contact("Target") + bridge.contacts.add(contact) + result = await bridge.send_path_discovery_req(contact.public_key) + assert result.success is True + assert len(injector.calls) == 1 + assert result.timeout_ms == 10000 + + async def test_send_trace_path_raw(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + result = await bridge.send_trace_path_raw(0x12345678, 0xABCD, 0, bytes([0x01, 0x02])) + assert result is True + assert len(injector.calls) == 1 + + async def test_send_control_data_valid_payload(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + result = await bridge.send_control_data(bytes([0x80, 0x01])) + assert result is True + assert len(injector.calls) == 1 + pkt, _ = injector.calls[0] + assert pkt.payload_len == 2 + assert list(pkt.payload) == [0x80, 0x01] + + async def test_send_control_data_rejects_no_high_bit(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + result = await bridge.send_control_data(bytes([0x00, 0x01])) + assert result is False + assert len(injector.calls) == 0 + + +# --------------------------------------------------------------------------- +# Binary request +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionBridgeBinaryReq: + async def test_send_binary_req_no_contact(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + result = await bridge.send_binary_req(b"\x00" * 32, bytes([0x01])) + assert result.success is False + + async def test_send_binary_req_with_contact(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + contact = _make_peer_contact("Rpt") + bridge.contacts.add(contact) + result = await bridge.send_binary_req( + contact.public_key, bytes([0x01]), timeout_seconds=5.0 + ) + assert result.success is True + assert result.expected_ack is not None + assert len(injector.calls) == 1 + + +# --------------------------------------------------------------------------- +# NODE_DISCOVERED -> advert pipeline (contact store + advert_received) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionBridgeNodeDiscoveredAdvertPipeline: + async def test_node_discovered_adds_contact_and_fires_advert_received(self): + """Single path: NODE_DISCOVERED event drives store + advert_received (Bridge and Radio).""" + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + peer = LocalIdentity() + pub_key_hex = peer.get_public_key().hex() + event_data = { + "public_key": pub_key_hex, + "name": "DiscoveredNode", + "contact_type": ADV_TYPE_CHAT, + "lat": 52.0, + "lon": -1.0, + "advert_timestamp": 1000, + "timestamp": 1001, + "snr": 5.0, + "rssi": -80, + } + advert_received_calls = [] + + def on_advert(c): + advert_received_calls.append(c) + + bridge.on_advert_received(on_advert) + await bridge._handle_mesh_event(MeshEvents.NODE_DISCOVERED, event_data) + assert bridge.contacts.get_count() == 1 + assert len(advert_received_calls) == 1 + assert advert_received_calls[0].name == "DiscoveredNode" + assert advert_received_calls[0].public_key == peer.get_public_key() + + async def test_one_node_discovered_event_produces_exactly_one_advert_received(self): + """Single-path guarantee: one NODE_DISCOVERED event yields exactly one + advert_received callback (no duplicate path, no duplicate push frames). + """ + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + peer = LocalIdentity() + event_data = { + "public_key": peer.get_public_key().hex(), + "name": "SinglePathNode", + "contact_type": ADV_TYPE_CHAT, + "lat": 0.0, + "lon": 0.0, + "advert_timestamp": 1000, + "timestamp": 1000, + "snr": 0.0, + "rssi": 0, + } + advert_received_calls = [] + bridge.on_advert_received(advert_received_calls.append) + await bridge._handle_mesh_event(MeshEvents.NODE_DISCOVERED, event_data) + assert len(advert_received_calls) == 1 + assert advert_received_calls[0].name == "SinglePathNode" + + async def test_path_packet_updates_contact_path_and_fires_contact_path_updated_once(self): + """PATH packet that decrypts updates contact out_path and fires contact_path_updated.""" + injector = MockPacketInjector() + local_identity = LocalIdentity() + peer_identity = LocalIdentity() + peer_pubkey = peer_identity.get_public_key() + bridge = CompanionBridge(local_identity, injector) + bridge.contacts.add(Contact(public_key=peer_pubkey, name="Peer")) + + path_len_byte = 2 + path_bytes = bytes([0x01, 0x02]) + extra_type = PAYLOAD_TYPE_RESPONSE + extra = bytes([0, 0, 0, 0, 0x00]) + plaintext = bytes([path_len_byte]) + path_bytes + bytes([extra_type]) + extra + peer_id = Identity(peer_pubkey) + shared_secret = peer_id.calc_shared_secret(local_identity.get_private_key()) + aes_key = shared_secret[:16] + encrypted = CryptoUtils.encrypt_then_mac(aes_key, shared_secret, plaintext) + our_hash = local_identity.get_public_key()[0] + src_hash = peer_pubkey[0] + payload = bytes([our_hash, src_hash]) + encrypted + + pkt = Packet() + pkt.header = (ROUTE_TYPE_FLOOD << 0) | (PAYLOAD_TYPE_PATH << 2) + pkt.path_len = 0 + pkt.path = bytearray() + pkt.payload = bytearray(payload) + pkt.payload_len = len(payload) + + path_updated_calls = [] + + async def on_path_updated(contact): + path_updated_calls.append(contact) + + bridge.on_contact_path_updated(on_path_updated) + await bridge.process_received_packet(pkt) + + assert len(path_updated_calls) == 1 + assert path_updated_calls[0].public_key == peer_pubkey + assert path_updated_calls[0].out_path_len == path_len_byte + assert path_updated_calls[0].out_path == path_bytes + contact = bridge.contacts.get_by_key(peer_pubkey) + assert contact is not None + assert contact.out_path_len == path_len_byte + assert contact.out_path == path_bytes + + async def test_node_discovered_fires_node_discovered_even_when_filtered(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + bridge.prefs.manual_add_contacts = 1 + bridge.prefs.autoadd_config = AUTOADD_CHAT + peer = LocalIdentity() + event_data = { + "public_key": peer.get_public_key().hex(), + "name": "RepeaterNode", + "contact_type": 2, + "lat": 0.0, + "lon": 0.0, + "advert_timestamp": 1000, + "timestamp": 1000, + "snr": 0.0, + "rssi": 0, + } + node_discovered_calls = [] + advert_received_calls = [] + + def on_node(data): + node_discovered_calls.append(data) + + def on_advert(c): + advert_received_calls.append(c) + + bridge.on_node_discovered(on_node) + bridge.on_advert_received(on_advert) + await bridge._handle_mesh_event(MeshEvents.NODE_DISCOVERED, event_data) + assert bridge.contacts.get_count() == 0 + assert len(advert_received_calls) == 0 + assert len(node_discovered_calls) == 1 + assert node_discovered_calls[0]["name"] == "RepeaterNode" + + async def test_node_discovered_event_path_adds_contact_and_fires_advert_received(self): + """Event path with optional inbound_path: store updated, advert_received fired once.""" + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + peer = LocalIdentity() + pub_key_hex = peer.get_public_key().hex() + event_data = { + "public_key": pub_key_hex, + "name": "AdvertNode", + "contact_type": ADV_TYPE_CHAT, + "lat": 0.0, + "lon": 0.0, + "advert_timestamp": 1000, + "timestamp": 1000, + "snr": 0.0, + "rssi": 0, + "inbound_path": b"\x01\x02\x03", + } + advert_received_calls = [] + + def on_advert(c): + advert_received_calls.append(c) + + bridge.on_advert_received(on_advert) + await bridge._handle_mesh_event(MeshEvents.NODE_DISCOVERED, event_data) + assert bridge.contacts.get_count() == 1 + assert len(advert_received_calls) == 1 + assert advert_received_calls[0].name == "AdvertNode" + # Second event (same contact): update, still one contact, advert_received again + await bridge._handle_mesh_event(MeshEvents.NODE_DISCOVERED, event_data) + assert bridge.contacts.get_count() == 1 + assert len(advert_received_calls) == 2 + + +# --------------------------------------------------------------------------- +# Deduplication (direct messages by packet_hash) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionBridgeDeduplication: + async def test_direct_message_deduplicated_by_packet_hash(self): + injector = MockPacketInjector() + bridge = CompanionBridge(LocalIdentity(), injector) + key_hex = LocalIdentity().get_public_key().hex() + same_hash = "A1B2C3D4E5F6" + data = { + "contact_pubkey": key_hex, + "message_text": "Hello", + "timestamp": 1000, + "txt_type": 0, + "packet_hash": same_hash, + } + await bridge._handle_mesh_event(MeshEvents.NEW_MESSAGE, data) + await bridge._handle_mesh_event(MeshEvents.NEW_MESSAGE, data) + await bridge._handle_mesh_event(MeshEvents.NEW_MESSAGE, data) + assert bridge.message_queue.count == 1 + msg = bridge.sync_next_message() + assert msg is not None + assert msg.text == "Hello" + assert bridge.sync_next_message() is None diff --git a/tests/test_companion_radio.py b/tests/test_companion_radio.py new file mode 100644 index 0000000..38ef497 --- /dev/null +++ b/tests/test_companion_radio.py @@ -0,0 +1,375 @@ +"""Tests for CompanionRadio (stand-alone companion with radio).""" + +import pytest + +from pymc_core.companion import CompanionRadio +from pymc_core.companion.constants import ADV_TYPE_CHAT +from pymc_core.companion.models import Contact +from pymc_core.protocol import LocalIdentity, Packet +from pymc_core.protocol.constants import PAYLOAD_TYPE_ACK + + +def _make_peer_contact(name: str) -> Contact: + """Return a contact with a valid Ed25519 public key (required for packet encryption).""" + peer = LocalIdentity() + return Contact(public_key=peer.get_public_key(), name=name) + + +class MockRadio: + """Mock radio for CompanionRadio: set_rx_callback, send, optional RSSI/SNR.""" + + def __init__(self): + self.rx_callback = None + self.sent: list[bytes] = [] + + def set_rx_callback(self, callback): + self.rx_callback = callback + + async def send(self, data: bytes) -> bool: + self.sent.append(data) + return True + + def get_last_rssi(self): + return -70 + + def get_last_snr(self): + return 5 + + +# --------------------------------------------------------------------------- +# Init and lifecycle +# --------------------------------------------------------------------------- + + +class TestCompanionRadioInit: + def test_init_creates_stores(self): + radio = MockRadio() + identity = LocalIdentity() + comp = CompanionRadio(radio, identity, node_name="TestNode") + assert comp.contacts is not None + assert comp.contacts.get_count() == 0 + assert comp.channels is not None + assert comp.message_queue is not None + assert comp.path_cache is not None + assert comp.stats is not None + assert comp.prefs.node_name == "TestNode" + assert comp.prefs.adv_type == ADV_TYPE_CHAT + assert comp.get_public_key() == identity.get_public_key() + assert comp.node is not None + assert comp.node.dispatcher is not None + + def test_init_passes_contacts_to_node(self): + radio = MockRadio() + identity = LocalIdentity() + comp = CompanionRadio(radio, identity) + comp.contacts.add(Contact(public_key=b"\x01" * 32, name="Alice")) + assert comp.node.contacts is comp.contacts + assert comp.node.contacts.get_by_name("Alice") is not None + + def test_initial_contacts_populates_store_on_boot(self): + radio = MockRadio() + identity = LocalIdentity() + alice = _make_peer_contact("Alice") + bob = _make_peer_contact("Bob") + comp = CompanionRadio(radio, identity, node_name="TestNode", initial_contacts=[alice, bob]) + assert comp.contacts.get_count() == 2 + assert comp.get_contact_by_name("Alice") is not None + assert comp.get_contact_by_name("Bob") is not None + + +@pytest.mark.asyncio +class TestCompanionRadioLifecycle: + async def test_start_stop(self): + radio = MockRadio() + identity = LocalIdentity() + comp = CompanionRadio(radio, identity) + assert comp.is_running is False + await comp.start() + assert comp.is_running is True + await comp.stop() + assert comp.is_running is False + + async def test_start_idempotent_warning(self, caplog): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + await comp.start() + await comp.start() + await comp.stop() + assert "already running" in caplog.text.lower() or True + + async def test_rx_log_data_callback_fired_on_raw_packet(self): + """CompanionRadio fires on_rx_log_data(snr, rssi, raw_bytes) for each RX.""" + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + log_calls = [] + + def on_log(snr: float, rssi: int, raw: bytes) -> None: + log_calls.append((snr, rssi, raw)) + + comp.on_rx_log_data(on_log) + await comp.start() + + # Build minimal valid packet (ACK) so dispatcher parses and notifies raw subscribers + pkt = Packet() + pkt.header = (1 << 6) | (PAYLOAD_TYPE_ACK << 2) + pkt.payload = bytearray(b"\x01\x02\x03\x04") + pkt.payload_len = 4 + pkt.path_len = 0 + raw = pkt.write_to() + + await comp.node.dispatcher._process_received_packet(raw, rssi=-75, snr=6.0) + await comp.stop() + + assert len(log_calls) == 1 + snr, rssi, data = log_calls[0] + assert snr == 6.0 + assert rssi == -75 + assert data == raw + + +# --------------------------------------------------------------------------- +# Contact management (base API via radio) +# --------------------------------------------------------------------------- + + +class TestCompanionRadioContacts: + def test_add_and_get_contact(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + key = b"\x02" * 32 + comp.add_update_contact(Contact(public_key=key, name="Bob")) + assert comp.get_contact_by_key(key) is not None + assert comp.get_contact_by_key(key).name == "Bob" + assert comp.get_contact_by_name("Bob") is not None + + def test_import_contact_packet_data(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + # 73 bytes: 32 key + 1 adv_type + 32 name (padded) + 4 lat + 4 lon + name_padded = b"Charlie\x00" * 4 # 32 bytes + packet_data = b"\x03" * 32 + bytes([1]) + name_padded + (0).to_bytes(4, "little") * 2 + assert comp.import_contact(packet_data) is True + contacts = comp.get_contacts() + assert len(contacts) == 1 + assert contacts[0].name.startswith("Charlie") + + def test_export_contact_self(self): + radio = MockRadio() + identity = LocalIdentity() + comp = CompanionRadio(radio, identity, node_name="Me") + data = comp.export_contact(None) + assert data is not None + assert len(data) >= 73 + assert data[:32] == identity.get_public_key() + + +# --------------------------------------------------------------------------- +# Advertise +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionRadioAdvertise: + async def test_advertise_sends_packet(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + result = await comp.advertise(flood=True) + assert result is True + assert len(radio.sent) == 1 + assert comp.stats.get_totals()["flood_tx"] == 1 + + async def test_advertise_direct(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + await comp.advertise(flood=False) + assert len(radio.sent) == 1 + assert comp.stats.get_totals()["direct_tx"] == 1 + + +# --------------------------------------------------------------------------- +# Send text (requires contact) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionRadioSendText: + async def test_send_text_message_no_contact(self, caplog): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + result = await comp.send_text_message(b"\x00" * 32, "Hi") + assert result.success is False + assert "contact not found" in caplog.text.lower() or "Contact not found" in caplog.text + + async def test_send_text_message_with_contact_sends_packet(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + contact = _make_peer_contact("Alice") + comp.contacts.add(contact) + result = await comp.send_text_message(contact.public_key, "Hello") + assert len(radio.sent) >= 1 + # success may be False if no ACK (mock radio doesn't echo ACK) + assert result.success is False or result.success is True + + +# --------------------------------------------------------------------------- +# Share contact, channel message, sync message +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionRadioMisc: + async def test_share_contact_not_found(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + result = await comp.share_contact(b"\x00" * 32) + assert result is False + + async def test_share_contact_success(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + key = b"\x22" * 32 + comp.contacts.add(Contact(public_key=key, name="Bob")) + result = await comp.share_contact(key) + assert result is True + assert len(radio.sent) == 1 + + async def test_sync_next_message_empty(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + assert comp.sync_next_message() is None + + async def test_send_channel_message_no_channel(self, caplog): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + result = await comp.send_channel_message(0, "Hi") + assert result is False + + +# --------------------------------------------------------------------------- +# Path discovery, trace, control data +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionRadioPathAndControl: + async def test_send_path_discovery_no_contact(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + result = await comp.send_path_discovery(b"\x00" * 32) + assert result is False + + async def test_send_path_discovery_req_sends(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + contact = _make_peer_contact("Target") + comp.contacts.add(contact) + result = await comp.send_path_discovery_req(contact.public_key) + assert result.success is True + assert len(radio.sent) == 1 + + async def test_send_trace_path_raw(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + result = await comp.send_trace_path_raw(0x12345678, 0xABCD, 0, bytes([0x01, 0x02])) + assert result is True + assert len(radio.sent) == 1 + + async def test_send_control_data_default_discovery(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + result = await comp.send_control_data() + assert result is True + assert len(radio.sent) == 1 + + async def test_send_control_data_raw_payload(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + result = await comp.send_control_data(bytes([0x80, 0x04])) + assert result is True + assert len(radio.sent) == 1 + + async def test_contact_path_updated_fired_when_handler_callback_invoked(self): + """Radio wires protocol_response_handler contact_path_updated to _fire_callbacks.""" + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + path_updated_calls = [] + + async def on_path_updated(contact): + path_updated_calls.append(contact) + + comp.on_contact_path_updated(on_path_updated) + proto = comp.node.dispatcher.protocol_response_handler + assert proto is not None + assert proto._contact_path_updated_callback is not None + + pub = b"\x22" * 32 + # Contact must exist in the store; path updates for unknown contacts + # are silently dropped (matches companion firmware behaviour). + comp.contacts.add(Contact(public_key=pub, name="test")) + + path_len = 2 + path_bytes = bytes([0x01, 0x02]) + cb_result = proto._contact_path_updated_callback(pub, path_len, path_bytes) + if hasattr(cb_result, "__await__"): + await cb_result + + assert len(path_updated_calls) == 1 + assert path_updated_calls[0].public_key == pub + assert path_updated_calls[0].out_path_len == path_len + assert path_updated_calls[0].out_path == path_bytes + + +# --------------------------------------------------------------------------- +# Stats and config +# --------------------------------------------------------------------------- + + +class TestCompanionRadioStats: + def test_get_stats_core(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + comp.contacts.add(Contact(public_key=b"\x01" * 32, name="A")) + core = comp.get_stats(0) + assert "contacts_count" in core + assert core["contacts_count"] == 1 + assert "queue_len" in core + assert "uptime_secs" in core + + def test_get_stats_packets(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + tot = comp.get_stats(2) + assert "flood_tx" in tot + assert "direct_rx" in tot + assert "tx_errors" in tot + + +# --------------------------------------------------------------------------- +# Binary request and repeater command (delegate to node) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCompanionRadioBinaryAndRepeater: + async def test_send_binary_req_no_contact(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + result = await comp.send_binary_req(b"\x00" * 32, bytes([0x01])) + assert result.success is False + + async def test_send_binary_req_with_contact(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + contact = _make_peer_contact("Rpt") + comp.contacts.add(contact) + result = await comp.send_binary_req(contact.public_key, bytes([0x01]), timeout_seconds=5.0) + assert result.success is True + assert result.expected_ack is not None + assert len(radio.sent) == 1 + + async def test_send_repeater_command_no_contact(self): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + out = await comp.send_repeater_command(b"\x00" * 32, "status") + assert out["success"] is False + assert "not found" in out["reason"].lower() diff --git a/tests/test_companion_regions.py b/tests/test_companion_regions.py new file mode 100644 index 0000000..2c935f3 --- /dev/null +++ b/tests/test_companion_regions.py @@ -0,0 +1,245 @@ +"""Tests for companion flood-scope / region support.""" + +from __future__ import annotations + +import pytest + +from pymc_core.companion import CompanionRadio +from pymc_core.protocol import LocalIdentity, Packet, PacketBuilder +from pymc_core.protocol.constants import ( + ROUTE_TYPE_DIRECT, + ROUTE_TYPE_FLOOD, + ROUTE_TYPE_TRANSPORT_FLOOD, +) +from pymc_core.protocol.transport_keys import calc_transport_code, get_auto_key_for + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_flood_packet() -> Packet: + """Create a minimal flood-routed advert packet for testing.""" + identity = LocalIdentity() + return PacketBuilder.create_advert( + local_identity=identity, + name="test", + route_type="flood", + ) + + +def _make_direct_packet() -> Packet: + """Create a minimal direct-routed advert packet for testing.""" + identity = LocalIdentity() + return PacketBuilder.create_advert( + local_identity=identity, + name="test", + route_type="direct", + ) + + +class MockRadio: + """Minimal mock radio for CompanionRadio.""" + + def __init__(self): + self.rx_callback = None + self.sent: list[bytes] = [] + + def set_rx_callback(self, callback): + self.rx_callback = callback + + async def send(self, data: bytes) -> bool: + self.sent.append(data) + return True + + +def _make_companion() -> CompanionRadio: + """Create a CompanionRadio with a mock radio for testing.""" + radio = MockRadio() + identity = LocalIdentity() + return CompanionRadio(radio=radio, identity=identity, node_name="test") + + +# --------------------------------------------------------------------------- +# _apply_flood_scope unit tests +# --------------------------------------------------------------------------- + + +class TestApplyFloodScope: + def test_sets_transport_codes_on_flood_packet(self): + companion = _make_companion() + key = get_auto_key_for("#usa") + companion.set_flood_scope(key) + pkt = _make_flood_packet() + + companion._apply_flood_scope(pkt) + + assert pkt.get_route_type() == ROUTE_TYPE_TRANSPORT_FLOOD + assert pkt.transport_codes[0] != 0 + assert pkt.transport_codes[1] == 0 + + def test_transport_code_matches_calc(self): + companion = _make_companion() + key = get_auto_key_for("#test-region") + companion.set_flood_scope(key) + pkt = _make_flood_packet() + + expected_code = calc_transport_code(key, pkt) + companion._apply_flood_scope(pkt) + + assert pkt.transport_codes[0] == expected_code + + def test_noop_when_no_key_set(self): + companion = _make_companion() + pkt = _make_flood_packet() + original_header = pkt.header + + companion._apply_flood_scope(pkt) + + assert pkt.header == original_header + assert pkt.get_route_type() == ROUTE_TYPE_FLOOD + assert pkt.transport_codes == [0, 0] + + def test_noop_on_direct_packet(self): + companion = _make_companion() + key = get_auto_key_for("#usa") + companion.set_flood_scope(key) + pkt = _make_direct_packet() + original_header = pkt.header + + companion._apply_flood_scope(pkt) + + assert pkt.header == original_header + assert pkt.get_route_type() == ROUTE_TYPE_DIRECT + assert pkt.transport_codes == [0, 0] + + +# --------------------------------------------------------------------------- +# set_flood_region tests +# --------------------------------------------------------------------------- + + +class TestSetFloodRegion: + def test_derives_key_with_hash_prefix(self): + companion = _make_companion() + companion.set_flood_region("#usa") + assert companion._flood_transport_key == get_auto_key_for("#usa") + + def test_auto_adds_hash_prefix(self): + companion = _make_companion() + companion.set_flood_region("usa") + assert companion._flood_transport_key == get_auto_key_for("#usa") + + def test_clear_with_none(self): + companion = _make_companion() + companion.set_flood_region("usa") + assert companion._flood_transport_key is not None + companion.set_flood_region(None) + assert companion._flood_transport_key is None + + def test_same_key_with_or_without_prefix(self): + c1 = _make_companion() + c2 = _make_companion() + c1.set_flood_region("europe") + c2.set_flood_region("#europe") + assert c1._flood_transport_key == c2._flood_transport_key + + +# --------------------------------------------------------------------------- +# set_flood_scope tests +# --------------------------------------------------------------------------- + + +class TestSetFloodScope: + def test_stores_16_byte_key(self): + companion = _make_companion() + key = b"\x01" * 16 + companion.set_flood_scope(key) + assert companion._flood_transport_key == key + + def test_truncates_longer_key(self): + companion = _make_companion() + key = b"\x02" * 32 + companion.set_flood_scope(key) + assert companion._flood_transport_key == b"\x02" * 16 + + def test_clear_with_none(self): + companion = _make_companion() + companion.set_flood_scope(b"\x01" * 16) + companion.set_flood_scope(None) + assert companion._flood_transport_key is None + + +# --------------------------------------------------------------------------- +# CompanionRadio dispatcher sync +# --------------------------------------------------------------------------- + + +class TestRadioDispatcherSync: + def test_set_flood_scope_syncs_to_dispatcher(self): + companion = _make_companion() + key = get_auto_key_for("#test") + companion.set_flood_scope(key) + assert companion.node.dispatcher.flood_transport_key == key + + def test_set_flood_region_syncs_to_dispatcher(self): + companion = _make_companion() + companion.set_flood_region("test") + expected = get_auto_key_for("#test") + assert companion.node.dispatcher.flood_transport_key == expected + + def test_clear_syncs_to_dispatcher(self): + companion = _make_companion() + companion.set_flood_scope(b"\x01" * 16) + assert companion.node.dispatcher.flood_transport_key is not None + companion.set_flood_scope(None) + assert companion.node.dispatcher.flood_transport_key is None + + +# --------------------------------------------------------------------------- +# Integration: advertise with flood scope +# --------------------------------------------------------------------------- + + +class TestAdvertiseWithFloodScope: + @pytest.mark.asyncio + async def test_advertise_flood_with_scope_sends_transport_flood(self): + radio = MockRadio() + identity = LocalIdentity() + companion = CompanionRadio(radio=radio, identity=identity, node_name="scoped") + companion.set_flood_region("usa") + + await companion.start() + try: + await companion.advertise(flood=True) + finally: + await companion.stop() + + # Verify the sent packet has transport codes + assert len(radio.sent) > 0 + raw = radio.sent[-1] + pkt = Packet() + pkt.read_from(raw) + assert pkt.get_route_type() == ROUTE_TYPE_TRANSPORT_FLOOD + assert pkt.transport_codes[0] != 0 + assert pkt.transport_codes[1] == 0 + + @pytest.mark.asyncio + async def test_advertise_flood_without_scope_sends_normal_flood(self): + radio = MockRadio() + identity = LocalIdentity() + companion = CompanionRadio(radio=radio, identity=identity, node_name="noscope") + # No flood scope set + + await companion.start() + try: + await companion.advertise(flood=True) + finally: + await companion.stop() + + assert len(radio.sent) > 0 + raw = radio.sent[-1] + pkt = Packet() + pkt.read_from(raw) + assert pkt.get_route_type() == ROUTE_TYPE_FLOOD + assert pkt.transport_codes == [0, 0] diff --git a/tests/test_companion_stores.py b/tests/test_companion_stores.py new file mode 100644 index 0000000..c778df1 --- /dev/null +++ b/tests/test_companion_stores.py @@ -0,0 +1,422 @@ +"""Tests for companion stores and models: ContactStore, ChannelStore, MessageQueue, PathCache.""" + +from pymc_core.companion import ChannelStore, ContactStore, MessageQueue, PathCache, StatsCollector +from pymc_core.companion.models import ( + AdvertPath, + Channel, + Contact, + NodePrefs, + QueuedMessage, + SentResult, +) + +# --------------------------------------------------------------------------- +# Models +# --------------------------------------------------------------------------- + + +class TestContact: + def test_contact_defaults(self): + key = b"\x00" * 32 + c = Contact(public_key=key, name="alice") + assert c.public_key == key + assert c.name == "alice" + assert c.adv_type == 0 + assert c.out_path_len == -1 + assert c.out_path == b"" + assert c.gps_lat == 0.0 + assert c.gps_lon == 0.0 + + def test_contact_with_path(self): + c = Contact( + public_key=b"\x01" * 32, + name="bob", + out_path_len=3, + out_path=bytes([0xAA, 0xBB, 0xCC]), + ) + assert c.out_path_len == 3 + assert c.out_path == bytes([0xAA, 0xBB, 0xCC]) + + +class TestSentResult: + def test_sent_result_minimal(self): + r = SentResult(success=False) + assert r.success is False + assert r.is_flood is False + assert r.expected_ack is None + assert r.timeout_ms is None + + def test_sent_result_full(self): + r = SentResult(success=True, is_flood=True, expected_ack=0x1234, timeout_ms=5000) + assert r.success is True + assert r.is_flood is True + assert r.expected_ack == 0x1234 + assert r.timeout_ms == 5000 + + +class TestNodePrefs: + def test_node_prefs_defaults(self): + p = NodePrefs() + assert p.node_name == "pyMC" + assert p.adv_type == 1 + assert p.tx_power_dbm == 20 + assert p.frequency_hz == 915000000 + + def test_node_prefs_custom(self): + p = NodePrefs(node_name="TestNode", adv_type=2) + assert p.node_name == "TestNode" + assert p.adv_type == 2 + + +class TestQueuedMessage: + def test_queued_message_direct(self): + key = b"\x02" * 32 + msg = QueuedMessage(sender_key=key, text="Hello", timestamp=1000) + assert msg.sender_key == key + assert msg.text == "Hello" + assert msg.timestamp == 1000 + assert msg.is_channel is False + assert msg.channel_idx == 0 + + def test_queued_message_channel(self): + msg = QueuedMessage( + sender_key=b"", + text="Sender: Hi", + is_channel=True, + channel_idx=2, + ) + assert msg.is_channel is True + assert msg.channel_idx == 2 + + +class TestAdvertPath: + def test_advert_path(self): + prefix = b"\x03" * 7 + ap = AdvertPath( + public_key_prefix=prefix, + name="sensor1", + path_len=2, + path=bytes([1, 2]), + recv_timestamp=12345, + ) + assert ap.public_key_prefix == prefix + assert ap.name == "sensor1" + assert ap.path_len == 2 + assert ap.path == bytes([1, 2]) + assert ap.recv_timestamp == 12345 + + +# --------------------------------------------------------------------------- +# ContactStore +# --------------------------------------------------------------------------- + + +class TestContactStore: + def test_empty_store(self): + store = ContactStore(max_contacts=10) + assert store.get_count() == 0 + assert store.get_all() == [] + assert store.get_by_key(b"\x00" * 32) is None + assert store.get_by_name("nobody") is None + assert store.is_full() is False + assert store.contacts == [] + assert store.list_contacts() == [] + + def test_add_and_get_by_key(self): + store = ContactStore(max_contacts=5) + key = b"\x11" * 32 + contact = Contact(public_key=key, name="Alice") + assert store.add(contact) is True + assert store.get_count() == 1 + assert store.get_by_key(key) is contact + assert store.get_by_key(b"\x22" * 32) is None + + def test_add_and_get_by_name(self): + store = ContactStore(max_contacts=5) + key = b"\x11" * 32 + contact = Contact(public_key=key, name="Bob") + store.add(contact) + proxy = store.get_by_name("Bob") + assert proxy is not None + assert proxy.name == "Bob" + assert proxy.public_key == key.hex() + assert store.get_by_name("Charlie") is None + + def test_update_existing(self): + store = ContactStore(max_contacts=5) + key = b"\x11" * 32 + store.add(Contact(public_key=key, name="Alice")) + updated = Contact(public_key=key, name="AliceUpdated", gps_lat=1.0) + assert store.update(updated) is True + c = store.get_by_key(key) + assert c.name == "AliceUpdated" + assert c.gps_lat == 1.0 + + def test_remove(self): + store = ContactStore(max_contacts=5) + key = b"\x11" * 32 + store.add(Contact(public_key=key, name="Alice")) + assert store.remove(key) is True + assert store.get_count() == 0 + assert store.get_by_key(key) is None + assert store.remove(key) is False + + def test_max_contacts(self): + store = ContactStore(max_contacts=2) + store.add(Contact(public_key=b"\x01" * 32, name="A")) + store.add(Contact(public_key=b"\x02" * 32, name="B")) + assert store.add(Contact(public_key=b"\x03" * 32, name="C")) is False + assert store.get_count() == 2 + assert store.is_full() is True + + def test_get_all_since(self): + store = ContactStore(max_contacts=10) + store.add(Contact(public_key=b"\x01" * 32, name="A", lastmod=100)) + store.add(Contact(public_key=b"\x02" * 32, name="B", lastmod=200)) + store.add(Contact(public_key=b"\x03" * 32, name="C", lastmod=150)) + all_c = store.get_all() + assert len(all_c) == 3 + since_150 = store.get_all(since=150) + assert len(since_150) == 2 + + def test_clear(self): + store = ContactStore(max_contacts=5) + store.add(Contact(public_key=b"\x01" * 32, name="A")) + store.clear() + assert store.get_count() == 0 + assert store.get_by_name("A") is None + + def test_load_from(self): + store = ContactStore(max_contacts=10) + contacts = [Contact(public_key=bytes([i] * 32), name=f"C{i}") for i in range(3)] + store.load_from(contacts) + assert store.get_count() == 3 + assert store.get_by_name("C1").name == "C1" + + def test_load_from_dicts(self): + store = ContactStore(max_contacts=10) + store.load_from_dicts( + [ + {"public_key": "a1" * 32, "name": "DictAlice"}, + {"public_key": "b2" * 32, "name": "DictBob"}, + ] + ) + assert store.get_count() == 2 + assert store.get_by_name("DictAlice") is not None + assert store.get_by_name("DictBob") is not None + + def test_to_dicts(self): + store = ContactStore(max_contacts=5) + store.add(Contact(public_key=b"\xaa" * 32, name="Export", adv_type=1)) + dicts = store.to_dicts() + assert len(dicts) == 1 + assert dicts[0]["name"] == "Export" + assert dicts[0]["public_key"] == "aa" * 32 + assert dicts[0]["adv_type"] == 1 + + def test_get_by_key_prefix(self): + store = ContactStore(max_contacts=5) + key = b"\x11\x22\x33" + b"\x00" * 29 + store.add(Contact(public_key=key, name="Prefix")) + assert store.get_by_key_prefix(b"\x11\x22") is not None + assert store.get_by_key_prefix(b"\x11\x22\x33").name == "Prefix" + assert store.get_by_key_prefix(b"\xff\xff") is None + + +# --------------------------------------------------------------------------- +# ChannelStore +# --------------------------------------------------------------------------- + + +class TestChannelStore: + def test_empty_channels(self): + store = ChannelStore(max_channels=8) + assert store.get_count() == 0 + assert store.get(0) is None + assert store.get_channels() == [] + assert store.find_by_name("any") is None + + def test_set_and_get(self): + store = ChannelStore(max_channels=8) + ch = Channel(name="general", secret=b"\x11" * 16) + assert store.set(0, ch) is True + assert store.get(0) is ch + assert store.get_count() == 1 + assert store.get_channels() == [{"name": "general", "secret": "11" * 16}] + + def test_find_by_name(self): + store = ChannelStore(max_channels=8) + store.set(0, Channel(name="alpha", secret=b"\x00" * 16)) + store.set(1, Channel(name="beta", secret=b"\x01" * 16)) + assert store.find_by_name("alpha") == 0 + assert store.find_by_name("beta") == 1 + assert store.find_by_name("gamma") is None + + def test_remove(self): + store = ChannelStore(max_channels=8) + store.set(0, Channel(name="x", secret=b"\x00" * 16)) + assert store.remove(0) is True + assert store.get(0) is None + assert store.remove(0) is False + assert store.remove(99) is False + + def test_clear(self): + store = ChannelStore(max_channels=8) + store.set(0, Channel(name="a", secret=b"\x00" * 16)) + store.clear() + assert store.get_count() == 0 + assert store.get(0) is None + + def test_out_of_range(self): + store = ChannelStore(max_channels=4) + ch = Channel(name="x", secret=b"\x00" * 16) + assert store.set(-1, ch) is False + assert store.set(4, ch) is False + assert store.get(4) is None + + +# --------------------------------------------------------------------------- +# MessageQueue +# --------------------------------------------------------------------------- + + +class TestMessageQueue: + def test_empty_queue(self): + q = MessageQueue(max_size=5) + assert q.count == 0 + assert q.is_empty() is True + assert q.is_full() is False + assert q.pop() is None + assert q.peek() is None + + def test_push_and_pop(self): + q = MessageQueue(max_size=5) + msg = QueuedMessage(sender_key=b"\x00" * 32, text="Hi") + q.push(msg) + assert q.count == 1 + assert q.peek() is msg + assert q.pop() is msg + assert q.count == 0 + assert q.pop() is None + + def test_maxlen_drops_oldest(self): + q = MessageQueue(max_size=2) + q.push(QueuedMessage(sender_key=b"\x01" * 32, text="1")) + q.push(QueuedMessage(sender_key=b"\x02" * 32, text="2")) + q.push(QueuedMessage(sender_key=b"\x03" * 32, text="3")) + assert q.count == 2 + first = q.pop() + assert first.text == "2" + assert q.pop().text == "3" + + def test_clear(self): + q = MessageQueue(max_size=5) + q.push(QueuedMessage(sender_key=b"\x00" * 32, text="x")) + q.clear() + assert q.count == 0 + assert q.pop() is None + + +# --------------------------------------------------------------------------- +# PathCache +# --------------------------------------------------------------------------- + + +class TestPathCache: + def test_empty_cache(self): + cache = PathCache(max_entries=8) + assert cache.get_all() == [] + assert cache.get_by_prefix(b"\x00" * 7) is None + + def test_update_and_get(self): + cache = PathCache(max_entries=8) + ap = AdvertPath( + public_key_prefix=b"\x01" * 7, + name="n1", + path_len=2, + path=bytes([1, 2]), + recv_timestamp=100, + ) + cache.update(ap) + assert len(cache.get_all()) == 1 + found = cache.get_by_prefix(b"\x01" * 5) + assert found is not None + assert found.name == "n1" + assert found.path == bytes([1, 2]) + + def test_update_replaces_same_prefix(self): + cache = PathCache(max_entries=8) + prefix = b"\x02" * 7 + cache.update(AdvertPath(public_key_prefix=prefix, name="v1", path=bytes([1]))) + cache.update(AdvertPath(public_key_prefix=prefix, name="v2", path=bytes([2, 2]))) + assert len(cache.get_all()) == 1 + assert cache.get_by_prefix(prefix).name == "v2" + assert cache.get_by_prefix(prefix).path == bytes([2, 2]) + + def test_eviction_when_full(self): + cache = PathCache(max_entries=2) + cache.update(AdvertPath(public_key_prefix=b"\x01" * 7, name="1", path=b"")) + cache.update(AdvertPath(public_key_prefix=b"\x02" * 7, name="2", path=b"")) + cache.update(AdvertPath(public_key_prefix=b"\x03" * 7, name="3", path=b"")) + assert len(cache.get_all()) == 2 + assert cache.get_by_prefix(b"\x01" * 7) is None + assert cache.get_by_prefix(b"\x02" * 7) is not None + assert cache.get_by_prefix(b"\x03" * 7) is not None + + def test_clear(self): + cache = PathCache(max_entries=8) + cache.update(AdvertPath(public_key_prefix=b"\x01" * 7, name="x", path=b"")) + cache.clear() + assert cache.get_all() == [] + assert cache.get_by_prefix(b"\x01" * 7) is None + + +# --------------------------------------------------------------------------- +# StatsCollector +# --------------------------------------------------------------------------- + + +class TestStatsCollector: + def test_initial_state(self): + s = StatsCollector() + assert s.packets.flood_tx == 0 + assert s.packets.direct_rx == 0 + assert s.packets.tx_errors == 0 + assert s.get_uptime_secs() >= 0 + + def test_record_tx_rx(self): + s = StatsCollector() + s.record_tx(is_flood=True) + s.record_tx(is_flood=True) + s.record_tx(is_flood=False) + s.record_rx(is_flood=False) + s.record_rx(is_flood=True) + assert s.packets.flood_tx == 2 + assert s.packets.direct_tx == 1 + assert s.packets.direct_rx == 1 + assert s.packets.flood_rx == 1 + + def test_record_tx_error(self): + s = StatsCollector() + s.record_tx_error() + s.record_tx_error() + assert s.packets.tx_errors == 2 + + def test_get_totals(self): + s = StatsCollector() + s.record_tx(is_flood=True) + s.record_rx(is_flood=False) + tot = s.get_totals() + assert tot["flood_tx"] == 1 + assert tot["direct_rx"] == 1 + assert tot["total_tx"] == 1 + assert tot["total_rx"] == 1 + assert "uptime_secs" in tot + + def test_reset(self): + s = StatsCollector() + s.record_tx(is_flood=True) + s.record_tx_error() + s.reset() + assert s.packets.flood_tx == 0 + assert s.packets.tx_errors == 0 diff --git a/tests/test_crypto.py b/tests/test_crypto.py index 7a8d1fc..6a5a701 100644 --- a/tests/test_crypto.py +++ b/tests/test_crypto.py @@ -56,3 +56,69 @@ def test_crypto_utils_key_exchange(): # Both should compute the same shared secret assert alice_shared == bob_shared assert len(alice_shared) == 32 + + +def test_ed25519_expand_seed_to_meshcore_64_same_public_key(): + """ed25519_expand_seed_to_meshcore_64 produces 64-byte key that derives to + same public key as PyNaCl. + + MeshCore firmware (and meshcore-keygen) derives the public key by taking + the first 32 bytes of the 64-byte private key and calling + ed25519_derive_pub / crypto_scalarmult_ed25519_base_noclamp. + So our expanded form must yield the same first-32-bytes scalar so the + client gets the same pub key. + """ + from nacl.bindings import crypto_scalarmult_ed25519_base_noclamp + + from pymc_core import LocalIdentity + + # Fixed seed for reproducible test + seed = bytes.fromhex("70" + "65e18fd9fabb70c1ed90dca19907de" "698c88b709ea146eafd93d9b830c7b60") + assert len(seed) == 32 + + # Expand to 64-byte MeshCore format (SHA-512 + clamp) + expanded = CryptoUtils.ed25519_expand_seed_to_meshcore_64(seed) + assert len(expanded) == 64 + + # Derive public key the way MeshCore firmware/keygen does: first 32 bytes = + # scalar + pub_from_expanded = crypto_scalarmult_ed25519_base_noclamp(expanded[:32]) + + # PyNaCl LocalIdentity(seed=seed) derives pub via the same Ed25519 expansion internally + identity = LocalIdentity(seed=seed) + pub_from_pynacl = identity.get_public_key() + + assert pub_from_expanded == pub_from_pynacl, ( + "Expanded seed must yield same public key as PyNaCl so MeshCore " "clients derive correctly" + ) + + +# Sample keypair generated by meshcore-keygen (MeshCore firmware–compatible +# format). Used to verify our derivation matches: +# pub = crypto_scalarmult_ed25519_base_noclamp(private_64[:32]). +MESHCORE_SAMPLE_PRIVATE_HEX = ( + "a8cdb06cef221ef0ee4e5e9d9d0829499cd304ca1bba47af2ea17d83b316726b" + "e232b1da0388a2eb142d9a16ed66d4994a7ac40339ec1fbc3937f3b81f5dc62b" +) +MESHCORE_SAMPLE_PUBLIC_HEX = "8dfd59f1102e74845758096baf54d2b3a91413813e7c239baf75d1e9a0f04bf4" + + +def test_ed25519_expand_seed_matches_meshcore_keygen_format(): + """Deriving pub from first 32 bytes of a meshcore-keygen 64-byte private + key matches expected pub. + + Uses a sample keypair generated by meshcore-keygen to verify our + derivation (crypto_scalarmult_ed25519_base_noclamp(priv_64[:32])) + matches the firmware/keygen format. + """ + from nacl.bindings import crypto_scalarmult_ed25519_base_noclamp + + private_bytes = bytes.fromhex(MESHCORE_SAMPLE_PRIVATE_HEX) + expected_public = bytes.fromhex(MESHCORE_SAMPLE_PUBLIC_HEX) + assert len(private_bytes) == 64 + assert len(expected_public) == 32 + + derived_public = crypto_scalarmult_ed25519_base_noclamp(private_bytes[:32]) + assert derived_public == expected_public, ( + "MeshCore format: pub must be derived from first 32 bytes of " "64-byte private key" + ) diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py index a805815..5b1f9bd 100644 --- a/tests/test_dispatcher.py +++ b/tests/test_dispatcher.py @@ -7,6 +7,7 @@ from pymc_core.protocol import Packet from pymc_core.protocol.constants import PAYLOAD_TYPE_ACK, PAYLOAD_TYPE_ADVERT, PAYLOAD_TYPE_TXT_MSG from pymc_core.protocol.packet_filter import PacketFilter +from pymc_core.protocol.packet_utils import PathUtils def create_test_packet(payload_type: int, payload: bytes) -> bytes: @@ -243,7 +244,7 @@ async def test_ack_waiting_and_receipt(self, dispatcher): dispatcher._waiting_acks[crc] = ack_event # Simulate receiving ACK - dispatcher._register_ack_received(crc) + await dispatcher._register_ack_received(crc) # Event should be set assert ack_event.is_set() @@ -447,6 +448,58 @@ def test_callback(packet, data, analysis): assert received_packet is not None assert received_data == packet_data + @pytest.mark.asyncio + async def test_full_hop_count_packet_marked_do_not_retransmit(self, dispatcher): + """Packet with path at max hops for its encoding is marked do not retransmit.""" + received_packet = None + + def capture(packet, data, analysis): + nonlocal received_packet + received_packet = packet + + dispatcher.set_raw_packet_callback(capture) + + # 1-byte hashes: max 63 hops + pkt = Packet() + pkt.header = (1 << 6) | (PAYLOAD_TYPE_TXT_MSG << 2) + pkt.path_len = PathUtils.encode_path_len(1, 63) + pkt.path = bytearray(bytes(63)) + pkt.payload = bytearray(b"x") + pkt.payload_len = 1 + packet_data = pkt.write_to() + + await dispatcher._process_received_packet(packet_data) + + assert received_packet is not None + assert received_packet.is_marked_do_not_retransmit() is True + assert received_packet.get_path_hash_count() == 63 + + @pytest.mark.asyncio + async def test_2byte_path_at_max_hops_marked_do_not_retransmit(self, dispatcher): + """Packet with 2-byte path at 32 hops (64 bytes) is marked do not retransmit.""" + received_packet = None + + def capture(packet, data, analysis): + nonlocal received_packet + received_packet = packet + + dispatcher.set_raw_packet_callback(capture) + + pkt = Packet() + pkt.header = (1 << 6) | (PAYLOAD_TYPE_TXT_MSG << 2) + pkt.path_len = PathUtils.encode_path_len(2, 32) + pkt.path = bytearray(64) # 32 * 2 + pkt.payload = bytearray(b"x") + pkt.payload_len = 1 + packet_data = pkt.write_to() + + await dispatcher._process_received_packet(packet_data) + + assert received_packet is not None + assert received_packet.is_marked_do_not_retransmit() is True + assert received_packet.get_path_hash_count() == 32 + assert received_packet.get_path_byte_len() == 64 + @pytest.mark.asyncio async def test_async_callback(self, dispatcher): """Test async callback.""" diff --git a/tests/test_frame_server.py b/tests/test_frame_server.py new file mode 100644 index 0000000..d2284fd --- /dev/null +++ b/tests/test_frame_server.py @@ -0,0 +1,408 @@ +"""Tests for CompanionFrameServer and advert push frame construction.""" + +import asyncio +import struct +from unittest.mock import AsyncMock, Mock + +import pytest + +from pymc_core.companion.constants import ( + ERR_CODE_TABLE_FULL, + ERR_CODE_UNSUPPORTED_CMD, + MAX_PATH_SIZE, + PUB_KEY_SIZE, + PUSH_CODE_ADVERT, + PUSH_CODE_NEW_ADVERT, +) +from pymc_core.companion.frame_server import CompanionFrameServer, _build_advert_push_frames +from pymc_core.companion.models import Contact, SentResult + + +def test_build_advert_push_frames_short_only_when_no_name(): + """Contact with empty name yields only short frame; full is None.""" + pubkey = bytes(range(32)) + contact = Contact(public_key=pubkey, name="") + short, full = _build_advert_push_frames(contact) + assert full is None + assert len(short) == 1 + PUB_KEY_SIZE + assert short[0] == PUSH_CODE_ADVERT + assert short[1:33] == pubkey + + +def test_build_advert_push_frames_short_and_full_when_has_name(): + """Contact with name yields short frame and full NEW_ADVERT frame.""" + pubkey = bytes(range(32)) + contact = Contact( + public_key=pubkey, + name="Alice", + adv_type=1, + flags=2, + out_path_len=0, + out_path=b"", + last_advert_timestamp=1000, + lastmod=2000, + gps_lat=52.5, + gps_lon=-1.7, + ) + short, full = _build_advert_push_frames(contact) + assert full is not None + # Short frame + assert len(short) == 1 + PUB_KEY_SIZE + assert short[0] == PUSH_CODE_ADVERT + assert short[1:33] == pubkey + # Full frame: code(1) + pubkey(32) + adv_type,flags,opl(3) + path(64) + name(32) + # + last_advert(4) + gps_lat(4) + gps_lon(4) + lastmod(4) + expected_full_len = 1 + 32 + 3 + MAX_PATH_SIZE + 32 + 4 + 4 + 4 + 4 + assert len(full) == expected_full_len + assert full[0] == PUSH_CODE_NEW_ADVERT + assert full[1:33] == pubkey + assert full[33] == 1 # adv_type + assert full[34] == 2 # flags + assert full[35] == 0 # opl_byte (out_path_len 0) + out_path = full[36 : 36 + MAX_PATH_SIZE] + assert out_path == b"\x00" * MAX_PATH_SIZE + name_b = full[36 + MAX_PATH_SIZE : 36 + MAX_PATH_SIZE + 32] + assert name_b.startswith(b"Alice") + assert name_b.rstrip(b"\x00") == b"Alice" + offset = 36 + MAX_PATH_SIZE + 32 + assert struct.unpack(" _write_ok.""" + bridge = _MockBridgeSendRawDirect(success=True) + server = CompanionFrameServer(bridge, "hash", port=0) + server._write_ok = Mock() + server._write_err = Mock() + data = bytes([1, 0x42]) + b"\x01\x02\x03\x04" + await server._cmd_send_raw_data(data) + assert len(bridge.calls) == 1 + path, payload, path_len_enc = bridge.calls[0] + assert path == b"\x42" + assert payload == b"\x01\x02\x03\x04" + assert path_len_enc == 1 # 1-byte hash, 1 hop + server._write_ok.assert_called_once() + server._write_err.assert_not_called() + + +@pytest.mark.asyncio +async def test_cmd_send_raw_data_invalid_len_writes_unsupported(): + """Invalid CMD_SEND_RAW_DATA len < 6 -> ERR_CODE_UNSUPPORTED_CMD.""" + bridge = _MockBridgeSendRawDirect() + server = CompanionFrameServer(bridge, "hash", port=0) + server._write_ok = Mock() + server._write_err = Mock() + await server._cmd_send_raw_data(b"\x00\x00\x00") + assert len(bridge.calls) == 0 + server._write_err.assert_called_once_with(ERR_CODE_UNSUPPORTED_CMD) + server._write_ok.assert_not_called() + + +@pytest.mark.asyncio +async def test_cmd_send_raw_data_send_failure_writes_table_full(): + """send_raw_data_direct returns False -> ERR_CODE_TABLE_FULL.""" + bridge = _MockBridgeSendRawDirect(success=False) + server = CompanionFrameServer(bridge, "hash", port=0) + server._write_ok = Mock() + server._write_err = Mock() + data = bytes([1, 0x42]) + b"\x01\x02\x03\x04" + await server._cmd_send_raw_data(data) + assert len(bridge.calls) == 1 + server._write_err.assert_called_once_with(ERR_CODE_TABLE_FULL) + server._write_ok.assert_not_called() + + +@pytest.mark.asyncio +async def test_cmd_send_raw_data_2byte_hashes(): + """CMD_SEND_RAW_DATA with 2-byte hash path encoding.""" + from pymc_core.protocol.packet_utils import PathUtils + + bridge = _MockBridgeSendRawDirect(success=True) + server = CompanionFrameServer(bridge, "hash", port=0) + server._write_ok = Mock() + server._write_err = Mock() + # path_len_encoded=0x42 → 2-byte hashes, 2 hops → 4 bytes of path + path_len_byte = PathUtils.encode_path_len(2, 2) # 0x42 + path_data = b"\x01\x02\x03\x04" + payload_data = b"\xAA\xBB\xCC\xDD" + data = bytes([path_len_byte]) + path_data + payload_data + await server._cmd_send_raw_data(data) + assert len(bridge.calls) == 1 + path, payload, path_len_enc = bridge.calls[0] + assert path == path_data + assert payload == payload_data + assert path_len_enc == path_len_byte + server._write_ok.assert_called_once() + + +@pytest.mark.asyncio +async def test_cmd_send_raw_data_invalid_path_encoding(): + """CMD_SEND_RAW_DATA with reserved hash_size=4 encoding → error.""" + bridge = _MockBridgeSendRawDirect() + server = CompanionFrameServer(bridge, "hash", port=0) + server._write_ok = Mock() + server._write_err = Mock() + # 0xC1 = hash_size 4 (reserved), should fail validation + data = bytes([0xC1]) + b"\x00" * 10 + await server._cmd_send_raw_data(data) + assert len(bridge.calls) == 0 + server._write_err.assert_called_once_with(ERR_CODE_UNSUPPORTED_CMD) + + +@pytest.mark.asyncio +async def test_cmd_send_raw_data_truncated_multibyte_path(): + """CMD_SEND_RAW_DATA with not enough path bytes for 2-byte encoding → error.""" + from pymc_core.protocol.packet_utils import PathUtils + + bridge = _MockBridgeSendRawDirect() + server = CompanionFrameServer(bridge, "hash", port=0) + server._write_ok = Mock() + server._write_err = Mock() + # 0x43 = 2-byte hashes, 3 hops → needs 6 path bytes + 4 payload = 11 total + # But only provide 8 bytes after path_len (not enough) + path_len_byte = PathUtils.encode_path_len(2, 3) # 0x43 + data = bytes([path_len_byte]) + b"\x00" * 8 # only 8 bytes, need 6+4=10 + await server._cmd_send_raw_data(data) + assert len(bridge.calls) == 0 + server._write_err.assert_called_once_with(ERR_CODE_UNSUPPORTED_CMD) + + +@pytest.mark.asyncio +async def test_push_trace_data_enqueues_frame(): + """push_trace_data enqueues a correctly formatted trace frame.""" + bridge = _MockBridgeSendRawDirect() + server = CompanionFrameServer(bridge, "hash", port=0) + server._write_queue = asyncio.Queue(maxsize=256) + + await server.push_trace_data( + path_len=1, + flags=0, + tag=1, + auth_code=0, + path_hashes=b"\x00", + path_snrs=b"\x00", + final_snr_byte=0, + ) + assert not server._write_queue.empty() + frame = server._write_queue.get_nowait() + # Frame format: FRAME_OUTBOUND_PREFIX + 2-byte LE length + payload + assert frame[0] == 0x3E # FRAME_OUTBOUND_PREFIX + _ = struct.unpack(" None: + self.calls.append(mode) + + +@pytest.mark.asyncio +async def test_cmd_set_path_hash_mode_valid(): + """Valid CMD_SET_PATH_HASH_MODE for each mode (0, 1, 2) → _write_ok.""" + for mode in (0, 1, 2): + bridge = _MockBridgePathHashMode() + server = CompanionFrameServer(bridge, "hash", port=0) + server._write_ok = Mock() + server._write_err = Mock() + await server._cmd_set_path_hash_mode(bytes([0, mode])) + assert bridge.calls == [mode] + server._write_ok.assert_called_once() + server._write_err.assert_not_called() + + +@pytest.mark.asyncio +async def test_cmd_set_path_hash_mode_invalid_mode(): + """CMD_SET_PATH_HASH_MODE with mode >= 3 → ERR_CODE_ILLEGAL_ARG.""" + from pymc_core.companion.constants import ERR_CODE_ILLEGAL_ARG + + bridge = _MockBridgePathHashMode() + server = CompanionFrameServer(bridge, "hash", port=0) + server._write_ok = Mock() + server._write_err = Mock() + await server._cmd_set_path_hash_mode(bytes([0, 3])) + assert len(bridge.calls) == 0 + server._write_err.assert_called_once_with(ERR_CODE_ILLEGAL_ARG) + + +@pytest.mark.asyncio +async def test_cmd_set_path_hash_mode_wrong_subtype(): + """CMD_SET_PATH_HASH_MODE with subtype != 0 → ERR_CODE_ILLEGAL_ARG.""" + from pymc_core.companion.constants import ERR_CODE_ILLEGAL_ARG + + bridge = _MockBridgePathHashMode() + server = CompanionFrameServer(bridge, "hash", port=0) + server._write_ok = Mock() + server._write_err = Mock() + await server._cmd_set_path_hash_mode(bytes([1, 0])) + assert len(bridge.calls) == 0 + server._write_err.assert_called_once_with(ERR_CODE_ILLEGAL_ARG) + + +@pytest.mark.asyncio +async def test_cmd_set_path_hash_mode_too_short(): + """CMD_SET_PATH_HASH_MODE with only 1 byte → ERR_CODE_ILLEGAL_ARG.""" + from pymc_core.companion.constants import ERR_CODE_ILLEGAL_ARG + + bridge = _MockBridgePathHashMode() + server = CompanionFrameServer(bridge, "hash", port=0) + server._write_ok = Mock() + server._write_err = Mock() + await server._cmd_set_path_hash_mode(bytes([0])) + assert len(bridge.calls) == 0 + server._write_err.assert_called_once_with(ERR_CODE_ILLEGAL_ARG) + + +@pytest.mark.asyncio +async def test_device_info_includes_path_hash_mode(): + """RESP_CODE_DEVICE_INFO frame includes path_hash_mode at byte [81].""" + from pymc_core.companion.constants import RESP_CODE_DEVICE_INFO + from pymc_core.companion.models import NodePrefs + + prefs = NodePrefs() + prefs.path_hash_mode = 2 # 3-byte hashes + + bridge = Mock() + bridge.get_self_info = Mock(return_value=prefs) + bridge.contacts = Mock(max_contacts=100) + bridge.channels = Mock(max_channels=8) + + server = CompanionFrameServer(bridge, "hash", port=0) + frames = [] + server._write_frame = lambda f: frames.append(f) + + await server._cmd_device_query(bytes([10])) # app_ver = 10 + + assert len(frames) == 1 + frame = frames[0] + assert frame[0] == RESP_CODE_DEVICE_INFO + assert len(frame) == 82 # 81 bytes (old) + 1 byte path_hash_mode + assert frame[81] == 2 # path_hash_mode at last byte diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 835aaf3..2f82414 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -14,7 +14,7 @@ TextMessageHandler, TraceHandler, ) -from pymc_core.protocol import LocalIdentity, Packet, PacketBuilder +from pymc_core.protocol import CryptoUtils, Identity, LocalIdentity, Packet, PacketBuilder from pymc_core.protocol.constants import ( PAYLOAD_TYPE_ACK, PAYLOAD_TYPE_ADVERT, @@ -259,9 +259,50 @@ def setup_method(self): self.send_packet_fn = AsyncMock() self.event_service = MockEventService() self.handler = GroupTextHandler( - self.local_identity, self.contacts, self.log_fn, self.send_packet_fn + self.local_identity, + self.contacts, + self.log_fn, + self.send_packet_fn, + channel_db=None, + event_service=self.event_service, + our_node_name="InitialName", ) - # GroupTextHandler doesn't take event_service in constructor + + def test_set_our_node_name_updates_stored_name(self): + """set_our_node_name updates the name used for echo detection.""" + assert self.handler.our_node_name == "InitialName" + self.handler.set_our_node_name("NewName") + assert self.handler.our_node_name == "NewName" + self.handler.set_our_node_name(None) + assert self.handler.our_node_name is None + + def test_is_own_message_uses_current_name_after_set_our_node_name(self): + """_is_own_message uses the current our_node_name after it is updated.""" + self.handler.set_our_node_name("Howl 🏝️") + packet = Packet() + packet.decrypted = {"group_text_data": {"sender_name": "Howl 🏝️"}} + assert self.handler._is_own_message(packet) is True + packet.decrypted = {"group_text_data": {"sender_name": "Howl 🧱"}} + assert self.handler._is_own_message(packet) is False + # After updating name, old name no longer matches + self.handler.set_our_node_name("Howl 🧱") + assert self.handler._is_own_message(packet) is True + + def test_is_own_message_false_when_sender_name_missing(self): + """_is_own_message returns False when packet has no sender_name in group_text_data.""" + self.handler.set_our_node_name("Me") + packet = Packet() + packet.decrypted = {} + assert self.handler._is_own_message(packet) is False + packet.decrypted = {"group_text_data": {}} + assert self.handler._is_own_message(packet) is False + + def test_is_own_message_false_when_no_match(self): + """_is_own_message returns False when sender name differs from our_node_name.""" + self.handler.set_our_node_name("Me") + packet = Packet() + packet.decrypted = {"group_text_data": {"sender_name": "Other"}} + assert self.handler._is_own_message(packet) is False def test_payload_type(self): """Test group text handler payload type.""" @@ -273,7 +314,7 @@ def test_group_text_handler_initialization(self): assert self.handler.contacts == self.contacts assert self.handler.log == self.log_fn assert self.handler.send_packet == self.send_packet_fn - # GroupTextHandler doesn't store event_service + assert self.handler.our_node_name == "InitialName" # Login Response Handler Tests @@ -316,8 +357,147 @@ def test_protocol_response_handler_initialization(self): assert self.handler._log == self.log_fn assert self.handler._local_identity == self.local_identity + def test_parse_telemetry_response_tag_plus_lpp(self): + """Parse tag(4) + CayenneLPP matches repeater firmware format; raw_bytes is LPP only.""" + # Repeater sends: tag(4) + LPP. Tag is 4-byte reflected_timestamp (little-endian). + # MeshCore first record: addVoltage(TELEM_CHANNEL_SELF=1, v) + # → channel=1, type=0x74 (LPP_VOLTAGE), 2 bytes 0.01V big-endian. 3.7V → 370 → 0x01 0x72 + tag = b"\x01\x00\x00\x00" # LE 1 + lpp = bytes([0x01, 0x74, 0x01, 0x72]) # ch 1, Voltage, 370 (3.70 V) + data = tag + lpp + result = self.handler._parse_telemetry_response(data) + assert result is not None + assert result["type"] == "telemetry" + assert result["reflected_timestamp"] == 1 + assert result["raw_bytes"] == lpp + assert result["sensor_count"] == 1 + sensor = result["sensors"][0] + assert sensor["channel"] == 1 + assert sensor["type"] == "Voltage" + assert sensor["type_id"] == 0x74 + assert abs(sensor["value"] - 3.7) < 0.001 + + def test_parse_telemetry_response_rejects_non_telemetry(self): + """Payload without channel=1, type=0x74 signature is not classified as telemetry.""" + tag = b"\x00\x00\x00\x00" # LE 0 + # Not starting with 0x01 0x74 + data = tag + bytes([0x01, 0x67, 0x00, 0x00]) # ch 1, Temperature, 0°C + result = self.handler._parse_telemetry_response(data) + assert result is None + + def test_set_contact_path_updated_callback(self): + """set_contact_path_updated_callback stores the callback.""" + cb = MagicMock() + self.handler.set_contact_path_updated_callback(cb) + assert self.handler._contact_path_updated_callback is cb + self.handler.set_contact_path_updated_callback(None) + assert self.handler._contact_path_updated_callback is None + + @pytest.mark.asyncio + async def test_contact_path_updated_callback_invoked_on_path_update(self): + """PATH decrypts and updates contact path; contact_path_updated callback is invoked.""" + from pymc_core.companion.contact_store import ContactStore + from pymc_core.companion.models import Contact + + local_identity = LocalIdentity() + peer_identity = LocalIdentity() + peer_pubkey = peer_identity.get_public_key() + contacts = ContactStore(5) + contacts.add(Contact(public_key=peer_pubkey, name="Peer")) + log_fn = MagicMock() + handler = ProtocolResponseHandler(log_fn, local_identity, contacts) + handler.set_binary_response_callback(lambda *a, **k: None) + + path_len_byte = 2 + path_bytes = bytes([0x01, 0x02]) + extra_type = PAYLOAD_TYPE_RESPONSE + extra = bytes([0, 0, 0, 0, 0x00]) # tag(4) + 1 byte (not login response) + plaintext = bytes([path_len_byte]) + path_bytes + bytes([extra_type]) + extra + + peer_id = Identity(peer_pubkey) + shared_secret = peer_id.calc_shared_secret(local_identity.get_private_key()) + aes_key = shared_secret[:16] + encrypted = CryptoUtils.encrypt_then_mac(aes_key, shared_secret, plaintext) + + our_hash = local_identity.get_public_key()[0] + src_hash = peer_pubkey[0] + payload = bytes([our_hash, src_hash]) + encrypted + + pkt = Packet() + pkt.header = (0 << 0) | (PAYLOAD_TYPE_PATH << 2) + pkt.path_len = 0 + pkt.path = bytearray() + pkt.payload = bytearray(payload) + pkt.payload_len = len(payload) + + callback_calls = [] + + async def on_path_updated(pub: bytes, path_len: int, path_bytes_arg: bytes) -> None: + callback_calls.append((pub, path_len, path_bytes_arg)) + + handler.set_contact_path_updated_callback(on_path_updated) + + await handler(pkt) + + assert len(callback_calls) == 1 + assert callback_calls[0][0] == peer_pubkey + assert callback_calls[0][1] == path_len_byte + assert callback_calls[0][2] == path_bytes + + @pytest.mark.asyncio + async def test_contact_path_updated_with_2byte_hashes(self): + """PATH with 2-byte hashes decrypts and updates contact path correctly.""" + from pymc_core.companion.contact_store import ContactStore + from pymc_core.companion.models import Contact + from pymc_core.protocol.packet_utils import PathUtils + + local_identity = LocalIdentity() + peer_identity = LocalIdentity() + peer_pubkey = peer_identity.get_public_key() + contacts = ContactStore(5) + contacts.add(Contact(public_key=peer_pubkey, name="Peer")) + log_fn = MagicMock() + handler = ProtocolResponseHandler(log_fn, local_identity, contacts) + handler.set_binary_response_callback(lambda *a, **k: None) + + # 2 hops × 2-byte hashes = 4 bytes of path data + path_len_byte = PathUtils.encode_path_len(2, 2) # 0x42 + path_bytes = bytes([0x01, 0x02, 0x03, 0x04]) + extra_type = PAYLOAD_TYPE_RESPONSE + extra = bytes([0, 0, 0, 0, 0x00]) + plaintext = bytes([path_len_byte]) + path_bytes + bytes([extra_type]) + extra + + peer_id = Identity(peer_pubkey) + shared_secret = peer_id.calc_shared_secret(local_identity.get_private_key()) + aes_key = shared_secret[:16] + encrypted = CryptoUtils.encrypt_then_mac(aes_key, shared_secret, plaintext) + + our_hash = local_identity.get_public_key()[0] + src_hash = peer_pubkey[0] + payload = bytes([our_hash, src_hash]) + encrypted + + pkt = Packet() + pkt.header = (0 << 0) | (PAYLOAD_TYPE_PATH << 2) + pkt.path_len = 0 + pkt.path = bytearray() + pkt.payload = bytearray(payload) + pkt.payload_len = len(payload) + + callback_calls = [] + + async def on_path_updated(pub: bytes, path_len: int, path_bytes_arg: bytes) -> None: + callback_calls.append((pub, path_len, path_bytes_arg)) + + handler.set_contact_path_updated_callback(on_path_updated) + + await handler(pkt) + + assert len(callback_calls) == 1 + assert callback_calls[0][0] == peer_pubkey + assert callback_calls[0][1] == path_len_byte # encoded byte, not raw count + assert callback_calls[0][2] == path_bytes # all 4 bytes of path data + -# Trace Handler Tests class TestTraceHandler: def setup_method(self): self.log_fn = MagicMock() diff --git a/tests/test_kiss_modem_wrapper.py b/tests/test_kiss_modem_wrapper.py index 77c8e72..a7397a6 100644 --- a/tests/test_kiss_modem_wrapper.py +++ b/tests/test_kiss_modem_wrapper.py @@ -13,23 +13,14 @@ from pymc_core.hardware.kiss_modem_wrapper import ( CMD_DATA, - CMD_ENCRYPT_DATA, - CMD_GET_AIRTIME, CMD_GET_BATTERY, - CMD_GET_IDENTITY, - CMD_GET_NOISE_FLOOR, CMD_GET_RADIO, - CMD_GET_RANDOM, CMD_GET_STATS, - CMD_GET_TX_POWER, CMD_GET_VERSION, - CMD_HASH, - CMD_KEY_EXCHANGE, CMD_PING, CMD_SET_RADIO, CMD_SET_TX_POWER, CMD_SIGN_DATA, - CMD_VERIFY_SIGNATURE, HW_CMD_GET_DEVICE_NAME, HW_CMD_GET_MCU_TEMP, HW_CMD_GET_SIGNAL_REPORT, @@ -39,6 +30,7 @@ HW_RESP_DEVICE_NAME, HW_RESP_MCU_TEMP, HW_RESP_OK, + HW_RESP_RX_META, HW_RESP_SIGNAL_REPORT, KISS_CMD_FULLDUPLEX, KISS_CMD_PERSISTENCE, @@ -49,25 +41,15 @@ KISS_FESC, KISS_TFEND, KISS_TFESC, - RESP_AIRTIME, RESP_BATTERY, - RESP_ENCRYPTED, RESP_ERROR, - RESP_HASH, RESP_IDENTITY, - RESP_NOISE_FLOOR, RESP_OK, RESP_PONG, RESP_RADIO, - RESP_RANDOM, - RESP_SHARED_SECRET, RESP_SIGNATURE, RESP_STATS, RESP_TX_DONE, - RESP_TX_POWER, - RESP_VERIFY, - RESP_VERSION, - HW_RESP_RX_META, KissModemWrapper, ) @@ -141,7 +123,9 @@ def test_decode_simple_frame(self): # Data frame: FEND + 0x00 + raw_packet + FEND (no in-frame metadata) data_frame = bytes([KISS_FEND, CMD_DATA, 0x01, 0x02, 0x03, KISS_FEND]) # RxMeta: FEND + 0x06 + 0xF9 + SNR + RSSI + FEND (sent immediately after Data) - rx_meta_frame = bytes([KISS_FEND, KISS_CMD_SETHARDWARE, HW_RESP_RX_META, 0x10, 0xB0, KISS_FEND]) + rx_meta_frame = bytes( + [KISS_FEND, KISS_CMD_SETHARDWARE, HW_RESP_RX_META, 0x10, 0xB0, KISS_FEND] + ) for byte in data_frame: modem._decode_kiss_byte(byte) @@ -160,10 +144,10 @@ def test_decode_frame_with_escapes(self): modem.on_frame_received = lambda data: received_frames.append(data) # Data frame: payload is escaped 0xC0 (FESC + TFEND) - data_frame = bytes( - [KISS_FEND, CMD_DATA, KISS_FESC, KISS_TFEND, KISS_FEND] + data_frame = bytes([KISS_FEND, CMD_DATA, KISS_FESC, KISS_TFEND, KISS_FEND]) + rx_meta_frame = bytes( + [KISS_FEND, KISS_CMD_SETHARDWARE, HW_RESP_RX_META, 0x10, 0xB0, KISS_FEND] ) - rx_meta_frame = bytes([KISS_FEND, KISS_CMD_SETHARDWARE, HW_RESP_RX_META, 0x10, 0xB0, KISS_FEND]) for byte in data_frame: modem._decode_kiss_byte(byte) @@ -180,7 +164,9 @@ def test_decode_extracts_rssi_snr(self): data_frame = bytes([KISS_FEND, CMD_DATA, 0xAA, 0xBB, KISS_FEND]) # RxMeta: SNR=0x10 (4.0 dB), RSSI=0xB0 (-80) - rx_meta_frame = bytes([KISS_FEND, KISS_CMD_SETHARDWARE, HW_RESP_RX_META, 0x10, 0xB0, KISS_FEND]) + rx_meta_frame = bytes( + [KISS_FEND, KISS_CMD_SETHARDWARE, HW_RESP_RX_META, 0x10, 0xB0, KISS_FEND] + ) for byte in data_frame: modem._decode_kiss_byte(byte) @@ -196,6 +182,7 @@ def test_rx_callback_receives_per_packet_rssi_snr(self): modem.is_connected = True received = [] + def capture(data, rssi, snr): received.append((data, rssi, snr)) @@ -274,7 +261,7 @@ def test_send_command_encodes_correctly(self): assert written_frame[0] == KISS_FEND assert written_frame[1] == KISS_CMD_SETHARDWARE # type SetHardware - assert written_frame[2] == HW_CMD_GET_VERSION # sub_cmd GetVersion + assert written_frame[2] == HW_CMD_GET_VERSION # sub_cmd GetVersion assert written_frame[-1] == KISS_FEND def test_response_parsing_identity(self): @@ -334,7 +321,7 @@ class TestRadioConfiguration: def test_radio_config_struct_format(self): """Test that radio config is packed correctly""" - modem = KissModemWrapper(port="/dev/null", auto_configure=False) + KissModemWrapper(port="/dev/null", auto_configure=False) freq_hz = 869618000 bw_hz = 62500 @@ -1007,7 +994,8 @@ def test_context_manager_calls_connect_disconnect(self): with patch.object(KissModemWrapper, "connect", return_value=True) as mock_connect: with patch.object(KissModemWrapper, "disconnect") as mock_disconnect: with KissModemWrapper(port="/dev/null", auto_configure=False) as modem: - pass + pass # keep reference so __del__ doesn't run before assert mock_connect.assert_called_once() mock_disconnect.assert_called_once() + _ = modem # hold ref so __del__ runs after assert, not before diff --git a/tests/test_modem_identity.py b/tests/test_modem_identity.py index f0a28dd..414672f 100644 --- a/tests/test_modem_identity.py +++ b/tests/test_modem_identity.py @@ -12,7 +12,6 @@ from pymc_core.protocol.modem_identity import ModemIdentity - # Generate a valid Ed25519 keypair for testing _TEST_SIGNING_KEY = SigningKey.generate() _TEST_PUBKEY = bytes(_TEST_SIGNING_KEY.verify_key) diff --git a/tests/test_packet.py b/tests/test_packet.py index 02e1328..c0d849d 100644 --- a/tests/test_packet.py +++ b/tests/test_packet.py @@ -1,4 +1,7 @@ +import pytest + from pymc_core.protocol import Packet +from pymc_core.protocol.packet_utils import PathUtils # Packet tests @@ -28,3 +31,290 @@ def test_packet_validation(): # Should validate successfully packet._validate_lengths() + + +# --- Multi-byte path support --- + + +class TestPacketSetPath: + """Tests for Packet.set_path() and path accessor methods.""" + + def test_set_path_default_1byte_hashes(self): + """set_path without encoded path_len assumes 1-byte hashes.""" + pkt = Packet() + pkt.header = 0x02 # ROUTE_TYPE_DIRECT + pkt.payload = bytearray(b"data") + pkt.payload_len = 4 + + pkt.set_path(b"\xAA\xBB\xCC") + + assert pkt.path == bytearray(b"\xAA\xBB\xCC") + assert pkt.path_len == 3 # encode_path_len(1, 3) == 3 + assert pkt.get_path_hash_size() == 1 + assert pkt.get_path_hash_count() == 3 + assert pkt.get_path_byte_len() == 3 + + def test_set_path_with_2byte_encoded(self): + """set_path with 2-byte hash encoded path_len.""" + pkt = Packet() + pkt.header = 0x02 + pkt.payload = bytearray(b"data") + pkt.payload_len = 4 + + # 2 hops × 2-byte hashes = 4 bytes of path data + encoded = PathUtils.encode_path_len(2, 2) # 0x42 + pkt.set_path(b"\x01\x02\x03\x04", path_len_encoded=encoded) + + assert pkt.path_len == 0x42 + assert pkt.get_path_hash_size() == 2 + assert pkt.get_path_hash_count() == 2 + assert pkt.get_path_byte_len() == 4 + assert pkt.path == bytearray(b"\x01\x02\x03\x04") + + def test_set_path_with_3byte_encoded(self): + """set_path with 3-byte hash encoded path_len.""" + pkt = Packet() + pkt.header = 0x02 + pkt.payload = bytearray(b"data") + pkt.payload_len = 4 + + # 3 hops × 3-byte hashes = 9 bytes of path data + encoded = PathUtils.encode_path_len(3, 3) # 0x83 + path_data = bytes(range(9)) + pkt.set_path(path_data, path_len_encoded=encoded) + + assert pkt.path_len == 0x83 + assert pkt.get_path_hash_size() == 3 + assert pkt.get_path_hash_count() == 3 + assert pkt.get_path_byte_len() == 9 + assert pkt.path == bytearray(path_data) + + def test_set_path_empty(self): + """set_path with empty path.""" + pkt = Packet() + pkt.header = 0x02 + pkt.payload = bytearray(b"data") + pkt.payload_len = 4 + + pkt.set_path(b"") + + assert pkt.path_len == 0 + assert pkt.get_path_hash_size() == 1 + assert pkt.get_path_hash_count() == 0 + assert pkt.get_path_byte_len() == 0 + + def test_set_path_64_bytes_raises_without_encoded(self): + """64-byte path without path_len_encoded would encode as 0 hops (64 & 0x3F).""" + pkt = Packet() + pkt.header = 0x02 + pkt.payload = bytearray(b"data") + pkt.payload_len = 4 + + with pytest.raises(ValueError, match="path length 64 exceeds maximum encodable"): + pkt.set_path(bytes(64)) + + # With explicit path_len_encoded (63 hops) path must be 63 bytes + encoded_63 = PathUtils.encode_path_len(1, 63) + pkt.set_path(bytes(63), path_len_encoded=encoded_63) + assert pkt.path_len == 0x3F + assert pkt.get_path_byte_len() == 63 + + +class TestGetPathHashes: + """Tests for Packet.get_path_hashes() and get_path_hashes_hex().""" + + def test_1byte_hashes(self): + pkt = Packet() + pkt.set_path(b"\xAA\xBB\xCC") + assert pkt.get_path_hashes() == [b"\xAA", b"\xBB", b"\xCC"] + + def test_2byte_hashes(self): + pkt = Packet() + encoded = PathUtils.encode_path_len(2, 2) + pkt.set_path(b"\xAA\xBB\xCC\xDD", path_len_encoded=encoded) + assert pkt.get_path_hashes() == [b"\xAA\xBB", b"\xCC\xDD"] + + def test_3byte_hashes(self): + pkt = Packet() + encoded = PathUtils.encode_path_len(3, 2) + pkt.set_path(b"\xAA\xBB\xCC\xDD\xEE\xFF", path_len_encoded=encoded) + assert pkt.get_path_hashes() == [b"\xAA\xBB\xCC", b"\xDD\xEE\xFF"] + + def test_hex_2byte(self): + pkt = Packet() + encoded = PathUtils.encode_path_len(2, 2) + pkt.set_path(b"\xAA\xBB\xCC\xDD", path_len_encoded=encoded) + assert pkt.get_path_hashes_hex() == ["AABB", "CCDD"] + + def test_hex_1byte_backward_compat(self): + pkt = Packet() + pkt.set_path(b"\xB5\xA3\xF2") + assert pkt.get_path_hashes_hex() == ["B5", "A3", "F2"] + + def test_empty_path(self): + pkt = Packet() + pkt.path_len = 0 + pkt.path = bytearray() + assert pkt.get_path_hashes() == [] + assert pkt.get_path_hashes_hex() == [] + + def test_zero_hops_with_hash_mode(self): + """0 hops but hash_size=2 (originated with path_hash_mode=1).""" + pkt = Packet() + pkt.path_len = PathUtils.encode_path_len(2, 0) # 0x40 + pkt.path = bytearray() + assert pkt.get_path_hashes() == [] + assert pkt.get_path_hashes_hex() == [] + + +class TestPacketRoundTrip: + """Tests for write_to → read_from round-trip with multi-byte paths.""" + + def _make_packet(self, header, path_data, path_len_encoded, payload): + """Helper to build a packet with given path encoding.""" + pkt = Packet() + pkt.header = header + pkt.set_path(path_data, path_len_encoded) + pkt.payload = bytearray(payload) + pkt.payload_len = len(payload) + return pkt + + def test_roundtrip_no_path(self): + """Round-trip with empty path.""" + pkt = self._make_packet(0x05, b"", 0, b"hello") + raw = pkt.write_to() + pkt2 = Packet() + pkt2.read_from(raw) + + assert pkt2.header == 0x05 + assert pkt2.path_len == 0 + assert pkt2.get_path_byte_len() == 0 + assert pkt2.path == bytearray() + assert pkt2.get_payload() == b"hello" + + def test_roundtrip_1byte_hashes(self): + """Round-trip with 1-byte hashes (backward compatible).""" + path = b"\xAA\xBB\xCC" + pkt = self._make_packet(0x06, path, PathUtils.encode_path_len(1, 3), b"payload") + raw = pkt.write_to() + pkt2 = Packet() + pkt2.read_from(raw) + + assert pkt2.get_path_hash_size() == 1 + assert pkt2.get_path_hash_count() == 3 + assert pkt2.get_path_byte_len() == 3 + assert pkt2.path == bytearray(path) + assert pkt2.get_payload() == b"payload" + + def test_roundtrip_2byte_hashes(self): + """Round-trip with 2-byte hashes.""" + # 4 hops × 2 bytes = 8 bytes path + path = bytes(range(8)) + encoded = PathUtils.encode_path_len(2, 4) + pkt = self._make_packet(0x06, path, encoded, b"test") + raw = pkt.write_to() + pkt2 = Packet() + pkt2.read_from(raw) + + assert pkt2.path_len == encoded + assert pkt2.get_path_hash_size() == 2 + assert pkt2.get_path_hash_count() == 4 + assert pkt2.get_path_byte_len() == 8 + assert pkt2.path == bytearray(path) + assert pkt2.get_payload() == b"test" + + def test_roundtrip_3byte_hashes(self): + """Round-trip with 3-byte hashes.""" + # 3 hops × 3 bytes = 9 bytes path + path = bytes([0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99]) + encoded = PathUtils.encode_path_len(3, 3) + pkt = self._make_packet(0x06, path, encoded, b"msg") + raw = pkt.write_to() + pkt2 = Packet() + pkt2.read_from(raw) + + assert pkt2.path_len == encoded + assert pkt2.get_path_hash_size() == 3 + assert pkt2.get_path_hash_count() == 3 + assert pkt2.get_path_byte_len() == 9 + assert pkt2.path == bytearray(path) + assert pkt2.get_payload() == b"msg" + + def test_roundtrip_max_2byte_path(self): + """Round-trip with maximum valid 2-byte hash path (32 hops × 2 = 64 bytes).""" + path = bytes(range(64)) + encoded = PathUtils.encode_path_len(2, 32) + pkt = self._make_packet(0x06, path, encoded, b"x") + raw = pkt.write_to() + pkt2 = Packet() + pkt2.read_from(raw) + + assert pkt2.get_path_hash_size() == 2 + assert pkt2.get_path_hash_count() == 32 + assert pkt2.get_path_byte_len() == 64 + assert pkt2.path == bytearray(path) + + def test_roundtrip_max_3byte_path(self): + """Round-trip with maximum valid 3-byte hash path (21 hops × 3 = 63 bytes).""" + path = bytes(range(63)) + encoded = PathUtils.encode_path_len(3, 21) + pkt = self._make_packet(0x06, path, encoded, b"x") + raw = pkt.write_to() + pkt2 = Packet() + pkt2.read_from(raw) + + assert pkt2.get_path_hash_size() == 3 + assert pkt2.get_path_hash_count() == 21 + assert pkt2.get_path_byte_len() == 63 + assert pkt2.path == bytearray(path) + + def test_roundtrip_1byte_backward_compat(self): + """Backward compatibility: 1-byte hash packets identical to legacy format.""" + # Build the packet the OLD way (direct assignment) + pkt_old = Packet() + pkt_old.header = 0x06 + pkt_old.path = bytearray(b"\x01\x02\x03") + pkt_old.path_len = 3 + pkt_old.payload = bytearray(b"test") + pkt_old.payload_len = 4 + raw_old = pkt_old.write_to() + + # Build the packet the NEW way (set_path) + pkt_new = Packet() + pkt_new.header = 0x06 + pkt_new.set_path(b"\x01\x02\x03") + pkt_new.payload = bytearray(b"test") + pkt_new.payload_len = 4 + raw_new = pkt_new.write_to() + + # Wire formats must be identical + assert raw_old == raw_new + + def test_read_from_invalid_path_len_raises(self): + """read_from rejects reserved hash_size=4 encoding.""" + # Hand-craft a raw packet with path_len 0xC1 (hash_size=4, invalid) + raw = bytes([0x06, 0xC1]) # header + invalid path_len + pkt = Packet() + with pytest.raises(ValueError, match="invalid path_len encoding"): + pkt.read_from(raw) + + def test_read_from_truncated_path_raises(self): + """read_from rejects packet with insufficient path bytes.""" + # 2 hops × 2-byte hashes = 4 bytes expected, but only provide 2 + encoded = PathUtils.encode_path_len(2, 2) # needs 4 bytes of path + raw = bytes([0x06, encoded, 0xAA, 0xBB]) # header + path_len + only 2 path bytes + pkt = Packet() + with pytest.raises(ValueError, match="truncated path"): + pkt.read_from(raw) + + def test_get_raw_length_multibyte(self): + """get_raw_length accounts for multi-byte path encoding.""" + pkt = Packet() + pkt.header = 0x06 + # 5 hops × 2-byte hashes = 10 bytes path + pkt.set_path(bytes(10), PathUtils.encode_path_len(2, 5)) + pkt.payload = bytearray(b"test") + pkt.payload_len = 4 + + # raw_length = header(1) + path_len_byte(1) + path(10) + payload(4) = 16 + assert pkt.get_raw_length() == 16 diff --git a/tests/test_packet_builder.py b/tests/test_packet_builder.py index 58f453f..fc60a0b 100644 --- a/tests/test_packet_builder.py +++ b/tests/test_packet_builder.py @@ -1,5 +1,10 @@ from pymc_core import LocalIdentity -from pymc_core.protocol.constants import PAYLOAD_TYPE_ACK, PAYLOAD_TYPE_ADVERT +from pymc_core.protocol.constants import ( + MAX_PACKET_PAYLOAD, + PAYLOAD_TYPE_ACK, + PAYLOAD_TYPE_ADVERT, + PAYLOAD_TYPE_RAW_CUSTOM, +) from pymc_core.protocol.packet_builder import PacketBuilder @@ -51,3 +56,24 @@ def test_packet_builder_create_direct_advert(): assert direct_advert is not None assert direct_advert.get_payload_type() == PAYLOAD_TYPE_ADVERT + + +def test_packet_builder_create_raw_data(): + """Test creating raw custom packets (PAYLOAD_TYPE_RAW_CUSTOM).""" + data = b"\x01\x02\x03\x04" + pkt = PacketBuilder.create_raw_data(data) + assert pkt is not None + assert pkt.get_payload_type() == PAYLOAD_TYPE_RAW_CUSTOM + assert pkt.payload == bytearray(data) + assert pkt.payload_len == len(data) + assert pkt.path_len == 0 + assert pkt.path == bytearray() + + +def test_packet_builder_create_raw_data_too_large_raises(): + """Test create_raw_data raises when data exceeds MAX_PACKET_PAYLOAD.""" + import pytest + + data = bytes(MAX_PACKET_PAYLOAD + 1) + with pytest.raises(ValueError, match="exceeds MAX_PACKET_PAYLOAD"): + PacketBuilder.create_raw_data(data) diff --git a/tests/test_packet_path_and_hash.py b/tests/test_packet_path_and_hash.py index b5894cd..c17907a 100644 --- a/tests/test_packet_path_and_hash.py +++ b/tests/test_packet_path_and_hash.py @@ -21,12 +21,10 @@ MAX_PACKET_PAYLOAD, MAX_PATH_SIZE, MAX_SUPPORTED_PAYLOAD_VERSION, - PATH_HASH_SIZE, PAYLOAD_TYPE_ACK, PAYLOAD_TYPE_ADVERT, - PAYLOAD_TYPE_TXT_MSG, PAYLOAD_TYPE_TRACE, - PH_ROUTE_MASK, + PAYLOAD_TYPE_TXT_MSG, PH_TYPE_SHIFT, PH_VER_SHIFT, ROUTE_TYPE_DIRECT, @@ -34,17 +32,13 @@ ROUTE_TYPE_TRANSPORT_DIRECT, ROUTE_TYPE_TRANSPORT_FLOOD, ) -from pymc_core.protocol.packet_utils import ( - PacketHashingUtils, - PacketHeaderUtils, - PacketValidationUtils, -) - +from pymc_core.protocol.packet_utils import PacketHashingUtils, PacketValidationUtils # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- + def _make_header(payload_type: int, route_type: int, version: int = 0) -> int: """Build a header byte from components (mirrors PacketHeaderUtils.make_header).""" return route_type | (payload_type << PH_TYPE_SHIFT) | (version << 6) @@ -94,6 +88,7 @@ def _build_packet( # 1. Basic hash correctness — matches reference C++ implementation # =================================================================== + class TestPacketHashBasic: """Verify hash output matches the C++ reference for various packet types.""" @@ -177,6 +172,7 @@ def test_trace_path_len_uint16_not_uint8(self): # 2. Hash excludes header, path, route type, transport codes # =================================================================== + class TestPacketHashExclusions: """Verify that fields NOT in the C++ hash don't affect the Python hash.""" @@ -193,7 +189,9 @@ def test_hash_independent_of_path_content(self): payload = b"path_test_data" pkt_no_path = _build_packet(PAYLOAD_TYPE_TXT_MSG, ROUTE_TYPE_FLOOD, payload) - pkt_with_path = _build_packet(PAYLOAD_TYPE_TXT_MSG, ROUTE_TYPE_FLOOD, payload, path=b"\xAA\xBB\xCC") + pkt_with_path = _build_packet( + PAYLOAD_TYPE_TXT_MSG, ROUTE_TYPE_FLOOD, payload, path=b"\xAA\xBB\xCC" + ) assert pkt_no_path.calculate_packet_hash() == pkt_with_path.calculate_packet_hash() @@ -239,6 +237,7 @@ def test_different_payload_type_gives_different_hash(self): # 3. payload_len truncation — hash must use only payload[:payload_len] # =================================================================== + class TestPayloadLenTruncation: """Verify hash uses payload[:payload_len], not the full bytearray buffer.""" @@ -307,6 +306,7 @@ def test_trace_hash_uses_payload_len_truncation_too(self): # 4. Flood forwarding: path append must not change duplicate detection # =================================================================== + class TestFloodPathAppend: """Simulate flood forwarding and verify hash stability.""" @@ -363,6 +363,7 @@ def test_max_path_flood(self): # 5. Direct forwarding: path consume must not change duplicate detection # =================================================================== + class TestDirectPathConsume: """Simulate direct forwarding path consumption and verify hash stability.""" @@ -422,6 +423,7 @@ def test_trace_hash_changes_after_path_consume(self): # 6. Serialization round-trip preserves hash # =================================================================== + class TestSerializationHashPreservation: """Verify that write_to → read_from round-trip preserves the packet hash.""" @@ -505,6 +507,7 @@ def test_roundtrip_after_flood_append_preserves_hash(self): # 7. PacketHashingUtils standalone tests # =================================================================== + class TestPacketHashingUtilsStandalone: """Test the static utility directly, confirming C++ compatibility.""" @@ -559,11 +562,12 @@ def test_hash_string_truncation(self): # 8. Edge cases and regression guards # =================================================================== -class TestEdgeCases: +class TestEdgeCases: def test_max_payload_hashes_correctly(self): """MAX_PACKET_PAYLOAD-sized payload should hash without error.""" from pymc_core.protocol.constants import MAX_PACKET_PAYLOAD + payload = bytes(range(256)) * (MAX_PACKET_PAYLOAD // 256 + 1) payload = payload[:MAX_PACKET_PAYLOAD] @@ -619,6 +623,7 @@ def test_header_0xff_sentinel_changes_payload_type(self): # 9. Bad / malformed packets — must be rejected by validation # =================================================================== + class TestBadPacketDeserialization: """Verify that read_from rejects malformed wire data.""" @@ -639,27 +644,34 @@ def test_unsupported_version_rejected(self): """Version > MAX_SUPPORTED_PAYLOAD_VERSION should be rejected.""" bad_version = MAX_SUPPORTED_PAYLOAD_VERSION + 1 # Craft header with unsupported version in bits 6-7 - header = ROUTE_TYPE_FLOOD | (PAYLOAD_TYPE_TXT_MSG << PH_TYPE_SHIFT) | (bad_version << PH_VER_SHIFT) + header = ( + ROUTE_TYPE_FLOOD + | (PAYLOAD_TYPE_TXT_MSG << PH_TYPE_SHIFT) + | (bad_version << PH_VER_SHIFT) + ) wire = bytes([header, 0]) # header + path_len=0, no payload pkt = Packet() with pytest.raises(ValueError, match="Unsupported packet version"): pkt.read_from(wire) def test_path_len_exceeds_max(self): - """path_len > MAX_PATH_SIZE (64) must be rejected.""" + """Encoded path_len that decodes to > MAX_PATH_SIZE (64) bytes must be rejected.""" + from pymc_core.protocol.packet_utils import PathUtils + header = _make_header(PAYLOAD_TYPE_TXT_MSG, ROUTE_TYPE_FLOOD) - bad_path_len = MAX_PATH_SIZE + 1 # 65 - wire = bytes([header, bad_path_len]) + bytes(bad_path_len) + b"payload" + # 33 hops × 2 bytes = 66 path bytes (exceeds MAX_PATH_SIZE) + bad_path_len = PathUtils.encode_path_len(hash_size=2, hash_count=33) + wire = bytes([header, bad_path_len]) + bytes(66) + b"payload" pkt = Packet() with pytest.raises(ValueError, match="path_len too large"): pkt.read_from(wire) def test_path_len_255_rejected(self): - """path_len=255 (max uint8) must be rejected.""" + """path_len=255 (reserved hash_size 4) must be rejected.""" header = _make_header(PAYLOAD_TYPE_TXT_MSG, ROUTE_TYPE_FLOOD) wire = bytes([header, 0xFF]) + bytes(255) pkt = Packet() - with pytest.raises(ValueError, match="path_len too large"): + with pytest.raises(ValueError, match="invalid path_len encoding"): pkt.read_from(wire) def test_truncated_path_rejected(self): @@ -696,13 +708,18 @@ def test_oversized_payload_rejected(self): pkt.read_from(wire) def test_path_len_exact_max_accepted(self): - """path_len == MAX_PATH_SIZE should be accepted (boundary test).""" + """Encoded path_len that decodes to exactly MAX_PATH_SIZE (64) bytes should be accepted.""" + from pymc_core.protocol.packet_utils import PathUtils + header = _make_header(PAYLOAD_TYPE_TXT_MSG, ROUTE_TYPE_FLOOD) + # 32 hops × 2 bytes = 64 path bytes (exactly MAX_PATH_SIZE) + path_len_byte = PathUtils.encode_path_len(hash_size=2, hash_count=32) path_data = bytes(MAX_PATH_SIZE) - wire = bytes([header, MAX_PATH_SIZE]) + path_data + b"ok" + wire = bytes([header, path_len_byte]) + path_data + b"ok" pkt = Packet() pkt.read_from(wire) - assert pkt.path_len == MAX_PATH_SIZE + assert pkt.path_len == path_len_byte + assert len(pkt.path) == MAX_PATH_SIZE assert pkt.payload == bytearray(b"ok") @@ -772,6 +789,7 @@ def test_random_garbage_rejected_or_benign(self): or raise ValueError — never crash with an unhandled exception. """ import random + rng = random.Random(42) # deterministic seed for _ in range(100): length = rng.randint(0, 300) diff --git a/tests/test_packet_utils.py b/tests/test_packet_utils.py index c5bb384..2eff143 100644 --- a/tests/test_packet_utils.py +++ b/tests/test_packet_utils.py @@ -7,6 +7,7 @@ PacketDataUtils, PacketHashingUtils, PacketValidationUtils, + PathUtils, ) @@ -149,7 +150,9 @@ class TestPacketHashingUtils: def test_hash_string_returns_full_uppercase_hex(self): payload_type = 0x05 path_len = 0 - payload = bytes.fromhex("D9BA8E4EA9444822AC56B4D52AC3C0044C6AE402997BB9805CCB331EC3378DCE339F2D") + payload = bytes.fromhex( + "D9BA8E4EA9444822AC56B4D52AC3C0044C6AE402997BB9805CCB331EC3378DCE339F2D" + ) expected_hex = "887B9BE6056D0B0517AF3A04AC2478EDFC2AB731936DEA525041500E7ADE74D3" @@ -166,7 +169,9 @@ def test_hash_string_returns_full_uppercase_hex(self): def test_hash_string_truncates_to_requested_length(self): payload_type = 0x05 path_len = 1 - payload = bytes.fromhex("D9BA8E4EA9444822AC56B4D52AC3C0044C6AE402997BB9805CCB331EC3378DCE339F2D") + payload = bytes.fromhex( + "D9BA8E4EA9444822AC56B4D52AC3C0044C6AE402997BB9805CCB331EC3378DCE339F2D" + ) expected_hex = "887B9BE6056D0B05" @@ -180,3 +185,155 @@ def test_hash_string_truncates_to_requested_length(self): assert truncated == expected_hex[:16] assert len(truncated) == 16 assert truncated.isupper() + + +class TestPathUtils: + """Tests for multi-byte path encoding/decoding utilities.""" + + # --- get_path_hash_size --- + + def test_hash_size_1byte(self): + """Bits 6-7 == 0b00 → hash_size = 1.""" + assert PathUtils.get_path_hash_size(0x00) == 1 + assert PathUtils.get_path_hash_size(0x05) == 1 + assert PathUtils.get_path_hash_size(0x3F) == 1 # max hop count, 1-byte hashes + + def test_hash_size_2byte(self): + """Bits 6-7 == 0b01 → hash_size = 2.""" + assert PathUtils.get_path_hash_size(0x40) == 2 + assert PathUtils.get_path_hash_size(0x45) == 2 + assert PathUtils.get_path_hash_size(0x7F) == 2 + + def test_hash_size_3byte(self): + """Bits 6-7 == 0b10 → hash_size = 3.""" + assert PathUtils.get_path_hash_size(0x80) == 3 + assert PathUtils.get_path_hash_size(0x8A) == 3 + assert PathUtils.get_path_hash_size(0xBF) == 3 + + def test_hash_size_reserved(self): + """Bits 6-7 == 0b11 → hash_size = 4 (reserved, invalid).""" + assert PathUtils.get_path_hash_size(0xC0) == 4 + + # --- get_path_hash_count --- + + def test_hash_count_extracts_lower_6_bits(self): + """Hash count is the lower 6 bits (0-63).""" + assert PathUtils.get_path_hash_count(0x00) == 0 + assert PathUtils.get_path_hash_count(0x05) == 5 + assert PathUtils.get_path_hash_count(0x3F) == 63 + # Upper bits should be masked off + assert PathUtils.get_path_hash_count(0x45) == 5 # 0b01_000101 + assert PathUtils.get_path_hash_count(0x8A) == 10 # 0b10_001010 + assert PathUtils.get_path_hash_count(0xC0) == 0 # 0b11_000000 + + # --- get_path_byte_len --- + + def test_byte_len_1byte_hashes(self): + """1-byte hashes: byte_len = hop_count * 1.""" + assert PathUtils.get_path_byte_len(0x00) == 0 + assert PathUtils.get_path_byte_len(0x01) == 1 + assert PathUtils.get_path_byte_len(0x05) == 5 + assert PathUtils.get_path_byte_len(0x3F) == 63 + + def test_byte_len_2byte_hashes(self): + """2-byte hashes: byte_len = hop_count * 2.""" + assert PathUtils.get_path_byte_len(0x40) == 0 # 0 hops × 2 + assert PathUtils.get_path_byte_len(0x41) == 2 # 1 hop × 2 + assert PathUtils.get_path_byte_len(0x45) == 10 # 5 hops × 2 + assert PathUtils.get_path_byte_len(0x60) == 64 # 32 hops × 2 = 64 + + def test_byte_len_3byte_hashes(self): + """3-byte hashes: byte_len = hop_count * 3.""" + assert PathUtils.get_path_byte_len(0x80) == 0 # 0 hops × 3 + assert PathUtils.get_path_byte_len(0x81) == 3 # 1 hop × 3 + assert PathUtils.get_path_byte_len(0x8A) == 30 # 10 hops × 3 + assert PathUtils.get_path_byte_len(0x95) == 63 # 21 hops × 3 = 63 + + # --- encode_path_len --- + + def test_encode_1byte_hashes(self): + """1-byte hashes: encoded = (0 << 6) | count = count.""" + assert PathUtils.encode_path_len(1, 0) == 0x00 + assert PathUtils.encode_path_len(1, 1) == 0x01 + assert PathUtils.encode_path_len(1, 5) == 0x05 + assert PathUtils.encode_path_len(1, 63) == 0x3F + + def test_encode_2byte_hashes(self): + """2-byte hashes: encoded = (1 << 6) | count.""" + assert PathUtils.encode_path_len(2, 0) == 0x40 + assert PathUtils.encode_path_len(2, 1) == 0x41 + assert PathUtils.encode_path_len(2, 5) == 0x45 + assert PathUtils.encode_path_len(2, 32) == 0x60 + + def test_encode_3byte_hashes(self): + """3-byte hashes: encoded = (2 << 6) | count.""" + assert PathUtils.encode_path_len(3, 0) == 0x80 + assert PathUtils.encode_path_len(3, 1) == 0x81 + assert PathUtils.encode_path_len(3, 10) == 0x8A + assert PathUtils.encode_path_len(3, 21) == 0x95 + + def test_encode_decode_roundtrip(self): + """encode → get_path_hash_size + get_path_hash_count roundtrip.""" + for hash_size in (1, 2, 3): + for count in (0, 1, 10, 63): + encoded = PathUtils.encode_path_len(hash_size, count) + assert PathUtils.get_path_hash_size(encoded) == hash_size + assert PathUtils.get_path_hash_count(encoded) == count + assert PathUtils.get_path_byte_len(encoded) == hash_size * count + + # --- is_valid_path_len --- + + def test_valid_path_len_1byte(self): + """1-byte hashes: valid up to MAX_PATH_SIZE hops.""" + assert PathUtils.is_valid_path_len(0x00) is True # 0 hops + assert PathUtils.is_valid_path_len(0x01) is True # 1 hop + assert PathUtils.is_valid_path_len(0x3F) is True # 63 hops, 63 bytes ≤ MAX_PATH_SIZE(64) + + def test_valid_path_len_2byte(self): + """2-byte hashes: valid when hop_count * 2 ≤ MAX_PATH_SIZE.""" + assert PathUtils.is_valid_path_len(0x40) is True # 0 hops + assert PathUtils.is_valid_path_len(0x41) is True # 1 hop × 2 = 2 + assert PathUtils.is_valid_path_len(0x60) is True # 32 hops × 2 = 64 ≤ MAX_PATH_SIZE + assert PathUtils.is_valid_path_len(0x61) is False # 33 hops × 2 = 66 > MAX_PATH_SIZE + + def test_valid_path_len_3byte(self): + """3-byte hashes: valid when hop_count * 3 ≤ MAX_PATH_SIZE.""" + assert PathUtils.is_valid_path_len(0x80) is True # 0 hops + assert PathUtils.is_valid_path_len(0x95) is True # 21 hops × 3 = 63 ≤ MAX_PATH_SIZE + assert PathUtils.is_valid_path_len(0x96) is False # 22 hops × 3 = 66 > MAX_PATH_SIZE + + def test_valid_path_len_reserved_hash_size(self): + """Hash size 4 (bits 6-7 == 0b11) is reserved and always invalid.""" + assert PathUtils.is_valid_path_len(0xC0) is False # hash_size=4, 0 hops + assert PathUtils.is_valid_path_len(0xC1) is False # hash_size=4, 1 hop + assert PathUtils.is_valid_path_len(0xFF) is False # hash_size=4, 63 hops + + # --- Backward compatibility --- + + def test_1byte_backward_compatible(self): + """For 1-byte hashes, encoded path_len == raw hop count == byte count.""" + for n in range(0, 64): + encoded = PathUtils.encode_path_len(1, n) + assert encoded == n + assert PathUtils.get_path_byte_len(encoded) == n + + def test_encode_path_len_rejects_hop_count_64(self): + """Hop count is 6 bits (0-63); 64 would mask to 0 and produce invalid packet.""" + with pytest.raises(ValueError, match="hop count must be 0-63"): + PathUtils.encode_path_len(1, 64) + with pytest.raises(ValueError, match="hop count must be 0-63"): + PathUtils.encode_path_len(1, 100) + + def test_is_path_at_max_hops(self): + """is_path_at_max_hops is True when path bytes/hops are at limit for hash size.""" + # No path + assert PathUtils.is_path_at_max_hops(0) is False + # 1-byte hashes: max 63 hops + assert PathUtils.is_path_at_max_hops(PathUtils.encode_path_len(1, 62)) is False + assert PathUtils.is_path_at_max_hops(PathUtils.encode_path_len(1, 63)) is True + # 2-byte hashes: max 32 hops (64 bytes) + assert PathUtils.is_path_at_max_hops(PathUtils.encode_path_len(2, 31)) is False + assert PathUtils.is_path_at_max_hops(PathUtils.encode_path_len(2, 32)) is True + # 3-byte hashes: max 21 hops (63 bytes) + assert PathUtils.is_path_at_max_hops(PathUtils.encode_path_len(3, 20)) is False + assert PathUtils.is_path_at_max_hops(PathUtils.encode_path_len(3, 21)) is True