From b015c88b92c7e24e97d93cc78068c44e20501a92 Mon Sep 17 00:00:00 2001 From: Filipe Regadas Date: Thu, 4 Oct 2018 16:53:49 -0400 Subject: [PATCH] Add better support for parameterised queries in @BigQueryType.fromQuery --- .../scio/bigquery/types/BigQueryTypeIT.scala | 24 +++++++++ .../scio/bigquery/types/BigQueryType.scala | 2 +- .../scio/bigquery/types/TypeProvider.scala | 54 ++++++++++++++----- .../scala/com/spotify/scio/coders/Foo.scala | 41 ++++++++++++++ .../src/main/scala/com/spotify/scio/Foo.scala | 6 +++ 5 files changed, 113 insertions(+), 14 deletions(-) create mode 100644 scio-coders-macros/src/main/scala/com/spotify/scio/coders/Foo.scala create mode 100644 scio-core/src/main/scala/com/spotify/scio/Foo.scala diff --git a/scio-bigquery/src/it/scala/com/spotify/scio/bigquery/types/BigQueryTypeIT.scala b/scio-bigquery/src/it/scala/com/spotify/scio/bigquery/types/BigQueryTypeIT.scala index b40a5086b2..178f45335d 100644 --- a/scio-bigquery/src/it/scala/com/spotify/scio/bigquery/types/BigQueryTypeIT.scala +++ b/scio-bigquery/src/it/scala/com/spotify/scio/bigquery/types/BigQueryTypeIT.scala @@ -44,6 +44,10 @@ object BigQueryTypeIT { "SELECT word, word_count FROM `data-integration-test.partition_a.table_%s`", "$LATEST") class SqlLatestT + @BigQueryType.fromQuery( + "SELECT word, word_count FROM `data-integration-test.partition_a.table_%s` LIMIT %d", "$LATEST", 1) + class SqlLatestTWithMultiArgs + @BigQueryType.fromTable("data-integration-test:partition_a.table_%s", "$LATEST") class FromTableLatestT @@ -141,6 +145,26 @@ class BigQueryTypeIT extends FlatSpec with Matchers { BigQueryType[SqlLatestT].query shouldBe Some(sqlLatestQuery) } + it should "have query fn" in { + """LegacyLatestT.query("TABLE")""" should compile + """SqlLatestT.query("TABLE")""" should compile + } + + it should "have query fn with only 1 argument" in { + """LegacyLatestT.query("TABLE", 1)""" shouldNot typeCheck + """SqlLatestT.query("TABLE", 1)""" shouldNot typeCheck + } + + it should "have query fn with multiple arguments" in { + """SqlLatestTWithMultiArgs.query("TABLE", 1)""" should compile + """SqlLatestTWithMultiArgs.query(1, "TABLE")""" shouldNot typeCheck + } + + it should "format query" in { + LegacyLatestT.query("TABLE") shouldBe legacyLatestQuery.format("TABLE") + SqlLatestT.query("TABLE") shouldBe sqlLatestQuery.format("TABLE") + } + "fromTable" should "work" in { val bqt = BigQueryType[FromTableT] bqt.isQuery shouldBe false diff --git a/scio-bigquery/src/main/scala/com/spotify/scio/bigquery/types/BigQueryType.scala b/scio-bigquery/src/main/scala/com/spotify/scio/bigquery/types/BigQueryType.scala index 383d0c4d01..3fee025a8e 100644 --- a/scio-bigquery/src/main/scala/com/spotify/scio/bigquery/types/BigQueryType.scala +++ b/scio-bigquery/src/main/scala/com/spotify/scio/bigquery/types/BigQueryType.scala @@ -216,7 +216,7 @@ object BigQueryType { * behavior, start the query string with `#legacysql` or `#standardsql`. * @group annotation */ - class fromQuery(query: String, args: String*) extends StaticAnnotation { + class fromQuery(query: String, args: Any*) extends StaticAnnotation { def macroTransform(annottees: Any*): Any = macro TypeProvider.queryImpl } diff --git a/scio-bigquery/src/main/scala/com/spotify/scio/bigquery/types/TypeProvider.scala b/scio-bigquery/src/main/scala/com/spotify/scio/bigquery/types/TypeProvider.scala index a4d4a9a8b2..2521791141 100644 --- a/scio-bigquery/src/main/scala/com/spotify/scio/bigquery/types/TypeProvider.scala +++ b/scio-bigquery/src/main/scala/com/spotify/scio/bigquery/types/TypeProvider.scala @@ -36,7 +36,7 @@ import com.spotify.scio.bigquery.{ import org.slf4j.LoggerFactory import scala.collection.JavaConverters._ -import scala.collection.mutable.{Map => MMap} +import scala.collection.mutable.{Map => MMap, Stack => MStack} import scala.reflect.macros._ // scalastyle:off line.size.limit @@ -44,17 +44,20 @@ private[types] object TypeProvider { private[this] val logger = LoggerFactory.getLogger(this.getClass) private lazy val bigquery: BigQueryClient = BigQueryClient.defaultInstance() + private[this] val FormatSpecifierRegex = + "(%(\\d+\\$)?([-#+ 0,(\\<]*)?(\\d+)?(\\.\\d+)?([tT])?([a-zA-Z%]))".r def tableImpl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { import c.universe._ val args = extractStrings(c, "Missing table specification") + val query = args.head.asInstanceOf[String] val tableSpec = BigQueryPartitionUtil.latestTable(bigquery, formatString(args)) val schema = bigquery.getTableSchema(tableSpec) val traits = List(tq"${p(c, SType)}.HasTable") - val tableDef = q"override def table: _root_.java.lang.String = ${args.head}" + val tableDef = q"override def table: _root_.java.lang.String = $query" val ta = annottees.map(_.tree) match { @@ -74,24 +77,49 @@ private[types] object TypeProvider { } def schemaImpl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { - val schemaString = extractStrings(c, "Missing schema").head + val schemaString = extractStrings(c, "Missing schema").head.asInstanceOf[String] val schema = BigQueryUtil.parseSchema(schemaString) schemaToType(c)(schema, annottees, Nil, Nil) } + // scalastyle:off cyclomatic.complexity def queryImpl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { import c.universe._ val args = extractStrings(c, "Missing query") + val (queryFormat: String) :: tail = args + val argsStack = MStack[Any](tail: _*) val query = BigQueryPartitionUtil.latestQuery(bigquery, formatString(args)) val schema = bigquery.getQuerySchema(query) val traits = List(tq"${p(c, SType)}.HasQuery") - val queryDef = q"override def query: _root_.java.lang.String = ${args.head}" + val queryDef = + q"override def query: _root_.java.lang.String = $queryFormat" + + val formatTerms = FormatSpecifierRegex + .findAllMatchIn(queryFormat) + .map(m => (TermName(c.freshName("queryArg$")), m.matched.last, argsStack.pop())) + .collect { + case (termName, 's', _: String) => typeOf[String] -> termName + case (termName, 'd', _: Int) => typeOf[Int] -> termName + case (termName, 'd', _: Long) => typeOf[Long] -> termName + case (termName, 'f', _: Float) => typeOf[Float] -> termName + case (termName, 'f', _: Double) => typeOf[Double] -> termName + case _ => + c.abort(c.enclosingPosition, "format specifier not supported") + } + .toList + + val queryFnDef = if (formatTerms.nonEmpty) { + val typesQ = formatTerms.map { case (tpt, termName) => q"$termName: $tpt" } + Some(q"def query(..$typesQ): String = $queryFormat.format(..${formatTerms.map(_._2)})") + } else { + None + } val qa = annottees.map(_.tree) match { - case (q"class $cName") :: tail => + case q"class $cName" :: _ => List(q""" implicit def bqQuery: ${p(c, SType)}.Query[$cName] = new ${p(c, SType)}.Query[$cName]{ @@ -101,10 +129,11 @@ private[types] object TypeProvider { case _ => Nil } - val overrides = List(queryDef) ++ qa + val overrides = queryFnDef.getOrElse(EmptyTree) :: queryDef :: qa schemaToType(c)(schema, annottees, traits, overrides) } + // scalastyle:on cyclomatic.complexity private def getTableDescription(c: blackbox.Context)( cd: c.universe.ClassDef): List[c.universe.Tree] = { @@ -136,13 +165,12 @@ private[types] object TypeProvider { val traits = (if (fields.size <= 22) Seq(fnTrait) else Seq()) ++ defTblDesc .map(_ => tq"${p(c, SType)}.HasTableDescription") val taggedFields = fields.map { - case ValDef(m, n, tpt, rhs) => { + case ValDef(m, n, tpt, rhs) => provider.initializeToTable(c)(m, n, tpt) c.universe.ValDef(c.universe.Modifiers(m.flags, m.privateWithin, m.annotations), n, tq"$tpt @${typeOf[BigQueryTag]}", rhs) - } } val caseClassTree = q"""${caseClass(c)(mods, cName, taggedFields, body)}""" @@ -273,12 +301,12 @@ private[types] object TypeProvider { // scalastyle:on method.length /** Extract string from annotation. */ - private def extractStrings(c: blackbox.Context, errorMessage: String): List[String] = { + private def extractStrings(c: blackbox.Context, errorMessage: String): List[Any] = { import c.universe._ def str(tree: c.Tree) = tree match { - // "string literal" - case Literal(Constant(s: String)) => s + // "argument literal" + case Literal(Constant(arg @ (_: String | _: Float | _: Double | _: Int | _: Long))) => arg // "string literal".stripMargin case Select(Literal(Constant(s: String)), TermName("stripMargin")) => s.stripMargin @@ -296,8 +324,8 @@ private[types] object TypeProvider { } } - private def formatString(xs: List[String]): String = - if (xs.tail.isEmpty) xs.head else xs.head.format(xs.tail: _*) + private def formatString(xs: List[Any]): String = + xs.head.asInstanceOf[String].format(xs.tail: _*) /** Generate a case class. */ private def caseClass(c: blackbox.Context)(mods: c.Modifiers, diff --git a/scio-coders-macros/src/main/scala/com/spotify/scio/coders/Foo.scala b/scio-coders-macros/src/main/scala/com/spotify/scio/coders/Foo.scala new file mode 100644 index 0000000000..09f01eb576 --- /dev/null +++ b/scio-coders-macros/src/main/scala/com/spotify/scio/coders/Foo.scala @@ -0,0 +1,41 @@ +package com.spotify.scio.coders + +import scala.collection.mutable +import scala.collection.mutable.ListBuffer +import scala.language.experimental.macros +import scala.reflect.macros._ + +object FooMacros { + def printf(format: String, params: Any*): Unit = macro printfImpl + + def printfImpl(c: whitebox.Context)(format: c.Expr[String], + params: c.Expr[Any]*): c.Expr[Unit] = { + import c.universe._ + val Literal(Constant(s_format: String)) = format.tree + + val evals = ListBuffer[Tree]() + def precompute(value: Tree, tpe: Type): Ident = { + val freshName = TermName(c.freshName("query$")) + evals += q"$freshName: $tpe" + Ident(freshName) + } + + val paramsStack = mutable.Stack[Tree](params.map(_.tree): _*) + val refs = s_format.split("(?<=%[\\w%])|(?=%[\\w%])") map { + case "%d" => precompute(paramsStack.pop, typeOf[Int]) + case "%s" => precompute(paramsStack.pop, typeOf[String]) + case "%%" => Literal(Constant("%")) + case part => Literal(Constant(part)) + } + + println(q""" + def asd(..$evals): String = "" + """) + +// val stats = evals ++ refs.map(ref => reify(print(c.Expr[Any](ref).splice)).tree) +// val asd = c.Expr[Unit](Block(stats.toList, Literal(Constant(())))) +// println(asd.toString()) + + c.Expr[Unit](Block(List.empty, Literal(Constant(())))) + } +} diff --git a/scio-core/src/main/scala/com/spotify/scio/Foo.scala b/scio-core/src/main/scala/com/spotify/scio/Foo.scala new file mode 100644 index 0000000000..4af5276437 --- /dev/null +++ b/scio-core/src/main/scala/com/spotify/scio/Foo.scala @@ -0,0 +1,6 @@ +package com.spotify.scio + +object Foo extends App { + import com.spotify.scio.coders.FooMacros._ + printf("hello %s %s!", "asd", "qwe") +}