forked from apache/airflow
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for EMR Notebook Execution (apache#14962)
- Loading branch information
Showing
9 changed files
with
734 additions
and
2 deletions.
There are no files selected for viewing
77 changes: 77 additions & 0 deletions
77
airflow/providers/amazon/aws/example_dags/example_emr_notebook_execution.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
""" | ||
This is an example dag for a AWS EMR Pipeline. | ||
Start a notebook execution then check the notebook execution until it finishes. | ||
""" | ||
import os | ||
from datetime import timedelta | ||
|
||
from airflow import DAG | ||
from airflow.providers.amazon.aws.operators.emr_start_notebook_execution import ( | ||
EmrStartNotebookExecutionOperator, | ||
) | ||
from airflow.providers.amazon.aws.sensors.emr_notebook_execution import EmrNotebookExecutionSensor | ||
from airflow.utils.dates import days_ago | ||
|
||
# [START howto_operator_emr_notebook_execution_env_variables] | ||
NOTEBOOK_ID = os.getenv("NOTEBOOK_ID", "e-MYDEMONOTEBOOKT0ACS9KN5UT") | ||
NOTEBOOK_FILE = os.getenv("NOTEBOOK_FILE", "test.ipynb") | ||
NOTEBOOK_EXECUTION_PARAMS = os.getenv("NOTEBOOK_EXECUTION_PARAMS", '{\"PARAM_1\":10}') | ||
EMR_CLUSTER_ID = os.getenv("EMR_CLUSTER_ID", "j-123456ABCDEFG") | ||
# [END howto_operator_emr_notebook_execution_env_variables] | ||
|
||
|
||
DEFAULT_ARGS = { | ||
'owner': 'airflow', | ||
'depends_on_past': False, | ||
'email': ['airflow@example.com'], | ||
'email_on_failure': False, | ||
'email_on_retry': False, | ||
} | ||
|
||
with DAG( | ||
dag_id='emr_notebook_execution_dag', | ||
dagrun_timeout=timedelta(hours=2), | ||
start_date=days_ago(1), | ||
schedule_interval="@once", | ||
tags=["emr_notebook_execution", "example"], | ||
) as dag: | ||
|
||
# [START howto_operator_emr_notebook_execution_tasks] | ||
notebook_execution_adder = EmrStartNotebookExecutionOperator( | ||
task_id='start_notebook_execution', | ||
aws_conn_id='aws_default', | ||
editor_id=NOTEBOOK_ID, | ||
relative_path=NOTEBOOK_FILE, | ||
notebook_execution_name='test-emr-notebook-execution-airflow-operator', | ||
notebook_params=NOTEBOOK_EXECUTION_PARAMS, | ||
execution_engine={'Id': EMR_CLUSTER_ID, 'Type': 'EMR'}, | ||
service_role='EMR_Notebooks_DefaultRole', | ||
tags=[{"Key": "create-by", "Value": "airflow-operator"}], | ||
) | ||
|
||
notebook_execution_checker = EmrNotebookExecutionSensor( | ||
task_id='watch_notebook_execution', | ||
aws_conn_id='aws_default', | ||
notebook_execution_id="{{ task_instance.xcom_pull('start_notebook_execution', key='return_value') }}", | ||
) | ||
|
||
notebook_execution_adder >> notebook_execution_checker | ||
# [END howto_operator_emr_notebook_execution_tasks] |
125 changes: 125 additions & 0 deletions
125
airflow/providers/amazon/aws/operators/emr_start_notebook_execution.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
from typing import Any, Dict, List, Optional | ||
|
||
from airflow.exceptions import AirflowException | ||
from airflow.models import BaseOperator | ||
from airflow.providers.amazon.aws.hooks.emr import EmrHook | ||
|
||
|
||
class EmrStartNotebookExecutionOperator(BaseOperator): | ||
""" | ||
An operator that starts a notebook execution to an existing EMR job_flow. | ||
:param editor_id: Id of the emr notebook to run (with prefix e-). (templated) | ||
:type editor_id: Optional[str] | ||
:param relative_path: Relative path of the notebook file to run. (templated) | ||
:type relative_path: Optional[str] | ||
:param notebook_execution_name: The name for the execution to be created (templated) | ||
:type notebook_execution_name: Optional[str] | ||
:param execution_engine: A JSON string to specify the EMR cluster, | ||
e.g. {'Id': 'j-123', 'Type': 'EMR'}. (templated) | ||
:type execution_engine: Optional[str] | ||
:param service_role: The service role name or ARN needed by the EMR service | ||
to operate on your behalf. (templated) | ||
:type service_role: Optional[str] | ||
:param notebook_instance_security_group_id: Security group ID for the EC2 instance that runs | ||
the notebook server. If not supplied, a default security group will be used. (templated) | ||
:type notebook_instance_security_group_id: Optional[str] | ||
:param tags: A list of tags to be applied to the execution. (templated) | ||
:type tags: Optional[List[Dict[str, str]]] | ||
:param aws_conn_id: aws connection to uses | ||
:type aws_conn_id: str | ||
:param do_xcom_push: if True, notebook_execution_id is pushed to XCom with key notebook_execution_id. | ||
:type do_xcom_push: bool | ||
""" | ||
|
||
template_fields = [ | ||
'editor_id', | ||
'relative_path', | ||
'notebook_execution_name', | ||
'notebook_params', | ||
'execution_engine', | ||
'service_role', | ||
'notebook_instance_security_group_id', | ||
'tags', | ||
] | ||
template_ext = ('.json',) | ||
ui_color = '#f9c915' | ||
|
||
def __init__( | ||
self, | ||
*, | ||
editor_id: str, | ||
relative_path: str, | ||
notebook_execution_name: Optional[str] = None, | ||
notebook_params: Optional[str] = None, | ||
execution_engine: Dict[str, str], | ||
service_role: str, | ||
notebook_instance_security_group_id: Optional[str] = None, | ||
tags: Optional[List[Dict[str, str]]] = None, | ||
aws_conn_id: str = 'aws_default', | ||
**kwargs, | ||
): | ||
if kwargs.get('xcom_push') is not None: | ||
raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead") | ||
super().__init__(**kwargs) | ||
self.aws_conn_id = aws_conn_id | ||
|
||
self.editor_id = editor_id | ||
self.relative_path = relative_path | ||
self.notebook_execution_name = notebook_execution_name | ||
self.notebook_params = notebook_params | ||
self.execution_engine = execution_engine | ||
self.service_role = service_role | ||
self.notebook_instance_security_group_id = notebook_instance_security_group_id | ||
self.tags = tags | ||
|
||
def execute(self, context: Dict[str, Any]) -> str: | ||
emr_hook = EmrHook(aws_conn_id=self.aws_conn_id) | ||
emr = emr_hook.get_conn() | ||
|
||
inputs = { | ||
"EditorId": self.editor_id, | ||
"RelativePath": self.relative_path, | ||
"ExecutionEngine": self.execution_engine, | ||
"ServiceRole": self.service_role, | ||
} | ||
|
||
if self.notebook_execution_name: | ||
inputs["NotebookExecutionName"] = self.notebook_execution_name | ||
|
||
if self.notebook_params: | ||
inputs["NotebookParams"] = self.notebook_params | ||
|
||
if self.notebook_instance_security_group_id: | ||
inputs["NotebookInstanceSecurityGroupId"] = self.notebook_instance_security_group_id | ||
|
||
if self.tags: | ||
inputs["Tags"] = self.tags | ||
|
||
response = emr.start_notebook_execution(**inputs) | ||
|
||
if not response['ResponseMetadata']['HTTPStatusCode'] == 200: | ||
raise AirflowException(f'Starting notebook execution failed: {response}') | ||
else: | ||
notebook_execution_id = response['NotebookExecutionId'] | ||
self.log.info('Started a notebook execution %s', notebook_execution_id) | ||
if self.do_xcom_push: | ||
context['ti'].xcom_push(key='notebook_execution_id', value=notebook_execution_id) | ||
return notebook_execution_id |
54 changes: 54 additions & 0 deletions
54
airflow/providers/amazon/aws/operators/emr_stop_notebook_execution.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
from typing import Any, Dict | ||
|
||
from airflow.exceptions import AirflowException | ||
from airflow.models import BaseOperator | ||
from airflow.providers.amazon.aws.hooks.emr import EmrHook | ||
|
||
|
||
class EmrStopNotebookExecutionOperator(BaseOperator): | ||
""" | ||
Operator to stop an EMR Notebook execution. | ||
:param notebook_execution_id: id of the EMR Notebook execution to stop. (templated) | ||
:type notebook_execution_id: str | ||
:param aws_conn_id: aws connection to uses | ||
:type aws_conn_id: str | ||
""" | ||
|
||
template_fields = ['notebook_execution_id'] | ||
template_ext = () | ||
ui_color = '#f9c915' | ||
|
||
def __init__(self, *, notebook_execution_id: str, aws_conn_id: str = 'aws_default', **kwargs): | ||
super().__init__(**kwargs) | ||
self.notebook_execution_id = notebook_execution_id | ||
self.aws_conn_id = aws_conn_id | ||
|
||
def execute(self, context: Dict[str, Any]) -> None: | ||
emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn() | ||
|
||
self.log.info('Requesting to stop Notebook execution: %s', self.notebook_execution_id) | ||
response = emr.stop_notebook_execution(NotebookExecutionId=self.notebook_execution_id) | ||
|
||
if not response['ResponseMetadata']['HTTPStatusCode'] == 200: | ||
raise AirflowException(f'Failed requesting to stop the Notebook execution: {response}') | ||
else: | ||
self.log.info('Requested to stop Notebook execution %s ', self.notebook_execution_id) |
120 changes: 120 additions & 0 deletions
120
airflow/providers/amazon/aws/sensors/emr_notebook_execution.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
from typing import Any, Dict, Iterable, Optional | ||
|
||
from airflow.exceptions import AirflowException | ||
from airflow.providers.amazon.aws.sensors.emr_base import EmrBaseSensor | ||
|
||
|
||
class EmrNotebookExecutionSensor(EmrBaseSensor): | ||
""" | ||
Asks for the state of the NotebookExecution until it reaches a terminal state. | ||
If it fails the sensor errors, failing the task. | ||
:param execution_id: notebook execution_id to check the state of | ||
:type execution_id: str | ||
:param target_states: the target states, sensor waits until | ||
notebook execution reaches any of these states | ||
:type target_states: list[str] | ||
:param failed_states: the failure states, sensor fails when | ||
notebook execution reaches any of these states | ||
:type failed_states: list[str] | ||
""" | ||
|
||
template_fields = ['notebook_execution_id', 'target_states', 'failed_states'] | ||
template_ext = () | ||
|
||
def __init__( | ||
self, | ||
*, | ||
notebook_execution_id: str, | ||
target_states: Optional[Iterable[str]] = None, | ||
failed_states: Optional[Iterable[str]] = None, | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
self.notebook_execution_id = notebook_execution_id | ||
self.target_states = target_states or ['FINISHED'] | ||
self.failed_states = failed_states or ['FAILED', 'STOPPED'] | ||
|
||
def poke(self, context): | ||
response = self.get_emr_response() | ||
|
||
if not response['ResponseMetadata']['HTTPStatusCode'] == 200: | ||
self.log.info('Bad HTTP response: %s', response) | ||
return False | ||
|
||
state = self.state_from_response(response) | ||
self.log.info('Notebook execution is %s', state) | ||
|
||
if state in self.target_states: | ||
return True | ||
|
||
if state in self.failed_states: | ||
final_message = 'Notebook execution failed' | ||
failure_message = self.failure_message_from_response(response) | ||
if failure_message: | ||
final_message = failure_message | ||
raise AirflowException(final_message) | ||
|
||
return False | ||
|
||
def get_emr_response(self): | ||
""" | ||
Make an API call with boto3 and get notebook execution details. | ||
.. seealso:: | ||
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr.html#EMR.Client.describe_notebook_execution | ||
:return: response | ||
:rtype: dict[str, Any] | ||
""" | ||
emr = self.get_hook().get_conn() | ||
self.log.info('Poking notebook execution %s', self.notebook_execution_id) | ||
return emr.describe_notebook_execution(NotebookExecutionId=self.notebook_execution_id) | ||
|
||
@staticmethod | ||
def state_from_response(response: Dict[str, Any]) -> str: | ||
""" | ||
Get state from response dictionary. | ||
:param response: response from AWS API | ||
:type response: dict[str, Any] | ||
:return: current state of the cluster | ||
:rtype: str | ||
""" | ||
return response['NotebookExecution']['Status'] | ||
|
||
@staticmethod | ||
def failure_message_from_response(response): | ||
""" | ||
Get failure message from response dictionary. | ||
:param response: response from AWS API | ||
:type response: dict[str, Any] | ||
:return: failure message | ||
:rtype: Optional[str] | ||
""" | ||
state = response['NotebookExecution']['Status'] | ||
if state not in ['FAILED', 'STOPPED']: | ||
return None | ||
last_state_change_reason = response['NotebookExecution']['LastStateChangeReason'] | ||
if last_state_change_reason: | ||
return 'Execution failed with reason: ' + last_state_change_reason | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.