diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 7752fe075d50c..6a1d5b07a67b6 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -3369,6 +3369,7 @@ def write_bigquery( self, project_id: str, dataset: str, + max_retry_cnt: int = 10, ray_remote_args: Dict[str, Any] = None, ) -> None: """Write the dataset to a BigQuery dataset table. @@ -3397,6 +3398,10 @@ def write_bigquery( dataset: The name of the dataset in the format of ``dataset_id.table_id``. The dataset is created if it doesn't already exist. The table_id is overwritten if it exists. + max_retry_cnt: The maximum number of retries that an individual block write + is retried due to BigQuery rate limiting errors. This isn't + related to Ray fault tolerance retries. The default number of retries + is 10. ray_remote_args: Kwargs passed to ray.remote in the write tasks. """ # noqa: E501 if ray_remote_args is None: @@ -3412,7 +3417,7 @@ def write_bigquery( else: ray_remote_args["max_retries"] = 0 - datasink = _BigQueryDatasink(project_id, dataset) + datasink = _BigQueryDatasink(project_id, dataset, max_retry_cnt=max_retry_cnt) self.write_datasink(datasink, ray_remote_args=ray_remote_args) @Deprecated diff --git a/python/ray/data/datasource/bigquery_datasink.py b/python/ray/data/datasource/bigquery_datasink.py index c0af8a16688cf..61d490c74c517 100644 --- a/python/ray/data/datasource/bigquery_datasink.py +++ b/python/ray/data/datasource/bigquery_datasink.py @@ -15,18 +15,24 @@ logger = logging.getLogger(__name__) -MAX_RETRY_CNT = 10 +DEFAULT_MAX_RETRY_CNT = 10 RATE_LIMIT_EXCEEDED_SLEEP_TIME = 11 class _BigQueryDatasink(Datasink): - def __init__(self, project_id: str, dataset: str) -> None: + def __init__( + self, + project_id: str, + dataset: str, + max_retry_cnt: int = DEFAULT_MAX_RETRY_CNT, + ) -> None: _check_import(self, module="google.cloud", package="bigquery") _check_import(self, module="google.cloud", package="bigquery_storage") _check_import(self, module="google.api_core", package="exceptions") self.project_id = project_id self.dataset = dataset + self.max_retry_cnt = max_retry_cnt def on_write_start(self) -> None: from google.api_core import exceptions @@ -71,25 +77,35 @@ def _write_single_block(block: Block, project_id: str, dataset: str) -> None: pq.write_table(block, fp, compression="SNAPPY") retry_cnt = 0 - while retry_cnt < MAX_RETRY_CNT: + while retry_cnt <= self.max_retry_cnt: with open(fp, "rb") as source_file: job = client.load_table_from_file( source_file, dataset, job_config=job_config ) - retry_cnt += 1 try: logger.info(job.result()) break except exceptions.Forbidden as e: - logger.info("Rate limit exceeded... Sleeping to try again") - logger.debug(e) + retry_cnt += 1 + if retry_cnt > self.max_retry_cnt: + break + logger.info( + "A block write encountered a rate limit exceeded error" + + f" {retry_cnt} time(s). Sleeping to try again." + ) + logging.debug(e) time.sleep(RATE_LIMIT_EXCEEDED_SLEEP_TIME) - # Raise exception if retry_cnt hits MAX_RETRY_CNT - if retry_cnt >= MAX_RETRY_CNT: + # Raise exception if retry_cnt exceeds max_retry_cnt + if retry_cnt > self.max_retry_cnt: + logger.info( + f"Maximum ({self.max_retry_cnt}) retry count exceeded. Ray" + + " will attempt to retry the block write via fault tolerance." + ) raise RuntimeError( - f"Write failed due to {MAX_RETRY_CNT} repeated" - + " API rate limit exceeded responses" + f"Write failed due to {retry_cnt}" + + " repeated API rate limit exceeded responses. Consider" + + " specifiying the max_retry_cnt kwarg with a higher value." ) _write_single_block = cached_remote_fn(_write_single_block)