From 7b14d30440bc3cf1d8b80fbaae4c33c35557791f Mon Sep 17 00:00:00 2001 From: Nate Karstens Date: Sat, 6 Jan 2024 17:55:19 -0600 Subject: [PATCH] Add support for USB connections Adds a new transport to manage USB connections. Signed-off-by: Nate Karstens --- pybricksdev/ble/pybricks.py | 6 + pybricksdev/connections/pybricks.py | 172 +++++++++++++++++++++++++++- 2 files changed, 175 insertions(+), 3 deletions(-) diff --git a/pybricksdev/ble/pybricks.py b/pybricksdev/ble/pybricks.py index cb058b9..d33bdf4 100644 --- a/pybricksdev/ble/pybricks.py +++ b/pybricksdev/ble/pybricks.py @@ -328,6 +328,12 @@ def _standard_uuid(short: int) -> str: .. availability:: Since Pybricks protocol v1.0.0. """ +DEVICE_NAME_UUID = _standard_uuid(0x2A00) +"""Standard Device Name UUID + +.. availability:: Since Pybricks protocol v1.0.0. +""" + FW_REV_UUID = _standard_uuid(0x2A26) """Standard Firmware Revision String characteristic UUID diff --git a/pybricksdev/connections/pybricks.py b/pybricksdev/connections/pybricks.py index e481812..e1e4013 100644 --- a/pybricksdev/connections/pybricks.py +++ b/pybricksdev/connections/pybricks.py @@ -8,6 +8,7 @@ import os import struct from typing import Awaitable, Callable, List, Optional, Tuple, TypeVar +from uuid import UUID import reactivex.operators as op import semver @@ -19,9 +20,15 @@ from tqdm.auto import tqdm from tqdm.contrib.logging import logging_redirect_tqdm +from usb.control import get_descriptor +from usb.core import Device as USBDevice +from usb.core import Endpoint, USBTimeoutError +from usb.util import ENDPOINT_IN, ENDPOINT_OUT, endpoint_direction, find_descriptor + from ..ble.lwp3.bytecodes import HubKind from ..ble.nus import NUS_RX_UUID, NUS_TX_UUID from ..ble.pybricks import ( + DEVICE_NAME_UUID, FW_REV_UUID, PNP_ID_UUID, PYBRICKS_COMMAND_EVENT_UUID, @@ -38,6 +45,7 @@ from ..compile import compile_file, compile_multi_file from ..tools import chunk from ..tools.checksum import xor_bytes +from ..usb import LegoUsbPid from . import ConnectionState logger = logging.getLogger(__name__) @@ -138,6 +146,158 @@ def handler(_, data): await self._client.start_notify(NUS_TX_UUID, handler) +class _USBTransport(_Transport): + _device: USBDevice + _disconnected_callback: Callable + _ep_in: Endpoint + _ep_out: Endpoint + _notify_callbacks = {} + _monitor_task: asyncio.Task + _response_event = asyncio.Event() + _response: int + + _USB_PYBRICKS_MSG_COMMAND = b"\x00" + _USB_PYBRICKS_MSG_COMMAND_RESPONSE = b"\x01" + _USB_PYBRICKS_MSG_EVENT = b"\x02" + + def __init__(self, device: USBDevice): + self._device = device + self._notify_callbacks[ + self._USB_PYBRICKS_MSG_COMMAND_RESPONSE[0] + ] = self._response_handler + + async def connect(self, disconnected_callback: Callable) -> None: + self._disconnected_callback = disconnected_callback + self._device.set_configuration() + + # Save input and output endpoints + cfg = self._device.get_active_configuration() + intf = cfg[(0, 0)] + self._ep_in = find_descriptor( + intf, + custom_match=lambda e: endpoint_direction(e.bEndpointAddress) + == ENDPOINT_IN, + ) + self._ep_out = find_descriptor( + intf, + custom_match=lambda e: endpoint_direction(e.bEndpointAddress) + == ENDPOINT_OUT, + ) + + # Get length of BOS descriptor + bos_descriptor = get_descriptor(self._device, 5, 0x0F, 0) + (ofst, bos_len) = struct.unpack(" None: + # FIXME: Need to make sure this is called when the USB cable is unplugged + self._monitor_task.cancel() + self._disconnected_callback() + + async def get_firmware_version(self) -> Version: + return self._fw_version + + async def get_protocol_version(self) -> Version: + return self._protocol_version + + async def get_hub_type(self) -> Tuple[HubKind, int]: + hub_types = { + LegoUsbPid.SPIKE_PRIME: (HubKind.TECHNIC_LARGE, 0), + LegoUsbPid.ROBOT_INVENTOR: (HubKind.TECHNIC_LARGE, 1), + LegoUsbPid.SPIKE_ESSENTIAL: (HubKind.TECHNIC_SMALL, 0), + } + + return hub_types[self._device.idProduct] + + async def get_hub_capabilities(self) -> Tuple[int, HubCapabilityFlag, int]: + return ( + self._max_write_size, + self._capability_flags, + self._max_user_program_size, + ) + + async def send_command(self, command: bytes) -> None: + self._response = None + self._response_event.clear() + self._ep_out.write(self._USB_PYBRICKS_MSG_COMMAND + command) + try: + await asyncio.wait_for(self._response_event.wait(), 1) + if self._response != 0: + print(f"Received error response for command: {self._response}") + except asyncio.TimeoutError: + print("Timed out waiting for a response") + + async def set_service_handler(self, callback: Callable) -> None: + self._notify_callbacks[self._USB_PYBRICKS_MSG_EVENT[0]] = callback + + async def _monitor_usb(self): + loop = asyncio.get_running_loop() + + while True: + msg = await loop.run_in_executor(None, self._read_usb) + + if msg is None or len(msg) == 0: + continue + + if msg[0] in self._notify_callbacks: + callback = self._notify_callbacks[msg[0]] + callback(bytes(msg[1:])) + + def _read_usb(self): + try: + msg = self._ep_in.read(self._ep_in.wMaxPacketSize) + return msg + except USBTimeoutError: + return None + + def _response_handler(self, data: bytes) -> None: + (self._response,) = struct.unpack(" None: if self._enable_line_handler: self._handle_line_data(payload) - async def connect(self, device: BLEDevice): + async def connect(self, device): """Connects to a device that was discovered with :meth:`pybricksdev.ble.find_device` + or :meth:`usb.core.find` Args: - device: The device to connect to. + device: The device to connect to (`BLEDevice` or `USBDevice`). Raises: BleakError: if connecting failed (or old firmware without Device @@ -350,7 +511,12 @@ async def connect(self, device: BLEDevice): self.connection_state_observable.on_next, ConnectionState.DISCONNECTED ) - self._transport = _BLETransport(device) + if isinstance(device, BLEDevice): + self._transport = _BLETransport(device) + elif isinstance(device, USBDevice): + self._transport = _USBTransport(device) + else: + raise TypeError("Unsupported device type") def handle_disconnect(): logger.info("Disconnected!")