diff --git a/README.md b/README.md index 13d3079..88f5565 100644 --- a/README.md +++ b/README.md @@ -139,7 +139,10 @@ the information about the Safe using: ``` > refresh ``` -## Ledger module +## Hardware wallets support +**NOTE**: before signing anything ensure that the data showing on your hardware wallet device is the same as the safe-cli data. + +### Ledger Ledger module is an optional feature of safe-cli to sign transactions with the help of [ledgereth](https://github.com/mikeshultz/ledger-eth-lib) library based on [ledgerblue](https://github.com/LedgerHQ/blue-loader-python). To enable, safe-cli must be installed as follows: @@ -155,7 +158,19 @@ SUBSYSTEMS=="usb", ATTRS{idVendor}=="2c97", ATTRS{idProduct}=="0004", MODE="0660 Safe-cli Ledger commands: - `load_ledger_cli_owners [--legacy-accounts] [--derivation-path ]`: show a list of the first 5 accounts (--legacy-accounts search using ledger legacy derivation) or load an account from provided derivation path. -**NOTE**: before signing anything ensure that the data showing on your ledger is the same as the safe-cli data. +### Trezor +Trezor module is an optional feature of safe-cli to sign transactions from Trezor hardware wallet using the [trezor](https://pypi.org/project/trezor/) library. + +To enable, safe-cli must be installed as follows: +``` +pip install safe-cli[trezor] +``` + +### Enable multiple hardware wallets +``` +pip install safe-cli[ledger, trezor] +``` + ## Creating a new Safe Use `safe-creator --owners --threshold --salt-nonce `. diff --git a/requirements.txt b/requirements.txt index e6b623c..ca0c51f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ pygments==2.17.2 requests==2.31.0 safe-eth-py==6.0.0b8 tabulate==0.9.0 +trezor==0.13.8 web3==6.11.4 diff --git a/safe_cli/operators/hw_accounts/ledger_manager.py b/safe_cli/operators/hw_accounts/ledger_manager.py deleted file mode 100644 index 643d4ed..0000000 --- a/safe_cli/operators/hw_accounts/ledger_manager.py +++ /dev/null @@ -1,141 +0,0 @@ -from typing import List, Optional, Set, Tuple - -from eth_typing import ChecksumAddress -from ledgereth import sign_typed_data_draft -from ledgereth.accounts import LedgerAccount, get_account_by_path -from ledgereth.comms import init_dongle -from ledgereth.exceptions import LedgerNotFound -from ledgereth.utils import is_bip32_path -from prompt_toolkit import HTML, print_formatted_text - -from gnosis.eth.eip712 import eip712_encode -from gnosis.safe import SafeTx -from gnosis.safe.signatures import signature_to_bytes - -from safe_cli.operators.hw_accounts.exceptions import ( - InvalidDerivationPath, - raise_as_hw_account_exception, -) - - -class LedgerManager: - def __init__(self): - self.dongle = None - self.accounts: Set[LedgerAccount] = set() - self.connect() - - def connect(self) -> bool: - """ - Connect with ledger - :return: True if connection was successful or False in other case - """ - try: - self.dongle = init_dongle(self.dongle) - return True - except LedgerNotFound: - return False - - @property - @raise_as_hw_account_exception - def connected(self) -> bool: - """ - :return: True if ledger is connected or False in other case - """ - return self.connect() - - @raise_as_hw_account_exception - def get_accounts( - self, legacy_account: Optional[bool] = False, number_accounts: Optional[int] = 5 - ) -> List[Tuple[ChecksumAddress, str]]: - """ - Request to ledger device the first n accounts - - :param legacy_account: - :param number_accounts: number of accounts requested to ledger - :return: a list of tuples with address and derivation path - """ - accounts = [] - for i in range(number_accounts): - if legacy_account: - path_string = f"44'/60'/0'/{i}" - else: - path_string = f"44'/60'/{i}'/0/0" - - account = get_account_by_path(path_string, self.dongle) - accounts.append((account.address, account.path)) - return accounts - - @raise_as_hw_account_exception - def add_account(self, derivation_path: str) -> ChecksumAddress: - """ - Add an account to ledger manager set and return the added address - - :param derivation_path: - :return: - """ - # we should accept m/ or m'/ starting derivation paths - if derivation_path[0:2] == "m/": - derivation_path = derivation_path.replace("m/", "") - - if not is_bip32_path(derivation_path): - raise InvalidDerivationPath() - - account = get_account_by_path(derivation_path, self.dongle) - self.accounts.add(LedgerAccount(account.path, account.address)) - return account.address - - def delete_accounts(self, addresses: List[ChecksumAddress]) -> Set: - """ - Remove ledger accounts from address - - :param accounts: - :return: list with the delete accounts - """ - accounts_to_remove = set() - for address in addresses: - for account in self.accounts: - if account.address == address: - accounts_to_remove.add(account) - self.accounts = self.accounts.difference(accounts_to_remove) - return accounts_to_remove - - @raise_as_hw_account_exception - def sign_eip712(self, safe_tx: SafeTx, accounts: List[LedgerAccount]) -> SafeTx: - """ - Call ledger ethereum app method to sign eip712 hashes with a ledger account - - :param domain_hash: - :param message_hash: - :param account: ledger account - :return: bytes of signature - """ - encode_hash = eip712_encode(safe_tx.eip712_structured_data) - domain_hash = encode_hash[1] - message_hash = encode_hash[2] - for account in accounts: - print_formatted_text( - HTML( - "Make sure in your ledger before signing that domain_hash and message_hash are both correct" - ) - ) - print_formatted_text(HTML(f"Domain_hash: {domain_hash.hex()}")) - print_formatted_text(HTML(f"Message_hash: {message_hash.hex()}")) - signed = sign_typed_data_draft( - domain_hash, message_hash, account.path, self.dongle - ) - - signature = signature_to_bytes(signed.v, signed.r, signed.s) - # TODO should be refactored on safe_eth_py function insert_signature_sorted - # Insert signature sorted - if account.address not in safe_tx.signers: - new_owners = safe_tx.signers + [account.address] - new_owner_pos = sorted(new_owners, key=lambda x: int(x, 16)).index( - account.address - ) - safe_tx.signatures = ( - safe_tx.signatures[: 65 * new_owner_pos] - + signature - + safe_tx.signatures[65 * new_owner_pos :] - ) - - return safe_tx diff --git a/safe_cli/operators/hw_accounts/__init__.py b/safe_cli/operators/hw_wallets/__init__.py similarity index 100% rename from safe_cli/operators/hw_accounts/__init__.py rename to safe_cli/operators/hw_wallets/__init__.py diff --git a/safe_cli/operators/hw_wallets/constants.py b/safe_cli/operators/hw_wallets/constants.py new file mode 100644 index 0000000..6c90557 --- /dev/null +++ b/safe_cli/operators/hw_wallets/constants.py @@ -0,0 +1,2 @@ +BIP32_ETH_PATTERN = r"^44'/60'/[0-9]+'/[0-9]+/[0-9]+$" +BIP32_LEGACY_LEDGER_PATTERN = r"^44'/60'/[0-9]+'/[0-9]+$" diff --git a/safe_cli/operators/hw_wallets/exceptions.py b/safe_cli/operators/hw_wallets/exceptions.py new file mode 100644 index 0000000..3b1507e --- /dev/null +++ b/safe_cli/operators/hw_wallets/exceptions.py @@ -0,0 +1,2 @@ +class InvalidDerivationPath(Exception): + message = "The provided derivation path is not valid" diff --git a/safe_cli/operators/hw_wallets/hw_wallet.py b/safe_cli/operators/hw_wallets/hw_wallet.py new file mode 100644 index 0000000..ab79338 --- /dev/null +++ b/safe_cli/operators/hw_wallets/hw_wallet.py @@ -0,0 +1,59 @@ +import re +from abc import ABC, abstractmethod + +from .constants import BIP32_ETH_PATTERN, BIP32_LEGACY_LEDGER_PATTERN +from .exceptions import InvalidDerivationPath + + +class HwWallet(ABC): + def __init__(self, derivation_path: str): + derivation_path = derivation_path.replace("m/", "") + if self._is_valid_derivation_path(derivation_path): + self.derivation_path = derivation_path + self.address = self.get_address() + + @property + def get_derivation_path(self): + return self.derivation_path + + @abstractmethod + def get_address(self): + """ + + :return: + """ + + def _is_valid_derivation_path(self, derivation_path: str): + """ + Detect if a string is a valid derivation path + """ + if not ( + re.match(BIP32_ETH_PATTERN, derivation_path) is not None + or re.match(BIP32_LEGACY_LEDGER_PATTERN, derivation_path) is not None + ): + raise InvalidDerivationPath() + + return True + + @abstractmethod + def sign_typed_hash(self, domain_hash: bytes, message_hash: bytes) -> bytes: + """ + + :param domain_hash: + :param message_hash: + :return: signature bytes + """ + + def __str__(self): + return f"{self.__class__.__name__} device with address {self.address}" + + def __eq__(self, other): + if isinstance(other, HwWallet): + return ( + self.derivation_path == other.derivation_path + and self.address == other.address + ) + return False + + def __hash__(self): + return hash((self.derivation_path, self.address)) diff --git a/safe_cli/operators/hw_wallets/hw_wallet_manager.py b/safe_cli/operators/hw_wallets/hw_wallet_manager.py new file mode 100644 index 0000000..a828859 --- /dev/null +++ b/safe_cli/operators/hw_wallets/hw_wallet_manager.py @@ -0,0 +1,138 @@ +from enum import Enum +from functools import lru_cache +from typing import Dict, List, Optional, Set, Tuple + +from eth_typing import ChecksumAddress +from prompt_toolkit import HTML, print_formatted_text + +from gnosis.eth.eip712 import eip712_encode +from gnosis.safe import SafeTx + +from .hw_wallet import HwWallet + + +class HwWalletType(Enum): + TREZOR = 0 + LEDGER = 1 + + +@lru_cache(maxsize=None) +def get_hw_wallet_manager(): + return HwWalletManager() + + +class HwWalletManager: + def __init__(self): + self.wallets: Set[HwWallet] = set() + self.supported_hw_wallet_types: Dict[str, HwWallet] = {} + try: + from .ledger_wallet import LedgerWallet + + self.supported_hw_wallet_types[HwWalletType.LEDGER] = LedgerWallet + except (ImportError): + pass + + try: + from .trezor_wallet import TrezorWallet + + self.supported_hw_wallet_types[HwWalletType.TREZOR] = TrezorWallet + except (ImportError): + pass + + def is_supported_hw_wallet(self, hw_wallet_type: HwWalletType) -> bool: + return hw_wallet_type in self.supported_hw_wallet_types + + def get_hw_wallet(self, hw_wallet_type: HwWalletType) -> Optional[HwWallet]: + if hw_wallet_type in self.supported_hw_wallet_types: + return self.supported_hw_wallet_types[hw_wallet_type] + + def get_accounts( + self, + hw_wallet_type: HwWalletType, + legacy_account: Optional[bool] = False, + number_accounts: Optional[int] = 5, + ) -> List[Tuple[ChecksumAddress, str]]: + """ + + :param hw_wallet: Trezor or Ledger + :param legacy_account: + :param number_accounts: number of accounts requested to ledger + :return: a list of tuples with address and derivation path + """ + accounts = [] + hw_wallet = self.get_hw_wallet(hw_wallet_type) + for i in range(number_accounts): + if legacy_account: + path_string = f"44'/60'/0'/{i}" + else: + path_string = f"44'/60'/{i}'/0/0" + + accounts.append((hw_wallet(path_string).address, path_string)) + return accounts + + def add_account( + self, hw_wallet_type: HwWalletType, derivation_path: str + ) -> ChecksumAddress: + """ + Add an account to ledger manager set and return the added address + + :param derivation_path: + :return: + """ + + hw_wallet = self.get_hw_wallet(hw_wallet_type) + + address = hw_wallet(derivation_path).address + self.wallets.add(hw_wallet(derivation_path)) + return address + + def delete_accounts(self, addresses: List[ChecksumAddress]) -> Set: + """ + Remove ledger accounts from address + + :param accounts: + :return: list with the delete accounts + """ + accounts_to_remove = set() + for address in addresses: + for account in self.wallets: + if account.address == address: + accounts_to_remove.add(account) + self.wallets = self.wallets.difference(accounts_to_remove) + return accounts_to_remove + + def sign_eip712(self, safe_tx: SafeTx, wallets: List[HwWallet]) -> SafeTx: + """ + Sign a safeTx EIP-712 hashes with supported hw wallet devices + + :param domain_hash: + :param message_hash: + :param wallets: list of HwWallet + :return: signed safeTx + """ + encode_hash = eip712_encode(safe_tx.eip712_structured_data) + domain_hash = encode_hash[1] + message_hash = encode_hash[2] + for wallet in wallets: + print_formatted_text( + HTML( + f"Make sure before signing in your {wallet} that the domain_hash and message_hash are both correct" + ) + ) + print_formatted_text(HTML(f"Domain_hash: {domain_hash.hex()}")) + print_formatted_text(HTML(f"Message_hash: {message_hash.hex()}")) + signature = wallet.sign_typed_hash(domain_hash, message_hash) + + # Insert signature sorted + if wallet.address not in safe_tx.signers: + new_owners = safe_tx.signers + [wallet.address] + new_owner_pos = sorted(new_owners, key=lambda x: int(x, 16)).index( + wallet.address + ) + safe_tx.signatures = ( + safe_tx.signatures[: 65 * new_owner_pos] + + signature + + safe_tx.signatures[65 * new_owner_pos :] + ) + + return safe_tx diff --git a/safe_cli/operators/hw_accounts/exceptions.py b/safe_cli/operators/hw_wallets/ledger_exceptions.py similarity index 80% rename from safe_cli/operators/hw_accounts/exceptions.py rename to safe_cli/operators/hw_wallets/ledger_exceptions.py index b12dcb7..7d4b129 100644 --- a/safe_cli/operators/hw_accounts/exceptions.py +++ b/safe_cli/operators/hw_wallets/ledger_exceptions.py @@ -3,19 +3,15 @@ from ledgereth.exceptions import ( LedgerAppNotOpened, LedgerCancel, - LedgerError, LedgerLocked, LedgerNotFound, ) -from safe_cli.operators.exceptions import HardwareWalletException +from ..exceptions import HardwareWalletException +from .exceptions import InvalidDerivationPath -class InvalidDerivationPath(LedgerError): - message = "The provided derivation path is not valid" - - -def raise_as_hw_account_exception(function): +def raise_ledger_exception_as_hw_wallet_exception(function): @functools.wraps(function) def wrapper(*args, **kwargs): try: diff --git a/safe_cli/operators/hw_wallets/ledger_wallet.py b/safe_cli/operators/hw_wallets/ledger_wallet.py new file mode 100644 index 0000000..31e01d2 --- /dev/null +++ b/safe_cli/operators/hw_wallets/ledger_wallet.py @@ -0,0 +1,51 @@ +from typing import Optional + +from eth_typing import ChecksumAddress +from ledgerblue.Dongle import Dongle +from ledgereth import sign_typed_data_draft +from ledgereth.accounts import get_account_by_path +from ledgereth.comms import init_dongle + +from gnosis.safe.signatures import signature_to_bytes + +from .hw_wallet import HwWallet +from .ledger_exceptions import raise_ledger_exception_as_hw_wallet_exception + + +class LedgerWallet(HwWallet): + @raise_ledger_exception_as_hw_wallet_exception + def __init__(self, derivation_path: str): + self.dongle: Optional[Dongle] = None + self.connect() + super().__init__(derivation_path) + + @raise_ledger_exception_as_hw_wallet_exception + def connect(self) -> bool: + """ + Connect with ledger + :return: True if connection was successful or False in other case + """ + self.dongle = init_dongle(self.dongle) + + @raise_ledger_exception_as_hw_wallet_exception + def get_address(self) -> ChecksumAddress: + """ + + :return: public address for provided derivation_path + """ + account = get_account_by_path(self.derivation_path) + return account.address + + @raise_ledger_exception_as_hw_wallet_exception + def sign_typed_hash(self, domain_hash: bytes, message_hash: bytes) -> bytes: + """ + + :param domain_hash: + :param message_hash: + :return: signature bytes + """ + signed = sign_typed_data_draft( + domain_hash, message_hash, self.derivation_path, self.dongle + ) + + return signature_to_bytes(signed.v, signed.r, signed.s) diff --git a/safe_cli/operators/hw_wallets/trezor_exceptions.py b/safe_cli/operators/hw_wallets/trezor_exceptions.py new file mode 100644 index 0000000..2148ece --- /dev/null +++ b/safe_cli/operators/hw_wallets/trezor_exceptions.py @@ -0,0 +1,33 @@ +import functools + +from trezorlib.exceptions import ( + Cancelled, + OutdatedFirmwareError, + PinException, + TrezorFailure, +) +from trezorlib.transport import TransportException + +from ..exceptions import HardwareWalletException +from .exceptions import InvalidDerivationPath + + +def raise_trezor_exception_as_hw_wallet_exception(function): + @functools.wraps(function) + def wrapper(*args, **kwargs): + try: + return function(*args, **kwargs) + except TrezorFailure as e: + raise HardwareWalletException(e.message) + except OutdatedFirmwareError: + raise HardwareWalletException("Trezor firmware version is not supported") + except PinException: + raise HardwareWalletException("Wrong PIN") + except Cancelled: + raise HardwareWalletException("Trezor operation was cancelled") + except TransportException: + raise HardwareWalletException("Trezor device is not connected") + except InvalidDerivationPath as e: + raise HardwareWalletException(e.message) + + return wrapper diff --git a/safe_cli/operators/hw_wallets/trezor_wallet.py b/safe_cli/operators/hw_wallets/trezor_wallet.py new file mode 100644 index 0000000..0ac2ed8 --- /dev/null +++ b/safe_cli/operators/hw_wallets/trezor_wallet.py @@ -0,0 +1,51 @@ +from functools import lru_cache + +from eth_typing import ChecksumAddress +from trezorlib import tools +from trezorlib.client import TrezorClient, get_default_client +from trezorlib.ethereum import get_address, sign_typed_data_hash +from trezorlib.ui import ClickUI + +from .hw_wallet import HwWallet +from .trezor_exceptions import raise_trezor_exception_as_hw_wallet_exception + + +@lru_cache(maxsize=None) +@raise_trezor_exception_as_hw_wallet_exception +def get_trezor_client() -> TrezorClient: + """ + Return default trezor configuration that store passphrase on host. + This method is cached to share the same configuration between trezor calls while the class is not instantiated. + :return: + """ + ui = ClickUI(passphrase_on_host=True) + client = get_default_client(ui=ui) + return client + + +class TrezorWallet(HwWallet): + def __init__(self, derivation_path: str): + self.client: TrezorClient = get_trezor_client() + super().__init__(derivation_path) + + @raise_trezor_exception_as_hw_wallet_exception + def get_address(self) -> ChecksumAddress: + """ + :return: public address for derivation_path + """ + address_n = tools.parse_path(self.derivation_path) + return get_address(client=self.client, n=address_n) + + @raise_trezor_exception_as_hw_wallet_exception + def sign_typed_hash(self, domain_hash: bytes, message_hash: bytes) -> bytes: + """ + + :param domain_hash: + :param message_hash: + :return: signature bytes + """ + address_n = tools.parse_path(self.derivation_path) + signed = sign_typed_data_hash( + self.client, n=address_n, domain_hash=domain_hash, message_hash=message_hash + ) + return signed.signature diff --git a/safe_cli/operators/safe_operator.py b/safe_cli/operators/safe_operator.py index 0217cfc..5ca756b 100644 --- a/safe_cli/operators/safe_operator.py +++ b/safe_cli/operators/safe_operator.py @@ -66,6 +66,7 @@ from safe_cli.utils import choose_option_from_list, get_erc_20_list, yes_or_no_question from ..contracts import safe_to_l2_migration +from .hw_wallets.hw_wallet_manager import HwWalletType, get_hw_wallet_manager @dataclasses.dataclass @@ -123,19 +124,6 @@ def decorated(self, *args, **kwargs): return decorated -def load_ledger_manager(): - """ - Load ledgerManager if dependencies are installed - :return: LedgerManager or None - """ - try: - from safe_cli.operators.hw_accounts.ledger_manager import LedgerManager - - return LedgerManager() - except (ModuleNotFoundError, IOError): - return None - - class SafeOperator: address: ChecksumAddress node_url: str @@ -183,7 +171,7 @@ def __init__(self, address: ChecksumAddress, node_url: str): self.require_all_signatures = ( True # Require all signatures to be present to send a tx ) - self.ledger_manager = load_ledger_manager() + self.hw_wallet_manager = get_hw_wallet_manager() @cached_property def last_default_fallback_handler_address(self) -> ChecksumAddress: @@ -282,14 +270,14 @@ def load_cli_owners(self, keys: List[str]): except ValueError: print_formatted_text(HTML(f"Cannot load key={key}")) - def load_ledger_cli_owners( - self, derivation_path: str = None, legacy_account: bool = False + def load_hw_wallet( + self, hw_wallet_type: HwWalletType, derivation_path: str, legacy_account: bool ): - if not self.ledger_manager: + if not self.hw_wallet_manager.is_supported_hw_wallet(hw_wallet_type): return None if derivation_path is None: - ledger_accounts = self.ledger_manager.get_accounts( - legacy_account=legacy_account + ledger_accounts = self.hw_wallet_manager.get_accounts( + hw_wallet_type, legacy_account=legacy_account ) if len(ledger_accounts) == 0: return None @@ -301,7 +289,7 @@ def load_ledger_cli_owners( return None _, derivation_path = ledger_accounts[option] - address = self.ledger_manager.add_account(derivation_path) + address = self.hw_wallet_manager.add_account(hw_wallet_type, derivation_path) balance = self.ethereum_client.get_balance(address) print_formatted_text( HTML( @@ -311,6 +299,16 @@ def load_ledger_cli_owners( ) ) + def load_ledger_cli_owners( + self, derivation_path: str = None, legacy_account: bool = False + ): + self.load_hw_wallet(HwWalletType.LEDGER, derivation_path, legacy_account) + + def load_trezor_cli_owners( + self, derivation_path: str = None, legacy_account: bool = False + ): + self.load_hw_wallet(HwWalletType.TREZOR, derivation_path, legacy_account) + def unload_cli_owners(self, owners: List[str]): accounts_to_remove: Set[Account] = set() for owner in owners: @@ -322,9 +320,9 @@ def unload_cli_owners(self, owners: List[str]): break self.accounts = self.accounts.difference(accounts_to_remove) # Check if there are ledger owners - if self.ledger_manager and len(accounts_to_remove) < len(owners): + if self.hw_wallet_manager.wallets and len(accounts_to_remove) < len(owners): accounts_to_remove = ( - accounts_to_remove | self.ledger_manager.delete_accounts(owners) + accounts_to_remove | self.hw_wallet_manager.delete_accounts(owners) ) if accounts_to_remove: @@ -335,11 +333,7 @@ def unload_cli_owners(self, owners: List[str]): print_formatted_text(HTML("No account was deleted")) def show_cli_owners(self): - accounts = ( - self.accounts | self.ledger_manager.accounts - if self.ledger_manager - else self.accounts - ) + accounts = self.accounts | self.hw_wallet_manager.wallets if not accounts: print_formatted_text(HTML("No accounts loaded")) else: @@ -745,25 +739,33 @@ def print_info(self): ) ) - if not self.ledger_manager: + if not self.hw_wallet_manager.is_supported_hw_wallet(HwWalletType.LEDGER): print_formatted_text( HTML( "Ledger=" "Disabled Optional ledger library is not installed, run pip install safe-cli[ledger] " ) ) - elif self.ledger_manager.connected: + else: print_formatted_text( HTML( "Ledger=" - "Connected" + "supported" + ) + ) + + if not self.hw_wallet_manager.is_supported_hw_wallet(HwWalletType.TREZOR): + print_formatted_text( + HTML( + "Trezor=" + "Disabled Optional trezor library is not installed, run pip install safe-cli[trezor] " ) ) else: print_formatted_text( HTML( - "Ledger=" - "disconnected" + "Trezor=" + "supported" ) ) @@ -962,8 +964,8 @@ def sign_transaction(self, safe_tx: SafeTx) -> SafeTx: break # If still pending required signatures continue with ledger owners selected_ledger_accounts = [] - if threshold > 0 and self.ledger_manager: - for ledger_account in self.ledger_manager.accounts: + if threshold > 0 and self.hw_wallet_manager.wallets: + for ledger_account in self.hw_wallet_manager.wallets: if ledger_account.address in permitted_signers: selected_ledger_accounts.append(ledger_account) threshold -= 1 @@ -978,7 +980,9 @@ def sign_transaction(self, safe_tx: SafeTx) -> SafeTx: # Sign with ledger if len(selected_ledger_accounts) > 0: - safe_tx = self.ledger_manager.sign_eip712(safe_tx, selected_ledger_accounts) + safe_tx = self.hw_wallet_manager.sign_eip712( + safe_tx, selected_ledger_accounts + ) return safe_tx diff --git a/safe_cli/operators/safe_tx_service_operator.py b/safe_cli/operators/safe_tx_service_operator.py index f493f58..1fad945 100644 --- a/safe_cli/operators/safe_tx_service_operator.py +++ b/safe_cli/operators/safe_tx_service_operator.py @@ -100,13 +100,13 @@ def submit_signatures(self, safe_tx_hash: bytes) -> bool: if account.address in owners: safe_tx.sign(account.key) # Check if there are ledger signers - if self.ledger_manager: + if self.hw_wallet_manager.wallets: selected_ledger_accounts = [] - for ledger_account in self.ledger_manager.accounts: + for ledger_account in self.hw_wallet_manager.wallets: if ledger_account.address in owners: selected_ledger_accounts.append(ledger_account) if len(selected_ledger_accounts) > 0: - safe_tx = self.ledger_manager.sign_eip712( + safe_tx = self.hw_wallet_manager.sign_eip712( safe_tx, selected_ledger_accounts ) diff --git a/safe_cli/prompt_parser.py b/safe_cli/prompt_parser.py index 5673b8d..2ecdd43 100644 --- a/safe_cli/prompt_parser.py +++ b/safe_cli/prompt_parser.py @@ -175,6 +175,12 @@ def load_ledger_cli_owners(args): derivation_path=args.derivation_path, legacy_account=args.legacy_accounts ) + @safe_exception + def load_trezor_cli_owners(args): + safe_operator.load_trezor_cli_owners( + derivation_path=args.derivation_path, legacy_account=args.legacy_accounts + ) + @safe_exception def unload_cli_owners(args): safe_operator.unload_cli_owners(args.addresses) @@ -328,10 +334,23 @@ def remove_delegate(args): parser_load_ledger_cli_owners.add_argument( "--legacy-accounts", action="store_true", - help="Enable search legacy accounts", + help="Search for legacy accounts", ) parser_load_ledger_cli_owners.set_defaults(func=load_ledger_cli_owners) + parser_load_trezor_cli_owners = subparsers.add_parser("load_trezor_cli_owners") + parser_load_trezor_cli_owners.add_argument( + "--derivation-path", + type=str, + help="Load address for the provided derivation path", + ) + parser_load_trezor_cli_owners.add_argument( + "--legacy-accounts", + action="store_true", + help="Search for legacy accounts", + ) + parser_load_trezor_cli_owners.set_defaults(func=load_trezor_cli_owners) + parser_unload_cli_owners = subparsers.add_parser("unload_cli_owners") parser_unload_cli_owners.add_argument( "addresses", type=check_ethereum_address, nargs="+" diff --git a/safe_cli/safe_completer_constants.py b/safe_cli/safe_completer_constants.py index 6ade998..da4f300 100644 --- a/safe_cli/safe_completer_constants.py +++ b/safe_cli/safe_completer_constants.py @@ -25,6 +25,7 @@ "info": "(read-only)", "load_cli_owners": " [...]", "load_ledger_cli_owners": "[--legacy-accounts] [--derivation-path ]", + "load_trezor_cli_owners": "[--legacy-accounts] [--derivation-path ]", "load_cli_owners_from_words": " ... ", "refresh": "", "remove_delegate": "
", @@ -158,7 +159,10 @@ "<account-private-key>." ), "load_ledger_cli_owners": HTML( - "Command load_ledger_cli_owners show a list of ledger addresses to choose between them " + "Command load_ledger_cli_owners show a list of Ledger hardware wallet addresses to choose between them " + ), + "load_trezor_cli_owners": HTML( + "Command load_trezor_cli_owners show a list of Trezor hardware wallet addresses to choose between them " ), "load_cli_owners_from_words": HTML( "Command load_cli_owners_from_words will try to load owners via" diff --git a/setup.py b/setup.py index 8389fbc..5872775 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ "safe-eth-py==6.0.0b8", "tabulate>=0.8", ], - extras_require={"ledger": ["ledgereth==0.9.1"]}, + extras_require={"ledger": ["ledgereth==0.9.1"], "trezor": ["trezor==0.13.8"]}, packages=setuptools.find_packages(), entry_points={ "console_scripts": [ diff --git a/tests/test_hw_wallet_manager.py b/tests/test_hw_wallet_manager.py new file mode 100644 index 0000000..6bfe29c --- /dev/null +++ b/tests/test_hw_wallet_manager.py @@ -0,0 +1,135 @@ +import unittest +from unittest import mock +from unittest.mock import MagicMock + +from eth_account import Account +from ledgerblue.Dongle import Dongle + +from gnosis.safe.tests.safe_test_case import SafeTestCaseMixin + +from safe_cli.operators.hw_wallets.hw_wallet_manager import ( + HwWalletManager, + HwWalletType, + get_hw_wallet_manager, +) +from safe_cli.operators.hw_wallets.ledger_wallet import LedgerWallet + + +class Testledger_wallet(SafeTestCaseMixin, unittest.TestCase): + def test_setup_hw_wallet_manager(self): + # Should support Treezor and Ledger + hw_wallet_manager = get_hw_wallet_manager() + self.assertTrue(hw_wallet_manager.is_supported_hw_wallet(HwWalletType.TREZOR)) + self.assertTrue(hw_wallet_manager.is_supported_hw_wallet(HwWalletType.LEDGER)) + self.assertEqual(len(hw_wallet_manager.wallets), 0) + + # Should get the same instance + other_hw_wallet_manager = get_hw_wallet_manager() + self.assertEqual(other_hw_wallet_manager, hw_wallet_manager) + + @mock.patch( + "safe_cli.operators.hw_wallets.ledger_wallet.init_dongle", + autospec=True, + return_value=Dongle(), + ) + @mock.patch( + "safe_cli.operators.hw_wallets.ledger_wallet.LedgerWallet.get_address", + autospec=True, + ) + def test_get_accounts( + self, mock_get_address: MagicMock, mock_init_dongle: MagicMock + ): + hw_wallet_manager = HwWalletManager() + addresses = [Account.create().address, Account.create().address] + derivation_paths = ["44'/60'/0'/0/0", "44'/60'/1'/0/0"] + mock_get_address.side_effect = addresses + # Choosing LEDGER because function is mocked for LEDGER + hw_wallets = hw_wallet_manager.get_accounts( + HwWalletType.LEDGER, number_accounts=2 + ) + self.assertEqual(len(hw_wallets), 2) + for hw_wallet, expected_address, expected_derivation_path in zip( + hw_wallets, addresses, derivation_paths + ): + address, derivation_path = hw_wallet + self.assertEqual(expected_address, address) + self.assertEqual(expected_derivation_path, derivation_path) + + @mock.patch( + "safe_cli.operators.hw_wallets.ledger_wallet.init_dongle", + autospec=True, + return_value=Dongle(), + ) + @mock.patch( + "safe_cli.operators.hw_wallets.ledger_wallet.LedgerWallet.get_address", + autospec=True, + ) + def test_add_account( + self, mock_get_address: MagicMock, mock_init_dongle: MagicMock + ): + hw_wallet_manager = HwWalletManager() + derivation_path = "44'/60'/0'/0" + account_address = Account.create().address + mock_get_address.return_value = account_address + + self.assertEqual(len(hw_wallet_manager.wallets), 0) + # Choosing LEDGER because function is mocked for LEDGER + self.assertEqual( + hw_wallet_manager.add_account(HwWalletType.LEDGER, derivation_path), + account_address, + ) + + self.assertEqual(len(hw_wallet_manager.wallets), 1) + ledger_wallet = list(hw_wallet_manager.wallets)[0] + self.assertEqual(ledger_wallet.address, account_address) + self.assertEqual(ledger_wallet.derivation_path, derivation_path) + # Shouldn't duplicate accounts + self.assertEqual( + hw_wallet_manager.add_account(HwWalletType.LEDGER, derivation_path), + account_address, + ) + self.assertEqual(len(hw_wallet_manager.wallets), 1) + + # Should accept derivation paths starting with master + master_derivation_path = "m/44'/60'/0'/0" + self.assertEqual( + hw_wallet_manager.add_account(HwWalletType.LEDGER, master_derivation_path), + account_address, + ) + + @mock.patch( + "safe_cli.operators.hw_wallets.ledger_wallet.init_dongle", + autospec=True, + return_value=Dongle(), + ) + @mock.patch( + "safe_cli.operators.hw_wallets.ledger_wallet.LedgerWallet.get_address", + autospec=True, + ) + def test_delete_account( + self, mock_get_address: MagicMock, mock_init_dongle: MagicMock + ): + hw_wallet_manager = HwWalletManager() + random_address = Account.create().address + random_address_2 = Account.create().address + self.assertEqual(len(hw_wallet_manager.wallets), 0) + self.assertEqual(len(hw_wallet_manager.delete_accounts([random_address])), 0) + + mock_get_address.return_value = random_address_2 + hw_wallet_manager.wallets.add(LedgerWallet("44'/60'/0'/0")) + self.assertEqual(len(hw_wallet_manager.delete_accounts([random_address])), 0) + self.assertEqual(len(hw_wallet_manager.wallets), 1) + self.assertEqual(len(hw_wallet_manager.delete_accounts([])), 0) + + mock_get_address.return_value = random_address + hw_wallet_manager.wallets.add(LedgerWallet("44'/60'/0'/1")) + self.assertEqual(len(hw_wallet_manager.wallets), 2) + self.assertEqual(len(hw_wallet_manager.delete_accounts([random_address])), 1) + self.assertEqual(len(hw_wallet_manager.wallets), 1) + hw_wallet_manager.wallets.add(LedgerWallet("44'/60'/0'/1")) + self.assertEqual(len(hw_wallet_manager.wallets), 2) + self.assertEqual( + len(hw_wallet_manager.delete_accounts([random_address, random_address_2])), + 2, + ) + self.assertEqual(len(hw_wallet_manager.wallets), 0) diff --git a/tests/test_ledger_manager.py b/tests/test_ledger_manager.py deleted file mode 100644 index c8993d8..0000000 --- a/tests/test_ledger_manager.py +++ /dev/null @@ -1,246 +0,0 @@ -import unittest -from unittest import mock -from unittest.mock import MagicMock - -from eth_account import Account -from ledgerblue.Dongle import Dongle -from ledgereth.accounts import LedgerAccount -from ledgereth.exceptions import ( - LedgerAppNotOpened, - LedgerCancel, - LedgerLocked, - LedgerNotFound, -) - -from gnosis.eth.eip712 import eip712_encode -from gnosis.safe import SafeTx -from gnosis.safe.signatures import signature_split -from gnosis.safe.tests.safe_test_case import SafeTestCaseMixin - -from safe_cli.operators.exceptions import HardwareWalletException -from safe_cli.operators.hw_accounts.ledger_manager import LedgerManager - - -class TestLedgerManager(SafeTestCaseMixin, unittest.TestCase): - def test_setup_ledger_manager(self): - ledger_manager = LedgerManager() - self.assertIsNone(ledger_manager.dongle) - self.assertEqual(len(ledger_manager.accounts), 0) - self.assertEqual(ledger_manager.connected, False) - - @mock.patch("safe_cli.operators.hw_accounts.ledger_manager.init_dongle") - @mock.patch("safe_cli.operators.hw_accounts.ledger_manager.get_account_by_path") - def test_connected( - self, mock_get_account_by_path: MagicMock, mock_init_dongle: MagicMock - ): - ledger_manager = LedgerManager() - mock_init_dongle.side_effect = LedgerNotFound() - - self.assertEqual(ledger_manager.connected, False) - - mock_init_dongle.side_effect = None - mock_init_dongle.return_value = Dongle() - mock_get_account_by_path.side_effect = LedgerLocked() - - self.assertEqual(ledger_manager.connected, True) - - @mock.patch( - "safe_cli.operators.hw_accounts.ledger_manager.sign_typed_data_draft", - autospec=True, - ) - @mock.patch( - "safe_cli.operators.hw_accounts.ledger_manager.get_account_by_path", - autospec=True, - ) - def test_hw_device_exception(self, mock_ledger_fn: MagicMock, mock_sign: MagicMock): - ledger_manager = LedgerManager() - - derivation_path = "44'/60'/0'/0" - ledger_account = LedgerAccount(derivation_path, Account.create().address) - safe = self.deploy_test_safe( - owners=[Account.create().address], - threshold=1, - initial_funding_wei=self.w3.to_wei(0.1, "ether"), - ) - safe_tx = SafeTx( - self.ethereum_client, - safe.address, - Account.create().address, - 10, - b"", - 0, - 200000, - 200000, - self.gas_price, - None, - None, - safe_nonce=0, - ) - - mock_ledger_fn.side_effect = LedgerNotFound - mock_sign.side_effect = LedgerNotFound - with self.assertRaises(HardwareWalletException): - ledger_manager.get_accounts() - with self.assertRaises(HardwareWalletException): - ledger_manager.add_account(derivation_path) - with self.assertRaises(HardwareWalletException): - ledger_manager.sign_eip712(safe_tx, [ledger_account]) - - mock_ledger_fn.side_effect = LedgerLocked - mock_sign.side_effect = LedgerLocked - with self.assertRaises(HardwareWalletException): - ledger_manager.get_accounts() - with self.assertRaises(HardwareWalletException): - ledger_manager.add_account(derivation_path) - with self.assertRaises(HardwareWalletException): - ledger_manager.sign_eip712(safe_tx, [ledger_account]) - - mock_ledger_fn.side_effect = LedgerAppNotOpened - mock_sign.side_effect = LedgerAppNotOpened - with self.assertRaises(HardwareWalletException): - ledger_manager.get_accounts() - with self.assertRaises(HardwareWalletException): - ledger_manager.add_account(derivation_path) - with self.assertRaises(HardwareWalletException): - ledger_manager.sign_eip712(safe_tx, [ledger_account]) - - mock_ledger_fn.side_effect = LedgerCancel - mock_sign.side_effect = LedgerCancel - with self.assertRaises(HardwareWalletException): - ledger_manager.get_accounts() - with self.assertRaises(HardwareWalletException): - ledger_manager.add_account(derivation_path) - with self.assertRaises(HardwareWalletException): - ledger_manager.sign_eip712(safe_tx, [ledger_account]) - - @mock.patch( - "safe_cli.operators.hw_accounts.ledger_manager.get_account_by_path", - autospec=True, - ) - def test_get_accounts(self, mock_get_account_by_path: MagicMock): - ledger_manager = LedgerManager() - addresses = [Account.create().address, Account.create().address] - derivation_paths = ["44'/60'/0'/0", "44'/60'/0'/1"] - mock_get_account_by_path.side_effect = [ - LedgerAccount(derivation_paths[0], addresses[0]), - LedgerAccount(derivation_paths[1], addresses[1]), - ] - ledger_accounts = ledger_manager.get_accounts(number_accounts=2) - self.assertEqual(len(ledger_accounts), 2) - for ledger_account, expected_address, expected_derivation_path in zip( - ledger_accounts, addresses, derivation_paths - ): - ledger_address, ledger_path = ledger_account - self.assertEqual(expected_address, ledger_address) - self.assertEqual(expected_derivation_path, ledger_path) - - @mock.patch( - "safe_cli.operators.hw_accounts.ledger_manager.get_account_by_path", - autospec=True, - ) - def test_add_account(self, mock_get_account_by_path: MagicMock): - ledger_manager = LedgerManager() - derivation_path = "44'/60'/0'/0" - account_address = Account.create().address - mock_get_account_by_path.return_value = LedgerAccount( - derivation_path, account_address - ) - self.assertEqual(len(ledger_manager.accounts), 0) - - self.assertEqual(ledger_manager.add_account(derivation_path), account_address) - - self.assertEqual(len(ledger_manager.accounts), 1) - ledger_account = list(ledger_manager.accounts)[0] - self.assertEqual(ledger_account.address, account_address) - self.assertEqual(ledger_account.path, derivation_path) - # Shouldn't duplicate accounts - self.assertEqual(ledger_manager.add_account(derivation_path), account_address) - self.assertEqual(len(ledger_manager.accounts), 1) - - # Should accept derivation paths starting with master - master_derivation_path = "m/44'/60'/0'/0" - self.assertEqual( - ledger_manager.add_account(master_derivation_path), account_address - ) - - def test_delete_account(self): - ledger_manager = LedgerManager() - random_address = Account.create().address - random_address_2 = Account.create().address - self.assertEqual(len(ledger_manager.accounts), 0) - self.assertEqual(len(ledger_manager.delete_accounts([random_address])), 0) - ledger_manager.accounts.add(LedgerAccount("44'/60'/0'/0", random_address_2)) - self.assertEqual(len(ledger_manager.delete_accounts([random_address])), 0) - self.assertEqual(len(ledger_manager.accounts), 1) - self.assertEqual(len(ledger_manager.delete_accounts([])), 0) - ledger_manager.accounts.add(LedgerAccount("44'/60'/0'/1", random_address)) - self.assertEqual(len(ledger_manager.accounts), 2) - self.assertEqual(len(ledger_manager.delete_accounts([random_address])), 1) - self.assertEqual(len(ledger_manager.accounts), 1) - ledger_manager.accounts.add(LedgerAccount("44'/60'/0'/1", random_address)) - self.assertEqual(len(ledger_manager.accounts), 2) - self.assertEqual( - len(ledger_manager.delete_accounts([random_address, random_address_2])), 2 - ) - self.assertEqual(len(ledger_manager.accounts), 0) - - @mock.patch( - "safe_cli.operators.hw_accounts.ledger_manager.init_dongle", - autospec=True, - return_value=Dongle(), - ) - def test_sign_eip712(self, mock_init_dongle: MagicMock): - ledger_manager = LedgerManager() - owner = Account.create() - to = Account.create() - ledger_account = LedgerAccount("44'/60'/0'/0", owner.address) - safe = self.deploy_test_safe( - owners=[owner.address], - threshold=1, - initial_funding_wei=self.w3.to_wei(0.1, "ether"), - ) - safe_tx = SafeTx( - self.ethereum_client, - safe.address, - to.address, - 10, - b"", - 0, - 200000, - 200000, - self.gas_price, - None, - None, - safe_nonce=0, - ) - encode_hash = eip712_encode(safe_tx.eip712_structured_data) - expected_signature = safe_tx.sign(owner.key) - # We need to split to change the bytes signature order to v + r + s like ledger return signature - v, r, s = signature_split(expected_signature) - - ledger_return_signature = ( - v.to_bytes(1, byteorder="big") - + r.to_bytes(32, byteorder="big") - + s.to_bytes(32, byteorder="big") - ) - mock_init_dongle.return_value.exchange = MagicMock( - return_value=ledger_return_signature - ) - safe_tx = ledger_manager.sign_eip712(safe_tx, [ledger_account]) - self.assertEqual(safe_tx.signatures, expected_signature) - - # Check that dongle exchange is called with the expected payload - # https://github.com/LedgerHQ/app-ethereum/blob/master/doc/ethapp.adoc#sign-eth-eip-712 - command = "e00c0000" + "51" # command + payload length - payload = ( - "04" + "8000002c8000003c8000000000000000" - ) # number of derivations + 44'/60'/0'/0 - expected_exchange_payload = ( - bytes.fromhex(command) - + bytes.fromhex(payload) - + encode_hash[1] - + encode_hash[2] - ) - mock_init_dongle.return_value.exchange.assert_called_once_with( - expected_exchange_payload - ) diff --git a/tests/test_ledger_wallet.py b/tests/test_ledger_wallet.py new file mode 100644 index 0000000..bc9e0d4 --- /dev/null +++ b/tests/test_ledger_wallet.py @@ -0,0 +1,169 @@ +import os +import unittest +from unittest import mock +from unittest.mock import MagicMock + +from eth_account import Account +from ledgerblue.Dongle import Dongle +from ledgereth.exceptions import ( + LedgerAppNotOpened, + LedgerCancel, + LedgerLocked, + LedgerNotFound, +) +from ledgereth.objects import LedgerAccount + +from gnosis.eth.eip712 import eip712_encode +from gnosis.safe import SafeTx +from gnosis.safe.signatures import signature_split +from gnosis.safe.tests.safe_test_case import SafeTestCaseMixin + +from safe_cli.operators.exceptions import HardwareWalletException +from safe_cli.operators.hw_wallets.ledger_wallet import LedgerWallet + + +class Testledger_wallet(SafeTestCaseMixin, unittest.TestCase): + @mock.patch( + "safe_cli.operators.hw_wallets.ledger_wallet.init_dongle", + return_value=Dongle(), + ) + def test_setup_ledger_wallet(self, mock_init_dongle: MagicMock): + derivation_path = "44'/60'/0'/0" + address = Account.create().address + with self.assertRaises(HardwareWalletException): + LedgerWallet(derivation_path) + with mock.patch( + "safe_cli.operators.hw_wallets.ledger_wallet.get_account_by_path", + return_value=LedgerAccount(derivation_path, address), + ): + ledger_wallet = LedgerWallet(derivation_path) + self.assertIsNotNone(ledger_wallet.dongle) + self.assertEqual(ledger_wallet.address, address) + + @mock.patch( + "safe_cli.operators.hw_wallets.ledger_wallet.sign_typed_data_draft", + autospec=True, + ) + @mock.patch( + "safe_cli.operators.hw_wallets.ledger_wallet.get_account_by_path", + autospec=True, + ) + @mock.patch( + "safe_cli.operators.hw_wallets.ledger_wallet.init_dongle", + return_value=Dongle(), + ) + def test_hw_device_exception( + self, + mock_init_dongle: MagicMock, + mock_get_account_by_path: MagicMock, + mock_sign: MagicMock, + ): + derivation_path = "44'/60'/0'/0" + address = Account.create().address + random_domain_bytes = os.urandom(32) + random_message_bytes = os.urandom(32) + mock_get_account_by_path.side_effect = LedgerNotFound + with self.assertRaises(HardwareWalletException): + LedgerWallet(derivation_path) + + mock_get_account_by_path.side_effect = LedgerLocked + with self.assertRaises(HardwareWalletException): + LedgerWallet(derivation_path) + + mock_get_account_by_path.side_effect = LedgerAppNotOpened + with self.assertRaises(HardwareWalletException): + LedgerWallet(derivation_path) + + # Test sign exceptions + mock_get_account_by_path.side_effect = None + mock_get_account_by_path.return_value = LedgerAccount(derivation_path, address) + mock_sign.side_effect = LedgerNotFound + with self.assertRaises(HardwareWalletException): + ledger_wallet = LedgerWallet(derivation_path) + ledger_wallet.sign_typed_hash(random_domain_bytes, random_message_bytes) + + mock_sign.side_effect = LedgerLocked + with self.assertRaises(HardwareWalletException): + ledger_wallet = LedgerWallet(derivation_path) + ledger_wallet.sign_typed_hash(random_domain_bytes, random_message_bytes) + + mock_sign.side_effect = LedgerAppNotOpened + with self.assertRaises(HardwareWalletException): + ledger_wallet = LedgerWallet(derivation_path) + ledger_wallet.sign_typed_hash(random_domain_bytes, random_message_bytes) + + mock_sign.side_effect = LedgerCancel + with self.assertRaises(HardwareWalletException): + ledger_wallet = LedgerWallet(derivation_path) + ledger_wallet.sign_typed_hash(random_domain_bytes, random_message_bytes) + + @mock.patch( + "safe_cli.operators.hw_wallets.ledger_wallet.get_account_by_path", + autospec=True, + ) + @mock.patch( + "safe_cli.operators.hw_wallets.ledger_wallet.init_dongle", + autospec=True, + return_value=Dongle(), + ) + def test_sign_typed_hash( + self, mock_init_dongle: MagicMock, mock_get_account_by_path: MagicMock + ): + owner = Account.create() + to = Account.create() + derivation_path = "44'/60'/0'/0" + mock_get_account_by_path.return_value = LedgerAccount( + derivation_path, owner.address + ) + ledger_wallet = LedgerWallet(derivation_path) + + safe = self.deploy_test_safe( + owners=[owner.address], + threshold=1, + initial_funding_wei=self.w3.to_wei(0.1, "ether"), + ) + safe_tx = SafeTx( + self.ethereum_client, + safe.address, + to.address, + 10, + b"", + 0, + 200000, + 200000, + self.gas_price, + None, + None, + safe_nonce=0, + ) + encode_hash = eip712_encode(safe_tx.eip712_structured_data) + expected_signature = safe_tx.sign(owner.key) + # We need to split to change the bytes signature order to v + r + s like ledger return signature + v, r, s = signature_split(expected_signature) + + ledger_return_signature = ( + v.to_bytes(1, byteorder="big") + + r.to_bytes(32, byteorder="big") + + s.to_bytes(32, byteorder="big") + ) + mock_init_dongle.return_value.exchange = MagicMock( + return_value=ledger_return_signature + ) + signature = ledger_wallet.sign_typed_hash(encode_hash[1], encode_hash[2]) + self.assertEqual(expected_signature, signature) + + # Check that dongle exchange is called with the expected payload + # https://github.com/LedgerHQ/app-ethereum/blob/master/doc/ethapp.adoc#sign-eth-eip-712 + command = "e00c0000" + "51" # command + payload length + payload = ( + "04" + "8000002c8000003c8000000000000000" + ) # number of derivations + 44'/60'/0'/0 + expected_exchange_payload = ( + bytes.fromhex(command) + + bytes.fromhex(payload) + + encode_hash[1] + + encode_hash[2] + ) + mock_init_dongle.return_value.exchange.assert_called_once_with( + expected_exchange_payload + ) diff --git a/tests/test_safe_operator.py b/tests/test_safe_operator.py index 1d9e9a7..c6a0502 100644 --- a/tests/test_safe_operator.py +++ b/tests/test_safe_operator.py @@ -5,6 +5,7 @@ from eth_account import Account from eth_typing import ChecksumAddress +from ledgerblue.Dongle import Dongle from ledgereth.objects import LedgerAccount from web3 import Web3 @@ -81,29 +82,28 @@ def test_load_cli_owner(self, get_contract_mock: MagicMock): # Test unload cli owner safe_operator.default_sender = random_accounts[0] number_of_accounts = len(safe_operator.accounts) - ledger_random_address = Account.create().address - safe_operator.ledger_manager.accounts.add( - LedgerAccount("44'/60'/0'/1", ledger_random_address) - ) - self.assertEqual(len(safe_operator.ledger_manager.accounts), 1) - safe_operator.unload_cli_owners( - ["aloha", random_accounts[0].address, "bye", ledger_random_address] - ) + + safe_operator.unload_cli_owners(["aloha", random_accounts[0].address, "bye"]) self.assertEqual(len(safe_operator.accounts), number_of_accounts - 1) self.assertFalse(safe_operator.default_sender) - self.assertEqual(len(safe_operator.ledger_manager.accounts), 0) - @mock.patch("safe_cli.operators.hw_accounts.ledger_manager.get_account_by_path") - def test_load_ledger_cli_owner(self, mock_get_account_by_path: MagicMock): + @mock.patch( + "safe_cli.operators.hw_wallets.ledger_wallet.init_dongle", + return_value=Dongle(), + ) + @mock.patch("safe_cli.operators.hw_wallets.ledger_wallet.get_account_by_path") + def test_load_ledger_cli_owner( + self, mock_get_account_by_path: MagicMock, mock_init_dongle: MagicMock + ): owner_address = Account.create().address safe_address = self.deploy_test_safe(owners=[owner_address]).address safe_operator = SafeOperator(safe_address, self.ethereum_node_url) - safe_operator.ledger_manager.get_accounts = MagicMock(return_value=[]) + safe_operator.hw_wallet_manager.get_accounts = MagicMock(return_value=[]) safe_operator.load_ledger_cli_owners() - self.assertEqual(len(safe_operator.ledger_manager.accounts), 0) + self.assertEqual(len(safe_operator.hw_wallet_manager.wallets), 0) random_address = Account.create().address other_random_address = Account.create().address - safe_operator.ledger_manager.get_accounts.return_value = [ + safe_operator.hw_wallet_manager.get_accounts.return_value = [ (random_address, "44'/60'/0'/0/0"), (other_random_address, "44'/60'/0'/0/1"), ] @@ -112,9 +112,9 @@ def test_load_ledger_cli_owner(self, mock_get_account_by_path: MagicMock): "44'/60'/0'/0/0", random_address ) safe_operator.load_ledger_cli_owners() - self.assertEqual(len(safe_operator.ledger_manager.accounts), 1) + self.assertEqual(len(safe_operator.hw_wallet_manager.wallets), 1) self.assertEqual( - safe_operator.ledger_manager.accounts.pop().address, random_address + safe_operator.hw_wallet_manager.wallets.pop().address, random_address ) # Only accept ethereum derivation paths @@ -125,10 +125,19 @@ def test_load_ledger_cli_owner(self, mock_get_account_by_path: MagicMock): "44'/60'/0'/0/0", owner_address ) safe_operator.load_ledger_cli_owners(derivation_path="44'/60'/0'/0/0") - self.assertEqual(len(safe_operator.ledger_manager.accounts), 1) + self.assertEqual(len(safe_operator.hw_wallet_manager.wallets), 1) self.assertEqual( - safe_operator.ledger_manager.accounts.pop().address, owner_address + safe_operator.hw_wallet_manager.wallets.pop().address, owner_address + ) + + # test unload ledger owner + ledger_random_address = Account.create().address + safe_operator.hw_wallet_manager.wallets.add( + LedgerAccount("44'/60'/0'/1", ledger_random_address) ) + self.assertEqual(len(safe_operator.hw_wallet_manager.wallets), 1) + safe_operator.unload_cli_owners([ledger_random_address]) + self.assertEqual(len(safe_operator.hw_wallet_manager.wallets), 0) def test_approve_hash(self): safe_address = self.deploy_test_safe( diff --git a/tests/test_safe_tx_service_operator.py b/tests/test_safe_tx_service_operator.py index c9d4891..cf61c3b 100644 --- a/tests/test_safe_tx_service_operator.py +++ b/tests/test_safe_tx_service_operator.py @@ -139,17 +139,17 @@ def test_submit_signatures( with mock.patch.object( SafeTx, "signers", return_value=["signer"], new_callable=mock.PropertyMock ) as mock_safe_tx: - safe_operator.ledger_manager.sign_eip712 = MagicMock( + safe_operator.hw_wallet_manager.sign_eip712 = MagicMock( return_value=mock_safe_tx ) get_safe_transaction_mock.return_value = GetMultisigTxRequestMock( executed=False ) - safe_operator.ledger_manager.accounts.add( + safe_operator.hw_wallet_manager.wallets.add( LedgerAccount("44'/60'/0'/0", Account.create().address) ) get_permitted_signers_mock.return_value = { - list(safe_operator.ledger_manager.accounts)[0].address + list(safe_operator.hw_wallet_manager.wallets)[0].address } self.assertTrue(safe_operator.submit_signatures(safe_tx_hash)) diff --git a/tests/test_trezor_wallet.py b/tests/test_trezor_wallet.py new file mode 100644 index 0000000..65f4611 --- /dev/null +++ b/tests/test_trezor_wallet.py @@ -0,0 +1,151 @@ +import os +import unittest +from unittest import mock +from unittest.mock import MagicMock + +from eth_account import Account +from trezorlib.client import TrezorClient +from trezorlib.exceptions import Cancelled, OutdatedFirmwareError, PinException +from trezorlib.messages import EthereumTypedDataSignature +from trezorlib.transport import TransportException +from trezorlib.ui import ClickUI + +from gnosis.eth.eip712 import eip712_encode +from gnosis.safe import SafeTx +from gnosis.safe.tests.safe_test_case import SafeTestCaseMixin + +from safe_cli.operators.exceptions import HardwareWalletException +from safe_cli.operators.hw_wallets.trezor_wallet import TrezorWallet + + +class TestTrezorManager(SafeTestCaseMixin, unittest.TestCase): + @mock.patch( + "safe_cli.operators.hw_wallets.trezor_wallet.get_trezor_client", + return_value=None, + ) + @mock.patch( + "safe_cli.operators.hw_wallets.trezor_wallet.get_address", + return_value=None, + ) + def test_setup_trezor_wallet( + self, mock_trezor_client: MagicMock, mock_get_address: MagicMock + ): + trezor_wallet = TrezorWallet("44'/60'/0'/0") + self.assertIsNone(trezor_wallet.client) + + @mock.patch( + "safe_cli.operators.hw_wallets.trezor_wallet.sign_typed_data_hash", + autospec=True, + ) + @mock.patch( + "safe_cli.operators.hw_wallets.trezor_wallet.get_address", + autospec=True, + ) + @mock.patch( + "safe_cli.operators.hw_wallets.trezor_wallet.get_trezor_client", + autospec=True, + ) + def test_hw_device_exception( + self, + mock_trezor_client: MagicMock, + mock_trezor_get_address: MagicMock, + mock_trezor_sign: MagicMock, + ): + derivation_path = "44'/60'/0'/0" + transport_mock = MagicMock(auto_spec=True) + mock_trezor_client.return_value = TrezorClient( + transport_mock, ui=ClickUI(), _init_device=False + ) + mock_trezor_client.return_value.is_outdated = MagicMock(return_value=False) + random_domain_bytes = os.urandom(32) + random_message_bytes = os.urandom(32) + + mock_trezor_get_address.side_effect = TransportException + with self.assertRaises(HardwareWalletException): + TrezorWallet(derivation_path) + + mock_trezor_get_address.side_effect = PinException + with self.assertRaises(HardwareWalletException): + TrezorWallet(derivation_path) + + mock_trezor_get_address.side_effect = Cancelled + with self.assertRaises(HardwareWalletException): + TrezorWallet(derivation_path) + + mock_trezor_get_address.side_effect = OutdatedFirmwareError + with self.assertRaises(HardwareWalletException): + TrezorWallet(derivation_path) + + mock_trezor_get_address.side_effect = None + mock_trezor_get_address.return_value = Account.create().address + mock_trezor_sign.side_effect = TransportException + with self.assertRaises(HardwareWalletException): + trezor_wallet = TrezorWallet(derivation_path) + trezor_wallet.sign_typed_hash(random_domain_bytes, random_message_bytes) + + mock_trezor_sign.side_effect = PinException + with self.assertRaises(HardwareWalletException): + trezor_wallet = TrezorWallet(derivation_path) + trezor_wallet.sign_typed_hash(random_domain_bytes, random_message_bytes) + + mock_trezor_sign.side_effect = Cancelled + with self.assertRaises(HardwareWalletException): + trezor_wallet = TrezorWallet(derivation_path) + trezor_wallet.sign_typed_hash(random_domain_bytes, random_message_bytes) + + mock_trezor_sign.side_effect = OutdatedFirmwareError + with self.assertRaises(HardwareWalletException): + trezor_wallet = TrezorWallet(derivation_path) + trezor_wallet.sign_typed_hash(random_domain_bytes, random_message_bytes) + + @mock.patch( + "safe_cli.operators.hw_wallets.trezor_wallet.get_address", + autospec=True, + ) + @mock.patch( + "safe_cli.operators.hw_wallets.trezor_wallet.get_trezor_client", + autospec=True, + ) + def test_sign_typed_hash( + self, mock_trezor_client: MagicMock, mock_get_address: MagicMock + ): + owner = Account.create() + to = Account.create() + transport_mock = MagicMock(auto_spec=True) + mock_trezor_client.return_value = TrezorClient( + transport_mock, ui=ClickUI(), _init_device=False + ) + mock_trezor_client.return_value.is_outdated = MagicMock(return_value=False) + mock_get_address.return_value = owner.address + trezor_wallet = TrezorWallet("44'/60'/0'/0") + + safe = self.deploy_test_safe( + owners=[owner.address], + threshold=1, + initial_funding_wei=self.w3.to_wei(0.1, "ether"), + ) + safe_tx = SafeTx( + self.ethereum_client, + safe.address, + to.address, + 10, + b"", + 0, + 200000, + 200000, + self.gas_price, + None, + None, + safe_nonce=0, + ) + encode_hash = eip712_encode(safe_tx.eip712_structured_data) + expected_signature = safe_tx.sign(owner.key) + + trezor_return_signature = EthereumTypedDataSignature( + signature=expected_signature + ) + mock_trezor_client.return_value.call = MagicMock( + return_value=trezor_return_signature + ) + signature = trezor_wallet.sign_typed_hash(encode_hash[1], encode_hash[2]) + self.assertEqual(expected_signature, signature)