diff --git a/CHANGES.md b/CHANGES.md index 69628f4..9d46bdb 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,10 @@ ## [Unreleased] +### Fixed + +- Parsing of `CORS_ORIGINS`, `CORS_HEADERS`, and `CORS_METHODS` from environment variables ([#313](https://github.com/stac-utils/stac-fastapi-pgstac/pull/313)) + ### Changed - Docker container runs as non-root user diff --git a/stac_fastapi/pgstac/config.py b/stac_fastapi/pgstac/config.py index 9af6513..e0daecc 100644 --- a/stac_fastapi/pgstac/config.py +++ b/stac_fastapi/pgstac/config.py @@ -1,11 +1,12 @@ """Postgres API configuration.""" +import json import warnings from typing import Annotated, Any, List, Optional, Sequence, Type from urllib.parse import quote_plus as quote from pydantic import BaseModel, BeforeValidator, Field, model_validator -from pydantic_settings import BaseSettings, SettingsConfigDict +from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict from stac_fastapi.types.config import ApiSettings from typing_extensions import Self @@ -158,8 +159,12 @@ def connection_string(self): def str_to_list(value: Any) -> Any: if isinstance(value, str): - return [v.strip() for v in value.split(",")] - return value + if value.startswith("["): + return json.loads(value) + else: + return [v.strip() for v in value.split(",")] + else: + return value class Settings(ApiSettings): @@ -201,15 +206,17 @@ class Settings(ApiSettings): Implies that the `Transactions` extension is enabled. """ - cors_origins: Annotated[Sequence[str], BeforeValidator(str_to_list)] = ("*",) + cors_origins: Annotated[Sequence[str], BeforeValidator(str_to_list), NoDecode] = ( + "*", + ) cors_origin_regex: Optional[str] = None - cors_methods: Annotated[Sequence[str], BeforeValidator(str_to_list)] = ( + cors_methods: Annotated[Sequence[str], BeforeValidator(str_to_list), NoDecode] = ( "GET", "POST", "OPTIONS", ) cors_credentials: bool = False - cors_headers: Annotated[Sequence[str], BeforeValidator(str_to_list)] = ( + cors_headers: Annotated[Sequence[str], BeforeValidator(str_to_list), NoDecode] = ( "Content-Type", ) diff --git a/tests/test_config.py b/tests/test_config.py index e05eb5e..195b776 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -4,8 +4,9 @@ import pytest from pydantic import ValidationError +from pytest import MonkeyPatch -from stac_fastapi.pgstac.config import PostgresSettings +from stac_fastapi.pgstac.config import PostgresSettings, Settings async def test_pg_settings_with_env(monkeypatch): @@ -74,3 +75,60 @@ async def test_pg_settings_attributes(monkeypatch): postgres_dbname="pgstac", _env_file=None, ) + + +@pytest.mark.parametrize( + "cors_origins", + [ + "http://stac-fastapi-pgstac.test,http://stac-fastapi.test", + '["http://stac-fastapi-pgstac.test","http://stac-fastapi.test"]', + ], +) +def test_cors_origins(monkeypatch: MonkeyPatch, cors_origins: str) -> None: + monkeypatch.setenv( + "CORS_ORIGINS", + cors_origins, + ) + settings = Settings() + assert settings.cors_origins == [ + "http://stac-fastapi-pgstac.test", + "http://stac-fastapi.test", + ] + + +@pytest.mark.parametrize( + "cors_methods", + [ + "GET,POST", + '["GET","POST"]', + ], +) +def test_cors_methods(monkeypatch: MonkeyPatch, cors_methods: str) -> None: + monkeypatch.setenv( + "CORS_METHODS", + cors_methods, + ) + settings = Settings() + assert settings.cors_methods == [ + "GET", + "POST", + ] + + +@pytest.mark.parametrize( + "cors_headers", + [ + "Content-Type,X-Foo", + '["Content-Type","X-Foo"]', + ], +) +def test_cors_headers(monkeypatch: MonkeyPatch, cors_headers: str) -> None: + monkeypatch.setenv( + "CORS_HEADERS", + cors_headers, + ) + settings = Settings() + assert settings.cors_headers == [ + "Content-Type", + "X-Foo", + ]