diff --git a/examples/extract/passing_compressed_document.py b/examples/extract/passing_compressed_document.py index 0f7bbe1..80b83c3 100644 --- a/examples/extract/passing_compressed_document.py +++ b/examples/extract/passing_compressed_document.py @@ -19,11 +19,8 @@ content_type='text/html', charset='utf-8', extraction_model='review_list', - is_document_compressed=False, # specify that the sent document is not compressed to compress it document_compression_format=CompressionFormat.GZIP # specify that compression format - # If both is_document_compressed and document_compression_format are ignored, the raw HTML sould be sent - # If is_document_compressed is set to false and CompressionFormat set to GZIP, the SDK will automatically compress the document to gzip - # is_document_compressed is set to false and CompressionFormat set to ZSTD or DEFLATE, the document passed to ExtractionConfig must be manually compressed + # If the body is not compressed, Scrapfly will automatically compress it based on the document_compression_format value ) ) diff --git a/scrapfly/extraction_config.py b/scrapfly/extraction_config.py index 847e456..5aeb922 100644 --- a/scrapfly/extraction_config.py +++ b/scrapfly/extraction_config.py @@ -1,7 +1,7 @@ import json import warnings from enum import Enum -from typing import Optional, Dict +from typing import Optional, Dict, Union from urllib.parse import quote_plus from base64 import urlsafe_b64encode from .api_config import BaseApiConfig @@ -14,6 +14,7 @@ class CompressionFormat(Enum): Attributes: GZIP: gzip format. ZSTD: zstd format. + DEFLATE: deflate format. """ GZIP = "gzip" @@ -26,7 +27,7 @@ class ExtractionConfigError(Exception): class ExtractionConfig(BaseApiConfig): - body: str + body: Union[str, bytes] content_type: str url: Optional[str] = None charset: Optional[str] = None @@ -45,7 +46,7 @@ class ExtractionConfig(BaseApiConfig): def __init__( self, - body: str, + body: Union[str, bytes], content_type: str, url: Optional[str] = None, charset: Optional[str] = None, @@ -85,26 +86,51 @@ def __init__( self.extraction_prompt = extraction_prompt self.extraction_model = extraction_model self.is_document_compressed = is_document_compressed - self.document_compression_format = document_compression_format + self.document_compression_format = CompressionFormat(document_compression_format) if document_compression_format else None self.webhook = webhook self.raise_on_upstream_error = raise_on_upstream_error - if self.document_compression_format is not None: - if self.is_document_compressed is None: - raise ExtractionConfigError( - 'When declaring compression format, your must declare the is_document_compressed parameter to compress the document or skip it.' - ) - if self.is_document_compressed is False: - compression_foramt = CompressionFormat(self.document_compression_format).value if self.document_compression_format else None + if isinstance(body, bytes) or document_compression_format: + compression_format = detect_compression_format(body) - if compression_foramt == CompressionFormat.GZIP.value: - import gzip - self.body = gzip.compress(bytes(self.body, 'utf-8')) - else: + if compression_format is not None: + self.is_document_compressed = True + + if self.document_compression_format and compression_format != self.document_compression_format: raise ExtractionConfigError( - f'Auto compression for {compression_foramt} format is not available. ' - f'You can manually compress to {compression_foramt} or choose the gzip format for auto compression.' + f'The detected compression format `{compression_format}` does not match declared format `{self.document_compression_format}`. ' + f'You must pass the compression format or disable compression.' ) + + self.document_compression_format = compression_format + + else: + self.is_document_compressed = False + + if self.is_document_compressed is False: + compression_foramt = CompressionFormat(self.document_compression_format) if self.document_compression_format else None + + if isinstance(self.body, str) and compression_foramt: + self.body = self.body.encode('utf-8') + + if compression_foramt == CompressionFormat.GZIP: + import gzip + self.body = gzip.compress(self.body) + + elif compression_foramt == CompressionFormat.ZSTD: + try: + import zstandard as zstd + except ImportError: + raise ExtractionConfigError( + f'zstandard is not installed. You must run pip install zstandard' + f' to auto compress into zstd or use compression formats.' + ) + self.body = zstd.compress(self.body) + + elif compression_foramt == CompressionFormat.DEFLATE: + import zlib + compressor = zlib.compressobj(wbits=-zlib.MAX_WBITS) # raw deflate compression + self.body = compressor.compress(self.body) + compressor.flush() def to_api_params(self, key: str) -> Dict: params = { @@ -143,17 +169,26 @@ def to_dict(self) -> Dict: """ Export the ExtractionConfig instance to a plain dictionary. """ - if self.is_document_compressed is False and self.document_compression_format: - compression_foramt = CompressionFormat(self.document_compression_format).value if self.document_compression_format else None - if compression_foramt == CompressionFormat.GZIP.value: + if self.is_document_compressed is True: + compression_foramt = CompressionFormat(self.document_compression_format) if self.document_compression_format else None + + if compression_foramt == CompressionFormat.GZIP: import gzip - self.body = gzip.decompress(self.body).decode('utf-8') - else: - raise ExtractionConfigError( - f'Auto decompression for {compression_foramt} format is not available. ' - f'You can manually decompress to {compression_foramt} or choose the gzip format for auto decompression.' - ) + self.body = gzip.decompress(self.body) + + elif compression_foramt == CompressionFormat.ZSTD: + import zstandard as zstd + self.body = zstd.decompress(self.body) + + elif compression_foramt == CompressionFormat.DEFLATE: + import zlib + decompressor = zlib.decompressobj(wbits=-zlib.MAX_WBITS) + self.body = decompressor.decompress(self.body) + decompressor.flush() + + if isinstance(self.body, bytes): + self.body = self.body.decode('utf-8') + self.is_document_compressed = False return { 'body': self.body, @@ -203,3 +238,44 @@ def from_dict(extraction_config_dict: Dict) -> 'ExtractionConfig': webhook=webhook, raise_on_upstream_error=raise_on_upstream_error ) + + +def detect_compression_format(data) -> Optional[CompressionFormat]: + """ + Detects the compression type of the given data. + + Args: + data: The compressed data as bytes. + + Returns: + The name of the compression type ("gzip", "zstd", "deflate", "unknown"). + """ + + if len(data) < 2: + return None + + # gzip + if data[0] == 0x1f and data[1] == 0x8b: + return CompressionFormat.GZIP + + # zstd + zstd_magic_numbers = [ + b'\x1e\xb5\x2f\xfd', # v0.1 + b'\x22\xb5\x2f\xfd', # v0.2 + b'\x23\xb5\x2f\xfd', # v0.3 + b'\x24\xb5\x2f\xfd', # v0.4 + b'\x25\xb5\x2f\xfd', # v0.5 + b'\x26\xb5\x2f\xfd', # v0.6 + b'\x27\xb5\x2f\xfd', # v0.7 + b'\x28\xb5\x2f\xfd', # v0.8 + ] + for magic in zstd_magic_numbers: + if data[:len(magic)] == magic: + return CompressionFormat.ZSTD + + # deflate + if data[0] == 0x78: + if data[1] in (0x01, 0x5E, 0x9C, 0xDA): + return CompressionFormat.DEFLATE + + return None \ No newline at end of file