-
Notifications
You must be signed in to change notification settings - Fork 3k
/
interface.py
1133 lines (985 loc) · 46.4 KB
/
interface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python
#
# Electrum - lightweight Bitcoin client
# Copyright (C) 2011 thomasv@gitorious
#
# Permission is hereby granted, free of charge, to any person
# obtaining a copy of this software and associated documentation files
# (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import os
import re
import ssl
import sys
import traceback
import asyncio
import socket
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple, Any, Sequence, Dict
from collections import defaultdict
from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address, IPv4Address
import itertools
import logging
import hashlib
import functools
import aiorpcx
from aiorpcx import TaskGroup
from aiorpcx import RPCSession, Notification, NetAddress, NewlineFramer
from aiorpcx.curio import timeout_after, TaskTimeout
from aiorpcx.jsonrpc import JSONRPC, CodeMessageError
from aiorpcx.rawsocket import RSClient
import certifi
from .util import (ignore_exceptions, log_exceptions, bfh, SilentTaskGroup, MySocksProxy,
is_integer, is_non_negative_integer, is_hash256_str, is_hex_str,
is_int_or_float, is_non_negative_int_or_float)
from . import util
from . import x509
from . import pem
from . import version
from . import blockchain
from .blockchain import Blockchain, HEADER_SIZE
from . import bitcoin
from . import constants
from .i18n import _
from .logging import Logger
from .transaction import Transaction
if TYPE_CHECKING:
from .network import Network
from .simple_config import SimpleConfig
ca_path = certifi.where()
BUCKET_NAME_OF_ONION_SERVERS = 'onion'
MAX_INCOMING_MSG_SIZE = 1_000_000 # in bytes
_KNOWN_NETWORK_PROTOCOLS = {'t', 's'}
PREFERRED_NETWORK_PROTOCOL = 's'
assert PREFERRED_NETWORK_PROTOCOL in _KNOWN_NETWORK_PROTOCOLS
class NetworkTimeout:
# seconds
class Generic:
NORMAL = 30
RELAXED = 45
MOST_RELAXED = 600
class Urgent(Generic):
NORMAL = 10
RELAXED = 20
MOST_RELAXED = 60
def assert_non_negative_integer(val: Any) -> None:
if not is_non_negative_integer(val):
raise RequestCorrupted(f'{val!r} should be a non-negative integer')
def assert_integer(val: Any) -> None:
if not is_integer(val):
raise RequestCorrupted(f'{val!r} should be an integer')
def assert_int_or_float(val: Any) -> None:
if not is_int_or_float(val):
raise RequestCorrupted(f'{val!r} should be int or float')
def assert_non_negative_int_or_float(val: Any) -> None:
if not is_non_negative_int_or_float(val):
raise RequestCorrupted(f'{val!r} should be a non-negative int or float')
def assert_hash256_str(val: Any) -> None:
if not is_hash256_str(val):
raise RequestCorrupted(f'{val!r} should be a hash256 str')
def assert_hex_str(val: Any) -> None:
if not is_hex_str(val):
raise RequestCorrupted(f'{val!r} should be a hex str')
def assert_dict_contains_field(d: Any, *, field_name: str) -> Any:
if not isinstance(d, dict):
raise RequestCorrupted(f'{d!r} should be a dict')
if field_name not in d:
raise RequestCorrupted(f'required field {field_name!r} missing from dict')
return d[field_name]
def assert_list_or_tuple(val: Any) -> None:
if not isinstance(val, (list, tuple)):
raise RequestCorrupted(f'{val!r} should be a list or tuple')
class NotificationSession(RPCSession):
def __init__(self, *args, interface: 'Interface', **kwargs):
super(NotificationSession, self).__init__(*args, **kwargs)
self.subscriptions = defaultdict(list)
self.cache = {}
self.default_timeout = NetworkTimeout.Generic.NORMAL
self._msg_counter = itertools.count(start=1)
self.interface = interface
self.cost_hard_limit = 0 # disable aiorpcx resource limits
async def handle_request(self, request):
self.maybe_log(f"--> {request}")
try:
if isinstance(request, Notification):
params, result = request.args[:-1], request.args[-1]
key = self.get_hashable_key_for_rpc_call(request.method, params)
if key in self.subscriptions:
self.cache[key] = result
for queue in self.subscriptions[key]:
await queue.put(request.args)
else:
raise Exception(f'unexpected notification')
else:
raise Exception(f'unexpected request. not a notification')
except Exception as e:
self.interface.logger.info(f"error handling request {request}. exc: {repr(e)}")
await self.close()
async def send_request(self, *args, timeout=None, **kwargs):
# note: semaphores/timeouts/backpressure etc are handled by
# aiorpcx. the timeout arg here in most cases should not be set
msg_id = next(self._msg_counter)
self.maybe_log(f"<-- {args} {kwargs} (id: {msg_id})")
try:
# note: RPCSession.send_request raises TaskTimeout in case of a timeout.
# TaskTimeout is a subclass of CancelledError, which is *suppressed* in TaskGroups
response = await asyncio.wait_for(
super().send_request(*args, **kwargs),
timeout)
except (TaskTimeout, asyncio.TimeoutError) as e:
raise RequestTimedOut(f'request timed out: {args} (id: {msg_id})') from e
except CodeMessageError as e:
self.maybe_log(f"--> {repr(e)} (id: {msg_id})")
raise
else:
self.maybe_log(f"--> {response} (id: {msg_id})")
return response
def set_default_timeout(self, timeout):
self.sent_request_timeout = timeout
self.max_send_delay = timeout
async def subscribe(self, method: str, params: List, queue: asyncio.Queue):
# note: until the cache is written for the first time,
# each 'subscribe' call might make a request on the network.
key = self.get_hashable_key_for_rpc_call(method, params)
self.subscriptions[key].append(queue)
if key in self.cache:
result = self.cache[key]
else:
result = await self.send_request(method, params)
self.cache[key] = result
await queue.put(params + [result])
def unsubscribe(self, queue):
"""Unsubscribe a callback to free object references to enable GC."""
# note: we can't unsubscribe from the server, so we keep receiving
# subsequent notifications
for v in self.subscriptions.values():
if queue in v:
v.remove(queue)
@classmethod
def get_hashable_key_for_rpc_call(cls, method, params):
"""Hashable index for subscriptions and cache"""
return str(method) + repr(params)
def maybe_log(self, msg: str) -> None:
if not self.interface: return
if self.interface.debug or self.interface.network.debug:
self.interface.logger.debug(msg)
def default_framer(self):
# overridden so that max_size can be customized
max_size = int(self.interface.network.config.get('network_max_incoming_msg_size',
MAX_INCOMING_MSG_SIZE))
return NewlineFramer(max_size=max_size)
class NetworkException(Exception): pass
class GracefulDisconnect(NetworkException):
log_level = logging.INFO
def __init__(self, *args, log_level=None, **kwargs):
Exception.__init__(self, *args, **kwargs)
if log_level is not None:
self.log_level = log_level
class RequestTimedOut(GracefulDisconnect):
def __str__(self):
return _("Network request timed out.")
class RequestCorrupted(Exception): pass
class ErrorParsingSSLCert(Exception): pass
class ErrorGettingSSLCertFromServer(Exception): pass
class ErrorSSLCertFingerprintMismatch(Exception): pass
class InvalidOptionCombination(Exception): pass
class ConnectError(NetworkException): pass
class _RSClient(RSClient):
async def create_connection(self):
try:
return await super().create_connection()
except OSError as e:
# note: using "from e" here will set __cause__ of ConnectError
raise ConnectError(e) from e
class ServerAddr:
def __init__(self, host: str, port: Union[int, str], *, protocol: str = None):
assert isinstance(host, str), repr(host)
if protocol is None:
protocol = 's'
if not host:
raise ValueError('host must not be empty')
if host[0] == '[' and host[-1] == ']': # IPv6
host = host[1:-1]
try:
net_addr = NetAddress(host, port) # this validates host and port
except Exception as e:
raise ValueError(f"cannot construct ServerAddr: invalid host or port (host={host}, port={port})") from e
if protocol not in _KNOWN_NETWORK_PROTOCOLS:
raise ValueError(f"invalid network protocol: {protocol}")
self.host = str(net_addr.host) # canonical form (if e.g. IPv6 address)
self.port = int(net_addr.port)
self.protocol = protocol
self._net_addr_str = str(net_addr)
@classmethod
def from_str(cls, s: str) -> 'ServerAddr':
# host might be IPv6 address, hence do rsplit:
host, port, protocol = str(s).rsplit(':', 2)
return ServerAddr(host=host, port=port, protocol=protocol)
@classmethod
def from_str_with_inference(cls, s: str) -> Optional['ServerAddr']:
"""Construct ServerAddr from str, guessing missing details.
Ongoing compatibility not guaranteed.
"""
if not s:
return None
items = str(s).rsplit(':', 2)
if len(items) < 2:
return None # although maybe we could guess the port too?
host = items[0]
port = items[1]
if len(items) >= 3:
protocol = items[2]
else:
protocol = PREFERRED_NETWORK_PROTOCOL
return ServerAddr(host=host, port=port, protocol=protocol)
def to_friendly_name(self) -> str:
# note: this method is closely linked to from_str_with_inference
if self.protocol == 's': # hide trailing ":s"
return self.net_addr_str()
return str(self)
def __str__(self):
return '{}:{}'.format(self.net_addr_str(), self.protocol)
def to_json(self) -> str:
return str(self)
def __repr__(self):
return f'<ServerAddr host={self.host} port={self.port} protocol={self.protocol}>'
def net_addr_str(self) -> str:
return self._net_addr_str
def __eq__(self, other):
if not isinstance(other, ServerAddr):
return False
return (self.host == other.host
and self.port == other.port
and self.protocol == other.protocol)
def __ne__(self, other):
return not (self == other)
def __hash__(self):
return hash((self.host, self.port, self.protocol))
def _get_cert_path_for_host(*, config: 'SimpleConfig', host: str) -> str:
filename = host
try:
ip = ip_address(host)
except ValueError:
pass
else:
if isinstance(ip, IPv6Address):
filename = f"ipv6_{ip.packed.hex()}"
return os.path.join(config.path, 'certs', filename)
class Interface(Logger):
LOGGING_SHORTCUT = 'i'
def __init__(self, *, network: 'Network', server: ServerAddr, proxy: Optional[dict]):
self.ready = asyncio.Future()
self.got_disconnected = asyncio.Event()
self.server = server
Logger.__init__(self)
assert network.config.path
self.cert_path = _get_cert_path_for_host(config=network.config, host=self.host)
self.blockchain = None # type: Optional[Blockchain]
self._requested_chunks = set() # type: Set[int]
self.network = network
self.proxy = MySocksProxy.from_proxy_dict(proxy)
self.session = None # type: Optional[NotificationSession]
self._ipaddr_bucket = None
# Latest block header and corresponding height, as claimed by the server.
# Note that these values are updated before they are verified.
# Especially during initial header sync, verification can take a long time.
# Failing verification will get the interface closed.
self.tip_header = None
self.tip = 0
self.fee_estimates_eta = {} # type: Dict[int, int]
# Dump network messages (only for this interface). Set at runtime from the console.
self.debug = False
self.taskgroup = SilentTaskGroup()
async def spawn_task():
task = await self.network.taskgroup.spawn(self.run())
if sys.version_info >= (3, 8):
task.set_name(f"interface::{str(server)}")
asyncio.run_coroutine_threadsafe(spawn_task(), self.network.asyncio_loop)
@property
def host(self):
return self.server.host
@property
def port(self):
return self.server.port
@property
def protocol(self):
return self.server.protocol
def diagnostic_name(self):
return self.server.net_addr_str()
def __str__(self):
return f"<Interface {self.diagnostic_name()}>"
async def is_server_ca_signed(self, ca_ssl_context):
"""Given a CA enforcing SSL context, returns True if the connection
can be established. Returns False if the server has a self-signed
certificate but otherwise is okay. Any other failures raise.
"""
try:
await self.open_session(ca_ssl_context, exit_early=True)
except ConnectError as e:
cause = e.__cause__
if isinstance(cause, ssl.SSLError) and cause.reason == 'CERTIFICATE_VERIFY_FAILED':
# failures due to self-signed certs are normal
return False
raise
return True
async def _try_saving_ssl_cert_for_first_time(self, ca_ssl_context):
ca_signed = await self.is_server_ca_signed(ca_ssl_context)
if ca_signed:
if self._get_expected_fingerprint():
raise InvalidOptionCombination("cannot use --serverfingerprint with CA signed servers")
with open(self.cert_path, 'w') as f:
# empty file means this is CA signed, not self-signed
f.write('')
else:
await self._save_certificate()
def _is_saved_ssl_cert_available(self):
if not os.path.exists(self.cert_path):
return False
with open(self.cert_path, 'r') as f:
contents = f.read()
if contents == '': # CA signed
if self._get_expected_fingerprint():
raise InvalidOptionCombination("cannot use --serverfingerprint with CA signed servers")
return True
# pinned self-signed cert
try:
b = pem.dePem(contents, 'CERTIFICATE')
except SyntaxError as e:
self.logger.info(f"error parsing already saved cert: {e}")
raise ErrorParsingSSLCert(e) from e
try:
x = x509.X509(b)
except Exception as e:
self.logger.info(f"error parsing already saved cert: {e}")
raise ErrorParsingSSLCert(e) from e
try:
x.check_date()
except x509.CertificateError as e:
self.logger.info(f"certificate has expired: {e}")
os.unlink(self.cert_path) # delete pinned cert only in this case
return False
self._verify_certificate_fingerprint(bytearray(b))
return True
async def _get_ssl_context(self):
if self.protocol != 's':
# using plaintext TCP
return None
# see if we already have cert for this server; or get it for the first time
ca_sslc = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH, cafile=ca_path)
if not self._is_saved_ssl_cert_available():
try:
await self._try_saving_ssl_cert_for_first_time(ca_sslc)
except (OSError, ConnectError, aiorpcx.socks.SOCKSError) as e:
raise ErrorGettingSSLCertFromServer(e) from e
# now we have a file saved in our certificate store
siz = os.stat(self.cert_path).st_size
if siz == 0:
# CA signed cert
sslc = ca_sslc
else:
# pinned self-signed cert
sslc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=self.cert_path)
sslc.check_hostname = 0
return sslc
def handle_disconnect(func):
@functools.wraps(func)
async def wrapper_func(self: 'Interface', *args, **kwargs):
try:
return await func(self, *args, **kwargs)
except GracefulDisconnect as e:
self.logger.log(e.log_level, f"disconnecting due to {repr(e)}")
except aiorpcx.jsonrpc.RPCError as e:
self.logger.warning(f"disconnecting due to {repr(e)}")
self.logger.debug(f"(disconnect) trace for {repr(e)}", exc_info=True)
finally:
self.got_disconnected.set()
await self.network.connection_down(self)
# if was not 'ready' yet, schedule waiting coroutines:
self.ready.cancel()
return wrapper_func
@ignore_exceptions # do not kill network.taskgroup
@log_exceptions
@handle_disconnect
async def run(self):
try:
ssl_context = await self._get_ssl_context()
except (ErrorParsingSSLCert, ErrorGettingSSLCertFromServer) as e:
self.logger.info(f'disconnecting due to: {repr(e)}')
return
try:
await self.open_session(ssl_context)
except (asyncio.CancelledError, ConnectError, aiorpcx.socks.SOCKSError) as e:
# make SSL errors for main interface more visible (to help servers ops debug cert pinning issues)
if (isinstance(e, ConnectError) and isinstance(e.__cause__, ssl.SSLError)
and self.is_main_server() and not self.network.auto_connect):
self.logger.warning(f'Cannot connect to main server due to SSL error '
f'(maybe cert changed compared to "{self.cert_path}"). Exc: {repr(e)}')
else:
self.logger.info(f'disconnecting due to: {repr(e)}')
return
def _mark_ready(self) -> None:
if self.ready.cancelled():
raise GracefulDisconnect('conn establishment was too slow; *ready* future was cancelled')
if self.ready.done():
return
assert self.tip_header
chain = blockchain.check_header(self.tip_header)
if not chain:
self.blockchain = blockchain.get_best_chain()
else:
self.blockchain = chain
assert self.blockchain is not None
self.logger.info(f"set blockchain with height {self.blockchain.height()}")
self.ready.set_result(1)
async def _save_certificate(self) -> None:
if not os.path.exists(self.cert_path):
# we may need to retry this a few times, in case the handshake hasn't completed
for _ in range(10):
dercert = await self._fetch_certificate()
if dercert:
self.logger.info("succeeded in getting cert")
self._verify_certificate_fingerprint(dercert)
with open(self.cert_path, 'w') as f:
cert = ssl.DER_cert_to_PEM_cert(dercert)
# workaround android bug
cert = re.sub("([^\n])-----END CERTIFICATE-----","\\1\n-----END CERTIFICATE-----",cert)
f.write(cert)
# even though close flushes we can't fsync when closed.
# and we must flush before fsyncing, cause flush flushes to OS buffer
# fsync writes to OS buffer to disk
f.flush()
os.fsync(f.fileno())
break
await asyncio.sleep(1)
else:
raise GracefulDisconnect("could not get certificate after 10 tries")
async def _fetch_certificate(self) -> bytes:
sslc = ssl.SSLContext()
async with _RSClient(session_factory=RPCSession,
host=self.host, port=self.port,
ssl=sslc, proxy=self.proxy) as session:
asyncio_transport = session.transport._asyncio_transport # type: asyncio.BaseTransport
ssl_object = asyncio_transport.get_extra_info("ssl_object") # type: ssl.SSLObject
return ssl_object.getpeercert(binary_form=True)
def _get_expected_fingerprint(self) -> Optional[str]:
if self.is_main_server():
return self.network.config.get("serverfingerprint")
def _verify_certificate_fingerprint(self, certificate):
expected_fingerprint = self._get_expected_fingerprint()
if not expected_fingerprint:
return
fingerprint = hashlib.sha256(certificate).hexdigest()
fingerprints_match = fingerprint.lower() == expected_fingerprint.lower()
if not fingerprints_match:
util.trigger_callback('cert_mismatch')
raise ErrorSSLCertFingerprintMismatch('Refusing to connect to server due to cert fingerprint mismatch')
self.logger.info("cert fingerprint verification passed")
async def get_block_header(self, height, assert_mode):
self.logger.info(f'requesting block header {height} in mode {assert_mode}')
# use lower timeout as we usually have network.bhi_lock here
timeout = self.network.get_network_timeout_seconds(NetworkTimeout.Urgent)
res = await self.session.send_request('blockchain.block.header', [height], timeout=timeout)
return blockchain.deserialize_header(bytes.fromhex(res), height)
async def request_chunk(self, height: int, tip=None, *, can_return_early=False):
if not is_non_negative_integer(height):
raise Exception(f"{repr(height)} is not a block height")
index = height // 2016
if can_return_early and index in self._requested_chunks:
return
self.logger.info(f"requesting chunk from height {height}")
size = 2016
if tip is not None:
size = min(size, tip - index * 2016 + 1)
size = max(size, 0)
try:
self._requested_chunks.add(index)
res = await self.session.send_request('blockchain.block.headers', [index * 2016, size])
finally:
self._requested_chunks.discard(index)
assert_dict_contains_field(res, field_name='count')
assert_dict_contains_field(res, field_name='hex')
assert_dict_contains_field(res, field_name='max')
assert_non_negative_integer(res['count'])
assert_non_negative_integer(res['max'])
assert_hex_str(res['hex'])
if len(res['hex']) != HEADER_SIZE * 2 * res['count']:
raise RequestCorrupted('inconsistent chunk hex and count')
# we never request more than 2016 headers, but we enforce those fit in a single response
if res['max'] < 2016:
raise RequestCorrupted(f"server uses too low 'max' count for block.headers: {res['max']} < 2016")
if res['count'] != size:
raise RequestCorrupted(f"expected {size} headers but only got {res['count']}")
conn = self.blockchain.connect_chunk(index, res['hex'])
if not conn:
return conn, 0
return conn, res['count']
def is_main_server(self) -> bool:
return (self.network.interface == self or
self.network.interface is None and self.network.default_server == self.server)
async def open_session(self, sslc, exit_early=False):
session_factory = lambda *args, iface=self, **kwargs: NotificationSession(*args, **kwargs, interface=iface)
async with _RSClient(session_factory=session_factory,
host=self.host, port=self.port,
ssl=sslc, proxy=self.proxy) as session:
self.session = session # type: NotificationSession
self.session.set_default_timeout(self.network.get_network_timeout_seconds(NetworkTimeout.Generic))
try:
ver = await session.send_request('server.version', [self.client_name(), version.PROTOCOL_VERSION])
except aiorpcx.jsonrpc.RPCError as e:
raise GracefulDisconnect(e) # probably 'unsupported protocol version'
if exit_early:
return
if ver[1] != version.PROTOCOL_VERSION:
raise GracefulDisconnect(f'server violated protocol-version-negotiation. '
f'we asked for {version.PROTOCOL_VERSION!r}, they sent {ver[1]!r}')
if not self.network.check_interface_against_healthy_spread_of_connected_servers(self):
raise GracefulDisconnect(f'too many connected servers already '
f'in bucket {self.bucket_based_on_ipaddress()}')
self.logger.info(f"connection established. version: {ver}")
try:
async with self.taskgroup as group:
await group.spawn(self.ping)
await group.spawn(self.request_fee_estimates)
await group.spawn(self.run_fetch_blocks)
await group.spawn(self.monitor_connection)
except aiorpcx.jsonrpc.RPCError as e:
if e.code in (JSONRPC.EXCESSIVE_RESOURCE_USAGE,
JSONRPC.SERVER_BUSY,
JSONRPC.METHOD_NOT_FOUND):
raise GracefulDisconnect(e, log_level=logging.WARNING) from e
raise
async def monitor_connection(self):
while True:
await asyncio.sleep(1)
if not self.session or self.session.is_closing():
raise GracefulDisconnect('session was closed')
async def ping(self):
while True:
await asyncio.sleep(300)
await self.session.send_request('server.ping')
async def request_fee_estimates(self):
from .simple_config import FEE_ETA_TARGETS
while True:
async with TaskGroup() as group:
fee_tasks = []
for i in FEE_ETA_TARGETS:
fee_tasks.append((i, await group.spawn(self.get_estimatefee(i))))
for nblock_target, task in fee_tasks:
fee = task.result()
if fee < 0: continue
assert isinstance(fee, int)
self.fee_estimates_eta[nblock_target] = fee
self.network.update_fee_estimates()
await asyncio.sleep(60)
async def close(self, *, force_after: int = None):
"""Closes the connection and waits for it to be closed.
We try to flush buffered data to the wire, so this can take some time.
"""
if force_after is None:
# We give up after a while and just abort the connection.
# Note: specifically if the server is running Fulcrum, waiting seems hopeless,
# the connection must be aborted (see https://github.com/cculianu/Fulcrum/issues/76)
force_after = 1 # seconds
if self.session:
await self.session.close(force_after=force_after)
# monitor_connection will cancel tasks
async def run_fetch_blocks(self):
header_queue = asyncio.Queue()
await self.session.subscribe('blockchain.headers.subscribe', [], header_queue)
while True:
item = await header_queue.get()
raw_header = item[0]
height = raw_header['height']
header = blockchain.deserialize_header(bfh(raw_header['hex']), height)
self.tip_header = header
self.tip = height
if self.tip < constants.net.max_checkpoint():
raise GracefulDisconnect('server tip below max checkpoint')
self._mark_ready()
await self._process_header_at_tip()
# header processing done
util.trigger_callback('blockchain_updated')
util.trigger_callback('network_updated')
await self.network.switch_unwanted_fork_interface()
await self.network.switch_lagging_interface()
async def _process_header_at_tip(self):
height, header = self.tip, self.tip_header
async with self.network.bhi_lock:
if self.blockchain.height() >= height and self.blockchain.check_header(header):
# another interface amended the blockchain
self.logger.info(f"skipping header {height}")
return
_, height = await self.step(height, header)
# in the simple case, height == self.tip+1
if height <= self.tip:
await self.sync_until(height)
async def sync_until(self, height, next_height=None):
if next_height is None:
next_height = self.tip
last = None
while last is None or height <= next_height:
prev_last, prev_height = last, height
if next_height > height + 10:
could_connect, num_headers = await self.request_chunk(height, next_height)
if not could_connect:
if height <= constants.net.max_checkpoint():
raise GracefulDisconnect('server chain conflicts with checkpoints or genesis')
last, height = await self.step(height)
continue
util.trigger_callback('network_updated')
height = (height // 2016 * 2016) + num_headers
assert height <= next_height+1, (height, self.tip)
last = 'catchup'
else:
last, height = await self.step(height)
assert (prev_last, prev_height) != (last, height), 'had to prevent infinite loop in interface.sync_until'
return last, height
async def step(self, height, header=None):
assert 0 <= height <= self.tip, (height, self.tip)
if header is None:
header = await self.get_block_header(height, 'catchup')
chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
if chain:
self.blockchain = chain if isinstance(chain, Blockchain) else self.blockchain
# note: there is an edge case here that is not handled.
# we might know the blockhash (enough for check_header) but
# not have the header itself. e.g. regtest chain with only genesis.
# this situation resolves itself on the next block
return 'catchup', height+1
can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height)
if not can_connect:
self.logger.info(f"can't connect {height}")
height, header, bad, bad_header = await self._search_headers_backwards(height, header)
chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height)
assert chain or can_connect
if can_connect:
self.logger.info(f"could connect {height}")
height += 1
if isinstance(can_connect, Blockchain): # not when mocking
self.blockchain = can_connect
self.blockchain.save_header(header)
return 'catchup', height
good, bad, bad_header = await self._search_headers_binary(height, bad, bad_header, chain)
return await self._resolve_potential_chain_fork_given_forkpoint(good, bad, bad_header)
async def _search_headers_binary(self, height, bad, bad_header, chain):
assert bad == bad_header['block_height']
_assert_header_does_not_check_against_any_chain(bad_header)
self.blockchain = chain if isinstance(chain, Blockchain) else self.blockchain
good = height
while True:
assert good < bad, (good, bad)
height = (good + bad) // 2
self.logger.info(f"binary step. good {good}, bad {bad}, height {height}")
header = await self.get_block_header(height, 'binary')
chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
if chain:
self.blockchain = chain if isinstance(chain, Blockchain) else self.blockchain
good = height
else:
bad = height
bad_header = header
if good + 1 == bad:
break
mock = 'mock' in bad_header and bad_header['mock']['connect'](height)
real = not mock and self.blockchain.can_connect(bad_header, check_height=False)
if not real and not mock:
raise Exception('unexpected bad header during binary: {}'.format(bad_header))
_assert_header_does_not_check_against_any_chain(bad_header)
self.logger.info(f"binary search exited. good {good}, bad {bad}")
return good, bad, bad_header
async def _resolve_potential_chain_fork_given_forkpoint(self, good, bad, bad_header):
assert good + 1 == bad
assert bad == bad_header['block_height']
_assert_header_does_not_check_against_any_chain(bad_header)
# 'good' is the height of a block 'good_header', somewhere in self.blockchain.
# bad_header connects to good_header; bad_header itself is NOT in self.blockchain.
bh = self.blockchain.height()
assert bh >= good, (bh, good)
if bh == good:
height = good + 1
self.logger.info(f"catching up from {height}")
return 'no_fork', height
# this is a new fork we don't yet have
height = bad + 1
self.logger.info(f"new fork at bad height {bad}")
forkfun = self.blockchain.fork if 'mock' not in bad_header else bad_header['mock']['fork']
b = forkfun(bad_header) # type: Blockchain
self.blockchain = b
assert b.forkpoint == bad
return 'fork', height
async def _search_headers_backwards(self, height, header):
async def iterate():
nonlocal height, header
checkp = False
if height <= constants.net.max_checkpoint():
height = constants.net.max_checkpoint()
checkp = True
header = await self.get_block_header(height, 'backward')
chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height)
if chain or can_connect:
return False
if checkp:
raise GracefulDisconnect("server chain conflicts with checkpoints")
return True
bad, bad_header = height, header
_assert_header_does_not_check_against_any_chain(bad_header)
with blockchain.blockchains_lock: chains = list(blockchain.blockchains.values())
local_max = max([0] + [x.height() for x in chains]) if 'mock' not in header else float('inf')
height = min(local_max + 1, height - 1)
while await iterate():
bad, bad_header = height, header
delta = self.tip - height
height = self.tip - 2 * delta
_assert_header_does_not_check_against_any_chain(bad_header)
self.logger.info(f"exiting backward mode at {height}")
return height, header, bad, bad_header
@classmethod
def client_name(cls) -> str:
return f'electrum/{version.ELECTRUM_VERSION}'
def is_tor(self):
return self.host.endswith('.onion')
def ip_addr(self) -> Optional[str]:
session = self.session
if not session: return None
peer_addr = session.remote_address()
if not peer_addr: return None
return str(peer_addr.host)
def bucket_based_on_ipaddress(self) -> str:
def do_bucket():
if self.is_tor():
return BUCKET_NAME_OF_ONION_SERVERS
try:
ip_addr = ip_address(self.ip_addr()) # type: Union[IPv4Address, IPv6Address]
except ValueError:
return ''
if not ip_addr:
return ''
if ip_addr.is_loopback: # localhost is exempt
return ''
if ip_addr.version == 4:
slash16 = IPv4Network(ip_addr).supernet(prefixlen_diff=32-16)
return str(slash16)
elif ip_addr.version == 6:
slash48 = IPv6Network(ip_addr).supernet(prefixlen_diff=128-48)
return str(slash48)
return ''
if not self._ipaddr_bucket:
self._ipaddr_bucket = do_bucket()
return self._ipaddr_bucket
async def get_merkle_for_transaction(self, tx_hash: str, tx_height: int) -> dict:
if not is_hash256_str(tx_hash):
raise Exception(f"{repr(tx_hash)} is not a txid")
if not is_non_negative_integer(tx_height):
raise Exception(f"{repr(tx_height)} is not a block height")
# do request
res = await self.session.send_request('blockchain.transaction.get_merkle', [tx_hash, tx_height])
# check response
block_height = assert_dict_contains_field(res, field_name='block_height')
merkle = assert_dict_contains_field(res, field_name='merkle')
pos = assert_dict_contains_field(res, field_name='pos')
# note: tx_height was just a hint to the server, don't enforce the response to match it
assert_non_negative_integer(block_height)
assert_non_negative_integer(pos)
assert_list_or_tuple(merkle)
for item in merkle:
assert_hash256_str(item)
return res
async def get_transaction(self, tx_hash: str, *, timeout=None) -> str:
if not is_hash256_str(tx_hash):
raise Exception(f"{repr(tx_hash)} is not a txid")
raw = await self.session.send_request('blockchain.transaction.get', [tx_hash], timeout=timeout)
# validate response
if not is_hex_str(raw):
raise RequestCorrupted(f"received garbage (non-hex) as tx data (txid {tx_hash}): {raw!r}")
tx = Transaction(raw)
try:
tx.deserialize() # see if raises
except Exception as e:
raise RequestCorrupted(f"cannot deserialize received transaction (txid {tx_hash})") from e
if tx.txid() != tx_hash:
raise RequestCorrupted(f"received tx does not match expected txid {tx_hash} (got {tx.txid()})")
return raw
async def get_history_for_scripthash(self, sh: str) -> List[dict]:
if not is_hash256_str(sh):
raise Exception(f"{repr(sh)} is not a scripthash")
# do request
res = await self.session.send_request('blockchain.scripthash.get_history', [sh])
# check response
assert_list_or_tuple(res)
prev_height = 1
for tx_item in res:
height = assert_dict_contains_field(tx_item, field_name='height')
assert_dict_contains_field(tx_item, field_name='tx_hash')
assert_integer(height)
assert_hash256_str(tx_item['tx_hash'])
if height in (-1, 0):
assert_dict_contains_field(tx_item, field_name='fee')
assert_non_negative_integer(tx_item['fee'])
prev_height = - float("inf") # this ensures confirmed txs can't follow mempool txs
else:
# check monotonicity of heights
if height < prev_height:
raise RequestCorrupted(f'heights of confirmed txs must be in increasing order')
prev_height = height
hashes = set(map(lambda item: item['tx_hash'], res))
if len(hashes) != len(res):
# Either server is sending garbage... or maybe if server is race-prone
# a recently mined tx could be included in both last block and mempool?
# Still, it's simplest to just disregard the response.
raise RequestCorrupted(f"server history has non-unique txids for sh={sh}")
return res
async def listunspent_for_scripthash(self, sh: str) -> List[dict]:
if not is_hash256_str(sh):
raise Exception(f"{repr(sh)} is not a scripthash")
# do request
res = await self.session.send_request('blockchain.scripthash.listunspent', [sh])
# check response
assert_list_or_tuple(res)
for utxo_item in res:
assert_dict_contains_field(utxo_item, field_name='tx_pos')
assert_dict_contains_field(utxo_item, field_name='value')
assert_dict_contains_field(utxo_item, field_name='tx_hash')
assert_dict_contains_field(utxo_item, field_name='height')
assert_non_negative_integer(utxo_item['tx_pos'])
assert_non_negative_integer(utxo_item['value'])
assert_non_negative_integer(utxo_item['height'])
assert_hash256_str(utxo_item['tx_hash'])
return res
async def get_balance_for_scripthash(self, sh: str) -> dict:
if not is_hash256_str(sh):
raise Exception(f"{repr(sh)} is not a scripthash")
# do request
res = await self.session.send_request('blockchain.scripthash.get_balance', [sh])
# check response
assert_dict_contains_field(res, field_name='confirmed')