Skip to content

Commit

Permalink
Merge pull request #5721 from SomberNight/201910_psbt
Browse files Browse the repository at this point in the history
integrate PSBT support natively. WIP
  • Loading branch information
ecdsa committed Nov 7, 2019
2 parents 6d12eba + cd49839 commit 707b74d
Show file tree
Hide file tree
Showing 62 changed files with 4,171 additions and 3,420 deletions.
94 changes: 44 additions & 50 deletions electrum/address_synchronizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple, NamedTuple, Sequence

from . import bitcoin
from .bitcoin import COINBASE_MATURITY, TYPE_ADDRESS, TYPE_PUBKEY
from .bitcoin import COINBASE_MATURITY
from .util import profiler, bfh, TxMinedInfo
from .transaction import Transaction, TxOutput
from .transaction import Transaction, TxOutput, TxInput, PartialTxInput, TxOutpoint, PartialTransaction
from .synchronizer import Synchronizer
from .verifier import SPV
from .blockchain import hash_header
Expand Down Expand Up @@ -125,27 +125,21 @@ def get_address_history_len(self, addr: str) -> int:
"""Return number of transactions where address is involved."""
return len(self._history_local.get(addr, ()))

def get_txin_address(self, txi) -> Optional[str]:
addr = txi.get('address')
if addr and addr != "(pubkey)":
return addr
prevout_hash = txi.get('prevout_hash')
prevout_n = txi.get('prevout_n')
def get_txin_address(self, txin: TxInput) -> Optional[str]:
if isinstance(txin, PartialTxInput):
if txin.address:
return txin.address
prevout_hash = txin.prevout.txid.hex()
prevout_n = txin.prevout.out_idx
for addr in self.db.get_txo_addresses(prevout_hash):
l = self.db.get_txo_addr(prevout_hash, addr)
for n, v, is_cb in l:
if n == prevout_n:
return addr
return None

def get_txout_address(self, txo: TxOutput):
if txo.type == TYPE_ADDRESS:
addr = txo.address
elif txo.type == TYPE_PUBKEY:
addr = bitcoin.public_key_to_p2pkh(bfh(txo.address))
else:
addr = None
return addr
def get_txout_address(self, txo: TxOutput) -> Optional[str]:
return txo.address

def load_unverified_transactions(self):
# review transactions that are in the history
Expand Down Expand Up @@ -183,7 +177,7 @@ def add_address(self, address):
if self.synchronizer:
self.synchronizer.add(address)

def get_conflicting_transactions(self, tx_hash, tx, include_self=False):
def get_conflicting_transactions(self, tx_hash, tx: Transaction, include_self=False):
"""Returns a set of transaction hashes from the wallet history that are
directly conflicting with tx, i.e. they have common outpoints being
spent with tx.
Expand All @@ -194,10 +188,10 @@ def get_conflicting_transactions(self, tx_hash, tx, include_self=False):
conflicting_txns = set()
with self.transaction_lock:
for txin in tx.inputs():
if txin['type'] == 'coinbase':
if txin.is_coinbase():
continue
prevout_hash = txin['prevout_hash']
prevout_n = txin['prevout_n']
prevout_hash = txin.prevout.txid.hex()
prevout_n = txin.prevout.out_idx
spending_tx_hash = self.db.get_spent_outpoint(prevout_hash, prevout_n)
if spending_tx_hash is None:
continue
Expand All @@ -213,7 +207,7 @@ def get_conflicting_transactions(self, tx_hash, tx, include_self=False):
conflicting_txns -= {tx_hash}
return conflicting_txns

def add_transaction(self, tx_hash, tx, allow_unrelated=False) -> bool:
def add_transaction(self, tx_hash, tx: Transaction, allow_unrelated=False) -> bool:
"""Returns whether the tx was successfully added to the wallet history."""
assert tx_hash, tx_hash
assert tx, tx
Expand All @@ -226,7 +220,7 @@ def add_transaction(self, tx_hash, tx, allow_unrelated=False) -> bool:
# BUT we track is_mine inputs in a txn, and during subsequent calls
# of add_transaction tx, we might learn of more-and-more inputs of
# being is_mine, as we roll the gap_limit forward
is_coinbase = tx.inputs()[0]['type'] == 'coinbase'
is_coinbase = tx.inputs()[0].is_coinbase()
tx_height = self.get_tx_height(tx_hash).height
if not allow_unrelated:
# note that during sync, if the transactions are not properly sorted,
Expand Down Expand Up @@ -277,11 +271,11 @@ def add_value_from_prev_output():
self._get_addr_balance_cache.pop(addr, None) # invalidate cache
return
for txi in tx.inputs():
if txi['type'] == 'coinbase':
if txi.is_coinbase():
continue
prevout_hash = txi['prevout_hash']
prevout_n = txi['prevout_n']
ser = prevout_hash + ':%d' % prevout_n
prevout_hash = txi.prevout.txid.hex()
prevout_n = txi.prevout.out_idx
ser = txi.prevout.to_str()
self.db.set_spent_outpoint(prevout_hash, prevout_n, tx_hash)
add_value_from_prev_output()
# add outputs
Expand Down Expand Up @@ -310,10 +304,10 @@ def remove_from_spent_outpoints():
if tx is not None:
# if we have the tx, this branch is faster
for txin in tx.inputs():
if txin['type'] == 'coinbase':
if txin.is_coinbase():
continue
prevout_hash = txin['prevout_hash']
prevout_n = txin['prevout_n']
prevout_hash = txin.prevout.txid.hex()
prevout_n = txin.prevout.out_idx
self.db.remove_spent_outpoint(prevout_hash, prevout_n)
else:
# expensive but always works
Expand Down Expand Up @@ -572,7 +566,7 @@ def get_local_height(self) -> int:
return cached_local_height
return self.network.get_local_height() if self.network else self.db.get('stored_height', 0)

def add_future_tx(self, tx, num_blocks):
def add_future_tx(self, tx: Transaction, num_blocks):
with self.lock:
self.add_transaction(tx.txid(), tx)
self.future_tx[tx.txid()] = num_blocks
Expand Down Expand Up @@ -649,13 +643,15 @@ def get_wallet_delta(self, tx: Transaction):
if self.is_mine(addr):
is_mine = True
is_relevant = True
d = self.db.get_txo_addr(txin['prevout_hash'], addr)
d = self.db.get_txo_addr(txin.prevout.txid.hex(), addr)
for n, v, cb in d:
if n == txin['prevout_n']:
if n == txin.prevout.out_idx:
value = v
break
else:
value = None
if value is None:
value = txin.value_sats()
if value is None:
is_pruned = True
else:
Expand Down Expand Up @@ -736,23 +732,19 @@ def get_addr_io(self, address):
sent[txi] = height
return received, sent

def get_addr_utxo(self, address):
def get_addr_utxo(self, address: str) -> Dict[TxOutpoint, PartialTxInput]:
coins, spent = self.get_addr_io(address)
for txi in spent:
coins.pop(txi)
out = {}
for txo, v in coins.items():
for prevout_str, v in coins.items():
tx_height, value, is_cb = v
prevout_hash, prevout_n = txo.split(':')
x = {
'address':address,
'value':value,
'prevout_n':int(prevout_n),
'prevout_hash':prevout_hash,
'height':tx_height,
'coinbase':is_cb
}
out[txo] = x
prevout = TxOutpoint.from_str(prevout_str)
utxo = PartialTxInput(prevout=prevout)
utxo._trusted_address = address
utxo._trusted_value_sats = value
utxo.block_height = tx_height
out[prevout] = utxo
return out

# return the total amount ever received by an address
Expand Down Expand Up @@ -799,7 +791,8 @@ def get_addr_balance(self, address, *, excluded_coins: Set[str] = None):

@with_local_height_cached
def get_utxos(self, domain=None, *, excluded_addresses=None,
mature_only: bool = False, confirmed_only: bool = False, nonlocal_only: bool = False):
mature_only: bool = False, confirmed_only: bool = False,
nonlocal_only: bool = False) -> Sequence[PartialTxInput]:
coins = []
if domain is None:
domain = self.get_addresses()
Expand All @@ -809,14 +802,15 @@ def get_utxos(self, domain=None, *, excluded_addresses=None,
mempool_height = self.get_local_height() + 1 # height of next block
for addr in domain:
utxos = self.get_addr_utxo(addr)
for x in utxos.values():
if confirmed_only and x['height'] <= 0:
for utxo in utxos.values():
if confirmed_only and utxo.block_height <= 0:
continue
if nonlocal_only and x['height'] == TX_HEIGHT_LOCAL:
if nonlocal_only and utxo.block_height == TX_HEIGHT_LOCAL:
continue
if mature_only and x['coinbase'] and x['height'] + COINBASE_MATURITY > mempool_height:
if (mature_only and utxo.prevout.is_coinbase()
and utxo.block_height + COINBASE_MATURITY > mempool_height):
continue
coins.append(x)
coins.append(utxo)
continue
return coins

Expand Down
9 changes: 6 additions & 3 deletions electrum/base_wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from . import bitcoin
from . import keystore
from . import mnemonic
from .bip32 import is_bip32_derivation, xpub_type, normalize_bip32_derivation
from .bip32 import is_bip32_derivation, xpub_type, normalize_bip32_derivation, BIP32Node
from .keystore import bip44_derivation, purpose48_derivation
from .wallet import (Imported_Wallet, Standard_Wallet, Multisig_Wallet,
wallet_types, Wallet, Abstract_Wallet)
Expand Down Expand Up @@ -230,7 +230,7 @@ def on_import(self, text):
assert bitcoin.is_private_key(pk)
txin_type, pubkey = k.import_privkey(pk, None)
addr = bitcoin.pubkey_to_address(txin_type, pubkey)
self.data['addresses'][addr] = {'type':txin_type, 'pubkey':pubkey, 'redeem_script':None}
self.data['addresses'][addr] = {'type':txin_type, 'pubkey':pubkey}
self.keystores.append(k)
else:
return self.terminate()
Expand Down Expand Up @@ -394,7 +394,7 @@ def derivation_and_script_type_dialog(self, f):
# For segwit, a custom path is used, as there is no standard at all.
default_choice_idx = 2
choices = [
('standard', 'legacy multisig (p2sh)', "m/45'/0"),
('standard', 'legacy multisig (p2sh)', normalize_bip32_derivation("m/45'/0")),
('p2wsh-p2sh', 'p2sh-segwit multisig (p2wsh-p2sh)', purpose48_derivation(0, xtype='p2wsh-p2sh')),
('p2wsh', 'native segwit multisig (p2wsh)', purpose48_derivation(0, xtype='p2wsh')),
]
Expand All @@ -420,16 +420,19 @@ def on_hw_derivation(self, name, device_info, derivation, xtype):
from .keystore import hardware_keystore
try:
xpub = self.plugin.get_xpub(device_info.device.id_, derivation, xtype, self)
root_xpub = self.plugin.get_xpub(device_info.device.id_, 'm', 'standard', self)
except ScriptTypeNotSupported:
raise # this is handled in derivation_dialog
except BaseException as e:
self.logger.exception('')
self.show_error(e)
return
xfp = BIP32Node.from_xkey(root_xpub).calc_fingerprint_of_this_node().hex().lower()
d = {
'type': 'hardware',
'hw_type': name,
'derivation': derivation,
'root_fingerprint': xfp,
'xpub': xpub,
'label': device_info.label,
}
Expand Down
71 changes: 65 additions & 6 deletions electrum/bip32.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# file LICENCE or http://www.opensource.org/licenses/mit-license.php

import hashlib
from typing import List, Tuple, NamedTuple, Union, Iterable
from typing import List, Tuple, NamedTuple, Union, Iterable, Sequence, Optional

from .util import bfh, bh2u, BitcoinException
from . import constants
Expand Down Expand Up @@ -116,7 +116,7 @@ class BIP32Node(NamedTuple):
eckey: Union[ecc.ECPubkey, ecc.ECPrivkey]
chaincode: bytes
depth: int = 0
fingerprint: bytes = b'\x00'*4
fingerprint: bytes = b'\x00'*4 # as in serialized format, this is the *parent's* fingerprint
child_number: bytes = b'\x00'*4

@classmethod
Expand Down Expand Up @@ -161,7 +161,18 @@ def from_rootseed(cls, seed: bytes, *, xtype: str) -> 'BIP32Node':
eckey=ecc.ECPrivkey(master_k),
chaincode=master_c)

@classmethod
def from_bytes(cls, b: bytes) -> 'BIP32Node':
if len(b) != 78:
raise Exception(f"unexpected xkey raw bytes len {len(b)} != 78")
xkey = EncodeBase58Check(b)
return cls.from_xkey(xkey)

def to_xprv(self, *, net=None) -> str:
payload = self.to_xprv_bytes(net=net)
return EncodeBase58Check(payload)

def to_xprv_bytes(self, *, net=None) -> bytes:
if not self.is_private():
raise Exception("cannot serialize as xprv; private key missing")
payload = (xprv_header(self.xtype, net=net) +
Expand All @@ -172,24 +183,34 @@ def to_xprv(self, *, net=None) -> str:
bytes([0]) +
self.eckey.get_secret_bytes())
assert len(payload) == 78, f"unexpected xprv payload len {len(payload)}"
return EncodeBase58Check(payload)
return payload

def to_xpub(self, *, net=None) -> str:
payload = self.to_xpub_bytes(net=net)
return EncodeBase58Check(payload)

def to_xpub_bytes(self, *, net=None) -> bytes:
payload = (xpub_header(self.xtype, net=net) +
bytes([self.depth]) +
self.fingerprint +
self.child_number +
self.chaincode +
self.eckey.get_public_key_bytes(compressed=True))
assert len(payload) == 78, f"unexpected xpub payload len {len(payload)}"
return EncodeBase58Check(payload)
return payload

def to_xkey(self, *, net=None) -> str:
if self.is_private():
return self.to_xprv(net=net)
else:
return self.to_xpub(net=net)

def to_bytes(self, *, net=None) -> bytes:
if self.is_private():
return self.to_xprv_bytes(net=net)
else:
return self.to_xpub_bytes(net=net)

def convert_to_public(self) -> 'BIP32Node':
if not self.is_private():
return self
Expand Down Expand Up @@ -248,6 +269,12 @@ def subkey_at_public_derivation(self, path: Union[str, Iterable[int]]) -> 'BIP32
fingerprint=fingerprint,
child_number=child_number)

def calc_fingerprint_of_this_node(self) -> bytes:
"""Returns the fingerprint of this node.
Note that self.fingerprint is of the *parent*.
"""
return hash_160(self.eckey.get_public_key_bytes(compressed=True))[0:4]


def xpub_type(x):
return BIP32Node.from_xkey(x).xtype
Expand Down Expand Up @@ -308,7 +335,7 @@ def convert_bip32_path_to_list_of_uint32(n: str) -> List[int]:
return path


def convert_bip32_intpath_to_strpath(path: List[int]) -> str:
def convert_bip32_intpath_to_strpath(path: Sequence[int]) -> str:
s = "m/"
for child_index in path:
if not isinstance(child_index, int):
Expand Down Expand Up @@ -336,8 +363,40 @@ def is_bip32_derivation(s: str) -> bool:
return True


def normalize_bip32_derivation(s: str) -> str:
def normalize_bip32_derivation(s: Optional[str]) -> Optional[str]:
if s is None:
return None
if not is_bip32_derivation(s):
raise ValueError(f"invalid bip32 derivation: {s}")
ints = convert_bip32_path_to_list_of_uint32(s)
return convert_bip32_intpath_to_strpath(ints)


def is_all_public_derivation(path: Union[str, Iterable[int]]) -> bool:
"""Returns whether all levels in path use non-hardened derivation."""
if isinstance(path, str):
path = convert_bip32_path_to_list_of_uint32(path)
for child_index in path:
if child_index < 0:
raise ValueError('the bip32 index needs to be non-negative')
if child_index & BIP32_PRIME:
return False
return True


def root_fp_and_der_prefix_from_xkey(xkey: str) -> Tuple[Optional[str], Optional[str]]:
"""Returns the root bip32 fingerprint and the derivation path from the
root to the given xkey, if they can be determined. Otherwise (None, None).
"""
node = BIP32Node.from_xkey(xkey)
derivation_prefix = None
root_fingerprint = None
assert node.depth >= 0, node.depth
if node.depth == 0:
derivation_prefix = 'm'
root_fingerprint = node.calc_fingerprint_of_this_node().hex().lower()
elif node.depth == 1:
child_number_int = int.from_bytes(node.child_number, 'big')
derivation_prefix = convert_bip32_intpath_to_strpath([child_number_int])
root_fingerprint = node.fingerprint.hex()
return root_fingerprint, derivation_prefix
Loading

0 comments on commit 707b74d

Please sign in to comment.