Skip to content

Commit

Permalink
Add support for USB connections
Browse files Browse the repository at this point in the history
Adds a new transport to manage USB connections.

Signed-off-by: Nate Karstens <nate.karstens@gmail.com>
  • Loading branch information
nkarstens committed Feb 18, 2024
1 parent e6b6b9a commit 7b14d30
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 3 deletions.
6 changes: 6 additions & 0 deletions pybricksdev/ble/pybricks.py
Expand Up @@ -328,6 +328,12 @@ def _standard_uuid(short: int) -> str:
.. availability:: Since Pybricks protocol v1.0.0.
"""

DEVICE_NAME_UUID = _standard_uuid(0x2A00)
"""Standard Device Name UUID
.. availability:: Since Pybricks protocol v1.0.0.
"""

FW_REV_UUID = _standard_uuid(0x2A26)
"""Standard Firmware Revision String characteristic UUID
Expand Down
172 changes: 169 additions & 3 deletions pybricksdev/connections/pybricks.py
Expand Up @@ -8,6 +8,7 @@
import os
import struct
from typing import Awaitable, Callable, List, Optional, Tuple, TypeVar
from uuid import UUID

import reactivex.operators as op
import semver
Expand All @@ -19,9 +20,15 @@
from tqdm.auto import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm

from usb.control import get_descriptor
from usb.core import Device as USBDevice
from usb.core import Endpoint, USBTimeoutError
from usb.util import ENDPOINT_IN, ENDPOINT_OUT, endpoint_direction, find_descriptor

from ..ble.lwp3.bytecodes import HubKind
from ..ble.nus import NUS_RX_UUID, NUS_TX_UUID
from ..ble.pybricks import (
DEVICE_NAME_UUID,
FW_REV_UUID,
PNP_ID_UUID,
PYBRICKS_COMMAND_EVENT_UUID,
Expand All @@ -38,6 +45,7 @@
from ..compile import compile_file, compile_multi_file
from ..tools import chunk
from ..tools.checksum import xor_bytes
from ..usb import LegoUsbPid
from . import ConnectionState

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -138,6 +146,158 @@ def handler(_, data):
await self._client.start_notify(NUS_TX_UUID, handler)


class _USBTransport(_Transport):
_device: USBDevice
_disconnected_callback: Callable
_ep_in: Endpoint
_ep_out: Endpoint
_notify_callbacks = {}
_monitor_task: asyncio.Task
_response_event = asyncio.Event()
_response: int

_USB_PYBRICKS_MSG_COMMAND = b"\x00"
_USB_PYBRICKS_MSG_COMMAND_RESPONSE = b"\x01"
_USB_PYBRICKS_MSG_EVENT = b"\x02"

def __init__(self, device: USBDevice):
self._device = device
self._notify_callbacks[
self._USB_PYBRICKS_MSG_COMMAND_RESPONSE[0]
] = self._response_handler

async def connect(self, disconnected_callback: Callable) -> None:
self._disconnected_callback = disconnected_callback
self._device.set_configuration()

# Save input and output endpoints
cfg = self._device.get_active_configuration()
intf = cfg[(0, 0)]
self._ep_in = find_descriptor(
intf,
custom_match=lambda e: endpoint_direction(e.bEndpointAddress)
== ENDPOINT_IN,
)
self._ep_out = find_descriptor(
intf,
custom_match=lambda e: endpoint_direction(e.bEndpointAddress)
== ENDPOINT_OUT,
)

# Get length of BOS descriptor
bos_descriptor = get_descriptor(self._device, 5, 0x0F, 0)
(ofst, bos_len) = struct.unpack("<BxHx", bos_descriptor)

# Get full BOS descriptor
bos_descriptor = get_descriptor(self._device, bos_len, 0x0F, 0)

while ofst < bos_len:
(len, desc_type, cap_type) = struct.unpack_from(
"<BBB", bos_descriptor, offset=ofst
)

if desc_type != 0x10:
raise Exception("Expected Device Capability descriptor")

# Look for platform descriptors
if cap_type == 0x05:
uuid_bytes = bos_descriptor[ofst + 4 : ofst + 4 + 16]
uuid_str = str(UUID(bytes_le=bytes(uuid_bytes)))

if uuid_str == DEVICE_NAME_UUID:
self._device_name = bytes(
bos_descriptor[ofst + 20 : ofst + len]
).decode()
print("Connected to hub '" + self._device_name + "'")

elif uuid_str == FW_REV_UUID:
fw_version = bytes(bos_descriptor[ofst + 20 : ofst + len])
self._fw_version = Version(fw_version.decode())

elif uuid_str == SW_REV_UUID:
protocol_version = bytes(bos_descriptor[ofst + 20 : ofst + len])
self._protocol_version = semver.VersionInfo.parse(
protocol_version.decode()
)

elif uuid_str == PYBRICKS_HUB_CAPABILITIES_UUID:
caps = bytes(bos_descriptor[ofst + 20 : ofst + len])
(
self._max_write_size,
self._capability_flags,
self._max_user_program_size,
) = unpack_hub_capabilities(caps)

ofst += len

self._monitor_task = asyncio.create_task(self._monitor_usb())

async def disconnect(self) -> None:
# FIXME: Need to make sure this is called when the USB cable is unplugged
self._monitor_task.cancel()
self._disconnected_callback()

async def get_firmware_version(self) -> Version:
return self._fw_version

async def get_protocol_version(self) -> Version:
return self._protocol_version

async def get_hub_type(self) -> Tuple[HubKind, int]:
hub_types = {
LegoUsbPid.SPIKE_PRIME: (HubKind.TECHNIC_LARGE, 0),
LegoUsbPid.ROBOT_INVENTOR: (HubKind.TECHNIC_LARGE, 1),
LegoUsbPid.SPIKE_ESSENTIAL: (HubKind.TECHNIC_SMALL, 0),
}

return hub_types[self._device.idProduct]

async def get_hub_capabilities(self) -> Tuple[int, HubCapabilityFlag, int]:
return (
self._max_write_size,
self._capability_flags,
self._max_user_program_size,
)

async def send_command(self, command: bytes) -> None:
self._response = None
self._response_event.clear()
self._ep_out.write(self._USB_PYBRICKS_MSG_COMMAND + command)
try:
await asyncio.wait_for(self._response_event.wait(), 1)
if self._response != 0:
print(f"Received error response for command: {self._response}")
except asyncio.TimeoutError:
print("Timed out waiting for a response")

async def set_service_handler(self, callback: Callable) -> None:
self._notify_callbacks[self._USB_PYBRICKS_MSG_EVENT[0]] = callback

async def _monitor_usb(self):
loop = asyncio.get_running_loop()

while True:
msg = await loop.run_in_executor(None, self._read_usb)

if msg is None or len(msg) == 0:
continue

if msg[0] in self._notify_callbacks:
callback = self._notify_callbacks[msg[0]]
callback(bytes(msg[1:]))

def _read_usb(self):
try:
msg = self._ep_in.read(self._ep_in.wMaxPacketSize)
return msg
except USBTimeoutError:
return None

def _response_handler(self, data: bytes) -> None:
(self._response,) = struct.unpack("<I", data)
self._response_event.set()


class PybricksHub:
EOL = b"\r\n" # MicroPython EOL

Expand Down Expand Up @@ -326,11 +486,12 @@ def _pybricks_service_handler(self, data: bytes) -> None:
if self._enable_line_handler:
self._handle_line_data(payload)

async def connect(self, device: BLEDevice):
async def connect(self, device):
"""Connects to a device that was discovered with :meth:`pybricksdev.ble.find_device`
or :meth:`usb.core.find`
Args:
device: The device to connect to.
device: The device to connect to (`BLEDevice` or `USBDevice`).
Raises:
BleakError: if connecting failed (or old firmware without Device
Expand All @@ -350,7 +511,12 @@ async def connect(self, device: BLEDevice):
self.connection_state_observable.on_next, ConnectionState.DISCONNECTED
)

self._transport = _BLETransport(device)
if isinstance(device, BLEDevice):
self._transport = _BLETransport(device)
elif isinstance(device, USBDevice):
self._transport = _USBTransport(device)
else:
raise TypeError("Unsupported device type")

def handle_disconnect():
logger.info("Disconnected!")
Expand Down

0 comments on commit 7b14d30

Please sign in to comment.