diff --git a/pyproject.toml b/pyproject.toml index 1c540e2f63b66..cdb040faccea1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -149,6 +149,9 @@ addopts = [ ] xfail_strict = true junit_family = "xunit2" +markers = [ + "asyncio: marks tests as async (deselect with '-m \"not asyncio\"')", +] filterwarnings = [ "error", 'ignore:starlette.middleware.wsgi is deprecated and will be removed in a future release\..*:DeprecationWarning:starlette', diff --git a/tests/test_background_tasks.py b/tests/test_background_tasks.py new file mode 100644 index 0000000000000..3db1b99516c80 --- /dev/null +++ b/tests/test_background_tasks.py @@ -0,0 +1,143 @@ +import pytest +from fastapi import FastAPI +from fastapi.background import BackgroundTasks +from fastapi.testclient import TestClient + + +def test_background_tasks_basic(): + app = FastAPI() + executed_tasks = [] + + def task_func(name: str, value: int): + executed_tasks.append(f"{name}:{value}") + + @app.get("/test") + def test_endpoint(background_tasks: BackgroundTasks): + background_tasks.add_task(task_func, "test", 123) + return {"message": "success"} + + client = TestClient(app) + response = client.get("/test") + + assert response.status_code == 200 + assert response.json() == {"message": "success"} + + +def test_background_tasks_async_function(): + app = FastAPI() + executed_tasks = [] + + async def async_task(name: str): + executed_tasks.append(f"async:{name}") + + @app.get("/test") + def test_endpoint(background_tasks: BackgroundTasks): + background_tasks.add_task(async_task, "test_async") + return {"message": "success"} + + client = TestClient(app) + response = client.get("/test") + + assert response.status_code == 200 + + +def test_background_tasks_with_kwargs(): + app = FastAPI() + + def task_with_kwargs(name: str, value: int = 42): + pass + + @app.get("/test") + def test_endpoint(background_tasks: BackgroundTasks): + background_tasks.add_task(task_with_kwargs, "test", value=100) + return {"message": "success"} + + client = TestClient(app) + response = client.get("/test") + + assert response.status_code == 200 + + +def test_background_tasks_multiple_tasks(): + app = FastAPI() + + def task_one(): + pass + + def task_two(arg: str): + pass + + @app.get("/test") + def test_endpoint(background_tasks: BackgroundTasks): + background_tasks.add_task(task_one) + background_tasks.add_task(task_two, "argument") + return {"message": "success"} + + client = TestClient(app) + response = client.get("/test") + + assert response.status_code == 200 + + +def test_background_tasks_inheritance(): + tasks = BackgroundTasks() + + def simple_task(): + pass + + tasks.add_task(simple_task) + + assert len(tasks.tasks) == 1 + + +def test_background_tasks_with_complex_args(): + app = FastAPI() + + def complex_task(data: dict, items: list, count: int = 1): + pass + + @app.get("/test") + def test_endpoint(background_tasks: BackgroundTasks): + background_tasks.add_task( + complex_task, + {"key": "value"}, + [1, 2, 3], + count=5 + ) + return {"message": "success"} + + client = TestClient(app) + response = client.get("/test") + + assert response.status_code == 200 + + +def test_background_tasks_no_args(): + app = FastAPI() + + def no_args_task(): + pass + + @app.get("/test") + def test_endpoint(background_tasks: BackgroundTasks): + background_tasks.add_task(no_args_task) + return {"message": "success"} + + client = TestClient(app) + response = client.get("/test") + + assert response.status_code == 200 + + +def test_background_tasks_lambda_function(): + app = FastAPI() + + @app.get("/test") + def test_endpoint(background_tasks: BackgroundTasks): + background_tasks.add_task(lambda: None) + return {"message": "success"} + + client = TestClient(app) + response = client.get("/test") + + assert response.status_code == 200 diff --git a/tests/test_compat_new.py b/tests/test_compat_new.py new file mode 100644 index 0000000000000..622189025c9fc --- /dev/null +++ b/tests/test_compat_new.py @@ -0,0 +1,172 @@ +import pytest +from typing import List, Dict, Union, Optional +from fastapi._compat import ( + field_annotation_is_sequence, + field_annotation_is_scalar, + field_annotation_is_complex, + field_annotation_is_scalar_sequence, + value_is_sequence, + is_bytes_or_nonable_bytes_annotation, + is_uploadfile_or_nonable_uploadfile_annotation, + is_bytes_sequence_annotation, + is_uploadfile_sequence_annotation, + _annotation_is_sequence, + _annotation_is_complex, + _regenerate_error_with_loc, + _normalize_errors, +) +from pydantic import BaseModel +from starlette.datastructures import UploadFile +from tests.utils import needs_pydanticv1, needs_pydanticv2 + + +class SampleModel(BaseModel): + name: str + value: int + + +class TestFieldAnnotationFunctions: + def test_field_annotation_is_sequence(self): + assert field_annotation_is_sequence(List[str]) is True + assert field_annotation_is_sequence(list) is True + assert field_annotation_is_sequence(tuple) is True + assert field_annotation_is_sequence(set) is True + assert field_annotation_is_sequence(str) is False + assert field_annotation_is_sequence(int) is False + + def test_field_annotation_is_scalar(self): + assert field_annotation_is_scalar(str) is True + assert field_annotation_is_scalar(int) is True + assert field_annotation_is_scalar(float) is True + assert field_annotation_is_scalar(bool) is True + assert field_annotation_is_scalar(List[str]) is False + assert field_annotation_is_scalar(SampleModel) is False + + def test_field_annotation_is_complex(self): + assert field_annotation_is_complex(SampleModel) is True + assert field_annotation_is_complex(Dict[str, str]) is True + assert field_annotation_is_complex(List[str]) is True + assert field_annotation_is_complex(UploadFile) is True + assert field_annotation_is_complex(str) is False + assert field_annotation_is_complex(int) is False + + def test_field_annotation_is_scalar_sequence(self): + assert field_annotation_is_scalar_sequence(List[str]) is True + assert field_annotation_is_scalar_sequence(List[int]) is True + assert field_annotation_is_scalar_sequence(List[SampleModel]) is False + assert field_annotation_is_scalar_sequence(str) is False + + def test_union_annotations(self): + union_type = Union[str, int] + assert field_annotation_is_scalar(union_type) is True + + complex_union = Union[str, List[str]] + assert field_annotation_is_complex(complex_union) is True + + optional_str = Optional[str] + assert field_annotation_is_scalar(optional_str) is True + + def test_ellipsis_annotation(self): + assert field_annotation_is_scalar(...) is True + + +class TestValueFunctions: + def test_value_is_sequence(self): + assert value_is_sequence([1, 2, 3]) is True + assert value_is_sequence((1, 2, 3)) is True + assert value_is_sequence({1, 2, 3}) is True + assert value_is_sequence("string") is False + assert value_is_sequence(b"bytes") is False + assert value_is_sequence(123) is False + assert value_is_sequence(None) is False + + +class TestBytesAnnotations: + def test_is_bytes_annotation(self): + assert is_bytes_or_nonable_bytes_annotation(bytes) is True + assert is_bytes_or_nonable_bytes_annotation(Union[bytes, None]) is True + assert is_bytes_or_nonable_bytes_annotation(Optional[bytes]) is True + assert is_bytes_or_nonable_bytes_annotation(str) is False + assert is_bytes_or_nonable_bytes_annotation(int) is False + + def test_is_bytes_sequence_annotation(self): + assert is_bytes_sequence_annotation(List[bytes]) is True + assert is_bytes_sequence_annotation(List[Union[bytes, None]]) is True + assert is_bytes_sequence_annotation(List[str]) is False + assert is_bytes_sequence_annotation(bytes) is False + + +class TestUploadFileAnnotations: + def test_is_uploadfile_annotation(self): + assert is_uploadfile_or_nonable_uploadfile_annotation(UploadFile) is True + assert is_uploadfile_or_nonable_uploadfile_annotation(Union[UploadFile, None]) is True + assert is_uploadfile_or_nonable_uploadfile_annotation(Optional[UploadFile]) is True + assert is_uploadfile_or_nonable_uploadfile_annotation(str) is False + assert is_uploadfile_or_nonable_uploadfile_annotation(bytes) is False + + def test_is_uploadfile_sequence_annotation(self): + assert is_uploadfile_sequence_annotation(List[UploadFile]) is True + assert is_uploadfile_sequence_annotation(List[Union[UploadFile, None]]) is True + assert is_uploadfile_sequence_annotation(List[str]) is False + assert is_uploadfile_sequence_annotation(UploadFile) is False + + +class TestPrivateAnnotationFunctions: + def test_annotation_is_sequence(self): + assert _annotation_is_sequence(list) is True + assert _annotation_is_sequence(tuple) is True + assert _annotation_is_sequence(set) is True + assert _annotation_is_sequence(str) is False + assert _annotation_is_sequence(bytes) is False + assert _annotation_is_sequence(int) is False + + def test_annotation_is_complex(self): + assert _annotation_is_complex(SampleModel) is True + assert _annotation_is_complex(dict) is True + assert _annotation_is_complex(UploadFile) is True + assert _annotation_is_complex(list) is True + assert _annotation_is_complex(str) is False + assert _annotation_is_complex(int) is False + + +class TestErrorFunctions: + def test_regenerate_error_with_loc(self): + errors = [{"type": "missing", "loc": ("field",), "msg": "field required"}] + loc_prefix = ("body",) + + result = _regenerate_error_with_loc(errors=errors, loc_prefix=loc_prefix) + + assert len(result) == 1 + assert result[0]["loc"] == ("body", "field") + assert result[0]["type"] == "missing" + + def test_regenerate_error_with_empty_loc(self): + errors = [{"type": "missing", "msg": "field required"}] + loc_prefix = ("body",) + + result = _regenerate_error_with_loc(errors=errors, loc_prefix=loc_prefix) + + assert len(result) == 1 + assert result[0]["loc"] == ("body",) + + def test_normalize_errors_basic(self): + errors = [{"type": "missing", "loc": ("field",), "msg": "field required"}] + result = _normalize_errors(errors) + assert result == errors + + +class TestCompatibilityEdgeCases: + def test_none_annotations(self): + assert field_annotation_is_scalar(None) is True + assert field_annotation_is_complex(None) is False + assert field_annotation_is_sequence(None) is False + + def test_nested_union_types(self): + nested_union = Union[str, Union[int, float]] + assert field_annotation_is_scalar(nested_union) is True + + def test_complex_nested_sequences(self): + complex_sequence = List[Dict[str, Union[str, int]]] + assert field_annotation_is_sequence(complex_sequence) is True + assert field_annotation_is_complex(complex_sequence) is True + assert field_annotation_is_scalar_sequence(complex_sequence) is False diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py new file mode 100644 index 0000000000000..9d4d2ed22f737 --- /dev/null +++ b/tests/test_concurrency.py @@ -0,0 +1,102 @@ +import pytest +from contextlib import contextmanager +from fastapi.concurrency import contextmanager_in_threadpool + + +class TestContextManagerInThreadpool: + @pytest.mark.anyio + async def test_successful_context_manager(self): + @contextmanager + def test_cm(): + yield "test_value" + + async with contextmanager_in_threadpool(test_cm()) as value: + assert value == "test_value" + + @pytest.mark.anyio + async def test_context_manager_with_exception_in_body(self): + @contextmanager + def test_cm(): + try: + yield "test_value" + except ValueError: + raise + + with pytest.raises(ValueError): + async with contextmanager_in_threadpool(test_cm()) as value: + assert value == "test_value" + raise ValueError("Test exception") + + @pytest.mark.anyio + async def test_context_manager_exit_handling(self): + exit_called = False + + @contextmanager + def tracking_cm(): + nonlocal exit_called + try: + yield "test_value" + finally: + exit_called = True + + async with contextmanager_in_threadpool(tracking_cm()) as value: + assert value == "test_value" + + assert exit_called is True + + @pytest.mark.anyio + async def test_context_manager_with_enter_exception(self): + @contextmanager + def failing_enter_cm(): + raise RuntimeError("Enter failed") + yield "never_reached" + + with pytest.raises(RuntimeError, match="Enter failed"): + async with contextmanager_in_threadpool(failing_enter_cm()): + pass + + @pytest.mark.anyio + async def test_context_manager_with_exit_exception(self): + @contextmanager + def failing_exit_cm(): + try: + yield "test_value" + finally: + raise RuntimeError("Exit failed") + + with pytest.raises(RuntimeError, match="Exit failed"): + async with contextmanager_in_threadpool(failing_exit_cm()) as value: + assert value == "test_value" + + @pytest.mark.anyio + async def test_context_manager_suppresses_exception(self): + @contextmanager + def suppressing_cm(): + try: + yield "test_value" + except ValueError: + pass + + async with contextmanager_in_threadpool(suppressing_cm()) as value: + assert value == "test_value" + raise ValueError("This should be suppressed") + + @pytest.mark.anyio + async def test_context_manager_with_complex_state(self): + state = {"entered": False, "exited": False, "value": None} + + @contextmanager + def stateful_cm(): + state["entered"] = True + state["value"] = "context_value" + try: + yield state["value"] + finally: + state["exited"] = True + + async with contextmanager_in_threadpool(stateful_cm()) as value: + assert state["entered"] is True + assert state["exited"] is False + assert value == "context_value" + + assert state["exited"] is True diff --git a/tests/test_fastapi_cli.py b/tests/test_fastapi_cli.py index a5c10778ad7e8..1d7db3ead007d 100644 --- a/tests/test_fastapi_cli.py +++ b/tests/test_fastapi_cli.py @@ -1,6 +1,6 @@ import subprocess import sys -from unittest.mock import patch +from unittest.mock import Mock, patch import fastapi.cli import pytest @@ -30,3 +30,45 @@ def test_fastapi_cli_not_installed(): with pytest.raises(RuntimeError) as exc_info: fastapi.cli.main() assert "To use the fastapi command, please install" in str(exc_info.value) + + +def test_fastapi_cli_help(): + result = subprocess.run( + [sys.executable, "-m", "fastapi", "--help"], + capture_output=True, + encoding="utf-8", + ) + assert result.returncode == 0 or "fastapi command" in result.stdout + + +def test_fastapi_main_function_direct(): + with patch.object(fastapi.cli, "cli_main", None): + with pytest.raises(RuntimeError) as exc_info: + fastapi.cli.main() + assert "fastapi[standard]" in str(exc_info.value) + + +def test_fastapi_main_with_mock_cli(): + mock_cli = Mock() + with patch.object(fastapi.cli, "cli_main", mock_cli): + fastapi.cli.main() + mock_cli.assert_called_once() + + +def test_fastapi_cli_import_error_message(): + with patch.object(fastapi.cli, "cli_main", None): + with pytest.raises(RuntimeError) as exc_info: + fastapi.cli.main() + error_msg = str(exc_info.value) + assert 'To use the fastapi command, please install "fastapi[standard]"' in error_msg + assert "pip install" in error_msg + + +def test_fastapi_cli_module_execution(): + result = subprocess.run( + [sys.executable, "-c", "import fastapi.cli; fastapi.cli.main()"], + capture_output=True, + encoding="utf-8", + ) + assert result.returncode != 0 + assert "fastapi[standard]" in result.stderr diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000000000..072a7d14e7edd --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,201 @@ +import pytest +from fastapi import FastAPI +from fastapi.utils import ( + is_body_allowed_for_status_code, + get_path_param_names, + create_model_field, + create_cloned_field, + generate_unique_id, + deep_dict_update, + get_value_or_default, +) +from fastapi.datastructures import DefaultPlaceholder +from fastapi.routing import APIRoute +from pydantic import BaseModel +from tests.utils import needs_pydanticv1, needs_pydanticv2 + + +class TestIsBodyAllowedForStatusCode: + def test_none_status_code(self): + assert is_body_allowed_for_status_code(None) is True + + def test_default_status_code(self): + assert is_body_allowed_for_status_code("default") is True + + def test_pattern_status_codes(self): + for pattern in ["1XX", "2XX", "3XX", "4XX", "5XX"]: + assert is_body_allowed_for_status_code(pattern) is True + + def test_success_status_codes(self): + for code in [200, 201, 202, 203]: + assert is_body_allowed_for_status_code(code) is True + + def test_no_body_status_codes(self): + for code in [204, 205, 304]: + assert is_body_allowed_for_status_code(code) is False + + def test_informational_status_codes(self): + for code in [100, 101, 102]: + assert is_body_allowed_for_status_code(code) is False + + def test_string_status_codes(self): + assert is_body_allowed_for_status_code("200") is True + assert is_body_allowed_for_status_code("204") is False + + +class TestGetPathParamNames: + def test_no_params(self): + assert get_path_param_names("/users") == set() + + def test_single_param(self): + assert get_path_param_names("/users/{user_id}") == {"user_id"} + + def test_multiple_params(self): + assert get_path_param_names("/users/{user_id}/posts/{post_id}") == {"user_id", "post_id"} + + def test_complex_params(self): + assert get_path_param_names("/api/v1/{version}/users/{user_id:int}") == {"version", "user_id:int"} + + +class TestCreateModelField: + def test_basic_field_creation(self): + field = create_model_field("test_field", str, default="default_value") + assert field.name == "test_field" + assert field.type_ == str + + def test_required_field(self): + field = create_model_field("required_field", int, required=True) + assert field.required is True + + def test_field_with_alias(self): + field = create_model_field("field_name", str, alias="fieldAlias") + assert field.alias == "fieldAlias" + + def test_field_creation_with_none_type(self): + field = create_model_field("none_field", type(None)) + assert field.name == "none_field" + + +class TestCreateClonedField: + @needs_pydanticv2 + def test_clone_field_pydantic_v2(self): + original_field = create_model_field("original", str, default="test") + cloned_field = create_cloned_field(original_field) + assert cloned_field.name == original_field.name + assert cloned_field.type_ == original_field.type_ + + @needs_pydanticv1 + def test_clone_field_pydantic_v1(self): + original_field = create_model_field("original", str, default="test") + cloned_field = create_cloned_field(original_field) + assert cloned_field.name == original_field.name + assert cloned_field.type_ == original_field.type_ + + +class TestDeepDictUpdate: + def test_simple_update(self): + main_dict = {"a": 1, "b": 2} + update_dict = {"c": 3} + deep_dict_update(main_dict, update_dict) + assert main_dict == {"a": 1, "b": 2, "c": 3} + + def test_nested_dict_update(self): + main_dict = {"a": {"x": 1, "y": 2}, "b": 3} + update_dict = {"a": {"z": 3}, "c": 4} + deep_dict_update(main_dict, update_dict) + assert main_dict == {"a": {"x": 1, "y": 2, "z": 3}, "b": 3, "c": 4} + + def test_list_concatenation(self): + main_dict = {"items": [1, 2]} + update_dict = {"items": [3, 4]} + deep_dict_update(main_dict, update_dict) + assert main_dict == {"items": [1, 2, 3, 4]} + + def test_value_replacement(self): + main_dict = {"a": 1} + update_dict = {"a": 2} + deep_dict_update(main_dict, update_dict) + assert main_dict == {"a": 2} + + def test_empty_dicts(self): + main_dict = {} + update_dict = {"a": 1} + deep_dict_update(main_dict, update_dict) + assert main_dict == {"a": 1} + + def test_none_values(self): + main_dict = {"a": None} + update_dict = {"a": "value"} + deep_dict_update(main_dict, update_dict) + assert main_dict == {"a": "value"} + + +class TestGetValueOrDefault: + def test_first_non_default_returned(self): + placeholder = DefaultPlaceholder(value="default") + result = get_value_or_default(placeholder, "value1", "value2") + assert result == "value1" + + def test_all_defaults_returns_first(self): + placeholder1 = DefaultPlaceholder(value="default1") + placeholder2 = DefaultPlaceholder(value="default2") + result = get_value_or_default(placeholder1, placeholder2) + assert result == placeholder1 + + def test_single_value(self): + result = get_value_or_default("single_value") + assert result == "single_value" + + def test_no_values(self): + result = get_value_or_default("first_value") + assert result == "first_value" + + def test_mixed_values(self): + placeholder = DefaultPlaceholder(value="default") + result = get_value_or_default("first", placeholder, "third") + assert result == "first" + + +class TestGenerateUniqueId: + def test_basic_route_id(self): + app = FastAPI() + + @app.get("/test") + def test_endpoint(): + return {"test": "value"} + + route = app.routes[0] + unique_id = generate_unique_id(route) + assert isinstance(unique_id, str) + assert len(unique_id) > 0 + + def test_route_with_path_params(self): + app = FastAPI() + + @app.get("/users/{user_id}") + def get_user(user_id: int): + return {"user_id": user_id} + + route = app.routes[0] + unique_id = generate_unique_id(route) + assert isinstance(unique_id, str) + assert len(unique_id) > 0 + + def test_different_methods_different_ids(self): + app = FastAPI() + + @app.get("/test") + def get_test(): + return {"method": "GET"} + + @app.post("/test") + def post_test(): + return {"method": "POST"} + + get_route = app.routes[0] + post_route = app.routes[1] + + get_id = generate_unique_id(get_route) + post_id = generate_unique_id(post_route) + + assert get_id != post_id