Skip to content

Commit

Permalink
Write schema_diff function
Browse files Browse the repository at this point in the history
```sql
spark-sql> SELECT * FROM schema_diff('lakefs', 'main~2', 'main', 'db.allstar_games');
+       2022-23 9876    AXL     1       2       3       4       5       6       7       99      Axo Lotl
+       2023-24 100     AXL     2       3       4       5       6       7       8       123     Axo Lotl
Time taken: 15.993 seconds, Fetched 2 row(s)
```

and

```sql
spark-sql> SELECT * FROM schema_diff('lakefs', 'main~', 'main', 'db.allstar_games');
+       2022-23 9876    AXL     1       2       3       4       5       6       7       99      Axo Lotl
-       2022-23 99      AXL     1       2       3       4       5       6       7       99      Axo Lotl
Time taken: 16.058 seconds, Fetched 2 row(s)
```

(Probably slow because it's running on my laptop and the lakeFS instance is in the cloud and I'm out of memory and ...).
  • Loading branch information
arielshaqed committed Jul 31, 2023
1 parent e95d191 commit f4118b1
Showing 1 changed file with 34 additions and 18 deletions.
52 changes: 34 additions & 18 deletions src/main/scala/io/lakefs/iceberg/extension/Extension.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,9 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
import org.apache.spark.sql.catalyst.expressions.StringLiteral


// A table-valued function that adds a column to a table.
//
// NEVER USE THIS, it allows trivial SQL injections!
object WithColumn {
private def sql(tableName: String, columnName: String, columnExpression: String) =
// BUG(ariels): Dangerous, allows SQL injections!
s"SELECT *, $columnExpression $columnName FROM $tableName"

// A table-valued function to compute the difference between the same table
// at two schemas.
object SchemaDiff {
private def computeString(e: Expression): String = {
val literalValue = StringLiteral.unapply(e)
literalValue match {
Expand All @@ -24,28 +19,49 @@ object WithColumn {
}
}

private def tableAt(prefix: String, schema: String, suffix: String): String = {
val spark = SparkSession.getActiveSession match {
case None => throw new RuntimeException("Whoops! No Spark session...")
case Some(spark) => spark
}
val parseId = spark.sessionState.sqlParser.parseMultipartIdentifier(_)
val parts = parseId(prefix) ++ Seq(schema) ++ parseId(suffix)
parts.map(str => s"`$str`").mkString(".")
}

private def sql(prefix: String, fromSchema: String, toSchema: String, suffix: String) = {
val fromTableName = tableAt(prefix, fromSchema, suffix)
val toTableName = tableAt(prefix, toSchema, suffix)

s"""
(SELECT '+', * FROM (SELECT * FROM $toTableName EXCEPT SELECT * FROM $fromTableName))
UNION ALL
(SELECT '-', * FROM (SELECT * FROM $fromTableName EXCEPT SELECT * FROM $toTableName))
"""
}

private def tdfBuilder(e: Seq[Expression]): LogicalPlan = {
val spark = SparkSession.getActiveSession match {
case None => throw new RuntimeException("Whoops: No spark session!")
case Some(spark) => spark
}
if (e.size != 3) {
throw new RuntimeException(s"Need exactly 3 arguments <tableName, columnName, columnExpression>, got $e")
if (e.size != 4) {
throw new RuntimeException(s"Need exactly 4 arguments <tablePrefix, fromSchema, toSchema, tableSuffix>, got $e")
}
val Seq(tableName, columnName, columnExpression) = e.map(computeString)
val sqlString = sql(tableName, columnName, columnExpression)
val Seq(tablePrefix, fromSchema, toSchema, tableSuffix) = e.map(computeString)
val sqlString = sql(tablePrefix, fromSchema, toSchema, tableSuffix)
spark.sql(sqlString).queryExecution.logical
}

val function = (FunctionIdentifier("with_column"),
new ExpressionInfo("io.lakefs.iceberg.extension.WithColumn$",
"", "with_column", "with_column('TABLE', 'NEW_COLUMN', 'NEW_COLUMN_EXPRESSION')", "with_column('TABLE', 'NEW_COLUMN', 'NEW_COLUMN_EXPRESSION')"),
val function = (FunctionIdentifier("schema_diff"),
new ExpressionInfo("io.lakefs.iceberg.extension.SchemaDiff$",
"", "schema_diff", "schema_diff('TABLE_PREFIX', 'FROM_SCHEMA', 'TO_SCHEMA', 'TABLE_SUFFIX')",
"schema_diff('TABLE_PREFIX', 'FROM_SCHEMA', 'TO_SCHEMA', 'TABLE_SUFFIX')"),
tdfBuilder _)
}

class FooSparkSessionExtensions extends (SparkSessionExtensions => Unit) {
class LakeFSSparkSessionExtensions extends (SparkSessionExtensions => Unit) {
override def apply(extensions: SparkSessionExtensions): Unit = {
println("*** Go register FooSparkSessionExtensions ***!")
extensions.injectTableFunction(WithColumn.function)
extensions.injectTableFunction(SchemaDiff.function)
}
}

0 comments on commit f4118b1

Please sign in to comment.