Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix checking parameter type #110

Merged
merged 1 commit into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion resultsdb/controllers/api_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from flask import Blueprint, jsonify, render_template
from flask import current_app as app
from flask_pydantic import validate
from pydantic import BaseModel

from resultsdb.models import db
from resultsdb.authorization import match_testcase_permissions, verify_authorization
Expand All @@ -21,6 +22,20 @@
api = Blueprint("api_v3", __name__)


def ensure_dict_input(cls):
"""
Wraps Pydantic model to ensure that the input type is dict.

This is a workaround for a bug in flask-pydantic that causes validation to
fail with unexpected exception.
"""

class EnsureJsonObject(BaseModel):
__root__: cls

return EnsureJsonObject


def permissions():
return app.config.get("PERMISSIONS", [])

Expand Down Expand Up @@ -68,7 +83,7 @@ def create_endpoint(params_class, oidc, provider):

@oidc.token_auth(provider)
@validate()
def create(body: params_class):
def create(body: ensure_dict_input(params_class)):
return create_result(body)

def get_schema():
Expand Down
110 changes: 110 additions & 0 deletions testing/test_api_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,113 @@ def test_api_v3_consistency(params_class, client):
assert r.status_code == 200, r.text
assert f"POST /api/v3/results/{artifact_type}s" in r.text
assert f'<a class="anchor-link" href="#results/{artifact_type}s">#</a>' in r.text


@pytest.mark.parametrize("params_class", RESULTS_PARAMS_CLASSES)
def test_api_v3_bad_param_type_int(params_class, client):
"""
Passing unexpected JSON type must propagate an error to the user.
"""
artifact_type = params_class.artifact_type()
r = client.post(f"/api/v3/results/{artifact_type}s", json=0)
assert r.status_code == 400, r.text
assert r.json == {
"validation_error": {
"body_params": [
{
"loc": ["__root__"],
"msg": "value is not a valid dict",
"type": "type_error.dict",
}
]
}
}


@pytest.mark.parametrize("params_class", RESULTS_PARAMS_CLASSES)
def test_api_v3_bad_param_type_str(params_class, client):
"""
Passing unexpected JSON type must propagate an error to the user.
"""
artifact_type = params_class.artifact_type()
r = client.post(f"/api/v3/results/{artifact_type}s", json="BAD")
assert r.status_code == 400, r.text
assert r.json == {
"validation_error": {
"body_params": [
{
"loc": ["__root__"],
"msg": "value is not a valid dict",
"type": "type_error.dict",
}
]
}
}


@pytest.mark.parametrize("params_class", RESULTS_PARAMS_CLASSES)
def test_api_v3_bad_param_type_null(params_class, client):
"""
Passing unexpected JSON type must propagate an error to the user.
"""
artifact_type = params_class.artifact_type()
r = client.post(
f"/api/v3/results/{artifact_type}s", content_type="application/json", data="null"
)
assert r.status_code == 400, r.text
assert r.json == {
"validation_error": {
"body_params": [
{
"loc": ["__root__"],
"msg": "none is not an allowed value",
"type": "type_error.none.not_allowed",
}
]
}
}


@pytest.mark.parametrize("params_class", RESULTS_PARAMS_CLASSES)
def test_api_v3_bad_param_invalid_json(params_class, client):
"""
Passing unexpected JSON type must propagate an error to the user.
"""
artifact_type = params_class.artifact_type()
r = client.post(f"/api/v3/results/{artifact_type}s", content_type="application/json", data="{")
assert r.status_code == 400, r.text
assert r.json == {"message": "Bad request"}


@pytest.mark.parametrize("params_class", RESULTS_PARAMS_CLASSES)
def test_api_v3_example(params_class, client):
"""
Passing unexpected JSON type must propagate an error to the user.
"""
artifact_type = params_class.artifact_type()
example = params_class.example().dict()
r = client.post(f"/api/v3/results/{artifact_type}s", json=example)
assert r.status_code == 201, r.text


@pytest.mark.parametrize("params_class", RESULTS_PARAMS_CLASSES)
def test_api_v3_missing_param(params_class, client):
"""
Passing unexpected JSON type must propagate an error to the user.
"""
artifact_type = params_class.artifact_type()
example = params_class.example().dict()
del example["outcome"]
r = client.post(f"/api/v3/results/{artifact_type}s", json=example)
assert r.status_code == 400, r.text
assert r.json == {
"validation_error": {
"body_params": [
{
"loc": ["__root__", "outcome"],
"msg": "field required",
"type": "value_error.missing",
}
]
}
}
Loading