Skip to content

Commit

Permalink
feat(python): introduce Trezor models
Browse files Browse the repository at this point in the history
This keeps information about vendors and USB IDs in one place, and
allows us to extend with model-specific information later.

By default, this should be backwards-compatible -- TrezorClient can
optionally accept model information, and if not, it will try to guess
based on Features.

It is possible to specify which models to look for in transport
enumeration. Bridge and UDP transports ignore the parameter, because
they can't know what model is on the other side.

supersedes #1448 and #1449
  • Loading branch information
matejcik committed Nov 29, 2021
1 parent 486d8d2 commit 49f2e69
Show file tree
Hide file tree
Showing 10 changed files with 118 additions and 43 deletions.
7 changes: 0 additions & 7 deletions python/src/trezorlib/__init__.py
Expand Up @@ -15,10 +15,3 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.

__version__ = "0.13.0"

# fmt: off
MINIMUM_FIRMWARE_VERSION = {
"1": (1, 8, 0),
"T": (2, 1, 0),
}
# fmt: on
28 changes: 21 additions & 7 deletions python/src/trezorlib/client.py
Expand Up @@ -21,7 +21,7 @@

from mnemonic import Mnemonic

from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages
from . import exceptions, mapping, messages, models
from .log import DUMP_BYTES
from .messages import Capability
from .tools import expect, parse_path, session
Expand All @@ -33,7 +33,6 @@

LOG = logging.getLogger(__name__)

VENDORS = ("bitcointrezor.com", "trezor.io")
MAX_PASSPHRASE_LENGTH = 50
MAX_PIN_LENGTH = 50

Expand Down Expand Up @@ -85,6 +84,7 @@ def __init__(
ui: "TrezorClientUI",
session_id: Optional[bytes] = None,
derive_cardano: Optional[bool] = None,
model: Optional[models.TrezorModel] = None,
_init_device: bool = True,
) -> None:
"""Create a TrezorClient instance.
Expand All @@ -101,6 +101,9 @@ def __init__(
You can supply a `session_id` you might have saved in the previous session. If
you do, the user might not need to enter their passphrase again.
You can provide Trezor model information. If not provided, it is detected from
the model name reported at initialization time.
By default, the instance will open a connection to the Trezor device, send an
`Initialize` message, set up the `features` field from the response, and connect
to a session. By specifying `_init_device=False`, this step is skipped. Notably,
Expand All @@ -110,7 +113,11 @@ def __init__(
might be removed at any time.
"""
LOG.info(f"creating client instance for device: {transport.get_path()}")
self.mapping = mapping.DEFAULT_MAPPING
self.model = model
if self.model:
self.mapping = self.model.default_mapping
else:
self.mapping = mapping.DEFAULT_MAPPING
self.transport = transport
self.ui = ui
self.session_counter = 0
Expand Down Expand Up @@ -254,7 +261,14 @@ def call(self, msg: "MessageType") -> "MessageType":

def _refresh_features(self, features: messages.Features) -> None:
"""Update internal fields based on passed-in Features message."""
if features.vendor not in VENDORS:

if not self.model:
# Trezor Model One bootloader 1.8.0 or older does not send model name
self.model = models.by_name(features.model or "1")
if self.model is None:
raise RuntimeError("Unsupported Trezor model")

if features.vendor not in self.model.vendors:
raise RuntimeError("Unsupported device")

self.features = features
Expand Down Expand Up @@ -353,9 +367,9 @@ def init_device(
def is_outdated(self) -> bool:
if self.features.bootloader_mode:
return False
model = self.features.model or "1"
required_version = MINIMUM_FIRMWARE_VERSION[model]
return self.version < required_version

assert self.model is not None # should happen in _refresh_features
return self.version < self.model.minimum_version

def check_firmware_version(self, warn_only: bool = False) -> None:
if self.is_outdated():
Expand Down
43 changes: 43 additions & 0 deletions python/src/trezorlib/models.py
@@ -0,0 +1,43 @@
from dataclasses import dataclass
from typing import Collection, Optional, Tuple

from . import mapping

UsbId = Tuple[int, int]

VENDORS = ("bitcointrezor.com", "trezor.io")


@dataclass(eq=True, frozen=True)
class TrezorModel:
name: str
minimum_version: Tuple[int, int, int]
vendors: Collection[str]
usb_ids: Collection[UsbId]
default_mapping: mapping.ProtobufMapping


TREZOR_ONE = TrezorModel(
name="1",
minimum_version=(1, 8, 0),
vendors=VENDORS,
usb_ids=((0x534C, 0x0001),),
default_mapping=mapping.DEFAULT_MAPPING,
)

TREZOR_T = TrezorModel(
name="T",
minimum_version=(2, 1, 0),
vendors=VENDORS,
usb_ids=((0x1209, 0x53C1), (0x1209, 0x53C0)),
default_mapping=mapping.DEFAULT_MAPPING,
)

TREZORS = {TREZOR_ONE, TREZOR_T}


def by_name(name: str) -> Optional[TrezorModel]:
for model in TREZORS:
if model.name == name:
return model
return None
19 changes: 9 additions & 10 deletions python/src/trezorlib/transport/__init__.py
Expand Up @@ -29,17 +29,12 @@
from ..exceptions import TrezorException

if TYPE_CHECKING:
from ..models import TrezorModel

T = TypeVar("T", bound="Transport")

LOG = logging.getLogger(__name__)

# USB vendor/product IDs for Trezors
DEV_TREZOR1 = (0x534C, 0x0001)
DEV_TREZOR2 = (0x1209, 0x53C1)
DEV_TREZOR2_BL = (0x1209, 0x53C0)

TREZORS = {DEV_TREZOR1, DEV_TREZOR2, DEV_TREZOR2_BL}

UDEV_RULES_STR = """
Do you have udev rules installed?
https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules
Expand Down Expand Up @@ -95,7 +90,9 @@ def find_debug(self: "T") -> "T":
raise NotImplementedError

@classmethod
def enumerate(cls: Type["T"]) -> Iterable["T"]:
def enumerate(
cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None
) -> Iterable["T"]:
raise NotImplementedError

@classmethod
Expand Down Expand Up @@ -126,12 +123,14 @@ def all_transports() -> Iterable[Type["Transport"]]:
return set(t for t in transports if t.ENABLED)


def enumerate_devices() -> Sequence["Transport"]:
def enumerate_devices(
models: Optional[Iterable["TrezorModel"]] = None,
) -> Sequence["Transport"]:
devices: List["Transport"] = []
for transport in all_transports():
name = transport.__name__
try:
found = list(transport.enumerate())
found = list(transport.enumerate(models))
LOG.info(f"Enumerating {name}: found {len(found)} devices")
devices.extend(found)
except NotImplementedError:
Expand Down
9 changes: 7 additions & 2 deletions python/src/trezorlib/transport/bridge.py
Expand Up @@ -16,13 +16,16 @@

import logging
import struct
from typing import Any, Dict, Iterable, Optional
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional

import requests

from ..log import DUMP_PACKETS
from . import MessagePayload, Transport, TransportException

if TYPE_CHECKING:
from ..models import TrezorModel

LOG = logging.getLogger(__name__)

TREZORD_HOST = "http://127.0.0.1:21325"
Expand Down Expand Up @@ -135,7 +138,9 @@ def _call(self, action: str, data: Optional[str] = None) -> requests.Response:
return call_bridge(uri, data=data)

@classmethod
def enumerate(cls) -> Iterable["BridgeTransport"]:
def enumerate(
cls, _models: Optional[Iterable["TrezorModel"]] = None
) -> Iterable["BridgeTransport"]:
try:
legacy = is_legacy_bridge()
return [
Expand Down
15 changes: 11 additions & 4 deletions python/src/trezorlib/transport/hid.py
Expand Up @@ -17,10 +17,11 @@
import logging
import sys
import time
from typing import Any, Dict, Iterable, List
from typing import Any, Dict, Iterable, List, Optional

from ..log import DUMP_PACKETS
from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException
from ..models import TREZOR_ONE, TrezorModel
from . import UDEV_RULES_STR, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -132,11 +133,17 @@ def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"

@classmethod
def enumerate(cls, debug: bool = False) -> Iterable["HidTransport"]:
def enumerate(
cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False
) -> Iterable["HidTransport"]:
if models is None:
models = {TREZOR_ONE}
usb_ids = [id for model in models for id in model.usb_ids]

devices: List["HidTransport"] = []
for dev in hid.enumerate(0, 0):
usb_id = (dev["vendor_id"], dev["product_id"])
if usb_id != DEV_TREZOR1:
if usb_id not in usb_ids:
continue
if debug:
if not is_debuglink(dev):
Expand Down
9 changes: 7 additions & 2 deletions python/src/trezorlib/transport/udp.py
Expand Up @@ -17,12 +17,15 @@
import logging
import socket
import time
from typing import Iterable, Optional
from typing import TYPE_CHECKING, Iterable, Optional

from ..log import DUMP_PACKETS
from . import TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1

if TYPE_CHECKING:
from ..models import TrezorModel

SOCKET_TIMEOUT = 10

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -70,7 +73,9 @@ def _try_path(cls, path: str) -> "UdpTransport":
d.close()

@classmethod
def enumerate(cls) -> Iterable["UdpTransport"]:
def enumerate(
cls, _models: Optional[Iterable["TrezorModel"]] = None
) -> Iterable["UdpTransport"]:
default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}"
try:
return [cls._try_path(default_path)]
Expand Down
13 changes: 10 additions & 3 deletions python/src/trezorlib/transport/webusb.py
Expand Up @@ -21,7 +21,8 @@
from typing import Iterable, List, Optional

from ..log import DUMP_PACKETS
from . import TREZORS, UDEV_RULES_STR, TransportException
from ..models import TREZORS, TrezorModel
from . import UDEV_RULES_STR, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -114,15 +115,21 @@ def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"

@classmethod
def enumerate(cls, usb_reset: bool = False) -> Iterable["WebUsbTransport"]:
def enumerate(
cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False
) -> Iterable["WebUsbTransport"]:
if cls.context is None:
cls.context = usb1.USBContext()
cls.context.open()
atexit.register(cls.context.close) # type: ignore [Param spec "_P@register" has no bound value]

if models is None:
models = TREZORS
usb_ids = [id for model in models for id in model.usb_ids]
devices: List["WebUsbTransport"] = []
for dev in cls.context.getDeviceIterator(skip_on_error=True):
usb_id = (dev.getVendorID(), dev.getProductID())
if usb_id not in TREZORS:
if usb_id not in usb_ids:
continue
if not is_vendor_class(dev):
continue
Expand Down
10 changes: 6 additions & 4 deletions tests/upgrade_tests/test_firmware_upgrades.py
Expand Up @@ -14,9 +14,11 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.

import dataclasses

import pytest

from trezorlib import MINIMUM_FIRMWARE_VERSION, btc, debuglink, device, exceptions, fido
from trezorlib import btc, debuglink, device, exceptions, fido, models
from trezorlib.messages import BackupType
from trezorlib.tools import H_

Expand All @@ -26,9 +28,9 @@
from ..emulators import ALL_TAGS, EmulatorWrapper
from . import for_all, for_tags

MINIMUM_FIRMWARE_VERSION["1"] = (1, 0, 0)
MINIMUM_FIRMWARE_VERSION["T"] = (2, 0, 0)

models.TREZOR_ONE = dataclasses.replace(models.TREZOR_ONE, minimum_version=(1, 0, 0))
models.TREZOR_T = dataclasses.replace(models.TREZOR_T, minimum_version=(2, 0, 0))
models.TREZORS = {models.TREZOR_ONE, models.TREZOR_T}

# **** COMMON DEFINITIONS ****

Expand Down
8 changes: 4 additions & 4 deletions tests/upgrade_tests/test_passphrase_consistency.py
Expand Up @@ -16,7 +16,7 @@

import pytest

from trezorlib import MINIMUM_FIRMWARE_VERSION, btc, device, mapping, messages, protobuf
from trezorlib import btc, device, mapping, messages, models, protobuf
from trezorlib.tools import parse_path

from ..emulators import EmulatorWrapper
Expand Down Expand Up @@ -57,8 +57,8 @@ def emulator(gen, tag):


@for_all(
core_minimum_version=MINIMUM_FIRMWARE_VERSION["T"],
legacy_minimum_version=MINIMUM_FIRMWARE_VERSION["1"],
core_minimum_version=models.TREZOR_T.minimum_version,
legacy_minimum_version=models.TREZOR_ONE.minimum_version,
)
def test_passphrase_works(emulator):
"""Check that passphrase handling in trezorlib works correctly in all versions."""
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_passphrase_works(emulator):


@for_all(
core_minimum_version=MINIMUM_FIRMWARE_VERSION["T"],
core_minimum_version=models.TREZOR_T.minimum_version,
legacy_minimum_version=(1, 9, 0),
)
def test_init_device(emulator):
Expand Down

0 comments on commit 49f2e69

Please sign in to comment.