diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 8743d72b887e1..3b7bd2a71d2d2 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -121,8 +121,8 @@ test - mysql - mysql-connector-java + org.mariadb.jdbc + mariadb-java-client test diff --git a/external/docker-integration-tests/src/test/resources/mariadb_docker_entrypoint.sh b/external/docker-integration-tests/src/test/resources/mariadb_docker_entrypoint.sh new file mode 100755 index 0000000000000..00885a3b62327 --- /dev/null +++ b/external/docker-integration-tests/src/test/resources/mariadb_docker_entrypoint.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash + +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +dpkg-divert --add /bin/systemctl && ln -sT /bin/true /bin/systemctl +apt update +apt install -y mariadb-plugin-gssapi-server +echo "gssapi_keytab_path=/docker-entrypoint-initdb.d/mariadb.keytab" >> /etc/mysql/mariadb.conf.d/auth_gssapi.cnf +echo "gssapi_principal_name=mariadb/__IP_ADDRESS_REPLACE_ME__@EXAMPLE.COM" >> /etc/mysql/mariadb.conf.d/auth_gssapi.cnf +docker-entrypoint.sh mysqld diff --git a/external/docker-integration-tests/src/test/resources/mariadb_krb_setup.sh b/external/docker-integration-tests/src/test/resources/mariadb_krb_setup.sh new file mode 100755 index 0000000000000..e97be805b4592 --- /dev/null +++ b/external/docker-integration-tests/src/test/resources/mariadb_krb_setup.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +mysql -u root -p'rootpass' -e 'CREATE USER "mariadb/__IP_ADDRESS_REPLACE_ME__@EXAMPLE.COM" IDENTIFIED WITH gssapi;' +mysql -u root -p'rootpass' -D mysql -e 'GRANT ALL PRIVILEGES ON *.* TO "mariadb/__IP_ADDRESS_REPLACE_ME__@EXAMPLE.COM";' diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala index cd26fb3628151..376dd4646608c 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala @@ -58,10 +58,19 @@ abstract class DatabaseOnDocker { */ def getJdbcUrl(ip: String, port: Int): String + /** + * Optional entry point when container starts + * + * Startup process is a parameter of entry point. This may or may not be considered during + * startup. Prefer entry point to startup process when you need a command always to be executed or + * you want to change the initialization order. + */ + def getEntryPoint: Option[String] = None + /** * Optional process to run when container starts */ - def getStartupProcessName: Option[String] + def getStartupProcessName: Option[String] = None /** * Optional step before container starts @@ -77,6 +86,7 @@ abstract class DockerJDBCIntegrationSuite extends SharedSparkSession with Eventu val db: DatabaseOnDocker private var docker: DockerClient = _ + protected var externalPort: Int = _ private var containerId: String = _ protected var jdbcUrl: String = _ @@ -101,7 +111,7 @@ abstract class DockerJDBCIntegrationSuite extends SharedSparkSession with Eventu docker.pull(db.imageName) } // Configure networking (necessary for boot2docker / Docker Machine) - val externalPort: Int = { + externalPort = { val sock = new ServerSocket(0) val port = sock.getLocalPort sock.close() @@ -118,9 +128,11 @@ abstract class DockerJDBCIntegrationSuite extends SharedSparkSession with Eventu .networkDisabled(false) .env(db.env.map { case (k, v) => s"$k=$v" }.toSeq.asJava) .exposedPorts(s"${db.jdbcPort}/tcp") - if(db.getStartupProcessName.isDefined) { - containerConfigBuilder - .cmd(db.getStartupProcessName.get) + if (db.getEntryPoint.isDefined) { + containerConfigBuilder.entrypoint(db.getEntryPoint.get) + } + if (db.getStartupProcessName.isDefined) { + containerConfigBuilder.cmd(db.getStartupProcessName.get) } db.beforeContainerStart(hostConfigBuilder, containerConfigBuilder) containerConfigBuilder.hostConfig(hostConfigBuilder.build()) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerKrbJDBCIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerKrbJDBCIntegrationSuite.scala index 583d8108c716c..009b4a2b1b32e 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerKrbJDBCIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerKrbJDBCIntegrationSuite.scala @@ -18,17 +18,22 @@ package org.apache.spark.sql.jdbc import java.io.{File, FileInputStream, FileOutputStream} +import java.sql.Connection +import java.util.Properties import javax.security.auth.login.Configuration import scala.io.Source import org.apache.hadoop.minikdc.MiniKdc +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.StringType import org.apache.spark.util.{SecurityUtils, Utils} abstract class DockerKrbJDBCIntegrationSuite extends DockerJDBCIntegrationSuite { private var kdc: MiniKdc = _ - protected var workDir: File = _ + protected var entryPointDir: File = _ + protected var initDbDir: File = _ protected val userName: String protected var principal: String = _ protected val keytabFileName: String @@ -46,8 +51,9 @@ abstract class DockerKrbJDBCIntegrationSuite extends DockerJDBCIntegrationSuite principal = s"$userName@${kdc.getRealm}" - workDir = Utils.createTempDir() - val keytabFile = new File(workDir, keytabFileName) + entryPointDir = Utils.createTempDir() + initDbDir = Utils.createTempDir() + val keytabFile = new File(initDbDir, keytabFileName) keytabFullPath = keytabFile.getAbsolutePath kdc.createPrincipal(keytabFile, userName) logInfo(s"Created keytab file: $keytabFullPath") @@ -62,6 +68,7 @@ abstract class DockerKrbJDBCIntegrationSuite extends DockerJDBCIntegrationSuite try { if (kdc != null) { kdc.stop() + kdc = null } Configuration.setConfiguration(null) SecurityUtils.setGlobalKrbDebug(false) @@ -71,7 +78,7 @@ abstract class DockerKrbJDBCIntegrationSuite extends DockerJDBCIntegrationSuite } protected def copyExecutableResource( - fileName: String, dir: File, processLine: String => String) = { + fileName: String, dir: File, processLine: String => String = identity) = { val newEntry = new File(dir.getAbsolutePath, fileName) newEntry.createNewFile() Utils.tryWithResource( @@ -91,4 +98,64 @@ abstract class DockerKrbJDBCIntegrationSuite extends DockerJDBCIntegrationSuite logInfo(s"Created executable resource file: ${newEntry.getAbsolutePath}") newEntry } + + override def dataPreparation(conn: Connection): Unit = { + conn.prepareStatement("CREATE TABLE bar (c0 text)").executeUpdate() + conn.prepareStatement("INSERT INTO bar VALUES ('hello')").executeUpdate() + } + + test("Basic read test in query option") { + // This makes sure Spark must do authentication + Configuration.setConfiguration(null) + + val expectedResult = Set("hello").map(Row(_)) + + val query = "SELECT c0 FROM bar" + // query option to pass on the query string. + val df = spark.read.format("jdbc") + .option("url", jdbcUrl) + .option("keytab", keytabFullPath) + .option("principal", principal) + .option("query", query) + .load() + assert(df.collect().toSet === expectedResult) + } + + test("Basic read test in create table path") { + // This makes sure Spark must do authentication + Configuration.setConfiguration(null) + + val expectedResult = Set("hello").map(Row(_)) + + val query = "SELECT c0 FROM bar" + // query option in the create table path. + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW queryOption + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$jdbcUrl', query '$query', keytab '$keytabFullPath', principal '$principal') + """.stripMargin.replaceAll("\n", " ")) + assert(sql("select c0 from queryOption").collect().toSet === expectedResult) + } + + test("Basic write test") { + // This makes sure Spark must do authentication + Configuration.setConfiguration(null) + + val props = new Properties + props.setProperty("keytab", keytabFullPath) + props.setProperty("principal", principal) + + val tableName = "write_test" + sqlContext.createDataFrame(Seq(("foo", "bar"))) + .write.jdbc(jdbcUrl, tableName, props) + val df = sqlContext.read.jdbc(jdbcUrl, tableName, props) + + val schema = df.schema + assert(schema.map(_.dataType).toSeq === Seq(StringType, StringType)) + val rows = df.collect() + assert(rows.length === 1) + assert(rows(0).getString(0) === "foo") + assert(rows(0).getString(1) === "bar") + } } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MariaDBKrbIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MariaDBKrbIntegrationSuite.scala new file mode 100644 index 0000000000000..7c1adc990bab3 --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MariaDBKrbIntegrationSuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import javax.security.auth.login.Configuration + +import com.spotify.docker.client.messages.{ContainerConfig, HostConfig} + +import org.apache.spark.sql.execution.datasources.jdbc.connection.SecureConnectionProvider +import org.apache.spark.tags.DockerTest + +@DockerTest +class MariaDBKrbIntegrationSuite extends DockerKrbJDBCIntegrationSuite { + override protected val userName = s"mariadb/$dockerIp" + override protected val keytabFileName = "mariadb.keytab" + + override val db = new DatabaseOnDocker { + override val imageName = "mariadb:10.4" + override val env = Map( + "MYSQL_ROOT_PASSWORD" -> "rootpass" + ) + override val usesIpc = false + override val jdbcPort = 3306 + + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:mysql://$ip:$port/mysql?user=$principal" + + override def getEntryPoint: Option[String] = + Some("/docker-entrypoint/mariadb_docker_entrypoint.sh") + + override def beforeContainerStart( + hostConfigBuilder: HostConfig.Builder, + containerConfigBuilder: ContainerConfig.Builder): Unit = { + def replaceIp(s: String): String = s.replace("__IP_ADDRESS_REPLACE_ME__", dockerIp) + copyExecutableResource("mariadb_docker_entrypoint.sh", entryPointDir, replaceIp) + copyExecutableResource("mariadb_krb_setup.sh", initDbDir, replaceIp) + + hostConfigBuilder.appendBinds( + HostConfig.Bind.from(entryPointDir.getAbsolutePath) + .to("/docker-entrypoint").readOnly(true).build(), + HostConfig.Bind.from(initDbDir.getAbsolutePath) + .to("/docker-entrypoint-initdb.d").readOnly(true).build() + ) + } + } + + override protected def setAuthentication(keytabFile: String, principal: String): Unit = { + val config = new SecureConnectionProvider.JDBCConfiguration( + Configuration.getConfiguration, "Krb5ConnectorContext", keytabFile, principal) + Configuration.setConfiguration(config) + } +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala index 5738307095933..42d64873c44d9 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala @@ -37,8 +37,6 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { override def getJdbcUrl(ip: String, port: Int): String = s"jdbc:sqlserver://$ip:$port;user=sa;password=Sapass123;" - - override def getStartupProcessName: Option[String] = None } override def dataPreparation(conn: Connection): Unit = { diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala index bba1b5275269b..4cbcb59e02de1 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -35,7 +35,6 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { override val jdbcPort: Int = 3306 override def getJdbcUrl(ip: String, port: Int): String = s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass" - override def getStartupProcessName: Option[String] = None } override def dataPreparation(conn: Connection): Unit = { diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 6faa888cf18ed..24c3adb9c0153 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -66,7 +66,6 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSpark override val jdbcPort: Int = 1521 override def getJdbcUrl(ip: String, port: Int): String = s"jdbc:oracle:thin:system/oracle@//$ip:$port/xe" - override def getStartupProcessName: Option[String] = None } override def dataPreparation(conn: Connection): Unit = { diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 599f00def0750..6611bc2d19ed8 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -37,7 +37,6 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { override val jdbcPort = 5432 override def getJdbcUrl(ip: String, port: Int): String = s"jdbc:postgresql://$ip:$port/postgres?user=postgres&password=rootpass" - override def getStartupProcessName: Option[String] = None } override def dataPreparation(conn: Connection): Unit = { diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala index 721a4882b986a..adf30fbdc1e12 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala @@ -17,15 +17,11 @@ package org.apache.spark.sql.jdbc -import java.sql.Connection -import java.util.Properties import javax.security.auth.login.Configuration import com.spotify.docker.client.messages.{ContainerConfig, HostConfig} -import org.apache.spark.sql.Row -import org.apache.spark.sql.execution.datasources.jdbc.connection.PostgresConnectionProvider -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.execution.datasources.jdbc.connection.SecureConnectionProvider import org.apache.spark.tags.DockerTest @DockerTest @@ -44,86 +40,22 @@ class PostgresKrbIntegrationSuite extends DockerKrbJDBCIntegrationSuite { override def getJdbcUrl(ip: String, port: Int): String = s"jdbc:postgresql://$ip:$port/postgres?user=$principal&gsslib=gssapi" - override def getStartupProcessName: Option[String] = None - override def beforeContainerStart( hostConfigBuilder: HostConfig.Builder, containerConfigBuilder: ContainerConfig.Builder): Unit = { def replaceIp(s: String): String = s.replace("__IP_ADDRESS_REPLACE_ME__", dockerIp) - copyExecutableResource("postgres_krb_setup.sh", workDir, replaceIp) + copyExecutableResource("postgres_krb_setup.sh", initDbDir, replaceIp) hostConfigBuilder.appendBinds( - HostConfig.Bind.from(workDir.getAbsolutePath) + HostConfig.Bind.from(initDbDir.getAbsolutePath) .to("/docker-entrypoint-initdb.d").readOnly(true).build() ) } } override protected def setAuthentication(keytabFile: String, principal: String): Unit = { - val config = new PostgresConnectionProvider.PGJDBCConfiguration( + val config = new SecureConnectionProvider.JDBCConfiguration( Configuration.getConfiguration, "pgjdbc", keytabFile, principal) Configuration.setConfiguration(config) } - - override def dataPreparation(conn: Connection): Unit = { - conn.prepareStatement("CREATE DATABASE foo").executeUpdate() - conn.setCatalog("foo") - conn.prepareStatement("CREATE TABLE bar (c0 text)").executeUpdate() - conn.prepareStatement("INSERT INTO bar VALUES ('hello')").executeUpdate() - } - - test("Basic read test in query option") { - // This makes sure Spark must do authentication - Configuration.setConfiguration(null) - - val expectedResult = Set("hello").map(Row(_)) - - val query = "SELECT c0 FROM bar" - // query option to pass on the query string. - val df = spark.read.format("jdbc") - .option("url", jdbcUrl) - .option("keytab", keytabFullPath) - .option("principal", principal) - .option("query", query) - .load() - assert(df.collect().toSet === expectedResult) - } - - test("Basic read test in create table path") { - // This makes sure Spark must do authentication - Configuration.setConfiguration(null) - - val expectedResult = Set("hello").map(Row(_)) - - val query = "SELECT c0 FROM bar" - // query option in the create table path. - sql( - s""" - |CREATE OR REPLACE TEMPORARY VIEW queryOption - |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$jdbcUrl', query '$query', keytab '$keytabFullPath', principal '$principal') - """.stripMargin.replaceAll("\n", " ")) - assert(sql("select c0 from queryOption").collect().toSet === expectedResult) - } - - test("Basic write test") { - // This makes sure Spark must do authentication - Configuration.setConfiguration(null) - - val props = new Properties - props.setProperty("keytab", keytabFullPath) - props.setProperty("principal", principal) - - val tableName = "write_test" - sqlContext.createDataFrame(Seq(("foo", "bar"))) - .write.jdbc(jdbcUrl, tableName, props) - val df = sqlContext.read.jdbc(jdbcUrl, tableName, props) - - val schema = df.schema - assert(schema.map(_.dataType).toSeq === Seq(StringType, StringType)) - val rows = df.collect() - assert(rows.length === 1) - assert(rows(0).getString(0) === "foo") - assert(rows(0).getString(1) === "bar") - } } diff --git a/pom.xml b/pom.xml index cc48ee794ea04..cd85db6e03264 100644 --- a/pom.xml +++ b/pom.xml @@ -951,6 +951,12 @@ 5.1.38 test + + org.mariadb.jdbc + mariadb-java-client + 2.5.4 + test + org.postgresql postgresql diff --git a/sql/core/pom.xml b/sql/core/pom.xml index c95fe3ce1c120..e97c7fd3280be 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -131,8 +131,8 @@ test - mysql - mysql-connector-java + org.mariadb.jdbc + mariadb-java-client test diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala index ccaff0d6ca7d4..c864f1f52fcce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala @@ -28,6 +28,9 @@ import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions * the parameters. */ private[jdbc] trait ConnectionProvider { + /** + * Opens connection toward the database. + */ def getConnection(): Connection } @@ -43,6 +46,10 @@ private[jdbc] object ConnectionProvider extends Logging { logDebug("Postgres connection provider found") new PostgresConnectionProvider(driver, options) + case MariaDBConnectionProvider.driverClass => + logDebug("MariaDB connection provider found") + new MariaDBConnectionProvider(driver, options) + case _ => throw new IllegalArgumentException(s"Driver ${options.driverClass} does not support " + "Kerberos authentication") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProvider.scala new file mode 100644 index 0000000000000..eb2f0f78022ba --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProvider.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc.connection + +import java.sql.Driver +import javax.security.auth.login.Configuration + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions + +private[jdbc] class MariaDBConnectionProvider(driver: Driver, options: JDBCOptions) + extends SecureConnectionProvider(driver, options) { + override val appEntry: String = { + "Krb5ConnectorContext" + } + + override def setAuthenticationConfigIfNeeded(): Unit = { + val parent = Configuration.getConfiguration + val configEntry = parent.getAppConfigurationEntry(appEntry) + /** + * Couple of things to mention here: + * 1. MariaDB doesn't support JAAS application name configuration + * 2. MariaDB sets a default JAAS config if "java.security.auth.login.config" is not set + */ + val entryUsesKeytab = configEntry != null && + configEntry.exists(_.getOptions().get("useKeyTab") == "true") + if (configEntry == null || configEntry.isEmpty || !entryUsesKeytab) { + val config = new SecureConnectionProvider.JDBCConfiguration( + parent, appEntry, options.keytab, options.principal) + logDebug("Adding database specific security configuration") + Configuration.setConfiguration(config) + } + } +} + +private[sql] object MariaDBConnectionProvider { + val driverClass = "org.mariadb.jdbc.Driver" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProvider.scala index e793c4dfd780e..14911fc75ebc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProvider.scala @@ -17,66 +17,32 @@ package org.apache.spark.sql.execution.datasources.jdbc.connection -import java.sql.{Connection, Driver} +import java.sql.Driver import java.util.Properties -import javax.security.auth.login.{AppConfigurationEntry, Configuration} - -import scala.collection.JavaConverters._ +import javax.security.auth.login.Configuration import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions -import org.apache.spark.sql.execution.datasources.jdbc.connection.PostgresConnectionProvider.PGJDBCConfiguration -import org.apache.spark.util.SecurityUtils private[jdbc] class PostgresConnectionProvider(driver: Driver, options: JDBCOptions) - extends BasicConnectionProvider(driver, options) { - val appEntry: String = { + extends SecureConnectionProvider(driver, options) { + override val appEntry: String = { val parseURL = driver.getClass.getMethod("parseURL", classOf[String], classOf[Properties]) val properties = parseURL.invoke(driver, options.url, null).asInstanceOf[Properties] properties.getProperty("jaasApplicationName", "pgjdbc") } - def setAuthenticationConfigIfNeeded(): Unit = { + override def setAuthenticationConfigIfNeeded(): Unit = { val parent = Configuration.getConfiguration val configEntry = parent.getAppConfigurationEntry(appEntry) if (configEntry == null || configEntry.isEmpty) { - val config = new PGJDBCConfiguration(parent, appEntry, options.keytab, options.principal) + val config = new SecureConnectionProvider.JDBCConfiguration( + parent, appEntry, options.keytab, options.principal) + logDebug("Adding database specific security configuration") Configuration.setConfiguration(config) } } - - override def getConnection(): Connection = { - setAuthenticationConfigIfNeeded() - super.getConnection() - } } private[sql] object PostgresConnectionProvider { - class PGJDBCConfiguration( - parent: Configuration, - appEntry: String, - keytab: String, - principal: String) extends Configuration { - private val entry = - new AppConfigurationEntry( - SecurityUtils.getKrb5LoginModuleName(), - AppConfigurationEntry.LoginModuleControlFlag.REQUIRED, - Map[String, Object]( - "useTicketCache" -> "false", - "useKeyTab" -> "true", - "keyTab" -> keytab, - "principal" -> principal, - "debug" -> "true" - ).asJava - ) - - override def getAppConfigurationEntry(name: String): Array[AppConfigurationEntry] = { - if (name.equals(appEntry)) { - Array(entry) - } else { - parent.getAppConfigurationEntry(name) - } - } - } - val driverClass = "org.postgresql.Driver" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/SecureConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/SecureConnectionProvider.scala new file mode 100644 index 0000000000000..ff192d71e6f33 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/SecureConnectionProvider.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc.connection + +import java.sql.{Connection, Driver} +import javax.security.auth.login.{AppConfigurationEntry, Configuration} + +import scala.collection.JavaConverters._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions +import org.apache.spark.util.SecurityUtils + +private[jdbc] abstract class SecureConnectionProvider(driver: Driver, options: JDBCOptions) + extends BasicConnectionProvider(driver, options) with Logging { + override def getConnection(): Connection = { + setAuthenticationConfigIfNeeded() + super.getConnection() + } + + /** + * Returns JAAS application name. This is sometimes configurable on the JDBC driver level. + */ + val appEntry: String + + /** + * Sets database specific authentication configuration when needed. If configuration already set + * then later calls must be no op. + */ + def setAuthenticationConfigIfNeeded(): Unit +} + +object SecureConnectionProvider { + class JDBCConfiguration( + parent: Configuration, + appEntry: String, + keytab: String, + principal: String) extends Configuration { + val entry = + new AppConfigurationEntry( + SecurityUtils.getKrb5LoginModuleName(), + AppConfigurationEntry.LoginModuleControlFlag.REQUIRED, + Map[String, Object]( + "useTicketCache" -> "false", + "useKeyTab" -> "true", + "keyTab" -> keytab, + "principal" -> principal, + "debug" -> "true" + ).asJava + ) + + override def getAppConfigurationEntry(name: String): Array[AppConfigurationEntry] = { + if (name.equals(appEntry)) { + Array(entry) + } else { + parent.getAppConfigurationEntry(name) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuiteBase.scala new file mode 100644 index 0000000000000..d18a3088c4f2f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuiteBase.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc.connection + +import java.sql.{Driver, DriverManager} +import javax.security.auth.login.Configuration + +import scala.collection.JavaConverters._ + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions} + +abstract class ConnectionProviderSuiteBase extends SparkFunSuite with BeforeAndAfterEach { + protected def registerDriver(driverClass: String): Driver = { + DriverRegistry.register(driverClass) + DriverManager.getDrivers.asScala.collectFirst { + case d if d.getClass.getCanonicalName == driverClass => d + }.get + } + + protected def options(url: String) = new JDBCOptions(Map[String, String]( + JDBCOptions.JDBC_URL -> url, + JDBCOptions.JDBC_TABLE_NAME -> "table", + JDBCOptions.JDBC_KEYTAB -> "/path/to/keytab", + JDBCOptions.JDBC_PRINCIPAL -> "principal" + )) + + override def afterEach(): Unit = { + try { + Configuration.setConfiguration(null) + } finally { + super.afterEach() + } + } + + protected def testSecureConnectionProvider(provider: SecureConnectionProvider): Unit = { + // Make sure no authentication for the database is set + assert(Configuration.getConfiguration.getAppConfigurationEntry(provider.appEntry) == null) + + // Make sure the first call sets authentication properly + val savedConfig = Configuration.getConfiguration + provider.setAuthenticationConfigIfNeeded() + val config = Configuration.getConfiguration + assert(savedConfig != config) + val appEntry = config.getAppConfigurationEntry(provider.appEntry) + assert(appEntry != null) + + // Make sure a second call is not modifying the existing authentication + provider.setAuthenticationConfigIfNeeded() + assert(config.getAppConfigurationEntry(provider.appEntry) === appEntry) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProviderSuite.scala new file mode 100644 index 0000000000000..70cad2097eb43 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProviderSuite.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc.connection + +class MariaDBConnectionProviderSuite extends ConnectionProviderSuiteBase { + test("setAuthenticationConfigIfNeeded must set authentication if not set") { + val driver = registerDriver(MariaDBConnectionProvider.driverClass) + val provider = new MariaDBConnectionProvider(driver, options("jdbc:mysql://localhost/mysql")) + + testSecureConnectionProvider(provider) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProviderSuite.scala index 59ff1c79bd064..8cef7652f9c54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProviderSuite.scala @@ -17,69 +17,16 @@ package org.apache.spark.sql.execution.datasources.jdbc.connection -import java.sql.{Driver, DriverManager} -import javax.security.auth.login.Configuration - -import scala.collection.JavaConverters._ - -import org.scalatest.BeforeAndAfterEach - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions} - -class PostgresConnectionProviderSuite extends SparkFunSuite with BeforeAndAfterEach { - private def options(url: String) = new JDBCOptions(Map[String, String]( - JDBCOptions.JDBC_URL -> url, - JDBCOptions.JDBC_TABLE_NAME -> "table", - JDBCOptions.JDBC_KEYTAB -> "/path/to/keytab", - JDBCOptions.JDBC_PRINCIPAL -> "principal" - )) - - override def afterEach(): Unit = { - try { - Configuration.setConfiguration(null) - } finally { - super.afterEach() - } - } - +class PostgresConnectionProviderSuite extends ConnectionProviderSuiteBase { test("setAuthenticationConfigIfNeeded must set authentication if not set") { - DriverRegistry.register(PostgresConnectionProvider.driverClass) - val driver = DriverManager.getDrivers.asScala.collectFirst { - case d if d.getClass.getCanonicalName == PostgresConnectionProvider.driverClass => d - }.get + val driver = registerDriver(PostgresConnectionProvider.driverClass) val defaultProvider = new PostgresConnectionProvider( driver, options("jdbc:postgresql://localhost/postgres")) val customProvider = new PostgresConnectionProvider( driver, options(s"jdbc:postgresql://localhost/postgres?jaasApplicationName=custompgjdbc")) assert(defaultProvider.appEntry !== customProvider.appEntry) - - // Make sure no authentication for postgres is set - assert(Configuration.getConfiguration.getAppConfigurationEntry( - defaultProvider.appEntry) == null) - assert(Configuration.getConfiguration.getAppConfigurationEntry( - customProvider.appEntry) == null) - - // Make sure the first call sets authentication properly - val savedConfig = Configuration.getConfiguration - defaultProvider.setAuthenticationConfigIfNeeded() - val defaultConfig = Configuration.getConfiguration - assert(savedConfig != defaultConfig) - val defaultAppEntry = defaultConfig.getAppConfigurationEntry(defaultProvider.appEntry) - assert(defaultAppEntry != null) - customProvider.setAuthenticationConfigIfNeeded() - val customConfig = Configuration.getConfiguration - assert(savedConfig != customConfig) - assert(defaultConfig != customConfig) - val customAppEntry = customConfig.getAppConfigurationEntry(customProvider.appEntry) - assert(customAppEntry != null) - - // Make sure a second call is not modifying the existing authentication - defaultProvider.setAuthenticationConfigIfNeeded() - customProvider.setAuthenticationConfigIfNeeded() - assert(customConfig == Configuration.getConfiguration) - assert(defaultConfig.getAppConfigurationEntry(defaultProvider.appEntry) === defaultAppEntry) - assert(customConfig.getAppConfigurationEntry(customProvider.appEntry) === customAppEntry) + testSecureConnectionProvider(defaultProvider) + testSecureConnectionProvider(customProvider) } }