From a638b6111ea6fe227044feba94a50e90d7edcdbc Mon Sep 17 00:00:00 2001 From: matejcik Date: Fri, 26 Nov 2021 16:31:35 +0100 Subject: [PATCH] feat(python): introduce Trezor models This keeps information about vendors and USB IDs in one place, and allows us to extend with model-specific information later. By default, this should be backwards-compatible -- TrezorClient can optionally accept model information, and if not, it will try to guess based on Features. It is possible to specify which models to look for in transport enumeration. Bridge and UDP transports ignore the parameter, because they can't know what model is on the other side. supersedes #1448 and #1449 --- python/src/trezorlib/__init__.py | 7 ---- python/src/trezorlib/client.py | 28 ++++++++++---- python/src/trezorlib/models.py | 43 ++++++++++++++++++++++ python/src/trezorlib/transport/__init__.py | 19 +++++----- python/src/trezorlib/transport/bridge.py | 9 ++++- python/src/trezorlib/transport/hid.py | 15 ++++++-- python/src/trezorlib/transport/udp.py | 9 ++++- python/src/trezorlib/transport/webusb.py | 13 +++++-- 8 files changed, 108 insertions(+), 35 deletions(-) create mode 100644 python/src/trezorlib/models.py diff --git a/python/src/trezorlib/__init__.py b/python/src/trezorlib/__init__.py index c3c862f89a4..0b1e89c1476 100644 --- a/python/src/trezorlib/__init__.py +++ b/python/src/trezorlib/__init__.py @@ -15,10 +15,3 @@ # If not, see . __version__ = "0.13.0" - -# fmt: off -MINIMUM_FIRMWARE_VERSION = { - "1": (1, 8, 0), - "T": (2, 1, 0), -} -# fmt: on diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 18d860fde24..40c7e21a89d 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -21,7 +21,7 @@ from mnemonic import Mnemonic -from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages +from . import exceptions, mapping, messages, models from .log import DUMP_BYTES from .messages import Capability from .tools import expect, parse_path, session @@ -33,7 +33,6 @@ LOG = logging.getLogger(__name__) -VENDORS = ("bitcointrezor.com", "trezor.io") MAX_PASSPHRASE_LENGTH = 50 MAX_PIN_LENGTH = 50 @@ -85,6 +84,7 @@ def __init__( ui: "TrezorClientUI", session_id: Optional[bytes] = None, derive_cardano: Optional[bool] = None, + model: Optional[models.TrezorModel] = None, _init_device: bool = True, ) -> None: """Create a TrezorClient instance. @@ -101,6 +101,9 @@ def __init__( You can supply a `session_id` you might have saved in the previous session. If you do, the user might not need to enter their passphrase again. + You can provide Trezor model information. If not provided, it is detected from + the model name reported at initialization time. + By default, the instance will open a connection to the Trezor device, send an `Initialize` message, set up the `features` field from the response, and connect to a session. By specifying `_init_device=False`, this step is skipped. Notably, @@ -110,7 +113,11 @@ def __init__( might be removed at any time. """ LOG.info(f"creating client instance for device: {transport.get_path()}") - self.mapping = mapping.DEFAULT_MAPPING + self.model = model + if self.model: + self.mapping = self.model.default_mapping + else: + self.mapping = mapping.DEFAULT_MAPPING self.transport = transport self.ui = ui self.session_counter = 0 @@ -254,7 +261,14 @@ def call(self, msg: "MessageType") -> "MessageType": def _refresh_features(self, features: messages.Features) -> None: """Update internal fields based on passed-in Features message.""" - if features.vendor not in VENDORS: + + if not self.model: + # Trezor 1 bootloader does not send model name + self.model = models.by_name(features.model or "1") + if self.model is None: + raise RuntimeError("Unsupported Trezor model") + + if features.vendor not in self.model.vendors: raise RuntimeError("Unsupported device") self.features = features @@ -353,9 +367,9 @@ def init_device( def is_outdated(self) -> bool: if self.features.bootloader_mode: return False - model = self.features.model or "1" - required_version = MINIMUM_FIRMWARE_VERSION[model] - return self.version < required_version + + assert self.model is not None # should happen in _refresh_features + return self.version < self.model.minimum_version def check_firmware_version(self, warn_only: bool = False) -> None: if self.is_outdated(): diff --git a/python/src/trezorlib/models.py b/python/src/trezorlib/models.py new file mode 100644 index 00000000000..c918c4a03ea --- /dev/null +++ b/python/src/trezorlib/models.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass +from typing import Collection, Optional, Tuple + +from . import mapping + +UsbId = Tuple[int, int] + +VENDORS = ("bitcointrezor.com", "trezor.io") + + +@dataclass(eq=True, frozen=True) +class TrezorModel: + name: str + minimum_version: Tuple[int, int, int] + vendors: Collection[str] + usb_ids: Collection[UsbId] + default_mapping: mapping.ProtobufMapping + + +MODEL_ONE = TrezorModel( + name="1", + minimum_version=(1, 8, 0), + vendors=VENDORS, + usb_ids=((0x534C, 0x0001),), + default_mapping=mapping.DEFAULT_MAPPING, +) + +MODEL_T = TrezorModel( + name="T", + minimum_version=(2, 1, 0), + vendors=VENDORS, + usb_ids=((0x1209, 0x53C1), (0x1209, 0x53C0)), + default_mapping=mapping.DEFAULT_MAPPING, +) + +TREZORS = {MODEL_ONE, MODEL_T} + + +def by_name(name: str) -> Optional[TrezorModel]: + for model in TREZORS: + if model.name == name: + return model + return None diff --git a/python/src/trezorlib/transport/__init__.py b/python/src/trezorlib/transport/__init__.py index 8508e42b2fe..0828c6ed93a 100644 --- a/python/src/trezorlib/transport/__init__.py +++ b/python/src/trezorlib/transport/__init__.py @@ -29,17 +29,12 @@ from ..exceptions import TrezorException if TYPE_CHECKING: + from ..models import TrezorModel + T = TypeVar("T", bound="Transport") LOG = logging.getLogger(__name__) -# USB vendor/product IDs for Trezors -DEV_TREZOR1 = (0x534C, 0x0001) -DEV_TREZOR2 = (0x1209, 0x53C1) -DEV_TREZOR2_BL = (0x1209, 0x53C0) - -TREZORS = {DEV_TREZOR1, DEV_TREZOR2, DEV_TREZOR2_BL} - UDEV_RULES_STR = """ Do you have udev rules installed? https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules @@ -95,7 +90,9 @@ def find_debug(self: "T") -> "T": raise NotImplementedError @classmethod - def enumerate(cls: Type["T"]) -> Iterable["T"]: + def enumerate( + cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None + ) -> Iterable["T"]: raise NotImplementedError @classmethod @@ -126,12 +123,14 @@ def all_transports() -> Iterable[Type["Transport"]]: return set(t for t in transports if t.ENABLED) -def enumerate_devices() -> Sequence["Transport"]: +def enumerate_devices( + models: Optional[Iterable["TrezorModel"]] = None, +) -> Sequence["Transport"]: devices: List["Transport"] = [] for transport in all_transports(): name = transport.__name__ try: - found = list(transport.enumerate()) + found = list(transport.enumerate(models)) LOG.info(f"Enumerating {name}: found {len(found)} devices") devices.extend(found) except NotImplementedError: diff --git a/python/src/trezorlib/transport/bridge.py b/python/src/trezorlib/transport/bridge.py index e3855160715..d77e3693d6a 100644 --- a/python/src/trezorlib/transport/bridge.py +++ b/python/src/trezorlib/transport/bridge.py @@ -16,13 +16,16 @@ import logging import struct -from typing import Any, Dict, Iterable, Optional +from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional import requests from ..log import DUMP_PACKETS from . import MessagePayload, Transport, TransportException +if TYPE_CHECKING: + from ..models import TrezorModel + LOG = logging.getLogger(__name__) TREZORD_HOST = "http://127.0.0.1:21325" @@ -135,7 +138,9 @@ def _call(self, action: str, data: Optional[str] = None) -> requests.Response: return call_bridge(uri, data=data) @classmethod - def enumerate(cls) -> Iterable["BridgeTransport"]: + def enumerate( + cls, _models: Optional[Iterable["TrezorModel"]] = None + ) -> Iterable["BridgeTransport"]: try: legacy = is_legacy_bridge() return [ diff --git a/python/src/trezorlib/transport/hid.py b/python/src/trezorlib/transport/hid.py index 06e22afb8ab..8bfd98fad6b 100644 --- a/python/src/trezorlib/transport/hid.py +++ b/python/src/trezorlib/transport/hid.py @@ -17,10 +17,11 @@ import logging import sys import time -from typing import Any, Dict, Iterable, List +from typing import Any, Dict, Iterable, List, Optional from ..log import DUMP_PACKETS -from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException +from ..models import MODEL_ONE, TrezorModel +from . import UDEV_RULES_STR, TransportException from .protocol import ProtocolBasedTransport, ProtocolV1 LOG = logging.getLogger(__name__) @@ -132,11 +133,17 @@ def get_path(self) -> str: return f"{self.PATH_PREFIX}:{self.device['path'].decode()}" @classmethod - def enumerate(cls, debug: bool = False) -> Iterable["HidTransport"]: + def enumerate( + cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False + ) -> Iterable["HidTransport"]: + if models is None: + models = {MODEL_ONE} + usb_ids = [id for model in models for id in model.usb_ids] + devices: List["HidTransport"] = [] for dev in hid.enumerate(0, 0): usb_id = (dev["vendor_id"], dev["product_id"]) - if usb_id != DEV_TREZOR1: + if usb_id not in usb_ids: continue if debug: if not is_debuglink(dev): diff --git a/python/src/trezorlib/transport/udp.py b/python/src/trezorlib/transport/udp.py index 5f422594506..0bd3e43bcbb 100644 --- a/python/src/trezorlib/transport/udp.py +++ b/python/src/trezorlib/transport/udp.py @@ -17,12 +17,15 @@ import logging import socket import time -from typing import Iterable, Optional +from typing import TYPE_CHECKING, Iterable, Optional from ..log import DUMP_PACKETS from . import TransportException from .protocol import ProtocolBasedTransport, ProtocolV1 +if TYPE_CHECKING: + from ..models import TrezorModel + SOCKET_TIMEOUT = 10 LOG = logging.getLogger(__name__) @@ -70,7 +73,9 @@ def _try_path(cls, path: str) -> "UdpTransport": d.close() @classmethod - def enumerate(cls) -> Iterable["UdpTransport"]: + def enumerate( + cls, _models: Optional[Iterable["TrezorModel"]] = None + ) -> Iterable["UdpTransport"]: default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}" try: return [cls._try_path(default_path)] diff --git a/python/src/trezorlib/transport/webusb.py b/python/src/trezorlib/transport/webusb.py index cde54c08d15..cf71f088316 100644 --- a/python/src/trezorlib/transport/webusb.py +++ b/python/src/trezorlib/transport/webusb.py @@ -21,7 +21,8 @@ from typing import Iterable, List, Optional from ..log import DUMP_PACKETS -from . import TREZORS, UDEV_RULES_STR, TransportException +from ..models import TREZORS, TrezorModel +from . import UDEV_RULES_STR, TransportException from .protocol import ProtocolBasedTransport, ProtocolV1 LOG = logging.getLogger(__name__) @@ -114,15 +115,21 @@ def get_path(self) -> str: return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}" @classmethod - def enumerate(cls, usb_reset: bool = False) -> Iterable["WebUsbTransport"]: + def enumerate( + cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False + ) -> Iterable["WebUsbTransport"]: if cls.context is None: cls.context = usb1.USBContext() cls.context.open() atexit.register(cls.context.close) # type: ignore [Param spec "_P@register" has no bound value] + + if models is None: + models = TREZORS + usb_ids = [id for model in models for id in model.usb_ids] devices: List["WebUsbTransport"] = [] for dev in cls.context.getDeviceIterator(skip_on_error=True): usb_id = (dev.getVendorID(), dev.getProductID()) - if usb_id not in TREZORS: + if usb_id not in usb_ids: continue if not is_vendor_class(dev): continue