From 1b36f9f20d8b326f27ed3adec7fd17d7a446ce25 Mon Sep 17 00:00:00 2001 From: cadnce Date: Sat, 17 Sep 2022 22:57:38 +1000 Subject: [PATCH 1/7] Fixes #599 Swap to using GCS native blob open under the hood. Reduces code maintenence overhead. --- setup.py | 2 +- smart_open/gcs.py | 348 +++-------------------------------- smart_open/tests/test_gcs.py | 154 +++------------- 3 files changed, 51 insertions(+), 453 deletions(-) diff --git a/setup.py b/setup.py index def401b2..4b906a41 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ def read(fname): aws_deps = ['boto3'] -gcs_deps = ['google-cloud-storage>=1.31.0'] +gcs_deps = ['google-cloud-storage>=1.37.0'] azure_deps = ['azure-storage-blob', 'azure-common', 'azure-core'] http_deps = ['requests'] diff --git a/smart_open/gcs.py b/smart_open/gcs.py index dc0e5042..e65d3253 100644 --- a/smart_open/gcs.py +++ b/smart_open/gcs.py @@ -28,8 +28,6 @@ _BINARY_TYPES = (bytes, bytearray, memoryview) """Allowed binary buffer types for writing to the underlying GCS stream""" -_UNKNOWN = '*' - SCHEME = "gs" """Supported scheme for GCS""" @@ -42,57 +40,6 @@ DEFAULT_BUFFER_SIZE = 256 * 1024 """Default buffer size for working with GCS""" -_UPLOAD_INCOMPLETE_STATUS_CODES = (308, ) -_UPLOAD_COMPLETE_STATUS_CODES = (200, 201) - - -def _make_range_string(start, stop=None, end=None): - # - # GCS seems to violate RFC-2616 (see utils.make_range_string), so we - # need a separate implementation. - # - # https://cloud.google.com/storage/docs/xml-api/resumable-upload#step_3upload_the_file_blocks - # - if end is None: - end = _UNKNOWN - if stop is None: - return 'bytes %d-/%s' % (start, end) - return 'bytes %d-%d/%s' % (start, stop, end) - - -class UploadFailedError(Exception): - def __init__(self, message, status_code, text): - """Raise when a multi-part upload to GCS returns a failed response status code. - - Parameters - ---------- - message: str - The error message to display. - status_code: int - The status code returned from the upload response. - text: str - The text returned from the upload response. - - """ - super(UploadFailedError, self).__init__(message) - self.status_code = status_code - self.text = text - - def __reduce__(self): - return UploadFailedError, (self.args[0], self.status_code, self.text) - - -def _fail(response, part_num, content_length, total_size, headers): - status_code = response.status_code - response_text = response.text - total_size_gb = total_size / 1024.0 ** 3 - - msg = ( - "upload failed (status code: %(status_code)d, response text: %(response_text)s), " - "part #%(part_num)d, %(total_size)d bytes (total %(total_size_gb).3fGB), headers: %(headers)r" - ) % locals() - raise UploadFailedError(msg, response.status_code, response.text) - def parse_uri(uri_as_string): sr = smart_open.utils.safe_urlsplit(uri_as_string) @@ -159,49 +106,6 @@ def open( fileobj.name = blob_id return fileobj - -class _RawReader(object): - """Read an GCS object.""" - - def __init__(self, gcs_blob, size): - # type: (google.cloud.storage.Blob, int) -> None - self._blob = gcs_blob - self._size = size - self._position = 0 - - def seek(self, position): - """Seek to the specified position (byte offset) in the GCS key. - - :param int position: The byte offset from the beginning of the key. - - Returns the position after seeking. - """ - self._position = position - return self._position - - def read(self, size=-1): - if self._position >= self._size: - return b'' - binary = self._download_blob_chunk(size) - self._position += len(binary) - return binary - - def _download_blob_chunk(self, size): - start = position = self._position - if position == self._size: - # - # When reading, we can't seek to the first byte of an empty file. - # Similarly, we can't seek past the last byte. Do nothing here. - # - binary = b'' - elif size == -1: - binary = self._blob.download_as_bytes(start=start) - else: - end = position + size - binary = self._blob.download_as_bytes(start=start, end=end) - return binary - - class Reader(io.BufferedIOBase): """Reads bytes from GCS. @@ -228,27 +132,16 @@ def __init__( self._size = self._blob.size if self._blob.size is not None else 0 - self._raw_reader = _RawReader(self._blob, self._size) - self._current_pos = 0 - self._current_part_size = buffer_size - self._current_part = smart_open.bytebuffer.ByteBuffer(buffer_size) - self._eof = False - self._line_terminator = line_terminator + self._buffer_size = buffer_size - # - # This member is part of the io.BufferedIOBase interface. - # - self.raw = None + self.raw = self._blob.open( + mode=constants.READ_BINARY, + chunk_size=self._buffer_size, + ) # # Override some methods from io.IOBase. # - def close(self): - """Flush and close this stream.""" - logger.debug("close: called") - self._blob = None - self._current_part = None - self._raw_reader = None def readable(self): """Return True if the stream can be read from.""" @@ -274,28 +167,11 @@ def seek(self, offset, whence=constants.WHENCE_START): :param int whence: Where the offset is from. Returns the position after seeking.""" - logger.debug('seeking to offset: %r whence: %r', offset, whence) - if whence not in constants.WHENCE_CHOICES: - raise ValueError('invalid whence, expected one of %r' % constants.WHENCE_CHOICES) - - if whence == constants.WHENCE_START: - new_position = offset - elif whence == constants.WHENCE_CURRENT: - new_position = self._current_pos + offset - else: - new_position = self._size + offset - new_position = smart_open.utils.clamp(new_position, 0, self._size) - self._current_pos = new_position - self._raw_reader.seek(new_position) - logger.debug('current_pos: %r', self._current_pos) - - self._current_part.empty() - self._eof = self._current_pos == self._size - return self._current_pos + return self.raw.seek(offset, whence) def tell(self): """Return the current position within the file.""" - return self._current_pos + return self.raw.tell() def truncate(self, size=None): """Unsupported.""" @@ -303,92 +179,26 @@ def truncate(self, size=None): def read(self, size=-1): """Read up to size bytes from the object and return them.""" - if size == 0: - return b'' - elif size < 0: - self._current_pos = self._size - return self._read_from_buffer() + self._raw_reader.read() - - # - # Return unused data first - # - if len(self._current_part) >= size: - return self._read_from_buffer(size) - - # - # If the stream is finished, return what we have. - # - if self._eof: - return self._read_from_buffer() - - # - # Fill our buffer to the required size. - # - self._fill_buffer(size) - return self._read_from_buffer(size) + return self.raw.read(size) def read1(self, size=-1): - """This is the same as read().""" - return self.read(size=size) + return self.raw.read1(size) def readinto(self, b): """Read up to len(b) bytes into b, and return the number of bytes read.""" - data = self.read(len(b)) - if not data: - return 0 - b[:len(data)] = data - return len(data) + return self.raw.readinto(b) - def readline(self, limit=-1): + def readline(self, limit=-1) -> bytes: """Read up to and including the next newline. Returns the bytes read.""" - if limit != -1: - raise NotImplementedError('limits other than -1 not implemented yet') - the_line = io.BytesIO() - while not (self._eof and len(self._current_part) == 0): - # - # In the worst case, we're reading the unread part of self._current_part - # twice here, once in the if condition and once when calling index. - # - # This is sub-optimal, but better than the alternative: wrapping - # .index in a try..except, because that is slower. - # - remaining_buffer = self._current_part.peek() - if self._line_terminator in remaining_buffer: - next_newline = remaining_buffer.index(self._line_terminator) - the_line.write(self._read_from_buffer(next_newline + 1)) - break - else: - the_line.write(self._read_from_buffer()) - self._fill_buffer() - return the_line.getvalue() - - # - # Internal methods. - # - def _read_from_buffer(self, size=-1): - """Remove at most size bytes from our buffer and return them.""" - # logger.debug('reading %r bytes from %r byte-long buffer', size, len(self._current_part)) - size = size if size >= 0 else len(self._current_part) - part = self._current_part.read(size) - self._current_pos += len(part) - # logger.debug('part: %r', part) - return part - - def _fill_buffer(self, size=-1): - size = size if size >= 0 else self._current_part._chunk_size - while len(self._current_part) < size and not self._eof: - bytes_read = self._current_part.fill(self._raw_reader) - if bytes_read == 0: - logger.debug('reached EOF while filling buffer') - self._eof = True + return super().readline(limit) def __str__(self): return "(%s, %r, %r)" % (self.__class__.__name__, self._blob.bucket.name, self._blob.name) def __repr__(self): return "%s(bucket=%r, blob=%r, buffer_size=%r)" % ( - self.__class__.__name__, self._blob.bucket.name, self._blob.name, self._current_part_size, + self.__class__.__name__, self._blob.bucket.name, self._blob.name, self._buffer_size, ) @@ -413,26 +223,15 @@ def __init__( assert min_part_size >= _MIN_MIN_PART_SIZE, 'min part size must be greater than 256KB' self._min_part_size = min_part_size - self._total_size = 0 - self._total_parts = 0 - self._bytes_uploaded = 0 - self._current_part = io.BytesIO() - - self._session = google.auth.transport.requests.AuthorizedSession(client._credentials) - if blob_properties: for k, v in blob_properties.items(): setattr(self._blob, k, v) - # - # https://cloud.google.com/storage/docs/json_api/v1/how-tos/resumable-upload#start-resumable - # - self._resumable_upload_url = self._blob.create_resumable_upload_session() - - # - # This member is part of the io.BufferedIOBase interface. - # - self.raw = None + self.raw = self._blob.open( + mode=constants.WRITE_BINARY, + chunk_size=min_part_size, + ignore_flush=True, + ) def flush(self): pass @@ -443,16 +242,12 @@ def flush(self): def close(self): logger.debug("closing") if not self.closed: - if self._total_size == 0: # empty files - self._upload_empty_part() - else: - self._upload_part(is_last=True) - self._client = None + self.raw.close() logger.debug("successfully closed") @property def closed(self): - return self._client is None + return self.raw.closed def writable(self): """Return True if the stream supports writing.""" @@ -474,7 +269,7 @@ def truncate(self, size=None): def tell(self): """Return the current stream position.""" - return self._total_size + return self.raw.tell() # # io.BufferedIOBase methods. @@ -491,106 +286,11 @@ def write(self, b): if not isinstance(b, _BINARY_TYPES): raise TypeError("input must be one of %r, got: %r" % (_BINARY_TYPES, type(b))) - self._current_part.write(b) - self._total_size += len(b) - - # - # If the size of this part is precisely equal to the minimum part size, - # we don't perform the actual write now, and wait until we see more data. - # We do this because the very last part of the upload must be handled slightly - # differently (see comments in the _upload_part method). - # - if self._current_part.tell() > self._min_part_size: - self._upload_part() - - return len(b) + return self.raw.write(b) + #TODO: Maintaining for api compatibility def terminate(self): - """Cancel the underlying resumable upload.""" - # - # https://cloud.google.com/storage/docs/xml-api/resumable-upload#example_cancelling_an_upload - # - self._session.delete(self._resumable_upload_url) - - # - # Internal methods. - # - def _upload_part(self, is_last=False): - part_num = self._total_parts + 1 - - # - # Here we upload the largest amount possible given GCS's restriction - # of parts being multiples of 256kB, except for the last one. - # - # A final upload of 0 bytes does not work, so we need to guard against - # this edge case. This results in occasionally keeping an additional - # 256kB in the buffer after uploading a part, but until this is fixed - # on Google's end there is no other option. - # - # https://stackoverflow.com/questions/60230631/upload-zero-size-final-part-to-google-cloud-storage-resumable-upload - # - content_length = self._current_part.tell() - remainder = content_length % self._min_part_size - if is_last: - end = self._bytes_uploaded + content_length - elif remainder == 0: - content_length -= _REQUIRED_CHUNK_MULTIPLE - end = None - else: - content_length -= remainder - end = None - - range_stop = self._bytes_uploaded + content_length - 1 - content_range = _make_range_string(self._bytes_uploaded, range_stop, end=end) - headers = { - 'Content-Length': str(content_length), - 'Content-Range': content_range, - } - logger.info( - "uploading part #%i, %i bytes (total %.3fGB) headers %r", - part_num, content_length, range_stop / 1024.0 ** 3, headers, - ) - self._current_part.seek(0) - response = self._session.put( - self._resumable_upload_url, - data=self._current_part.read(content_length), - headers=headers, - ) - - if is_last: - expected = _UPLOAD_COMPLETE_STATUS_CODES - else: - expected = _UPLOAD_INCOMPLETE_STATUS_CODES - if response.status_code not in expected: - _fail(response, part_num, content_length, self._total_size, headers) - logger.debug("upload of part #%i finished" % part_num) - - self._total_parts += 1 - self._bytes_uploaded += content_length - - # - # For the last part, the below _current_part handling is a NOOP. - # - self._current_part = io.BytesIO(self._current_part.read()) - self._current_part.seek(0, io.SEEK_END) - - def _upload_empty_part(self): - logger.debug("creating empty file") - headers = {'Content-Length': '0'} - response = self._session.put(self._resumable_upload_url, headers=headers) - if response.status_code not in _UPLOAD_COMPLETE_STATUS_CODES: - _fail(response, self._total_parts + 1, 0, self._total_size, headers) - - self._total_parts += 1 - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if exc_type is not None: - self.terminate() - else: - self.close() + pass def __str__(self): return "(%s, %r, %r)" % (self.__class__.__name__, self._blob.bucket.name, self._blob.name) diff --git a/smart_open/tests/test_gcs.py b/smart_open/tests/test_gcs.py index 22eb7f40..ce5b7bf8 100644 --- a/smart_open/tests/test_gcs.py +++ b/smart_open/tests/test_gcs.py @@ -139,7 +139,8 @@ def __init__(self, name, bucket): self._bucket = bucket # type: FakeBucket self._exists = False self.__contents = io.BytesIO() - + self.__contents.close = lambda: None + self._size = 0 self._create_if_not_exists() def create_resumable_upload_session(self): @@ -171,21 +172,35 @@ def upload_from_string(self, data): # https://googleapis.dev/python/storage/latest/blobs.html#google.cloud.storage.blob.Blob.upload_from_string if isinstance(data, str): data = bytes(data, 'utf8') - self.__contents = io.BytesIO(data) - self.__contents.seek(0, io.SEEK_END) + self.__contents.truncate(0) + self.__contents.seek(0) + self.__contents.write(data) + self._size = self.__contents.tell() def write(self, data): self.upload_from_string(data) + def open( + self, + mode, + chunk_size=None, + ignore_flush=None, + encoding=None, + errors=None, + newline=None, + **kwargs, + ): + if mode.startswith('r'): + self.__contents.seek(0) + return self.__contents + @property def bucket(self): return self._bucket @property def size(self): - if self.__contents.tell() == 0: - return None - return self.__contents.tell() + return self._size if self._size > 0 else None def _create_if_not_exists(self): self._bucket.register_blob(self) @@ -336,7 +351,7 @@ def put(self, url, data=None, headers=None): upload.write(data.read()) else: upload.write(data) - if not headers.get('Content-Range', '').endswith(smart_open.gcs._UNKNOWN): + if not headers.get('Content-Range', '').endswith('*'): upload.finish() return FakeResponse(200) return FakeResponse(smart_open.gcs._UPLOAD_INCOMPLETE_STATUS_CODES[0]) @@ -361,17 +376,8 @@ def test_delete(self): self.assertFalse(self.blob.exists()) self.assertDictEqual(self.client.uploads, {}) - def test_unfinished_put_does_not_write_to_blob(self): - data = io.BytesIO(b'test') - headers = { - 'Content-Range': 'bytes 0-3/*', - 'Content-Length': str(4), - } - response = self.session.put(self.upload_url, data, headers=headers) - self.assertIn(response.status_code, smart_open.gcs._UPLOAD_INCOMPLETE_STATUS_CODES) - self.session._blob_with_url(self.upload_url, self.client) - blob_contents = self.blob.download_as_bytes() - self.assertEqual(blob_contents, b'') + #def test_unfinished_put_does_not_write_to_blob(self): + # Removed as google client handles this for us def test_finished_put_writes_to_blob(self): data = io.BytesIO(b'test') @@ -711,85 +717,8 @@ def test_write_02(self): fout.write(u"testžížáč".encode("utf-8")) self.assertEqual(fout.tell(), 14) - def test_write_03(self): - """Do multiple writes less than the min_part_size work correctly?""" - # write - min_part_size = 256 * 1024 - smart_open_write = smart_open.gcs.Writer( - BUCKET_NAME, WRITE_BLOB_NAME, min_part_size=min_part_size - ) - local_write = io.BytesIO() - - with smart_open_write as fout: - first_part = b"t" * 262141 - fout.write(first_part) - local_write.write(first_part) - self.assertEqual(fout._current_part.tell(), 262141) - - second_part = b"t\n" - fout.write(second_part) - local_write.write(second_part) - self.assertEqual(fout._current_part.tell(), 262143) - self.assertEqual(fout._total_parts, 0) - - third_part = b"t" - fout.write(third_part) - local_write.write(third_part) - self.assertEqual(fout._current_part.tell(), 262144) - self.assertEqual(fout._total_parts, 0) - - fourth_part = b"t" * 1 - fout.write(fourth_part) - local_write.write(fourth_part) - self.assertEqual(fout._current_part.tell(), 1) - self.assertEqual(fout._total_parts, 1) - - # read back the same key and check its content - output = list(smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME))) - local_write.seek(0) - actual = [line.decode("utf-8") for line in list(local_write)] - self.assertEqual(output, actual) - - def test_write_03a(self): - """Do multiple writes greater than the min_part_size work correctly?""" - min_part_size = 256 * 1024 - smart_open_write = smart_open.gcs.Writer( - BUCKET_NAME, WRITE_BLOB_NAME, min_part_size=min_part_size - ) - local_write = io.BytesIO() - - with smart_open_write as fout: - for i in range(1, 4): - part = b"t" * (min_part_size + 1) - fout.write(part) - local_write.write(part) - self.assertEqual(fout._current_part.tell(), i) - self.assertEqual(fout._total_parts, i) - - # read back the same key and check its content - output = list(smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME))) - local_write.seek(0) - actual = [line.decode("utf-8") for line in list(local_write)] - self.assertEqual(output, actual) - - def test_write_03b(self): - """Does writing a last chunk size equal to a multiple of the min_part_size work?""" - min_part_size = 256 * 1024 - smart_open_write = smart_open.gcs.Writer( - BUCKET_NAME, WRITE_BLOB_NAME, min_part_size=min_part_size - ) - expected = b"t" * min_part_size * 2 - - with smart_open_write as fout: - fout.write(expected) - self.assertEqual(fout._current_part.tell(), 262144) - self.assertEqual(fout._total_parts, 1) - - # read back the same key and check its content - with smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME)) as fin: - output = fin.read().encode('utf-8') - - self.assertEqual(output, expected) + # def test_write_03(self): + # we no longer need to test if part writes work correctly as thats covered by the google cloud storage functionality def test_write_04(self): """Does writing no data cause key with an empty value to be created?""" @@ -875,16 +804,6 @@ def test_flush_close(self): fout.flush() fout.close() - def test_terminate(self): - text = u'там за туманами, вечными, пьяными'.encode('utf-8') - fout = smart_open.gcs.open(BUCKET_NAME, 'key', 'wb') - fout.write(text) - fout.terminate() - - with self.assertRaises(google.api_core.exceptions.NotFound): - with smart_open.gcs.open(BUCKET_NAME, 'key', 'rb') as fin: - fin.read() - @maybe_mock_gcs class OpenTest(unittest.TestCase): @@ -918,27 +837,6 @@ def test_round_trip(self): self.assertEqual(test_string, actual) - -class MakeRangeStringTest(unittest.TestCase): - def test_no_stop(self): - start, stop = 1, None - self.assertEqual(smart_open.gcs._make_range_string(start, stop), 'bytes 1-/*') - - def test_stop(self): - start, stop = 1, 2 - self.assertEqual(smart_open.gcs._make_range_string(start, stop), 'bytes 1-2/*') - - -class PickleUploadFailedTest(unittest.TestCase): - def test_pickle_upload_failed(self): - original = smart_open.gcs.UploadFailedError("foo", 123, "bar") - recovered: smart_open.gcs.UploadFailedError = pickle.loads(pickle.dumps(original)) - - self.assertEqual(original.args, recovered.args) - self.assertEqual(original.text, recovered.text) - self.assertEqual(original.status_code, recovered.status_code) - - if __name__ == '__main__': logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) unittest.main() From 9d183bbdacbfcd123750ae48a773951107f64c3c Mon Sep 17 00:00:00 2001 From: cadnce Date: Mon, 10 Oct 2022 19:41:21 +1100 Subject: [PATCH 2/7] Use the underlying google-storage-blob instead of writing a proxy. Breaking changes: * Removed gcs.Reader/gcs.Writer classes * No Reader/Writer.terminate() * The buffer size can no-longer be controlled independently of chunk_size * calling close twice on a gcs file object will now throw an exception --- smart_open/gcs.py | 268 +++++------------------------------ smart_open/tests/test_gcs.py | 133 ++++++----------- 2 files changed, 81 insertions(+), 320 deletions(-) diff --git a/smart_open/gcs.py b/smart_open/gcs.py index e65d3253..0ef6c5f6 100644 --- a/smart_open/gcs.py +++ b/smart_open/gcs.py @@ -8,7 +8,6 @@ """Implements file-like objects for reading and writing to/from GCS.""" -import io import logging try: @@ -25,20 +24,13 @@ logger = logging.getLogger(__name__) -_BINARY_TYPES = (bytes, bytearray, memoryview) -"""Allowed binary buffer types for writing to the underlying GCS stream""" - SCHEME = "gs" """Supported scheme for GCS""" -_MIN_MIN_PART_SIZE = _REQUIRED_CHUNK_MULTIPLE = 256 * 1024 -"""Google requires you to upload in multiples of 256 KB, except for the last part.""" - _DEFAULT_MIN_PART_SIZE = 50 * 1024**2 """Default minimum part size for GCS multipart uploads""" -DEFAULT_BUFFER_SIZE = 256 * 1024 -"""Default buffer size for working with GCS""" +_DEFAULT_WRITE_OPEN_KWARGS = {'ignore_flush': True} def parse_uri(uri_as_string): @@ -56,14 +48,14 @@ def open_uri(uri, mode, transport_params): def open( - bucket_id, - blob_id, - mode, - buffer_size=DEFAULT_BUFFER_SIZE, - min_part_size=_MIN_MIN_PART_SIZE, - client=None, # type: google.cloud.storage.Client - blob_properties=None - ): + bucket_id, + blob_id, + mode, + min_part_size=_DEFAULT_MIN_PART_SIZE, + client=None, # type: google.cloud.storage.Client + blob_properties=None, + blob_open_kwargs=None, +): """Open an GCS blob for reading or writing. Parameters @@ -74,228 +66,46 @@ def open( The name of the blob within the bucket. mode: str The mode for opening the object. Must be either "rb" or "wb". - buffer_size: int, optional - The buffer size to use when performing I/O. For reading only. min_part_size: int, optional The minimum part size for multipart uploads. For writing only. client: google.cloud.storage.Client, optional The GCS client to use when working with google-cloud-storage. blob_properties: dict, optional Set properties on blob before writing. For writing only. - + blob_open_kwargs: dict, optional + Set properties on the blob """ - if mode == constants.READ_BINARY: - fileobj = Reader( - bucket_id, - blob_id, - buffer_size=buffer_size, - line_terminator=constants.BINARY_NEWLINE, - client=client, - ) - elif mode == constants.WRITE_BINARY: - fileobj = Writer( - bucket_id, + if blob_open_kwargs is None: + blob_open_kwargs = {} + if blob_properties is None: + blob_properties = {} + + if client is None: + client = google.cloud.storage.Client() + + bucket = client.bucket(bucket_id) + if not bucket.exists(): + raise google.cloud.exceptions.NotFound(f'bucket {bucket_id} not found') + + if mode in (constants.READ_BINARY, 'r', 'rt'): + blob = bucket.get_blob(blob_id) + if blob is None: + raise google.cloud.exceptions.NotFound(f'blob {blob_id} not found in {bucket_id}') + + elif mode in (constants.WRITE_BINARY, 'w', 'wt'): + blob_open_kwargs = {**_DEFAULT_WRITE_OPEN_KWARGS, **blob_open_kwargs} + blob = bucket.blob( blob_id, - min_part_size=min_part_size, - client=client, - blob_properties=blob_properties, - ) - else: - raise NotImplementedError('GCS support for mode %r not implemented' % mode) - - fileobj.name = blob_id - return fileobj - -class Reader(io.BufferedIOBase): - """Reads bytes from GCS. - - Implements the io.BufferedIOBase interface of the standard library. - - :raises google.cloud.exceptions.NotFound: Raised when the blob to read from does not exist. - - """ - def __init__( - self, - bucket, - key, - buffer_size=DEFAULT_BUFFER_SIZE, - line_terminator=constants.BINARY_NEWLINE, - client=None, # type: google.cloud.storage.Client - ): - if client is None: - client = google.cloud.storage.Client() - - self._blob = client.bucket(bucket).get_blob(key) # type: google.cloud.storage.Blob - - if self._blob is None: - raise google.cloud.exceptions.NotFound('blob %s not found in %s' % (key, bucket)) - - self._size = self._blob.size if self._blob.size is not None else 0 - - self._buffer_size = buffer_size - - self.raw = self._blob.open( - mode=constants.READ_BINARY, - chunk_size=self._buffer_size, - ) - - # - # Override some methods from io.IOBase. - # - - def readable(self): - """Return True if the stream can be read from.""" - return True - - def seekable(self): - """If False, seek(), tell() and truncate() will raise IOError. - - We offer only seek support, and no truncate support.""" - return True - - # - # io.BufferedIOBase methods. - # - def detach(self): - """Unsupported.""" - raise io.UnsupportedOperation - - def seek(self, offset, whence=constants.WHENCE_START): - """Seek to the specified position. - - :param int offset: The offset in bytes. - :param int whence: Where the offset is from. - - Returns the position after seeking.""" - return self.raw.seek(offset, whence) - - def tell(self): - """Return the current position within the file.""" - return self.raw.tell() - - def truncate(self, size=None): - """Unsupported.""" - raise io.UnsupportedOperation - - def read(self, size=-1): - """Read up to size bytes from the object and return them.""" - return self.raw.read(size) - - def read1(self, size=-1): - return self.raw.read1(size) - - def readinto(self, b): - """Read up to len(b) bytes into b, and return the number of bytes - read.""" - return self.raw.readinto(b) - - def readline(self, limit=-1) -> bytes: - """Read up to and including the next newline. Returns the bytes read.""" - return super().readline(limit) - - def __str__(self): - return "(%s, %r, %r)" % (self.__class__.__name__, self._blob.bucket.name, self._blob.name) - - def __repr__(self): - return "%s(bucket=%r, blob=%r, buffer_size=%r)" % ( - self.__class__.__name__, self._blob.bucket.name, self._blob.name, self._buffer_size, - ) - - -class Writer(io.BufferedIOBase): - """Writes bytes to GCS. - - Implements the io.BufferedIOBase interface of the standard library.""" - - def __init__( - self, - bucket, - blob, - min_part_size=_DEFAULT_MIN_PART_SIZE, - client=None, # type: google.cloud.storage.Client - blob_properties=None, - ): - if client is None: - client = google.cloud.storage.Client() - self._client = client - self._blob = self._client.bucket(bucket).blob(blob) # type: google.cloud.storage.Blob - assert min_part_size % _REQUIRED_CHUNK_MULTIPLE == 0, 'min part size must be a multiple of 256KB' - assert min_part_size >= _MIN_MIN_PART_SIZE, 'min part size must be greater than 256KB' - self._min_part_size = min_part_size - - if blob_properties: - for k, v in blob_properties.items(): - setattr(self._blob, k, v) - - self.raw = self._blob.open( - mode=constants.WRITE_BINARY, chunk_size=min_part_size, - ignore_flush=True, ) - def flush(self): - pass - - # - # Override some methods from io.IOBase. - # - def close(self): - logger.debug("closing") - if not self.closed: - self.raw.close() - logger.debug("successfully closed") - - @property - def closed(self): - return self.raw.closed + for k, v in blob_properties.items(): + try: + setattr(blob, k, v) + except AttributeError: + logger.warn(f'Unable to set property {k} on blob') - def writable(self): - """Return True if the stream supports writing.""" - return True - - def seekable(self): - """If False, seek(), tell() and truncate() will raise IOError. - - We offer only tell support, and no seek or truncate support.""" - return True - - def seek(self, offset, whence=constants.WHENCE_START): - """Unsupported.""" - raise io.UnsupportedOperation - - def truncate(self, size=None): - """Unsupported.""" - raise io.UnsupportedOperation - - def tell(self): - """Return the current stream position.""" - return self.raw.tell() - - # - # io.BufferedIOBase methods. - # - def detach(self): - raise io.UnsupportedOperation("detach() not supported") - - def write(self, b): - """Write the given bytes (binary string) to the GCS file. - - There's buffering happening under the covers, so this may not actually - do any HTTP transfer right away.""" - - if not isinstance(b, _BINARY_TYPES): - raise TypeError("input must be one of %r, got: %r" % (_BINARY_TYPES, type(b))) - - return self.raw.write(b) - - #TODO: Maintaining for api compatibility - def terminate(self): - pass - - def __str__(self): - return "(%s, %r, %r)" % (self.__class__.__name__, self._blob.bucket.name, self._blob.name) + else: + raise NotImplementedError(f'GCS support for mode {mode} not implemented') - def __repr__(self): - return "%s(bucket=%r, blob=%r, min_part_size=%r)" % ( - self.__class__.__name__, self._blob.bucket.name, self._blob.name, self._min_part_size, - ) + return blob.open(mode, **blob_open_kwargs) diff --git a/smart_open/tests/test_gcs.py b/smart_open/tests/test_gcs.py index ce5b7bf8..60e004d3 100644 --- a/smart_open/tests/test_gcs.py +++ b/smart_open/tests/test_gcs.py @@ -55,8 +55,8 @@ def __init__(self, client, name=None): # self.client.register_bucket(self) - def blob(self, blob_id): - return self.blobs.get(blob_id, FakeBlob(blob_id, self)) + def blob(self, blob_id, *args, **kwargs): + return self.blobs.get(blob_id, FakeBlob(blob_id, self, *args, **kwargs)) def delete(self): self.client.delete_bucket(self) @@ -67,7 +67,7 @@ def delete(self): def exists(self): return self._exists - def get_blob(self, blob_id): + def get_blob(self, blob_id, **kwargs): try: return self.blobs[blob_id] except KeyError as e: @@ -134,7 +134,7 @@ def test_list_blobs(self): class FakeBlob(object): - def __init__(self, name, bucket): + def __init__(self, name, bucket, *args, **kwargs): self.name = name self._bucket = bucket # type: FakeBucket self._exists = False @@ -354,45 +354,13 @@ def put(self, url, data=None, headers=None): if not headers.get('Content-Range', '').endswith('*'): upload.finish() return FakeResponse(200) - return FakeResponse(smart_open.gcs._UPLOAD_INCOMPLETE_STATUS_CODES[0]) + return FakeResponse(308) @staticmethod def _blob_with_url(url, client): # type: (str, FakeClient) -> FakeBlobUpload return client.uploads.get(url) - -class FakeAuthorizedSessionTest(unittest.TestCase): - def setUp(self): - self.client = FakeClient() - self.credentials = FakeCredentials(self.client) - self.session = FakeAuthorizedSession(self.credentials) - self.bucket = FakeBucket(self.client, 'test-bucket') - self.blob = FakeBlob('test-blob', self.bucket) - self.upload_url = self.blob.create_resumable_upload_session() - - def test_delete(self): - self.session.delete(self.upload_url) - self.assertFalse(self.blob.exists()) - self.assertDictEqual(self.client.uploads, {}) - - #def test_unfinished_put_does_not_write_to_blob(self): - # Removed as google client handles this for us - - def test_finished_put_writes_to_blob(self): - data = io.BytesIO(b'test') - headers = { - 'Content-Range': 'bytes 0-3/4', - 'Content-Length': str(4), - } - response = self.session.put(self.upload_url, data, headers=headers) - self.assertEqual(response.status_code, 200) - self.session._blob_with_url(self.upload_url, self.client) - blob_contents = self.blob.download_as_bytes() - data.seek(0) - self.assertEqual(blob_contents, data.read()) - - if DISABLE_MOCKS: storage_client = google.cloud.storage.Client() else: @@ -506,10 +474,6 @@ def tearDownModule(): # noqa @maybe_mock_gcs class ReaderTest(unittest.TestCase): def setUp(self): - # lower the multipart upload size, to speed up these tests - self.old_min_buffer_size = smart_open.gcs.DEFAULT_BUFFER_SIZE - smart_open.gcs.DEFAULT_BUFFER_SIZE = 5 * 1024**2 - ignore_resource_warnings() def tearDown(self): @@ -521,7 +485,7 @@ def test_iter(self): put_to_bucket(contents=expected) # connect to fake GCS and read from the fake key we filled above - fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) + fin = smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=BLOB_NAME, mode='rb') output = [line.rstrip(b'\n') for line in fin] self.assertEqual(output, expected.split(b'\n')) @@ -529,7 +493,7 @@ def test_iter_context_manager(self): # same thing but using a context manager expected = u"hello wořld\nhow are you?".encode('utf8') put_to_bucket(contents=expected) - with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) as fin: + with smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=BLOB_NAME, mode='rb') as fin: output = [line.rstrip(b'\n') for line in fin] self.assertEqual(output, expected.split(b'\n')) @@ -539,7 +503,7 @@ def test_read(self): put_to_bucket(contents=content) logger.debug('content: %r len: %r', content, len(content)) - fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) + fin = smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=BLOB_NAME, mode='rb') self.assertEqual(content[:6], fin.read(6)) self.assertEqual(content[6:14], fin.read(8)) # ř is 2 bytes self.assertEqual(content[14:], fin.read()) # read the rest @@ -549,7 +513,7 @@ def test_seek_beginning(self): content = u"hello wořld\nhow are you?".encode('utf8') put_to_bucket(contents=content) - fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) + fin = smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=BLOB_NAME, mode='rb') self.assertEqual(content[:6], fin.read(6)) self.assertEqual(content[6:14], fin.read(8)) # ř is 2 bytes @@ -564,7 +528,7 @@ def test_seek_start(self): content = u"hello wořld\nhow are you?".encode('utf8') put_to_bucket(contents=content) - fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) + fin = smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=BLOB_NAME, mode='rb') seek = fin.seek(6) self.assertEqual(seek, 6) self.assertEqual(fin.tell(), 6) @@ -575,9 +539,9 @@ def test_seek_current(self): content = u"hello wořld\nhow are you?".encode('utf8') put_to_bucket(contents=content) - fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) + fin = smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=BLOB_NAME, mode='rb') self.assertEqual(fin.read(5), b'hello') - seek = fin.seek(1, whence=smart_open.constants.WHENCE_CURRENT) + seek = fin.seek(1, smart_open.constants.WHENCE_CURRENT) self.assertEqual(seek, 6) self.assertEqual(fin.read(6), u'wořld'.encode('utf-8')) @@ -586,8 +550,8 @@ def test_seek_end(self): content = u"hello wořld\nhow are you?".encode('utf8') put_to_bucket(contents=content) - fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) - seek = fin.seek(-4, whence=smart_open.constants.WHENCE_END) + fin = smart_open.gcs.open(BUCKET_NAME, BLOB_NAME, mode='rb') + seek = fin.seek(-4, smart_open.constants.WHENCE_END) self.assertEqual(seek, len(content) - 4) self.assertEqual(fin.read(), b'you?') @@ -595,11 +559,11 @@ def test_detect_eof(self): content = u"hello wořld\nhow are you?".encode('utf8') put_to_bucket(contents=content) - fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) + fin = smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=BLOB_NAME, mode='rb') fin.read() eof = fin.tell() self.assertEqual(eof, len(content)) - fin.seek(0, whence=smart_open.constants.WHENCE_END) + fin.seek(0, smart_open.constants.WHENCE_END) self.assertEqual(eof, fin.tell()) def test_read_gzip(self): @@ -613,7 +577,7 @@ def test_read_gzip(self): # # Make sure we're reading things correctly. # - with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) as fin: + with smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=BLOB_NAME, mode='rb') as fin: self.assertEqual(fin.read(), buf.getvalue()) # @@ -624,7 +588,7 @@ def test_read_gzip(self): self.assertEqual(zipfile.read(), expected) logger.debug('starting actual test') - with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) as fin: + with smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=BLOB_NAME, mode='rb') as fin: with gzip.GzipFile(fileobj=fin) as zipfile: actual = zipfile.read() @@ -634,7 +598,7 @@ def test_readline(self): content = b'englishman\nin\nnew\nyork\n' put_to_bucket(contents=content) - with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) as fin: + with smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=BLOB_NAME, mode='rb') as fin: fin.readline() self.assertEqual(fin.tell(), content.index(b'\n')+1) @@ -645,21 +609,11 @@ def test_readline(self): expected = [b'englishman\n', b'in\n', b'new\n', b'york\n'] self.assertEqual(expected, actual) - def test_readline_tiny_buffer(self): - content = b'englishman\nin\nnew\nyork\n' - put_to_bucket(contents=content) - - with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME, buffer_size=8) as fin: - actual = list(fin) - - expected = [b'englishman\n', b'in\n', b'new\n', b'york\n'] - self.assertEqual(expected, actual) - def test_read0_does_not_return_data(self): content = b'englishman\nin\nnew\nyork\n' put_to_bucket(contents=content) - with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) as fin: + with smart_open.gcs.open(BUCKET_NAME, BLOB_NAME, mode='rb') as fin: data = fin.read(0) self.assertEqual(data, b'') @@ -668,7 +622,7 @@ def test_read_past_end(self): content = b'englishman\nin\nnew\nyork\n' put_to_bucket(contents=content) - with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) as fin: + with smart_open.gcs.open(BUCKET_NAME, BLOB_NAME, mode='rb') as fin: data = fin.read(100) self.assertEqual(data, content) @@ -690,7 +644,7 @@ def test_write_01(self): """Does writing into GCS work correctly?""" test_string = u"žluťoučký koníček".encode('utf8') - with smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME) as fout: + with smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=WRITE_BLOB_NAME, mode='wb') as fout: fout.write(test_string) with smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME), "rb") as fin: @@ -701,7 +655,7 @@ def test_write_01(self): def test_incorrect_input(self): """Does gcs write fail on incorrect input?""" try: - with smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME) as fin: + with smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=WRITE_BLOB_NAME, mode='wb') as fin: fin.write(None) except TypeError: pass @@ -710,19 +664,16 @@ def test_incorrect_input(self): def test_write_02(self): """Does gcs write unicode-utf8 conversion work?""" - smart_open_write = smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME) + smart_open_write = smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=WRITE_BLOB_NAME, mode='wb') smart_open_write.tell() logger.info("smart_open_write: %r", smart_open_write) with smart_open_write as fout: fout.write(u"testžížáč".encode("utf-8")) self.assertEqual(fout.tell(), 14) - # def test_write_03(self): - # we no longer need to test if part writes work correctly as thats covered by the google cloud storage functionality - def test_write_04(self): """Does writing no data cause key with an empty value to be created?""" - smart_open_write = smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME) + smart_open_write = smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=WRITE_BLOB_NAME, mode='wb') with smart_open_write as fout: # noqa pass @@ -731,25 +682,27 @@ def test_write_04(self): self.assertEqual(output, []) - def test_write_05(self): - """Do blob_properties get applied?""" - smart_open_write = smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME, - blob_properties={ - "content_type": "random/x-test", - "content_encoding": "coded" - } - ) - with smart_open_write as fout: # noqa - assert fout._blob.content_type == "random/x-test" - assert fout._blob.content_encoding == "coded" + # def test_write_05(self): + # """Do blob_properties get applied?""" + # smart_open_write = smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=WRITE_BLOB_NAME, mode='wb', + # blob_properties={ + # "content_type": "random/x-test", + # "content_encoding": "coded" + # } + # ) + + # # TODO: Mock + assert calls to set content_type + content_encoding + # with smart_open_write as fout: # noqa + # assert fout.content_type == "random/x-test" + # assert fout.content_encoding == "coded" def test_gzip(self): expected = u'а не спеть ли мне песню... о любви'.encode('utf-8') - with smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME) as fout: + with smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=WRITE_BLOB_NAME, mode='wb') as fout: with gzip.GzipFile(fileobj=fout, mode='w') as zipfile: zipfile.write(expected) - with smart_open.gcs.Reader(BUCKET_NAME, WRITE_BLOB_NAME) as fin: + with smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=WRITE_BLOB_NAME, mode='rb') as fin: with gzip.GzipFile(fileobj=fin) as zipfile: actual = zipfile.read() @@ -762,11 +715,11 @@ def test_buffered_writer_wrapper_works(self): """ expected = u'не думай о секундах свысока' - with smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME) as fout: + with smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=WRITE_BLOB_NAME, mode='wb') as fout: with io.BufferedWriter(fout) as sub_out: sub_out.write(expected.encode('utf-8')) - with smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME), 'rb') as fin: + with smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME), mode='rb') as fin: with io.TextIOWrapper(fin, encoding='utf-8') as text: actual = text.read() @@ -817,11 +770,9 @@ def test_read_never_returns_none(self): """read should never return None.""" test_string = u"ветер по морю гуляет..." with smart_open.gcs.open(BUCKET_NAME, BLOB_NAME, "wb") as fout: - self.assertEqual(fout.name, BLOB_NAME) fout.write(test_string.encode('utf8')) r = smart_open.gcs.open(BUCKET_NAME, BLOB_NAME, "rb") - self.assertEqual(r.name, BLOB_NAME) self.assertEqual(r.read(), test_string.encode("utf-8")) self.assertEqual(r.read(), b"") self.assertEqual(r.read(), b"") From d7025996dff8697d2d9ec6889b2703194a729609 Mon Sep 17 00:00:00 2001 From: cadnce Date: Mon, 10 Oct 2022 19:58:44 +1100 Subject: [PATCH 3/7] Fix lint issues --- smart_open/gcs.py | 4 ++-- smart_open/tests/test_gcs.py | 23 ++++++++++++----------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/smart_open/gcs.py b/smart_open/gcs.py index 0ef6c5f6..2649e2ec 100644 --- a/smart_open/gcs.py +++ b/smart_open/gcs.py @@ -5,7 +5,6 @@ # This code is distributed under the terms and conditions # from the MIT License (MIT). # - """Implements file-like objects for reading and writing to/from GCS.""" import logging @@ -73,7 +72,8 @@ def open( blob_properties: dict, optional Set properties on blob before writing. For writing only. blob_open_kwargs: dict, optional - Set properties on the blob + Set properties for opening the blob, passed through directly to + the google-cloud-storage library """ if blob_open_kwargs is None: blob_open_kwargs = {} diff --git a/smart_open/tests/test_gcs.py b/smart_open/tests/test_gcs.py index 60e004d3..44442cc6 100644 --- a/smart_open/tests/test_gcs.py +++ b/smart_open/tests/test_gcs.py @@ -10,7 +10,6 @@ import io import logging import os -import pickle import time import uuid import unittest @@ -172,7 +171,7 @@ def upload_from_string(self, data): # https://googleapis.dev/python/storage/latest/blobs.html#google.cloud.storage.blob.Blob.upload_from_string if isinstance(data, str): data = bytes(data, 'utf8') - self.__contents.truncate(0) + self.__contents.truncate(0) self.__contents.seek(0) self.__contents.write(data) self._size = self.__contents.tell() @@ -181,15 +180,15 @@ def write(self, data): self.upload_from_string(data) def open( - self, - mode, - chunk_size=None, - ignore_flush=None, - encoding=None, - errors=None, - newline=None, - **kwargs, - ): + self, + mode, + chunk_size=None, + ignore_flush=None, + encoding=None, + errors=None, + newline=None, + **kwargs, + ): if mode.startswith('r'): self.__contents.seek(0) return self.__contents @@ -361,6 +360,7 @@ def _blob_with_url(url, client): # type: (str, FakeClient) -> FakeBlobUpload return client.uploads.get(url) + if DISABLE_MOCKS: storage_client = google.cloud.storage.Client() else: @@ -788,6 +788,7 @@ def test_round_trip(self): self.assertEqual(test_string, actual) + if __name__ == '__main__': logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) unittest.main() From ebff7cbf30723bef13e332db8fc90106893e126d Mon Sep 17 00:00:00 2001 From: cadnce Date: Thu, 13 Oct 2022 11:57:34 +1100 Subject: [PATCH 4/7] Restore original apis --- smart_open/gcs.py | 107 ++++++++++++++++++++++++++++------- smart_open/tests/test_gcs.py | 80 +++++++++++++------------- 2 files changed, 124 insertions(+), 63 deletions(-) diff --git a/smart_open/gcs.py b/smart_open/gcs.py index 2649e2ec..b663d937 100644 --- a/smart_open/gcs.py +++ b/smart_open/gcs.py @@ -32,6 +32,10 @@ _DEFAULT_WRITE_OPEN_KWARGS = {'ignore_flush': True} +def __noop(): + pass + + def parse_uri(uri_as_string): sr = smart_open.utils.safe_urlsplit(uri_as_string) assert sr.scheme == SCHEME @@ -50,6 +54,7 @@ def open( bucket_id, blob_id, mode, + buffer_size=None, min_part_size=_DEFAULT_MIN_PART_SIZE, client=None, # type: google.cloud.storage.Client blob_properties=None, @@ -65,6 +70,8 @@ def open( The name of the blob within the bucket. mode: str The mode for opening the object. Must be either "rb" or "wb". + buffer_size: + deprecated min_part_size: int, optional The minimum part size for multipart uploads. For writing only. client: google.cloud.storage.Client, optional @@ -74,38 +81,96 @@ def open( blob_open_kwargs: dict, optional Set properties for opening the blob, passed through directly to the google-cloud-storage library + """ if blob_open_kwargs is None: blob_open_kwargs = {} - if blob_properties is None: - blob_properties = {} if client is None: client = google.cloud.storage.Client() - bucket = client.bucket(bucket_id) - if not bucket.exists(): - raise google.cloud.exceptions.NotFound(f'bucket {bucket_id} not found') - if mode in (constants.READ_BINARY, 'r', 'rt'): - blob = bucket.get_blob(blob_id) - if blob is None: - raise google.cloud.exceptions.NotFound(f'blob {blob_id} not found in {bucket_id}') + _blob = Reader(bucket=bucket_id, + key=blob_id, + client=client, + blob_open_kwargs=blob_open_kwargs) elif mode in (constants.WRITE_BINARY, 'w', 'wt'): - blob_open_kwargs = {**_DEFAULT_WRITE_OPEN_KWARGS, **blob_open_kwargs} - blob = bucket.blob( - blob_id, - chunk_size=min_part_size, - ) - - for k, v in blob_properties.items(): - try: - setattr(blob, k, v) - except AttributeError: - logger.warn(f'Unable to set property {k} on blob') + _blob = Writer(bucket=bucket_id, + blob=blob_id, + min_part_size=min_part_size, + client=client, + blob_properties=blob_properties, + blob_open_kwargs=blob_open_kwargs, + ) else: raise NotImplementedError(f'GCS support for mode {mode} not implemented') - return blob.open(mode, **blob_open_kwargs) + return _blob + + +def Reader(bucket, + key, + buffer_size=None, + line_terminator=None, + client=None, + blob_open_kwargs=None, + ): + + if blob_open_kwargs is None: + blob_open_kwargs = {} + if client is None: + client = google.cloud.storage.Client() + + bkt = client.bucket(bucket) + blob = bkt.get_blob(key) + + if blob is None: + raise google.cloud.exceptions.NotFound(f'blob {key} not found in {bucket}') + + return blob.open('rb', **blob_open_kwargs) + + +def Writer(bucket, + blob, + min_part_size=None, + client=None, + blob_properties=None, + blob_open_kwargs=None, + ): + + if blob_open_kwargs is None: + blob_open_kwargs = {} + if blob_properties is None: + blob_properties = {} + if client is None: + client = google.cloud.storage.Client() + + blob_open_kwargs = {**_DEFAULT_WRITE_OPEN_KWARGS, **blob_open_kwargs} + + g_bucket = client.bucket(bucket) + if not g_bucket.exists(): + raise google.cloud.exceptions.NotFound(f'bucket {bucket} not found') + + g_blob = g_bucket.blob( + blob, + chunk_size=min_part_size, + ) + + for k, v in blob_properties.items(): + try: + setattr(g_blob, k, v) + except AttributeError: + logger.warn(f'Unable to set property {k} on blob') + + _blob = g_blob.open('wb', **blob_open_kwargs) + + if hasattr(_blob, 'terminate'): + raise RuntimeWarning( + 'Unexpected incompatibility between dependency and google-cloud-storage dependency.' + 'Things may not work as expected' + ) + _blob.terminate = __noop + + return _blob diff --git a/smart_open/tests/test_gcs.py b/smart_open/tests/test_gcs.py index 44442cc6..35ca3497 100644 --- a/smart_open/tests/test_gcs.py +++ b/smart_open/tests/test_gcs.py @@ -54,8 +54,8 @@ def __init__(self, client, name=None): # self.client.register_bucket(self) - def blob(self, blob_id, *args, **kwargs): - return self.blobs.get(blob_id, FakeBlob(blob_id, self, *args, **kwargs)) + def blob(self, blob_id, **kwargs): + return self.blobs.get(blob_id, FakeBlob(blob_id, self, **kwargs)) def delete(self): self.client.delete_bucket(self) @@ -66,7 +66,7 @@ def delete(self): def exists(self): return self._exists - def get_blob(self, blob_id, **kwargs): + def get_blob(self, blob_id): try: return self.blobs[blob_id] except KeyError as e: @@ -133,13 +133,12 @@ def test_list_blobs(self): class FakeBlob(object): - def __init__(self, name, bucket, *args, **kwargs): + def __init__(self, name, bucket, **kwargs): self.name = name self._bucket = bucket # type: FakeBucket self._exists = False self.__contents = io.BytesIO() self.__contents.close = lambda: None - self._size = 0 self._create_if_not_exists() def create_resumable_upload_session(self): @@ -174,7 +173,6 @@ def upload_from_string(self, data): self.__contents.truncate(0) self.__contents.seek(0) self.__contents.write(data) - self._size = self.__contents.tell() def write(self, data): self.upload_from_string(data) @@ -199,7 +197,9 @@ def bucket(self): @property def size(self): - return self._size if self._size > 0 else None + if self.__contents.tell() == 0: + return None + return self.__contents.tell() def _create_if_not_exists(self): self._bucket.register_blob(self) @@ -485,7 +485,7 @@ def test_iter(self): put_to_bucket(contents=expected) # connect to fake GCS and read from the fake key we filled above - fin = smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=BLOB_NAME, mode='rb') + fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) output = [line.rstrip(b'\n') for line in fin] self.assertEqual(output, expected.split(b'\n')) @@ -493,7 +493,7 @@ def test_iter_context_manager(self): # same thing but using a context manager expected = u"hello wořld\nhow are you?".encode('utf8') put_to_bucket(contents=expected) - with smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=BLOB_NAME, mode='rb') as fin: + with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) as fin: output = [line.rstrip(b'\n') for line in fin] self.assertEqual(output, expected.split(b'\n')) @@ -503,7 +503,7 @@ def test_read(self): put_to_bucket(contents=content) logger.debug('content: %r len: %r', content, len(content)) - fin = smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=BLOB_NAME, mode='rb') + fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) self.assertEqual(content[:6], fin.read(6)) self.assertEqual(content[6:14], fin.read(8)) # ř is 2 bytes self.assertEqual(content[14:], fin.read()) # read the rest @@ -513,7 +513,7 @@ def test_seek_beginning(self): content = u"hello wořld\nhow are you?".encode('utf8') put_to_bucket(contents=content) - fin = smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=BLOB_NAME, mode='rb') + fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) self.assertEqual(content[:6], fin.read(6)) self.assertEqual(content[6:14], fin.read(8)) # ř is 2 bytes @@ -528,7 +528,7 @@ def test_seek_start(self): content = u"hello wořld\nhow are you?".encode('utf8') put_to_bucket(contents=content) - fin = smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=BLOB_NAME, mode='rb') + fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) seek = fin.seek(6) self.assertEqual(seek, 6) self.assertEqual(fin.tell(), 6) @@ -539,7 +539,7 @@ def test_seek_current(self): content = u"hello wořld\nhow are you?".encode('utf8') put_to_bucket(contents=content) - fin = smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=BLOB_NAME, mode='rb') + fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) self.assertEqual(fin.read(5), b'hello') seek = fin.seek(1, smart_open.constants.WHENCE_CURRENT) self.assertEqual(seek, 6) @@ -550,7 +550,7 @@ def test_seek_end(self): content = u"hello wořld\nhow are you?".encode('utf8') put_to_bucket(contents=content) - fin = smart_open.gcs.open(BUCKET_NAME, BLOB_NAME, mode='rb') + fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) seek = fin.seek(-4, smart_open.constants.WHENCE_END) self.assertEqual(seek, len(content) - 4) self.assertEqual(fin.read(), b'you?') @@ -559,7 +559,7 @@ def test_detect_eof(self): content = u"hello wořld\nhow are you?".encode('utf8') put_to_bucket(contents=content) - fin = smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=BLOB_NAME, mode='rb') + fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) fin.read() eof = fin.tell() self.assertEqual(eof, len(content)) @@ -577,7 +577,7 @@ def test_read_gzip(self): # # Make sure we're reading things correctly. # - with smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=BLOB_NAME, mode='rb') as fin: + with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) as fin: self.assertEqual(fin.read(), buf.getvalue()) # @@ -588,7 +588,7 @@ def test_read_gzip(self): self.assertEqual(zipfile.read(), expected) logger.debug('starting actual test') - with smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=BLOB_NAME, mode='rb') as fin: + with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) as fin: with gzip.GzipFile(fileobj=fin) as zipfile: actual = zipfile.read() @@ -598,7 +598,7 @@ def test_readline(self): content = b'englishman\nin\nnew\nyork\n' put_to_bucket(contents=content) - with smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=BLOB_NAME, mode='rb') as fin: + with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) as fin: fin.readline() self.assertEqual(fin.tell(), content.index(b'\n')+1) @@ -613,7 +613,7 @@ def test_read0_does_not_return_data(self): content = b'englishman\nin\nnew\nyork\n' put_to_bucket(contents=content) - with smart_open.gcs.open(BUCKET_NAME, BLOB_NAME, mode='rb') as fin: + with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) as fin: data = fin.read(0) self.assertEqual(data, b'') @@ -622,7 +622,7 @@ def test_read_past_end(self): content = b'englishman\nin\nnew\nyork\n' put_to_bucket(contents=content) - with smart_open.gcs.open(BUCKET_NAME, BLOB_NAME, mode='rb') as fin: + with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) as fin: data = fin.read(100) self.assertEqual(data, content) @@ -644,7 +644,7 @@ def test_write_01(self): """Does writing into GCS work correctly?""" test_string = u"žluťoučký koníček".encode('utf8') - with smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=WRITE_BLOB_NAME, mode='wb') as fout: + with smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME) as fout: fout.write(test_string) with smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME), "rb") as fin: @@ -655,7 +655,7 @@ def test_write_01(self): def test_incorrect_input(self): """Does gcs write fail on incorrect input?""" try: - with smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=WRITE_BLOB_NAME, mode='wb') as fin: + with smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME) as fin: fin.write(None) except TypeError: pass @@ -664,7 +664,7 @@ def test_incorrect_input(self): def test_write_02(self): """Does gcs write unicode-utf8 conversion work?""" - smart_open_write = smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=WRITE_BLOB_NAME, mode='wb') + smart_open_write = smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME) smart_open_write.tell() logger.info("smart_open_write: %r", smart_open_write) with smart_open_write as fout: @@ -673,7 +673,7 @@ def test_write_02(self): def test_write_04(self): """Does writing no data cause key with an empty value to be created?""" - smart_open_write = smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=WRITE_BLOB_NAME, mode='wb') + smart_open_write = smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME) with smart_open_write as fout: # noqa pass @@ -682,27 +682,13 @@ def test_write_04(self): self.assertEqual(output, []) - # def test_write_05(self): - # """Do blob_properties get applied?""" - # smart_open_write = smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=WRITE_BLOB_NAME, mode='wb', - # blob_properties={ - # "content_type": "random/x-test", - # "content_encoding": "coded" - # } - # ) - - # # TODO: Mock + assert calls to set content_type + content_encoding - # with smart_open_write as fout: # noqa - # assert fout.content_type == "random/x-test" - # assert fout.content_encoding == "coded" - def test_gzip(self): expected = u'а не спеть ли мне песню... о любви'.encode('utf-8') - with smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=WRITE_BLOB_NAME, mode='wb') as fout: + with smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME) as fout: with gzip.GzipFile(fileobj=fout, mode='w') as zipfile: zipfile.write(expected) - with smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=WRITE_BLOB_NAME, mode='rb') as fin: + with smart_open.gcs.Reader(BUCKET_NAME, WRITE_BLOB_NAME) as fin: with gzip.GzipFile(fileobj=fin) as zipfile: actual = zipfile.read() @@ -715,11 +701,11 @@ def test_buffered_writer_wrapper_works(self): """ expected = u'не думай о секундах свысока' - with smart_open.gcs.open(bucket_id=BUCKET_NAME, blob_id=WRITE_BLOB_NAME, mode='wb') as fout: + with smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME) as fout: with io.BufferedWriter(fout) as sub_out: sub_out.write(expected.encode('utf-8')) - with smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME), mode='rb') as fin: + with smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME), 'rb') as fin: with io.TextIOWrapper(fin, encoding='utf-8') as text: actual = text.read() @@ -757,6 +743,16 @@ def test_flush_close(self): fout.flush() fout.close() + def test_terminate(self): + text = u'там за туманами, вечными, пьяными'.encode('utf-8') + fout = smart_open.gcs.open(BUCKET_NAME, 'key', 'wb') + fout.write(text) + fout.terminate() + + with self.assertRaises(google.api_core.exceptions.NotFound): + with smart_open.gcs.open(BUCKET_NAME, 'key', 'rb') as fin: + fin.read() + @maybe_mock_gcs class OpenTest(unittest.TestCase): From 34a311174273746605c87bb61a8c9f9defa782e1 Mon Sep 17 00:00:00 2001 From: cadnce Date: Thu, 13 Oct 2022 16:27:54 +1100 Subject: [PATCH 5/7] Maintain all original interfaces --- smart_open/gcs.py | 8 +- smart_open/tests/test_gcs.py | 615 +++++------------------------------ 2 files changed, 82 insertions(+), 541 deletions(-) diff --git a/smart_open/gcs.py b/smart_open/gcs.py index b663d937..dffa68cb 100644 --- a/smart_open/gcs.py +++ b/smart_open/gcs.py @@ -32,9 +32,6 @@ _DEFAULT_WRITE_OPEN_KWARGS = {'ignore_flush': True} -def __noop(): - pass - def parse_uri(uri_as_string): sr = smart_open.utils.safe_urlsplit(uri_as_string) @@ -86,9 +83,6 @@ def open( if blob_open_kwargs is None: blob_open_kwargs = {} - if client is None: - client = google.cloud.storage.Client() - if mode in (constants.READ_BINARY, 'r', 'rt'): _blob = Reader(bucket=bucket_id, key=blob_id, @@ -171,6 +165,6 @@ def Writer(bucket, 'Unexpected incompatibility between dependency and google-cloud-storage dependency.' 'Things may not work as expected' ) - _blob.terminate = __noop + _blob.terminate = lambda: None return _blob diff --git a/smart_open/tests/test_gcs.py b/smart_open/tests/test_gcs.py index 35ca3497..bf4d6550 100644 --- a/smart_open/tests/test_gcs.py +++ b/smart_open/tests/test_gcs.py @@ -5,12 +5,10 @@ # This code is distributed under the terms and conditions # from the MIT License (MIT). # -import gzip -import inspect + import io import logging import os -import time import uuid import unittest from unittest import mock @@ -28,13 +26,6 @@ WRITE_BLOB_NAME = 'test-write-blob' DISABLE_MOCKS = os.environ.get('SO_DISABLE_GCS_MOCKS') == "1" -RESUMABLE_SESSION_URI_TEMPLATE = ( - 'https://www.googleapis.com/upload/storage/v1/b/' - '%(bucket)s' - '/o?uploadType=resumable&upload_id=' - '%(upload_id)s' -) - logger = logging.getLogger(__name__) @@ -141,56 +132,23 @@ def __init__(self, name, bucket, **kwargs): self.__contents.close = lambda: None self._create_if_not_exists() - def create_resumable_upload_session(self): - resumeable_upload_url = RESUMABLE_SESSION_URI_TEMPLATE % dict( - bucket=self._bucket.name, - upload_id=str(uuid.uuid4()), - ) - upload = FakeBlobUpload(resumeable_upload_url, self) - self._bucket.register_upload(upload) - return resumeable_upload_url + self.open = mock.Mock(side_effect=self._mock_open) + def _mock_open(self, mode, *args, **kwargs): + if mode.startswith('r'): + self.__contents.seek(0) + return self.__contents + def delete(self): self._bucket.delete_blob(self) self._exists = False - def download_as_bytes(self, start=0, end=None): - # mimics Google's API by returning bytes - # https://googleapis.dev/python/storage/latest/blobs.html#google.cloud.storage.blob.Blob.download_as_bytes - if end is None: - end = self.__contents.tell() - self.__contents.seek(start) - return self.__contents.read(end - start) - def exists(self, client=None): return self._exists - def upload_from_string(self, data): - # mimics Google's API by accepting bytes or str, despite the method name - # https://googleapis.dev/python/storage/latest/blobs.html#google.cloud.storage.blob.Blob.upload_from_string - if isinstance(data, str): - data = bytes(data, 'utf8') - self.__contents.truncate(0) - self.__contents.seek(0) - self.__contents.write(data) - def write(self, data): self.upload_from_string(data) - def open( - self, - mode, - chunk_size=None, - ignore_flush=None, - encoding=None, - errors=None, - newline=None, - **kwargs, - ): - if mode.startswith('r'): - self.__contents.seek(0) - return self.__contents - @property def bucket(self): return self._bucket @@ -206,52 +164,8 @@ def _create_if_not_exists(self): self._exists = True -class FakeBlobTest(unittest.TestCase): - def setUp(self): - self.client = FakeClient() - self.bucket = FakeBucket(self.client, 'test-bucket') - - def test_create_resumable_upload_session(self): - blob = FakeBlob('fake-blob', self.bucket) - resumable_upload_url = blob.create_resumable_upload_session() - self.assertTrue(resumable_upload_url in self.client.uploads) - - def test_delete(self): - blob = FakeBlob('fake-blob', self.bucket) - blob.delete() - self.assertFalse(blob.exists()) - self.assertEqual(self.bucket.list_blobs(), []) - - def test_upload_download(self): - blob = FakeBlob('fake-blob', self.bucket) - contents = b'test' - blob.upload_from_string(contents) - self.assertEqual(blob.download_as_bytes(), b'test') - self.assertEqual(blob.download_as_bytes(start=2), b'st') - self.assertEqual(blob.download_as_bytes(end=2), b'te') - self.assertEqual(blob.download_as_bytes(start=2, end=3), b's') - - def test_size(self): - blob = FakeBlob('fake-blob', self.bucket) - self.assertEqual(blob.size, None) - blob.upload_from_string(b'test') - self.assertEqual(blob.size, 4) - - -class FakeCredentials(object): - def __init__(self, client): - self.client = client # type: FakeClient - - def before_request(self, *args, **kwargs): - pass - - class FakeClient(object): - def __init__(self, credentials=None): - if credentials is None: - credentials = FakeCredentials(self) - self._credentials = credentials # type: FakeCredentials - self.uploads = OrderedDict() + def __init__(self): self.__buckets = OrderedDict() def bucket(self, bucket_id): @@ -306,461 +220,38 @@ def test_create_bucket(self): self.assertEqual(actual, bucket) -class FakeBlobUpload(object): - def __init__(self, url, blob): - self.url = url - self.blob = blob # type: FakeBlob - self._finished = False - self.__contents = io.BytesIO() - - def write(self, data): - self.__contents.write(data) - - def finish(self): - if not self._finished: - self.__contents.seek(0) - data = self.__contents.read() - self.blob.upload_from_string(data) - self._finished = True - - def terminate(self): - self.blob.delete() - self.__contents = None - +def get_test_bucket(client): + return client.bucket(BUCKET_NAME) -class FakeResponse(object): - def __init__(self, status_code=200, text=None): - self.status_code = status_code - self.text = text - -class FakeAuthorizedSession(object): - def __init__(self, credentials): - self._credentials = credentials # type: FakeCredentials - - def delete(self, upload_url): - upload = self._credentials.client.uploads.pop(upload_url) - upload.terminate() - - def put(self, url, data=None, headers=None): - upload = self._credentials.client.uploads[url] - - if data is not None: - if hasattr(data, 'read'): - upload.write(data.read()) - else: - upload.write(data) - if not headers.get('Content-Range', '').endswith('*'): - upload.finish() - return FakeResponse(200) - return FakeResponse(308) - - @staticmethod - def _blob_with_url(url, client): - # type: (str, FakeClient) -> FakeBlobUpload - return client.uploads.get(url) - - -if DISABLE_MOCKS: - storage_client = google.cloud.storage.Client() -else: - storage_client = FakeClient() - - -def get_bucket(): - return storage_client.bucket(BUCKET_NAME) - - -def get_blob(): - bucket = get_bucket() - return bucket.blob(BLOB_NAME) - - -def cleanup_bucket(): - bucket = get_bucket() +def cleanup_test_bucket(client): + bucket = get_test_bucket(client) blobs = bucket.list_blobs() for blob in blobs: blob.delete() -def put_to_bucket(contents, num_attempts=12, sleep_time=5): - logger.debug('%r', locals()) - - # - # In real life, it can take a few seconds for the bucket to become ready. - # If we try to write to the key while the bucket while it isn't ready, we - # will get a StorageError: NotFound. - # - for attempt in range(num_attempts): - try: - blob = get_blob() - blob.upload_from_string(contents) - return - except google.cloud.exceptions.NotFound as err: - logger.error('caught %r, retrying', err) - time.sleep(sleep_time) - - assert False, 'failed to create bucket %s after %d attempts' % (BUCKET_NAME, num_attempts) - - -def mock_gcs(class_or_func): - """Mock all methods of a class or a function.""" - if inspect.isclass(class_or_func): - for attr in class_or_func.__dict__: - if callable(getattr(class_or_func, attr)): - setattr(class_or_func, attr, mock_gcs_func(getattr(class_or_func, attr))) - return class_or_func - else: - return mock_gcs_func(class_or_func) - - -def mock_gcs_func(func): - """Mock the function and provide additional required arguments.""" - assert callable(func), '%r is not a callable function' % func - - def inner(*args, **kwargs): - # - # Is it a function or a method? The latter requires a self parameter. - # - signature = inspect.signature(func) - - fake_session = FakeAuthorizedSession(storage_client._credentials) - patched_client = mock.patch( - 'google.cloud.storage.Client', - return_value=storage_client, - ) - patched_session = mock.patch( - 'google.auth.transport.requests.AuthorizedSession', - return_value=fake_session, - ) - - with patched_client, patched_session: - if not hasattr(signature, 'self'): - return func(*args, **kwargs) - else: - return func(signature.self, *args, **kwargs) - - return inner - - -def maybe_mock_gcs(func): - if DISABLE_MOCKS: - return func - else: - return mock_gcs(func) - - -@maybe_mock_gcs -def setUpModule(): # noqa - """Called once by unittest when initializing this module. Set up the - test GCS bucket. - """ - storage_client.create_bucket(BUCKET_NAME) - - -@maybe_mock_gcs -def tearDownModule(): # noqa - """Called once by unittest when tearing down this module. Empty and - removes the test GCS bucket. - """ - try: - bucket = get_bucket() - bucket.delete() - except google.cloud.exceptions.NotFound: - pass - - -@maybe_mock_gcs -class ReaderTest(unittest.TestCase): +class OpenTest(unittest.TestCase): def setUp(self): - ignore_resource_warnings() - - def tearDown(self): - cleanup_bucket() - - def test_iter(self): - """Are GCS files iterated over correctly?""" - expected = u"hello wořld\nhow are you?".encode('utf8') - put_to_bucket(contents=expected) - - # connect to fake GCS and read from the fake key we filled above - fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) - output = [line.rstrip(b'\n') for line in fin] - self.assertEqual(output, expected.split(b'\n')) - - def test_iter_context_manager(self): - # same thing but using a context manager - expected = u"hello wořld\nhow are you?".encode('utf8') - put_to_bucket(contents=expected) - with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) as fin: - output = [line.rstrip(b'\n') for line in fin] - self.assertEqual(output, expected.split(b'\n')) - - def test_read(self): - """Are GCS files read correctly?""" - content = u"hello wořld\nhow are you?".encode('utf8') - put_to_bucket(contents=content) - logger.debug('content: %r len: %r', content, len(content)) - - fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) - self.assertEqual(content[:6], fin.read(6)) - self.assertEqual(content[6:14], fin.read(8)) # ř is 2 bytes - self.assertEqual(content[14:], fin.read()) # read the rest - - def test_seek_beginning(self): - """Does seeking to the beginning of GCS files work correctly?""" - content = u"hello wořld\nhow are you?".encode('utf8') - put_to_bucket(contents=content) - - fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) - self.assertEqual(content[:6], fin.read(6)) - self.assertEqual(content[6:14], fin.read(8)) # ř is 2 bytes - - fin.seek(0) - self.assertEqual(content, fin.read()) # no size given => read whole file - - fin.seek(0) - self.assertEqual(content, fin.read(-1)) # same thing - - def test_seek_start(self): - """Does seeking from the start of GCS files work correctly?""" - content = u"hello wořld\nhow are you?".encode('utf8') - put_to_bucket(contents=content) - - fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) - seek = fin.seek(6) - self.assertEqual(seek, 6) - self.assertEqual(fin.tell(), 6) - self.assertEqual(fin.read(6), u'wořld'.encode('utf-8')) - - def test_seek_current(self): - """Does seeking from the middle of GCS files work correctly?""" - content = u"hello wořld\nhow are you?".encode('utf8') - put_to_bucket(contents=content) - - fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) - self.assertEqual(fin.read(5), b'hello') - seek = fin.seek(1, smart_open.constants.WHENCE_CURRENT) - self.assertEqual(seek, 6) - self.assertEqual(fin.read(6), u'wořld'.encode('utf-8')) - - def test_seek_end(self): - """Does seeking from the end of GCS files work correctly?""" - content = u"hello wořld\nhow are you?".encode('utf8') - put_to_bucket(contents=content) - - fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) - seek = fin.seek(-4, smart_open.constants.WHENCE_END) - self.assertEqual(seek, len(content) - 4) - self.assertEqual(fin.read(), b'you?') - - def test_detect_eof(self): - content = u"hello wořld\nhow are you?".encode('utf8') - put_to_bucket(contents=content) - - fin = smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) - fin.read() - eof = fin.tell() - self.assertEqual(eof, len(content)) - fin.seek(0, smart_open.constants.WHENCE_END) - self.assertEqual(eof, fin.tell()) - - def test_read_gzip(self): - expected = u'раcцветали яблони и груши, поплыли туманы над рекой...'.encode('utf-8') - buf = io.BytesIO() - buf.close = lambda: None # keep buffer open so that we can .getvalue() - with gzip.GzipFile(fileobj=buf, mode='w') as zipfile: - zipfile.write(expected) - put_to_bucket(contents=buf.getvalue()) - - # - # Make sure we're reading things correctly. - # - with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) as fin: - self.assertEqual(fin.read(), buf.getvalue()) - - # - # Make sure the buffer we wrote is legitimate gzip. - # - sanity_buf = io.BytesIO(buf.getvalue()) - with gzip.GzipFile(fileobj=sanity_buf) as zipfile: - self.assertEqual(zipfile.read(), expected) - - logger.debug('starting actual test') - with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) as fin: - with gzip.GzipFile(fileobj=fin) as zipfile: - actual = zipfile.read() - - self.assertEqual(expected, actual) - - def test_readline(self): - content = b'englishman\nin\nnew\nyork\n' - put_to_bucket(contents=content) - - with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) as fin: - fin.readline() - self.assertEqual(fin.tell(), content.index(b'\n')+1) - - fin.seek(0) - actual = list(fin) - self.assertEqual(fin.tell(), len(content)) - - expected = [b'englishman\n', b'in\n', b'new\n', b'york\n'] - self.assertEqual(expected, actual) - - def test_read0_does_not_return_data(self): - content = b'englishman\nin\nnew\nyork\n' - put_to_bucket(contents=content) - - with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) as fin: - data = fin.read(0) - - self.assertEqual(data, b'') - - def test_read_past_end(self): - content = b'englishman\nin\nnew\nyork\n' - put_to_bucket(contents=content) - - with smart_open.gcs.Reader(BUCKET_NAME, BLOB_NAME) as fin: - data = fin.read(100) - - self.assertEqual(data, content) - + if DISABLE_MOCKS: + self.client = google.cloud.storage.Client() + else: + self.client = FakeClient() + self.mock_gcs = mock.patch('smart_open.gcs.google.cloud.storage.Client').start() + self.mock_gcs.return_value = self.client -@maybe_mock_gcs -class WriterTest(unittest.TestCase): - """ - Test writing into GCS files. + self.client.create_bucket(BUCKET_NAME) - """ - def setUp(self): ignore_resource_warnings() def tearDown(self): - cleanup_bucket() - - def test_write_01(self): - """Does writing into GCS work correctly?""" - test_string = u"žluťoučký koníček".encode('utf8') - - with smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME) as fout: - fout.write(test_string) - - with smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME), "rb") as fin: - output = list(fin) - - self.assertEqual(output, [test_string]) - - def test_incorrect_input(self): - """Does gcs write fail on incorrect input?""" - try: - with smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME) as fin: - fin.write(None) - except TypeError: - pass - else: - self.fail() - - def test_write_02(self): - """Does gcs write unicode-utf8 conversion work?""" - smart_open_write = smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME) - smart_open_write.tell() - logger.info("smart_open_write: %r", smart_open_write) - with smart_open_write as fout: - fout.write(u"testžížáč".encode("utf-8")) - self.assertEqual(fout.tell(), 14) - - def test_write_04(self): - """Does writing no data cause key with an empty value to be created?""" - smart_open_write = smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME) - with smart_open_write as fout: # noqa - pass - - # read back the same key and check its content - output = list(smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME))) - - self.assertEqual(output, []) - - def test_gzip(self): - expected = u'а не спеть ли мне песню... о любви'.encode('utf-8') - with smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME) as fout: - with gzip.GzipFile(fileobj=fout, mode='w') as zipfile: - zipfile.write(expected) - - with smart_open.gcs.Reader(BUCKET_NAME, WRITE_BLOB_NAME) as fin: - with gzip.GzipFile(fileobj=fin) as zipfile: - actual = zipfile.read() - - self.assertEqual(expected, actual) - - def test_buffered_writer_wrapper_works(self): - """ - Ensure that we can wrap a smart_open gcs stream in a BufferedWriter, which - passes a memoryview object to the underlying stream in python >= 2.7 - """ - expected = u'не думай о секундах свысока' - - with smart_open.gcs.Writer(BUCKET_NAME, WRITE_BLOB_NAME) as fout: - with io.BufferedWriter(fout) as sub_out: - sub_out.write(expected.encode('utf-8')) - - with smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME), 'rb') as fin: - with io.TextIOWrapper(fin, encoding='utf-8') as text: - actual = text.read() - - self.assertEqual(expected, actual) - - def test_binary_iterator(self): - expected = u"выйду ночью в поле с конём".encode('utf-8').split(b' ') - put_to_bucket(contents=b"\n".join(expected)) - with smart_open.gcs.open(BUCKET_NAME, BLOB_NAME, 'rb') as fin: - actual = [line.rstrip() for line in fin] - self.assertEqual(expected, actual) - - def test_nonexisting_bucket(self): - expected = u"выйду ночью в поле с конём".encode('utf-8') - with self.assertRaises(google.api_core.exceptions.NotFound): - with smart_open.gcs.open('thisbucketdoesntexist', 'mykey', 'wb') as fout: - fout.write(expected) - - def test_read_nonexisting_key(self): - with self.assertRaises(google.api_core.exceptions.NotFound): - with smart_open.gcs.open(BUCKET_NAME, 'my_nonexisting_key', 'rb') as fin: - fin.read() - - def test_double_close(self): - text = u'там за туманами, вечными, пьяными'.encode('utf-8') - fout = smart_open.gcs.open(BUCKET_NAME, 'key', 'wb') - fout.write(text) - fout.close() - fout.close() - - def test_flush_close(self): - text = u'там за туманами, вечными, пьяными'.encode('utf-8') - fout = smart_open.gcs.open(BUCKET_NAME, 'key', 'wb') - fout.write(text) - fout.flush() - fout.close() - - def test_terminate(self): - text = u'там за туманами, вечными, пьяными'.encode('utf-8') - fout = smart_open.gcs.open(BUCKET_NAME, 'key', 'wb') - fout.write(text) - fout.terminate() - - with self.assertRaises(google.api_core.exceptions.NotFound): - with smart_open.gcs.open(BUCKET_NAME, 'key', 'rb') as fin: - fin.read() - - -@maybe_mock_gcs -class OpenTest(unittest.TestCase): - def setUp(self): - ignore_resource_warnings() + cleanup_test_bucket(self.client) + bucket = get_test_bucket(self.client) + bucket.delete() - def tearDown(self): - cleanup_bucket() + if not DISABLE_MOCKS: + self.mock_gcs.stop() def test_read_never_returns_none(self): """read should never return None.""" @@ -784,6 +275,62 @@ def test_round_trip(self): self.assertEqual(test_string, actual) +class WriterTest(unittest.TestCase): + def setUp(self): + self.client = FakeClient() + self.mock_gcs = mock.patch('smart_open.gcs.google.cloud.storage.Client').start() + self.mock_gcs.return_value = self.client + + self.client.create_bucket(BUCKET_NAME) + + def tearDown(self): + cleanup_test_bucket(self.client) + bucket = get_test_bucket(self.client) + bucket.delete() + self.mock_gcs.stop() + + def test_property_passthrough(self): + blob_properties = {'content_type': 'text/utf-8'} + + smart_open.gcs.Writer(BUCKET_NAME, BLOB_NAME, blob_properties=blob_properties) + + b = self.client.bucket(BUCKET_NAME).get_blob(BLOB_NAME) + + for k, v in blob_properties.items(): + self.assertEqual(getattr(b, k), v) + + def test_default_open_kwargs(self): + smart_open.gcs.Writer(BUCKET_NAME, BLOB_NAME) + + self.client.bucket(BUCKET_NAME).get_blob(BLOB_NAME) \ + .open.assert_called_once_with('wb', **smart_open.gcs._DEFAULT_WRITE_OPEN_KWARGS) + + def test_open_kwargs_passthrough(self): + open_kwargs = {'ignore_flush': True, 'property': 'value', 'something': 2} + + smart_open.gcs.Writer(BUCKET_NAME, BLOB_NAME, blob_open_kwargs=open_kwargs) + + self.client.bucket(BUCKET_NAME).get_blob(BLOB_NAME) \ + .open.assert_called_once_with('wb', **open_kwargs) + + def test_non_existing_bucket(self): + with self.assertRaises(google.cloud.exceptions.NotFound): + smart_open.gcs.Writer('unknown_bucket', BLOB_NAME) + + def test_will_warn_for_conflict(self): + # Add a terminate() to simulate that being added to the underlying google-cloud-storage library + original_mo = FakeBlob._mock_open + + def fake_open_with_terminate(*args, **kwargs): + original_output = original_mo(*args, **kwargs) + original_output.terminate = lambda: None + return original_output + + FakeBlob._mock_open = fake_open_with_terminate + + with self.assertRaises(RuntimeWarning): + smart_open.gcs.Writer(BUCKET_NAME, BLOB_NAME) + if __name__ == '__main__': logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) From a70c64b28066f95f9a595b6af5bc93a3314165f3 Mon Sep 17 00:00:00 2001 From: cadnce Date: Thu, 13 Oct 2022 16:28:41 +1100 Subject: [PATCH 6/7] Add deprecation warning to un-used parameters --- smart_open/gcs.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/smart_open/gcs.py b/smart_open/gcs.py index dffa68cb..d486867f 100644 --- a/smart_open/gcs.py +++ b/smart_open/gcs.py @@ -8,6 +8,7 @@ """Implements file-like objects for reading and writing to/from GCS.""" import logging +import warnings try: import google.cloud.exceptions @@ -46,6 +47,9 @@ def open_uri(uri, mode, transport_params): kwargs = smart_open.utils.check_kwargs(open, transport_params) return open(parsed_uri['bucket_id'], parsed_uri['blob_id'], mode, **kwargs) +def warn_deprecated(parameter_name): + message = f"Parameter {parameter_name} is deprecated, this parameter no-longer has any effect" + warnings.warn(message, UserWarning) def open( bucket_id, @@ -83,6 +87,9 @@ def open( if blob_open_kwargs is None: blob_open_kwargs = {} + if buffer_size is not None: + warn_deprecated('buffer_size') + if mode in (constants.READ_BINARY, 'r', 'rt'): _blob = Reader(bucket=bucket_id, key=blob_id, @@ -116,6 +123,10 @@ def Reader(bucket, blob_open_kwargs = {} if client is None: client = google.cloud.storage.Client() + if buffer_size is not None: + warn_deprecated('buffer_size') + if line_terminator is not None: + warn_deprecated('line_terminator') bkt = client.bucket(bucket) blob = bkt.get_blob(key) From 12e22b7e2cccb2cb481345d24932de053802ecbd Mon Sep 17 00:00:00 2001 From: cadnce Date: Thu, 13 Oct 2022 16:45:26 +1100 Subject: [PATCH 7/7] Fix linting --- smart_open/gcs.py | 3 ++- smart_open/tests/test_gcs.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/smart_open/gcs.py b/smart_open/gcs.py index d486867f..260067e7 100644 --- a/smart_open/gcs.py +++ b/smart_open/gcs.py @@ -33,7 +33,6 @@ _DEFAULT_WRITE_OPEN_KWARGS = {'ignore_flush': True} - def parse_uri(uri_as_string): sr = smart_open.utils.safe_urlsplit(uri_as_string) assert sr.scheme == SCHEME @@ -47,10 +46,12 @@ def open_uri(uri, mode, transport_params): kwargs = smart_open.utils.check_kwargs(open, transport_params) return open(parsed_uri['bucket_id'], parsed_uri['blob_id'], mode, **kwargs) + def warn_deprecated(parameter_name): message = f"Parameter {parameter_name} is deprecated, this parameter no-longer has any effect" warnings.warn(message, UserWarning) + def open( bucket_id, blob_id, diff --git a/smart_open/tests/test_gcs.py b/smart_open/tests/test_gcs.py index bf4d6550..52db1069 100644 --- a/smart_open/tests/test_gcs.py +++ b/smart_open/tests/test_gcs.py @@ -138,7 +138,7 @@ def _mock_open(self, mode, *args, **kwargs): if mode.startswith('r'): self.__contents.seek(0) return self.__contents - + def delete(self): self._bucket.delete_blob(self) self._exists = False @@ -275,6 +275,7 @@ def test_round_trip(self): self.assertEqual(test_string, actual) + class WriterTest(unittest.TestCase): def setUp(self): self.client = FakeClient()