From 9e1a35660d1f9f0cce9a355bfeaaaa0d9a5e5b1c Mon Sep 17 00:00:00 2001 From: Ryan Barrett Date: Sun, 31 Mar 2024 17:25:00 -0700 Subject: [PATCH] DatastoreStorage.read_blocks_by_seq: can't use @ndb_context since it's a generator ...so handle context manually instead --- arroba/datastore_storage.py | 16 ++++++++++------ arroba/tests/test_datastore_storage.py | 7 +++++++ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/arroba/datastore_storage.py b/arroba/datastore_storage.py index 3fb5d57..ea16f60 100644 --- a/arroba/datastore_storage.py +++ b/arroba/datastore_storage.py @@ -424,15 +424,19 @@ def read_many(self, cids): return {cid: block.to_block() if block else None for cid, block in got} - @ndb_context + # can't use @ndb_context because this is a generator, not a normal function def read_blocks_by_seq(self, start=0): assert start >= 0 - # lexrpc event subscription handlers like subscribeRepos call this on a - # different thread, so if we're there, we need to create a new ndb context - for atp_block in AtpBlock.query(AtpBlock.seq >= start)\ - .order(AtpBlock.seq): - yield atp_block.to_block() + context = get_context(raise_context_error=False) + + with context.use() if context else self.ndb_client.context() as cm: + # lexrpc event subscription handlers like subscribeRepos call this + # on a different thread, so if we're there, we need to create a new + # ndb context + for atp_block in AtpBlock.query(AtpBlock.seq >= start)\ + .order(AtpBlock.seq): + yield atp_block.to_block() @ndb_context def has(self, cid): diff --git a/arroba/tests/test_datastore_storage.py b/arroba/tests/test_datastore_storage.py index a1ebbc4..806a11e 100644 --- a/arroba/tests/test_datastore_storage.py +++ b/arroba/tests/test_datastore_storage.py @@ -149,6 +149,13 @@ def test_read_blocks_by_seq(self): [b.cid for b in self.storage.read_blocks_by_seq(start=4)]) self.assertEqual([], [b.cid for b in self.storage.read_blocks_by_seq(start=6)]) + def test_read_blocks_by_seq_no_ndb_context(self): + AtpSequence.allocate(SUBSCRIBE_REPOS_NSID) + block = self.storage.write(repo_did='did:plc:123', obj={'foo': 2}) + + self.ndb_context.__exit__(None, None, None) + self.assertEqual([block], [b.cid for b in self.storage.read_blocks_by_seq()]) + def assert_same_seq(self, cids): """ Args: