diff --git a/pyproject.toml b/pyproject.toml index d1adf9a..2914683 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "tinybird-python-sdk" -version = "0.3.4" +version = "0.3.5" description = "Python SDK for Tinybird" readme = "README.md" authors = [ diff --git a/tb/a/api.py b/tb/a/api.py index c871be2..52f6d91 100644 --- a/tb/a/api.py +++ b/tb/a/api.py @@ -35,6 +35,8 @@ def __init__( self.retry_total = retry_total self.token_error = TOKEN_ERROR + self._shutdown = False + self._pending_requests = 0 def ui_url(self) -> str: return self.api_url.replace("api", "ui") @@ -48,8 +50,27 @@ async def _get_session(self) -> aiohttp.ClientSession: async def close(self) -> None: """Close the aiohttp session.""" - if self._session and not self._session.closed: - await self._session.close() + try: + # Set shutdown flag to prevent new operations + self._shutdown = True + + # Wait for any pending requests to complete + if self._pending_requests > 0: + logging.info( + f"Waiting for {self._pending_requests} pending requests to complete..." + ) + # Give a short time for requests to complete + for _ in range(5): # Try for up to 5 seconds + if self._pending_requests == 0: + break + await asyncio.sleep(1) + + # Close the session + if self._session and not self._session.closed: + await self._session.close() + except asyncio.CancelledError: + # If we're cancelled during close, just propagate the error + raise async def __aenter__(self): """Support for async context manager.""" @@ -61,84 +82,127 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): async def _handle_rate_limit(self) -> None: """Handle rate limiting by waiting if necessary.""" - if self.rate_limit_remaining == 0: - time_to_sleep = min((self.rate_limit_reset - time.time()), 10) - time_to_sleep = max(time_to_sleep, 1) + 1 - logging.info(f"Waiting {str(time_to_sleep)} seconds before retrying...") - await asyncio.sleep(time_to_sleep) - logging.info("Retrying") + try: + if self.rate_limit_remaining == 0: + time_to_sleep = min((self.rate_limit_reset - time.time()), 10) + time_to_sleep = max(time_to_sleep, 1) + 1 + logging.info(f"Waiting {str(time_to_sleep)} seconds before retrying...") + await asyncio.sleep(time_to_sleep) + logging.info("Retrying") + except asyncio.CancelledError: + self._shutdown = True + raise def _set_rate_limit(self, response: aiohttp.ClientResponse) -> None: """Update rate limit information from response headers.""" - headers = response.headers - if "X-Ratelimit-Limit" in headers: - self.rate_limit_points = int(headers.get("X-Ratelimit-Limit")) - self.rate_limit_remaining = int(headers.get("X-Ratelimit-Remaining")) - self.rate_limit_reset = int(headers.get("X-Ratelimit-Reset")) - self.retry_after = int(headers.get("Retry-After", "0")) + try: + headers = response.headers + if "X-Ratelimit-Limit" in headers: + self.rate_limit_points = int(headers.get("X-Ratelimit-Limit")) + self.rate_limit_remaining = int(headers.get("X-Ratelimit-Remaining")) + self.rate_limit_reset = int(headers.get("X-Ratelimit-Reset")) + self.retry_after = int(headers.get("Retry-After", "0")) + except asyncio.CancelledError: + self._shutdown = True + raise async def send(self, path: str, method: str = "POST", **kwargs): @backoff.on_exception( backoff.expo, (RateLimitError,), max_tries=self.retry_total ) async def _send(): - session = await self._get_session() - headers = {"Authorization": f"Bearer {self.token}"} - - if "headers" in kwargs: - kwargs["headers"].update(headers) - else: - kwargs["headers"] = headers - - url = f"{self.api_url}/{self.version}/{path.lstrip('/')}" - - while True: - if method == "POST": - response = await session.post(url, **kwargs) - elif method == "DELETE": - response = await session.delete(url, **kwargs) - else: - response = await session.get(url, **kwargs) - - self._set_rate_limit(response) - - if response.status == 429: - logging.warning( - f"Too many requests, you can do {self.rate_limit_points} requests per minute..." - ) - raise RateLimitError() - else: - break - - if response.status == 403: - logging.error(self.token_error) - - response.raise_for_status() - return response + try: + # If we're shutting down, don't start new requests + if self._shutdown: + raise asyncio.CancelledError() + + # Increment pending requests counter + self._pending_requests += 1 + + try: + session = await self._get_session() + headers = {"Authorization": f"Bearer {self.token}"} + + if "headers" in kwargs: + kwargs["headers"].update(headers) + else: + kwargs["headers"] = headers + + url = f"{self.api_url}/{self.version}/{path.lstrip('/')}" + + while True: + if method == "POST": + response = await session.post(url, **kwargs) + elif method == "DELETE": + response = await session.delete(url, **kwargs) + else: + response = await session.get(url, **kwargs) + + self._set_rate_limit(response) + + if response.status == 429: + logging.warning( + f"Too many requests, you can do {self.rate_limit_points} requests per minute..." + ) + raise RateLimitError() + else: + break + + if response.status == 403: + logging.error(self.token_error) + + response.raise_for_status() + return response + finally: + # Decrement pending requests counter + self._pending_requests -= 1 + except asyncio.CancelledError: + self._shutdown = True + raise return await _send() async def post(self, path: str, **kwargs) -> aiohttp.ClientResponse: """Send a POST request to the Tinybird API.""" - return await self.send(path, method="POST", **kwargs) + try: + return await self.send(path, method="POST", **kwargs) + except asyncio.CancelledError: + self._shutdown = True + raise async def get(self, path: str, **kwargs) -> aiohttp.ClientResponse: """Send a GET request to the Tinybird API.""" - return await self.send(path, method="GET", **kwargs) + try: + return await self.send(path, method="GET", **kwargs) + except asyncio.CancelledError: + self._shutdown = True + raise async def delete(self, path: str, **kwargs) -> aiohttp.ClientResponse: """Send a DELETE request to the Tinybird API.""" - return await self.send(path, method="DELETE", **kwargs) + try: + return await self.send(path, method="DELETE", **kwargs) + except asyncio.CancelledError: + self._shutdown = True + raise async def get_json(self, path: str, **kwargs) -> Dict[str, Any]: """Send a GET request and return the JSON response.""" - response = await self.get(path, **kwargs) - return await response.json() + try: + response = await self.get(path, **kwargs) + return await response.json() + except asyncio.CancelledError: + self._shutdown = True + raise async def post_json(self, path: str, **kwargs) -> Dict[str, Any]: """Send a POST request and return the JSON response.""" - response = await self.post(path, **kwargs) - return await response.json() + try: + response = await self.post(path, **kwargs) + return await response.json() + except asyncio.CancelledError: + self._shutdown = True + raise async def initialize(self) -> None: """Initialize the API by checking the token validity.""" @@ -148,3 +212,6 @@ async def initialize(self) -> None: if e.status == 403: logging.error(self.token_error) sys.exit(-1) + except asyncio.CancelledError: + self._shutdown = True + raise diff --git a/tb/a/datasource.py b/tb/a/datasource.py index 713f163..252c093 100644 --- a/tb/a/datasource.py +++ b/tb/a/datasource.py @@ -23,45 +23,77 @@ def __init__( self.timer_start = None self.sink = None self._lock = asyncio.Lock() + self._shutdown = False + self._pending_flush = False async def append(self): - async with self._lock: - while self.sink and self.sink.wait: - logging.info("Waiting while flushing...") - await asyncio.sleep(0.1) - - self.records += 1 - if max(self.records % self.max_wait_records / 100, 10) == 0: - logging.info( - f"Buffering {self.records} records and {bytes2human(self.sink.tell())} bytes" - ) - - if ( - self.records < self.max_wait_records - and self.sink.tell() < self.max_wait_bytes - ): - if not self.timer_task or self.timer_task.done(): - self.timer_start = asyncio.get_event_loop().time() - self.timer_task = asyncio.create_task(self._timer_callback()) - else: - await self.flush() - - async def _timer_callback(self): - await asyncio.sleep(self.max_wait_seconds) - await self.flush() - - async def flush(self): - async with self._lock: + try: + async with self._lock: + # If we're already shutting down or flushing, don't add more records + if self._shutdown or self._pending_flush: + return + + # If we're waiting for a flush to complete, wait a bit + while self.sink and self.sink.wait and not self._shutdown: + logging.info("Waiting while flushing...") + await asyncio.sleep(0.1) + + if self._shutdown: + return + + self.records += 1 + if max(self.records % self.max_wait_records / 100, 10) == 0: + logging.info( + f"Buffering {self.records} records and {bytes2human(self.sink.tell())} bytes" + ) + + if ( + self.records < self.max_wait_records + and self.sink.tell() < self.max_wait_bytes + ): + if not self.timer_task or self.timer_task.done(): + self.timer_start = asyncio.get_event_loop().time() + self.timer_task = asyncio.create_task(self._timer_callback()) + else: + await self.flush() + except asyncio.CancelledError: + self._shutdown = True if self.timer_task and not self.timer_task.done(): self.timer_task.cancel() - self.timer_task = None - self.timer_start = None + raise - if not self.records or not self.sink: - return + async def _timer_callback(self): + try: + await asyncio.sleep(self.max_wait_seconds) + await self.flush() + except asyncio.CancelledError: + self._shutdown = True + raise - await self.sink.flush() - self.records = 0 + async def flush(self): + try: + async with self._lock: + # If we're already shutting down or there's nothing to flush, return early + if self._shutdown or not self.records or not self.sink: + return + + # Cancel any pending timer task + if self.timer_task and not self.timer_task.done(): + self.timer_task.cancel() + self.timer_task = None + self.timer_start = None + + # Mark that we're about to flush to prevent new records from being added + self._pending_flush = True + + try: + await self.sink.flush() + self.records = 0 + finally: + self._pending_flush = False + except asyncio.CancelledError: + self._shutdown = True + raise class AsyncDatasource: @@ -82,18 +114,28 @@ def __init__( self.buffer.sink = self self.wait = False self._lock = asyncio.Lock() + self._shutdown = False + self._pending_flush = False def reset(self): self.chunk = StringIO() async def append(self, value: Union[str, bytes, Dict[str, Any]]): - async with self._lock: - if isinstance(value, bytes): - value = value.decode("utf-8") - if not isinstance(value, str): - value = json.dumps(value) - self.chunk.write(value + "\n") - await self.buffer.append() + try: + async with self._lock: + # If we're already shutting down or flushing, don't add more records + if self._shutdown or self._pending_flush: + return + + if isinstance(value, bytes): + value = value.decode("utf-8") + if not isinstance(value, str): + value = json.dumps(value) + self.chunk.write(value + "\n") + await self.buffer.append() + except asyncio.CancelledError: + self._shutdown = True + raise def tell(self): return self.chunk.tell() @@ -113,18 +155,41 @@ async def __lshift__(self, row): return self async def close(self): - await self.buffer.flush() - await self.api.close() + try: + # Set shutdown flag to prevent new operations + self._shutdown = True + + # Only try to flush if we have data + if self.buffer.records > 0: + await self.buffer.flush() + + # Close the API connection + await self.api.close() + except asyncio.CancelledError: + # If we're cancelled during close, just propagate the error + raise async def flush(self): - async with self._lock: - try: - logging.info( - f"Flushing {self.buffer.records} records and {bytes2human(self.tell())} bytes to {self.datasource_name}" - ) - self.wait = True - data = self.chunk.getvalue() - self.reset() - await self.api.post(self.path, data=data) - finally: - self.wait = False + try: + async with self._lock: + # If we're already shutting down or there's nothing to flush, return early + if self._shutdown or self.buffer.records == 0: + return + + # Mark that we're about to flush to prevent new records from being added + self._pending_flush = True + + try: + logging.info( + f"Flushing {self.buffer.records} records and {bytes2human(self.tell())} bytes to {self.datasource_name}" + ) + self.wait = True + data = self.chunk.getvalue() + self.reset() + await self.api.post(self.path, data=data) + finally: + self.wait = False + self._pending_flush = False + except asyncio.CancelledError: + self._shutdown = True + raise