Skip to content

Commit

Permalink
Add better support for parameterised queries in @BigQueryType.fromQuery
Browse files Browse the repository at this point in the history
  • Loading branch information
regadas committed Oct 4, 2018
1 parent 3283fe6 commit b015c88
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 14 deletions.
Expand Up @@ -44,6 +44,10 @@ object BigQueryTypeIT {
"SELECT word, word_count FROM `data-integration-test.partition_a.table_%s`", "$LATEST")
class SqlLatestT

@BigQueryType.fromQuery(
"SELECT word, word_count FROM `data-integration-test.partition_a.table_%s` LIMIT %d", "$LATEST", 1)
class SqlLatestTWithMultiArgs

@BigQueryType.fromTable("data-integration-test:partition_a.table_%s", "$LATEST")
class FromTableLatestT

Expand Down Expand Up @@ -141,6 +145,26 @@ class BigQueryTypeIT extends FlatSpec with Matchers {
BigQueryType[SqlLatestT].query shouldBe Some(sqlLatestQuery)
}

it should "have query fn" in {
"""LegacyLatestT.query("TABLE")""" should compile
"""SqlLatestT.query("TABLE")""" should compile
}

it should "have query fn with only 1 argument" in {
"""LegacyLatestT.query("TABLE", 1)""" shouldNot typeCheck
"""SqlLatestT.query("TABLE", 1)""" shouldNot typeCheck
}

it should "have query fn with multiple arguments" in {
"""SqlLatestTWithMultiArgs.query("TABLE", 1)""" should compile
"""SqlLatestTWithMultiArgs.query(1, "TABLE")""" shouldNot typeCheck
}

it should "format query" in {
LegacyLatestT.query("TABLE") shouldBe legacyLatestQuery.format("TABLE")
SqlLatestT.query("TABLE") shouldBe sqlLatestQuery.format("TABLE")
}

"fromTable" should "work" in {
val bqt = BigQueryType[FromTableT]
bqt.isQuery shouldBe false
Expand Down
Expand Up @@ -216,7 +216,7 @@ object BigQueryType {
* behavior, start the query string with `#legacysql` or `#standardsql`.
* @group annotation
*/
class fromQuery(query: String, args: String*) extends StaticAnnotation {
class fromQuery(query: String, args: Any*) extends StaticAnnotation {
def macroTransform(annottees: Any*): Any = macro TypeProvider.queryImpl
}

Expand Down
Expand Up @@ -36,25 +36,28 @@ import com.spotify.scio.bigquery.{
import org.slf4j.LoggerFactory

import scala.collection.JavaConverters._
import scala.collection.mutable.{Map => MMap}
import scala.collection.mutable.{Map => MMap, Stack => MStack}
import scala.reflect.macros._

// scalastyle:off line.size.limit
private[types] object TypeProvider {

private[this] val logger = LoggerFactory.getLogger(this.getClass)
private lazy val bigquery: BigQueryClient = BigQueryClient.defaultInstance()
private[this] val FormatSpecifierRegex =
"(%(\\d+\\$)?([-#+ 0,(\\<]*)?(\\d+)?(\\.\\d+)?([tT])?([a-zA-Z%]))".r

def tableImpl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
import c.universe._

val args = extractStrings(c, "Missing table specification")
val query = args.head.asInstanceOf[String]
val tableSpec =
BigQueryPartitionUtil.latestTable(bigquery, formatString(args))
val schema = bigquery.getTableSchema(tableSpec)
val traits = List(tq"${p(c, SType)}.HasTable")

val tableDef = q"override def table: _root_.java.lang.String = ${args.head}"
val tableDef = q"override def table: _root_.java.lang.String = $query"

val ta =
annottees.map(_.tree) match {
Expand All @@ -74,24 +77,49 @@ private[types] object TypeProvider {
}

def schemaImpl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
val schemaString = extractStrings(c, "Missing schema").head
val schemaString = extractStrings(c, "Missing schema").head.asInstanceOf[String]
val schema = BigQueryUtil.parseSchema(schemaString)
schemaToType(c)(schema, annottees, Nil, Nil)
}

// scalastyle:off cyclomatic.complexity
def queryImpl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
import c.universe._

val args = extractStrings(c, "Missing query")
val (queryFormat: String) :: tail = args
val argsStack = MStack[Any](tail: _*)
val query = BigQueryPartitionUtil.latestQuery(bigquery, formatString(args))
val schema = bigquery.getQuerySchema(query)
val traits = List(tq"${p(c, SType)}.HasQuery")

val queryDef = q"override def query: _root_.java.lang.String = ${args.head}"
val queryDef =
q"override def query: _root_.java.lang.String = $queryFormat"

val formatTerms = FormatSpecifierRegex
.findAllMatchIn(queryFormat)
.map(m => (TermName(c.freshName("queryArg$")), m.matched.last, argsStack.pop()))
.collect {
case (termName, 's', _: String) => typeOf[String] -> termName
case (termName, 'd', _: Int) => typeOf[Int] -> termName
case (termName, 'd', _: Long) => typeOf[Long] -> termName
case (termName, 'f', _: Float) => typeOf[Float] -> termName
case (termName, 'f', _: Double) => typeOf[Double] -> termName
case _ =>
c.abort(c.enclosingPosition, "format specifier not supported")
}
.toList

val queryFnDef = if (formatTerms.nonEmpty) {
val typesQ = formatTerms.map { case (tpt, termName) => q"$termName: $tpt" }
Some(q"def query(..$typesQ): String = $queryFormat.format(..${formatTerms.map(_._2)})")
} else {
None
}

val qa =
annottees.map(_.tree) match {
case (q"class $cName") :: tail =>
case q"class $cName" :: _ =>
List(q"""
implicit def bqQuery: ${p(c, SType)}.Query[$cName] =
new ${p(c, SType)}.Query[$cName]{
Expand All @@ -101,10 +129,11 @@ private[types] object TypeProvider {
case _ =>
Nil
}
val overrides = List(queryDef) ++ qa
val overrides = queryFnDef.getOrElse(EmptyTree) :: queryDef :: qa

schemaToType(c)(schema, annottees, traits, overrides)
}
// scalastyle:on cyclomatic.complexity

private def getTableDescription(c: blackbox.Context)(
cd: c.universe.ClassDef): List[c.universe.Tree] = {
Expand Down Expand Up @@ -136,13 +165,12 @@ private[types] object TypeProvider {
val traits = (if (fields.size <= 22) Seq(fnTrait) else Seq()) ++ defTblDesc
.map(_ => tq"${p(c, SType)}.HasTableDescription")
val taggedFields = fields.map {
case ValDef(m, n, tpt, rhs) => {
case ValDef(m, n, tpt, rhs) =>
provider.initializeToTable(c)(m, n, tpt)
c.universe.ValDef(c.universe.Modifiers(m.flags, m.privateWithin, m.annotations),
n,
tq"$tpt @${typeOf[BigQueryTag]}",
rhs)
}
}
val caseClassTree =
q"""${caseClass(c)(mods, cName, taggedFields, body)}"""
Expand Down Expand Up @@ -273,12 +301,12 @@ private[types] object TypeProvider {
// scalastyle:on method.length

/** Extract string from annotation. */
private def extractStrings(c: blackbox.Context, errorMessage: String): List[String] = {
private def extractStrings(c: blackbox.Context, errorMessage: String): List[Any] = {
import c.universe._

def str(tree: c.Tree) = tree match {
// "string literal"
case Literal(Constant(s: String)) => s
// "argument literal"
case Literal(Constant(arg @ (_: String | _: Float | _: Double | _: Int | _: Long))) => arg
// "string literal".stripMargin
case Select(Literal(Constant(s: String)), TermName("stripMargin")) =>
s.stripMargin
Expand All @@ -296,8 +324,8 @@ private[types] object TypeProvider {
}
}

private def formatString(xs: List[String]): String =
if (xs.tail.isEmpty) xs.head else xs.head.format(xs.tail: _*)
private def formatString(xs: List[Any]): String =
xs.head.asInstanceOf[String].format(xs.tail: _*)

/** Generate a case class. */
private def caseClass(c: blackbox.Context)(mods: c.Modifiers,
Expand Down
@@ -0,0 +1,41 @@
package com.spotify.scio.coders

import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import scala.language.experimental.macros
import scala.reflect.macros._

object FooMacros {
def printf(format: String, params: Any*): Unit = macro printfImpl

def printfImpl(c: whitebox.Context)(format: c.Expr[String],
params: c.Expr[Any]*): c.Expr[Unit] = {
import c.universe._
val Literal(Constant(s_format: String)) = format.tree

val evals = ListBuffer[Tree]()
def precompute(value: Tree, tpe: Type): Ident = {
val freshName = TermName(c.freshName("query$"))
evals += q"$freshName: $tpe"
Ident(freshName)
}

val paramsStack = mutable.Stack[Tree](params.map(_.tree): _*)
val refs = s_format.split("(?<=%[\\w%])|(?=%[\\w%])") map {
case "%d" => precompute(paramsStack.pop, typeOf[Int])
case "%s" => precompute(paramsStack.pop, typeOf[String])
case "%%" => Literal(Constant("%"))
case part => Literal(Constant(part))
}

println(q"""
def asd(..$evals): String = ""
""")

// val stats = evals ++ refs.map(ref => reify(print(c.Expr[Any](ref).splice)).tree)
// val asd = c.Expr[Unit](Block(stats.toList, Literal(Constant(()))))
// println(asd.toString())

c.Expr[Unit](Block(List.empty, Literal(Constant(()))))
}
}
6 changes: 6 additions & 0 deletions scio-core/src/main/scala/com/spotify/scio/Foo.scala
@@ -0,0 +1,6 @@
package com.spotify.scio

object Foo extends App {
import com.spotify.scio.coders.FooMacros._
printf("hello %s %s!", "asd", "qwe")
}

0 comments on commit b015c88

Please sign in to comment.