Skip to content

Commit

Permalink
Add support for EMR Notebook Execution (apache#14962)
Browse files Browse the repository at this point in the history
  • Loading branch information
rliuamzn committed Jul 13, 2021
1 parent f2bd15d commit ef48150
Show file tree
Hide file tree
Showing 9 changed files with 734 additions and 2 deletions.
@@ -0,0 +1,77 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This is an example dag for a AWS EMR Pipeline.
Start a notebook execution then check the notebook execution until it finishes.
"""
import os
from datetime import timedelta

from airflow import DAG
from airflow.providers.amazon.aws.operators.emr_start_notebook_execution import (
EmrStartNotebookExecutionOperator,
)
from airflow.providers.amazon.aws.sensors.emr_notebook_execution import EmrNotebookExecutionSensor
from airflow.utils.dates import days_ago

# [START howto_operator_emr_notebook_execution_env_variables]
NOTEBOOK_ID = os.getenv("NOTEBOOK_ID", "e-MYDEMONOTEBOOKT0ACS9KN5UT")
NOTEBOOK_FILE = os.getenv("NOTEBOOK_FILE", "test.ipynb")
NOTEBOOK_EXECUTION_PARAMS = os.getenv("NOTEBOOK_EXECUTION_PARAMS", '{\"PARAM_1\":10}')
EMR_CLUSTER_ID = os.getenv("EMR_CLUSTER_ID", "j-123456ABCDEFG")
# [END howto_operator_emr_notebook_execution_env_variables]


DEFAULT_ARGS = {
'owner': 'airflow',
'depends_on_past': False,
'email': ['airflow@example.com'],
'email_on_failure': False,
'email_on_retry': False,
}

with DAG(
dag_id='emr_notebook_execution_dag',
dagrun_timeout=timedelta(hours=2),
start_date=days_ago(1),
schedule_interval="@once",
tags=["emr_notebook_execution", "example"],
) as dag:

# [START howto_operator_emr_notebook_execution_tasks]
notebook_execution_adder = EmrStartNotebookExecutionOperator(
task_id='start_notebook_execution',
aws_conn_id='aws_default',
editor_id=NOTEBOOK_ID,
relative_path=NOTEBOOK_FILE,
notebook_execution_name='test-emr-notebook-execution-airflow-operator',
notebook_params=NOTEBOOK_EXECUTION_PARAMS,
execution_engine={'Id': EMR_CLUSTER_ID, 'Type': 'EMR'},
service_role='EMR_Notebooks_DefaultRole',
tags=[{"Key": "create-by", "Value": "airflow-operator"}],
)

notebook_execution_checker = EmrNotebookExecutionSensor(
task_id='watch_notebook_execution',
aws_conn_id='aws_default',
notebook_execution_id="{{ task_instance.xcom_pull('start_notebook_execution', key='return_value') }}",
)

notebook_execution_adder >> notebook_execution_checker
# [END howto_operator_emr_notebook_execution_tasks]
125 changes: 125 additions & 0 deletions airflow/providers/amazon/aws/operators/emr_start_notebook_execution.py
@@ -0,0 +1,125 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, List, Optional

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.emr import EmrHook


class EmrStartNotebookExecutionOperator(BaseOperator):
"""
An operator that starts a notebook execution to an existing EMR job_flow.
:param editor_id: Id of the emr notebook to run (with prefix e-). (templated)
:type editor_id: Optional[str]
:param relative_path: Relative path of the notebook file to run. (templated)
:type relative_path: Optional[str]
:param notebook_execution_name: The name for the execution to be created (templated)
:type notebook_execution_name: Optional[str]
:param execution_engine: A JSON string to specify the EMR cluster,
e.g. {'Id': 'j-123', 'Type': 'EMR'}. (templated)
:type execution_engine: Optional[str]
:param service_role: The service role name or ARN needed by the EMR service
to operate on your behalf. (templated)
:type service_role: Optional[str]
:param notebook_instance_security_group_id: Security group ID for the EC2 instance that runs
the notebook server. If not supplied, a default security group will be used. (templated)
:type notebook_instance_security_group_id: Optional[str]
:param tags: A list of tags to be applied to the execution. (templated)
:type tags: Optional[List[Dict[str, str]]]
:param aws_conn_id: aws connection to uses
:type aws_conn_id: str
:param do_xcom_push: if True, notebook_execution_id is pushed to XCom with key notebook_execution_id.
:type do_xcom_push: bool
"""

template_fields = [
'editor_id',
'relative_path',
'notebook_execution_name',
'notebook_params',
'execution_engine',
'service_role',
'notebook_instance_security_group_id',
'tags',
]
template_ext = ('.json',)
ui_color = '#f9c915'

def __init__(
self,
*,
editor_id: str,
relative_path: str,
notebook_execution_name: Optional[str] = None,
notebook_params: Optional[str] = None,
execution_engine: Dict[str, str],
service_role: str,
notebook_instance_security_group_id: Optional[str] = None,
tags: Optional[List[Dict[str, str]]] = None,
aws_conn_id: str = 'aws_default',
**kwargs,
):
if kwargs.get('xcom_push') is not None:
raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead")
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id

self.editor_id = editor_id
self.relative_path = relative_path
self.notebook_execution_name = notebook_execution_name
self.notebook_params = notebook_params
self.execution_engine = execution_engine
self.service_role = service_role
self.notebook_instance_security_group_id = notebook_instance_security_group_id
self.tags = tags

def execute(self, context: Dict[str, Any]) -> str:
emr_hook = EmrHook(aws_conn_id=self.aws_conn_id)
emr = emr_hook.get_conn()

inputs = {
"EditorId": self.editor_id,
"RelativePath": self.relative_path,
"ExecutionEngine": self.execution_engine,
"ServiceRole": self.service_role,
}

if self.notebook_execution_name:
inputs["NotebookExecutionName"] = self.notebook_execution_name

if self.notebook_params:
inputs["NotebookParams"] = self.notebook_params

if self.notebook_instance_security_group_id:
inputs["NotebookInstanceSecurityGroupId"] = self.notebook_instance_security_group_id

if self.tags:
inputs["Tags"] = self.tags

response = emr.start_notebook_execution(**inputs)

if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
raise AirflowException(f'Starting notebook execution failed: {response}')
else:
notebook_execution_id = response['NotebookExecutionId']
self.log.info('Started a notebook execution %s', notebook_execution_id)
if self.do_xcom_push:
context['ti'].xcom_push(key='notebook_execution_id', value=notebook_execution_id)
return notebook_execution_id
@@ -0,0 +1,54 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Any, Dict

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.emr import EmrHook


class EmrStopNotebookExecutionOperator(BaseOperator):
"""
Operator to stop an EMR Notebook execution.
:param notebook_execution_id: id of the EMR Notebook execution to stop. (templated)
:type notebook_execution_id: str
:param aws_conn_id: aws connection to uses
:type aws_conn_id: str
"""

template_fields = ['notebook_execution_id']
template_ext = ()
ui_color = '#f9c915'

def __init__(self, *, notebook_execution_id: str, aws_conn_id: str = 'aws_default', **kwargs):
super().__init__(**kwargs)
self.notebook_execution_id = notebook_execution_id
self.aws_conn_id = aws_conn_id

def execute(self, context: Dict[str, Any]) -> None:
emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn()

self.log.info('Requesting to stop Notebook execution: %s', self.notebook_execution_id)
response = emr.stop_notebook_execution(NotebookExecutionId=self.notebook_execution_id)

if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
raise AirflowException(f'Failed requesting to stop the Notebook execution: {response}')
else:
self.log.info('Requested to stop Notebook execution %s ', self.notebook_execution_id)
120 changes: 120 additions & 0 deletions airflow/providers/amazon/aws/sensors/emr_notebook_execution.py
@@ -0,0 +1,120 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Any, Dict, Iterable, Optional

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.sensors.emr_base import EmrBaseSensor


class EmrNotebookExecutionSensor(EmrBaseSensor):
"""
Asks for the state of the NotebookExecution until it reaches a terminal state.
If it fails the sensor errors, failing the task.
:param execution_id: notebook execution_id to check the state of
:type execution_id: str
:param target_states: the target states, sensor waits until
notebook execution reaches any of these states
:type target_states: list[str]
:param failed_states: the failure states, sensor fails when
notebook execution reaches any of these states
:type failed_states: list[str]
"""

template_fields = ['notebook_execution_id', 'target_states', 'failed_states']
template_ext = ()

def __init__(
self,
*,
notebook_execution_id: str,
target_states: Optional[Iterable[str]] = None,
failed_states: Optional[Iterable[str]] = None,
**kwargs,
):
super().__init__(**kwargs)
self.notebook_execution_id = notebook_execution_id
self.target_states = target_states or ['FINISHED']
self.failed_states = failed_states or ['FAILED', 'STOPPED']

def poke(self, context):
response = self.get_emr_response()

if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
self.log.info('Bad HTTP response: %s', response)
return False

state = self.state_from_response(response)
self.log.info('Notebook execution is %s', state)

if state in self.target_states:
return True

if state in self.failed_states:
final_message = 'Notebook execution failed'
failure_message = self.failure_message_from_response(response)
if failure_message:
final_message = failure_message
raise AirflowException(final_message)

return False

def get_emr_response(self):
"""
Make an API call with boto3 and get notebook execution details.
.. seealso::
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr.html#EMR.Client.describe_notebook_execution
:return: response
:rtype: dict[str, Any]
"""
emr = self.get_hook().get_conn()
self.log.info('Poking notebook execution %s', self.notebook_execution_id)
return emr.describe_notebook_execution(NotebookExecutionId=self.notebook_execution_id)

@staticmethod
def state_from_response(response: Dict[str, Any]) -> str:
"""
Get state from response dictionary.
:param response: response from AWS API
:type response: dict[str, Any]
:return: current state of the cluster
:rtype: str
"""
return response['NotebookExecution']['Status']

@staticmethod
def failure_message_from_response(response):
"""
Get failure message from response dictionary.
:param response: response from AWS API
:type response: dict[str, Any]
:return: failure message
:rtype: Optional[str]
"""
state = response['NotebookExecution']['Status']
if state not in ['FAILED', 'STOPPED']:
return None
last_state_change_reason = response['NotebookExecution']['LastStateChangeReason']
if last_state_change_reason:
return 'Execution failed with reason: ' + last_state_change_reason
return None
3 changes: 3 additions & 0 deletions airflow/providers/amazon/provider.yaml
Expand Up @@ -178,6 +178,8 @@ operators:
- airflow.providers.amazon.aws.operators.emr_create_job_flow
- airflow.providers.amazon.aws.operators.emr_modify_cluster
- airflow.providers.amazon.aws.operators.emr_terminate_job_flow
- airflow.providers.amazon.aws.operators.emr_start_notebook_execution
- airflow.providers.amazon.aws.operators.emr_stop_notebook_execution
- integration-name: Amazon Glacier
python-modules:
- airflow.providers.amazon.aws.operators.glacier
Expand Down Expand Up @@ -232,6 +234,7 @@ sensors:
- airflow.providers.amazon.aws.sensors.emr_base
- airflow.providers.amazon.aws.sensors.emr_job_flow
- airflow.providers.amazon.aws.sensors.emr_step
- airflow.providers.amazon.aws.sensors.emr_notebook_execution
- integration-name: Amazon Glacier
python-modules:
- airflow.providers.amazon.aws.sensors.glacier
Expand Down

0 comments on commit ef48150

Please sign in to comment.