Skip to content

Commit

Permalink
lnchannel: add_htlc and receive_htlc now take and return UpdateAddHtlc
Browse files Browse the repository at this point in the history
also fix undefined vars in _maybe_forward_htlc and _maybe_fulfill_htlc
in lnpeer
  • Loading branch information
SomberNight authored and ecdsa committed Aug 20, 2019
1 parent 62be0c4 commit 3a2ab14
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 67 deletions.
22 changes: 13 additions & 9 deletions electrum/lnchannel.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,39 +241,43 @@ def get_funding_address(self):
script = funding_output_script(self.config[LOCAL], self.config[REMOTE])
return redeem_script_to_address('p2wsh', script)

def add_htlc(self, htlc):
def add_htlc(self, htlc: UpdateAddHtlc) -> UpdateAddHtlc:
"""
AddHTLC adds an HTLC to the state machine's local update log. This method
should be called when preparing to send an outgoing HTLC.
This docstring is from LND.
"""
assert type(htlc) is dict
self._check_can_pay(htlc['amount_msat'])
htlc = UpdateAddHtlc(**htlc, htlc_id=self.config[LOCAL].next_htlc_id)
if isinstance(htlc, dict): # legacy conversion # FIXME remove
htlc = UpdateAddHtlc(**htlc)
assert isinstance(htlc, UpdateAddHtlc)
self._check_can_pay(htlc.amount_msat)
htlc = htlc._replace(htlc_id=self.config[LOCAL].next_htlc_id)
self.hm.send_htlc(htlc)
self.print_error("add_htlc")
self.config[LOCAL]=self.config[LOCAL]._replace(next_htlc_id=htlc.htlc_id + 1)
return htlc.htlc_id
return htlc

def receive_htlc(self, htlc):
def receive_htlc(self, htlc: UpdateAddHtlc) -> UpdateAddHtlc:
"""
ReceiveHTLC adds an HTLC to the state machine's remote update log. This
method should be called in response to receiving a new HTLC from the remote
party.
This docstring is from LND.
"""
assert type(htlc) is dict
htlc = UpdateAddHtlc(**htlc, htlc_id = self.config[REMOTE].next_htlc_id)
if isinstance(htlc, dict): # legacy conversion # FIXME remove
htlc = UpdateAddHtlc(**htlc)
assert isinstance(htlc, UpdateAddHtlc)
htlc = htlc._replace(htlc_id=self.config[REMOTE].next_htlc_id)
if self.available_to_spend(REMOTE) < htlc.amount_msat:
raise RemoteMisbehaving('Remote dipped below channel reserve.' +\
f' Available at remote: {self.available_to_spend(REMOTE)},' +\
f' HTLC amount: {htlc.amount_msat}')
self.hm.recv_htlc(htlc)
self.print_error("receive_htlc")
self.config[REMOTE]=self.config[REMOTE]._replace(next_htlc_id=htlc.htlc_id + 1)
return htlc.htlc_id
return htlc

def sign_next_commitment(self):
"""
Expand Down
91 changes: 52 additions & 39 deletions electrum/lnpeer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from .util import PrintError, bh2u, print_error, bfh, log_exceptions, list_enabled_bits, ignore_exceptions
from .transaction import Transaction, TxOutput
from .lnonion import (new_onion_packet, decode_onion_error, OnionFailureCode, calc_hops_data_for_payment,
process_onion_packet, OnionPacket, construct_onion_error, OnionRoutingFailureMessage)
process_onion_packet, OnionPacket, construct_onion_error, OnionRoutingFailureMessage,
ProcessedOnionPacket)
from .lnchannel import Channel, RevokeAndAck, htlcsum
from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc,
RemoteConfig, OnlyPubkeyKeypair, ChannelConstraints, RevocationStore,
Expand Down Expand Up @@ -841,7 +842,7 @@ async def await_local(self, chan: Channel, ctn: int):
await self._local_changed_events[chan.channel_id].wait()

async def pay(self, route: List['RouteEdge'], chan: Channel, amount_msat: int,
payment_hash: bytes, min_final_cltv_expiry: int):
payment_hash: bytes, min_final_cltv_expiry: int) -> UpdateAddHtlc:
assert chan.get_state() == "OPEN", chan.get_state()
assert amount_msat > 0, "amount_msat is not greater zero"
# create onion packet
Expand All @@ -851,22 +852,22 @@ async def pay(self, route: List['RouteEdge'], chan: Channel, amount_msat: int,
secret_key = os.urandom(32)
onion = new_onion_packet([x.node_id for x in route], secret_key, hops_data, associated_data=payment_hash)
# create htlc
htlc = {'amount_msat':amount_msat, 'payment_hash':payment_hash, 'cltv_expiry':cltv}
htlc_id = chan.add_htlc(htlc)
htlc = UpdateAddHtlc(amount_msat=amount_msat, payment_hash=payment_hash, cltv_expiry=cltv)
htlc = chan.add_htlc(htlc)
remote_ctn = chan.get_current_ctn(REMOTE)
chan.onion_keys[htlc_id] = secret_key
self.attempted_route[(chan.channel_id, htlc_id)] = route
chan.onion_keys[htlc.htlc_id] = secret_key
self.attempted_route[(chan.channel_id, htlc.htlc_id)] = route
self.print_error(f"starting payment. route: {route}")
self.send_message("update_add_htlc",
channel_id=chan.channel_id,
id=htlc_id,
cltv_expiry=cltv,
amount_msat=amount_msat,
payment_hash=payment_hash,
id=htlc.htlc_id,
cltv_expiry=htlc.cltv_expiry,
amount_msat=htlc.amount_msat,
payment_hash=htlc.payment_hash,
onion_routing_packet=onion.to_bytes())
self.remote_pending_updates[chan] = True
await self.await_remote(chan, remote_ctn)
return UpdateAddHtlc(**htlc, htlc_id=htlc_id)
return htlc

def send_revoke_and_ack(self, chan: Channel):
rev, _ = chan.revoke_current_commitment()
Expand Down Expand Up @@ -923,18 +924,29 @@ def on_update_add_htlc(self, payload):
if cltv_expiry >= 500_000_000:
pass # TODO fail the channel
# add htlc
htlc = {'amount_msat': amount_msat_htlc, 'payment_hash':payment_hash, 'cltv_expiry':cltv_expiry}
htlc_id = chan.receive_htlc(htlc)
htlc = UpdateAddHtlc(amount_msat=amount_msat_htlc, payment_hash=payment_hash, cltv_expiry=cltv_expiry)
htlc = chan.receive_htlc(htlc)
self.local_pending_updates[chan] = True
local_ctn = chan.get_current_ctn(LOCAL)
remote_ctn = chan.get_current_ctn(REMOTE)
if processed_onion.are_we_final:
asyncio.ensure_future(self._maybe_fulfill_htlc(chan, local_ctn, remote_ctn, htlc_id, htlc, payment_hash, cltv_expiry, amount_msat_htlc, processed_onion))
asyncio.ensure_future(self._maybe_fulfill_htlc(chan=chan,
htlc=htlc,
local_ctn=local_ctn,
remote_ctn=remote_ctn,
onion_packet=onion_packet,
processed_onion=processed_onion))
else:
asyncio.ensure_future(self._maybe_forward_htlc(chan, local_ctn, remote_ctn, htlc_id, htlc, payment_hash, cltv_expiry, amount_msat_htlc, processed_onion))
asyncio.ensure_future(self._maybe_forward_htlc(chan=chan,
htlc=htlc,
local_ctn=local_ctn,
remote_ctn=remote_ctn,
onion_packet=onion_packet,
processed_onion=processed_onion))

@log_exceptions
async def _maybe_forward_htlc(self, chan, local_ctn, remote_ctn, htlc_id, htlc, payment_hash, cltv_expiry, amount_msat_htlc, processed_onion):
async def _maybe_forward_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *, local_ctn: int, remote_ctn: int,
onion_packet: OnionPacket, processed_onion: ProcessedOnionPacket):
await self.await_local(chan, local_ctn)
await self.await_remote(chan, remote_ctn)
# Forward HTLC
Expand All @@ -945,69 +957,70 @@ async def _maybe_forward_htlc(self, chan, local_ctn, remote_ctn, htlc_id, htlc,
if next_chan is None or next_chan.get_state() != 'OPEN':
self.print_error("cannot forward htlc", next_chan.get_state() if next_chan else None)
reason = OnionRoutingFailureMessage(code=OnionFailureCode.PERMANENT_CHANNEL_FAILURE, data=b'')
await self.fail_htlc(chan, htlc_id, onion_packet, reason)
await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
return
self.print_error('forwarding htlc to', next_chan.node_id)
next_cltv_expiry = int.from_bytes(dph.outgoing_cltv_value, 'big')
next_amount_msat_htlc = int.from_bytes(dph.amt_to_forward, 'big')
next_htlc = {'amount_msat':next_amount_msat_htlc, 'payment_hash':payment_hash, 'cltv_expiry':next_cltv_expiry}
next_htlc_id = next_chan.add_htlc(next_htlc)
next_htlc = UpdateAddHtlc(amount_msat=next_amount_msat_htlc, payment_hash=htlc.payment_hash, cltv_expiry=next_cltv_expiry)
next_htlc = next_chan.add_htlc(next_htlc)
next_remote_ctn = next_chan.get_current_ctn(REMOTE)
next_peer.send_message(
"update_add_htlc",
channel_id=next_chan.channel_id,
id=next_htlc_id,
id=next_htlc.htlc_id,
cltv_expiry=dph.outgoing_cltv_value,
amount_msat=dph.amt_to_forward,
payment_hash=payment_hash,
payment_hash=next_htlc.payment_hash,
onion_routing_packet=processed_onion.next_packet.to_bytes()
)
next_peer.remote_pending_updates[next_chan] = True
await next_peer.await_remote(next_chan, next_remote_ctn)
# wait until we get paid
preimage = await next_peer.payment_preimages[payment_hash].get()
preimage = await next_peer.payment_preimages[next_htlc.payment_hash].get()
# fulfill the original htlc
await self._fulfill_htlc(chan, htlc_id, preimage)
await self._fulfill_htlc(chan, htlc.htlc_id, preimage)
self.print_error("htlc forwarded successfully")

@log_exceptions
async def _maybe_fulfill_htlc(self, chan, local_ctn, remote_ctn, htlc_id, htlc, payment_hash, cltv_expiry, amount_msat_htlc, processed_onion):
async def _maybe_fulfill_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *, local_ctn: int, remote_ctn: int,
onion_packet: OnionPacket, processed_onion: ProcessedOnionPacket):
await self.await_local(chan, local_ctn)
await self.await_remote(chan, remote_ctn)
try:
invoice = self.lnworker.get_invoice(payment_hash)
preimage = self.lnworker.get_preimage(payment_hash)
invoice = self.lnworker.get_invoice(htlc.payment_hash)
preimage = self.lnworker.get_preimage(htlc.payment_hash)
except UnknownPaymentHash:
reason = OnionRoutingFailureMessage(code=OnionFailureCode.UNKNOWN_PAYMENT_HASH, data=b'')
await self.fail_htlc(chan, htlc_id, onion_packet, reason)
await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
return
expected_received_msat = int(invoice.amount * bitcoin.COIN * 1000) if invoice.amount is not None else None
if expected_received_msat is not None and \
(amount_msat_htlc < expected_received_msat or amount_msat_htlc > 2 * expected_received_msat):
(htlc.amount_msat < expected_received_msat or htlc.amount_msat > 2 * expected_received_msat):
reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_PAYMENT_AMOUNT, data=b'')
await self.fail_htlc(chan, htlc_id, onion_packet, reason)
await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
return
local_height = self.network.get_local_height()
if local_height + MIN_FINAL_CLTV_EXPIRY_ACCEPTED > cltv_expiry:
if local_height + MIN_FINAL_CLTV_EXPIRY_ACCEPTED > htlc.cltv_expiry:
reason = OnionRoutingFailureMessage(code=OnionFailureCode.FINAL_EXPIRY_TOO_SOON, data=b'')
await self.fail_htlc(chan, htlc_id, onion_packet, reason)
await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
return
cltv_from_onion = int.from_bytes(processed_onion.hop_data.per_hop.outgoing_cltv_value, byteorder="big")
if cltv_from_onion != cltv_expiry:
if cltv_from_onion != htlc.cltv_expiry:
reason = OnionRoutingFailureMessage(code=OnionFailureCode.FINAL_INCORRECT_CLTV_EXPIRY,
data=cltv_expiry.to_bytes(4, byteorder="big"))
await self.fail_htlc(chan, htlc_id, onion_packet, reason)
data=htlc.cltv_expiry.to_bytes(4, byteorder="big"))
await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
return
amount_from_onion = int.from_bytes(processed_onion.hop_data.per_hop.amt_to_forward, byteorder="big")
if amount_from_onion > amount_msat_htlc:
if amount_from_onion > htlc.amount_msat:
reason = OnionRoutingFailureMessage(code=OnionFailureCode.FINAL_INCORRECT_HTLC_AMOUNT,
data=amount_msat_htlc.to_bytes(8, byteorder="big"))
await self.fail_htlc(chan, htlc_id, onion_packet, reason)
data=htlc.amount_msat.to_bytes(8, byteorder="big"))
await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
return
self.network.trigger_callback('htlc_added', UpdateAddHtlc(**htlc, htlc_id=htlc_id), invoice, RECEIVED)
self.network.trigger_callback('htlc_added', htlc, invoice, RECEIVED)
if self.network.config.debug_lightning_do_not_settle:
return
await self._fulfill_htlc(chan, htlc_id, preimage)
await self._fulfill_htlc(chan, htlc.htlc_id, preimage)

async def _fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes):
chan.settle_htlc(preimage, htlc_id)
Expand Down
17 changes: 11 additions & 6 deletions electrum/lnutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,8 +567,10 @@ class LnGlobalFeatures(IntFlag):
LN_GLOBAL_FEATURES_KNOWN_SET = set(LnGlobalFeatures)


class LNPeerAddr(namedtuple('LNPeerAddr', ['host', 'port', 'pubkey'])):
__slots__ = ()
class LNPeerAddr(NamedTuple):
host: str
port: int
pubkey: bytes

def __str__(self):
return '{}@{}:{}'.format(bh2u(self.pubkey), self.host, self.port)
Expand Down Expand Up @@ -663,19 +665,22 @@ def format_short_channel_id(short_channel_id: Optional[bytes]):
+ 'x' + str(int.from_bytes(short_channel_id[3:6], 'big')) \
+ 'x' + str(int.from_bytes(short_channel_id[6:], 'big'))


class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'htlc_id'])):
"""
This whole class body is so that if you pass a hex-string as payment_hash,
it is decoded to bytes. Bytes can't be saved to disk, so we save hex-strings.
"""
# note: typing.NamedTuple cannot be used because we are overriding __new__

__slots__ = ()
def __new__(cls, *args, **kwargs):
# if you pass a hex-string as payment_hash, it is decoded to bytes.
# Bytes can't be saved to disk, so we save hex-strings.
if len(args) > 0:
args = list(args)
if type(args[1]) is str:
args[1] = bfh(args[1])
return super().__new__(cls, *args)
if type(kwargs['payment_hash']) is str:
kwargs['payment_hash'] = bfh(kwargs['payment_hash'])
if len(args) < 4 and 'htlc_id' not in kwargs:
kwargs['htlc_id'] = None
return super().__new__(cls, **kwargs)

26 changes: 13 additions & 13 deletions electrum/tests/test_lnchannel.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,13 @@ def setUp(self):
# First Alice adds the outgoing HTLC to her local channel's state
# update log. Then Alice sends this wire message over to Bob who adds
# this htlc to his remote state update log.
self.aliceHtlcIndex = self.alice_channel.add_htlc(self.htlc_dict)
self.aliceHtlcIndex = self.alice_channel.add_htlc(self.htlc_dict).htlc_id
self.assertNotEqual(self.alice_channel.hm.htlcs_by_direction(REMOTE, RECEIVED, 1), set())

before = self.bob_channel.balance_minus_outgoing_htlcs(REMOTE)
beforeLocal = self.bob_channel.balance_minus_outgoing_htlcs(LOCAL)

self.bobHtlcIndex = self.bob_channel.receive_htlc(self.htlc_dict)
self.bobHtlcIndex = self.bob_channel.receive_htlc(self.htlc_dict).htlc_id

self.assertEqual(1, self.bob_channel.hm.log[LOCAL]['ctn'] + 1)
self.assertNotEqual(self.bob_channel.hm.htlcs_by_direction(LOCAL, RECEIVED, 1), set())
Expand All @@ -230,8 +230,8 @@ def setUp(self):
def test_concurrent_reversed_payment(self):
self.htlc_dict['payment_hash'] = bitcoin.sha256(32 * b'\x02')
self.htlc_dict['amount_msat'] += 1000
bob_idx = self.bob_channel.add_htlc(self.htlc_dict)
alice_idx = self.alice_channel.receive_htlc(self.htlc_dict)
self.bob_channel.add_htlc(self.htlc_dict)
self.alice_channel.receive_htlc(self.htlc_dict)
self.alice_channel.receive_new_commitment(*self.bob_channel.sign_next_commitment())
self.assertEqual(len(self.alice_channel.pending_commitment(REMOTE).outputs()), 4)

Expand Down Expand Up @@ -481,8 +481,8 @@ def com():
self.assertNotEqual(tx5, tx6)

self.htlc_dict['amount_msat'] *= 5
bob_index = bob_channel.add_htlc(self.htlc_dict)
alice_index = alice_channel.receive_htlc(self.htlc_dict)
bob_index = bob_channel.add_htlc(self.htlc_dict).htlc_id
alice_index = alice_channel.receive_htlc(self.htlc_dict).htlc_id

bob_channel.pending_commitment(REMOTE)
alice_channel.pending_commitment(LOCAL)
Expand Down Expand Up @@ -597,7 +597,7 @@ def test_AddHTLCNegativeBalance(self):
def test_sign_commitment_is_pure(self):
force_state_transition(self.alice_channel, self.bob_channel)
self.htlc_dict['payment_hash'] = bitcoin.sha256(b'\x02' * 32)
aliceHtlcIndex = self.alice_channel.add_htlc(self.htlc_dict)
self.alice_channel.add_htlc(self.htlc_dict)
before_signing = self.alice_channel.to_save()
self.alice_channel.sign_next_commitment()
after_signing = self.alice_channel.to_save()
Expand All @@ -622,8 +622,8 @@ def test_DesyncHTLCs(self):
'cltv_expiry' : 5,
}

alice_idx = alice_channel.add_htlc(htlc_dict)
bob_idx = bob_channel.receive_htlc(htlc_dict)
alice_idx = alice_channel.add_htlc(htlc_dict).htlc_id
bob_idx = bob_channel.receive_htlc(htlc_dict).htlc_id
force_state_transition(alice_channel, bob_channel)
bob_channel.fail_htlc(bob_idx)
alice_channel.receive_fail_htlc(alice_idx)
Expand Down Expand Up @@ -745,8 +745,8 @@ def part3(self):
'amount_msat' : int(2 * one_bitcoin_in_msat),
'cltv_expiry' : 5,
}
alice_idx = self.alice_channel.add_htlc(htlc_dict)
bob_idx = self.bob_channel.receive_htlc(htlc_dict)
alice_idx = self.alice_channel.add_htlc(htlc_dict).htlc_id
bob_idx = self.bob_channel.receive_htlc(htlc_dict).htlc_id
force_state_transition(self.alice_channel, self.bob_channel)
self.check_bals(one_bitcoin_in_msat*3\
- self.alice_channel.pending_local_fee(),
Expand Down Expand Up @@ -791,8 +791,8 @@ def test_DustLimit(self):
}

old_values = [x.value for x in bob_channel.current_commitment(LOCAL).outputs() ]
aliceHtlcIndex = alice_channel.add_htlc(htlc)
bobHtlcIndex = bob_channel.receive_htlc(htlc)
aliceHtlcIndex = alice_channel.add_htlc(htlc).htlc_id
bobHtlcIndex = bob_channel.receive_htlc(htlc).htlc_id
force_state_transition(alice_channel, bob_channel)
alice_ctx = alice_channel.current_commitment(LOCAL)
bob_ctx = bob_channel.current_commitment(LOCAL)
Expand Down

0 comments on commit 3a2ab14

Please sign in to comment.