From 5343cf343964b3a069770c8add87ac0ca57506dd Mon Sep 17 00:00:00 2001 From: RachitSharma2001 Date: Sat, 24 Dec 2022 10:28:33 -0800 Subject: [PATCH] Fix s3.iter_bucket failure when botocore_session passed in --- smart_open/s3.py | 22 ++++++++++++++-------- smart_open/tests/test_s3.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/smart_open/s3.py b/smart_open/s3.py index b2111637..4d87e85c 100644 --- a/smart_open/s3.py +++ b/smart_open/s3.py @@ -1176,17 +1176,18 @@ def iter_bucket( pass total_size, key_no = 0, -1 + global session + session = boto3.session.Session(**session_kwargs) key_iterator = _list_bucket( bucket_name, prefix=prefix, accept_key=accept_key, - **session_kwargs) + create_session=False) download_key = functools.partial( _download_key, bucket_name=bucket_name, retries=retries, - **session_kwargs) - + create_session=False) with smart_open.concurrency.create_pool(processes=workers) as pool: result_iterator = pool.imap_unordered(download_key, key_iterator) key_no = 0 @@ -1221,9 +1222,12 @@ def _list_bucket( bucket_name, prefix='', accept_key=lambda k: True, + create_session=True, **session_kwargs): - session = boto3.session.Session(**session_kwargs) - client = session.client('s3') + if create_session: + client = boto3.session.Session(**session_kwargs).client('s3') + else: + client = session.client('s3') ctoken = None while True: @@ -1248,15 +1252,17 @@ def _list_bucket( break -def _download_key(key_name, bucket_name=None, retries=3, **session_kwargs): +def _download_key(key_name, bucket_name=None, retries=3, create_session=True, **session_kwargs): if bucket_name is None: raise ValueError('bucket_name may not be None') # # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/resources.html#multithreading-or-multiprocessing-with-resources # - session = boto3.session.Session(**session_kwargs) - s3 = session.resource('s3') + if create_session: + s3 = boto3.session.Session(**session_kwargs).resource('s3') + else: + s3 = session.resource('s3') bucket = s3.Bucket(bucket_name) # Sometimes, https://github.com/boto/boto/issues/2409 can happen diff --git a/smart_open/tests/test_s3.py b/smart_open/tests/test_s3.py index a91a731e..189b3071 100644 --- a/smart_open/tests/test_s3.py +++ b/smart_open/tests/test_s3.py @@ -681,6 +681,34 @@ def test_iter_bucket(self): results = list(smart_open.s3.iter_bucket(BUCKET_NAME)) self.assertEqual(len(results), 10) + @pytest.mark.skipif(condition=sys.platform == 'win32', reason="does not run on windows") + @pytest.mark.xfail( + condition=sys.platform == 'darwin', + reason="MacOS uses spawn rather than fork for multiprocessing", + ) + def test_iter_bucket_passed_in_session_multiprocessing_true(self): + populate_bucket() + smart_open.concurrency._MULTIPROCESSING = True + my_session = botocore.session.Session() + results = list(smart_open.s3.iter_bucket(bucket_name=BUCKET_NAME, + botocore_session=my_session, + workers=1)) + self.assertEqual(len(results), 10) + + @pytest.mark.skipif(condition=sys.platform == 'win32', reason="does not run on windows") + @pytest.mark.xfail( + condition=sys.platform == 'darwin', + reason="MacOS uses spawn rather than fork for multiprocessing", + ) + def test_iter_bucket_passed_in_session_multiprocessing_false(self): + populate_bucket() + smart_open.concurrency._MULTIPROCESSING = False + my_session = botocore.session.Session() + results = list(smart_open.s3.iter_bucket(bucket_name=BUCKET_NAME, + botocore_session=my_session, + workers=1)) + self.assertEqual(len(results), 10) + @pytest.mark.skipif(condition=sys.platform == 'win32', reason="does not run on windows") @pytest.mark.xfail( condition=sys.platform == 'darwin',