Skip to content

Commit

Permalink
Add support for Trezor
Browse files Browse the repository at this point in the history
  • Loading branch information
moisses89 committed Nov 30, 2023
1 parent e8c2a0a commit a8bc569
Show file tree
Hide file tree
Showing 12 changed files with 276 additions and 76 deletions.
16 changes: 11 additions & 5 deletions safe_cli/operators/hw_accounts/hw_account.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import re
from abc import ABC, abstractmethod
from typing import Tuple

from eth_typing import ChecksumAddress

BIP32_ETH_PATTERN = r"^(m/)?44'/60'/[0-9]+'/[0-9]+/[0-9]+$"
BIP32_LEGACY_LEDGER_PATTERN = r"^(m/)?44'/60'/[0-9]+'/[0-9]+$"


class InvalidDerivationPath(Exception):
message = "The provided derivation path is not valid"


class HwAccount(ABC):
def __init__(self, derivation_path: str, address: ChecksumAddress):
self.derivation_path = derivation_path
Expand All @@ -26,10 +29,13 @@ def is_valid_derivation_path(derivation_path: str):
"""
Detect if a string is a valid derivation path
"""
return (
if not (
re.match(BIP32_ETH_PATTERN, derivation_path) is not None
or re.match(BIP32_LEGACY_LEDGER_PATTERN, derivation_path) is not None
)
):
raise InvalidDerivationPath

return True

@staticmethod
@abstractmethod
Expand All @@ -41,12 +47,12 @@ def get_address_by_derivation_path(derivation_path: str) -> ChecksumAddress:
"""

@abstractmethod
def sign_typed_hash(self, domain_hash, message_hash) -> Tuple[bytes, bytes, bytes]:
def sign_typed_hash(self, domain_hash, message_hash) -> bytes:
"""
:param domain_hash:
:param message_hash:
:return: tuple os signature v, r, s
:return: signature
"""

def __eq__(self, other):
Expand Down
16 changes: 4 additions & 12 deletions safe_cli/operators/hw_accounts/hw_account_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@

from gnosis.eth.eip712 import eip712_encode
from gnosis.safe import SafeTx
from gnosis.safe.signatures import signature_to_bytes

from safe_cli.operators.exceptions import HardwareWalletException
from safe_cli.operators.hw_accounts.hw_account import HwAccount


Expand Down Expand Up @@ -88,12 +86,9 @@ def add_account(

hw_wallet = self.get_hw_wallet(hw_wallet_type)

if hw_wallet.is_valid_derivation_path(derivation_path):
address = hw_wallet.get_address_by_derivation_path(derivation_path)
self.accounts.add(hw_wallet(derivation_path, address))
return address
else:
raise HardwareWalletException("Invalid derivation path")
address = hw_wallet.get_address_by_derivation_path(derivation_path)
self.accounts.add(hw_wallet(derivation_path, address))
return address

def delete_accounts(self, addresses: List[ChecksumAddress]) -> Set:
"""
Expand Down Expand Up @@ -130,11 +125,8 @@ def sign_eip712(self, safe_tx: SafeTx, accounts: List[HwAccount]) -> SafeTx:
)
print_formatted_text(HTML(f"Domain_hash: <b>{domain_hash.hex()}</b>"))
print_formatted_text(HTML(f"Message_hash: <b>{message_hash.hex()}</b>"))
signed_v, signed_r, signed_s = account.sign_typed_hash(
domain_hash, message_hash
)
signature = account.sign_typed_hash(domain_hash, message_hash)

signature = signature_to_bytes(signed_v, signed_r, signed_s)
# TODO should be refactored on safe_eth_py function insert_signature_sorted
# Insert signature sorted
if account.address not in safe_tx.signers:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,19 @@
from ledgereth.exceptions import (
LedgerAppNotOpened,
LedgerCancel,
LedgerError,
LedgerLocked,
LedgerNotFound,
)

from safe_cli.operators.exceptions import HardwareWalletException


class InvalidDerivationPath(LedgerError):
message = "The provided derivation path is not valid"
from safe_cli.operators.hw_accounts.hw_account import InvalidDerivationPath


class UnsupportedHwWalletException(Exception):
pass


def raise_as_hw_account_exception(function):
def raise_ledger_exception_as_hw_account_exception(function):
@functools.wraps(function)
def wrapper(*args, **kwargs):
try:
Expand Down
24 changes: 14 additions & 10 deletions safe_cli/operators/hw_accounts/ledger_manager.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import Tuple

from eth_typing import ChecksumAddress
from ledgereth import sign_typed_data_draft
from ledgereth.accounts import get_account_by_path
from ledgereth.comms import init_dongle
from ledgereth.exceptions import LedgerNotFound

from safe_cli.operators.hw_accounts.exceptions import raise_as_hw_account_exception
from gnosis.safe.signatures import signature_to_bytes

from safe_cli.operators.hw_accounts.hw_account import HwAccount
from safe_cli.operators.hw_accounts.ledger_exceptions import (
raise_ledger_exception_as_hw_account_exception,
)


class LedgerManager(HwAccount):
Expand All @@ -28,14 +30,14 @@ def connect(self) -> bool:
return False

@property
@raise_as_hw_account_exception
@raise_ledger_exception_as_hw_account_exception
def connected(self) -> bool:
"""
:return: True if ledger is connected or False in other case
"""
return self.connect()

@raise_as_hw_account_exception
@raise_ledger_exception_as_hw_account_exception
def get_address_by_derivation_path(derivation_path: str) -> ChecksumAddress:
"""
Expand All @@ -44,12 +46,14 @@ def get_address_by_derivation_path(derivation_path: str) -> ChecksumAddress:
"""
if derivation_path[0:2] == "m/":
derivation_path = derivation_path.replace("m/", "")
account = get_account_by_path(derivation_path)
return account.address
if LedgerManager.is_valid_derivation_path(derivation_path):
account = get_account_by_path(derivation_path)
return account.address

@raise_as_hw_account_exception
def sign_typed_hash(self, domain_hash, message_hash) -> Tuple[bytes, bytes, bytes]:
@raise_ledger_exception_as_hw_account_exception
def sign_typed_hash(self, domain_hash, message_hash) -> bytes:
signed = sign_typed_data_draft(
domain_hash, message_hash, self.derivation_path, self.dongle
)
return (signed.v, signed.r, signed.s)

return signature_to_bytes(signed.v, signed.r, signed.s)
37 changes: 37 additions & 0 deletions safe_cli/operators/hw_accounts/trezor_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import functools

from trezorlib.exceptions import (
Cancelled,
OutdatedFirmwareError,
PinException,
TrezorFailure,
)
from trezorlib.transport import TransportException

from safe_cli.operators.exceptions import HardwareWalletException
from safe_cli.operators.hw_accounts.hw_account import InvalidDerivationPath


class UnsupportedHwWalletException(Exception):
pass


def raise_trezor_exception_as_hw_account_exception(function):
@functools.wraps(function)
def wrapper(*args, **kwargs):
try:
return function(*args, **kwargs)
except TrezorFailure as e:
raise HardwareWalletException(e.message)
except OutdatedFirmwareError:
raise HardwareWalletException("Trezor firmware version is not supported")
except PinException:
raise HardwareWalletException("Wrong PIN")
except Cancelled:
raise HardwareWalletException("Trezor operation was cancelled")
except TransportException:
raise HardwareWalletException("Trezor device is not connected")
except InvalidDerivationPath as e:
raise HardwareWalletException(e.message)

return wrapper
40 changes: 35 additions & 5 deletions safe_cli/operators/hw_accounts/trezor_manager.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,52 @@
from typing import Tuple
from functools import lru_cache

from eth_typing import ChecksumAddress
from trezorlib import tools
from trezorlib.client import TrezorClient, get_default_client
from trezorlib.ethereum import get_address, sign_typed_data_hash
from trezorlib.ui import ClickUI

from safe_cli.operators.hw_accounts.hw_account import HwAccount
from safe_cli.operators.hw_accounts.trezor_exceptions import (
raise_trezor_exception_as_hw_account_exception,
)


class TrezorManager(HwAccount):
def __init__(self, derivation_path: str, address: ChecksumAddress):
self.client = None
self.client = TrezorManager.client()
super().__init__(derivation_path, address)

# Caching client to avoid requesting passphrase between Trezor requests
@lru_cache(maxsize=None)
@staticmethod
@raise_trezor_exception_as_hw_account_exception
def client() -> TrezorClient:
"""
Return default trezor configuration that store passphrase on host.
This method is cached to share the same configuration between trezor calls while the class is not instantiated.
:return:
"""
ui = ClickUI(passphrase_on_host=True)
client = get_default_client(ui=ui)
return client

@raise_trezor_exception_as_hw_account_exception
def get_address_by_derivation_path(derivation_path: str) -> ChecksumAddress:
"""
:param derivation_path:
:return: public address for provided derivation_path
"""
raise NotImplementedError
if TrezorManager.is_valid_derivation_path(derivation_path):
client = TrezorManager.client()
address_n = tools.parse_path(derivation_path)
return get_address(client=client, n=address_n)

def sign_typed_hash(self, domain_hash, message_hash) -> Tuple[bytes, bytes, bytes]:
raise NotImplementedError
@raise_trezor_exception_as_hw_account_exception
def sign_typed_hash(self, domain_hash, message_hash) -> bytes:
address_n = tools.parse_path(self.derivation_path)
signed = sign_typed_data_hash(
self.client, n=address_n, domain_hash=domain_hash, message_hash=message_hash
)
return signed.signature
22 changes: 15 additions & 7 deletions safe_cli/operators/safe_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,14 +270,14 @@ def load_cli_owners(self, keys: List[str]):
except ValueError:
print_formatted_text(HTML(f"<ansired>Cannot load key={key}</ansired>"))

def load_ledger_cli_owners(
self, derivation_path: str = None, legacy_account: bool = False
def load_hw_wallet(
self, hw_wallet_type: HwWalletType, derivation_path: str, legacy_account: bool
):
if not self.hw_account_manager.is_supported_hw_wallet(HwWalletType.LEDGER):
if not self.hw_account_manager.is_supported_hw_wallet(hw_wallet_type):
return None
if derivation_path is None:
ledger_accounts = self.hw_account_manager.get_accounts(
HwWalletType.LEDGER, legacy_account=legacy_account
hw_wallet_type, legacy_account=legacy_account
)
if len(ledger_accounts) == 0:
return None
Expand All @@ -289,9 +289,7 @@ def load_ledger_cli_owners(
return None
_, derivation_path = ledger_accounts[option]

address = self.hw_account_manager.add_account(
HwWalletType.LEDGER, derivation_path
)
address = self.hw_account_manager.add_account(hw_wallet_type, derivation_path)
balance = self.ethereum_client.get_balance(address)
print_formatted_text(
HTML(
Expand All @@ -301,6 +299,16 @@ def load_ledger_cli_owners(
)
)

def load_ledger_cli_owners(
self, derivation_path: str = None, legacy_account: bool = False
):
self.load_hw_wallet(HwWalletType.LEDGER, derivation_path, legacy_account)

def load_trezor_cli_owners(
self, derivation_path: str = None, legacy_account: bool = False
):
self.load_hw_wallet(HwWalletType.TREZOR, derivation_path, legacy_account)

def unload_cli_owners(self, owners: List[str]):
accounts_to_remove: Set[Account] = set()
for owner in owners:
Expand Down
19 changes: 19 additions & 0 deletions safe_cli/prompt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,12 @@ def load_ledger_cli_owners(args):
derivation_path=args.derivation_path, legacy_account=args.legacy_accounts
)

@safe_exception
def load_trezor_cli_owners(args):
safe_operator.load_trezor_cli_owners(
derivation_path=args.derivation_path, legacy_account=args.legacy_accounts
)

@safe_exception
def unload_cli_owners(args):
safe_operator.unload_cli_owners(args.addresses)
Expand Down Expand Up @@ -332,6 +338,19 @@ def remove_delegate(args):
)
parser_load_ledger_cli_owners.set_defaults(func=load_ledger_cli_owners)

parser_load_trezor_cli_owners = subparsers.add_parser("load_trezor_cli_owners")
parser_load_trezor_cli_owners.add_argument(
"--derivation-path",
type=str,
help="Load address for the provided derivation path",
)
parser_load_trezor_cli_owners.add_argument(
"--legacy-accounts",
action="store_true",
help="Enable search legacy accounts",
)
parser_load_trezor_cli_owners.set_defaults(func=load_trezor_cli_owners)

parser_unload_cli_owners = subparsers.add_parser("unload_cli_owners")
parser_unload_cli_owners.add_argument(
"addresses", type=check_ethereum_address, nargs="+"
Expand Down
4 changes: 4 additions & 0 deletions safe_cli/safe_completer_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"info": "(read-only)",
"load_cli_owners": "<account-private-key> [<account-private-key>...]",
"load_ledger_cli_owners": "[--legacy-accounts] [--derivation-path <str>]",
"load_trezor_cli_owners": "[--legacy-accounts] [--derivation-path <str>]",
"load_cli_owners_from_words": "<word_1> <word_2> ... <word_12>",
"refresh": "",
"remove_delegate": "<address> <signer-address>",
Expand Down Expand Up @@ -160,6 +161,9 @@
"load_ledger_cli_owners": HTML(
"Command <b>load_ledger_cli_owners</b> show a list of ledger addresses to choose between them "
),
"load_trezor_cli_owners": HTML(
"Command <b>load_trezor_cli_owners</b> show a list of trezor addresses to choose between them "
),
"load_cli_owners_from_words": HTML(
"Command <b>load_cli_owners_from_words</b> will try to load owners via"
"<u>seed_words</u>. Only relevant accounts(owners) will be loaded"
Expand Down
Loading

0 comments on commit a8bc569

Please sign in to comment.