diff --git a/examples/extract/passing_compressed_document.py b/examples/extract/passing_compressed_document.py index 0f7bbe1..80b83c3 100644 --- a/examples/extract/passing_compressed_document.py +++ b/examples/extract/passing_compressed_document.py @@ -19,11 +19,8 @@ content_type='text/html', charset='utf-8', extraction_model='review_list', - is_document_compressed=False, # specify that the sent document is not compressed to compress it document_compression_format=CompressionFormat.GZIP # specify that compression format - # If both is_document_compressed and document_compression_format are ignored, the raw HTML sould be sent - # If is_document_compressed is set to false and CompressionFormat set to GZIP, the SDK will automatically compress the document to gzip - # is_document_compressed is set to false and CompressionFormat set to ZSTD or DEFLATE, the document passed to ExtractionConfig must be manually compressed + # If the body is not compressed, Scrapfly will automatically compress it based on the document_compression_format value ) ) diff --git a/scrapfly/extraction_config.py b/scrapfly/extraction_config.py index 21e2d60..5aeb922 100644 --- a/scrapfly/extraction_config.py +++ b/scrapfly/extraction_config.py @@ -1,7 +1,7 @@ import json import warnings from enum import Enum -from typing import Optional, Dict +from typing import Optional, Dict, Union from urllib.parse import quote_plus from base64 import urlsafe_b64encode from .api_config import BaseApiConfig @@ -14,6 +14,7 @@ class CompressionFormat(Enum): Attributes: GZIP: gzip format. ZSTD: zstd format. + DEFLATE: deflate format. """ GZIP = "gzip" @@ -26,7 +27,7 @@ class ExtractionConfigError(Exception): class ExtractionConfig(BaseApiConfig): - body: str + body: Union[str, bytes] content_type: str url: Optional[str] = None charset: Optional[str] = None @@ -45,7 +46,7 @@ class ExtractionConfig(BaseApiConfig): def __init__( self, - body: str, + body: Union[str, bytes], content_type: str, url: Optional[str] = None, charset: Optional[str] = None, @@ -85,23 +86,51 @@ def __init__( self.extraction_prompt = extraction_prompt self.extraction_model = extraction_model self.is_document_compressed = is_document_compressed - self.document_compression_format = document_compression_format + self.document_compression_format = CompressionFormat(document_compression_format) if document_compression_format else None self.webhook = webhook self.raise_on_upstream_error = raise_on_upstream_error - if self.document_compression_format is not None: - if self.is_document_compressed is None: - raise ExtractionConfigError( - 'When declaring compression format, your must declare the is_document_compressed parameter to compress the document or skip it.' - ) - if self.is_document_compressed is False: - if self.document_compression_format == CompressionFormat.GZIP: - import gzip - self.body = gzip.compress(bytes(self.body, 'utf-8')) - else: + if isinstance(body, bytes) or document_compression_format: + compression_format = detect_compression_format(body) + + if compression_format is not None: + self.is_document_compressed = True + + if self.document_compression_format and compression_format != self.document_compression_format: raise ExtractionConfigError( - f'Auto compression for {self.document_compression_format.value} format is not available. You can manually compress to {self.document_compression_format.value} or choose the gzip format for auto compression.' + f'The detected compression format `{compression_format}` does not match declared format `{self.document_compression_format}`. ' + f'You must pass the compression format or disable compression.' ) + + self.document_compression_format = compression_format + + else: + self.is_document_compressed = False + + if self.is_document_compressed is False: + compression_foramt = CompressionFormat(self.document_compression_format) if self.document_compression_format else None + + if isinstance(self.body, str) and compression_foramt: + self.body = self.body.encode('utf-8') + + if compression_foramt == CompressionFormat.GZIP: + import gzip + self.body = gzip.compress(self.body) + + elif compression_foramt == CompressionFormat.ZSTD: + try: + import zstandard as zstd + except ImportError: + raise ExtractionConfigError( + f'zstandard is not installed. You must run pip install zstandard' + f' to auto compress into zstd or use compression formats.' + ) + self.body = zstd.compress(self.body) + + elif compression_foramt == CompressionFormat.DEFLATE: + import zlib + compressor = zlib.compressobj(wbits=-zlib.MAX_WBITS) # raw deflate compression + self.body = compressor.compress(self.body) + compressor.flush() def to_api_params(self, key: str) -> Dict: params = { @@ -135,3 +164,118 @@ def to_api_params(self, key: str) -> Dict: params['webhook_name'] = self.webhook return params + + def to_dict(self) -> Dict: + """ + Export the ExtractionConfig instance to a plain dictionary. + """ + + if self.is_document_compressed is True: + compression_foramt = CompressionFormat(self.document_compression_format) if self.document_compression_format else None + + if compression_foramt == CompressionFormat.GZIP: + import gzip + self.body = gzip.decompress(self.body) + + elif compression_foramt == CompressionFormat.ZSTD: + import zstandard as zstd + self.body = zstd.decompress(self.body) + + elif compression_foramt == CompressionFormat.DEFLATE: + import zlib + decompressor = zlib.decompressobj(wbits=-zlib.MAX_WBITS) + self.body = decompressor.decompress(self.body) + decompressor.flush() + + if isinstance(self.body, bytes): + self.body = self.body.decode('utf-8') + self.is_document_compressed = False + + return { + 'body': self.body, + 'content_type': self.content_type, + 'url': self.url, + 'charset': self.charset, + 'extraction_template': self.extraction_template, + 'extraction_ephemeral_template': self.extraction_ephemeral_template, + 'extraction_prompt': self.extraction_prompt, + 'extraction_model': self.extraction_model, + 'is_document_compressed': self.is_document_compressed, + 'document_compression_format': CompressionFormat(self.document_compression_format).value if self.document_compression_format else None, + 'webhook': self.webhook, + 'raise_on_upstream_error': self.raise_on_upstream_error, + } + + @staticmethod + def from_dict(extraction_config_dict: Dict) -> 'ExtractionConfig': + """Create an ExtractionConfig instance from a dictionary.""" + body = extraction_config_dict.get('body', None) + content_type = extraction_config_dict.get('content_type', None) + url = extraction_config_dict.get('url', None) + charset = extraction_config_dict.get('charset', None) + extraction_template = extraction_config_dict.get('extraction_template', None) + extraction_ephemeral_template = extraction_config_dict.get('extraction_ephemeral_template', None) + extraction_prompt = extraction_config_dict.get('extraction_prompt', None) + extraction_model = extraction_config_dict.get('extraction_model', None) + is_document_compressed = extraction_config_dict.get('is_document_compressed', None) + + document_compression_format = extraction_config_dict.get('document_compression_format', None) + document_compression_format = CompressionFormat(document_compression_format) if document_compression_format else None + + webhook = extraction_config_dict.get('webhook', None) + raise_on_upstream_error = extraction_config_dict.get('raise_on_upstream_error', True) + + return ExtractionConfig( + body=body, + content_type=content_type, + url=url, + charset=charset, + extraction_template=extraction_template, + extraction_ephemeral_template=extraction_ephemeral_template, + extraction_prompt=extraction_prompt, + extraction_model=extraction_model, + is_document_compressed=is_document_compressed, + document_compression_format=document_compression_format, + webhook=webhook, + raise_on_upstream_error=raise_on_upstream_error + ) + + +def detect_compression_format(data) -> Optional[CompressionFormat]: + """ + Detects the compression type of the given data. + + Args: + data: The compressed data as bytes. + + Returns: + The name of the compression type ("gzip", "zstd", "deflate", "unknown"). + """ + + if len(data) < 2: + return None + + # gzip + if data[0] == 0x1f and data[1] == 0x8b: + return CompressionFormat.GZIP + + # zstd + zstd_magic_numbers = [ + b'\x1e\xb5\x2f\xfd', # v0.1 + b'\x22\xb5\x2f\xfd', # v0.2 + b'\x23\xb5\x2f\xfd', # v0.3 + b'\x24\xb5\x2f\xfd', # v0.4 + b'\x25\xb5\x2f\xfd', # v0.5 + b'\x26\xb5\x2f\xfd', # v0.6 + b'\x27\xb5\x2f\xfd', # v0.7 + b'\x28\xb5\x2f\xfd', # v0.8 + ] + for magic in zstd_magic_numbers: + if data[:len(magic)] == magic: + return CompressionFormat.ZSTD + + # deflate + if data[0] == 0x78: + if data[1] in (0x01, 0x5E, 0x9C, 0xDA): + return CompressionFormat.DEFLATE + + return None \ No newline at end of file diff --git a/scrapfly/scrape_config.py b/scrapfly/scrape_config.py index adcc5da..04e01c1 100644 --- a/scrapfly/scrape_config.py +++ b/scrapfly/scrape_config.py @@ -423,3 +423,150 @@ def from_exported_config(config:str) -> 'ScrapeConfig': auto_scroll=data['auto_scroll'], cost_budget=data['cost_budget'] ) + + def to_dict(self) -> Dict: + """ + Export the ScrapeConfig instance to a plain dictionary. + Useful for JSON-serialization or other external storage. + """ + + return { + 'url': self.url, + 'retry': self.retry, + 'method': self.method, + 'country': self.country, + 'render_js': self.render_js, + 'cache': self.cache, + 'cache_clear': self.cache_clear, + 'ssl': self.ssl, + 'dns': self.dns, + 'asp': self.asp, + 'debug': self.debug, + 'raise_on_upstream_error': self.raise_on_upstream_error, + 'cache_ttl': self.cache_ttl, + 'proxy_pool': self.proxy_pool, + 'session': self.session, + 'tags': list(self.tags), + 'format': Format(self.format).value if self.format else None, + 'format_options': [FormatOption(option).value for option in self.format_options] if self.format_options else None, + 'extraction_template': self.extraction_template, + 'extraction_ephemeral_template': self.extraction_ephemeral_template, + 'extraction_prompt': self.extraction_prompt, + 'extraction_model': self.extraction_model, + 'correlation_id': self.correlation_id, + 'cookies': CaseInsensitiveDict(self.cookies), + 'body': self.body, + 'data': None if self.body else self.data, + 'headers': CaseInsensitiveDict(self.headers), + 'js': self.js, + 'rendering_wait': self.rendering_wait, + 'wait_for_selector': self.wait_for_selector, + 'session_sticky_proxy': self.session_sticky_proxy, + 'screenshots': self.screenshots, + 'screenshot_flags': [ScreenshotFlag(flag).value for flag in self.screenshot_flags] if self.screenshot_flags else None, + 'webhook': self.webhook, + 'timeout': self.timeout, + 'js_scenario': self.js_scenario, + 'extract': self.extract, + 'lang': self.lang, + 'os': self.os, + 'auto_scroll': self.auto_scroll, + 'cost_budget': self.cost_budget, + } + + @staticmethod + def from_dict(scrape_config_dict: Dict) -> 'ScrapeConfig': + """Create a ScrapeConfig instance from a dictionary.""" + url = scrape_config_dict.get('url', None) + retry = scrape_config_dict.get('retry', False) + method = scrape_config_dict.get('method', 'GET') + country = scrape_config_dict.get('country', None) + render_js = scrape_config_dict.get('render_js', False) + cache = scrape_config_dict.get('cache', False) + cache_clear = scrape_config_dict.get('cache_clear', False) + ssl = scrape_config_dict.get('ssl', False) + dns = scrape_config_dict.get('dns', False) + asp = scrape_config_dict.get('asp', False) + debug = scrape_config_dict.get('debug', False) + raise_on_upstream_error = scrape_config_dict.get('raise_on_upstream_error', True) + cache_ttl = scrape_config_dict.get('cache_ttl', None) + proxy_pool = scrape_config_dict.get('proxy_pool', None) + session = scrape_config_dict.get('session', None) + tags = scrape_config_dict.get('tags', []) + + format = scrape_config_dict.get('format', None) + format = Format(format) if format else None + + format_options = scrape_config_dict.get('format_options', None) + format_options = [FormatOption(option) for option in format_options] if format_options else None + + extraction_template = scrape_config_dict.get('extraction_template', None) + extraction_ephemeral_template = scrape_config_dict.get('extraction_ephemeral_template', None) + extraction_prompt = scrape_config_dict.get('extraction_prompt', None) + extraction_model = scrape_config_dict.get('extraction_model', None) + correlation_id = scrape_config_dict.get('correlation_id', None) + cookies = scrape_config_dict.get('cookies', {}) + body = scrape_config_dict.get('body', None) + data = scrape_config_dict.get('data', None) + headers = scrape_config_dict.get('headers', {}) + js = scrape_config_dict.get('js', None) + rendering_wait = scrape_config_dict.get('rendering_wait', None) + wait_for_selector = scrape_config_dict.get('wait_for_selector', None) + screenshots = scrape_config_dict.get('screenshots', []) + + screenshot_flags = scrape_config_dict.get('screenshot_flags', []) + screenshot_flags = [ScreenshotFlag(flag) for flag in screenshot_flags] if screenshot_flags else None + + session_sticky_proxy = scrape_config_dict.get('session_sticky_proxy', False) + webhook = scrape_config_dict.get('webhook', None) + timeout = scrape_config_dict.get('timeout', None) + js_scenario = scrape_config_dict.get('js_scenario', None) + extract = scrape_config_dict.get('extract', None) + os = scrape_config_dict.get('os', None) + lang = scrape_config_dict.get('lang', None) + auto_scroll = scrape_config_dict.get('auto_scroll', None) + cost_budget = scrape_config_dict.get('cost_budget', None) + + return ScrapeConfig( + url=url, + retry=retry, + method=method, + country=country, + render_js=render_js, + cache=cache, + cache_clear=cache_clear, + ssl=ssl, + dns=dns, + asp=asp, + debug=debug, + raise_on_upstream_error=raise_on_upstream_error, + cache_ttl=cache_ttl, + proxy_pool=proxy_pool, + session=session, + tags=tags, + format=format, + format_options=format_options, + extraction_template=extraction_template, + extraction_ephemeral_template=extraction_ephemeral_template, + extraction_prompt=extraction_prompt, + extraction_model=extraction_model, + correlation_id=correlation_id, + cookies=cookies, + body=body, + data=data, + headers=headers, + js=js, + rendering_wait=rendering_wait, + wait_for_selector=wait_for_selector, + screenshots=screenshots, + screenshot_flags=screenshot_flags, + session_sticky_proxy=session_sticky_proxy, + webhook=webhook, + timeout=timeout, + js_scenario=js_scenario, + extract=extract, + os=os, + lang=lang, + auto_scroll=auto_scroll, + cost_budget=cost_budget, + ) diff --git a/scrapfly/scrapy/request.py b/scrapfly/scrapy/request.py index 4725974..9afcf66 100644 --- a/scrapfly/scrapy/request.py +++ b/scrapfly/scrapy/request.py @@ -1,6 +1,5 @@ from copy import deepcopy -from functools import partial -from typing import Dict, Optional, List +from typing import Dict, Optional from scrapy import Request @@ -32,6 +31,18 @@ def __init__(self, scrape_config:ScrapeConfig, meta:Dict={}, *args, **kwargs): **kwargs ) + def to_dict(self, *, spider: Optional["scrapy.Spider"] = None) -> dict: + if spider is None: + raise ValueError("The 'spider' argument is required to serialize the request.") + return super().to_dict(spider=spider) + + @classmethod + def from_dict(cls, data): + scrape_config_data = data['meta']['scrapfly_scrape_config'].to_dict() + scrape_config = ScrapeConfig.from_dict(scrape_config_data) + request = cls(scrape_config=scrape_config) + return request + def replace(self, *args, **kwargs): for x in [ 'meta', diff --git a/scrapfly/screenshot_config.py b/scrapfly/screenshot_config.py index d9e1143..c542b8f 100644 --- a/scrapfly/screenshot_config.py +++ b/scrapfly/screenshot_config.py @@ -151,3 +151,71 @@ def to_api_params(self, key:str) -> Dict: params['webhook_name'] = self.webhook return params + + def to_dict(self) -> Dict: + """ + Export the ScreenshotConfig instance to a plain dictionary. + """ + return { + 'url': self.url, + 'format': Format(self.format).value if self.format else None, + 'capture': self.capture, + 'resolution': self.resolution, + 'country': self.country, + 'timeout': self.timeout, + 'rendering_wait': self.rendering_wait, + 'wait_for_selector': self.wait_for_selector, + 'options': [Options(option).value for option in self.options] if self.options else None, + 'auto_scroll': self.auto_scroll, + 'js': self.js, + 'cache': self.cache, + 'cache_ttl': self.cache_ttl, + 'cache_clear': self.cache_clear, + 'webhook': self.webhook, + 'raise_on_upstream_error': self.raise_on_upstream_error + } + + @staticmethod + def from_dict(screenshot_config_dict: Dict) -> 'ScreenshotConfig': + """Create a ScreenshotConfig instance from a dictionary.""" + url = screenshot_config_dict.get('url', None) + + format = screenshot_config_dict.get('format', None) + format = Format(format) if format else None + + capture = screenshot_config_dict.get('capture', None) + resolution = screenshot_config_dict.get('resolution', None) + country = screenshot_config_dict.get('country', None) + timeout = screenshot_config_dict.get('timeout', None) + rendering_wait = screenshot_config_dict.get('rendering_wait', None) + wait_for_selector = screenshot_config_dict.get('wait_for_selector', None) + + options = screenshot_config_dict.get('options', None) + options = [Options(option) for option in options] if options else None + + auto_scroll = screenshot_config_dict.get('auto_scroll', None) + js = screenshot_config_dict.get('js', None) + cache = screenshot_config_dict.get('cache', None) + cache_ttl = screenshot_config_dict.get('cache_ttl', None) + cache_clear = screenshot_config_dict.get('cache_clear', None) + webhook = screenshot_config_dict.get('webhook', None) + raise_on_upstream_error = screenshot_config_dict.get('raise_on_upstream_error', True) + + return ScreenshotConfig( + url=url, + format=format, + capture=capture, + resolution=resolution, + country=country, + timeout=timeout, + rendering_wait=rendering_wait, + wait_for_selector=wait_for_selector, + options=options, + auto_scroll=auto_scroll, + js=js, + cache=cache, + cache_ttl=cache_ttl, + cache_clear=cache_clear, + webhook=webhook, + raise_on_upstream_error=raise_on_upstream_error + )