diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/CopyToTempView.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/CopyToTempView.scala new file mode 100644 index 000000000000..13259c4964fb --- /dev/null +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/CopyToTempView.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi.command.procedures + +import org.apache.hudi.DataSourceReadOptions +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{DataTypes, Metadata, StructField, StructType} + +import java.util.function.Supplier + +class CopyToTempView extends BaseProcedure with ProcedureBuilder with Logging { + + private val PARAMETERS = Array[ProcedureParameter]( + ProcedureParameter.required(0, "table", DataTypes.StringType, None), + ProcedureParameter.optional(1, "query_type", DataTypes.StringType, DataSourceReadOptions.QUERY_TYPE_SNAPSHOT_OPT_VAL), + ProcedureParameter.required(2, "view_name", DataTypes.StringType, None), + ProcedureParameter.optional(3, "begin_instance_time", DataTypes.StringType, ""), + ProcedureParameter.optional(4, "end_instance_time", DataTypes.StringType, ""), + ProcedureParameter.optional(5, "as_of_instant", DataTypes.StringType, ""), + ProcedureParameter.optional(6, "replace", DataTypes.BooleanType, false), + ProcedureParameter.optional(7, "global", DataTypes.BooleanType, false) + ) + + private val OUTPUT_TYPE = new StructType(Array[StructField]( + StructField("status", DataTypes.IntegerType, nullable = true, Metadata.empty)) + ) + + def parameters: Array[ProcedureParameter] = PARAMETERS + + def outputType: StructType = OUTPUT_TYPE + + override def call(args: ProcedureArgs): Seq[Row] = { + super.checkArgs(PARAMETERS, args) + + val tableName = getArgValueOrDefault(args, PARAMETERS(0)) + val queryType = getArgValueOrDefault(args, PARAMETERS(1)).get.asInstanceOf[String] + val viewName = getArgValueOrDefault(args, PARAMETERS(2)).get.asInstanceOf[String] + val beginInstance = getArgValueOrDefault(args, PARAMETERS(3)).get.asInstanceOf[String] + val endInstance = getArgValueOrDefault(args, PARAMETERS(4)).get.asInstanceOf[String] + val asOfInstant = getArgValueOrDefault(args, PARAMETERS(5)).get.asInstanceOf[String] + val replace = getArgValueOrDefault(args, PARAMETERS(6)).get.asInstanceOf[Boolean] + val global = getArgValueOrDefault(args, PARAMETERS(7)).get.asInstanceOf[Boolean] + + val tablePath = getBasePath(tableName) + + val sourceDataFrame = queryType match { + case DataSourceReadOptions.QUERY_TYPE_SNAPSHOT_OPT_VAL => if (asOfInstant.nonEmpty) { + sparkSession.read + .format("org.apache.hudi") + .option(DataSourceReadOptions.QUERY_TYPE.key, DataSourceReadOptions.QUERY_TYPE_SNAPSHOT_OPT_VAL) + .option(DataSourceReadOptions.TIME_TRAVEL_AS_OF_INSTANT.key, asOfInstant) + .load(tablePath) + } else { + sparkSession.read + .format("org.apache.hudi") + .option(DataSourceReadOptions.QUERY_TYPE.key, DataSourceReadOptions.QUERY_TYPE_SNAPSHOT_OPT_VAL) + .load(tablePath) + } + case DataSourceReadOptions.QUERY_TYPE_INCREMENTAL_OPT_VAL => + assert(beginInstance.nonEmpty && endInstance.nonEmpty, "when the query_type is incremental, begin_instance_time and end_instance_time can not be null.") + sparkSession.read + .format("org.apache.hudi") + .option(DataSourceReadOptions.QUERY_TYPE.key, DataSourceReadOptions.QUERY_TYPE_INCREMENTAL_OPT_VAL) + .option(DataSourceReadOptions.BEGIN_INSTANTTIME.key, beginInstance) + .option(DataSourceReadOptions.END_INSTANTTIME.key, endInstance) + .load(tablePath) + case DataSourceReadOptions.QUERY_TYPE_READ_OPTIMIZED_OPT_VAL => + sparkSession.read + .format("org.apache.hudi") + .option(DataSourceReadOptions.QUERY_TYPE.key, DataSourceReadOptions.QUERY_TYPE_READ_OPTIMIZED_OPT_VAL) + .load(tablePath) + } + if (global) { + if (replace) { + sourceDataFrame.createOrReplaceGlobalTempView(viewName) + } else { + sourceDataFrame.createGlobalTempView(viewName) + } + } else { + if (replace) { + sourceDataFrame.createOrReplaceTempView(viewName) + } else { + sourceDataFrame.createTempView(viewName) + } + } + Seq(Row(0)) + } + + override def build = new CopyToTempView() +} + +object CopyToTempView { + val NAME = "copy_to_temp_view" + + def builder: Supplier[ProcedureBuilder] = new Supplier[ProcedureBuilder] { + override def get() = new CopyToTempView() + } +} diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/HoodieProcedures.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/HoodieProcedures.scala index b2bbec848945..713cd5d7da29 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/HoodieProcedures.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/procedures/HoodieProcedures.scala @@ -80,6 +80,7 @@ object HoodieProcedures { ,(ValidateHoodieSyncProcedure.NAME, ValidateHoodieSyncProcedure.builder) ,(ShowInvalidParquetProcedure.NAME, ShowInvalidParquetProcedure.builder) ,(HiveSyncProcedure.NAME, HiveSyncProcedure.builder) + ,(CopyToTempView.NAME, CopyToTempView.builder) ) } } diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/procedure/TestCopyToTempViewProcedure.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/procedure/TestCopyToTempViewProcedure.scala new file mode 100644 index 000000000000..13da259df1a6 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/procedure/TestCopyToTempViewProcedure.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi.procedure + +import org.apache.spark.sql.hudi.HoodieSparkSqlTestBase + +class TestCopyToTempViewProcedure extends HoodieSparkSqlTestBase { + + + test("Test Call copy_to_temp_view Procedure with default params") { + withTempDir { tmp => + val tableName = generateTableName + // create table + spark.sql( + s""" + |create table $tableName ( + | id int, + | name string, + | price double, + | ts long + |) using hudi + | location '${tmp.getCanonicalPath}/$tableName' + | tblproperties ( + | primaryKey = 'id', + | preCombineField = 'ts' + | ) + """.stripMargin) + + // insert data to table + spark.sql(s"insert into $tableName select 1, 'a1', 10, 1000") + spark.sql(s"insert into $tableName select 2, 'a2', 20, 1500") + spark.sql(s"insert into $tableName select 3, 'a3', 30, 2000") + spark.sql(s"insert into $tableName select 4, 'a4', 40, 2500") + + // Check required fields + checkExceptionContain(s"call copy_to_temp_view(table=>'$tableName')")(s"Argument: view_name is required") + + val viewName = generateTableName + + val row = spark.sql(s"""call copy_to_temp_view(table=>'$tableName',view_name=>'$viewName')""").collectAsList() + assert(row.size() == 1 && row.get(0).get(0) == 0) + val copyTableCount = spark.sql(s"""select count(1) from $viewName""").collectAsList() + assert(copyTableCount.size() == 1 && copyTableCount.get(0).get(0) == 4) + } + } + + test("Test Call copy_to_temp_view Procedure with replace params") { + withTempDir { tmp => + val tableName = generateTableName + // create table + spark.sql( + s""" + |create table $tableName ( + | id int, + | name string, + | price double, + | ts long + |) using hudi + | location '${tmp.getCanonicalPath}/$tableName' + | tblproperties ( + | primaryKey = 'id', + | preCombineField = 'ts' + | ) + """.stripMargin) + + // insert data to table + spark.sql(s"insert into $tableName select 1, 'a1', 10, 1000") + spark.sql(s"insert into $tableName select 2, 'a2', 20, 1500") + spark.sql(s"insert into $tableName select 3, 'a3', 30, 2000") + spark.sql(s"insert into $tableName select 4, 'a4', 40, 2500") + + // Check required fields + checkExceptionContain(s"call copy_to_temp_view(table=>'$tableName')")(s"Argument: view_name is required") + + // 1: copyToTempView + val viewName = generateTableName + val row = spark.sql(s"""call copy_to_temp_view(table=>'$tableName',view_name=>'$viewName')""").collectAsList() + assert(row.size() == 1 && row.get(0).get(0) == 0) + val copyTableCount = spark.sql(s"""select count(1) from $viewName""").collectAsList() + assert(copyTableCount.size() == 1 && copyTableCount.get(0).get(0) == 4) + + // 2: add new record to hudi table + spark.sql(s"insert into $tableName select 5, 'a5', 40, 2500") + + // 3: copyToTempView with replace=false + checkExceptionContain(s"""call copy_to_temp_view(table=>'$tableName',view_name=>'$viewName',replace=>false)""")(s"Temporary view '$viewName' already exists") + // 4: copyToTempView with replace=true + val row2 = spark.sql(s"""call copy_to_temp_view(table=>'$tableName',view_name=>'$viewName',replace=>true)""").collectAsList() + assert(row2.size() == 1 && row2.get(0).get(0) == 0) + // 5: query new replace view ,count=5 + val newViewCount = spark.sql(s"""select count(1) from $viewName""").collectAsList() + assert(newViewCount.size() == 1 && newViewCount.get(0).get(0) == 5) + } + } + + test("Test Call copy_to_temp_view Procedure with global params") { + withTempDir { tmp => + val tableName = generateTableName + // create table + spark.sql( + s""" + |create table $tableName ( + | id int, + | name string, + | price double, + | ts long + |) using hudi + | location '${tmp.getCanonicalPath}/$tableName' + | tblproperties ( + | primaryKey = 'id', + | preCombineField = 'ts' + | ) + """.stripMargin) + + // insert data to table + spark.sql(s"insert into $tableName select 1, 'a1', 10, 1000") + spark.sql(s"insert into $tableName select 2, 'a2', 20, 1500") + spark.sql(s"insert into $tableName select 3, 'a3', 30, 2000") + spark.sql(s"insert into $tableName select 4, 'a4', 40, 2500") + + // Check required fields + checkExceptionContain(s"call copy_to_temp_view(table=>'$tableName')")(s"Argument: view_name is required") + + // 1: copyToTempView with global=false + val viewName = generateTableName + val row = spark.sql(s"""call copy_to_temp_view(table=>'$tableName',view_name=>'$viewName',global=>false)""").collectAsList() + assert(row.size() == 1 && row.get(0).get(0) == 0) + val copyTableCount = spark.sql(s"""select count(1) from $viewName""").collectAsList() + assert(copyTableCount.size() == 1 && copyTableCount.get(0).get(0) == 4) + + // 2: query view in other session + var newSession = spark.newSession() + var hasException = false + val errorMsg = s"Table or view not found: $viewName" + try { + newSession.sql(s"""select count(1) from $viewName""") + } catch { + case e: Throwable if e.getMessage.contains(errorMsg) => hasException = true + case f: Throwable => fail("Exception should contain: " + errorMsg + ", error message: " + f.getMessage, f) + } + assertResult(true)(hasException) + // 3: copyToTempView with global=true, + val row2 = spark.sql(s"""call copy_to_temp_view(table=>'$tableName',view_name=>'$viewName',global=>true,replace=>true)""").collectAsList() + assert(row2.size() == 1 && row2.get(0).get(0) == 0) + + newSession = spark.newSession() + // 4: query view in other session + val newViewCount = spark.sql(s"""select count(1) from $viewName""").collectAsList() + assert(newViewCount.size() == 1 && newViewCount.get(0).get(0) == 4) + + } + } +}