Skip to content

Commit

Permalink
pass blacklist to lnrouter.find_route, so that lnrouter is stateless …
Browse files Browse the repository at this point in the history
…(see #6778)
  • Loading branch information
ecdsa committed Jan 11, 2021
1 parent 9d7a317 commit ad91257
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 28 deletions.
32 changes: 10 additions & 22 deletions electrum/lnrouter.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,24 +135,12 @@ def is_fee_sane(fee_msat: int, *, payment_amount_msat: int) -> bool:
return False


BLACKLIST_DURATION = 3600

class LNPathFinder(Logger):

def __init__(self, channel_db: ChannelDB):
Logger.__init__(self)
self.channel_db = channel_db
self.blacklist = dict() # short_chan_id -> timestamp

def add_to_blacklist(self, short_channel_id: ShortChannelID):
self.logger.info(f'blacklisting channel {short_channel_id}')
now = int(time.time())
self.blacklist[short_channel_id] = now

def is_blacklisted(self, short_channel_id: ShortChannelID) -> bool:
now = int(time.time())
t = self.blacklist.get(short_channel_id, 0)
return now - t < BLACKLIST_DURATION

def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes,
payment_amt_msat: int, ignore_costs=False, is_mine=False, *,
Expand Down Expand Up @@ -200,10 +188,9 @@ def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes
overall_cost = base_cost + fee_msat + cltv_cost
return overall_cost, fee_msat

def get_distances(self, nodeA: bytes, nodeB: bytes,
invoice_amount_msat: int, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None
) -> Dict[bytes, PathEdge]:
def get_distances(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None,
blacklist: Set[ShortChannelID] = None) -> Dict[bytes, PathEdge]:
# note: we don't lock self.channel_db, so while the path finding runs,
# the underlying graph could potentially change... (not good but maybe ~OK?)

Expand All @@ -216,7 +203,6 @@ def get_distances(self, nodeA: bytes, nodeB: bytes,
nodes_to_explore = queue.PriorityQueue()
nodes_to_explore.put((0, invoice_amount_msat, nodeB)) # order of fields (in tuple) matters!


# main loop of search
while nodes_to_explore.qsize() > 0:
dist_to_edge_endnode, amount_msat, edge_endnode = nodes_to_explore.get()
Expand All @@ -229,7 +215,7 @@ def get_distances(self, nodeA: bytes, nodeB: bytes,
continue
for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode, my_channels=my_channels):
assert isinstance(edge_channel_id, bytes)
if self.is_blacklisted(edge_channel_id):
if blacklist and edge_channel_id in blacklist:
continue
channel_info = self.channel_db.get_channel_info(edge_channel_id, my_channels=my_channels)
edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id
Expand Down Expand Up @@ -263,7 +249,8 @@ def get_distances(self, nodeA: bytes, nodeB: bytes,
@profiler
def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
invoice_amount_msat: int, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None) \
my_channels: Dict[ShortChannelID, 'Channel'] = None,
blacklist: Set[ShortChannelID] = None) \
-> Optional[LNPaymentPath]:
"""Return a path from nodeA to nodeB."""
assert type(nodeA) is bytes
Expand All @@ -272,7 +259,7 @@ def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
if my_channels is None:
my_channels = {}

prev_node = self.get_distances(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels)
prev_node = self.get_distances(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels, blacklist=blacklist)

if nodeA not in prev_node:
return None # no path found
Expand Down Expand Up @@ -312,8 +299,9 @@ def create_route_from_path(self, path: Optional[LNPaymentPath], from_node_id: by
return route

def find_route(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *,
path = None, my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[LNPaymentRoute]:
path = None, my_channels: Dict[ShortChannelID, 'Channel'] = None,
blacklist: Set[ShortChannelID] = None) -> Optional[LNPaymentRoute]:
if not path:
path = self.find_path_for_payment(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels)
path = self.find_path_for_payment(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels, blacklist=blacklist)
if path:
return self.create_route_from_path(path, nodeA, my_channels=my_channels)
16 changes: 15 additions & 1 deletion electrum/lnutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections import namedtuple, defaultdict
from typing import NamedTuple, List, Tuple, Mapping, Optional, TYPE_CHECKING, Union, Dict, Set, Sequence
import re

import time
import attr
from aiorpcx import NetAddress

Expand Down Expand Up @@ -1313,3 +1313,17 @@ class OnionFailureCodeMetaFlag(IntFlag):
NODE = 0x2000
UPDATE = 0x1000


class ChannelBlackList:

def __init__(self):
self.blacklist = dict() # short_chan_id -> timestamp

def add(self, short_channel_id: ShortChannelID):
now = int(time.time())
self.blacklist[short_channel_id] = now

def get_current_list(self) -> Set[ShortChannelID]:
BLACKLIST_DURATION = 3600
now = int(time.time())
return set(k for k, t in self.blacklist.items() if now - t < BLACKLIST_DURATION)
11 changes: 7 additions & 4 deletions electrum/lnworker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from decimal import Decimal
import random
import time
from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING, NamedTuple, Union, Mapping, Any
from typing import Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING, NamedTuple, Union, Mapping, Any
import threading
import socket
import aiohttp
Expand Down Expand Up @@ -540,6 +540,7 @@ async def process_gossip(self, chan_anns, node_anns, chan_upds):
if categorized_chan_upds.good:
self.logger.debug(f'on_channel_update: {len(categorized_chan_upds.good)}/{len(chan_upds_chunk)}')


class LNWallet(LNWorker):

lnwatcher: Optional['LNWalletWatcher']
Expand Down Expand Up @@ -1014,7 +1015,8 @@ async def _pay_to_route(self, route: LNPaymentRoute, lnaddr: LnAddr) -> PaymentA
except IndexError:
self.logger.info("payment destination reported error")
else:
self.network.path_finder.add_to_blacklist(short_chan_id)
self.logger.info(f'blacklisting channel {short_channel_id}')
self.network.channel_blacklist.add(short_chan_id)
else:
# probably got "update_fail_malformed_htlc". well... who to penalise now?
assert payment_attempt.failure_message is not None
Expand Down Expand Up @@ -1127,6 +1129,7 @@ def _create_route_from_invoice(self, decoded_invoice: 'LnAddr',
channels = list(self.channels.values())
scid_to_my_channels = {chan.short_channel_id: chan for chan in channels
if chan.short_channel_id is not None}
blacklist = self.network.channel_blacklist.get_current_list()
for private_route in r_tags:
if len(private_route) == 0:
continue
Expand All @@ -1144,7 +1147,7 @@ def _create_route_from_invoice(self, decoded_invoice: 'LnAddr',
try:
route = self.network.path_finder.find_route(
self.node_keypair.pubkey, border_node_pubkey, amount_msat,
path=path, my_channels=scid_to_my_channels)
path=path, my_channels=scid_to_my_channels, blacklist=blacklist)
except NoChannelPolicy:
continue
if not route:
Expand Down Expand Up @@ -1186,7 +1189,7 @@ def _create_route_from_invoice(self, decoded_invoice: 'LnAddr',
if route is None:
route = self.network.path_finder.find_route(
self.node_keypair.pubkey, invoice_pubkey, amount_msat,
path=full_path, my_channels=scid_to_my_channels)
path=full_path, my_channels=scid_to_my_channels, blacklist=blacklist)
if not route:
raise NoPathFound()
if not is_route_sane_to_use(route, amount_msat, decoded_invoice.get_min_final_cltv_expiry()):
Expand Down
3 changes: 2 additions & 1 deletion electrum/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from .util import (log_exceptions, ignore_exceptions,
bfh, SilentTaskGroup, make_aiohttp_session, send_exception_to_crash_reporter,
is_hash256_str, is_non_negative_integer, MyEncoder, NetworkRetryManager)

from .bitcoin import COIN
from . import constants
from . import blockchain
Expand All @@ -60,6 +59,7 @@
from .simple_config import SimpleConfig
from .i18n import _
from .logging import get_logger, Logger
from .lnutil import ChannelBlackList

if TYPE_CHECKING:
from .channel_db import ChannelDB
Expand Down Expand Up @@ -335,6 +335,7 @@ def __init__(self, config: SimpleConfig, *, daemon: 'Daemon' = None):
self._has_ever_managed_to_connect_to_server = False

# lightning network
self.channel_blacklist = ChannelBlackList()
self.channel_db = None # type: Optional[ChannelDB]
self.lngossip = None # type: Optional[LNGossip]
self.local_watchtower = None # type: Optional[WatchTower]
Expand Down
2 changes: 2 additions & 0 deletions electrum/tests/test_lnpeer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from electrum.logging import console_stderr_handler, Logger
from electrum.lnworker import PaymentInfo, RECEIVED, PR_UNPAID
from electrum.lnonion import OnionFailureCode
from electrum.lnutil import ChannelBlackList

from .test_lnchannel import create_test_channels
from .test_bitcoin import needs_test_with_all_chacha20_implementations
Expand Down Expand Up @@ -62,6 +63,7 @@ def __init__(self, tx_queue):
self.path_finder = LNPathFinder(self.channel_db)
self.tx_queue = tx_queue
self._blockchain = MockBlockchain()
self.channel_blacklist = ChannelBlackList()

@property
def callback_lock(self):
Expand Down

0 comments on commit ad91257

Please sign in to comment.