Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
clairemcginty committed Mar 25, 2024
1 parent 60d4840 commit b5064ca
Showing 1 changed file with 35 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.avro.generic.GenericRecord
import org.apache.hadoop.conf.Configuration
import org.apache.parquet.avro.{AvroParquetReader, AvroParquetWriter, AvroReadSupport}
import org.apache.parquet.filter2.predicate.FilterPredicate
import org.apache.parquet.hadoop.ParquetInputFormat
import org.apache.parquet.hadoop.{ParquetInputFormat, ParquetReader, ParquetWriter}
import org.apache.parquet.io._

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
Expand All @@ -41,28 +41,15 @@ trait ParquetTestUtils {
case class ParquetMagnolifyHelpers[T: ParquetType: ClassTag](records: Iterable[T])
extends ImplementsFilter[T] {
override def parquetFilter(filter: FilterPredicate): Iterable[T] = {
val configuration = new Configuration()
ParquetInputFormat.setFilterPredicate(configuration, filter)

roundtrip(records, configuration)
}

private def roundtrip(records: Iterable[T], readConfiguration: Configuration): Iterable[T] = {
val pt = implicitly[ParquetType[T]]
val baos = new ByteArrayOutputStream()
val writer = pt.writeBuilder(inMemoryOutputFile(baos)).build()

records.foreach(writer.write)
writer.close()

val reader = pt
.readBuilder(inMemoryInputFile(baos.toByteArray))
.withConf(readConfiguration)
.build()
val configuration = new Configuration()
ParquetInputFormat.setFilterPredicate(configuration, filter)

val roundtripped = Iterator.continually(reader.read()).takeWhile(_ != null).toSeq
reader.close()
roundtripped
roundtrip(
outputFile => pt.writeBuilder(outputFile).build(),
inputFile => pt.readBuilder(inputFile).withConf(configuration).build()
)(records)
}
}

Expand All @@ -73,46 +60,54 @@ trait ParquetTestUtils {
val configuration = new Configuration()
AvroReadSupport.setRequestedProjection(configuration, projection)

roundtrip(records, configuration)
roundtripAvro(records, configuration)
}

override def parquetFilter(filter: FilterPredicate): Iterable[T] = {
val configuration = new Configuration()
ParquetInputFormat.setFilterPredicate(configuration, filter)

roundtrip(records, configuration)
roundtripAvro(records, configuration)
}

private def roundtrip(records: Iterable[T], readConfiguration: Configuration): Iterable[T] = {
private def roundtripAvro(
records: Iterable[T],
readConfiguration: Configuration
): Iterable[T] = {
records.headOption match {
case None =>
records // empty iterable
case Some(head) =>
val schema = head.getSchema

val baos = new ByteArrayOutputStream()
val writer = AvroParquetWriter
.builder[T](inMemoryOutputFile(baos))
.withSchema(schema)
.build()

records.foreach(writer.write)
writer.close()

val reader = AvroParquetReader
.builder[T](inMemoryInputFile(baos.toByteArray))
.withConf(readConfiguration)
.build()

val roundtripped = Iterator.continually(reader.read()).takeWhile(_ != null).toSeq
reader.close()
roundtripped
roundtrip(
outputFile => AvroParquetWriter.builder[T](outputFile).withSchema(schema).build(),
inputFile => AvroParquetReader.builder[T](inputFile).withConf(readConfiguration).build()
)(records)
}
}
}

// @Todo tensorflow helpers

private def roundtrip[T](
writerFn: OutputFile => ParquetWriter[T],
readerFn: InputFile => ParquetReader[T]
)(
records: Iterable[T]
): Iterable[T] = {
val baos = new ByteArrayOutputStream()
val writer = writerFn(inMemoryOutputFile(baos))

records.foreach(writer.write)
writer.close()

val reader = readerFn(inMemoryInputFile(baos.toByteArray))
val roundtripped = Iterator.continually(reader.read()).takeWhile(_ != null).toSeq
reader.close()
roundtripped
}

private def inMemoryOutputFile(baos: ByteArrayOutputStream): OutputFile = new OutputFile {
override def create(blockSizeHint: Long): PositionOutputStream = newPositionOutputStream()

Expand Down

0 comments on commit b5064ca

Please sign in to comment.