Skip to content

Commit

Permalink
Rename SparkSubmitOperator's fields' names to comply with templated…
Browse files Browse the repository at this point in the history
… fields validation (apache#38051)
  • Loading branch information
shahar1 committed Mar 16, 2024
1 parent e30c93a commit baa6f08
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 96 deletions.
4 changes: 1 addition & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,7 @@ repos:
# https://github.com/apache/airflow/issues/36484
exclude: |
(?x)^(
^airflow\/providers\/google\/cloud\/operators\/cloud_storage_transfer_service.py$|
^airflow\/providers\/apache\/spark\/operators\/spark_submit.py\.py$|
^airflow\/providers\/apache\/spark\/operators\/spark_submit\.py$|
^airflow\/providers\/google\/cloud\/operators\/cloud_storage_transfer_service.py$
)$
- id: ruff
name: Run 'ruff' for extremely fast Python linting
Expand Down
34 changes: 4 additions & 30 deletions airflow/providers/apache/spark/operators/spark_jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,6 @@ class SparkJDBCOperator(SparkSubmitOperator):
:param spark_files: Additional files to upload to the container running the job
:param spark_jars: Additional jars to upload and add to the driver and
executor classpath
:param num_executors: number of executor to run. This should be set so as to manage
the number of connections made with the JDBC database
:param executor_cores: Number of cores per executor
:param executor_memory: Memory per executor (e.g. 1000M, 2G)
:param driver_memory: Memory allocated to the driver (e.g. 1000M, 2G)
:param verbose: Whether to pass the verbose flag to spark-submit for debugging
:param keytab: Full path to the file that contains the keytab
:param principal: The name of the kerberos principal used for keytab
:param cmd_type: Which way the data should flow. 2 possible values:
spark_to_jdbc: data written by spark from metastore to jdbc
jdbc_to_spark: data written by spark from jdbc to metastore
Expand All @@ -60,7 +52,7 @@ class SparkJDBCOperator(SparkSubmitOperator):
:param jdbc_driver: Name of the JDBC driver to use for the JDBC connection. This
driver (usually a jar) should be passed in the 'jars' parameter
:param metastore_table: The name of the metastore table,
:param jdbc_truncate: (spark_to_jdbc only) Whether or not Spark should truncate or
:param jdbc_truncate: (spark_to_jdbc only) Whether Spark should truncate or
drop and recreate the JDBC table. This only takes effect if
'save_mode' is set to Overwrite. Also, if the schema is
different, Spark cannot truncate, and will drop and recreate
Expand Down Expand Up @@ -91,9 +83,7 @@ class SparkJDBCOperator(SparkSubmitOperator):
(e.g: "name CHAR(64), comments VARCHAR(1024)").
The specified types should be valid spark sql data
types.
:param use_krb5ccache: if True, configure spark to use ticket cache instead of relying
on keytab for Kerberos login
:param kwargs: kwargs passed to SparkSubmitOperator.
"""

def __init__(
Expand All @@ -105,13 +95,6 @@ def __init__(
spark_py_files: str | None = None,
spark_files: str | None = None,
spark_jars: str | None = None,
num_executors: int | None = None,
executor_cores: int | None = None,
executor_memory: str | None = None,
driver_memory: str | None = None,
verbose: bool = False,
principal: str | None = None,
keytab: str | None = None,
cmd_type: str = "spark_to_jdbc",
jdbc_table: str | None = None,
jdbc_conn_id: str = "jdbc-default",
Expand All @@ -127,7 +110,6 @@ def __init__(
lower_bound: str | None = None,
upper_bound: str | None = None,
create_table_column_types: str | None = None,
use_krb5ccache: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -137,13 +119,6 @@ def __init__(
self._spark_py_files = spark_py_files
self._spark_files = spark_files
self._spark_jars = spark_jars
self._num_executors = num_executors
self._executor_cores = executor_cores
self._executor_memory = executor_memory
self._driver_memory = driver_memory
self._verbose = verbose
self._keytab = keytab
self._principal = principal
self._cmd_type = cmd_type
self._jdbc_table = jdbc_table
self._jdbc_conn_id = jdbc_conn_id
Expand All @@ -160,7 +135,6 @@ def __init__(
self._upper_bound = upper_bound
self._create_table_column_types = create_table_column_types
self._hook: SparkJDBCHook | None = None
self._use_krb5ccache = use_krb5ccache

def execute(self, context: Context) -> None:
"""Call the SparkSubmitHook to run the provided spark job."""
Expand All @@ -186,8 +160,8 @@ def _get_hook(self) -> SparkJDBCHook:
executor_memory=self._executor_memory,
driver_memory=self._driver_memory,
verbose=self._verbose,
keytab=self._keytab,
principal=self._principal,
keytab=self.keytab,
principal=self.principal,
cmd_type=self._cmd_type,
jdbc_table=self._jdbc_table,
jdbc_conn_id=self._jdbc_conn_id,
Expand Down
90 changes: 45 additions & 45 deletions airflow/providers/apache/spark/operators/spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,21 +81,21 @@ class SparkSubmitOperator(BaseOperator):
"""

template_fields: Sequence[str] = (
"_application",
"_conf",
"_files",
"_py_files",
"_jars",
"_driver_class_path",
"_packages",
"_exclude_packages",
"_keytab",
"_principal",
"_proxy_user",
"_name",
"_application_args",
"_env_vars",
"_properties_file",
"application",
"conf",
"files",
"py_files",
"jars",
"driver_class_path",
"packages",
"exclude_packages",
"keytab",
"principal",
"proxy_user",
"name",
"application_args",
"env_vars",
"properties_file",
)
ui_color = WEB_COLORS["LIGHTORANGE"]

Expand Down Expand Up @@ -135,32 +135,32 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self._application = application
self._conf = conf
self._files = files
self._py_files = py_files
self.application = application
self.conf = conf
self.files = files
self.py_files = py_files
self._archives = archives
self._driver_class_path = driver_class_path
self._jars = jars
self.driver_class_path = driver_class_path
self.jars = jars
self._java_class = java_class
self._packages = packages
self._exclude_packages = exclude_packages
self.packages = packages
self.exclude_packages = exclude_packages
self._repositories = repositories
self._total_executor_cores = total_executor_cores
self._executor_cores = executor_cores
self._executor_memory = executor_memory
self._driver_memory = driver_memory
self._keytab = keytab
self._principal = principal
self._proxy_user = proxy_user
self._name = name
self.keytab = keytab
self.principal = principal
self.proxy_user = proxy_user
self.name = name
self._num_executors = num_executors
self._status_poll_interval = status_poll_interval
self._application_args = application_args
self._env_vars = env_vars
self.application_args = application_args
self.env_vars = env_vars
self._verbose = verbose
self._spark_binary = spark_binary
self._properties_file = properties_file
self.properties_file = properties_file
self._queue = queue
self._deploy_mode = deploy_mode
self._hook: SparkSubmitHook | None = None
Expand All @@ -171,7 +171,7 @@ def execute(self, context: Context) -> None:
"""Call the SparkSubmitHook to run the provided spark job."""
if self._hook is None:
self._hook = self._get_hook()
self._hook.submit(self._application)
self._hook.submit(self.application)

def on_kill(self) -> None:
if self._hook is None:
Expand All @@ -180,32 +180,32 @@ def on_kill(self) -> None:

def _get_hook(self) -> SparkSubmitHook:
return SparkSubmitHook(
conf=self._conf,
conf=self.conf,
conn_id=self._conn_id,
files=self._files,
py_files=self._py_files,
files=self.files,
py_files=self.py_files,
archives=self._archives,
driver_class_path=self._driver_class_path,
jars=self._jars,
driver_class_path=self.driver_class_path,
jars=self.jars,
java_class=self._java_class,
packages=self._packages,
exclude_packages=self._exclude_packages,
packages=self.packages,
exclude_packages=self.exclude_packages,
repositories=self._repositories,
total_executor_cores=self._total_executor_cores,
executor_cores=self._executor_cores,
executor_memory=self._executor_memory,
driver_memory=self._driver_memory,
keytab=self._keytab,
principal=self._principal,
proxy_user=self._proxy_user,
name=self._name,
keytab=self.keytab,
principal=self.principal,
proxy_user=self.proxy_user,
name=self.name,
num_executors=self._num_executors,
status_poll_interval=self._status_poll_interval,
application_args=self._application_args,
env_vars=self._env_vars,
application_args=self.application_args,
env_vars=self.env_vars,
verbose=self._verbose,
spark_binary=self._spark_binary,
properties_file=self._properties_file,
properties_file=self.properties_file,
queue=self._queue,
deploy_mode=self._deploy_mode,
use_krb5ccache=self._use_krb5ccache,
Expand Down
49 changes: 47 additions & 2 deletions tests/providers/apache/spark/operators/test_spark_jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# under the License.
from __future__ import annotations

import pytest

from airflow.models.dag import DAG
from airflow.providers.apache.spark.operators.spark_jdbc import SparkJDBCOperator
from airflow.utils import timezone
Expand Down Expand Up @@ -111,8 +113,8 @@ def test_execute(self):
assert expected_dict["executor_memory"] == operator._executor_memory
assert expected_dict["driver_memory"] == operator._driver_memory
assert expected_dict["verbose"] == operator._verbose
assert expected_dict["keytab"] == operator._keytab
assert expected_dict["principal"] == operator._principal
assert expected_dict["keytab"] == operator.keytab
assert expected_dict["principal"] == operator.principal
assert expected_dict["cmd_type"] == operator._cmd_type
assert expected_dict["jdbc_table"] == operator._jdbc_table
assert expected_dict["jdbc_driver"] == operator._jdbc_driver
Expand All @@ -128,3 +130,46 @@ def test_execute(self):
assert expected_dict["upper_bound"] == operator._upper_bound
assert expected_dict["create_table_column_types"] == operator._create_table_column_types
assert expected_dict["use_krb5ccache"] == operator._use_krb5ccache

@pytest.mark.db_test
def test_templating_with_create_task_instance_of_operator(self, create_task_instance_of_operator):
ti = create_task_instance_of_operator(
SparkJDBCOperator,
# Templated fields
application="{{ 'application' }}",
conf="{{ 'conf' }}",
files="{{ 'files' }}",
py_files="{{ 'py-files' }}",
jars="{{ 'jars' }}",
driver_class_path="{{ 'driver_class_path' }}",
packages="{{ 'packages' }}",
exclude_packages="{{ 'exclude_packages' }}",
keytab="{{ 'keytab' }}",
principal="{{ 'principal' }}",
proxy_user="{{ 'proxy_user' }}",
name="{{ 'name' }}",
application_args="{{ 'application_args' }}",
env_vars="{{ 'env_vars' }}",
properties_file="{{ 'properties_file' }}",
# Other parameters
dag_id="test_template_body_templating_dag",
task_id="test_template_body_templating_task",
execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
)
ti.render_templates()
task: SparkJDBCOperator = ti.task
assert task.application == "application"
assert task.conf == "conf"
assert task.files == "files"
assert task.py_files == "py-files"
assert task.jars == "jars"
assert task.driver_class_path == "driver_class_path"
assert task.packages == "packages"
assert task.exclude_packages == "exclude_packages"
assert task.keytab == "keytab"
assert task.principal == "principal"
assert task.proxy_user == "proxy_user"
assert task.name == "name"
assert task.application_args == "application_args"
assert task.env_vars == "env_vars"
assert task.properties_file == "properties_file"
Loading

0 comments on commit baa6f08

Please sign in to comment.