Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

### Added

- Added retry with back-off logic for Redis related functions. [#528](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/528)
- Added nanosecond precision datetime filtering that ensures nanosecond precision support in filtering by datetime. This is configured via the `USE_DATETIME_NANOS` environment variable, while maintaining microseconds compatibility for datetime precision. [#529](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/529)

### Changed
Expand Down
1 change: 1 addition & 0 deletions stac_fastapi/core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dependencies = [
"jsonschema~=4.0.0",
"slowapi~=0.1.9",
"redis==6.4.0",
"retry==0.9.2",
]

[project.urls]
Expand Down
184 changes: 97 additions & 87 deletions stac_fastapi/core/stac_fastapi/core/redis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,39 @@

import json
import logging
from typing import List, Optional, Tuple
from functools import wraps
from typing import Callable, List, Optional, Tuple, cast
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse

from pydantic import Field, field_validator
from pydantic_settings import BaseSettings
from redis import asyncio as aioredis
from redis.asyncio.sentinel import Sentinel
from redis.exceptions import ConnectionError as RedisConnectionError
from redis.exceptions import TimeoutError as RedisTimeoutError
from retry import retry # type: ignore

logger = logging.getLogger(__name__)


class RedisSentinelSettings(BaseSettings):
"""Configuration for connecting to Redis Sentinel."""
class RedisCommonSettings(BaseSettings):
"""Common configuration for Redis Sentinel and Redis Standalone."""

REDIS_SENTINEL_HOSTS: str = ""
REDIS_SENTINEL_PORTS: str = "26379"
REDIS_SENTINEL_MASTER_NAME: str = "master"
REDIS_DB: int = 15

REDIS_MAX_CONNECTIONS: Optional[int] = None
REDIS_RETRY_TIMEOUT: bool = True
REDIS_DECODE_RESPONSES: bool = True
REDIS_CLIENT_NAME: str = "stac-fastapi-app"
REDIS_HEALTH_CHECK_INTERVAL: int = Field(default=30, gt=0)
REDIS_SELF_LINK_TTL: int = 1800

REDIS_QUERY_RETRIES_NUM: int = Field(default=3, gt=0)
REDIS_QUERY_INITIAL_DELAY: float = Field(default=1.0, gt=0)
REDIS_QUERY_BACKOFF: float = Field(default=2.0, gt=1)

@field_validator("REDIS_DB")
@classmethod
def validate_db_sentinel(cls, v: int) -> int:
def validate_db(cls, v: int) -> int:
"""Validate REDIS_DB is not negative integer."""
if v < 0:
raise ValueError("REDIS_DB must be a positive integer")
Expand All @@ -46,12 +50,20 @@ def validate_max_connections(cls, v):

@field_validator("REDIS_SELF_LINK_TTL")
@classmethod
def validate_self_link_ttl_sentinel(cls, v: int) -> int:
"""Validate REDIS_SELF_LINK_TTL is not a negative integer."""
def validate_self_link_ttl(cls, v: int) -> int:
"""Validate REDIS_SELF_LINK_TTL is negative."""
if v < 0:
raise ValueError("REDIS_SELF_LINK_TTL must be a positive integer")
return v


class RedisSentinelSettings(RedisCommonSettings):
"""Configuration for connecting to Redis Sentinel."""

REDIS_SENTINEL_HOSTS: str = ""
REDIS_SENTINEL_PORTS: str = "26379"
REDIS_SENTINEL_MASTER_NAME: str = "master"

def get_sentinel_hosts(self) -> List[str]:
"""Parse Redis Sentinel hosts from string to list."""
if not self.REDIS_SENTINEL_HOSTS:
Expand Down Expand Up @@ -96,19 +108,11 @@ def get_sentinel_nodes(self) -> List[Tuple[str, int]]:
return [(str(host), int(port)) for host, port in zip(hosts, ports)]


class RedisSettings(BaseSettings):
class RedisSettings(RedisCommonSettings):
"""Configuration for connecting Redis."""

REDIS_HOST: str = ""
REDIS_PORT: int = 6379
REDIS_DB: int = 15

REDIS_MAX_CONNECTIONS: Optional[int] = None
REDIS_RETRY_TIMEOUT: bool = True
REDIS_DECODE_RESPONSES: bool = True
REDIS_CLIENT_NAME: str = "stac-fastapi-app"
REDIS_HEALTH_CHECK_INTERVAL: int = Field(default=30, gt=0)
REDIS_SELF_LINK_TTL: int = 1800

@field_validator("REDIS_PORT")
@classmethod
Expand All @@ -118,89 +122,93 @@ def validate_port_standalone(cls, v: int) -> int:
raise ValueError("REDIS_PORT must be a positive integer")
return v

@field_validator("REDIS_DB")
@classmethod
def validate_db_standalone(cls, v: int) -> int:
"""Validate REDIS_DB is not a negative integer."""
if v < 0:
raise ValueError("REDIS_DB must be a positive integer")
return v

@field_validator("REDIS_MAX_CONNECTIONS", mode="before")
@classmethod
def validate_max_connections(cls, v):
"""Handle empty/None values for REDIS_MAX_CONNECTIONS."""
if v in ["", "null", "Null", "NULL", "none", "None", "NONE", None]:
return None
return v

@field_validator("REDIS_SELF_LINK_TTL")
@classmethod
def validate_self_link_ttl_standalone(cls, v: int) -> int:
"""Validate REDIS_SELF_LINK_TTL is negative."""
if v < 0:
raise ValueError("REDIS_SELF_LINK_TTL must be a positive integer")
return v


# Configure only one Redis configuration
sentinel_settings = RedisSentinelSettings()
standalone_settings = RedisSettings()
settings: RedisCommonSettings = cast(
RedisCommonSettings,
sentinel_settings if sentinel_settings.REDIS_SENTINEL_HOSTS else RedisSettings(),
)


def redis_retry(func: Callable) -> Callable:
"""Retry with back-off decorator for Redis connections."""

@wraps(func)
@retry(
exceptions=(RedisConnectionError, RedisTimeoutError),
tries=settings.REDIS_QUERY_RETRIES_NUM,
delay=settings.REDIS_QUERY_INITIAL_DELAY,
backoff=settings.REDIS_QUERY_BACKOFF,
logger=logger,
)
async def wrapper(*args, **kwargs):
return await func(*args, **kwargs)

return wrapper

async def connect_redis() -> Optional[aioredis.Redis]:

@redis_retry
async def _connect_redis_internal() -> Optional[aioredis.Redis]:
"""Return a Redis connection Redis or Redis Sentinel."""
try:
if sentinel_settings.REDIS_SENTINEL_HOSTS:
sentinel_nodes = sentinel_settings.get_sentinel_nodes()
sentinel = Sentinel(
sentinel_nodes,
decode_responses=sentinel_settings.REDIS_DECODE_RESPONSES,
)
if sentinel_settings.REDIS_SENTINEL_HOSTS:
sentinel_nodes = settings.get_sentinel_nodes()
sentinel = Sentinel(
sentinel_nodes,
decode_responses=settings.REDIS_DECODE_RESPONSES,
)

redis = sentinel.master_for(
service_name=sentinel_settings.REDIS_SENTINEL_MASTER_NAME,
db=sentinel_settings.REDIS_DB,
decode_responses=sentinel_settings.REDIS_DECODE_RESPONSES,
retry_on_timeout=sentinel_settings.REDIS_RETRY_TIMEOUT,
client_name=sentinel_settings.REDIS_CLIENT_NAME,
max_connections=sentinel_settings.REDIS_MAX_CONNECTIONS,
health_check_interval=sentinel_settings.REDIS_HEALTH_CHECK_INTERVAL,
)
logger.info("Connected to Redis Sentinel")

elif standalone_settings.REDIS_HOST:
pool = aioredis.ConnectionPool(
host=standalone_settings.REDIS_HOST,
port=standalone_settings.REDIS_PORT,
db=standalone_settings.REDIS_DB,
max_connections=standalone_settings.REDIS_MAX_CONNECTIONS,
decode_responses=standalone_settings.REDIS_DECODE_RESPONSES,
retry_on_timeout=standalone_settings.REDIS_RETRY_TIMEOUT,
health_check_interval=standalone_settings.REDIS_HEALTH_CHECK_INTERVAL,
)
redis = aioredis.Redis(
connection_pool=pool, client_name=standalone_settings.REDIS_CLIENT_NAME
)
logger.info("Connected to Redis")
else:
logger.warning("No Redis configuration found")
return None
redis = sentinel.master_for(
service_name=settings.REDIS_SENTINEL_MASTER_NAME,
db=settings.REDIS_DB,
decode_responses=settings.REDIS_DECODE_RESPONSES,
retry_on_timeout=settings.REDIS_RETRY_TIMEOUT,
client_name=settings.REDIS_CLIENT_NAME,
max_connections=settings.REDIS_MAX_CONNECTIONS,
health_check_interval=settings.REDIS_HEALTH_CHECK_INTERVAL,
)
logger.info("Connected to Redis Sentinel")

elif settings.REDIS_HOST:
pool = aioredis.ConnectionPool(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
max_connections=settings.REDIS_MAX_CONNECTIONS,
decode_responses=settings.REDIS_DECODE_RESPONSES,
retry_on_timeout=settings.REDIS_RETRY_TIMEOUT,
health_check_interval=settings.REDIS_HEALTH_CHECK_INTERVAL,
)
redis = aioredis.Redis(
connection_pool=pool, client_name=settings.REDIS_CLIENT_NAME
)
logger.info("Connected to Redis")
else:
logger.warning("No Redis configuration found")
return None

return redis

return redis

async def connect_redis() -> Optional[aioredis.Redis]:
"""Handle Redis connection."""
try:
return await _connect_redis_internal()
except (
aioredis.ConnectionError,
aioredis.TimeoutError,
) as e:
logger.error(f"Redis connection failed after retries: {e}")
except aioredis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
return None
except aioredis.AuthenticationError as e:
logger.error(f"Redis authentication error: {e}")
return None
except aioredis.TimeoutError as e:
logger.error(f"Redis timeout error: {e}")
return None
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
return None
return None


def get_redis_key(url: str, token: str) -> str:
Expand Down Expand Up @@ -230,19 +238,21 @@ def build_url_with_token(base_url: str, token: str) -> str:
)


@redis_retry
async def save_prev_link(
redis: aioredis.Redis, next_url: str, current_url: str, next_token: str
) -> None:
"""Save the current page as the previous link for the next URL."""
if next_url and next_token:
if sentinel_settings.REDIS_SENTINEL_HOSTS:
ttl_seconds = sentinel_settings.REDIS_SELF_LINK_TTL
elif standalone_settings.REDIS_HOST:
ttl_seconds = standalone_settings.REDIS_SELF_LINK_TTL
ttl_seconds = settings.REDIS_SELF_LINK_TTL
elif settings.REDIS_HOST:
ttl_seconds = settings.REDIS_SELF_LINK_TTL
key = get_redis_key(next_url, next_token)
await redis.setex(key, ttl_seconds, current_url)


@redis_retry
async def get_prev_link(
redis: aioredis.Redis, current_url: str, current_token: str
) -> Optional[str]:
Expand Down
90 changes: 90 additions & 0 deletions stac_fastapi/tests/redis/test_redis_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest
from redis.exceptions import ConnectionError as RedisConnectionError

import stac_fastapi.core.redis_utils as redis_utils
from stac_fastapi.core.redis_utils import connect_redis, get_prev_link, save_prev_link


Expand Down Expand Up @@ -46,3 +48,91 @@ async def test_redis_utils_functions():
redis, "http://mywebsite.com/search", "non_existent_token"
)
assert non_existent is None


@pytest.mark.asyncio
async def test_redis_retry_retries_until_success(monkeypatch):
monkeypatch.setattr(
redis_utils.settings, "REDIS_QUERY_RETRIES_NUM", 3, raising=False
)
monkeypatch.setattr(
redis_utils.settings, "REDIS_QUERY_INITIAL_DELAY", 0, raising=False
)
monkeypatch.setattr(redis_utils.settings, "REDIS_QUERY_BACKOFF", 2.0, raising=False)

captured_kwargs = {}

def fake_retry(**kwargs):
captured_kwargs.update(kwargs)

def decorator(func):
async def wrapped(*args, **inner_kwargs):
attempts = 0
while True:
try:
attempts += 1
return await func(*args, **inner_kwargs)
except kwargs["exceptions"] as exc:
if attempts >= kwargs["tries"]:
raise exc
continue

return wrapped

return decorator

monkeypatch.setattr(redis_utils, "retry", fake_retry)

call_counter = {"count": 0}

@redis_utils.redis_retry
async def flaky() -> str:
call_counter["count"] += 1
if call_counter["count"] < 3:
raise RedisConnectionError("transient failure")
return "success"

result = await flaky()

assert result == "success"
assert call_counter["count"] == 3
assert captured_kwargs["tries"] == redis_utils.settings.REDIS_QUERY_RETRIES_NUM
assert captured_kwargs["delay"] == redis_utils.settings.REDIS_QUERY_INITIAL_DELAY
assert captured_kwargs["backoff"] == redis_utils.settings.REDIS_QUERY_BACKOFF


@pytest.mark.asyncio
async def test_redis_retry_raises_after_exhaustion(monkeypatch):
monkeypatch.setattr(
redis_utils.settings, "REDIS_QUERY_RETRIES_NUM", 3, raising=False
)
monkeypatch.setattr(
redis_utils.settings, "REDIS_QUERY_INITIAL_DELAY", 0, raising=False
)
monkeypatch.setattr(redis_utils.settings, "REDIS_QUERY_BACKOFF", 2.0, raising=False)

def fake_retry(**kwargs):
def decorator(func):
async def wrapped(*args, **inner_kwargs):
attempts = 0
while True:
try:
attempts += 1
return await func(*args, **inner_kwargs)
except kwargs["exceptions"] as exc:
if attempts >= kwargs["tries"]:
raise exc
continue

return wrapped

return decorator

monkeypatch.setattr(redis_utils, "retry", fake_retry)

@redis_utils.redis_retry
async def always_fail() -> str:
raise RedisConnectionError("pernament failure")

with pytest.raises(RedisConnectionError):
await always_fail()