diff --git a/README.md b/README.md index 866ab7018..116f152e3 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,30 @@ To run test, first install the `test` extra package pip install darwin-py[test] ``` +### Configuration + +#### Retry Configuration + +The SDK includes a retry mechanism for handling API rate limits (429) and server errors (500, 502, 503, 504). You can configure the retry behavior using the following environment variables: + +- `DARWIN_RETRY_INITIAL_WAIT`: Initial wait time in seconds between retries (default: 60) +- `DARWIN_RETRY_MAX_WAIT`: Maximum wait time in seconds between retries (default: 300) +- `DARWIN_RETRY_MAX_ATTEMPTS`: Maximum number of retry attempts (default: 10) + +Example configuration: +```bash +# Configure shorter retry intervals and fewer attempts +export DARWIN_RETRY_INITIAL_WAIT=30 +export DARWIN_RETRY_MAX_WAIT=120 +export DARWIN_RETRY_MAX_ATTEMPTS=5 +``` + +The retry mechanism will automatically handle: +- Rate limiting (HTTP 429) +- Server errors (HTTP 500, 502, 503, 504) + +For each retry attempt, you'll see a message indicating the type of error and the wait time before the next attempt. + ### Development See our development and QA environment installation recommendations [here](docs/DEV.md) diff --git a/darwin/backend_v2.py b/darwin/backend_v2.py index 69702efcc..8779fa56b 100644 --- a/darwin/backend_v2.py +++ b/darwin/backend_v2.py @@ -1,10 +1,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from urllib import parse -from requests.exceptions import HTTPError -from requests.models import Response -from tenacity import RetryCallState, retry, stop_after_attempt, wait_exponential_jitter - from darwin.datatypes import ItemId @@ -21,19 +17,6 @@ def wrapper(self, *args, **kwargs) -> Callable: return wrapper -def log_rate_limit_exceeded(retry_state: RetryCallState): - wait_time = retry_state.next_action.sleep - print(f"Rate limit exceeded. Retrying in {wait_time:.2f} seconds...") - - -def retry_if_status_code_429(retry_state: RetryCallState): - exception = retry_state.outcome.exception() - if isinstance(exception, HTTPError): - response: Response = exception.response - return response.status_code == 429 - return False - - class BackendV2: def __init__(self, client: "Client", default_team): # noqa F821 self._client = client @@ -255,12 +238,6 @@ def import_annotation( f"v2/teams/{team_slug}/items/{item_id}/import", payload=payload ) - @retry( - wait=wait_exponential_jitter(initial=60, max=300), - stop=stop_after_attempt(10), - retry=retry_if_status_code_429, - before_sleep=log_rate_limit_exceeded, - ) @inject_default_team_slug def register_items(self, payload: Dict[str, Any], team_slug: str) -> None: """ diff --git a/darwin/client.py b/darwin/client.py index febdd9520..0b847d810 100644 --- a/darwin/client.py +++ b/darwin/client.py @@ -1,16 +1,15 @@ import json import logging import os -import time import zlib from logging import Logger from pathlib import Path from typing import Dict, Iterator, List, Optional, Union, cast - +from requests.exceptions import HTTPError import requests from requests import Response from requests.adapters import HTTPAdapter - +from tenacity import RetryCallState, retry, stop_after_attempt, wait_exponential_jitter from darwin.backend_v2 import BackendV2 from darwin.config import Config from darwin.dataset.identifier import DatasetIdentifier @@ -52,6 +51,75 @@ ) from darwin.utils.get_item_count import get_item_count +INITIAL_WAIT = int(os.getenv("DARWIN_RETRY_INITIAL_WAIT", "60")) +MAX_WAIT = int(os.getenv("DARWIN_RETRY_MAX_WAIT", "300")) +MAX_RETRIES = int(os.getenv("DARWIN_RETRY_MAX_ATTEMPTS", "10")) + + +def log_rate_limit_exceeded(retry_state: RetryCallState): + wait_time = retry_state.next_action.sleep + print(f"Rate limit exceeded. Retrying in {wait_time:.2f} seconds...") + + +def retry_if_status_code_429(retry_state: RetryCallState): + exception = retry_state.outcome.exception() + if isinstance(exception, HTTPError): + response: Response = exception.response + return response.status_code == 429 + return False + + +def retry_if_status_code_429_or_5xx(retry_state: RetryCallState) -> bool: + """ + Determines if a request should be retried based on the response status code. + + Retries on: + - Rate limit (429) + - Server errors (500, 502, 503, 504) + + Parameters + ---------- + retry_state : RetryCallState + The current state of the retry mechanism + + Returns + ------- + bool + True if the request should be retried, False otherwise + """ + exception = retry_state.outcome.exception() + if isinstance(exception, HTTPError): + response: Response = exception.response + return response.status_code in { + 429, + 500, + 502, + 503, + 504, + } + return False + + +def log_retry_error(retry_state: RetryCallState) -> None: + """ + Logs information about why a request is being retried. + + Parameters + ---------- + retry_state : RetryCallState + The current state of the retry mechanism + """ + wait_time = retry_state.next_action.sleep + exception = retry_state.outcome.exception() + if isinstance(exception, HTTPError): + response: Response = exception.response + if response.status_code == 429: + print(f"Rate limit exceeded. Retrying in {wait_time:.2f} seconds...") + else: + print( + f"Server error {response.status_code}. Retrying in {wait_time:.2f} seconds..." + ) + class Client: def __init__( @@ -719,9 +787,14 @@ def default_base_url() -> str: return os.getenv("DARWIN_BASE_URL", "https://darwin.v7labs.com") def _get_headers( - self, team_slug: Optional[str] = None, compressed: bool = False + self, + team_slug: Optional[str] = None, + compressed: bool = False, + auth_token: Optional[bool] = False, ) -> Dict[str, str]: headers: Dict[str, str] = {"Content-Type": "application/json"} + if auth_token: + return headers api_key: Optional[str] = None team_config: Optional[Team] = self.config.get_team( @@ -742,15 +815,23 @@ def _get_headers( headers["User-Agent"] = f"darwin-py/{__version__}" return headers + @retry( + wait=wait_exponential_jitter(initial=INITIAL_WAIT, max=MAX_WAIT), + stop=stop_after_attempt(MAX_RETRIES), + retry=retry_if_status_code_429_or_5xx, + before_sleep=log_retry_error, + ) def _get_raw_from_full_url( self, url: str, team_slug: Optional[str] = None, - retry: bool = False, stream: bool = False, + auth_token: Optional[bool] = False, ) -> Response: response: Response = self.session.get( - url, headers=self._get_headers(team_slug), stream=stream + url, + headers=self._get_headers(team_slug, auth_token=auth_token), + stream=stream, ) self.log.debug( @@ -761,40 +842,36 @@ def _get_raw_from_full_url( ) self._raise_if_known_error(response, url) - - if not response.ok and retry: - time.sleep(10) - return self._get_raw_from_full_url( - url=url, team_slug=team_slug, retry=False, stream=stream - ) - response.raise_for_status() - return response def _get_raw( self, endpoint: str, team_slug: Optional[str] = None, - retry: bool = False, stream: bool = False, ) -> Response: return self._get_raw_from_full_url( - urljoin(self.url, endpoint), team_slug, retry=retry, stream=stream + urljoin(self.url, endpoint), team_slug, stream=stream ) def _get( - self, endpoint: str, team_slug: Optional[str] = None, retry: bool = False + self, endpoint: str, team_slug: Optional[str] = None ) -> Union[Dict[str, UnknownType], List[Dict[str, UnknownType]]]: - response = self._get_raw(endpoint, team_slug, retry) + response = self._get_raw(endpoint, team_slug) return self._decode_response(response) + @retry( + wait=wait_exponential_jitter(initial=INITIAL_WAIT, max=MAX_WAIT), + stop=stop_after_attempt(MAX_RETRIES), + retry=retry_if_status_code_429_or_5xx, + before_sleep=log_retry_error, + ) def _put_raw( self, endpoint: str, payload: Dict[str, UnknownType], team_slug: Optional[str] = None, - retry: bool = False, ) -> Response: response: requests.Response = self.session.put( urljoin(self.url, endpoint), @@ -810,13 +887,7 @@ def _put_raw( ) self._raise_if_known_error(response, urljoin(self.url, endpoint)) - - if not response.ok and retry: - time.sleep(10) - return self._put_raw(endpoint, payload=payload, retry=False) - response.raise_for_status() - return response def _put( @@ -824,17 +895,21 @@ def _put( endpoint: str, payload: Dict[str, UnknownType], team_slug: Optional[str] = None, - retry: bool = False, ) -> Union[Dict[str, UnknownType], List[Dict[str, UnknownType]]]: - response: Response = self._put_raw(endpoint, payload, team_slug, retry) + response: Response = self._put_raw(endpoint, payload, team_slug) return self._decode_response(response) + @retry( + wait=wait_exponential_jitter(initial=INITIAL_WAIT, max=MAX_WAIT), + stop=stop_after_attempt(MAX_RETRIES), + retry=retry_if_status_code_429_or_5xx, + before_sleep=log_retry_error, + ) def _post_raw( self, endpoint: str, payload: Optional[Dict[str, UnknownType]] = None, team_slug: Optional[str] = None, - retry: bool = False, ) -> Response: if payload is None: payload = {} @@ -868,13 +943,7 @@ def _post_raw( ) self._raise_if_known_error(response, urljoin(self.url, endpoint)) - - if not response.ok and retry: - time.sleep(10) - return self._post_raw(endpoint, payload=payload, retry=False) - response.raise_for_status() - return response def _post( @@ -882,17 +951,21 @@ def _post( endpoint: str, payload: Optional[Dict[str, UnknownType]] = None, team_slug: Optional[str] = None, - retry: bool = False, ) -> Union[Dict[str, UnknownType], List[Dict[str, UnknownType]]]: - response: Response = self._post_raw(endpoint, payload, team_slug, retry) + response: Response = self._post_raw(endpoint, payload, team_slug) return self._decode_response(response) + @retry( + wait=wait_exponential_jitter(initial=INITIAL_WAIT, max=MAX_WAIT), + stop=stop_after_attempt(MAX_RETRIES), + retry=retry_if_status_code_429_or_5xx, + before_sleep=log_retry_error, + ) def _delete( self, endpoint: str, payload: Optional[Dict[str, UnknownType]] = None, team_slug: Optional[str] = None, - retry: bool = False, ) -> Union[Dict[str, UnknownType], List[Dict[str, UnknownType]]]: if payload is None: payload = {} @@ -911,13 +984,7 @@ def _delete( ) self._raise_if_known_error(response, urljoin(self.url, endpoint)) - - if not response.ok and retry: - time.sleep(10) - return self._delete(endpoint, payload=payload, retry=False) - response.raise_for_status() - return self._decode_response(response) def _raise_if_known_error(self, response: Response, url: str) -> None: diff --git a/darwin/dataset/download_manager.py b/darwin/dataset/download_manager.py index 92ce6448a..e0426629e 100644 --- a/darwin/dataset/download_manager.py +++ b/darwin/dataset/download_manager.py @@ -9,13 +9,21 @@ from collections import Counter from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, + TYPE_CHECKING, +) import numpy as np import orjson as json import requests from PIL import Image -from requests.adapters import HTTPAdapter, Retry from rich.console import Console import darwin.datatypes as dt @@ -34,9 +42,12 @@ parse_darwin_json, ) +if TYPE_CHECKING: + from darwin.client import Client + def download_all_images_from_annotations( - api_key: str, + client: "Client", annotations_path: Path, images_path: Path, force_replace: bool = False, @@ -48,7 +59,7 @@ def download_all_images_from_annotations( ignore_slots: bool = False, ) -> Tuple[Callable[[], Iterable[Any]], int]: """ - Downloads the all images corresponding to a project. + Downloads all the images corresponding to a project. Parameters ---------- @@ -146,7 +157,7 @@ def download_all_images_from_annotations( download_functions: List = [] for annotation_path, force_slots in annotations_to_download_path: file_download_functions = lazy_download_image_from_annotation( - api_key, + client, annotation_path, images_path, annotation_format, @@ -164,8 +175,8 @@ def download_all_images_from_annotations( def lazy_download_image_from_annotation( - api_key: str, - annotation_path: Path, + client: "Client", + annotation: AnnotationFile, images_path: Path, annotation_format: str, use_folders: bool, @@ -179,10 +190,10 @@ def lazy_download_image_from_annotation( Parameters ---------- - api_key : str - API Key of the current team - annotation_path : Path - Path where the annotation is located + client : Client + Client of the current team + annotation : AnnotationFile + Annotation file corresponding to the dataset file images_path : Path Path where to download the image annotation_format : str @@ -202,8 +213,8 @@ def lazy_download_image_from_annotation( if annotation_format == "json": return _download_image_from_json_annotation( - api_key, - annotation_path, + client, + annotation, images_path, use_folders, video_frames, @@ -217,7 +228,7 @@ def lazy_download_image_from_annotation( def _download_image_from_json_annotation( - api_key: str, + client: "Client", annotation_path: Path, image_path: Path, use_folders: bool, @@ -239,7 +250,7 @@ def _download_image_from_json_annotation( if ignore_slots: return _download_single_slot_from_json_annotation( annotation, - api_key, + client, parent_path, annotation_path, video_frames, @@ -247,12 +258,12 @@ def _download_image_from_json_annotation( ) if force_slots: return _download_all_slots_from_json_annotation( - annotation, api_key, parent_path, video_frames + annotation, client, parent_path, video_frames ) else: return _download_single_slot_from_json_annotation( annotation, - api_key, + client, parent_path, annotation_path, video_frames, @@ -263,7 +274,10 @@ def _download_image_from_json_annotation( def _download_all_slots_from_json_annotation( - annotation: dt.AnnotationFile, api_key: str, parent_path: Path, video_frames: bool + annotation: dt.AnnotationFile, + client: "Client", + parent_path: Path, + video_frames: bool, ) -> Iterable[Callable[[], None]]: generator = [] for slot in annotation.slots: @@ -280,7 +294,7 @@ def _download_all_slots_from_json_annotation( video_path: Path = slot_path video_path.mkdir(exist_ok=True, parents=True) if not slot.frame_urls: - segment_manifests = get_segment_manifests(slot, slot_path, api_key) + segment_manifests = get_segment_manifests(slot, slot_path, client) for index, manifest in enumerate(segment_manifests): if slot.segments is None: raise ValueError("No segments found") @@ -290,7 +304,7 @@ def _download_all_slots_from_json_annotation( functools.partial( _download_and_extract_video_segment, segment_url, - api_key, + client, path, manifest, ) @@ -300,7 +314,7 @@ def _download_all_slots_from_json_annotation( path = video_path / f"{i:07d}.png" generator.append( functools.partial( - _download_image, frame_url, path, api_key, slot + _download_image, frame_url, path, client, slot ) ) else: @@ -312,7 +326,7 @@ def _download_all_slots_from_json_annotation( annotation, upload.url, file_path, - api_key, + client, ) ) return generator @@ -320,7 +334,7 @@ def _download_all_slots_from_json_annotation( def _download_single_slot_from_json_annotation( annotation: dt.AnnotationFile, - api_key: str, + client: "Client", parent_path: Path, annotation_path: Path, video_frames: bool, @@ -337,7 +351,7 @@ def _download_single_slot_from_json_annotation( # Indicates it's a long video and uses the segment and manifest if not slot.frame_urls: - segment_manifests = get_segment_manifests(slot, video_path, api_key) + segment_manifests = get_segment_manifests(slot, video_path, client) for index, manifest in enumerate(segment_manifests): if slot.segments is None: raise ValueError("No segments found") @@ -347,7 +361,7 @@ def _download_single_slot_from_json_annotation( functools.partial( _download_and_extract_video_segment, segment_url, - api_key, + client, path, manifest, ) @@ -356,7 +370,7 @@ def _download_single_slot_from_json_annotation( for i, frame_url in enumerate(slot.frame_urls): path = video_path / f"{i:07d}.png" generator.append( - functools.partial(_download_image, frame_url, path, api_key, slot) + functools.partial(_download_image, frame_url, path, client, slot) ) else: if len(slot.source_files) > 0: @@ -380,7 +394,7 @@ def _download_single_slot_from_json_annotation( annotation, image_url, image_path, - api_key, + client, ) ) return generator @@ -404,7 +418,7 @@ def _update_local_path(annotation: AnnotationFile, url, local_path): def _download_image( - url: str, path: Path, api_key: str, slot: Optional[dt.Slot] = None + url: str, path: Path, client: "Client", slot: Optional[dt.Slot] = None ) -> None: if path.exists(): return @@ -416,11 +430,11 @@ def _download_image( transform_file_function = _rg16_to_grayscale while True: if "token" in url: - response: requests.Response = requests.get(url, stream=True) - else: - response = requests.get( - url, headers={"Authorization": f"ApiKey {api_key}"}, stream=True + response: requests.Response = client._get_raw_from_full_url( + url, stream=True ) + else: + response = client._get_raw_from_full_url(url, stream=True) # Correct status: download image if response.ok and has_json_content_type(response): # this branch is a workaround for edge case in V1 when video file from external storage could be registered @@ -441,8 +455,8 @@ def _download_image( time.sleep(1) -def _download_image_with_trace(annotation, image_url, image_path, api_key): - _download_image(image_url, image_path, api_key) +def _download_image_with_trace(annotation, image_url, image_path, client): + _download_image(image_url, image_path, client) _update_local_path(annotation, image_url, image_path) @@ -498,9 +512,9 @@ def _rg16_to_grayscale(path): def _download_and_extract_video_segment( - url: str, api_key: str, path: Path, manifest: dt.SegmentManifest + url: str, client: "Client", path: Path, manifest: dt.SegmentManifest ) -> None: - _download_video_segment_file(url, api_key, path) + _download_video_segment_file(url, client, path) _extract_frames_from_segment(path, manifest) path.unlink() @@ -538,17 +552,9 @@ def _extract_frames_from_segment(path: Path, manifest: dt.SegmentManifest) -> No cap.release() -def _download_video_segment_file(url: str, api_key: str, path: Path) -> None: - with requests.Session() as session: - retries = Retry( - total=5, backoff_factor=0.5, status_forcelist=[500, 502, 503, 504] - ) - session.mount("https://", HTTPAdapter(max_retries=retries)) - if "token" in url: - response = session.get(url) - else: - session.headers = {"Authorization": f"ApiKey {api_key}"} - response = session.get(url) +def _download_video_segment_file(url: str, client: "Client", path: Path) -> None: + auth_token = "token" in url + response = client._get_raw_from_full_url(url, stream=True, auth_token=auth_token) if not response.ok or (400 <= response.status_code <= 499): raise Exception( f"Request to ({url}) failed. Status code: {response.status_code}, content:\n{get_response_content(response)}." @@ -559,41 +565,37 @@ def _download_video_segment_file(url: str, api_key: str, path: Path) -> None: file.write(chunk) -def download_manifest_txts(urls: List[str], api_key: str, folder: Path) -> List[Path]: +def download_manifest_txts( + urls: List[str], client: "Client", folder: Path +) -> List[Path]: paths = [] - with requests.Session() as session: - retries = Retry( - total=5, backoff_factor=0.5, status_forcelist=[500, 502, 503, 504] + for index, url in enumerate(urls): + auth_token = "token" in url + response = client._get_raw_from_full_url( + url, stream=True, auth_token=auth_token ) - session.mount("https://", HTTPAdapter(max_retries=retries)) - for index, url in enumerate(urls): - if "token" in url: - response = session.get(url) - else: - session.headers = {"Authorization": f"ApiKey {api_key}"} - response = session.get(url) - if not response.ok or (400 <= response.status_code <= 499): - raise Exception( - f"Request to ({url}) failed. Status code: {response.status_code}, content:\n{get_response_content(response)}." - ) - if not response.content: - raise Exception(f"Manifest file ({url}) is empty.") - path = folder / f"manifest_{index + 1}.txt" - with open(str(path), "wb") as file: - file.write(response.content) - paths.append(path) + if not response.ok or (400 <= response.status_code <= 499): + raise Exception( + f"Request to ({url}) failed. Status code: {response.status_code}, content:\n{get_response_content(response)}." + ) + if not response.content: + raise Exception(f"Manifest file ({url}) is empty.") + path = folder / f"manifest_{index + 1}.txt" + with open(str(path), "wb") as file: + file.write(response.content) + paths.append(path) return paths def get_segment_manifests( - slot: dt.Slot, parent_path: Path, api_key: str + slot: dt.Slot, parent_path: Path, client: "Client" ) -> List[dt.SegmentManifest]: with TemporaryDirectory(dir=parent_path) as tmpdirname: tmpdir = Path(tmpdirname) if slot.frame_manifest is None: raise ValueError("No frame manifest found") frame_urls = [item["url"] for item in slot.frame_manifest] - manifest_paths = download_manifest_txts(frame_urls, api_key, tmpdir) + manifest_paths = download_manifest_txts(frame_urls, client, tmpdir) segment_manifests = _parse_manifests(manifest_paths, slot.name or "0") return segment_manifests diff --git a/darwin/dataset/remote_dataset.py b/darwin/dataset/remote_dataset.py index 303cb884d..e0ded8fc0 100644 --- a/darwin/dataset/remote_dataset.py +++ b/darwin/dataset/remote_dataset.py @@ -39,7 +39,7 @@ is_unix_like_os, make_class_lists, ) -from darwin.datatypes import AnnotationClass, AnnotationFile, ItemId, PathLike, Team +from darwin.datatypes import AnnotationClass, AnnotationFile, ItemId, PathLike from darwin.exceptions import MissingDependency, NotFound, UnsupportedExportFormat from darwin.exporter.formats.darwin import build_image_annotation from darwin.item import DatasetItem @@ -385,15 +385,9 @@ def pull( # No images will be downloaded return None, 0 - team_config: Optional[Team] = self.client.config.get_team(self.team) - if not team_config: - raise ValueError("Unable to get Team configuration.") - - api_key = team_config.api_key - # Create the generator with the download instructions progress, count = download_all_images_from_annotations( - api_key=api_key, + client=self.client, annotations_path=annotations_dir, images_path=self.local_images_path, force_replace=force_replace, diff --git a/tests/darwin/backend_v2_test.py b/tests/darwin/backend_v2_test.py deleted file mode 100644 index dc5ee6c00..000000000 --- a/tests/darwin/backend_v2_test.py +++ /dev/null @@ -1,28 +0,0 @@ -from unittest.mock import Mock, call, patch - -import pytest -from requests.exceptions import HTTPError -from requests.models import Response -from tenacity import RetryError - -from darwin.backend_v2 import BackendV2 - - -class TestBackendV2: - @patch("time.sleep", return_value=None) - def test_register_items_retries_on_429(self, mock_sleep): - mock_client = Mock() - mock_response = Mock(spec=Response) - mock_response.status_code = 429 - mock_client._post_raw.side_effect = HTTPError(response=mock_response) - - backend = BackendV2(mock_client, "team_slug") - - payload = {"key": "value"} - with pytest.raises(RetryError): - backend.register_items(payload) - - assert mock_client._post_raw.call_count == 10 - - expected_call = call("/v2/teams/team_slug/items/register_existing", payload) - assert mock_client._post_raw.call_args_list == [expected_call] * 10 diff --git a/tests/darwin/client_test.py b/tests/darwin/client_test.py index 5b97fb8db..3bb3cc070 100644 --- a/tests/darwin/client_test.py +++ b/tests/darwin/client_test.py @@ -15,6 +15,12 @@ from tests.fixtures import * # noqa: F401, F403 +from unittest.mock import Mock, patch +from requests import Response, HTTPError +from darwin.client import MAX_RETRIES +from tenacity import RetryError + + @pytest.fixture def darwin_client( darwin_config_path: Path, darwin_datasets_path: Path, team_slug_darwin_json_v2: str @@ -557,3 +563,120 @@ def test_returns_list_of_external_storage_connections( assert actual_storages[1].readonly == expected_storage_2.readonly assert actual_storages[1].provider == expected_storage_2.provider assert actual_storages[1].default == expected_storage_2.default + + +class TestClientRetry: + @pytest.fixture + def mock_config(self): + config = Mock(spec=Config) + + # Set up the mock to return different values based on the key + def get_side_effect(key, default=None): + if key == "global/api_endpoint": + return "https://darwin.v7labs.com/api/" + if key == "global/payload_compression_level": + return "0" + return default + + config.get.side_effect = get_side_effect + config.get_team.return_value = Mock(api_key="test-key", slug="test-team") + return config + + @pytest.fixture + def client(self, mock_config): + return Client(config=mock_config, default_team="test-team") + + @patch("time.sleep", return_value=None) + def test_get_retries_on_429(self, mock_sleep, client): + mock_response = Mock(spec=Response) + mock_response.status_code = 429 + mock_response.headers = {} + mock_response.raise_for_status.side_effect = HTTPError(response=mock_response) + + with patch("requests.Session.get") as mock_get: + mock_get.return_value = mock_response + + with pytest.raises(RetryError): + client._get("/test-endpoint") + + assert mock_get.call_count == MAX_RETRIES + + @patch("time.sleep", return_value=None) + def test_post_retries_on_429(self, mock_sleep, client): + mock_response = Mock(spec=Response) + mock_response.status_code = 429 + mock_response.headers = {} + mock_response.raise_for_status.side_effect = HTTPError(response=mock_response) + + with patch("requests.post") as mock_post: + mock_post.return_value = mock_response + + with pytest.raises(RetryError): + client._post("/test-endpoint", {"test": "data"}) + + assert mock_post.call_count == MAX_RETRIES + + @patch("time.sleep", return_value=None) + def test_put_retries_on_429(self, mock_sleep, client): + + mock_response = Mock(spec=Response) + mock_response.status_code = 429 + mock_response.headers = {} + mock_response.raise_for_status.side_effect = HTTPError(response=mock_response) + + with patch("requests.Session.put") as mock_put: + mock_put.return_value = mock_response + + with pytest.raises(RetryError): + client._put("/test-endpoint", {"test": "data"}) + + assert mock_put.call_count == MAX_RETRIES + + @patch("time.sleep", return_value=None) + def test_request_succeeds_after_retries(self, mock_sleep, client): + mock_429_response = Mock(spec=Response) + mock_429_response.status_code = 429 + mock_429_response.headers = {} + + mock_success_response = Mock(spec=Response) + mock_success_response.status_code = 200 + mock_success_response.json.return_value = {"success": True} + mock_success_response.headers = {} + + with patch("requests.Session.get") as mock_get: + mock_get.side_effect = [ + HTTPError(response=mock_429_response), + HTTPError(response=mock_429_response), + mock_success_response, + ] + + result = client._get("/test-endpoint") + + assert result == {"success": True} + assert mock_get.call_count == 3 + + def test_no_retry_on_other_errors(self, client): + mock_response = Mock(spec=Response) + mock_response.status_code = 404 + + with patch("requests.Session.get") as mock_get: + mock_get.side_effect = HTTPError(response=mock_response) + + with pytest.raises(HTTPError): + client._get("/test-endpoint") + + assert mock_get.call_count == 1 + + @patch("time.sleep", return_value=None) + def test_retry_respects_rate_limit_headers(self, mock_sleep, client): + mock_response = Mock(spec=Response) + mock_response.status_code = 429 + + with patch("requests.Session.get") as mock_get: + mock_get.side_effect = HTTPError(response=mock_response) + + with pytest.raises(RetryError): + client._get("/test-endpoint") + + assert mock_get.call_count == MAX_RETRIES + assert mock_sleep.called diff --git a/tests/darwin/dataset/download_manager_test.py b/tests/darwin/dataset/download_manager_test.py index db9e6f315..f72f5f0fc 100644 --- a/tests/darwin/dataset/download_manager_test.py +++ b/tests/darwin/dataset/download_manager_test.py @@ -8,6 +8,8 @@ from darwin.dataset import download_manager as dm from darwin.datatypes import AnnotationClass, AnnotationFile, Slot, SourceFile from tests.fixtures import * +from darwin.client import Client +from darwin.config import Config @pytest.fixture @@ -28,6 +30,23 @@ def slot_w_manifests() -> Slot: ) +@pytest.fixture +def darwin_client( + darwin_config_path: Path, + darwin_datasets_path: Path, + team_slug_darwin_json_v2: str, +) -> Client: + config = Config(darwin_config_path) + config.put(["global", "api_endpoint"], "http://localhost/api") + config.put(["global", "base_url"], "http://localhost") + config.put(["teams", team_slug_darwin_json_v2, "api_key"], "mock_api_key") + config.put( + ["teams", team_slug_darwin_json_v2, "datasets_dir"], + str(darwin_datasets_path), + ) + return Client(config) + + def test_parse_manifests(manifest_paths: List[Path]) -> None: segment_manifests = dm._parse_manifests(manifest_paths, "0") assert len(segment_manifests) == 4 @@ -49,15 +68,21 @@ def test_parse_manifests(manifest_paths: List[Path]) -> None: assert segment_manifests[3].items[1].visibility is True +@pytest.mark.usefixtures("file_read_write_test") def test_get_segment_manifests( - manifest_paths: List[Path], slot_w_manifests: Slot + manifest_paths: List[Path], + slot_w_manifests: Slot, + darwin_client: Client, + file_read_write_test, ) -> None: parent_path = Path("tests/darwin/dataset/data/manifest_examples") files = [open(path, "r").read() for path in manifest_paths] with responses.RequestsMock() as rsps: rsps.add(responses.GET, "http://test.com", body=files[0]) rsps.add(responses.GET, "http://test2.com", body=files[1]) - segment_manifests = dm.get_segment_manifests(slot_w_manifests, parent_path, "") + segment_manifests = dm.get_segment_manifests( + slot_w_manifests, parent_path, darwin_client + ) assert len(segment_manifests) == 4 assert len(segment_manifests[0].items) == 2 assert len(segment_manifests[1].items) == 2 diff --git a/tests/darwin/dataset/remote_dataset_test.py b/tests/darwin/dataset/remote_dataset_test.py index 0fb12d703..ccd65bc87 100644 --- a/tests/darwin/dataset/remote_dataset_test.py +++ b/tests/darwin/dataset/remote_dataset_test.py @@ -1284,8 +1284,13 @@ def test_raises_error_if_timeout_less_than_interval(self, remote_dataset): class TestPullNamingConvention: + @pytest.mark.usefixtures("file_read_write_test") def _test_pull_naming_convention( - self, file_name, use_folders, video_frames, force_slots + self, + file_name, + use_folders, + video_frames, + force_slots, ): with tempfile.TemporaryDirectory() as temp_dir: with zipfile.ZipFile("tests/data.zip") as zfile: @@ -1294,7 +1299,7 @@ def _test_pull_naming_convention( Path(temp_dir) / "v7-darwin-json-v2/pull_naming_tests" / file_name ) download_func = _download_image_from_json_annotation( - api_key="api_key", + client=darwin_client, annotation_path=file_path, image_path=Path("dataset_dir_path"), use_folders=use_folders, @@ -1823,7 +1828,10 @@ def test_omits_unavailable_releases_when_retry_is_false( assert isinstance(releases[0], Release) -def test_force_slots_true(mock_is_file_extension_allowed): +@pytest.mark.usefixtures("file_read_write_test") +def test_force_slots_true( + mock_is_file_extension_allowed, darwin_client, file_read_write_test +): mock_is_file_extension_allowed.return_value = True with tempfile.TemporaryDirectory() as tmp_dir: with zipfile.ZipFile("tests/data.zip") as zfile: @@ -1831,8 +1839,8 @@ def test_force_slots_true(mock_is_file_extension_allowed): annotations_path = Path(tmp_dir) / "v7-darwin-json-v2/force_slots" images_path = Path("images") generator, count = download_all_images_from_annotations( - api_key="api_key", annotations_path=annotations_path, + client=darwin_client, images_path=images_path, force_slots=True, force_replace=True, @@ -1857,7 +1865,10 @@ def test_force_slots_true(mock_is_file_extension_allowed): assert expected_path in planned_paths -def test_force_slots_false(mock_is_file_extension_allowed): +@pytest.mark.usefixtures("file_read_write_test") +def test_force_slots_false( + mock_is_file_extension_allowed, darwin_client, file_read_write_test +): mock_is_file_extension_allowed.return_value = True with tempfile.TemporaryDirectory() as tmp_dir: with zipfile.ZipFile("tests/data.zip") as zfile: @@ -1865,8 +1876,8 @@ def test_force_slots_false(mock_is_file_extension_allowed): annotations_path = Path(tmp_dir) / "v7-darwin-json-v2/force_slots" images_path = Path("images") generator, count = download_all_images_from_annotations( - api_key="api_key", annotations_path=annotations_path, + client=darwin_client, images_path=images_path, force_slots=False, force_replace=True, diff --git a/tests/darwin/dataset/upload_manager_test.py b/tests/darwin/dataset/upload_manager_test.py index 7e01c4251..71a6f5319 100644 --- a/tests/darwin/dataset/upload_manager_test.py +++ b/tests/darwin/dataset/upload_manager_test.py @@ -6,7 +6,7 @@ import inspect -from darwin.client import Client +from darwin.client import Client, MAX_RETRIES from darwin.config import Config from darwin.dataset import RemoteDataset from darwin.dataset.identifier import DatasetIdentifier @@ -199,15 +199,16 @@ def test_error_count_is_correct_on_signature_request( responses.add(responses.GET, sign_upload_endpoint, status=500) local_file = LocalFile(local_path=Path("test.jpg")) - with patch.object(dataset, "fetch_remote_files", return_value=[]): + with patch.object(dataset, "fetch_remote_files", return_value=[]), patch( + "time.sleep", return_value=None + ): upload_handler = UploadHandler.build(dataset, [local_file]) - - upload_handler.upload() - for file_to_upload in upload_handler.progress: - file_to_upload() + upload_handler.upload() + for file_to_upload in upload_handler.progress: + file_to_upload() responses.assert_call_count(request_upload_endpoint, 1) - responses.assert_call_count(sign_upload_endpoint, 1) + responses.assert_call_count(sign_upload_endpoint, MAX_RETRIES) responses.assert_call_count(upload_to_s3_endpoint, 0) responses.assert_call_count(confirm_upload_endpoint, 0)