diff --git a/contrib/requirements/requirements.txt b/contrib/requirements/requirements.txt index 7500fb28a625..7f855a67c184 100644 --- a/contrib/requirements/requirements.txt +++ b/contrib/requirements/requirements.txt @@ -13,3 +13,4 @@ pycryptodomex>=3.7 jsonrpcserver jsonrpcclient jsonpatch +attr diff --git a/electrum/json_db.py b/electrum/json_db.py index 345a3cad482e..6a32d1b06995 100644 --- a/electrum/json_db.py +++ b/electrum/json_db.py @@ -30,59 +30,97 @@ from collections import defaultdict from typing import Dict, Optional, List, Tuple, Set, Iterable, NamedTuple, Sequence import jsonpatch +import binascii from . import util, bitcoin from .util import profiler, WalletFileException, multisig_type, TxMinedInfo, bfh, PR_TYPE_ONCHAIN from .keystore import bip44_derivation from .transaction import Transaction, TxOutpoint, PartialTxOutput from .logging import Logger +from .lnutil import LOCAL, REMOTE, FeeUpdate, UpdateAddHtlc, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKeypair, RevocationStore +from .lnutil import ChannelConstraints, Outpoint +from .lnutil import StoredAttr # seed_version is now used for the version of the wallet file OLD_SEED_VERSION = 4 # electrum versions < 2.0 NEW_SEED_VERSION = 11 # electrum versions >= 2.0 -FINAL_SEED_VERSION = 23 # electrum >= 2.7 will set this to prevent +FINAL_SEED_VERSION = 24 # electrum >= 2.7 will set this to prevent # old versions from overwriting new format JsonDBJsonEncoder = util.MyEncoder +def decodeAll(d, local): + for k, v in d.items(): + if k == 'revocation_store': + yield (k, RevocationStore.from_json_obj(v)) + elif k.endswith("_basepoint") or k.endswith("_key"): + if local: + yield (k, Keypair(**dict(decodeAll(v, local)))) + else: + yield (k, OnlyPubkeyKeypair(**dict(decodeAll(v, local)))) + elif k in ["node_id", "channel_id", "short_channel_id", "pubkey", "privkey", "current_per_commitment_point", "next_per_commitment_point", "per_commitment_secret_seed", "current_commitment_signature", "current_htlc_signatures"] and v is not None: + yield (k, binascii.unhexlify(v)) + else: + yield (k, v) + + class TxFeesValue(NamedTuple): fee: Optional[int] = None is_calculated_by_us: bool = False num_inputs: Optional[int] = None + class StorageDict(dict): def __init__(self, db, path, data): - #self.update(data) self.db = db self.path = path + # recursively convert dicts to storagedict for k, v in list(data.items()): - # recursively convert dicts to storagedict self.__setitem__(k, v, patch=False) + def convert_key(self, key): + # convert int, HTLCOwner + # TODO: no need to convert address indexes anymore + return key if type(key) is str else str(int(key)) + def __setitem__(self, key, value, patch=True): + key = self.convert_key(key) is_dict = type(value) is dict is_new = key not in self # early return to prevent unnecessary disk writes if not is_new and self[key] == value: return + # set parent of StoredAttr + if isinstance(value, StoredAttr): + value.set_parent(self, key) # convert dict to StorageDict v = StorageDict(self.db, self.path+'/'+key, value) if is_dict else value + # set item dict.__setitem__(self, key, v) - if patch: + if self.db and patch: path = self.path + '/' + str(key) op = 'add' if is_new else 'replace' self.db.add_patch({'op': op, 'path': path, 'value': value}) def __delitem__(self, key): + key = self.convert_key(key) path = self.path + '/' + str(key) - self.db.add_patch({'op': 'remove', 'path': path}) + if self.db: + self.db.add_patch({'op': 'remove', 'path': path}) return dict.__delitem__(self, key) + def __getitem__(self, key): + key = self.convert_key(key) + return dict.__getitem__(self, key) + + def __contains__(self, key): + key = self.convert_key(key) + return dict.__contains__(self, key) class JsonDB(Logger): @@ -270,6 +308,7 @@ def upgrade(self): self._convert_version_21() self._convert_version_22() self._convert_version_23() + self._convert_version_24() self.put('seed_version', FINAL_SEED_VERSION) # just to be sure self._after_upgrade_tasks() @@ -610,6 +649,14 @@ def _convert_version_23(self): self.data['seed_version'] = 23 + def _convert_version_24(self): + if not self._is_upgrade_method_needed(23, 23): + return + channels = self.get('channels', []) + self.data['channels'] = { x['channel_id']: x for x in channels } + # unacked_local_updates2, fee_updates + self.data['seed_version'] = 24 + def _convert_imported(self): if not self._is_upgrade_method_needed(0, 13): return @@ -1030,6 +1077,24 @@ def _load_transactions(self): if invoice.get('type') == PR_TYPE_ONCHAIN: outputs = [PartialTxOutput.from_legacy_tuple(*output) for output in invoice.get('outputs')] invoice.__setitem__('outputs', outputs, patch=False) + # convert FeeUpdate and UpdateAddHtlc + for k, c in self.get_dict('channels').items(): + local_config = LocalConfig(**dict(decodeAll(c["local_config"], True))) + c.__setitem__('local_config', local_config, patch=False) + remote_config = RemoteConfig(**dict(decodeAll(c["remote_config"], False))) + c.__setitem__('remote_config', remote_config, patch=False) + constraints = ChannelConstraints(**c["constraints"]) + c.__setitem__('constraints', constraints, patch=False) + funding_outpoint = Outpoint(**dict(decodeAll(c["funding_outpoint"], False))) + c.__setitem__('funding_outpoint', funding_outpoint, patch=False) + log = c['log'] + for sub in (LOCAL, REMOTE): + d = log[sub] + for htlc_id, htlc in d['adds'].items(): + d['adds'].__setitem__(htlc_id, UpdateAddHtlc(*htlc), patch=False) + for i, fee_upd in d['fee_updates'].items(): + d['fee_updates'].__setitem__(i, FeeUpdate(*fee_upd), patch=False) + @modifier def clear_history(self): diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py index 9ce8d9011294..d2be2d034934 100644 --- a/electrum/lnchannel.py +++ b/electrum/lnchannel.py @@ -53,6 +53,7 @@ if TYPE_CHECKING: from .lnworker import LNWallet + from .json_db import StorageDict # lightning channel states @@ -91,15 +92,6 @@ class peer_states(IntEnum): (cs.CLOSED, cs.REDEEMED), ] -class ChannelJsonEncoder(json.JSONEncoder): - def default(self, o): - if isinstance(o, bytes): - return binascii.hexlify(o).decode("ascii") - if isinstance(o, RevocationStore): - return o.serialize() - if isinstance(o, set): - return list(o) - return super().default(o) RevokeAndAck = namedtuple("RevokeAndAck", ["per_commitment_secret", "next_per_commitment_point"]) @@ -107,31 +99,9 @@ def default(self, o): class RemoteCtnTooFarInFuture(Exception): pass -def decodeAll(d, local): - for k, v in d.items(): - if k == 'revocation_store': - yield (k, RevocationStore.from_json_obj(v)) - elif k.endswith("_basepoint") or k.endswith("_key"): - if local: - yield (k, Keypair(**dict(decodeAll(v, local)))) - else: - yield (k, OnlyPubkeyKeypair(**dict(decodeAll(v, local)))) - elif k in ["node_id", "channel_id", "short_channel_id", "pubkey", "privkey", "current_per_commitment_point", "next_per_commitment_point", "per_commitment_secret_seed", "current_commitment_signature", "current_htlc_signatures"] and v is not None: - yield (k, binascii.unhexlify(v)) - else: - yield (k, v) - def htlcsum(htlcs): return sum([x.amount_msat for x in htlcs]) -# following two functions are used because json -# doesn't store int keys and byte string values -def str_bytes_dict_from_save(x) -> Dict[int, bytes]: - return {int(k): bfh(v) for k,v in x.items()} - -def str_bytes_dict_to_save(x) -> Dict[str, str]: - return {str(k): bh2u(v) for k, v in x.items()} - class Channel(Logger): # note: try to avoid naming ctns/ctxs/etc as "current" and "pending". @@ -146,42 +116,51 @@ def diagnostic_name(self): except: return super().diagnostic_name() - def __init__(self, state, *, sweep_address=None, name=None, lnworker=None, initial_feerate=None): + def __init__(self, state: 'StorageDict', *, sweep_address=None, name=None, lnworker=None, initial_feerate=None): self.name = name Logger.__init__(self) self.lnworker = lnworker # type: Optional[LNWallet] self.sweep_address = sweep_address - assert 'local_state' not in state + self.storage = state self.config = {} self.config[LOCAL] = state["local_config"] - if type(self.config[LOCAL]) is not LocalConfig: - conf = dict(decodeAll(self.config[LOCAL], True)) - self.config[LOCAL] = LocalConfig(**conf) - assert type(self.config[LOCAL].htlc_basepoint.privkey) is bytes - self.config[REMOTE] = state["remote_config"] - if type(self.config[REMOTE]) is not RemoteConfig: - conf = dict(decodeAll(self.config[REMOTE], False)) - self.config[REMOTE] = RemoteConfig(**conf) - assert type(self.config[REMOTE].htlc_basepoint.pubkey) is bytes - - self.channel_id = bfh(state["channel_id"]) if type(state["channel_id"]) not in (bytes, type(None)) else state["channel_id"] - self.constraints = ChannelConstraints(**state["constraints"]) if type(state["constraints"]) is not ChannelConstraints else state["constraints"] - self.funding_outpoint = Outpoint(**dict(decodeAll(state["funding_outpoint"], False))) if type(state["funding_outpoint"]) is not Outpoint else state["funding_outpoint"] - self.node_id = bfh(state["node_id"]) if type(state["node_id"]) not in (bytes, type(None)) else state["node_id"] # type: bytes + self.channel_id = bfh(state["channel_id"]) + self.constraints = state["constraints"] + self.funding_outpoint = state["funding_outpoint"] + self.node_id = bfh(state["node_id"]) self.short_channel_id = ShortChannelID.normalize(state["short_channel_id"]) self.short_channel_id_predicted = self.short_channel_id - self.onion_keys = str_bytes_dict_from_save(state.get('onion_keys', {})) - self.data_loss_protect_remote_pcp = str_bytes_dict_from_save(state.get('data_loss_protect_remote_pcp', {})) - self.remote_update = bfh(state.get('remote_update')) if state.get('remote_update') else None - - log = state.get('log') - self.hm = HTLCManager(log=log, initial_feerate=initial_feerate) + self.onion_keys = state['onion_keys'] + self.data_loss_protect_remote_pcp = state['data_loss_protect_remote_pcp'] + self.hm = HTLCManager(log=state['log'], initial_feerate=initial_feerate) self._state = channel_states[state['state']] self.peer_state = peer_states.DISCONNECTED self.sweep_info = {} # type: Dict[str, Dict[str, SweepInfo]] self._outgoing_channel_update = None # type: Optional[bytes] + def set_onion_key(self, key, value): + self.onion_keys[key] = value + + def get_onion_key(self, key): + return self.onion_keys.get(key) + + def set_data_loss_protect_remote_pcp(self, key, value): + self.data_loss_protect_remote_pcp[key] = value + + def get_data_loss_protect_remote_pcp(self, key): + self.data_loss_protect_remote_pcp.get(key) + + def set_remote_update(self, raw): + self.storage['remote_update'] = raw.hex() + + def get_remote_update(self): + return bfh(self.storage.get('remote_update')) if self.storage.get('remote_update') else None + + def set_short_channel_id(self, short_id): + self.short_channel_id = short_id + self.storage["short_channel_id"] = short_id + def get_feerate(self, subject, ctn): return self.hm.get_feerate(subject, ctn) @@ -211,9 +190,9 @@ def get_payments(self): return out def open_with_first_pcp(self, remote_pcp, remote_sig): - self.config[REMOTE] = self.config[REMOTE]._replace(current_per_commitment_point=remote_pcp, - next_per_commitment_point=None) - self.config[LOCAL] = self.config[LOCAL]._replace(current_commitment_signature=remote_sig) + self.config[REMOTE].current_per_commitment_point=remote_pcp + self.config[REMOTE].next_per_commitment_point=None + self.config[LOCAL].current_commitment_signature=remote_sig self.hm.channel_open_finished() self.peer_state = peer_states.GOOD self.set_state(channel_states.OPENING) @@ -223,8 +202,10 @@ def set_state(self, state): old_state = self._state if (old_state, state) not in state_transitions: raise Exception(f"Transition not allowed: {old_state.name} -> {state.name}") - self._state = state self.logger.debug(f'Setting channel state: {old_state.name} -> {state.name}') + self._state = state + self.storage['state'] = self._state.name + if self.lnworker: self.lnworker.save_channel(self) self.lnworker.network.trigger_callback('channel', self) @@ -397,9 +378,8 @@ def receive_new_commitment(self, sig, htlc_sigs): ctx_output_idx=ctx_output_idx) self.hm.recv_ctx() - self.config[LOCAL]=self.config[LOCAL]._replace( - current_commitment_signature=sig, - current_htlc_signatures=htlc_sigs_string) + self.config[LOCAL].current_commitment_signature=sig + self.config[LOCAL].current_htlc_signatures=htlc_sigs_string def verify_htlc(self, *, htlc: UpdateAddHtlc, htlc_sig: bytes, htlc_direction: Direction, pcp: bytes, ctx: Transaction, ctx_output_idx: int) -> None: @@ -451,13 +431,16 @@ def receive_revocation(self, revocation: RevokeAndAck): if cur_point != derived_point: raise Exception('revoked secret not for current point') - self.config[REMOTE].revocation_store.add_next_entry(revocation.per_commitment_secret) + ### Note: here we trigger RemoteConfig.__setattr__ + ### to avoid that, RevocationStore should not be a field of RemoteConfig + rev = self.config[REMOTE].revocation_store + rev.add_next_entry(revocation.per_commitment_secret) + self.config[REMOTE].revocation_store = rev + ##### start applying fee/htlc changes self.hm.recv_rev() - self.config[REMOTE]=self.config[REMOTE]._replace( - current_per_commitment_point=self.config[REMOTE].next_per_commitment_point, - next_per_commitment_point=revocation.next_per_commitment_point, - ) + self.config[REMOTE].current_per_commitment_point=self.config[REMOTE].next_per_commitment_point + self.config[REMOTE].next_per_commitment_point=revocation.next_per_commitment_point def balance(self, whose, *, ctx_owner=HTLCOwner.LOCAL, ctn=None): """ @@ -646,49 +629,8 @@ def update_fee(self, feerate: int, from_us: bool): else: self.hm.recv_update_fee(feerate) - def to_save(self): - to_save = { - "local_config": self.config[LOCAL], - "remote_config": self.config[REMOTE], - "channel_id": self.channel_id, - "short_channel_id": self.short_channel_id, - "constraints": self.constraints, - "funding_outpoint": self.funding_outpoint, - "node_id": self.node_id, - "log": self.hm.to_save(), - "onion_keys": str_bytes_dict_to_save(self.onion_keys), - "state": self._state.name, - "data_loss_protect_remote_pcp": str_bytes_dict_to_save(self.data_loss_protect_remote_pcp), - "remote_update": self.remote_update.hex() if self.remote_update else None - } - return to_save - def serialize(self): - namedtuples_to_dict = lambda v: {i: j._asdict() if isinstance(j, tuple) else j for i, j in v._asdict().items()} - serialized_channel = {} - to_save_ref = self.to_save() - for k, v in to_save_ref.items(): - if isinstance(v, tuple): - serialized_channel[k] = namedtuples_to_dict(v) - else: - serialized_channel[k] = v - dumped = ChannelJsonEncoder().encode(serialized_channel) - roundtripped = json.loads(dumped) - reconstructed = Channel(roundtripped) - to_save_new = reconstructed.to_save() - if to_save_new != to_save_ref: - from pprint import PrettyPrinter - pp = PrettyPrinter(indent=168) - try: - from deepdiff import DeepDiff - except ImportError: - raise Exception("Channels did not roundtrip serialization without changes:\n" + pp.pformat(to_save_ref) + "\n" + pp.pformat(to_save_new)) - else: - raise Exception("Channels did not roundtrip serialization without changes:\n" + pp.pformat(DeepDiff(to_save_ref, to_save_new))) - return roundtripped - - def __str__(self): - return str(self.serialize()) + pass def make_commitment(self, subject, this_point, ctn) -> PartialTransaction: assert type(subject) is HTLCOwner @@ -741,7 +683,8 @@ def make_commitment(self, subject, this_point, ctn) -> PartialTransaction: other_revocation_pubkey, derive_pubkey(this_config.delayed_basepoint.pubkey, this_point), other_config.to_self_delay, - *self.funding_outpoint, + self.funding_outpoint.txid, + self.funding_outpoint.output_index, self.constraints.capacity, local_msat, remote_msat, diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py index 19999c12469b..da87aaa8d8cf 100644 --- a/electrum/lnhtlc.py +++ b/electrum/lnhtlc.py @@ -1,48 +1,37 @@ from copy import deepcopy -from typing import Optional, Sequence, Tuple, List, Dict +from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction, FeeUpdate from .util import bh2u, bfh +if TYPE_CHECKING: + from .json_db import StorageDict class HTLCManager: - def __init__(self, *, log=None, initial_feerate=None): - if log is None: + def __init__(self, log:'StorageDict', *, initial_feerate=None): + + if len(log) == 0: initial = { 'adds': {}, 'locked_in': {}, 'settles': {}, 'fails': {}, - 'fee_updates': [], + 'fee_updates': {}, # "side who initiated fee update" -> action -> list of FeeUpdates 'revack_pending': False, 'next_htlc_id': 0, - 'ctn': -1, # oldest unrevoked ctx of sub + 'ctn': -1, # oldest unrevoked ctx of sub } - log = {LOCAL: deepcopy(initial), REMOTE: deepcopy(initial)} - else: - assert type(log) is dict - log = {(HTLCOwner(int(k)) if k in ("-1", "1") else k): v - for k, v in deepcopy(log).items()} - for sub in (LOCAL, REMOTE): - log[sub]['adds'] = {int(htlc_id): UpdateAddHtlc(*htlc) for htlc_id, htlc in log[sub]['adds'].items()} - coerceHtlcOwner2IntMap = lambda ctns: {HTLCOwner(int(owner)): ctn for owner, ctn in ctns.items()} - # "side who offered htlc" -> action -> htlc_id -> whose ctx -> ctn - log[sub]['locked_in'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['locked_in'].items()} - log[sub]['settles'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['settles'].items()} - log[sub]['fails'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['fails'].items()} - # "side who initiated fee update" -> action -> list of FeeUpdates - log[sub]['fee_updates'] = [FeeUpdate.from_dict(fee_upd) for fee_upd in log[sub]['fee_updates']] - if 'unacked_local_updates2' not in log: + log[LOCAL] = deepcopy(initial) + log[REMOTE] = deepcopy(initial) log['unacked_local_updates2'] = {} - log['unacked_local_updates2'] = {int(ctn): [bfh(msg) for msg in messages] - for ctn, messages in log['unacked_local_updates2'].items()} + # maybe bootstrap fee_updates if initial_feerate was provided if initial_feerate is not None: assert type(initial_feerate) is int for sub in (LOCAL, REMOTE): if not log[sub]['fee_updates']: - log[sub]['fee_updates'].append(FeeUpdate(initial_feerate, ctns={LOCAL:0, REMOTE:0})) + log[sub]['fee_updates'][0] = FeeUpdate(initial_feerate, ctn_local=0, ctn_remote=0) self.log = log def ctn_latest(self, sub: HTLCOwner) -> int: @@ -65,20 +54,6 @@ def _set_revack_pending(self, sub: HTLCOwner, pending: bool) -> None: def get_next_htlc_id(self, sub: HTLCOwner) -> int: return self.log[sub]['next_htlc_id'] - def to_save(self): - log = deepcopy(self.log) - for sub in (LOCAL, REMOTE): - # adds - d = {} - for htlc_id, htlc in log[sub]['adds'].items(): - d[htlc_id] = (htlc[0], bh2u(htlc[1])) + htlc[2:] - log[sub]['adds'] = d - # fee_updates - log[sub]['fee_updates'] = [FeeUpdate.to_dict(fee_upd) for fee_upd in log[sub]['fee_updates']] - log['unacked_local_updates2'] = {ctn: [bh2u(msg) for msg in messages] - for ctn, messages in log['unacked_local_updates2'].items()} - return log - ##### Actions on channel: def channel_open_finished(self): @@ -120,22 +95,25 @@ def recv_fail(self, htlc_id: int) -> None: def send_update_fee(self, feerate: int) -> None: fee_update = FeeUpdate(rate=feerate, - ctns={LOCAL: None, REMOTE: self.ctn_latest(REMOTE) + 1}) + ctn_local=None, ctn_remote=self.ctn_latest(REMOTE) + 1) self._new_feeupdate(fee_update, subject=LOCAL) def recv_update_fee(self, feerate: int) -> None: fee_update = FeeUpdate(rate=feerate, - ctns={LOCAL: self.ctn_latest(LOCAL) + 1, REMOTE: None}) + ctn_local=self.ctn_latest(LOCAL) + 1, ctn_remote=None) self._new_feeupdate(fee_update, subject=REMOTE) def _new_feeupdate(self, fee_update: FeeUpdate, subject: HTLCOwner) -> None: # overwrite last fee update if not yet committed to by anyone; otherwise append - last_fee_update = self.log[subject]['fee_updates'][-1] - if (last_fee_update.ctns[LOCAL] is None or last_fee_update.ctns[LOCAL] > self.ctn_latest(LOCAL)) \ - and (last_fee_update.ctns[REMOTE] is None or last_fee_update.ctns[REMOTE] > self.ctn_latest(REMOTE)): - self.log[subject]['fee_updates'][-1] = fee_update + d = self.log[subject]['fee_updates'] + #assert type(d) is StorageDict + n = len(d) + last_fee_update = d[n-1] + if (last_fee_update.ctn_local is None or last_fee_update.ctn_local > self.ctn_latest(LOCAL)) \ + and (last_fee_update.ctn_remote is None or last_fee_update.ctn_remote > self.ctn_latest(REMOTE)): + d[n-1] = fee_update else: - self.log[subject]['fee_updates'].append(fee_update) + d[n] = fee_update def send_ctx(self) -> None: assert self.ctn_latest(REMOTE) == self.ctn_oldest_unrevoked(REMOTE), (self.ctn_latest(REMOTE), self.ctn_oldest_unrevoked(REMOTE)) @@ -157,9 +135,9 @@ def send_rev(self) -> None: if ctns[REMOTE] is None and ctns[LOCAL] <= self.ctn_latest(LOCAL): ctns[REMOTE] = self.ctn_latest(REMOTE) + 1 # fee updates - for fee_update in self.log[REMOTE]['fee_updates']: - if fee_update.ctns[REMOTE] is None and fee_update.ctns[LOCAL] <= self.ctn_latest(LOCAL): - fee_update.ctns[REMOTE] = self.ctn_latest(REMOTE) + 1 + for k, fee_update in list(self.log[REMOTE]['fee_updates'].items()): + if fee_update.ctn_remote is None and fee_update.ctn_local <= self.ctn_latest(LOCAL): + self.log[REMOTE]['fee_updates'][k] = fee_update._replace(ctn_remote = self.ctn_latest(REMOTE) + 1) def recv_rev(self) -> None: self.log[REMOTE]['ctn'] += 1 @@ -173,11 +151,12 @@ def recv_rev(self) -> None: if ctns[LOCAL] is None and ctns[REMOTE] <= self.ctn_latest(REMOTE): ctns[LOCAL] = self.ctn_latest(LOCAL) + 1 # fee updates - for fee_update in self.log[LOCAL]['fee_updates']: - if fee_update.ctns[LOCAL] is None and fee_update.ctns[REMOTE] <= self.ctn_latest(REMOTE): - fee_update.ctns[LOCAL] = self.ctn_latest(LOCAL) + 1 + for k, fee_update in list(self.log[LOCAL]['fee_updates'].items()): + if fee_update.ctn_local is None and fee_update.ctn_remote <= self.ctn_latest(REMOTE): + self.log[LOCAL]['fee_updates'][k] = fee_update._replace(ctn_local = self.ctn_latest(LOCAL) + 1) + # no need to keep local update raw msgs anymore, they have just been ACKed. - self.log['unacked_local_updates2'].pop(self.log[REMOTE]['ctn'], None) + self.log['unacked_local_updates2'].pop(str(self.log[REMOTE]['ctn']), None) def discard_unsigned_remote_updates(self): """Discard updates sent by the remote, that the remote itself @@ -189,7 +168,7 @@ def discard_unsigned_remote_updates(self): del self.log[REMOTE]['locked_in'][htlc_id] del self.log[REMOTE]['adds'][htlc_id] if self.log[REMOTE]['locked_in']: - self.log[REMOTE]['next_htlc_id'] = max(self.log[REMOTE]['locked_in']) + 1 + self.log[REMOTE]['next_htlc_id'] = max([int(x) for x in self.log[REMOTE]['locked_in'].keys()]) + 1 else: self.log[REMOTE]['next_htlc_id'] = 0 # htlcs removed @@ -198,9 +177,9 @@ def discard_unsigned_remote_updates(self): if ctns[LOCAL] > self.ctn_latest(LOCAL): del self.log[LOCAL][log_action][htlc_id] # fee updates - for i, fee_update in enumerate(list(self.log[REMOTE]['fee_updates'])): - if fee_update.ctns[LOCAL] > self.ctn_latest(LOCAL): - del self.log[REMOTE]['fee_updates'][i] + for k, fee_update in list(self.log[REMOTE]['fee_updates'].items()): + if fee_update.ctn_local > self.ctn_latest(LOCAL): + self.log[REMOTE]['fee_updates'].pop(k) def store_local_update_raw_msg(self, raw_update_msg: bytes, *, is_commitment_signed: bool) -> None: """We need to be able to replay unacknowledged updates we sent to the remote @@ -212,12 +191,15 @@ def store_local_update_raw_msg(self, raw_update_msg: bytes, *, is_commitment_sig ctn_idx = self.ctn_latest(REMOTE) else: ctn_idx = self.ctn_latest(REMOTE) + 1 - if ctn_idx not in self.log['unacked_local_updates2']: - self.log['unacked_local_updates2'][ctn_idx] = [] - self.log['unacked_local_updates2'][ctn_idx].append(raw_update_msg) + ctn_idx = str(ctn_idx) + l = self.log['unacked_local_updates2'].get(ctn_idx, []) + l.append(raw_update_msg.hex()) + self.log['unacked_local_updates2'][ctn_idx] = l def get_unacked_local_updates(self) -> Dict[int, Sequence[bytes]]: - return self.log['unacked_local_updates2'] + #return self.log['unacked_local_updates2'] + return {int(ctn): [bfh(msg) for msg in messages] + for ctn, messages in self.log['unacked_local_updates2'].items()} ##### Queries re HTLCs: @@ -331,7 +313,7 @@ def get_feerate(self, subject: HTLCOwner, ctn: int) -> int: right = len(fee_log) while True: i = (left + right) // 2 - ctn_at_i = fee_log[i].ctns[subject] + ctn_at_i = fee_log[i].ctn_local if subject==LOCAL else fee_log[i].ctn_remote if right - left <= 1: break if ctn_at_i is None: # Nones can only be on the right end diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index fec0348dfba5..640c631b4d55 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -221,7 +221,7 @@ def on_channel_update(self, payload): def maybe_save_remote_update(self, payload): for chan in self.channels.values(): if chan.short_channel_id == payload['short_channel_id']: - chan.remote_update = payload['raw'] + chan.set_remote_update(payload['raw']) self.logger.info("saved remote_update") def on_announcement_signatures(self, payload): @@ -582,17 +582,9 @@ async def channel_establishment_flow(self, password: Optional[str], funding_tx: funding_index = funding_tx.outputs().index(funding_output) # remote commitment transaction channel_id, funding_txid_bytes = channel_id_from_funding_tx(funding_txid, funding_index) - chan_dict = { - "node_id": self.pubkey, - "channel_id": channel_id, - "short_channel_id": None, - "funding_outpoint": Outpoint(funding_txid, funding_index), - "remote_config": remote_config, - "local_config": local_config, - "constraints": ChannelConstraints(capacity=funding_sat, is_initiator=True, funding_txn_minimum_depth=funding_txn_minimum_depth), - "remote_update": None, - "state": channel_states.PREOPENING.name, - } + outpoint = Outpoint(funding_txid, funding_index) + constraints = ChannelConstraints(capacity=funding_sat, is_initiator=True, funding_txn_minimum_depth=funding_txn_minimum_depth) + chan_dict = self.create_channel_storage(channel_id, outpoint, local_config, remote_config, constraints) chan = Channel(chan_dict, sweep_address=self.lnworker.sweep_address, lnworker=self.lnworker, @@ -610,6 +602,26 @@ async def channel_establishment_flow(self, password: Optional[str], funding_tx: chan.open_with_first_pcp(remote_per_commitment_point, remote_sig) return chan, funding_tx + def create_channel_storage(self, channel_id, outpoint, local_config, remote_config, constraints): + chan_dict = { + "node_id": self.pubkey.hex(), + "channel_id": channel_id.hex(), + "short_channel_id": None, + "funding_outpoint": outpoint, + "remote_config": remote_config, + "local_config": local_config, + "constraints": constraints, + "remote_update": None, + "state": channel_states.PREOPENING.name, + 'onion_keys': {}, + 'data_loss_protect_remote_pcp': {}, + "log": {}, + } + channel_id = chan_dict.get('channel_id') + channels = self.lnworker.storage.db.get_dict('channels') + channels[channel_id] = chan_dict + return channels.get(channel_id) + async def on_open_channel(self, payload): # payload['channel_flags'] if payload['chain_hash'] != constants.net.rev_genesis_bytes(): @@ -649,33 +661,26 @@ async def on_open_channel(self, payload): remote_balance_sat = funding_sat * 1000 - push_msat remote_dust_limit_sat = int.from_bytes(payload['dust_limit_satoshis'], byteorder='big') # TODO validate remote_reserve_sat = self.validate_remote_reserve(payload['channel_reserve_satoshis'], remote_dust_limit_sat, funding_sat) - chan_dict = { - "node_id": self.pubkey, - "channel_id": channel_id, - "short_channel_id": None, - "funding_outpoint": Outpoint(funding_txid, funding_idx), - "remote_config": RemoteConfig( - payment_basepoint=OnlyPubkeyKeypair(payload['payment_basepoint']), - multisig_key=OnlyPubkeyKeypair(payload['funding_pubkey']), - htlc_basepoint=OnlyPubkeyKeypair(payload['htlc_basepoint']), - delayed_basepoint=OnlyPubkeyKeypair(payload['delayed_payment_basepoint']), - revocation_basepoint=OnlyPubkeyKeypair(payload['revocation_basepoint']), - to_self_delay=int.from_bytes(payload['to_self_delay'], 'big'), - dust_limit_sat=remote_dust_limit_sat, - max_htlc_value_in_flight_msat=int.from_bytes(payload['max_htlc_value_in_flight_msat'], 'big'), # TODO validate - max_accepted_htlcs=int.from_bytes(payload['max_accepted_htlcs'], 'big'), # TODO validate - initial_msat=remote_balance_sat, - reserve_sat = remote_reserve_sat, - htlc_minimum_msat=int.from_bytes(payload['htlc_minimum_msat'], 'big'), # TODO validate - next_per_commitment_point=payload['first_per_commitment_point'], - current_per_commitment_point=None, - revocation_store=their_revocation_store, - ), - "local_config": local_config, - "constraints": ChannelConstraints(capacity=funding_sat, is_initiator=False, funding_txn_minimum_depth=min_depth), - "remote_update": None, - "state": channel_states.PREOPENING.name, - } + remote_config = RemoteConfig( + payment_basepoint=OnlyPubkeyKeypair(payload['payment_basepoint']), + multisig_key=OnlyPubkeyKeypair(payload['funding_pubkey']), + htlc_basepoint=OnlyPubkeyKeypair(payload['htlc_basepoint']), + delayed_basepoint=OnlyPubkeyKeypair(payload['delayed_payment_basepoint']), + revocation_basepoint=OnlyPubkeyKeypair(payload['revocation_basepoint']), + to_self_delay=int.from_bytes(payload['to_self_delay'], 'big'), + dust_limit_sat=remote_dust_limit_sat, + max_htlc_value_in_flight_msat=int.from_bytes(payload['max_htlc_value_in_flight_msat'], 'big'), # TODO validate + max_accepted_htlcs=int.from_bytes(payload['max_accepted_htlcs'], 'big'), # TODO validate + initial_msat=remote_balance_sat, + reserve_sat = remote_reserve_sat, + htlc_minimum_msat=int.from_bytes(payload['htlc_minimum_msat'], 'big'), # TODO validate + next_per_commitment_point=payload['first_per_commitment_point'], + current_per_commitment_point=None, + revocation_store=their_revocation_store, + ) + constraints = ChannelConstraints(capacity=funding_sat, is_initiator=True, funding_txn_minimum_depth=min_depth) + outpoint = Outpoint(funding_txid, funding_idx) + chan_dict = self.create_channel_storage(outpoint, local_config, remote_config, constraints) chan = Channel(chan_dict, sweep_address=self.lnworker.sweep_address, lnworker=self.lnworker, @@ -688,7 +693,7 @@ async def on_open_channel(self, payload): signature=sig_64, ) chan.open_with_first_pcp(payload['first_per_commitment_point'], remote_sig) - self.lnworker.save_channel(chan) + self.lnworker.add_channel(chan) self.lnworker.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address()) def validate_remote_reserve(self, payload_field: bytes, dust_limit: int, funding_sat: int) -> int: @@ -855,7 +860,7 @@ def are_datalossprotect_fields_valid() -> bool: else: if dlp_enabled and should_close_they_are_ahead: self.logger.warning(f"channel_reestablish: remote is ahead of us! luckily DLP is enabled. remote PCP: {bh2u(their_local_pcp)}") - chan.data_loss_protect_remote_pcp[their_next_local_ctn - 1] = their_local_pcp + chan.set_data_loss_protect_remote_pcp(their_next_local_ctn - 1, their_local_pcp) self.lnworker.save_channel(chan) if should_close_they_are_ahead: self.logger.warning(f"channel_reestablish: remote is ahead of us! trying to get them to force-close.") @@ -890,15 +895,12 @@ def on_funding_locked(self, payload): self.logger.info(f"on_funding_locked. channel: {bh2u(channel_id)}") chan = self.channels.get(channel_id) if not chan: - print(self.channels) raise Exception("Got unknown funding_locked", channel_id) if not chan.config[LOCAL].funding_locked_received: our_next_point = chan.config[REMOTE].next_per_commitment_point their_next_point = payload["next_per_commitment_point"] - new_remote_state = chan.config[REMOTE]._replace(next_per_commitment_point=their_next_point) - new_local_state = chan.config[LOCAL]._replace(funding_locked_received = True) - chan.config[REMOTE]=new_remote_state - chan.config[LOCAL]=new_local_state + chan.config[REMOTE].next_per_commitment_point = their_next_point + chan.config[LOCAL].funding_locked_received = True self.lnworker.save_channel(chan) if chan.short_channel_id: self.mark_open(chan) @@ -913,9 +915,9 @@ def on_network_update(self, chan: Channel, funding_tx_depth: int): # don't announce our channels # FIXME should this be a field in chan.local_state maybe? return - chan.config[LOCAL]=chan.config[LOCAL]._replace(was_announced=True) - coro = self.handle_announcements(chan) + chan.config[LOCAL].was_announced=True self.lnworker.save_channel(chan) + coro = self.handle_announcements(chan) asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop) @log_exceptions @@ -1011,11 +1013,11 @@ def add_own_channel(self, chan): # peer may have sent us a channel update for the incoming direction previously pending_channel_update = self.orphan_channel_updates.get(chan.short_channel_id) if pending_channel_update: - chan.remote_update = pending_channel_update['raw'] + chan.set_remote_update(pending_channel_update['raw']) # add remote update with a fresh timestamp - if chan.remote_update: + if chan.get_remote_update(): now = int(time.time()) - remote_update_decoded = decode_msg(chan.remote_update)[1] + remote_update_decoded = decode_msg(chan.get_remote_update())[1] remote_update_decoded['timestamp'] = now.to_bytes(4, byteorder="big") self.channel_db.add_channel_update(remote_update_decoded) diff --git a/electrum/lnsweep.py b/electrum/lnsweep.py index c9058e8e3449..f612877d96e0 100644 --- a/electrum/lnsweep.py +++ b/electrum/lnsweep.py @@ -299,8 +299,8 @@ def analyze_ctx(chan: 'Channel', ctx: Transaction): their_pcp = ecc.ECPrivkey(per_commitment_secret).get_public_key_bytes(compressed=True) is_revocation = True #_logger.info(f'tx for revoked: {list(txs.keys())}') - elif ctn in chan.data_loss_protect_remote_pcp: - their_pcp = chan.data_loss_protect_remote_pcp[ctn] + elif chan.get_data_loss_protect_remote_pcp(ctn): + their_pcp = chan.get_data_loss_protect_remote_pcp(ctn) is_revocation = False else: return diff --git a/electrum/lnutil.py b/electrum/lnutil.py index 819e5a9d4678..726ad8833953 100644 --- a/electrum/lnutil.py +++ b/electrum/lnutil.py @@ -7,6 +7,7 @@ from collections import namedtuple from typing import NamedTuple, List, Tuple, Mapping, Optional, TYPE_CHECKING, Union, Dict, Set import re +import attr from aiorpcx import NetAddress @@ -35,74 +36,91 @@ def ln_dummy_address(): return redeem_script_to_address('p2wsh', '') -class Keypair(NamedTuple): - pubkey: bytes - privkey: bytes - - -class OnlyPubkeyKeypair(NamedTuple): - pubkey: bytes - - -# NamedTuples cannot subclass NamedTuples :'( https://github.com/python/typing/issues/427 -class LocalConfig(NamedTuple): - # shared channel config fields (DUPLICATED code!!) - payment_basepoint: 'Keypair' - multisig_key: 'Keypair' - htlc_basepoint: 'Keypair' - delayed_basepoint: 'Keypair' - revocation_basepoint: 'Keypair' - to_self_delay: int - dust_limit_sat: int - max_htlc_value_in_flight_msat: int - max_accepted_htlcs: int - initial_msat: int - reserve_sat: int +@attr.s +class Keypair: + pubkey = attr.ib(bytes) + privkey = attr.ib(bytes) + def to_json(self): + return vars(self) + +@attr.s +class OnlyPubkeyKeypair: + pubkey = attr.ib(bytes) + def to_json(self): + return vars(self) + + +class StoredX: + # ancestor so that we can call super() + pass + +class StoredAttr(StoredX): + db = None + path = None + + def __setattr__(self, key, value): + super().__setattr__(key, value) + if self.db and key not in ['path', 'db']: + path = self.path + '/' + str(key) + self.db.add_patch({'op': 'replace', 'path': path, 'value': value}) + + def to_json(self): + # dict() copies the object + d = dict(vars(self)) + d.pop('path', None) + d.pop('db', None) + return d + + def set_parent(self, parent, key): + self.db = parent.db + self.path = parent.path + '/'+ key + + +@attr.s +class Config(StoredAttr): + # shared channel config fields + payment_basepoint = attr.ib(Union['Keypair', 'OnlyPubkeyKeypair']) + multisig_key = attr.ib(Union['Keypair', 'OnlyPubkeyKeypair']) + htlc_basepoint = attr.ib(Union['Keypair', 'OnlyPubkeyKeypair']) + delayed_basepoint = attr.ib(Union['Keypair', 'OnlyPubkeyKeypair']) + revocation_basepoint = attr.ib(Union['Keypair', 'OnlyPubkeyKeypair']) + to_self_delay = attr.ib(int) + dust_limit_sat = attr.ib(int) + max_htlc_value_in_flight_msat = attr.ib(int) + max_accepted_htlcs = attr.ib(int) + initial_msat = attr.ib(int) + reserve_sat = attr.ib(int) + +@attr.s +class LocalConfig(Config): # specific to "LOCAL" config - per_commitment_secret_seed: bytes - funding_locked_received: bool - was_announced: bool - current_commitment_signature: Optional[bytes] - current_htlc_signatures: bytes - - -class RemoteConfig(NamedTuple): - # shared channel config fields (DUPLICATED code!!) - payment_basepoint: Union['Keypair', 'OnlyPubkeyKeypair'] - multisig_key: Union['Keypair', 'OnlyPubkeyKeypair'] - htlc_basepoint: Union['Keypair', 'OnlyPubkeyKeypair'] - delayed_basepoint: Union['Keypair', 'OnlyPubkeyKeypair'] - revocation_basepoint: Union['Keypair', 'OnlyPubkeyKeypair'] - to_self_delay: int - dust_limit_sat: int - max_htlc_value_in_flight_msat: int - max_accepted_htlcs: int - initial_msat: int - reserve_sat: int + per_commitment_secret_seed = attr.ib(bytes) + funding_locked_received = attr.ib(bool) + was_announced = attr.ib(bool) + current_commitment_signature = attr.ib(bytes) + current_htlc_signatures = attr.ib(bytes) + +@attr.s +class RemoteConfig(Config): # specific to "REMOTE" config - htlc_minimum_msat: int - next_per_commitment_point: bytes - revocation_store: 'RevocationStore' - current_per_commitment_point: Optional[bytes] + htlc_minimum_msat = attr.ib(int) + next_per_commitment_point = attr.ib(bytes) + revocation_store = attr.ib('RevocationStore') + current_per_commitment_point = attr.ib(bytes) + class FeeUpdate(NamedTuple): rate: int # in sat/kw - ctns: Dict['HTLCOwner', Optional[int]] - - @classmethod - def from_dict(cls, d: dict) -> 'FeeUpdate': - return FeeUpdate(rate=d['rate'], - ctns={LOCAL: d['ctns'][str(int(LOCAL))], - REMOTE: d['ctns'][str(int(REMOTE))]}) + ctn_local: Optional[int] + ctn_remote: Optional[int] - def to_dict(self) -> dict: - return {'rate': self.rate, - 'ctns': {int(LOCAL): self.ctns[LOCAL], - int(REMOTE): self.ctns[REMOTE]}} - -ChannelConstraints = namedtuple("ChannelConstraints", ["capacity", "is_initiator", "funding_txn_minimum_depth"]) +@attr.s +class ChannelConstraints(StoredAttr): + capacity = attr.ib(int) + is_initiator = attr.ib(bool) + funding_txn_minimum_depth = attr.ib(int) class ScriptHtlc(NamedTuple): @@ -111,7 +129,10 @@ class ScriptHtlc(NamedTuple): # FIXME duplicate of TxOutpoint in transaction.py?? -class Outpoint(NamedTuple("Outpoint", [('txid', str), ('output_index', int)])): +@attr.s +class Outpoint(StoredAttr): + txid = attr.ib(str) + output_index = attr.ib(int) def to_str(self): return "{}:{}".format(self.txid, self.output_index) @@ -202,6 +223,9 @@ def retrieve_secret(self, index: int) -> bytes: def serialize(self): return {"index": self.index, "buckets": [[bh2u(k.secret), k.index] if k is not None else None for k in self.buckets]} + def to_json(self): + return self.serialize() + @staticmethod def from_json_obj(decoded_json_obj): store = RevocationStore() diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 5807fcc44c68..36a50bb20c30 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -33,13 +33,14 @@ from .util import bh2u, bfh, InvoiceError, resolve_dns_srv, is_ip_address, log_exceptions from .util import ignore_exceptions, make_aiohttp_session from .util import timestamp_to_datetime +from .util import MyEncoder from .logging import Logger from .lntransport import LNTransport, LNResponderTransport from .lnpeer import Peer, LN_P2P_NETWORK_TIMEOUT from .lnaddr import lnencode, LnAddr, lndecode from .ecc import der_sig_from_sig_string from .ecc_fast import is_using_fast_ecc -from .lnchannel import Channel, ChannelJsonEncoder +from .lnchannel import Channel from .lnchannel import channel_states, peer_states from . import lnutil from .lnutil import funding_output_script @@ -105,8 +106,6 @@ LNPeerAddr(host='3.124.63.44', port=9735, pubkey=bfh('0242a4ae0c5bef18048fbecf995094b74bfb0f7391418d71ed394784373f41e4f3')), ] -encoder = ChannelJsonEncoder() - from typing import NamedTuple @@ -348,10 +347,12 @@ def __init__(self, wallet: 'Abstract_Wallet', xprv): self.logs = defaultdict(list) # note: accessing channels (besides simple lookup) needs self.lock! - self.channels = {} # type: Dict[bytes, Channel] - for x in wallet.storage.get("channels", []): - c = Channel(x, sweep_address=self.sweep_address, lnworker=self) - self.channels[c.channel_id] = c + self.channels = {} + channels = self.storage.db.get_dict("channels") + for channel_id, c in channels.items(): + assert c['local_config'].db is not None + self.channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self) + # timestamps of opening and closing transactions self.channel_timestamps = self.storage.db.get_dict('lightning_channel_timestamps') self.pending_payments = defaultdict(asyncio.Future) @@ -602,17 +603,10 @@ def save_channel(self, chan): assert type(chan) is Channel if chan.config[REMOTE].next_per_commitment_point == chan.config[REMOTE].current_per_commitment_point: raise Exception("Tried to save channel with next_point == current_point, this should not happen") - with self.lock: - self.channels[chan.channel_id] = chan - self.save_channels() + # + self.wallet.storage.write() self.network.trigger_callback('channel', chan) - def save_channels(self): - with self.lock: - dumped = [x.serialize() for x in self.channels.values()] - self.storage.put("channels", dumped) - self.storage.write() - def save_short_chan_id(self, chan): """ Checks if Funding TX has been mined. If it has, save the short channel ID in chan; @@ -640,8 +634,8 @@ def save_short_chan_id(self, chan): return block_height, tx_pos = self.lnwatcher.get_txpos(chan.funding_outpoint.txid) assert tx_pos >= 0 - chan.short_channel_id = ShortChannelID.from_components( - block_height, tx_pos, chan.funding_outpoint.output_index) + chan.set_short_channel_id(ShortChannelID.from_components( + block_height, tx_pos, chan.funding_outpoint.output_index)) self.logger.info(f"save_short_channel_id: {chan.short_channel_id}") self.save_channel(chan) @@ -821,7 +815,9 @@ async def _open_channel_coroutine(self, *, connect_str: str, funding_tx: Partial funding_sat=funding_sat, push_msat=push_sat * 1000, temp_channel_id=os.urandom(32)) - self.save_channel(chan) + + self.add_channel(chan) + self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address()) self.network.trigger_callback('channels_updated', self.wallet) if funding_tx.is_complete(): @@ -829,6 +825,10 @@ async def _open_channel_coroutine(self, *, connect_str: str, funding_tx: Partial await asyncio.wait_for(self.network.broadcast_transaction(funding_tx), LN_P2P_NETWORK_TIMEOUT) return chan, funding_tx + def add_channel(self, chan): + with self.lock: + self.channels[chan.channel_id] = chan + @log_exceptions async def add_peer(self, connect_str: str) -> Peer: node_id, rest = extract_nodeid(connect_str) @@ -1222,6 +1222,7 @@ def get_balance(self): return Decimal(sum(chan.balance(LOCAL) if not chan.is_closed() else 0 for chan in self.channels.values()))/1000 def list_channels(self): + encoder = MyEncoder() with self.lock: # we output the funding_outpoint instead of the channel_id because lnd uses channel_point (funding outpoint) to identify channels for channel_id, chan in self.channels.items(): @@ -1259,7 +1260,8 @@ def remove_channel(self, chan_id): assert chan.is_closed() with self.lock: self.channels.pop(chan_id) - self.save_channels() + self.storage.get('channels').pop(chan_id.hex()) + self.network.trigger_callback('channels_updated', self.wallet) self.network.trigger_callback('wallet_updated', self.wallet) diff --git a/electrum/tests/test_lnchannel.py b/electrum/tests/test_lnchannel.py index ffb2f4362509..192bb2f5542e 100644 --- a/electrum/tests/test_lnchannel.py +++ b/electrum/tests/test_lnchannel.py @@ -35,6 +35,7 @@ from electrum.ecc import sig_string_from_der_sig from electrum.logging import console_stderr_handler from electrum.lnchannel import channel_states +from electrum.json_db import StorageDict from . import ElectrumTestCase @@ -46,9 +47,8 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator, assert remote_amount > 0 channel_id, _ = lnpeer.channel_id_from_funding_tx(funding_txid, funding_index) their_revocation_store = lnpeer.RevocationStore() - - return { - "channel_id":channel_id, + state = { + "channel_id":channel_id.hex(), "short_channel_id":channel_id[:8], "funding_outpoint":lnpeer.Outpoint(funding_txid, funding_index), "remote_config":lnpeer.RemoteConfig( @@ -93,10 +93,13 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator, is_initiator=is_initiator, funding_txn_minimum_depth=3, ), - "node_id":other_node_id, + "node_id":other_node_id.hex(), 'onion_keys': {}, + 'data_loss_protect_remote_pcp': {}, 'state': 'PREOPENING', + 'log': {}, } + return StorageDict(None, '', state) def bip32(sequence): node = bip32_utils.BIP32Node.from_rootseed(b"9dk", xtype='standard').subkey_at_private_derivation(sequence) @@ -151,14 +154,16 @@ def create_test_channels(feerate=6000, local=None, remote=None): assert len(a_htlc_sigs) == 0 assert len(b_htlc_sigs) == 0 - alice.config[LOCAL] = alice.config[LOCAL]._replace(current_commitment_signature=sig_from_bob) - bob.config[LOCAL] = bob.config[LOCAL]._replace(current_commitment_signature=sig_from_alice) + alice.config[LOCAL].current_commitment_signature = sig_from_bob + bob.config[LOCAL].current_commitment_signature = sig_from_alice alice_second = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(alice_seed, lnutil.RevocationStore.START_INDEX - 1), "big")) bob_second = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(bob_seed, lnutil.RevocationStore.START_INDEX - 1), "big")) - alice.config[REMOTE] = alice.config[REMOTE]._replace(next_per_commitment_point=bob_second, current_per_commitment_point=bob_first) - bob.config[REMOTE] = bob.config[REMOTE]._replace(next_per_commitment_point=alice_second, current_per_commitment_point=alice_first) + alice.config[REMOTE].next_per_commitment_point=bob_second + alice.config[REMOTE].current_per_commitment_point = bob_first + bob.config[REMOTE].next_per_commitment_point=alice_second + bob.config[REMOTE].current_per_commitment_point = alice_first alice.hm.channel_open_finished() bob.hm.channel_open_finished() @@ -663,15 +668,11 @@ def setUp(self): bob_min_reserve = 6 * one_bitcoin_in_msat // 1000 # bob min reserve was decided by alice, but applies to bob - alice_channel.config[LOCAL] =\ - alice_channel.config[LOCAL]._replace(reserve_sat=bob_min_reserve) - alice_channel.config[REMOTE] =\ - alice_channel.config[REMOTE]._replace(reserve_sat=alice_min_reserve) + alice_channel.config[LOCAL].reserve_sat = bob_min_reserve + alice_channel.config[REMOTE].reserve_sat = alice_min_reserve - bob_channel.config[LOCAL] =\ - bob_channel.config[LOCAL]._replace(reserve_sat=alice_min_reserve) - bob_channel.config[REMOTE] =\ - bob_channel.config[REMOTE]._replace(reserve_sat=bob_min_reserve) + bob_channel.config[LOCAL].reserve_sat = alice_min_reserve + bob_channel.config[REMOTE].reserve_sat = bob_min_reserve self.alice_channel = alice_channel self.bob_channel = bob_channel diff --git a/electrum/tests/test_lnhtlc.py b/electrum/tests/test_lnhtlc.py index a606dc44b834..db24f5b9f168 100644 --- a/electrum/tests/test_lnhtlc.py +++ b/electrum/tests/test_lnhtlc.py @@ -4,18 +4,18 @@ from electrum.lnutil import RECEIVED, LOCAL, REMOTE, SENT, HTLCOwner, Direction from electrum.lnhtlc import HTLCManager +from electrum.json_db import StorageDict from . import ElectrumTestCase - class H(NamedTuple): owner : str htlc_id : int class TestHTLCManager(ElectrumTestCase): def test_adding_htlcs_race(self): - A = HTLCManager() - B = HTLCManager() + A = HTLCManager(StorageDict(None, '', {})) + B = HTLCManager(StorageDict(None, '', {})) A.channel_open_finished() B.channel_open_finished() ah0, bh0 = H('A', 0), H('B', 0) @@ -61,8 +61,8 @@ def test_adding_htlcs_race(self): def test_single_htlc_full_lifecycle(self): def htlc_lifecycle(htlc_success: bool): - A = HTLCManager() - B = HTLCManager() + A = HTLCManager(StorageDict(None, '', {})) + B = HTLCManager(StorageDict(None, '', {})) A.channel_open_finished() B.channel_open_finished() B.recv_htlc(A.send_htlc(H('A', 0))) @@ -134,8 +134,8 @@ def htlc_lifecycle(htlc_success: bool): def test_remove_htlc_while_owing_commitment(self): def htlc_lifecycle(htlc_success: bool): - A = HTLCManager() - B = HTLCManager() + A = HTLCManager(StorageDict(None, '', {})) + B = HTLCManager(StorageDict(None, '', {})) A.channel_open_finished() B.channel_open_finished() ah0 = H('A', 0) @@ -171,8 +171,8 @@ def htlc_lifecycle(htlc_success: bool): htlc_lifecycle(htlc_success=False) def test_adding_htlc_between_send_ctx_and_recv_rev(self): - A = HTLCManager() - B = HTLCManager() + A = HTLCManager(StorageDict(None, '', {})) + B = HTLCManager(StorageDict(None, '', {})) A.channel_open_finished() B.channel_open_finished() A.send_ctx() @@ -217,8 +217,8 @@ def test_adding_htlc_between_send_ctx_and_recv_rev(self): self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE)) def test_unacked_local_updates(self): - A = HTLCManager() - B = HTLCManager() + A = HTLCManager(StorageDict(None, '', {})) + B = HTLCManager(StorageDict(None, '', {})) A.channel_open_finished() B.channel_open_finished() self.assertEqual({}, A.get_unacked_local_updates()) diff --git a/electrum/util.py b/electrum/util.py index c60a29770e16..7fee07e1da77 100644 --- a/electrum/util.py +++ b/electrum/util.py @@ -271,6 +271,8 @@ def default(self, obj): return obj.isoformat(' ')[:-3] if isinstance(obj, set): return list(obj) + if isinstance(obj, bytes): # for nametuples in lnchannel + return obj.hex() if hasattr(obj, 'to_json') and callable(obj.to_json): return obj.to_json() return super(MyEncoder, self).default(obj)