Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Treat value classes as such #610

Merged
merged 14 commits into from
Oct 14, 2022
114 changes: 66 additions & 48 deletions avro/src/main/scala/magnolify/avro/AvroType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,20 @@ sealed trait AvroType[T] extends Converter[T, GenericRecord, GenericRecord] {
}

object AvroType {
implicit def apply[T: AvroField.Record]: AvroType[T] = AvroType(CaseMapper.identity)

def apply[T](cm: CaseMapper)(implicit f: AvroField.Record[T]): AvroType[T] = {
f.schema(cm) // fail fast on bad annotations
new AvroType[T] {
private val caseMapper: CaseMapper = cm
@transient override lazy val schema: Schema = f.schema(caseMapper)
override def from(v: GenericRecord): T = f.from(v)(caseMapper)
override def to(v: T): GenericRecord = f.to(v)(caseMapper)
implicit def apply[T: AvroField]: AvroType[T] = AvroType(CaseMapper.identity)

def apply[T](cm: CaseMapper)(implicit f: AvroField[T]): AvroType[T] = {
f match {
case r: AvroField.Record[_] =>
r.schema(cm) // fail fast on bad annotations
new AvroType[T] {
private val caseMapper: CaseMapper = cm
@transient override lazy val schema: Schema = r.schema(caseMapper)
override def from(v: GenericRecord): T = r.from(v)(caseMapper)
override def to(v: T): GenericRecord = r.to(v)(caseMapper)
}
case _ =>
throw new IllegalArgumentException(s"AvroType can only be created from Record. Got $f")
}
}
}
Expand Down Expand Up @@ -86,54 +91,67 @@ object AvroField {
override type FromT = From
override type ToT = To
}

sealed trait Record[T] extends Aux[T, GenericRecord, GenericRecord]

// ////////////////////////////////////////////////

type Typeclass[T] = AvroField[T]

def join[T](caseClass: CaseClass[Typeclass, T]): Record[T] = new Record[T] {
override protected def buildSchema(cm: CaseMapper): Schema = Schema
.createRecord(
caseClass.typeName.short,
getDoc(caseClass.annotations, caseClass.typeName.full),
caseClass.typeName.owner,
false,
caseClass.parameters.map { p =>
new Schema.Field(
cm.map(p.label),
p.typeclass.schema(cm),
getDoc(p.annotations, s"${caseClass.typeName.full}#${p.label}"),
p.default
.map(d => p.typeclass.makeDefault(d)(cm))
.getOrElse(p.typeclass.fallbackDefault)
def join[T](caseClass: CaseClass[Typeclass, T]): AvroField[T] = {
if (caseClass.isValueClass) {
val p = caseClass.parameters.head
val tc = p.typeclass
new AvroField[T] {
override type FromT = tc.FromT
override type ToT = tc.ToT
override protected def buildSchema(cm: CaseMapper): Schema = tc.buildSchema(cm)
override def from(v: FromT)(cm: CaseMapper): T = caseClass.construct(_ => tc.fromAny(v)(cm))
override def to(v: T)(cm: CaseMapper): ToT = tc.to(p.dereference(v))(cm)
}
} else {
new Record[T] {
override protected def buildSchema(cm: CaseMapper): Schema = Schema
.createRecord(
caseClass.typeName.short,
getDoc(caseClass.annotations, caseClass.typeName.full),
caseClass.typeName.owner,
false,
caseClass.parameters.map { p =>
new Schema.Field(
cm.map(p.label),
p.typeclass.schema(cm),
getDoc(p.annotations, s"${caseClass.typeName.full}#${p.label}"),
p.default
.map(d => p.typeclass.makeDefault(d)(cm))
.getOrElse(p.typeclass.fallbackDefault)
)
}.asJava
)
}.asJava
)

// `JacksonUtils.toJson` expects `Map[String, Any]` for `RECORD` defaults
override def makeDefault(d: T)(cm: CaseMapper): ju.Map[String, Any] = {
caseClass.parameters
.map { p =>
val name = cm.map(p.label)
val value = p.typeclass.makeDefault(p.dereference(d))(cm)
name -> value

// `JacksonUtils.toJson` expects `Map[String, Any]` for `RECORD` defaults
override def makeDefault(d: T)(cm: CaseMapper): ju.Map[String, Any] = {
caseClass.parameters
.map { p =>
val name = cm.map(p.label)
val value = p.typeclass.makeDefault(p.dereference(d))(cm)
name -> value
}
.toMap
.asJava
}
.toMap
.asJava
}

override def from(v: GenericRecord)(cm: CaseMapper): T =
caseClass.construct { p =>
p.typeclass.fromAny(v.get(p.index))(cm)
}
override def from(v: GenericRecord)(cm: CaseMapper): T =
caseClass.construct { p =>
p.typeclass.fromAny(v.get(p.index))(cm)
}

override def to(v: T)(cm: CaseMapper): GenericRecord =
caseClass.parameters.foldLeft(new GenericData.Record(schema(cm))) { (r, p) =>
r.put(p.index, p.typeclass.to(p.dereference(v))(cm))
r
override def to(v: T)(cm: CaseMapper): GenericRecord =
caseClass.parameters.foldLeft(new GenericData.Record(schema(cm))) { (r, p) =>
r.put(p.index, p.typeclass.to(p.dereference(v))(cm))
r
}
}
}
}

private def getDoc(annotations: Seq[Any], name: String): String = {
Expand All @@ -144,9 +162,9 @@ object AvroField {

@implicitNotFound("Cannot derive AvroField for sealed trait")
private sealed trait Dispatchable[T]
def split[T: Dispatchable](sealedTrait: SealedTrait[Typeclass, T]): Record[T] = ???
def split[T: Dispatchable](sealedTrait: SealedTrait[Typeclass, T]): AvroField[T] = ???

implicit def gen[T]: Record[T] = macro Magnolia.gen[T]
implicit def gen[T]: AvroField[T] = macro Magnolia.gen[T]

// ////////////////////////////////////////////////

Expand Down
10 changes: 10 additions & 0 deletions avro/src/test/scala/magnolify/avro/test/AvroTypeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,16 @@ class AvroTypeSuite extends MagnolifySuite {
test[Custom]
}

test("AnyVal") {
implicit val at: AvroType[HasValueClass] = AvroType[HasValueClass]
test[HasValueClass]

assert(at.schema.getField("vc").schema().getType == Schema.Type.STRING)

val record = at(HasValueClass(ValueClass("String")))
assert(record.get("vc") == "String")
}

{
implicit val eqByteArray: Eq[Array[Byte]] = Eq.by(_.toList)
test[AvroTypes]
Expand Down
121 changes: 70 additions & 51 deletions bigquery/src/main/scala/magnolify/bigquery/TableRowType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package magnolify.bigquery

import java.{util => ju}

import com.google.api.services.bigquery.model.{TableFieldSchema, TableRow, TableSchema}
import com.google.common.io.BaseEncoding
import magnolia1._
Expand All @@ -42,17 +41,22 @@ sealed trait TableRowType[T] extends Converter[T, TableRow, TableRow] {
}

object TableRowType {
implicit def apply[T: TableRowField.Record]: TableRowType[T] = TableRowType(CaseMapper.identity)

def apply[T](cm: CaseMapper)(implicit f: TableRowField.Record[T]): TableRowType[T] = {
f.fieldSchema(cm) // fail fast on bad annotations
new TableRowType[T] {
private val caseMapper: CaseMapper = cm
@transient override lazy val schema: TableSchema =
new TableSchema().setFields(f.fieldSchema(caseMapper).getFields)
override val description: String = f.fieldSchema(caseMapper).getDescription
override def from(v: TableRow): T = f.from(v)(caseMapper)
override def to(v: T): TableRow = f.to(v)(caseMapper)
implicit def apply[T: TableRowField]: TableRowType[T] = TableRowType(CaseMapper.identity)

def apply[T](cm: CaseMapper)(implicit f: TableRowField[T]): TableRowType[T] = {
f match {
case pr: TableRowField.Record[_] =>
pr.fieldSchema(cm) // fail fast on bad annotations
new TableRowType[T] {
private val caseMapper: CaseMapper = cm
@transient override lazy val schema: TableSchema =
new TableSchema().setFields(pr.fieldSchema(caseMapper).getFields)
override val description: String = pr.fieldSchema(caseMapper).getDescription
override def from(v: TableRow): T = pr.from(v)(caseMapper)
override def to(v: T): TableRow = pr.to(v)(caseMapper)
}
case _ =>
throw new IllegalArgumentException(s"TableRowType can only be created from Record. Got $f")
}
}
}
Expand Down Expand Up @@ -84,57 +88,72 @@ object TableRowField {
sealed trait Record[T] extends Aux[T, ju.Map[String, AnyRef], TableRow]

// ////////////////////////////////////////////////

type Typeclass[T] = TableRowField[T]

def join[T](caseClass: CaseClass[Typeclass, T]): Record[T] = new Record[T] {
override protected def buildSchema(cm: CaseMapper): TableFieldSchema = {
// do not use a scala wrapper in the schema, so clone() works
val fields = new ju.ArrayList[TableFieldSchema](caseClass.parameters.size)
caseClass.parameters.foreach { p =>
val f = p.typeclass
.fieldSchema(cm)
.clone()
.setName(cm.map(p.label))
.setDescription(getDescription(p.annotations, s"${caseClass.typeName.full}#${p.label}"))
fields.add(f)
def join[T](caseClass: CaseClass[Typeclass, T]): TableRowField[T] = {
if (caseClass.isValueClass) {
val p = caseClass.parameters.head
val tc = p.typeclass
new TableRowField[T] {
override type FromT = tc.FromT
override type ToT = tc.ToT
override protected def buildSchema(cm: CaseMapper): TableFieldSchema = tc.buildSchema(cm)
override def from(v: FromT)(cm: CaseMapper): T = caseClass.construct(_ => tc.from(v)(cm))
override def to(v: T)(cm: CaseMapper): ToT = tc.to(p.dereference(v))(cm)
}

new TableFieldSchema()
.setType("STRUCT")
.setMode("REQUIRED")
.setDescription(getDescription(caseClass.annotations, caseClass.typeName.full))
.setFields(fields)
}

override def from(v: ju.Map[String, AnyRef])(cm: CaseMapper): T =
caseClass.construct { p =>
val f = v.get(cm.map(p.label))
if (f == null && p.default.isDefined) {
p.default.get
} else {
p.typeclass.fromAny(f)(cm)
} else {
new Record[T] {
override protected def buildSchema(cm: CaseMapper): TableFieldSchema = {
// do not use a scala wrapper in the schema, so clone() works
val fields = new ju.ArrayList[TableFieldSchema](caseClass.parameters.size)
caseClass.parameters.foreach { p =>
val f = p.typeclass
.fieldSchema(cm)
.clone()
.setName(cm.map(p.label))
.setDescription(
getDescription(p.annotations, s"${caseClass.typeName.full}#${p.label}")
)
fields.add(f)
}

new TableFieldSchema()
.setType("STRUCT")
.setMode("REQUIRED")
.setDescription(getDescription(caseClass.annotations, caseClass.typeName.full))
.setFields(fields)
}
}

override def to(v: T)(cm: CaseMapper): TableRow =
caseClass.parameters.foldLeft(new TableRow) { (tr, p) =>
val f = p.typeclass.to(p.dereference(v))(cm)
if (f == null) tr else tr.set(cm.map(p.label), f)
override def from(v: ju.Map[String, AnyRef])(cm: CaseMapper): T =
caseClass.construct { p =>
val f = v.get(cm.map(p.label))
if (f == null && p.default.isDefined) {
p.default.get
} else {
p.typeclass.fromAny(f)(cm)
}
}

override def to(v: T)(cm: CaseMapper): TableRow =
caseClass.parameters.foldLeft(new TableRow) { (tr, p) =>
val f = p.typeclass.to(p.dereference(v))(cm)
if (f == null) tr else tr.set(cm.map(p.label), f)
}

private def getDescription(annotations: Seq[Any], name: String): String = {
val descs = annotations.collect { case d: description => d.toString }
require(descs.size <= 1, s"More than one @description annotation: $name")
descs.headOption.orNull
}
}

private def getDescription(annotations: Seq[Any], name: String): String = {
val descs = annotations.collect { case d: description => d.toString }
require(descs.size <= 1, s"More than one @description annotation: $name")
descs.headOption.orNull
}
}

@implicitNotFound("Cannot derive TableRowField for sealed trait")
private sealed trait Dispatchable[T]
def split[T: Dispatchable](sealedTrait: SealedTrait[Typeclass, T]): Record[T] = ???
def split[T: Dispatchable](sealedTrait: SealedTrait[Typeclass, T]): TableRowField[T] = ???

implicit def gen[T]: Record[T] = macro Magnolia.gen[T]
implicit def gen[T]: TableRowField[T] = macro Magnolia.gen[T]

// ////////////////////////////////////////////////

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@ class TableRowTypeSuite extends MagnolifySuite {
test[Custom]
}

test("AnyVal") {
implicit val trt: TableRowType[HasValueClass] = TableRowType[HasValueClass]
test[HasValueClass]

assert(trt.schema.getFields.asScala.head.getType == "STRING")

val record = trt(HasValueClass(ValueClass("String")))
assert(record.get("vc") == "String")
}

{
implicit val arbBigDecimal: Arbitrary[BigDecimal] =
Arbitrary(Gen.chooseNum(0, Int.MaxValue).map(BigDecimal(_)))
Expand Down