Skip to content

Commit

Permalink
Switch to Connexion 3 framework
Browse files Browse the repository at this point in the history
This is a huge PR being result of over a 100 commits
made by a number of people in #apache#36052 and apache#37638. It
switches to Connexion 3 as the driving backend
implementation for both - Airflow REST APIs and Flask
app that powers Airflow UI. It should be largely
backwards compatible when it comes to behaviour of
both APIs and Airflow Webserver views, however due to
decisions made by Connexion 3 maintainers, it changes
heavily the technology stack used under-the-hood:

1) Connexion App is an ASGI-compatible Open-API spec-first
   framework using ASGI as an interface between webserver
   and Python web application. ASGI is an asynchronous
   successor of WSGI.

2) Connexion itself is using Starlette to run asynchronous
   web services in Python.

3) We continue using gunicorn appliation server that still
   uses WSGI standard, which means that we can continue using
   Flask and we are usig standard Uvicorn ASGI webserver that
   converts the ASGI interface to WSGI interface of Gunicorn

Some of the problems handled in this PR

There were two problem was with session handling:

* the get_session_cookie - did not get the right cookie - it returned
  "session" string. The right fix was to change cookie_jar into
  cookie.jar because this is where apparently TestClient of starlette
  is holding the cookies (visible when you debug)

* The client does not accept "set_cookie" method - it accepts passing
  cookies via "cookies" dictionary - this is the usual httpx client
  - see  https://www.starlette.io/testclient/ - so we have to set
  cookie directly in the get method to try it out

Add "flask_client_with_login" for tests that neeed flask client

Some tests require functionality not available to Starlette test
client as they use Flask test client specific features - for those
we have an option to get flask test client instead of starlette
one.

Fix error handling for new connection 3 approach

Error handling for Connexion 3 integration needed to be reworked.

The way it behaves is much the same as it works in main:

* for API errors - we get application/problem+json responses
* for UI erros - we have rendered views
* for redirection - we have correct location header (it's been
  missing)
* the api error handled was not added as available middleware
  in the www tests

It should fix all test_views_base.py tests which were failing
on lack of location header for redirection.

Fix wrong response is tests_view_cluster_activity

The problem in the test was that Starlette Test Client opens a new
connection and start new session, while flask test client
uses the same database session. The test did not show data because
the data was not committed and session was not closed - which also
failed sqlite local tests with "database is locked" error.

Fix test_extra_links

The tests were failing again because the dagrun created was not
committed and session not closed. This worked with flask client
that used the same session accidentally but did not work with
test client from Starlette. Also it caused "database locked"
in sqlite / local tests.

Switch to non-deprecated auth manager

Fix to test_views_log.py

This PR partially fixes sessions and request parameter for
test_views_log. Some tests are still failing but for different reasons -
to be investigated.

Fix views_custom_user_views tests

The problem in those tests was that the check in security manager
was based on the assumption that the security manager was shared
between the client and test flask application - because they
were coming from the same flask app. But when we use starlette,
the call goes to a new process started and the user is deleted in
the database - so the shortcut of checking the security manager
did not work.

The change is that we are now checking if the user is deleted by
calling /users/show (we need a new users READ permission for that)
 - this way we go to the database and check if the user was indeed
deleted.

Fix test_task_instance_endpoint tests

There were two reasons for the test failed:

* when the Job was added to task instance, the task instance was not
  merged in session, which means that commit did not store the added
  Job

* some of the tests were expecting a call with specific session
  and they failed because session was different. Replacing the
  session with mock.ANY tells pytest that this parameter can be
  anything - we will have different session when when the call will
  be made with ASGI/Starlette

Fix parameter validation

* added default value for limit parameter across the board. Connexion
  3 does not like if the parameter had no default and we had not
    provided one - even if our custom decorated was adding it. Adding
  default value and updating our decorator to treat None as `default`
  fixed a number of problems where limits were not passed

* swapped openapi specification for /datasets/{uri} and /dataset/events.
  Since `{uri}` was defined first, connection matched `events` with
  `{uri}` and chose parameter definitions from `{uri}` not events

Fix test_log_enpoint tests

The problem here was that some sessions should be committed/closed but
also in order to run it standalone we wanted to create log templates
in the database - as it relied implcitly on log templates created by
other tests.

Fix test_views_dagrun, test_views_tasks and test_views_log

Fixed by switching to use flask client for testing rather than
starlette. Starlette client in this case has some side effects that are
also impacting Sqlite's session being created in a different
thread and deleted with close_all_sessions fixture.

Fix test_views_dagrun

Fixed by switching to use flask client for testing rather than
starlette. Starlette client in this case has some side effects that are
also impacting Sqlite's session being created in a different
thread and deleted with close_all_sessions fixture.

Co-authored-by: sudipto baral <sudiptobaral.me@gmail.com>
Co-authored-by: satoshi-sh <satoss1108@gmail.com>
Co-authored-by: Maksim Yermakou <maksimy@google.com>
Co-authored-by: Ulada Zakharava <Vlada_Zakharava@epam.com>
  • Loading branch information
4 people authored and potiuk committed Apr 25, 2024
1 parent 15c2734 commit 396fc6f
Show file tree
Hide file tree
Showing 124 changed files with 2,782 additions and 2,601 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/basic-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ jobs:
env:
HATCH_ENV: "test"
working-directory: ./clients/python
- name: Compile www assets
run: breeze compile-www-assets
- name: "Install Airflow in editable mode with fab for webserver tests"
run: pip install -e ".[fab]"
- name: "Install Python client"
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/connection_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def get_connection(*, connection_id: str, session: Session = NEW_SESSION) -> API
@provide_session
def get_connections(
*,
limit: int,
limit: int | None = None,
offset: int = 0,
order_by: str = "id",
session: Session = NEW_SESSION,
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/dag_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def get_dag_details(
@provide_session
def get_dags(
*,
limit: int,
limit: int | None = None,
offset: int = 0,
tags: Collection[str] | None = None,
dag_id_pattern: str | None = None,
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/dag_warning_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
@provide_session
def get_dag_warnings(
*,
limit: int,
limit: int | None = None,
dag_id: str | None = None,
warning_type: str | None = None,
offset: int | None = None,
Expand Down
6 changes: 3 additions & 3 deletions airflow/api_connexion/endpoints/dataset_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def get_dataset(*, uri: str, session: Session = NEW_SESSION) -> APIResponse:
@provide_session
def get_datasets(
*,
limit: int,
limit: int | None = None,
offset: int = 0,
uri_pattern: str | None = None,
dag_ids: str | None = None,
Expand Down Expand Up @@ -113,11 +113,11 @@ def get_datasets(


@security.requires_access_dataset("GET")
@provide_session
@format_parameters({"limit": check_limit})
@provide_session
def get_dataset_events(
*,
limit: int,
limit: int | None = None,
offset: int = 0,
order_by: str = "timestamp",
dataset_id: int | None = None,
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/event_log_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_event_logs(
included_events: str | None = None,
before: str | None = None,
after: str | None = None,
limit: int,
limit: int | None = None,
offset: int | None = None,
order_by: str = "event_log_id",
session: Session = NEW_SESSION,
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/import_error_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION) ->
@provide_session
def get_import_errors(
*,
limit: int,
limit: int | None = None,
offset: int | None = None,
order_by: str = "import_error_id",
session: Session = NEW_SESSION,
Expand Down
5 changes: 4 additions & 1 deletion airflow/api_connexion/endpoints/log_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ def get_log(
logs = logs[0] if task_try_number is not None else logs
# we must have token here, so we can safely ignore it
token = URLSafeSerializer(key).dumps(metadata) # type: ignore[assignment]
return logs_schema.dump(LogResponseObject(continuation_token=token, content=logs))
return Response(
logs_schema.dumps(LogResponseObject(continuation_token=token, content=logs)),
headers={"Content-Type": "application/json"},
)
# text/plain. Stream
logs = task_log_reader.read_log_stream(ti, task_try_number, metadata)

Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/pool_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIResponse:
@provide_session
def get_pools(
*,
limit: int,
limit: int | None = None,
order_by: str = "id",
offset: int | None = None,
session: Session = NEW_SESSION,
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def _apply_range_filter(query: Select, key: ClauseElement, value_range: tuple[T,
@provide_session
def get_task_instances(
*,
limit: int,
limit: int | None = None,
dag_id: str | None = None,
dag_run_id: str | None = None,
execution_date_gte: str | None = None,
Expand Down
55 changes: 23 additions & 32 deletions airflow/api_connexion/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, Any

import werkzeug
from connexion import FlaskApi, ProblemException, problem
from connexion import ProblemException, problem

from airflow.utils.docs import get_docs_url

if TYPE_CHECKING:
import flask
from connexion.lifecycle import ConnexionRequest, ConnexionResponse

doc_link = get_docs_url("stable-rest-api-ref.html")

Expand All @@ -40,37 +39,29 @@
}


def common_error_handler(exception: BaseException) -> flask.Response:
def problem_error_handler(_request: ConnexionRequest, exception: ProblemException) -> ConnexionResponse:
"""Use to capture connexion exceptions and add link to the type field."""
if isinstance(exception, ProblemException):
link = EXCEPTIONS_LINK_MAP.get(exception.status)
if link:
response = problem(
status=exception.status,
title=exception.title,
detail=exception.detail,
type=link,
instance=exception.instance,
headers=exception.headers,
ext=exception.ext,
)
else:
response = problem(
status=exception.status,
title=exception.title,
detail=exception.detail,
type=exception.type,
instance=exception.instance,
headers=exception.headers,
ext=exception.ext,
)
link = EXCEPTIONS_LINK_MAP.get(exception.status)
if link:
return problem(
status=exception.status,
title=exception.title,
detail=exception.detail,
type=link,
instance=exception.instance,
headers=exception.headers,
ext=exception.ext,
)
else:
if not isinstance(exception, werkzeug.exceptions.HTTPException):
exception = werkzeug.exceptions.InternalServerError()

response = problem(title=exception.name, detail=exception.description, status=exception.code)

return FlaskApi.get_response(response)
return problem(
status=exception.status,
title=exception.title,
detail=exception.detail,
type=exception.type,
instance=exception.instance,
headers=exception.headers,
ext=exception.ext,
)


class NotFound(ProblemException):
Expand Down
61 changes: 35 additions & 26 deletions airflow/api_connexion/openapi/v1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1367,6 +1367,10 @@ paths:
responses:
"204":
description: Success.
content:
text/html:
schema:
type: string
"400":
$ref: "#/components/responses/BadRequest"
"401":
Expand Down Expand Up @@ -1743,6 +1747,10 @@ paths:
responses:
"204":
description: Success.
content:
text/html:
schema:
type: string
"400":
$ref: "#/components/responses/BadRequest"
"401":
Expand Down Expand Up @@ -1885,8 +1893,8 @@ paths:
response = self.client.get(
request_url,
query_string={"token": token},
headers={"Accept": "text/plain"},
environ_overrides={"REMOTE_USER": "test"},
headers={"Accept": "text/plain","REMOTE_USER": "test"},
)
continuation_token = response.json["continuation_token"]
metadata = URLSafeSerializer(key).loads(continuation_token)
Expand Down Expand Up @@ -2020,7 +2028,7 @@ paths:
properties:
content:
type: string
plain/text:
text/plain:
schema:
type: string

Expand Down Expand Up @@ -2106,29 +2114,6 @@ paths:
"403":
$ref: "#/components/responses/PermissionDenied"

/datasets/{uri}:
parameters:
- $ref: "#/components/parameters/DatasetURI"
get:
summary: Get a dataset
description: Get a dataset by uri.
x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint
operationId: get_dataset
tags: [Dataset]
responses:
"200":
description: Success.
content:
application/json:
schema:
$ref: "#/components/schemas/Dataset"
"401":
$ref: "#/components/responses/Unauthenticated"
"403":
$ref: "#/components/responses/PermissionDenied"
"404":
$ref: "#/components/responses/NotFound"

/datasets/events:
get:
summary: Get dataset events
Expand Down Expand Up @@ -2186,6 +2171,30 @@ paths:
'404':
$ref: '#/components/responses/NotFound'

/datasets/{uri}:
parameters:
- $ref: "#/components/parameters/DatasetURI"
get:
summary: Get a dataset
description: Get a dataset by uri.
x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint
operationId: get_dataset
tags: [Dataset]
responses:
"200":
description: Success.
content:
application/json:
schema:
$ref: "#/components/schemas/Dataset"
"401":
$ref: "#/components/responses/Unauthenticated"
"403":
$ref: "#/components/responses/PermissionDenied"
"404":
$ref: "#/components/responses/NotFound"


/config:
get:
summary: Get current configuration
Expand Down
14 changes: 9 additions & 5 deletions airflow/api_connexion/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def validate_istimezone(value: datetime) -> None:
raise BadRequest("Invalid datetime format", detail="Naive datetime is disallowed")


def format_datetime(value: str) -> datetime:
def format_datetime(value: str | None) -> datetime | None:
"""
Format datetime objects.
Expand All @@ -50,6 +50,8 @@ def format_datetime(value: str) -> datetime:
This should only be used within connection views because it raises 400
"""
if value is None:
return None
value = value.strip()
if value[-1] != "Z":
value = value.replace(" ", "+")
Expand All @@ -59,7 +61,7 @@ def format_datetime(value: str) -> datetime:
raise BadRequest("Incorrect datetime argument", detail=str(err))


def check_limit(value: int) -> int:
def check_limit(value: int | None) -> int:
"""
Check the limit does not exceed configured value.
Expand All @@ -68,7 +70,8 @@ def check_limit(value: int) -> int:
"""
max_val = conf.getint("api", "maximum_page_limit") # user configured max page limit
fallback = conf.getint("api", "fallback_page_limit")

if value is None:
return fallback
if value > max_val:
log.warning(
"The limit param value %s passed in API exceeds the configured maximum page limit %s",
Expand Down Expand Up @@ -99,8 +102,9 @@ def format_parameters_decorator(func: T) -> T:
@wraps(func)
def wrapped_function(*args, **kwargs):
for key, formatter in params_formatters.items():
if key in kwargs:
kwargs[key] = formatter(kwargs[key])
value = formatter(kwargs.get(key))
if value:
kwargs[key] = value
return func(*args, **kwargs)

return cast(T, wrapped_function)
Expand Down
5 changes: 3 additions & 2 deletions airflow/auth/managers/base_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from airflow.utils.session import NEW_SESSION, provide_session

if TYPE_CHECKING:
import connexion
from flask import Blueprint
from sqlalchemy.orm import Session

Expand Down Expand Up @@ -81,8 +82,8 @@ def get_cli_commands() -> list[CLICommand]:
"""
return []

def get_api_endpoints(self) -> None | Blueprint:
"""Return API endpoint(s) definition for the auth manager."""
def set_api_endpoints(self, connexion_app: connexion.FlaskApp) -> None:
"""Set API endpoint(s) definition for the auth manager."""
return None

def get_user_name(self) -> str:
Expand Down
Loading

0 comments on commit 396fc6f

Please sign in to comment.