Skip to content

Commit

Permalink
Merge 1387a45 into 7cf547c
Browse files Browse the repository at this point in the history
  • Loading branch information
SomberNight committed Jan 9, 2018
2 parents 7cf547c + 1387a45 commit 74025c3
Showing 1 changed file with 51 additions and 13 deletions.
64 changes: 51 additions & 13 deletions plugins/trezor/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,27 @@
# TREZOR initialization methods
TIM_NEW, TIM_RECOVER, TIM_MNEMONIC, TIM_PRIVKEY = range(0, 4)

# script "generation"
SCRIPT_GEN_LEGACY, SCRIPT_GEN_P2SH_SEGWIT, SCRIPT_GEN_NATIVE_SEGWIT = range(0, 3)

class TrezorCompatibleKeyStore(Hardware_KeyStore):

def get_derivation(self):
return self.derivation

def is_segwit(self):
return self.derivation.startswith("m/49'/")
def get_script_gen(self):
def is_p2sh_segwit():
return self.derivation.startswith("m/49'/")

def is_native_segwit():
return self.derivation.startswith("m/84'/")

if is_native_segwit():
return SCRIPT_GEN_NATIVE_SEGWIT
elif is_p2sh_segwit():
return SCRIPT_GEN_P2SH_SEGWIT
else:
return SCRIPT_GEN_LEGACY

def get_client(self, force_pair=True):
return self.plugin.get_client(self, force_pair)
Expand Down Expand Up @@ -226,8 +240,8 @@ def sign_transaction(self, keystore, tx, prev_tx, xpub_path):
self.prev_tx = prev_tx
self.xpub_path = xpub_path
client = self.get_client(keystore)
inputs = self.tx_inputs(tx, True, keystore.is_segwit())
outputs = self.tx_outputs(keystore.get_derivation(), tx, keystore.is_segwit())
inputs = self.tx_inputs(tx, True, keystore.get_script_gen())
outputs = self.tx_outputs(keystore.get_derivation(), tx, keystore.get_script_gen())
signed_tx = client.sign_tx(self.get_coin_name(), inputs, outputs, lock_time=tx.locktime)[1]
raw = bh2u(signed_tx)
tx.update_signatures(raw)
Expand All @@ -241,11 +255,16 @@ def show_address(self, wallet, address):
derivation = wallet.keystore.derivation
address_path = "%s/%d/%d"%(derivation, change, index)
address_n = client.expand_path(address_path)
segwit = wallet.keystore.is_segwit()
script_type = self.types.InputScriptType.SPENDP2SHWITNESS if segwit else self.types.InputScriptType.SPENDADDRESS
script_gen = wallet.keystore.get_script_gen()
if script_gen == SCRIPT_GEN_NATIVE_SEGWIT:
script_type = self.types.InputScriptType.SPENDWITNESS
elif script_gen == SCRIPT_GEN_P2SH_SEGWIT:
script_type = self.types.InputScriptType.SPENDP2SHWITNESS
else:
script_type = self.types.InputScriptType.SPENDADDRESS
client.get_address(self.get_coin_name(), address_n, True, script_type=script_type)

def tx_inputs(self, tx, for_sig=False, segwit=False):
def tx_inputs(self, tx, for_sig=False, script_gen=SCRIPT_GEN_LEGACY):
inputs = []
for txin in tx.inputs():
txinputtype = self.types.TxInputType()
Expand All @@ -260,7 +279,12 @@ def tx_inputs(self, tx, for_sig=False, segwit=False):
xpub, s = parse_xpubkey(x_pubkey)
xpub_n = self.client_class.expand_path(self.xpub_path[xpub])
txinputtype._extend_address_n(xpub_n + s)
txinputtype.script_type = self.types.InputScriptType.SPENDP2SHWITNESS if segwit else self.types.InputScriptType.SPENDADDRESS
if script_gen == SCRIPT_GEN_NATIVE_SEGWIT:
txinputtype.script_type = self.types.InputScriptType.SPENDWITNESS
elif script_gen == SCRIPT_GEN_P2SH_SEGWIT:
txinputtype.script_type = self.types.InputScriptType.SPENDP2SHWITNESS
else:
txinputtype.script_type = self.types.InputScriptType.SPENDADDRESS
else:
def f(x_pubkey):
if is_xpubkey(x_pubkey):
Expand All @@ -276,7 +300,12 @@ def f(x_pubkey):
signatures=map(lambda x: bfh(x)[:-1] if x else b'', txin.get('signatures')),
m=txin.get('num_sig'),
)
script_type = self.types.InputScriptType.SPENDP2SHWITNESS if segwit else self.types.InputScriptType.SPENDMULTISIG
if script_gen == SCRIPT_GEN_NATIVE_SEGWIT:
script_type = self.types.InputScriptType.SPENDWITNESS
elif script_gen == SCRIPT_GEN_P2SH_SEGWIT:
script_type = self.types.InputScriptType.SPENDP2SHWITNESS
else:
script_type = self.types.InputScriptType.SPENDMULTISIG
txinputtype = self.types.TxInputType(
script_type=script_type,
multisig=multisig
Expand Down Expand Up @@ -308,26 +337,35 @@ def f(x_pubkey):

return inputs

def tx_outputs(self, derivation, tx, segwit=False):
def tx_outputs(self, derivation, tx, script_gen=SCRIPT_GEN_LEGACY):
outputs = []
has_change = False

for _type, address, amount in tx.outputs():
info = tx.output_info.get(address)
if info is not None and not has_change:
has_change = True # no more than one change address
addrtype, hash_160 = b58_address_to_hash160(address)
index, xpubs, m = info
if len(xpubs) == 1:
script_type = self.types.OutputScriptType.PAYTOP2SHWITNESS if segwit else self.types.OutputScriptType.PAYTOADDRESS
if script_gen == SCRIPT_GEN_NATIVE_SEGWIT:
script_type = self.types.OutputScriptType.PAYTOWITNESS
elif script_gen == SCRIPT_GEN_P2SH_SEGWIT:
script_type = self.types.OutputScriptType.PAYTOP2SHWITNESS
else:
script_type = self.types.OutputScriptType.PAYTOADDRESS
address_n = self.client_class.expand_path(derivation + "/%d/%d"%index)
txoutputtype = self.types.TxOutputType(
amount = amount,
script_type = script_type,
address_n = address_n,
)
else:
script_type = self.types.OutputScriptType.PAYTOP2SHWITNESS if segwit else self.types.OutputScriptType.PAYTOMULTISIG
if script_gen == SCRIPT_GEN_NATIVE_SEGWIT:
script_type = self.types.OutputScriptType.PAYTOWITNESS
elif script_gen == SCRIPT_GEN_P2SH_SEGWIT:
script_type = self.types.OutputScriptType.PAYTOP2SHWITNESS
else:
script_type = self.types.OutputScriptType.PAYTOMULTISIG
address_n = self.client_class.expand_path("/%d/%d"%index)
nodes = map(self.ckd_public.deserialize, xpubs)
pubkeys = [ self.types.HDNodePathType(node=node, address_n=address_n) for node in nodes]
Expand Down

0 comments on commit 74025c3

Please sign in to comment.