Skip to content

Commit

Permalink
Refactor test helper method naming
Browse files Browse the repository at this point in the history
  • Loading branch information
clairemcginty committed May 30, 2024
1 parent 4214363 commit 344c3f0
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,29 @@
* under the License.
*/

package com.spotify.scio.testing
package com.spotify.scio.testing.parquet

import com.spotify.parquet.tensorflow.{
TensorflowExampleParquetReader,
TensorflowExampleParquetWriter,
TensorflowExampleReadSupport
}
import magnolify.parquet._
import _root_.magnolify.parquet.ParquetType
import org.apache.avro.Schema
import org.apache.avro.generic.GenericRecord
import org.apache.hadoop.conf.Configuration
import org.apache.parquet.avro.{AvroParquetReader, AvroParquetWriter, AvroReadSupport}
import org.apache.parquet.filter2.predicate.FilterPredicate
import org.apache.parquet.hadoop.{ParquetInputFormat, ParquetReader, ParquetWriter}
import org.apache.parquet.io._
import org.tensorflow.metadata.{v0 => tfmd}
import org.tensorflow.proto.example.Example
import org.tensorflow.metadata.{v0 => tfmd}

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}

trait ParquetTestUtils {
case class ParquetMagnolifyHelpers[T: ParquetType](records: Iterable[T]) {
def parquetFilter(filter: FilterPredicate): Iterable[T] = {
object ParquetTestUtils {
case class ParquetMagnolifyHelpers[T: ParquetType] private[testing] (records: Iterable[T]) {
def withFilter(filter: FilterPredicate): Iterable[T] = {
val pt = implicitly[ParquetType[T]]

val configuration = new Configuration()
Expand All @@ -50,15 +50,17 @@ trait ParquetTestUtils {
}
}

case class ParquetAvroHelpers[T <: GenericRecord](records: Iterable[T]) {
def parquetProject(projection: Schema): Iterable[T] = {
private[testing] case class ParquetAvroHelpers[T <: GenericRecord] private[testing] (
records: Iterable[T]
) {
def withProjection(projection: Schema): Iterable[T] = {
val configuration = new Configuration()
AvroReadSupport.setRequestedProjection(configuration, projection)

roundtripAvro(records, configuration)
}

def parquetFilter(filter: FilterPredicate): Iterable[T] = {
def withFilter(filter: FilterPredicate): Iterable[T] = {
val configuration = new Configuration()
ParquetInputFormat.setFilterPredicate(configuration, filter)

Expand All @@ -83,8 +85,8 @@ trait ParquetTestUtils {
}
}

case class ParquetExampleHelpers(records: Iterable[Example]) {
def parquetProject(schema: tfmd.Schema, projection: tfmd.Schema): Iterable[Example] = {
private[testing] case class ParquetExampleHelpers private[testing] (records: Iterable[Example]) {
def withProjection(schema: tfmd.Schema, projection: tfmd.Schema): Iterable[Example] = {
val configuration = new Configuration()
TensorflowExampleReadSupport.setExampleReadSchema(
configuration,
Expand All @@ -98,7 +100,7 @@ trait ParquetTestUtils {
roundtripExample(records, schema, configuration)
}

def parquetFilter(schema: tfmd.Schema, filter: FilterPredicate): Iterable[Example] = {
def withFilter(schema: tfmd.Schema, filter: FilterPredicate): Iterable[Example] = {
val configuration = new Configuration()
TensorflowExampleReadSupport.setExampleReadSchema(
configuration,
Expand All @@ -109,7 +111,7 @@ trait ParquetTestUtils {
roundtripExample(records, schema, configuration)
}

def roundtripExample(
private def roundtripExample(
records: Iterable[Example],
schema: tfmd.Schema,
readConfiguration: Configuration
Expand Down Expand Up @@ -194,16 +196,4 @@ trait ParquetTestUtils {
}
}
}

implicit def toParquetAvroHelpers[T <: GenericRecord](
records: Iterable[T]
): ParquetAvroHelpers[T] = ParquetAvroHelpers(records)

implicit def toParquetMagnolifyHelpers[T: ParquetType](
records: Iterable[T]
): ParquetMagnolifyHelpers[T] = ParquetMagnolifyHelpers(records)

implicit def toParquetExampleHelpers(
records: Iterable[Example]
): ParquetExampleHelpers = ParquetExampleHelpers(records)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright 2024 Spotify AB.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package com.spotify.scio.testing

import com.spotify.scio.testing.parquet.ParquetTestUtils._
import magnolify.parquet.ParquetType
import org.apache.avro.generic.GenericRecord
import org.tensorflow.proto.example.Example

package object parquet {

object avro {
implicit def toParquetAvroHelpers[T <: GenericRecord](
records: Iterable[T]
): ParquetAvroHelpers[T] = ParquetAvroHelpers(records)
}

object types {
implicit def toParquetMagnolifyHelpers[T: ParquetType](
records: Iterable[T]
): ParquetMagnolifyHelpers[T] = ParquetMagnolifyHelpers(records)
}

object tensorflow {
implicit def toParquetExampleHelpers(
records: Iterable[Example]
): ParquetExampleHelpers = ParquetExampleHelpers(records)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
* under the License.
*/

package com.spotify.scio.testing
package com.spotify.scio.testing.parquet

import com.spotify.scio.avro.TestRecord
import com.spotify.scio.avro.{Account, AccountStatus}
import org.apache.avro.SchemaBuilder
import org.apache.avro.generic.GenericRecordBuilder
import org.apache.parquet.filter2.predicate.FilterApi
Expand All @@ -30,39 +30,38 @@ import scala.jdk.CollectionConverters._

case class SomeRecord(intField: Int)

class ParquetTestUtilsTest extends AnyFlatSpec with Matchers with ParquetTestUtils {
class ParquetTestUtilsTest extends AnyFlatSpec with Matchers {

"Avro SpecificRecords" should "be filterable and projectable" in {
import com.spotify.scio.testing.parquet.avro._

val records = (1 to 10).map(i =>
new TestRecord(
i,
i.toLong,
i.toFloat,
i.toDouble,
true,
"hello",
List[CharSequence]("a", "b", "c").asJava
)
Account
.newBuilder()
.setId(i)
.setName(i.toString)
.setAmount(i.toDouble)
.setType(s"Type$i")
.setAccountStatus(AccountStatus.Active)
.build()
)

val transformed = records
.parquetFilter(
FilterApi.gt(FilterApi.intColumn("int_field"), 5.asInstanceOf[java.lang.Integer])
)
.parquetProject(
SchemaBuilder.record("TestRecord").fields().optionalInt("int_field").endRecord()
)
val filter = FilterApi.gt(FilterApi.intColumn("id"), 5.asInstanceOf[java.lang.Integer])
val projection = SchemaBuilder.record("Account").fields().requiredInt("id").endRecord()

transformed.map(_.int_field) should contain theSameElementsAs Seq(6, 7, 8, 9, 10)
val transformed = records withFilter filter withProjection projection
transformed.map(_.getId) should contain theSameElementsAs Seq(6, 7, 8, 9, 10)
transformed.foreach { r =>
r.long_field shouldBe null
r.string_field shouldBe null
r.double_field shouldBe null
r.boolean_field shouldBe null
r.getName shouldBe null
r.getAmount shouldBe 0.0d
r.getAccountStatus shouldBe null
r.getType shouldBe null
}
}

"Avro GenericRecords" should "be filterable and projectable" in {
import com.spotify.scio.testing.parquet.avro._

val recordSchema = SchemaBuilder
.record("TestRecord")
.fields()
Expand All @@ -77,31 +76,32 @@ class ParquetTestUtilsTest extends AnyFlatSpec with Matchers with ParquetTestUti
.build()
)

val transformed = records
.parquetFilter(
FilterApi.gt(FilterApi.intColumn("int_field"), Int.box(5))
)
.parquetProject(
SchemaBuilder.record("Projection").fields().optionalInt("int_field").endRecord()
)
val filter = FilterApi.gt(FilterApi.intColumn("int_field"), 5.asInstanceOf[java.lang.Integer])
val projection =
SchemaBuilder.record("Projection").fields().optionalInt("int_field").endRecord()

transformed.map(_.get("int_field").toString.toInt) should contain theSameElementsAs Seq(6, 7, 8,
9, 10)
transformed.foreach { r =>
r.get("string_field") shouldBe null
}
records withFilter filter withProjection projection should contain theSameElementsAs Seq(6, 7,
8, 9, 10).map(i =>
new GenericRecordBuilder(recordSchema)
.set("int_field", i)
.set("string_field", null)
.build()
)
}

"Case classes" should "be filterable" in {
val records = (1 to 10).map(SomeRecord)
import com.spotify.scio.testing.parquet.types._

val transformed = records
.parquetFilter(FilterApi.gt(FilterApi.intColumn("intField"), Int.box(5)))
val records = (1 to 10).map(SomeRecord)

transformed.map(_.intField) should contain theSameElementsAs Seq(6, 7, 8, 9, 10)
records withFilter (
FilterApi.gt(FilterApi.intColumn("intField"), Int.box(5))
) should contain theSameElementsAs Seq(6, 7, 8, 9, 10).map(SomeRecord)
}

"TfExamples" should "be filterable and projectable" in {
import com.spotify.scio.testing.parquet.tensorflow._

val required = tfmd.ValueCount.newBuilder().setMin(1).setMax(1).build()

val schema = tfmd.Schema
Expand Down Expand Up @@ -157,27 +157,23 @@ class ParquetTestUtilsTest extends AnyFlatSpec with Matchers with ParquetTestUti
.build()
)

val transformed = records
.parquetFilter(
schema,
FilterApi.gt(FilterApi.floatColumn("float_required"), Float.box(5.5f))
)
.parquetProject(
schema,
tfmd.Schema
.newBuilder()
.addFeature(
tfmd.Feature
.newBuilder()
.setName("int64_required")
.setType(tfmd.FeatureType.INT)
.setValueCount(required)
.build()
)
.build()
)

transformed should contain theSameElementsAs (1 to 4).map(i =>
records withFilter (
schema,
FilterApi.gt(FilterApi.floatColumn("float_required"), Float.box(5.5f))
) withProjection (
schema,
tfmd.Schema
.newBuilder()
.addFeature(
tfmd.Feature
.newBuilder()
.setName("int64_required")
.setType(tfmd.FeatureType.INT)
.setValueCount(required)
.build()
)
.build()
) should contain theSameElementsAs (1 to 4).map(i =>
Example
.newBuilder()
.setFeatures(
Expand Down

0 comments on commit 344c3f0

Please sign in to comment.