# Unit 3 Testing FastAPI Applications with pytest


So far, we've established a strong foundation by building a basic FastAPI application and integrating a Machine Learning model for diamond price prediction. Now, we'll focus on a crucial aspect of production-ready applications: **testing**.

Testing is a vital part of developing reliable APIs, especially when they serve Machine Learning models that make critical predictions. By the end of this lesson, you'll learn how to create comprehensive tests for your diamond price prediction API using **pytest**, ensuring your endpoints function correctly and handle various scenarios appropriately.

Let's build robust tests that give us confidence in our API's functionality\!

### Why Testing Matters for ML APIs

Before diving into the code, let's understand why testing is particularly important for Machine Learning APIs:

  * **Model behavior validation** - We need to ensure predictions are made correctly and consistently.
  * **Input validation** - Our API must properly handle both valid and invalid inputs.
  * **Error handling** - The system should respond appropriately when something goes wrong.
  * **Reliability** - APIs in production need to be dependable under various conditions.

In this lesson, we'll focus on **unit tests** that verify our API endpoints function correctly in isolation. We'll use mocking techniques to replace the actual ML model with controlled test doubles, allowing us to test the API's logic independently from the model implementation. This approach lets us verify that our FastAPI routes, data validation, and error handling work as expected without being affected by the complexities of the actual ML model. Note that **effective API tests** should be **isolated** from external dependencies (like the actual ML model), cover typical usage scenarios, test error conditions to ensure graceful failure, and verify expected responses, including status codes and data formats.

### Setting Up the Testing Environment

The **pytest** framework provides an excellent foundation for testing our API. It offers a simple syntax while providing powerful features like parametrization and fixtures. FastAPI integrates wonderfully with **pytest** through its **TestClient** class.

Let's start by setting up our test module:

```python
import pytest
from fastapi.testclient import TestClient
import numpy as np
from fastapi import HTTPException
from main import app

# Create a test client
client = TestClient(app)
```

The **TestClient** wraps our FastAPI application, allowing us to make requests directly to our endpoints without actually running a server. This client simulates HTTP requests and captures the responses, making it easy to verify that our endpoints behave as expected.

This setup establishes the foundation for all our tests. The **client** variable will be used throughout our test functions to interact with our API endpoints, just as a real client would in production. By importing our application instance directly, we're testing the exact same code that would run in a production environment.

### Testing Basic Endpoints

Let's begin by testing our basic endpoints. First, we'll test the root endpoint that welcomes users to our API:

```python
def test_root_endpoint():
    """
    Test the root endpoint redirects to documentation.
    """
    response = client.get("/")
    assert response.status_code == 200
    assert "Welcome" in response.json()["message"]
```

This test makes a `GET` request to the root path and verifies that the response has a 200 status code (OK) and that the response JSON contains a welcome message.

Next, let's test the health check endpoint:

```python
def test_health_check():
    """
    Test the health check endpoint returns proper status.
    """
    response = client.get("/health")
    assert response.status_code == 200
    assert response.json()["status"] == "healthy"
    assert "api_version" in response.json()
```

Health check endpoints are critical for production systems as they allow monitoring tools to verify your service is operating correctly. This test confirms our endpoint returns a 200 status code and includes both a "healthy" status and the API version in its response.

These simple tests follow a common pattern in API testing that you'll use repeatedly: make a request, check the status code, and verify the response content matches expectations. While straightforward, they provide valuable assurance that your API's foundation is solid.

### Mocking Model Predictions with Fixtures

When unit-testing prediction endpoints, we face a challenge: we don't want our tests to depend on actual model files or be affected by model behavior changes.

In pytest, **fixtures** are special functions that provide reusable test dependencies. They're a powerful way to set up preconditions for your tests, manage test resources, and inject dependencies. Fixtures help create a consistent testing environment and reduce code duplication across tests.

For our ML API testing, we'll use fixtures to implement **mocking** — creating simplified substitutes for the model that return predefined values. This approach allows us to test our API logic independently from the actual model implementation.

Let's create a fixture to handle our model mocking:

```python
@pytest.fixture
def mock_model_and_preprocessor():
    """Fixture providing mock model and preprocessor."""
    class MockModel:
        def predict(self, X):
            return np.array([1500.0])  # Mock prediction

    class MockPreprocessor:
        def transform(self, X):
            return np.array([[1.0, 2.0, 3.0]])  # Mock transformed features
    
    return MockModel(), MockPreprocessor()
```

This fixture creates two mock classes:

  * **MockModel** with a `predict` method that always returns a fixed prediction value;
  * **MockPreprocessor** with a `transform` method that returns a fixed feature array.

By using these mock objects, we can test our API's logic without depending on actual model files or behavior. This makes our tests more reliable and faster to run.

### Testing Prediction Endpoints with Mocks

With our mocking fixture in place, we can now test our prediction endpoint with realistic but controlled scenarios:

```python
@pytest.mark.parametrize("diamond_features", [
    # Valid diamond features
    {
        "carat": 0.5,
        "cut": "Ideal",
        "color": "E",
        "clarity": "VS2",
        "depth": 61.5,
        "table": 55.0,
        "x": 5.15,
        "y": 5.20,
        "z": 3.18
    }])
def test_predict_endpoint_valid_input(diamond_features, monkeypatch, mock_model_and_preprocessor):
    """
    Test the predict endpoint with valid input data.
    This test mocks the model prediction to avoid requiring an actual model.
    """
    # Get mock model and preprocessor from fixture
    model, preprocessor = mock_model_and_preprocessor

    # Apply the monkey patch to avoid loading the real model
    monkeypatch.setattr("main.get_model", lambda: (model, preprocessor))

    # Call the prediction endpoint
    response = client.post("/predict", json=diamond_features)

    # Verify the response
    assert response.status_code == 200
    result = response.json()
    assert "predicted_price" in result
    assert result["predicted_price"] == 1500.0
    assert result["diamond_features"] == diamond_features
```

This test demonstrates several powerful testing techniques:

  * **Parametrization** with `@pytest.mark.parametrize` lets you run the same test with different inputs — you could easily add more test cases to this list.
  * **Monkey patching** with the **monkeypatch** fixture temporarily replaces functions or attributes at runtime. This built-in pytest fixture allows us to modify behavior without changing the actual code. Here, we use it to replace our real `get_model` function with a lambda that returns our mock objects, eliminating the dependency on the actual model file during tests.

After setting up these mocks, we make a `POST` request with our test diamond features and verify that the response contains the expected prediction (1500.0) and returns the original features. This approach lets us test our API logic independently from the actual Machine Learning model.

### Testing Error Handling Scenarios

A robust API should handle invalid inputs gracefully. Let's test how our prediction endpoint responds to various error conditions:

```python
@pytest.mark.parametrize("invalid_features, expected_status_code", [
    # Missing required field
    ({"cut": "Ideal", "color": "E", "clarity": "VS2", "depth": 61.5,
      "table": 55.0, "x": 5.15, "y": 5.20, "z": 3.18}, 422),

    # Invalid data type (string instead of float)
    ({"carat": "not_a_number", "cut": "Ideal", "color": "E", "clarity": "VS2",
      "depth": 61.5, "table": 55.0, "x": 5.15, "y": 5.20, "z": 3.18}, 422),

    # Invalid value (negative carat)
    ({"carat": -0.5, "cut": "Ideal", "color": "E", "clarity": "VS2",
      "depth": 61.5, "table": 55.0, "x": 5.15, "y": 5.20, "z": 3.18}, 422),
])
def test_predict_endpoint_invalid_input(invalid_features, expected_status_code, monkeypatch):
    """
    Test the predict endpoint with invalid input data.
    """
    # Mock the model to avoid loading failure
    monkeypatch.setattr("main.get_model", lambda: (object(), object()))

    # Call the prediction endpoint with invalid data
    response = client.post("/predict", json=invalid_features)

    # Verify the response has the expected error status code
    assert response.status_code == expected_status_code
```

This test uses parametrization brilliantly to examine multiple error scenarios with minimal code duplication:

  * A request missing a required field (carat is omitted).
  * A request with an invalid data type (string instead of number for carat).
  * A request with an invalid value (negative carat).

In each case, we expect a 422 Unprocessable Entity status code, which FastAPI automatically returns when request validation fails. By testing these scenarios, we ensure our API's data validation works correctly and prevents invalid data from reaching our model.

### Testing Service Availability

What happens when our model isn't available? This is a critical scenario to handle properly in production environments. Let's test this case:

```python
def test_predict_endpoint_model_not_loaded(monkeypatch):
    """
    Test the predict endpoint when the model is not available.
    """
    # Define a mock function that raises an exception
    def mock_get_model():
        raise HTTPException(status_code=503, detail="Model not available")

    # Apply the monkey patch
    monkeypatch.setattr("main.get_model", mock_get_model)

    # Valid diamond features
    diamond_features = {
        "carat": 0.5,
        "cut": "Ideal",
        "color": "E",
        "clarity": "VS2",
        "depth": 61.5,
        "table": 55.0,
        "x": 5.15,
        "y": 5.20,
        "z": 3.18
    }

    # Call the prediction endpoint
    response = client.post("/predict", json=diamond_features)

    # This should fail with a 503 Service Unavailable
    assert response.status_code == 503
```

This test simulates a scenario where our `get_model` function fails to load the model and raises an `HTTPException`. We use a clean, straightforward approach to simulate this failure. This tests if our dependency injection system properly propagates errors to clients with appropriate status codes.

Why is this important? In production, models might fail to load for many reasons — corrupt files, memory issues, or version incompatibilities. When this happens, your API should provide clear feedback rather than crashing or hanging. This test ensures your error handling system works correctly for this critical failure mode.

### Running Your Tests with pytest

Now that we've written a comprehensive test suite, let's explore how to actually run these tests using **pytest**. Running tests regularly is essential to catch issues early in your development process.

First, ensure your test files follow the pytest naming convention - they should be named `test_*.py` or `*_test.py` to be automatically discovered. A common convention is to place these files in a `tests` directory for better organization, but note this is not explicitly required by **pytest**.

To run all your tests, simply move to the project directory and execute:

```bash
pytest
```

This command discovers and runs all test files in your project. You can run specific test files by providing the path:

```bash
pytest tests/test_api.py
```

You can also run pytest programmatically from Python code:

```python
# Running pytest programmatically
import pytest

# Run all tests
pytest.main()

# Run a specific test file
pytest.main(["tests/test_api.py"])
```

### Conclusion

Excellent work\! You've now learned how to create a comprehensive test suite for your FastAPI applications. These tests verify that your endpoints function correctly, validate inputs properly, and handle errors gracefully — all essential qualities for a production-ready API.

Now, time for some practice. Keep up the great work\!

## Testing FastAPI with Pytest

Welcome back! You've done a fantastic job so far, especially with setting up your FastAPI application. Now, it's time to dive into testing your application using pytest. This is a crucial step in ensuring your API behaves as expected.

In this exercise, your goal is to set up the testing environment by importing the app from the main module and creating a TestClient. This client will allow you to simulate HTTP requests to your API endpoints, making it possible to verify their behavior without running a server.

Here's a hint to guide you: think about how you can wrap your FastAPI application with a client that acts like a real-world user. This setup is the backbone of your testing environment, so make sure it's rock-solid.

Note: in this and the following practices, the "Run" button will execute pytest, so there's no need to run it programmatically..

Once you have these components in place, you'll be ready to test your API endpoints effectively. Let's get started and ensure your testing environment is ready for action!

```python
# main.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import numpy as np

app = FastAPI()

# Define the schema for diamond features using Pydantic
class DiamondFeatures(BaseModel):
    carat: float
    cut: str
    color: str
    clarity: str
    depth: float
    table: float
    x: float
    y: float
    z: float

# Dummy model and preprocessor implementations for demonstration purposes
class DummyModel:
    def predict(self, X):
        # Return a dummy prediction; in a real scenario, this would use a trained model.
        return np.array([round(np.sum(X[0]) * 10, 2)])

class DummyPreprocessor:
    def transform(self, data):
        # For demonstration, return a dummy transformed array.
        # In practice, you would process 'data' to match the model's requirements.
        return np.array([[1.0, 2.0, 3.0]])

# Function to get the model and preprocessor
def get_model():
    # In production, load and return your actual ML model and preprocessor.
    return DummyModel(), DummyPreprocessor()

# Define the root endpoint that welcomes users
@app.get("/")
def read_root():
    return {"message": "Welcome to the Diamond Price Prediction API!"}

# Health check endpoint to verify the API status
@app.get("/health")
def health_check():
    return {"status": "healthy", "api_version": "1.0.0"}

# Prediction endpoint that uses the model to predict diamond prices
@app.post("/predict")
def predict(diamond_features: DiamondFeatures):
    try:
        model, preprocessor = get_model()
    except Exception:
        raise HTTPException(status_code=503, detail="Model not available")
    
    # Preprocess the input features and generate a prediction
    transformed_features = preprocessor.transform(diamond_features.dict())
    prediction = model.predict(transformed_features)
    
    return {
        "predicted_price": prediction[0],
        "diamond_features": diamond_features.dict()
    }

```

```python
# test_main.py
"""
Test Module for Diamond Price Prediction API

This module contains tests for the FastAPI application endpoints
using the TestClient from FastAPI's testing toolkit.
"""

import pytest
import numpy as np
from fastapi import HTTPException

# TODO: Import the app from the main module

# TODO: Create a test client for the FastAPI application

def test_root_endpoint():
    """
    Test the root endpoint redirects to documentation.
    """
    response = client.get("/")
    assert response.status_code == 200
    assert "Welcome" in response.json()["message"]

def test_health_check():
    """
    Test the health check endpoint returns proper status.
    """
    response = client.get("/health")
    assert response.status_code == 200
    assert response.json()["status"] == "healthy"
    assert "api_version" in response.json()

@pytest.fixture
def mock_model_and_preprocessor():
    """Fixture providing mock model and preprocessor."""
    class MockModel:
        def predict(self, X):
            return np.array([1500.0])  # Mock prediction
            
    class MockPreprocessor:
        def transform(self, X):
            return np.array([[1.0, 2.0, 3.0]])  # Mock transformed features
    
    return MockModel(), MockPreprocessor()

@pytest.mark.parametrize("diamond_features", [
    # Valid diamond features
    {
        "carat": 0.5,
        "cut": "Ideal",
        "color": "E",
        "clarity": "VS2",
        "depth": 61.5,
        "table": 55.0,
        "x": 5.15,
        "y": 5.20,
        "z": 3.18
    }
])
def test_predict_endpoint_valid_input(diamond_features, monkeypatch, mock_model_and_preprocessor):
    """
    Test the predict endpoint with valid input data.
    This test mocks the model prediction to avoid requiring an actual model.
    """
    # Get mock model and preprocessor from fixture
    model, preprocessor = mock_model_and_preprocessor
    
    # Apply the monkey patch to avoid loading the real model
    monkeypatch.setattr("main.get_model", lambda: (model, preprocessor))
    
    # Call the prediction endpoint
    response = client.post("/predict", json=diamond_features)
    
    # Verify the response
    assert response.status_code == 200
    result = response.json()
    assert "predicted_price" in result
    assert result["predicted_price"] == 1500.0
    assert result["diamond_features"] == diamond_features

@pytest.mark.parametrize("invalid_features, expected_status_code", [
    # Missing required field
    ({"cut": "Ideal", "color": "E", "clarity": "VS2", "depth": 61.5, 
      "table": 55.0, "x": 5.15, "y": 5.20, "z": 3.18}, 422),
    
    # Invalid data type (string instead of float)
    ({"carat": "not_a_number", "cut": "Ideal", "color": "E", "clarity": "VS2", 
      "depth": 61.5, "table": 55.0, "x": 5.15, "y": 5.20, "z": 3.18}, 422),
    
    # Invalid value (negative carat)
    ({"carat": -0.5, "cut": "Ideal", "color": "E", "clarity": "VS2", 
      "depth": 61.5, "table": 55.0, "x": 5.15, "y": 5.20, "z": 3.18}, 422),
])
def test_predict_endpoint_invalid_input(invalid_features, expected_status_code, monkeypatch):
    """
    Test the predict endpoint with invalid input data.
    """
    # Mock the model to avoid loading failure
    monkeypatch.setattr("main.get_model", lambda: (object(), object()))
    
    # Call the prediction endpoint with invalid data
    response = client.post("/predict", json=invalid_features)
    
    # Verify the response has the expected error status code
    assert response.status_code == expected_status_code

def test_predict_endpoint_model_not_loaded(monkeypatch):
    """
    Test the predict endpoint when the model is not available.
    """
    # Define a mock function that raises an exception
    def mock_get_model():
        raise HTTPException(status_code=503, detail="Model not available")
    
    # Apply the monkey patch
    monkeypatch.setattr("main.get_model", mock_get_model)
    
    # Valid diamond features
    diamond_features = {
        "carat": 0.5,
        "cut": "Ideal",
        "color": "E",
        "clarity": "VS2",
        "depth": 61.5,
        "table": 55.0,
        "x": 5.15,
        "y": 5.20,
        "z": 3.18
    }
    
    # Call the prediction endpoint
    response = client.post("/predict", json=diamond_features)
    
    # This should fail with a 503 Service Unavailable
    assert response.status_code == 503
```

## Testing Your API's Welcome Message

## Mocking Models with Pytest Fixtures

## Testing API Resilience with Invalid Inputs