Skip to content

Commit

Permalink
Refactor hw__account_manager
Browse files Browse the repository at this point in the history
  • Loading branch information
moisses89 committed Nov 29, 2023
1 parent 443cd5f commit e8c2a0a
Show file tree
Hide file tree
Showing 12 changed files with 451 additions and 274 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ pygments==2.17.1
requests==2.31.0
safe-eth-py==6.0.0b8
tabulate==0.9.0
trezor==0.13.8
web3==6.11.3
4 changes: 4 additions & 0 deletions safe_cli/operators/hw_accounts/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ class InvalidDerivationPath(LedgerError):
message = "The provided derivation path is not valid"


class UnsupportedHwWalletException(Exception):
pass


def raise_as_hw_account_exception(function):
@functools.wraps(function)
def wrapper(*args, **kwargs):
Expand Down
61 changes: 61 additions & 0 deletions safe_cli/operators/hw_accounts/hw_account.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import re
from abc import ABC, abstractmethod
from typing import Tuple

from eth_typing import ChecksumAddress

BIP32_ETH_PATTERN = r"^(m/)?44'/60'/[0-9]+'/[0-9]+/[0-9]+$"
BIP32_LEGACY_LEDGER_PATTERN = r"^(m/)?44'/60'/[0-9]+'/[0-9]+$"


class HwAccount(ABC):
def __init__(self, derivation_path: str, address: ChecksumAddress):
self.derivation_path = derivation_path
self.address = address

@property
def get_derivation_path(self):
return self.derivation_path

@property
def get_address(self):
return self.address

@staticmethod
def is_valid_derivation_path(derivation_path: str):
"""
Detect if a string is a valid derivation path
"""
return (
re.match(BIP32_ETH_PATTERN, derivation_path) is not None
or re.match(BIP32_LEGACY_LEDGER_PATTERN, derivation_path) is not None
)

@staticmethod
@abstractmethod
def get_address_by_derivation_path(derivation_path: str) -> ChecksumAddress:
"""
:param derivation_path:
:return: public address for provided derivation_path
"""

@abstractmethod
def sign_typed_hash(self, domain_hash, message_hash) -> Tuple[bytes, bytes, bytes]:
"""
:param domain_hash:
:param message_hash:
:return: tuple os signature v, r, s
"""

def __eq__(self, other):
if isinstance(other, HwAccount):
return (
self.derivation_path == other.derivation_path
and self.address == other.address
)
return False

def __hash__(self):
return hash((self.derivation_path, self.address))
151 changes: 151 additions & 0 deletions safe_cli/operators/hw_accounts/hw_account_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from enum import Enum
from typing import Dict, List, Optional, Set, Tuple

from eth_typing import ChecksumAddress
from prompt_toolkit import HTML, print_formatted_text

from gnosis.eth.eip712 import eip712_encode
from gnosis.safe import SafeTx
from gnosis.safe.signatures import signature_to_bytes

from safe_cli.operators.exceptions import HardwareWalletException
from safe_cli.operators.hw_accounts.hw_account import HwAccount


class HwWalletType(Enum):
TREZOR = "Trezor"
LEDGER = "Ledger"


class HwAccountManager:
def __new__(cls):
if not hasattr(cls, "instance"):
cls.instance = super(HwAccountManager, cls).__new__(cls)
return cls.instance

def __init__(self):
self.accounts: Set[HwAccount] = set()
self.supported_hw_wallets: Dict[str, HwAccount] = {}
try:
from safe_cli.operators.hw_accounts.ledger_manager import LedgerManager

self.supported_hw_wallets[HwWalletType.LEDGER] = LedgerManager
except (ModuleNotFoundError, IOError):
print("Exception")
pass

print("Continue")
try:
from safe_cli.operators.hw_accounts.trezor_manager import TrezorManager

print("Continue")
self.supported_hw_wallets[HwWalletType.TREZOR] = TrezorManager
except (ModuleNotFoundError, IOError):
pass

def is_supported_hw_wallet(self, hw_wallet_type: HwWalletType):
return hw_wallet_type in self.supported_hw_wallets

def get_hw_wallet(self, hw_wallet_type: HwWalletType):
if hw_wallet_type in self.supported_hw_wallets:
return self.supported_hw_wallets[hw_wallet_type]
# TODO add unsupported exception

def get_accounts(
self,
hw_wallet_type: HwWalletType,
legacy_account: Optional[bool] = False,
number_accounts: Optional[int] = 5,
) -> Tuple[ChecksumAddress, str]:
"""
:param hw_wallet: Trezor or Ledger
:param legacy_account:
:param number_accounts: number of accounts requested to ledger
:return: a list of tuples with address and derivation path
"""
accounts = []
hw_wallet = self.get_hw_wallet(hw_wallet_type)
for i in range(number_accounts):
if legacy_account:
path_string = f"44'/60'/0'/{i}"
else:
path_string = f"44'/60'/{i}'/0/0"

address = hw_wallet.get_address_by_derivation_path(path_string)
accounts.append((address, path_string))
return accounts

def add_account(
self, hw_wallet_type: HwWalletType, derivation_path: str
) -> ChecksumAddress:
"""
Add an account to ledger manager set and return the added address
:param derivation_path:
:return:
"""

hw_wallet = self.get_hw_wallet(hw_wallet_type)

if hw_wallet.is_valid_derivation_path(derivation_path):
address = hw_wallet.get_address_by_derivation_path(derivation_path)
self.accounts.add(hw_wallet(derivation_path, address))
return address
else:
raise HardwareWalletException("Invalid derivation path")

def delete_accounts(self, addresses: List[ChecksumAddress]) -> Set:
"""
Remove ledger accounts from address
:param accounts:
:return: list with the delete accounts
"""
accounts_to_remove = set()
for address in addresses:
for account in self.accounts:
if account.address == address:
accounts_to_remove.add(account)
self.accounts = self.accounts.difference(accounts_to_remove)
return accounts_to_remove

def sign_eip712(self, safe_tx: SafeTx, accounts: List[HwAccount]) -> SafeTx:
"""
Call ledger ethereum app method to sign eip712 hashes with a ledger account
:param domain_hash:
:param message_hash:
:param account: ledger account
:return: bytes of signature
"""
encode_hash = eip712_encode(safe_tx.eip712_structured_data)
domain_hash = encode_hash[1]
message_hash = encode_hash[2]
for account in accounts:
print_formatted_text(
HTML(
"<ansired>Make sure in your ledger before signing that domain_hash and message_hash are both correct</ansired>"
)
)
print_formatted_text(HTML(f"Domain_hash: <b>{domain_hash.hex()}</b>"))
print_formatted_text(HTML(f"Message_hash: <b>{message_hash.hex()}</b>"))
signed_v, signed_r, signed_s = account.sign_typed_hash(
domain_hash, message_hash
)

signature = signature_to_bytes(signed_v, signed_r, signed_s)
# TODO should be refactored on safe_eth_py function insert_signature_sorted
# Insert signature sorted
if account.address not in safe_tx.signers:
new_owners = safe_tx.signers + [account.address]
new_owner_pos = sorted(new_owners, key=lambda x: int(x, 16)).index(
account.address
)
safe_tx.signatures = (
safe_tx.signatures[: 65 * new_owner_pos]
+ signature
+ safe_tx.signatures[65 * new_owner_pos :]
)

return safe_tx
116 changes: 15 additions & 101 deletions safe_cli/operators/hw_accounts/ledger_manager.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,20 @@
from typing import List, Optional, Set, Tuple
from typing import Tuple

from eth_typing import ChecksumAddress
from ledgereth import sign_typed_data_draft
from ledgereth.accounts import LedgerAccount, get_account_by_path
from ledgereth.accounts import get_account_by_path
from ledgereth.comms import init_dongle
from ledgereth.exceptions import LedgerNotFound
from ledgereth.utils import is_bip32_path
from prompt_toolkit import HTML, print_formatted_text

from gnosis.eth.eip712 import eip712_encode
from gnosis.safe import SafeTx
from gnosis.safe.signatures import signature_to_bytes
from safe_cli.operators.hw_accounts.exceptions import raise_as_hw_account_exception
from safe_cli.operators.hw_accounts.hw_account import HwAccount

from safe_cli.operators.hw_accounts.exceptions import (
InvalidDerivationPath,
raise_as_hw_account_exception,
)


class LedgerManager:
def __init__(self):
class LedgerManager(HwAccount):
def __init__(self, derivation_path: str, address: ChecksumAddress):
self.dongle = None
self.accounts: Set[LedgerAccount] = set()
self.connect()
super().__init__(derivation_path, address)

def connect(self) -> bool:
"""
Expand All @@ -44,98 +36,20 @@ def connected(self) -> bool:
return self.connect()

@raise_as_hw_account_exception
def get_accounts(
self, legacy_account: Optional[bool] = False, number_accounts: Optional[int] = 5
) -> List[Tuple[ChecksumAddress, str]]:
"""
Request to ledger device the first n accounts
:param legacy_account:
:param number_accounts: number of accounts requested to ledger
:return: a list of tuples with address and derivation path
def get_address_by_derivation_path(derivation_path: str) -> ChecksumAddress:
"""
accounts = []
for i in range(number_accounts):
if legacy_account:
path_string = f"44'/60'/0'/{i}"
else:
path_string = f"44'/60'/{i}'/0/0"

account = get_account_by_path(path_string, self.dongle)
accounts.append((account.address, account.path))
return accounts

@raise_as_hw_account_exception
def add_account(self, derivation_path: str) -> ChecksumAddress:
"""
Add an account to ledger manager set and return the added address
:param derivation_path:
:return:
:return: public address for provided derivation_path
"""
# we should accept m/ or m'/ starting derivation paths
if derivation_path[0:2] == "m/":
derivation_path = derivation_path.replace("m/", "")

if not is_bip32_path(derivation_path):
raise InvalidDerivationPath()

account = get_account_by_path(derivation_path, self.dongle)
self.accounts.add(LedgerAccount(account.path, account.address))
account = get_account_by_path(derivation_path)
return account.address

def delete_accounts(self, addresses: List[ChecksumAddress]) -> Set:
"""
Remove ledger accounts from address
:param accounts:
:return: list with the delete accounts
"""
accounts_to_remove = set()
for address in addresses:
for account in self.accounts:
if account.address == address:
accounts_to_remove.add(account)
self.accounts = self.accounts.difference(accounts_to_remove)
return accounts_to_remove

@raise_as_hw_account_exception
def sign_eip712(self, safe_tx: SafeTx, accounts: List[LedgerAccount]) -> SafeTx:
"""
Call ledger ethereum app method to sign eip712 hashes with a ledger account
:param domain_hash:
:param message_hash:
:param account: ledger account
:return: bytes of signature
"""
encode_hash = eip712_encode(safe_tx.eip712_structured_data)
domain_hash = encode_hash[1]
message_hash = encode_hash[2]
for account in accounts:
print_formatted_text(
HTML(
"<ansired>Make sure in your ledger before signing that domain_hash and message_hash are both correct</ansired>"
)
)
print_formatted_text(HTML(f"Domain_hash: <b>{domain_hash.hex()}</b>"))
print_formatted_text(HTML(f"Message_hash: <b>{message_hash.hex()}</b>"))
signed = sign_typed_data_draft(
domain_hash, message_hash, account.path, self.dongle
)

signature = signature_to_bytes(signed.v, signed.r, signed.s)
# TODO should be refactored on safe_eth_py function insert_signature_sorted
# Insert signature sorted
if account.address not in safe_tx.signers:
new_owners = safe_tx.signers + [account.address]
new_owner_pos = sorted(new_owners, key=lambda x: int(x, 16)).index(
account.address
)
safe_tx.signatures = (
safe_tx.signatures[: 65 * new_owner_pos]
+ signature
+ safe_tx.signatures[65 * new_owner_pos :]
)

return safe_tx
def sign_typed_hash(self, domain_hash, message_hash) -> Tuple[bytes, bytes, bytes]:
signed = sign_typed_data_draft(
domain_hash, message_hash, self.derivation_path, self.dongle
)
return (signed.v, signed.r, signed.s)
22 changes: 22 additions & 0 deletions safe_cli/operators/hw_accounts/trezor_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Tuple

from eth_typing import ChecksumAddress

from safe_cli.operators.hw_accounts.hw_account import HwAccount


class TrezorManager(HwAccount):
def __init__(self, derivation_path: str, address: ChecksumAddress):
self.client = None
super().__init__(derivation_path, address)

def get_address_by_derivation_path(derivation_path: str) -> ChecksumAddress:
"""
:param derivation_path:
:return: public address for provided derivation_path
"""
raise NotImplementedError

def sign_typed_hash(self, domain_hash, message_hash) -> Tuple[bytes, bytes, bytes]:
raise NotImplementedError
Loading

0 comments on commit e8c2a0a

Please sign in to comment.