Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No more blocking, part 2 #1350

Merged
merged 10 commits into from
Jun 3, 2024
2 changes: 1 addition & 1 deletion benchmarks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def client_node(request, grid, storage_nodes, number_of_nodes) -> Client:
"client_node",
needed=number_of_nodes,
happy=number_of_nodes,
total=number_of_nodes,
total=number_of_nodes + 3, # Make sure FEC does some work
)
)
print(f"Client node pid: {client_node.process.transport.pid}")
Expand Down
1 change: 1 addition & 0 deletions newsfragments/4072.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Continued work to make Tahoe-LAFS take advantage of multiple CPUs.
8 changes: 4 additions & 4 deletions src/allmydata/crypto/aes.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def encrypt_data(encryptor, plaintext):
"""

_validate_cryptor(encryptor, encrypt=True)
if not isinstance(plaintext, bytes):
raise ValueError('Plaintext must be bytes')
if not isinstance(plaintext, (bytes, memoryview)):
raise ValueError(f'Plaintext must be bytes or memoryview: {type(plaintext)}')

return encryptor.update(plaintext)

Expand Down Expand Up @@ -116,8 +116,8 @@ def decrypt_data(decryptor, plaintext):
"""

_validate_cryptor(decryptor, encrypt=False)
if not isinstance(plaintext, bytes):
raise ValueError('Plaintext must be bytes')
if not isinstance(plaintext, (bytes, memoryview)):
raise ValueError(f'Plaintext must be bytes or memoryview: {type(plaintext)}')

return decryptor.update(plaintext)

Expand Down
2 changes: 1 addition & 1 deletion src/allmydata/immutable/downloader/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def fetch_failed(self, sf, f):

def process_blocks(self, segnum, blocks):
start = now()
d = defer.maybeDeferred(self._decode_blocks, segnum, blocks)
d = self._decode_blocks(segnum, blocks)
d.addCallback(self._check_ciphertext_hash, segnum)
def _deliver(result):
log.msg(format="delivering segment(%(segnum)d)",
Expand Down
10 changes: 6 additions & 4 deletions src/allmydata/mutable/filenode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
IMutableFileVersion, IWriteable
from allmydata.util import hashutil, log, consumer, deferredutil, mathutil
from allmydata.util.assertutil import precondition
from allmydata.util.cputhreadpool import defer_to_thread
from allmydata.uri import WriteableSSKFileURI, ReadonlySSKFileURI, \
WriteableMDMFFileURI, ReadonlyMDMFFileURI
from allmydata.monitor import Monitor
Expand Down Expand Up @@ -128,7 +129,8 @@ def init_from_cap(self, filecap):

return self

def create_with_keys(self, keypair, contents,
@deferredutil.async_to_deferred
async def create_with_keys(self, keypair, contents,
version=SDMF_VERSION):
"""Call this to create a brand-new mutable file. It will create the
shares, find homes for them, and upload the initial contents (created
Expand All @@ -137,8 +139,8 @@ def create_with_keys(self, keypair, contents,
use) when it completes.
"""
self._pubkey, self._privkey = keypair
self._writekey, self._encprivkey, self._fingerprint = derive_mutable_keys(
keypair,
self._writekey, self._encprivkey, self._fingerprint = await defer_to_thread(
derive_mutable_keys, keypair
)
if version == MDMF_VERSION:
self._uri = WriteableMDMFFileURI(self._writekey, self._fingerprint)
Expand All @@ -149,7 +151,7 @@ def create_with_keys(self, keypair, contents,
self._readkey = self._uri.readkey
self._storage_index = self._uri.storage_index
initial_contents = self._get_initial_contents(contents)
return self._upload(initial_contents, None)
return await self._upload(initial_contents, None)

def _get_initial_contents(self, contents):
if contents is None:
Expand Down
38 changes: 27 additions & 11 deletions src/allmydata/mutable/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from __future__ import annotations

import time

from itertools import count

from zope.interface import implementer
from twisted.internet import defer
from twisted.python import failure
Expand Down Expand Up @@ -873,11 +873,20 @@ def _decode_blocks(self, results, segnum):
shares = shares[:self._required_shares]
self.log("decoding segment %d" % segnum)
if segnum == self._num_segments - 1:
d = defer.maybeDeferred(self._tail_decoder.decode, shares, shareids)
d = self._tail_decoder.decode(shares, shareids)
else:
d = defer.maybeDeferred(self._segment_decoder.decode, shares, shareids)
def _process(buffers):
segment = b"".join(buffers)
d = self._segment_decoder.decode(shares, shareids)

# For larger shares, this can take a few milliseconds. As such, we want
# to unblock the event loop. In newer Python b"".join() will release
# the GIL: https://github.com/python/cpython/issues/80232
@deferredutil.async_to_deferred
async def _got_buffers(buffers):
return await defer_to_thread(lambda: b"".join(buffers))

d.addCallback(_got_buffers)

def _process(segment):
self.log(format="now decoding segment %(segnum)s of %(numsegs)s",
segnum=segnum,
numsegs=self._num_segments,
Expand Down Expand Up @@ -928,12 +937,20 @@ def notify_server_corruption(self, server, shnum, reason):
reason,
)


def _try_to_validate_privkey(self, enc_privkey, reader, server):
@deferredutil.async_to_deferred
async def _try_to_validate_privkey(self, enc_privkey, reader, server):
node_writekey = self._node.get_writekey()
alleged_privkey_s = decrypt_privkey(node_writekey, enc_privkey)
alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
if alleged_writekey != node_writekey:

def get_privkey():
alleged_privkey_s = decrypt_privkey(node_writekey, enc_privkey)
alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
if alleged_writekey != node_writekey:
return None
privkey, _ = rsa.create_signing_keypair_from_string(alleged_privkey_s)
return privkey

privkey = await defer_to_thread(get_privkey)
if privkey is None:
self.log("invalid privkey from %s shnum %d" %
(reader, reader.shnum),
level=log.WEIRD, umid="YIw4tA")
Expand All @@ -950,7 +967,6 @@ def _try_to_validate_privkey(self, enc_privkey, reader, server):
# it's good
self.log("got valid privkey from shnum %d on reader %s" %
(reader.shnum, reader))
privkey, _ = rsa.create_signing_keypair_from_string(alleged_privkey_s)
self._node._populate_encprivkey(enc_privkey)
self._node._populate_privkey(privkey)
self._need_privkey = False
Expand Down
Loading