Skip to content

Commit

Permalink
(fix)[avro] Use correct DatumWriter constructor (#5371)
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones committed May 17, 2024
1 parent d235906 commit e1687eb
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 12 deletions.
33 changes: 33 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,39 @@ jobs:
with:
token: ${{ secrets.CODECOV_TOKEN }}

avro-latest:
name: Test Latest Avro
if: github.event_name != 'pull_request' && github.ref == 'refs/heads/main'
strategy:
matrix:
os: [ubuntu-latest]
scala: [2.13]
java: [corretto@11]
runs-on: ${{ matrix.os }}
steps:
- name: Checkout current branch (full)
uses: actions/checkout@v4
with:
fetch-depth: 0

- name: Setup Java (corretto@11)
id: setup-java-corretto-11
if: matrix.java == 'corretto@11'
uses: actions/setup-java@v4
with:
distribution: corretto
java-version: 11
cache: sbt

- name: sbt update
if: matrix.java == 'corretto@11' && steps.setup-java-corretto-11.outputs.cache-hit == 'false'
run: sbt +update

- name: Test
env:
JAVA_OPTS: '-Davro.version=1.11.3'
run: sbt '++ ${{ matrix.scala }}' scio-avro/test

it-test:
name: Integration Test
if: github.event_name != 'pull_request' && github.ref == 'refs/heads/main'
Expand Down
25 changes: 23 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ val googleApiServicesPubsubVersion = s"v1-rev20220904-$googleClientsVersion"
val googleApiServicesStorageVersion = s"v1-rev20240311-$googleClientsVersion"
// beam tested versions
val zetasketchVersion = "0.1.0" // sdks/java/extensions/zetasketch/build.gradle
val avroVersion = "1.8.2" // sdks/java/extensions/avro/build.gradle
val avroVersion = avroCompilerVersion // sdks/java/extensions/avro/build.gradle
val flinkVersion = "1.17.0" // runners/flink/1.17/build.gradle
val hadoopVersion = "3.2.4" // sdks/java/io/parquet/build.gradle
val sparkVersion = "3.5.0" // runners/spark/3/build.gradle
Expand Down Expand Up @@ -311,6 +311,22 @@ ThisBuild / githubWorkflowAddedJobs ++= Seq(
scalas = List(CrossVersion.binaryScalaVersion(scalaDefault)),
javas = List(javaDefault)
),
WorkflowJob(
"avro-latest",
"Test Latest Avro",
WorkflowStep.CheckoutFull ::
WorkflowStep.SetupJava(List(javaDefault)) :::
List(
WorkflowStep.Sbt(
List("scio-avro/test"),
env = Map("JAVA_OPTS" -> "-Davro.version=1.11.3"),
name = Some("Test")
)
),
cond = Some(Seq(condSkipPR, condIsMain).mkString(" && ")),
scalas = List(CrossVersion.binaryScalaVersion(scalaDefault)),
javas = List(javaDefault)
),
WorkflowJob(
"it-test",
"Integration Test",
Expand Down Expand Up @@ -685,6 +701,7 @@ lazy val `scio-core` = project
"org.apache.avro" % "avro" % avroVersion % Test,
"org.apache.beam" % "beam-runners-direct-java" % beamVersion % Test,
"org.apache.beam" % "beam-sdks-java-core" % beamVersion % Test,
"org.codehaus.jackson" % "jackson-mapper-asl" % "1.9.13" % Test,
"org.hamcrest" % "hamcrest" % hamcrestVersion % Test,
"org.scalacheck" %% "scalacheck" % scalacheckVersion % Test,
"org.scalactic" %% "scalactic" % scalatestVersion % Test,
Expand Down Expand Up @@ -823,7 +840,11 @@ lazy val `scio-avro` = project
"org.scalatestplus" %% "scalacheck-1-17" % scalatestplusVersion % Test,
"org.slf4j" % "slf4j-simple" % slf4jVersion % Test,
"org.typelevel" %% "cats-core" % catsVersion % Test
)
),
Test / unmanagedSourceDirectories += {
val base = (Test / sourceDirectory).value
if (avroVersion.startsWith("1.8")) base / "scala-avro-legacy" else base / "scala-avro-latest"
}
)

lazy val `scio-google-cloud-platform` = project
Expand Down
3 changes: 2 additions & 1 deletion project/plugins.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ addSbtPlugin("org.scalameta" % "sbt-mdoc" % "2.5.2")
addSbtPlugin("org.scoverage" % "sbt-scoverage" % "2.0.12")
addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.4.7")

val avroVersion = sys.props.get("avro.version").getOrElse("1.8.2")
libraryDependencies ++= Seq(
"org.apache.avro" % "avro-compiler" % "1.8.2",
"org.apache.avro" % "avro-compiler" % avroVersion,
"org.typelevel" %% "scalac-options" % "0.1.4"
)
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,9 @@ private[scio] object GenericRecordDatumFactory extends AvroDatumFactory.GenericD
private[scio] class SpecificRecordDatumFactory[T <: SpecificRecord](recordType: Class[T])
extends AvroDatumFactory.SpecificDatumFactory[T](recordType) {
import SpecificRecordDatumFactory._
private class ScioSpecificDatumReader extends SpecificDatumReader[T](recordType) {
override def findStringClass(schema: Schema): Class[_] = super.findStringClass(schema) match {
case cls if cls == classOf[CharSequence] => classOf[String]
case cls => cls
}
}

override def apply(writer: Schema): DatumWriter[T] = {
val datumWriter = new SpecificDatumWriter[T]()
val datumWriter = new SpecificDatumWriter(recordType)
// avro 1.8 generated code does not add conversions to the data
if (runtimeAvroVersion.exists(_.startsWith("1.8."))) {
addLogicalTypeConversions(datumWriter.getData.asInstanceOf[SpecificData], writer)
Expand All @@ -81,6 +75,14 @@ private[scio] class SpecificRecordDatumFactory[T <: SpecificRecord](recordType:
datumWriter
}

// TODO move this to companion object
private class ScioSpecificDatumReader extends SpecificDatumReader[T](recordType) {
override def findStringClass(schema: Schema): Class[_] = super.findStringClass(schema) match {
case cls if cls == classOf[CharSequence] => classOf[String]
case cls => cls
}
}

override def apply(writer: Schema, reader: Schema): DatumReader[T] = {
val datumReader = new ScioSpecificDatumReader()
// avro 1.8 generated code does not add conversions to the data
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright 2024 Spotify AB
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.spotify.scio.avro

import org.apache.avro.LogicalTypes
import org.apache.avro.data.TimeConversions
import org.apache.avro.specific.{SpecificDatumReader, SpecificDatumWriter}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

class AvroDatumFactoryTest extends AnyFlatSpec with Matchers {

"SpecificRecordDatumFactory" should "load model with conversions" in {
val factory = new SpecificRecordDatumFactory(classOf[LogicalTypesTest])
val schema = LogicalTypesTest.getClassSchema

{
val writer = factory(schema)
val data = writer.asInstanceOf[SpecificDatumWriter[LogicalTypesTest]].getData
// top-level
val timestamp = data.getConversionFor(LogicalTypes.timestampMillis())
timestamp shouldBe a[TimeConversions.TimestampMillisConversion]
// nested-level
val date = data.getConversionFor(LogicalTypes.date())
date shouldBe a[TimeConversions.DateConversion]
val time = data.getConversionFor(LogicalTypes.timeMillis())
time shouldBe a[TimeConversions.TimeMillisConversion]
}

{
val reader = factory(schema, schema)
val data = reader.asInstanceOf[SpecificDatumReader[LogicalTypesTest]].getData
// top-level
val timestamp = data.getConversionFor(LogicalTypes.timestampMillis())
timestamp shouldBe a[TimeConversions.TimestampMillisConversion]
// nested-level
val date = data.getConversionFor(LogicalTypes.date())
date shouldBe a[TimeConversions.DateConversion]
val time = data.getConversionFor(LogicalTypes.timeMillis())
time shouldBe a[TimeConversions.TimeMillisConversion]
}
}

it should "allow classes with 'conversions' field" in {
val f = new SpecificRecordDatumFactory(classOf[NameConflict])
val schema = LogicalTypesTest.getClassSchema
noException shouldBe thrownBy(f(schema))
noException shouldBe thrownBy(f(schema, schema))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.scalactic.Equality
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

import scala.annotation.tailrec
import scala.jdk.CollectionConverters._

final class AvroCoderTest extends AnyFlatSpec with Matchers {
Expand All @@ -44,10 +45,17 @@ final class AvroCoderTest extends AnyFlatSpec with Matchers {
}

it should "support not Avro's SpecificRecord if a concrete type is not provided" in {
@tailrec
def rootCause(e: Throwable): Throwable =
Option(e.getCause) match {
case Some(cause) => rootCause(cause)
case None => e
}

val caught = intercept[RuntimeException] {
Avro.user.asInstanceOf[SpecificRecord] coderShould notFallback()
}
val cause = caught.getCause.getCause
val cause = rootCause(caught)
cause shouldBe a[AvroRuntimeException]
cause.getMessage shouldBe "Not a Specific class: interface org.apache.avro.specific.SpecificRecord"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ case class Pair(name: String, size: Int)
case class CaseClassWithGenericRecord(name: String, size: Int, record: GenericRecord)
case class CaseClassWithSpecificRecord(name: String, size: Int, record: TestRecord)

// additional kryo registrar for avro schema null value
// additional kryo registrar for avro 1.8 schema null value
// deserializes null to the singleton instance for schema equality
@KryoRegistrar
class TestKryoRegistrar extends IKryoRegistrar {
Expand Down

0 comments on commit e1687eb

Please sign in to comment.