From 708ce2a25a4ab84bcb826c0013de64618d9877c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sylvain=20Veyri=C3=A9?= Date: Tue, 27 Sep 2022 22:37:44 +0200 Subject: [PATCH 01/12] Neo4J: treat value classes as such --- .../scala/magnolify/neo4j/ValueType.scala | 20 +++++++++++++------ .../magnolify/neo4j/ValueTypeSuite.scala | 16 +++++++++++++++ .../test/scala/magnolify/test/Simple.scala | 3 +++ 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/neo4j/src/main/scala/magnolify/neo4j/ValueType.scala b/neo4j/src/main/scala/magnolify/neo4j/ValueType.scala index 85987332d..f4fe80320 100644 --- a/neo4j/src/main/scala/magnolify/neo4j/ValueType.scala +++ b/neo4j/src/main/scala/magnolify/neo4j/ValueType.scala @@ -61,7 +61,10 @@ object ValueField { override def from(v: Value)(cm: CaseMapper): T = caseClass.construct { p => val field = cm.map(p.label) - try p.typeclass.from(v.get(field))(cm) + try { + val value = if (caseClass.isValueClass) v else v.get(field) + p.typeclass.from(value)(cm) + } catch { case e: ValueException => throw new RuntimeException(s"Failed to decode $field: ${e.getMessage}", e) @@ -69,11 +72,16 @@ object ValueField { } override def to(v: T)(cm: CaseMapper): Value = { - val jmap = caseClass.parameters - .foldLeft(Map.newBuilder[String, AnyRef]) { (m, p) => - m += cm.map(p.label) -> p.typeclass.to(p.dereference(v))(cm) - m - } + val jmap = if (caseClass.isValueClass) { + val p = caseClass.parameters.head + p.typeclass.to(p.dereference(v))(cm) + } + else + caseClass.parameters + .foldLeft(Map.newBuilder[String, AnyRef]) { (m, p) => + m += cm.map(p.label) -> p.typeclass.to(p.dereference(v))(cm) + m + } .result() .asJava Values.value(jmap) diff --git a/neo4j/src/test/scala/magnolify/neo4j/ValueTypeSuite.scala b/neo4j/src/test/scala/magnolify/neo4j/ValueTypeSuite.scala index 72e51b112..8d9aff917 100644 --- a/neo4j/src/test/scala/magnolify/neo4j/ValueTypeSuite.scala +++ b/neo4j/src/test/scala/magnolify/neo4j/ValueTypeSuite.scala @@ -22,10 +22,13 @@ import magnolify.test.Simple._ import magnolify.cats.auto._ import magnolify.scalacheck.auto._ import magnolify.shared.CaseMapper +import org.neo4j.driver.Value +import org.neo4j.driver.internal.value.{MapValue, StringValue} import org.scalacheck.{Arbitrary, Prop} import java.net.URI import scala.reflect.ClassTag +import scala.jdk.CollectionConverters._ class ValueTypeSuite extends MagnolifySuite { @@ -76,4 +79,17 @@ class ValueTypeSuite extends MagnolifySuite { assert(!fields.map(record.get).exists(_.isNull)) assert(!record.get("INNERFIELD").get("INNERFIRST").isNull) } + + test("AnyVal") { + val vt: ValueType[HasValueClass] = implicitly + test[HasValueClass] + + val record = vt(HasValueClass(ValueClass("String"))) + assert(record.get("vc").asString() == "String") + + val v: Value = new StringValue("Hello, world") + val a = new MapValue(Map("vc" -> v).asJava) + val c = vt.from(a) + assert(c == HasValueClass(ValueClass("Hello, world"))) + } } diff --git a/test/src/test/scala/magnolify/test/Simple.scala b/test/src/test/scala/magnolify/test/Simple.scala index d3cbf5f0c..a33f7cb85 100644 --- a/test/src/test/scala/magnolify/test/Simple.scala +++ b/test/src/test/scala/magnolify/test/Simple.scala @@ -123,4 +123,7 @@ object Simple { val fields: Seq[String] = Seq("firstField", "secondField", "innerField") val default: LowerCamel = LowerCamel("first", "second", LowerCamelInner("inner.first")) } + + case class ValueClass(str: String) extends AnyVal + case class HasValueClass(vc: ValueClass) } From f6df7de26ae163c9e9a9f5f130c3210a0baa4819 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sylvain=20Veyri=C3=A9?= Date: Wed, 5 Oct 2022 10:57:40 +0200 Subject: [PATCH 02/12] Neo4J: scalafmt --- .../main/scala/magnolify/neo4j/ValueType.scala | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/neo4j/src/main/scala/magnolify/neo4j/ValueType.scala b/neo4j/src/main/scala/magnolify/neo4j/ValueType.scala index f4fe80320..c40a1f634 100644 --- a/neo4j/src/main/scala/magnolify/neo4j/ValueType.scala +++ b/neo4j/src/main/scala/magnolify/neo4j/ValueType.scala @@ -30,6 +30,7 @@ import scala.collection.compat._ trait ValueType[T] extends Converter[T, Value, Value] { def apply(r: Value): T = from(r) + def apply(t: T): Value = to(t) } @@ -39,12 +40,15 @@ object ValueType { def apply[T](cm: CaseMapper)(implicit f: ValueField.Record[T]): ValueType[T] = new ValueType[T] { private val caseMapper: CaseMapper = cm + override def from(v: Value): T = f.from(v)(caseMapper) + override def to(v: T): Value = f.to(v)(caseMapper) } } -sealed trait ValueField[T] extends Serializable { self => +sealed trait ValueField[T] extends Serializable { + self => def from(v: Value)(cm: CaseMapper): T @@ -64,8 +68,7 @@ object ValueField { try { val value = if (caseClass.isValueClass) v else v.get(field) p.typeclass.from(value)(cm) - } - catch { + } catch { case e: ValueException => throw new RuntimeException(s"Failed to decode $field: ${e.getMessage}", e) } @@ -75,15 +78,14 @@ object ValueField { val jmap = if (caseClass.isValueClass) { val p = caseClass.parameters.head p.typeclass.to(p.dereference(v))(cm) - } - else + } else caseClass.parameters .foldLeft(Map.newBuilder[String, AnyRef]) { (m, p) => m += cm.map(p.label) -> p.typeclass.to(p.dereference(v))(cm) m } - .result() - .asJava + .result() + .asJava Values.value(jmap) } } @@ -105,6 +107,7 @@ object ValueField { def apply[U](f: T => U)(g: U => T)(implicit af: ValueField[T]): ValueField[U] = new ValueField[U] { override def from(v: Value)(cm: CaseMapper): U = f(af.from(v)(cm)) + override def to(v: U)(cm: CaseMapper): Value = af.to(g(v))(cm) } } @@ -116,6 +119,7 @@ object ValueField { if (v.isNull) throw new ValueException("Cannot convert null value") f(v) } + override def to(v: T)(cm: CaseMapper): Value = Values.value(v) } From 33ce1528cb97b811739005335c8a9cbd7109ea6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=2E=20Veyri=C3=A9?= Date: Wed, 5 Oct 2022 11:25:59 +0200 Subject: [PATCH 03/12] Update neo4j/src/test/scala/magnolify/neo4j/ValueTypeSuite.scala Co-authored-by: Michel Davit --- neo4j/src/test/scala/magnolify/neo4j/ValueTypeSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neo4j/src/test/scala/magnolify/neo4j/ValueTypeSuite.scala b/neo4j/src/test/scala/magnolify/neo4j/ValueTypeSuite.scala index 8d9aff917..2b8e667b4 100644 --- a/neo4j/src/test/scala/magnolify/neo4j/ValueTypeSuite.scala +++ b/neo4j/src/test/scala/magnolify/neo4j/ValueTypeSuite.scala @@ -81,7 +81,7 @@ class ValueTypeSuite extends MagnolifySuite { } test("AnyVal") { - val vt: ValueType[HasValueClass] = implicitly + implicit val vt: ValueType[HasValueClass] = ValueType[HasValueClass] test[HasValueClass] val record = vt(HasValueClass(ValueClass("String"))) From b588d73a75e6c9f0b2876fbc7ccc6cbd9fbfef27 Mon Sep 17 00:00:00 2001 From: Michel Davit Date: Mon, 10 Oct 2022 14:42:43 +0200 Subject: [PATCH 04/12] Add avro value class support --- .../main/scala/magnolify/avro/AvroType.scala | 117 +++++++++++------- .../magnolify/avro/test/AvroTypeSuite.scala | 10 ++ 2 files changed, 83 insertions(+), 44 deletions(-) diff --git a/avro/src/main/scala/magnolify/avro/AvroType.scala b/avro/src/main/scala/magnolify/avro/AvroType.scala index d6690f9e2..7851bf800 100644 --- a/avro/src/main/scala/magnolify/avro/AvroType.scala +++ b/avro/src/main/scala/magnolify/avro/AvroType.scala @@ -20,13 +20,14 @@ import java.nio.ByteBuffer import java.time._ import java.{util => ju} import magnolia1._ +import magnolify.avro.AvroField.{ProductRecord, ValueClassRecord} import magnolify.shared._ import magnolify.shims.FactoryCompat import org.apache.avro.generic.GenericData.EnumSymbol import org.apache.avro.generic._ import org.apache.avro.{JsonProperties, LogicalType, LogicalTypes, Schema} -import scala.annotation.{implicitNotFound, StaticAnnotation} +import scala.annotation.{StaticAnnotation, implicitNotFound} import scala.collection.concurrent import scala.language.experimental.macros import scala.language.implicitConversions @@ -48,12 +49,17 @@ object AvroType { implicit def apply[T: AvroField.Record]: AvroType[T] = AvroType(CaseMapper.identity) def apply[T](cm: CaseMapper)(implicit f: AvroField.Record[T]): AvroType[T] = { - f.schema(cm) // fail fast on bad annotations - new AvroType[T] { - private val caseMapper: CaseMapper = cm - @transient override lazy val schema: Schema = f.schema(caseMapper) - override def from(v: GenericRecord): T = f.from(v)(caseMapper) - override def to(v: T): GenericRecord = f.to(v)(caseMapper) + f match { + case pr: ProductRecord[_] => + pr.schema(cm) // fail fast on bad annotations + new AvroType[T] { + private val caseMapper: CaseMapper = cm + @transient override lazy val schema: Schema = pr.schema(caseMapper) + override def from(v: GenericRecord): T = pr.from(v)(caseMapper) + override def to(v: T): GenericRecord = pr.to(v)(caseMapper) + } + case _: ValueClassRecord[_] => + throw new IllegalArgumentException("Value classes are not valid AvroType") } } } @@ -87,53 +93,76 @@ object AvroField { override type ToT = To } - sealed trait Record[T] extends Aux[T, GenericRecord, GenericRecord] + sealed trait Record[T] extends AvroField[T] + sealed trait ValueClassRecord[T] extends Record[T] + sealed trait ProductRecord[T] extends Record[T] { + override type FromT = GenericRecord + override type ToT = GenericRecord + } // //////////////////////////////////////////////// type Typeclass[T] = AvroField[T] - def join[T](caseClass: CaseClass[Typeclass, T]): Record[T] = new Record[T] { - override protected def buildSchema(cm: CaseMapper): Schema = Schema - .createRecord( - caseClass.typeName.short, - getDoc(caseClass.annotations, caseClass.typeName.full), - caseClass.typeName.owner, - false, - caseClass.parameters.map { p => - new Schema.Field( - cm.map(p.label), - p.typeclass.schema(cm), - getDoc(p.annotations, s"${caseClass.typeName.full}#${p.label}"), - p.default - .map(d => p.typeclass.makeDefault(d)(cm)) - .getOrElse(p.typeclass.fallbackDefault) + def join[T](caseClass: CaseClass[Typeclass, T]): Record[T] = { + if (caseClass.isValueClass) { + val p = caseClass.parameters.head + val tc = p.typeclass + new ValueClassRecord[T] { + override type FromT = tc.FromT + override type ToT = tc.ToT + + override protected def buildSchema(cm: CaseMapper): Schema = tc.buildSchema(cm) + override def from(v: FromT)(cm: CaseMapper): T = + caseClass.construct(_ => tc.fromAny(v)(cm)) + override def to(v: T)(cm: CaseMapper): ToT = { + tc.to(p.dereference(v))(cm) + } + } + } else { + new ProductRecord[T] { + override protected def buildSchema(cm: CaseMapper): Schema = Schema + .createRecord( + caseClass.typeName.short, + getDoc(caseClass.annotations, caseClass.typeName.full), + caseClass.typeName.owner, + false, + caseClass.parameters.map { p => + new Schema.Field( + cm.map(p.label), + p.typeclass.schema(cm), + getDoc(p.annotations, s"${caseClass.typeName.full}#${p.label}"), + p.default + .map(d => p.typeclass.makeDefault(d)(cm)) + .getOrElse(p.typeclass.fallbackDefault) + ) + }.asJava ) - }.asJava - ) - - // `JacksonUtils.toJson` expects `Map[String, Any]` for `RECORD` defaults - override def makeDefault(d: T)(cm: CaseMapper): ju.Map[String, Any] = { - caseClass.parameters - .map { p => - val name = cm.map(p.label) - val value = p.typeclass.makeDefault(p.dereference(d))(cm) - name -> value + + // `JacksonUtils.toJson` expects `Map[String, Any]` for `RECORD` defaults + override def makeDefault(d: T)(cm: CaseMapper): ju.Map[String, Any] = { + caseClass.parameters + .map { p => + val name = cm.map(p.label) + val value = p.typeclass.makeDefault(p.dereference(d))(cm) + name -> value + } + .toMap + .asJava } - .toMap - .asJava - } - override def from(v: GenericRecord)(cm: CaseMapper): T = - caseClass.construct { p => - p.typeclass.fromAny(v.get(p.index))(cm) - } + override def from(v: GenericRecord)(cm: CaseMapper): T = + caseClass.construct { p => + p.typeclass.fromAny(v.get(p.index))(cm) + } - override def to(v: T)(cm: CaseMapper): GenericRecord = - caseClass.parameters.foldLeft(new GenericData.Record(schema(cm))) { (r, p) => - r.put(p.index, p.typeclass.to(p.dereference(v))(cm)) - r + override def to(v: T)(cm: CaseMapper): GenericRecord = + caseClass.parameters.foldLeft(new GenericData.Record(schema(cm))) { (r, p) => + r.put(p.index, p.typeclass.to(p.dereference(v))(cm)) + r + } } + } } private def getDoc(annotations: Seq[Any], name: String): String = { diff --git a/avro/src/test/scala/magnolify/avro/test/AvroTypeSuite.scala b/avro/src/test/scala/magnolify/avro/test/AvroTypeSuite.scala index 14c0894b0..af9ef84c6 100644 --- a/avro/src/test/scala/magnolify/avro/test/AvroTypeSuite.scala +++ b/avro/src/test/scala/magnolify/avro/test/AvroTypeSuite.scala @@ -127,6 +127,16 @@ class AvroTypeSuite extends MagnolifySuite { test[Custom] } + test("AnyVal") { + implicit val at: AvroType[HasValueClass] = AvroType[HasValueClass] + test[HasValueClass] + + assert(at.schema.getField("vc").schema().getType == Schema.Type.STRING) + + val record = at(HasValueClass(ValueClass("String"))) + assert(record.get("vc") == "String") + } + { implicit val eqByteArray: Eq[Array[Byte]] = Eq.by(_.toList) test[AvroTypes] From 887f084c6851d75a35068bbb898fd8ec3b8ded10 Mon Sep 17 00:00:00 2001 From: Michel Davit Date: Mon, 10 Oct 2022 15:08:37 +0200 Subject: [PATCH 05/12] Add bigquery value-class support --- .../magnolify/bigquery/TableRowType.scala | 119 +++++++++++------- .../bigquery/test/TableRowTypeSuite.scala | 10 ++ 2 files changed, 83 insertions(+), 46 deletions(-) diff --git a/bigquery/src/main/scala/magnolify/bigquery/TableRowType.scala b/bigquery/src/main/scala/magnolify/bigquery/TableRowType.scala index 75aff3295..0b81d974e 100644 --- a/bigquery/src/main/scala/magnolify/bigquery/TableRowType.scala +++ b/bigquery/src/main/scala/magnolify/bigquery/TableRowType.scala @@ -17,10 +17,10 @@ package magnolify.bigquery import java.{util => ju} - import com.google.api.services.bigquery.model.{TableFieldSchema, TableRow, TableSchema} import com.google.common.io.BaseEncoding import magnolia1._ +import magnolify.bigquery.TableRowField.{ProductRecord, ValueClassRecord} import magnolify.shared.{CaseMapper, Converter} import magnolify.shims.FactoryCompat @@ -45,14 +45,19 @@ object TableRowType { implicit def apply[T: TableRowField.Record]: TableRowType[T] = TableRowType(CaseMapper.identity) def apply[T](cm: CaseMapper)(implicit f: TableRowField.Record[T]): TableRowType[T] = { - f.fieldSchema(cm) // fail fast on bad annotations - new TableRowType[T] { - private val caseMapper: CaseMapper = cm - @transient override lazy val schema: TableSchema = - new TableSchema().setFields(f.fieldSchema(caseMapper).getFields) - override val description: String = f.fieldSchema(caseMapper).getDescription - override def from(v: TableRow): T = f.from(v)(caseMapper) - override def to(v: T): TableRow = f.to(v)(caseMapper) + f match { + case pr: ProductRecord[_] => + pr.fieldSchema(cm) // fail fast on bad annotations + new TableRowType[T] { + private val caseMapper: CaseMapper = cm + @transient override lazy val schema: TableSchema = + new TableSchema().setFields(pr.fieldSchema(caseMapper).getFields) + override val description: String = pr.fieldSchema(caseMapper).getDescription + override def from(v: TableRow): T = pr.from(v)(caseMapper) + override def to(v: T): TableRow = pr.to(v)(caseMapper) + } + case _: ValueClassRecord[_] => + throw new IllegalArgumentException("Value classes are not valid TableRowType") } } } @@ -81,52 +86,74 @@ object TableRowField { } sealed trait Generic[T] extends Aux[T, Any, Any] - sealed trait Record[T] extends Aux[T, ju.Map[String, AnyRef], TableRow] + + sealed trait Record[T] extends TableRowField[T] + sealed trait ValueClassRecord[T] extends Record[T] + sealed trait ProductRecord[T] extends Record[T] { + override type FromT = ju.Map[String, AnyRef] + override type ToT = TableRow + } // //////////////////////////////////////////////// type Typeclass[T] = TableRowField[T] - def join[T](caseClass: CaseClass[Typeclass, T]): Record[T] = new Record[T] { - override protected def buildSchema(cm: CaseMapper): TableFieldSchema = { - // do not use a scala wrapper in the schema, so clone() works - val fields = new ju.ArrayList[TableFieldSchema](caseClass.parameters.size) - caseClass.parameters.foreach { p => - val f = p.typeclass - .fieldSchema(cm) - .clone() - .setName(cm.map(p.label)) - .setDescription(getDescription(p.annotations, s"${caseClass.typeName.full}#${p.label}")) - fields.add(f) + def join[T](caseClass: CaseClass[Typeclass, T]): Record[T] = { + if (caseClass.isValueClass) { + val p = caseClass.parameters.head + val tc = p.typeclass + new ValueClassRecord[T] { + override type FromT = tc.FromT + override type ToT = tc.ToT + override protected def buildSchema(cm: CaseMapper): TableFieldSchema = tc.buildSchema(cm) + override def from(v: FromT)(cm: CaseMapper): T = caseClass.construct(_ => tc.from(v)(cm)) + override def to(v: T)(cm: CaseMapper): ToT = tc.to(p.dereference(v))(cm) } - - new TableFieldSchema() - .setType("STRUCT") - .setMode("REQUIRED") - .setDescription(getDescription(caseClass.annotations, caseClass.typeName.full)) - .setFields(fields) - } - - override def from(v: ju.Map[String, AnyRef])(cm: CaseMapper): T = - caseClass.construct { p => - val f = v.get(cm.map(p.label)) - if (f == null && p.default.isDefined) { - p.default.get - } else { - p.typeclass.fromAny(f)(cm) + } else { + new ProductRecord[T] { + override protected def buildSchema(cm: CaseMapper): TableFieldSchema = { + // do not use a scala wrapper in the schema, so clone() works + val fields = new ju.ArrayList[TableFieldSchema](caseClass.parameters.size) + caseClass.parameters.foreach { p => + val f = p.typeclass + .fieldSchema(cm) + .clone() + .setName(cm.map(p.label)) + .setDescription( + getDescription(p.annotations, s"${caseClass.typeName.full}#${p.label}") + ) + fields.add(f) + } + + new TableFieldSchema() + .setType("STRUCT") + .setMode("REQUIRED") + .setDescription(getDescription(caseClass.annotations, caseClass.typeName.full)) + .setFields(fields) } - } - override def to(v: T)(cm: CaseMapper): TableRow = - caseClass.parameters.foldLeft(new TableRow) { (tr, p) => - val f = p.typeclass.to(p.dereference(v))(cm) - if (f == null) tr else tr.set(cm.map(p.label), f) + override def from(v: ju.Map[String, AnyRef])(cm: CaseMapper): T = + caseClass.construct { p => + val f = v.get(cm.map(p.label)) + if (f == null && p.default.isDefined) { + p.default.get + } else { + p.typeclass.fromAny(f)(cm) + } + } + + override def to(v: T)(cm: CaseMapper): TableRow = + caseClass.parameters.foldLeft(new TableRow) { (tr, p) => + val f = p.typeclass.to(p.dereference(v))(cm) + if (f == null) tr else tr.set(cm.map(p.label), f) + } + + private def getDescription(annotations: Seq[Any], name: String): String = { + val descs = annotations.collect { case d: description => d.toString } + require(descs.size <= 1, s"More than one @description annotation: $name") + descs.headOption.orNull + } } - - private def getDescription(annotations: Seq[Any], name: String): String = { - val descs = annotations.collect { case d: description => d.toString } - require(descs.size <= 1, s"More than one @description annotation: $name") - descs.headOption.orNull } } diff --git a/bigquery/src/test/scala/magnolify/bigquery/test/TableRowTypeSuite.scala b/bigquery/src/test/scala/magnolify/bigquery/test/TableRowTypeSuite.scala index fa89daf4d..beef2b386 100644 --- a/bigquery/src/test/scala/magnolify/bigquery/test/TableRowTypeSuite.scala +++ b/bigquery/src/test/scala/magnolify/bigquery/test/TableRowTypeSuite.scala @@ -87,6 +87,16 @@ class TableRowTypeSuite extends MagnolifySuite { test[Custom] } + test("AnyVal") { + implicit val trt: TableRowType[HasValueClass] = TableRowType[HasValueClass] + test[HasValueClass] + + assert(trt.schema.getFields.asScala.head.getType == "STRING") + + val record = trt(HasValueClass(ValueClass("String"))) + assert(record.get("vc") == "String") + } + { implicit val arbBigDecimal: Arbitrary[BigDecimal] = Arbitrary(Gen.chooseNum(0, Int.MaxValue).map(BigDecimal(_))) From decae9745e59e6973a5cfa15e4160d47a3f29d85 Mon Sep 17 00:00:00 2001 From: Michel Davit Date: Mon, 10 Oct 2022 15:09:27 +0200 Subject: [PATCH 06/12] Add bigtable value-class support --- .../magnolify/bigtable/BigtableType.scala | 54 +++++++++++-------- .../bigtable/test/BigtableTypeSuite.scala | 8 +++ 2 files changed, 41 insertions(+), 21 deletions(-) diff --git a/bigtable/src/main/scala/magnolify/bigtable/BigtableType.scala b/bigtable/src/main/scala/magnolify/bigtable/BigtableType.scala index 5bd283767..fa7455bc1 100644 --- a/bigtable/src/main/scala/magnolify/bigtable/BigtableType.scala +++ b/bigtable/src/main/scala/magnolify/bigtable/BigtableType.scala @@ -18,7 +18,6 @@ package magnolify.bigtable import java.nio.ByteBuffer import java.util.UUID - import com.google.bigtable.v2.{Cell, Column, Family, Mutation, Row} import com.google.bigtable.v2.Mutation.SetCell import com.google.protobuf.ByteString @@ -26,9 +25,9 @@ import magnolia1._ import magnolify.shared._ import magnolify.shims._ +import java.util import scala.annotation.implicitNotFound import scala.language.experimental.macros - import scala.jdk.CollectionConverters._ import scala.collection.compat._ @@ -136,28 +135,41 @@ object BigtableField { type Typeclass[T] = BigtableField[T] - def join[T](caseClass: CaseClass[Typeclass, T]): Record[T] = new Record[T] { - private def key(prefix: String, label: String): String = - if (prefix == null) label else s"$prefix.$label" - - override def get(xs: java.util.List[Column], k: String)(cm: CaseMapper): Value[T] = { - var fallback = true - val r = caseClass.construct { p => - val cq = key(k, cm.map(p.label)) - val v = p.typeclass.get(xs, cq)(cm) - if (v.isSome) { - fallback = false + def join[T](caseClass: CaseClass[Typeclass, T]): Record[T] = { + if (caseClass.isValueClass) { + val p = caseClass.parameters.head + val tc = p.typeclass + new Record[T] { + override def get(xs: util.List[Column], k: String)(cm: CaseMapper): Value[T] = + tc.get(xs, k)(cm).map(x => caseClass.construct(_ => x)) + override def put(k: String, v: T)(cm: CaseMapper): Seq[SetCell.Builder] = + p.typeclass.put(k, p.dereference(v))(cm) + } + } else { + new Record[T] { + private def key(prefix: String, label: String): String = + if (prefix == null) label else s"$prefix.$label" + + override def get(xs: java.util.List[Column], k: String)(cm: CaseMapper): Value[T] = { + var fallback = true + val r = caseClass.construct { p => + val cq = key(k, cm.map(p.label)) + val v = p.typeclass.get(xs, cq)(cm) + if (v.isSome) { + fallback = false + } + v.getOrElse(p.default) + } + // result is default if all fields are default + if (fallback) Value.Default(r) else Value.Some(r) } - v.getOrElse(p.default) + + override def put(k: String, v: T)(cm: CaseMapper): Seq[SetCell.Builder] = + caseClass.parameters.flatMap(p => + p.typeclass.put(key(k, cm.map(p.label)), p.dereference(v))(cm) + ) } - // result is default if all fields are default - if (fallback) Value.Default(r) else Value.Some(r) } - - override def put(k: String, v: T)(cm: CaseMapper): Seq[SetCell.Builder] = - caseClass.parameters.flatMap(p => - p.typeclass.put(key(k, cm.map(p.label)), p.dereference(v))(cm) - ) } @implicitNotFound("Cannot derive BigtableField for sealed trait") diff --git a/bigtable/src/test/scala/magnolify/bigtable/test/BigtableTypeSuite.scala b/bigtable/src/test/scala/magnolify/bigtable/test/BigtableTypeSuite.scala index 9367fc8a2..500c8f6ad 100644 --- a/bigtable/src/test/scala/magnolify/bigtable/test/BigtableTypeSuite.scala +++ b/bigtable/src/test/scala/magnolify/bigtable/test/BigtableTypeSuite.scala @@ -80,6 +80,14 @@ class BigtableTypeSuite extends MagnolifySuite { test[Custom] } + test("AnyVal") { + implicit val btt: BigtableType[HasValueClass] = BigtableType[HasValueClass] + test[HasValueClass] + + val records = btt(HasValueClass(ValueClass("String")), "cf") + assert(records.head.getSetCell.getValue.toStringUtf8 == "String") + } + { implicit val arbByteString: Arbitrary[ByteString] = Arbitrary(Gen.alphaNumStr.map(ByteString.copyFromUtf8)) From 33d3776d0c7975f5889755abae2db1ffc1ec7a06 Mon Sep 17 00:00:00 2001 From: Michel Davit Date: Mon, 10 Oct 2022 15:26:22 +0200 Subject: [PATCH 07/12] Add datastore value-class support --- .../magnolify/datastore/EntityType.scala | 188 ++++++++++-------- .../datastore/test/EntityTypeSuite.scala | 8 + 2 files changed, 113 insertions(+), 83 deletions(-) diff --git a/datastore/src/main/scala/magnolify/datastore/EntityType.scala b/datastore/src/main/scala/magnolify/datastore/EntityType.scala index 0af6e96f8..e1037e49d 100644 --- a/datastore/src/main/scala/magnolify/datastore/EntityType.scala +++ b/datastore/src/main/scala/magnolify/datastore/EntityType.scala @@ -17,15 +17,15 @@ package magnolify.datastore import java.time.Instant - import com.google.datastore.v1._ import com.google.datastore.v1.client.DatastoreHelper.makeValue import com.google.protobuf.{ByteString, NullValue} import magnolia1._ +import magnolify.datastore.EntityField.{ProductRecord, ValueClassRecord} import magnolify.shared.{CaseMapper, Converter} import magnolify.shims.FactoryCompat -import scala.annotation.{implicitNotFound, StaticAnnotation} +import scala.annotation.{StaticAnnotation, implicitNotFound} import scala.language.experimental.macros import scala.jdk.CollectionConverters._ import scala.collection.compat._ @@ -43,12 +43,16 @@ class excludeFromIndexes(val exclude: Boolean = true) extends StaticAnnotation w object EntityType { implicit def apply[T: EntityField.Record]: EntityType[T] = EntityType(CaseMapper.identity) - def apply[T](cm: CaseMapper)(implicit f: EntityField.Record[T]): EntityType[T] = - new EntityType[T] { - private val caseMapper: CaseMapper = cm - override def from(v: Entity): T = f.fromEntity(v)(caseMapper) - override def to(v: T): Entity.Builder = f.toEntity(v)(caseMapper) - } + def apply[T](cm: CaseMapper)(implicit f: EntityField.Record[T]): EntityType[T] = f match { + case pr: ProductRecord[_] => + new EntityType[T] { + private val caseMapper: CaseMapper = cm + override def from(v: Entity): T = pr.fromEntity(v)(caseMapper) + override def to(v: T): Entity.Builder = pr.toEntity(v)(caseMapper) + } + case _: ValueClassRecord[_] => + throw new IllegalArgumentException("Value classes are not valid EntityType") + } } sealed trait KeyField[T] extends Serializable { self => @@ -91,7 +95,12 @@ sealed trait EntityField[T] extends Serializable { } object EntityField { - sealed trait Record[T] extends EntityField[T] { + + sealed trait Record[T] extends EntityField[T] + + sealed trait ValueClassRecord[T] extends Record[T] + + sealed trait ProductRecord[T] extends Record[T] { def fromEntity(v: Entity)(cm: CaseMapper): T def toEntity(v: T)(cm: CaseMapper): Entity.Builder @@ -104,90 +113,103 @@ object EntityField { type Typeclass[T] = EntityField[T] - def join[T: KeyField](caseClass: CaseClass[Typeclass, T]): Record[T] = new Record[T] { - private val (keyIndex, keyOpt): (Int, Option[key]) = { - val keys = caseClass.parameters - .map(p => p -> getKey(p.annotations, s"${caseClass.typeName.full}#${p.label}")) - .filter(_._2.isDefined) - require( - keys.size <= 1, - s"More than one field with @key annotation: ${caseClass.typeName.full}#[${keys.map(_._1.label).mkString(", ")}]" - ) - keys.headOption match { - case None => (-1, None) - case Some((p, k)) => + def join[T: KeyField](caseClass: CaseClass[Typeclass, T]): Record[T] = { + if (caseClass.isValueClass) { + val p = caseClass.parameters.head + val tc = p.typeclass + new ValueClassRecord[T] { + override lazy val keyField: KeyField[T] = tc.keyField.map(p.dereference) + override def from(v: Value)(cm: CaseMapper): T = caseClass.construct(_ => tc.from(v)(cm)) + override def to(v: T)(cm: CaseMapper): Value.Builder = tc.to(p.dereference(v))(cm) + } + } else { + new ProductRecord[T] { + private val (keyIndex, keyOpt): (Int, Option[key]) = { + val keys = caseClass.parameters + .map(p => p -> getKey(p.annotations, s"${caseClass.typeName.full}#${p.label}")) + .filter(_._2.isDefined) require( - !p.typeclass.keyField.isInstanceOf[KeyField.NotSupported[_]], - s"No KeyField[T] instance: ${caseClass.typeName.full}#${p.label}" + keys.size <= 1, + s"More than one field with @key annotation: ${caseClass.typeName.full}#[${keys.map(_._1.label).mkString(", ")}]" ) - (p.index, k) - } - } - - private val excludeFromIndexes: Array[Boolean] = { - val a = new Array[Boolean](caseClass.parameters.length) - caseClass.parameters.foreach { p => - a(p.index) = getExcludeFromIndexes(p.annotations, s"${caseClass.typeName.full}#${p.label}") - } - a - } - - override val keyField: KeyField[T] = implicitly[KeyField[T]] - - override def fromEntity(v: Entity)(cm: CaseMapper): T = - caseClass.construct { p => - val f = v.getPropertiesOrDefault(cm.map(p.label), null) - if (f == null && p.default.isDefined) { - p.default.get - } else { - p.typeclass.from(f)(cm) + keys.headOption match { + case None => (-1, None) + case Some((p, k)) => + require( + !p.typeclass.keyField.isInstanceOf[KeyField.NotSupported[_]], + s"No KeyField[T] instance: ${caseClass.typeName.full}#${p.label}" + ) + (p.index, k) + } } - } - override def toEntity(v: T)(cm: CaseMapper): Entity.Builder = - caseClass.parameters.foldLeft(Entity.newBuilder()) { (eb, p) => - val value = p.dereference(v) - val vb = p.typeclass.to(value)(cm) - if (vb != null) { - eb.putProperties( - cm.map(p.label), - vb.setExcludeFromIndexes(excludeFromIndexes(p.index)) - .build() - ) + private val excludeFromIndexes: Array[Boolean] = { + val a = new Array[Boolean](caseClass.parameters.length) + caseClass.parameters.foreach { p => + a(p.index) = + getExcludeFromIndexes(p.annotations, s"${caseClass.typeName.full}#${p.label}") + } + a } - if (p.index == keyIndex) { - val k = keyOpt.get - val partitionId = { - val b = PartitionId.newBuilder() - if (k.project != null) { - b.setProjectId(k.project) + + override val keyField: KeyField[T] = implicitly[KeyField[T]] + + override def fromEntity(v: Entity)(cm: CaseMapper): T = + caseClass.construct { p => + val f = v.getPropertiesOrDefault(cm.map(p.label), null) + if (f == null && p.default.isDefined) { + p.default.get + } else { + p.typeclass.from(f)(cm) } - b.setNamespaceId(if (k.namespace != null) k.namespace else caseClass.typeName.owner) } - val path = { - val b = Key.PathElement.newBuilder() - b.setKind(if (k.kind != null) k.kind else caseClass.typeName.short) - p.typeclass.keyField.setKey(b, value) + + override def toEntity(v: T)(cm: CaseMapper): Entity.Builder = + caseClass.parameters.foldLeft(Entity.newBuilder()) { (eb, p) => + val value = p.dereference(v) + val vb = p.typeclass.to(value)(cm) + if (vb != null) { + eb.putProperties( + cm.map(p.label), + vb.setExcludeFromIndexes(excludeFromIndexes(p.index)) + .build() + ) + } + if (p.index == keyIndex) { + val k = keyOpt.get + val partitionId = { + val b = PartitionId.newBuilder() + if (k.project != null) { + b.setProjectId(k.project) + } + b.setNamespaceId(if (k.namespace != null) k.namespace else caseClass.typeName.owner) + } + val path = { + val b = Key.PathElement.newBuilder() + b.setKind(if (k.kind != null) k.kind else caseClass.typeName.short) + p.typeclass.keyField.setKey(b, value) + } + val kb = Key + .newBuilder() + .setPartitionId(partitionId) + .addPath(path) + eb.setKey(kb) + } + eb } - val kb = Key - .newBuilder() - .setPartitionId(partitionId) - .addPath(path) - eb.setKey(kb) - } - eb - } - private def getKey(annotations: Seq[Any], name: String): Option[key] = { - val keys = annotations.collect { case k: key => k } - require(keys.size <= 1, s"More than one @key annotation: $name") - keys.headOption - } + private def getKey(annotations: Seq[Any], name: String): Option[key] = { + val keys = annotations.collect { case k: key => k } + require(keys.size <= 1, s"More than one @key annotation: $name") + keys.headOption + } - private def getExcludeFromIndexes(annotations: Seq[Any], name: String): Boolean = { - val excludes = annotations.collect { case e: excludeFromIndexes => e.exclude } - require(excludes.size <= 1, s"More than one @excludeFromIndexes annotation: $name") - excludes.headOption.getOrElse(false) + private def getExcludeFromIndexes(annotations: Seq[Any], name: String): Boolean = { + val excludes = annotations.collect { case e: excludeFromIndexes => e.exclude } + require(excludes.size <= 1, s"More than one @excludeFromIndexes annotation: $name") + excludes.headOption.getOrElse(false) + } + } } } diff --git a/datastore/src/test/scala/magnolify/datastore/test/EntityTypeSuite.scala b/datastore/src/test/scala/magnolify/datastore/test/EntityTypeSuite.scala index 0b95dc888..2f90bac2f 100644 --- a/datastore/src/test/scala/magnolify/datastore/test/EntityTypeSuite.scala +++ b/datastore/src/test/scala/magnolify/datastore/test/EntityTypeSuite.scala @@ -77,6 +77,14 @@ class EntityTypeSuite extends MagnolifySuite { test[Custom] } + test("AnyVal") { + implicit val et: EntityType[HasValueClass] = EntityType[HasValueClass] + test[HasValueClass] + + val record = et(HasValueClass(ValueClass("String"))) + assert(record.getPropertiesOrThrow("vc").getStringValue == "String") + } + { implicit val arbByteString: Arbitrary[ByteString] = Arbitrary(Gen.alphaNumStr.map(ByteString.copyFromUtf8)) From e1f032b8f570e0c07a64a5b229da67e136c1f6fa Mon Sep 17 00:00:00 2001 From: Michel Davit Date: Mon, 10 Oct 2022 16:24:20 +0200 Subject: [PATCH 08/12] Add funnel value-class support --- .../guava/semiauto/FunnelDerivation.scala | 5 ++- .../guava/test/FunnelDerivationSuite.scala | 31 ++++++++++++++----- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/guava/src/main/scala/magnolify/guava/semiauto/FunnelDerivation.scala b/guava/src/main/scala/magnolify/guava/semiauto/FunnelDerivation.scala index f4982bdfb..9fa16d55e 100644 --- a/guava/src/main/scala/magnolify/guava/semiauto/FunnelDerivation.scala +++ b/guava/src/main/scala/magnolify/guava/semiauto/FunnelDerivation.scala @@ -27,7 +27,10 @@ object FunnelDerivation { def join[T](caseClass: ReadOnlyCaseClass[Typeclass, T]): Typeclass[T] = new Funnel[T] { override def funnel(from: T, into: PrimitiveSink): Unit = - if (caseClass.parameters.isEmpty) { + if (caseClass.isValueClass) { + val p = caseClass.parameters.head + p.typeclass.funnel(p.dereference(from), into) + } else if (caseClass.parameters.isEmpty) { into.putString(caseClass.typeName.short, Charsets.UTF_8) } else { caseClass.parameters.foreach { p => diff --git a/guava/src/test/scala/magnolify/guava/test/FunnelDerivationSuite.scala b/guava/src/test/scala/magnolify/guava/test/FunnelDerivationSuite.scala index f204d813c..5bb01b3b5 100644 --- a/guava/src/test/scala/magnolify/guava/test/FunnelDerivationSuite.scala +++ b/guava/src/test/scala/magnolify/guava/test/FunnelDerivationSuite.scala @@ -16,12 +16,11 @@ package magnolify.guava.test -import java.io.{ByteArrayOutputStream, ObjectOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} import java.net.URI import java.nio.ByteBuffer -import java.nio.charset.Charset +import java.nio.charset.{Charset, StandardCharsets} import java.time.Duration - import com.google.common.hash.{Funnel, PrimitiveSink} import magnolify.guava.auto._ import magnolify.scalacheck.auto._ @@ -74,6 +73,19 @@ class FunnelDerivationSuite extends MagnolifySuite { test[Custom] } + test("AnyVal") { + implicit val f: Funnel[HasValueClass] = FunnelDerivation[HasValueClass] + test[HasValueClass] + + val sink = new BytesSink() + f.funnel(HasValueClass(ValueClass("String")), sink) + + val ois = new ObjectInputStream(new ByteArrayInputStream(sink.toBytes)) + assert(ois.readInt() == 0) + assert(ois.readUTF() == "String") + assert(ois.available() == 0) + } + test[Node] test[GNode[Int]] test[Shape] @@ -98,16 +110,16 @@ class BytesSink extends PrimitiveSink { } override def putBytes(bytes: Array[Byte]): PrimitiveSink = { - baos.write(bytes) + oos.write(bytes) this } override def putBytes(bytes: Array[Byte], off: Int, len: Int): PrimitiveSink = { - baos.write(bytes, off, len) + oos.write(bytes, off, len) this } override def putBytes(bytes: ByteBuffer): PrimitiveSink = { - baos.write(bytes.array(), bytes.position(), bytes.limit()) + oos.write(bytes.array(), bytes.position(), bytes.limit()) this } @@ -151,6 +163,9 @@ class BytesSink extends PrimitiveSink { this } - override def putString(charSequence: CharSequence, charset: Charset): PrimitiveSink = - putBytes(charset.encode(charSequence.toString)) + override def putString(charSequence: CharSequence, charset: Charset): PrimitiveSink = { + require(charset == StandardCharsets.UTF_8) + oos.writeUTF(charSequence.toString) + this + } } From 599801b4fddd1f2805ecc30a898a1e014459eb0e Mon Sep 17 00:00:00 2001 From: Michel Davit Date: Mon, 10 Oct 2022 17:00:38 +0200 Subject: [PATCH 09/12] Add parquet value-class support --- .../scala/magnolify/parquet/ParquetType.scala | 102 +++++++++++------- .../parquet/test/ParquetTypeSuite.scala | 14 ++- 2 files changed, 79 insertions(+), 37 deletions(-) diff --git a/parquet/src/main/scala/magnolify/parquet/ParquetType.scala b/parquet/src/main/scala/magnolify/parquet/ParquetType.scala index 152c1f1c6..7281b8dee 100644 --- a/parquet/src/main/scala/magnolify/parquet/ParquetType.scala +++ b/parquet/src/main/scala/magnolify/parquet/ParquetType.scala @@ -225,52 +225,82 @@ sealed trait ParquetField[T] extends Serializable { object ParquetField { type Typeclass[T] = ParquetField[T] - sealed trait Record[T] extends ParquetField[T] { + sealed trait Record[T] extends ParquetField[T] + sealed trait ValueClassRecord[T] extends Record[T] + sealed trait ProductRecord[T] extends Record[T] { override protected val isGroup: Boolean = true override protected def isEmpty(v: T): Boolean = false } - def join[T](caseClass: CaseClass[Typeclass, T]): Record[T] = new Record[T] { - override def buildSchema(cm: CaseMapper): Type = - caseClass.parameters - .foldLeft(Types.requiredGroup()) { (g, p) => - g.addField(Schema.rename(p.typeclass.schema(cm), cm.map(p.label))) - } - .named(caseClass.typeName.full) - - override val hasAvroArray: Boolean = caseClass.parameters.exists(_.typeclass.hasAvroArray) - - override def write(c: RecordConsumer, v: T)(cm: CaseMapper): Unit = { - caseClass.parameters.foreach { p => - val x = p.dereference(v) - if (!p.typeclass.isEmpty(x)) { - val name = cm.map(p.label) - c.startField(name, p.index) - p.typeclass.writeGroup(c, x)(cm) - c.endField(name, p.index) + def join[T](caseClass: CaseClass[Typeclass, T]): Record[T] = { + if (caseClass.isValueClass) { + val p = caseClass.parameters.head + val tc = p.typeclass + new ValueClassRecord[T] { + override protected def buildSchema(cm: CaseMapper): Type = tc.buildSchema(cm) + override protected def isEmpty(v: T): Boolean = tc.isEmpty(p.dereference(v)) + override def write(c: RecordConsumer, v: T)(cm: CaseMapper): Unit = + tc.writeGroup(c, p.dereference(v))(cm) + override def newConverter: TypeConverter[T] = { + val buffered = tc + .newConverter + .asInstanceOf[TypeConverter.Buffered[p.PType]] + new TypeConverter.Delegate[p.PType, T](buffered) { + override def get: T = inner.get(b => caseClass.construct(_ => b.head)) + } } } - } - - override def newConverter: TypeConverter[T] = - new GroupConverter with TypeConverter.Buffered[T] { - private val fieldConverters = caseClass.parameters.map(_.typeclass.newConverter) - override def isPrimitive: Boolean = false - override def getConverter(fieldIndex: Int): Converter = fieldConverters(fieldIndex) - override def start(): Unit = () - override def end(): Unit = { - val value = caseClass.construct { p => - try { - fieldConverters(p.index).get - } catch { - case e: IllegalArgumentException => - val field = s"${caseClass.typeName.full}#${p.label}" - throw new ParquetDecodingException(s"Failed to decode $field: ${e.getMessage}", e) + } else { + new ProductRecord[T] { + override def buildSchema(cm: CaseMapper): Type = + caseClass.parameters + .foldLeft(Types.requiredGroup()) { (g, p) => + g.addField(Schema.rename(p.typeclass.schema(cm), cm.map(p.label))) + } + .named(caseClass.typeName.full) + + override val hasAvroArray: Boolean = caseClass.parameters.exists(_.typeclass.hasAvroArray) + + override def write(c: RecordConsumer, v: T)(cm: CaseMapper): Unit = { + caseClass.parameters.foreach { p => + val x = p.dereference(v) + if (!p.typeclass.isEmpty(x)) { + val name = cm.map(p.label) + c.startField(name, p.index) + p.typeclass.writeGroup(c, x)(cm) + c.endField(name, p.index) } } - addValue(value) } + + override def newConverter: TypeConverter[T] = + new GroupConverter with TypeConverter.Buffered[T] { + private val fieldConverters = caseClass.parameters.map(_.typeclass.newConverter) + + override def isPrimitive: Boolean = false + + override def getConverter(fieldIndex: Int): Converter = fieldConverters(fieldIndex) + + override def start(): Unit = () + + override def end(): Unit = { + val value = caseClass.construct { p => + try { + fieldConverters(p.index).get + } catch { + case e: IllegalArgumentException => + val field = s"${caseClass.typeName.full}#${p.label}" + throw new ParquetDecodingException( + s"Failed to decode $field: ${e.getMessage}", + e + ) + } + } + addValue(value) + } + } } + } } @implicitNotFound("Cannot derive ParquetType for sealed trait") diff --git a/parquet/src/test/scala/magnolify/parquet/test/ParquetTypeSuite.scala b/parquet/src/test/scala/magnolify/parquet/test/ParquetTypeSuite.scala index 94ae637b1..e7ae70eb0 100644 --- a/parquet/src/test/scala/magnolify/parquet/test/ParquetTypeSuite.scala +++ b/parquet/src/test/scala/magnolify/parquet/test/ParquetTypeSuite.scala @@ -20,7 +20,6 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import java.net.URI import java.time._ import java.util.UUID - import cats._ import magnolify.cats.auto._ import magnolify.parquet._ @@ -30,7 +29,9 @@ import magnolify.shared.CaseMapper import magnolify.test.Simple._ import magnolify.test.Time._ import magnolify.test._ +import org.apache.parquet.hadoop.ParquetFileReader import org.apache.parquet.io._ +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.scalacheck._ import scala.reflect.ClassTag @@ -88,6 +89,17 @@ class ParquetTypeSuite extends MagnolifySuite { test[Custom] } + test("AnyVal") { + implicit val pt: ParquetType[HasValueClass] = ParquetType[HasValueClass] + test[HasValueClass] + + val schema = pt.schema + val index = schema.getFieldIndex("vc") + val field = schema.getFields.get(index) + assert(field.isPrimitive) + assert(field.asPrimitiveType().getPrimitiveTypeName == PrimitiveTypeName.BINARY) + } + { implicit val eqByteArray: Eq[Array[Byte]] = Eq.by(_.toList) test[ParquetTypes] From 182a489560286f310b06e5f0a19350838cbf73a1 Mon Sep 17 00:00:00 2001 From: Michel Davit Date: Mon, 10 Oct 2022 17:22:52 +0200 Subject: [PATCH 10/12] Add protobuf value-class support --- .../magnolify/protobuf/ProtobufType.scala | 203 ++++++++++-------- .../protobuf/test/ProtobufTypeSuite.scala | 6 + 2 files changed, 120 insertions(+), 89 deletions(-) diff --git a/protobuf/src/main/scala/magnolify/protobuf/ProtobufType.scala b/protobuf/src/main/scala/magnolify/protobuf/ProtobufType.scala index 76dc0ad95..9f0ed6eaa 100644 --- a/protobuf/src/main/scala/magnolify/protobuf/ProtobufType.scala +++ b/protobuf/src/main/scala/magnolify/protobuf/ProtobufType.scala @@ -68,31 +68,35 @@ object ProtobufType { f: ProtobufField.Record[T], ct: ClassTag[MsgT], po: ProtobufOption - ): ProtobufType[T, MsgT] = - new ProtobufType[T, MsgT] { - { - val descriptor = ct.runtimeClass - .getMethod("getDescriptor") - .invoke(null) - .asInstanceOf[Descriptor] - if (f.hasOptional) { - po.check(f, descriptor.getFile.getSyntax) + ): ProtobufType[T, MsgT] = f match { + case pr: ProtobufField.ProductRecord[_] => + new ProtobufType[T, MsgT] { + { + val descriptor = ct.runtimeClass + .getMethod("getDescriptor") + .invoke(null) + .asInstanceOf[Descriptor] + if (pr.hasOptional) { + po.check(pr, descriptor.getFile.getSyntax) + } + pr.checkDefaults(descriptor)(cm) } - f.checkDefaults(descriptor)(cm) - } - @transient private var _newBuilder: Method = _ - private def newBuilder: Message.Builder = { - if (_newBuilder == null) { - _newBuilder = ct.runtimeClass.getMethod("newBuilder") + @transient private var _newBuilder: Method = _ + private def newBuilder: Message.Builder = { + if (_newBuilder == null) { + _newBuilder = ct.runtimeClass.getMethod("newBuilder") + } + _newBuilder.invoke(null).asInstanceOf[Message.Builder] } - _newBuilder.invoke(null).asInstanceOf[Message.Builder] - } - private val caseMapper: CaseMapper = cm - override def from(v: MsgT): T = f.from(v)(caseMapper) - override def to(v: T): MsgT = f.to(v, newBuilder)(caseMapper).asInstanceOf[MsgT] - } + private val caseMapper: CaseMapper = cm + override def from(v: MsgT): T = pr.from(v)(caseMapper) + override def to(v: T): MsgT = pr.to(v, newBuilder)(caseMapper).asInstanceOf[MsgT] + } + case _: ProtobufField.ValueClassRecord[_] => + throw new IllegalArgumentException("Value classes are not valid ProtobufType") + } } sealed trait ProtobufField[T] extends Serializable { @@ -115,7 +119,11 @@ object ProtobufField { override type ToT = To } - sealed trait Record[T] extends Aux[T, Message, Message] { + sealed trait Record[T] extends ProtobufField[T] + sealed trait ValueClassRecord[T] extends Record[T] + sealed trait ProductRecord[T] extends Record[T] { + override type FromT = Message + override type ToT = Message override val default: Option[T] = None } @@ -123,79 +131,96 @@ object ProtobufField { type Typeclass[T] = ProtobufField[T] - def join[T](caseClass: CaseClass[Typeclass, T]): Record[T] = new Record[T] { - // One Record[T] instance may be used for multiple Message types - @transient private lazy val fieldsCache: concurrent.Map[String, Array[FieldDescriptor]] = - concurrent.TrieMap.empty - - private def getFields(descriptor: Descriptor)(cm: CaseMapper): Array[FieldDescriptor] = - fieldsCache.getOrElseUpdate( - descriptor.getFullName, { - val fields = new Array[FieldDescriptor](caseClass.parameters.size) - caseClass.parameters.foreach(p => - fields(p.index) = descriptor.findFieldByName(cm.map(p.label)) - ) - fields - } - ) - - override val hasOptional: Boolean = caseClass.parameters.exists(_.typeclass.hasOptional) - - override def checkDefaults(descriptor: Descriptor)(cm: CaseMapper): Unit = { - val syntax = descriptor.getFile.getSyntax - val fields = getFields(descriptor)(cm) - caseClass.parameters.foreach { p => - val field = fields(p.index) - val protoDefault = if (syntax == Syntax.PROTO2 && field.hasDefaultValue) { - Some(p.typeclass.fromAny(field.getDefaultValue)(cm)) - } else { - p.typeclass.default - } - p.default.foreach { d => - require( - protoDefault.contains(d), - s"Default mismatch ${caseClass.typeName.full}#${p.label}: $d != ${protoDefault.orNull}" - ) - } - if (field.getType == FieldDescriptor.Type.MESSAGE) { - p.typeclass.checkDefaults(field.getMessageType)(cm) - } + def join[T](caseClass: CaseClass[Typeclass, T]): Record[T] = { + if (caseClass.isValueClass) { + val p = caseClass.parameters.head + val tc = p.typeclass + new ValueClassRecord[T] { + override type FromT = tc.FromT + override type ToT = tc.ToT + override val hasOptional: Boolean = tc.hasOptional + override val default: Option[T] = tc.default.map(x => caseClass.construct(_ => x)) + override def from(v: FromT)(cm: CaseMapper): T = caseClass.construct(_ => tc.from(v)(cm)) + override def to(v: T, b: Message.Builder)(cm: CaseMapper): ToT = + tc.to(p.dereference(v), b)(cm) } - } - override def from(v: Message)(cm: CaseMapper): T = { - val descriptor = v.getDescriptorForType - val syntax = descriptor.getFile.getSyntax - val fields = getFields(descriptor)(cm) - - caseClass.construct { p => - val field = fields(p.index) - // hasField behaves correctly on PROTO2 optional fields - val value = if (syntax == Syntax.PROTO2 && field.isOptional && !v.hasField(field)) { - null - } else { - v.getField(field) + } else { + new ProductRecord[T] { + // One Record[T] instance may be used for multiple Message types + @transient private lazy val fieldsCache: concurrent.Map[String, Array[FieldDescriptor]] = + concurrent.TrieMap.empty + + private def getFields(descriptor: Descriptor)(cm: CaseMapper): Array[FieldDescriptor] = + fieldsCache.getOrElseUpdate( + descriptor.getFullName, { + val fields = new Array[FieldDescriptor](caseClass.parameters.size) + caseClass.parameters.foreach(p => + fields(p.index) = descriptor.findFieldByName(cm.map(p.label)) + ) + fields + } + ) + + override val hasOptional: Boolean = caseClass.parameters.exists(_.typeclass.hasOptional) + + override def checkDefaults(descriptor: Descriptor)(cm: CaseMapper): Unit = { + val syntax = descriptor.getFile.getSyntax + val fields = getFields(descriptor)(cm) + caseClass.parameters.foreach { p => + val field = fields(p.index) + val protoDefault = if (syntax == Syntax.PROTO2 && field.hasDefaultValue) { + Some(p.typeclass.fromAny(field.getDefaultValue)(cm)) + } else { + p.typeclass.default + } + p.default.foreach { d => + require( + protoDefault.contains(d), + s"Default mismatch ${caseClass.typeName.full}#${p.label}: $d != ${protoDefault.orNull}" + ) + } + if (field.getType == FieldDescriptor.Type.MESSAGE) { + p.typeclass.checkDefaults(field.getMessageType)(cm) + } + } } - p.typeclass.fromAny(value)(cm) - } - } - override def to(v: T, bu: Message.Builder)(cm: CaseMapper): Message = { - val fields = getFields(bu.getDescriptorForType)(cm) - - caseClass.parameters - .foldLeft(bu) { (b, p) => - val field = fields(p.index) - val value = if (field.getType == FieldDescriptor.Type.MESSAGE) { - // nested records - p.typeclass.to(p.dereference(v), b.newBuilderForField(field))(cm) - } else { - // non-nested - p.typeclass.to(p.dereference(v), null)(cm) + override def from(v: Message)(cm: CaseMapper): T = { + val descriptor = v.getDescriptorForType + val syntax = descriptor.getFile.getSyntax + val fields = getFields(descriptor)(cm) + + caseClass.construct { p => + val field = fields(p.index) + // hasField behaves correctly on PROTO2 optional fields + val value = if (syntax == Syntax.PROTO2 && field.isOptional && !v.hasField(field)) { + null + } else { + v.getField(field) + } + p.typeclass.fromAny(value)(cm) } - if (value == null) b else b.setField(field, value) } - .build() + + override def to(v: T, bu: Message.Builder)(cm: CaseMapper): Message = { + val fields = getFields(bu.getDescriptorForType)(cm) + + caseClass.parameters + .foldLeft(bu) { (b, p) => + val field = fields(p.index) + val value = if (field.getType == FieldDescriptor.Type.MESSAGE) { + // nested records + p.typeclass.to(p.dereference(v), b.newBuilderForField(field))(cm) + } else { + // non-nested + p.typeclass.to(p.dereference(v), null)(cm) + } + if (value == null) b else b.setField(field, value) + } + .build() + } + } } } diff --git a/protobuf/src/test/scala/magnolify/protobuf/test/ProtobufTypeSuite.scala b/protobuf/src/test/scala/magnolify/protobuf/test/ProtobufTypeSuite.scala index c1e4b7641..b2363fde5 100644 --- a/protobuf/src/test/scala/magnolify/protobuf/test/ProtobufTypeSuite.scala +++ b/protobuf/src/test/scala/magnolify/protobuf/test/ProtobufTypeSuite.scala @@ -95,6 +95,10 @@ class ProtobufTypeSuite extends BaseProtobufTypeSuite { test[Collections, CollectionP3] test[MoreCollections, MoreCollectionP3] } + + test("AnyVal") { + test[ProtoHasValueClass, IntegersP2] + } } // Workaround for "Method too large: magnolify/protobuf/test/ProtobufTypeSuite. ()V" @@ -219,6 +223,8 @@ object Proto3Enums { ProtobufField.enum[ADT.Color, EnumsP3.ScalaEnums] } +case class ProtoValueClass(value: Long) extends AnyVal +case class ProtoHasValueClass(i: Int, l: ProtoValueClass) case class UnsafeByte(i: Byte, l: Long) case class UnsafeChar(i: Char, l: Long) case class UnsafeShort(i: Short, l: Long) From 598e90234434f8990b85043c60ba3f300c1cc915 Mon Sep 17 00:00:00 2001 From: Michel Davit Date: Mon, 10 Oct 2022 17:33:30 +0200 Subject: [PATCH 11/12] Ad tensorflow value-class support --- .../magnolify/tensorflow/ExampleType.scala | 109 ++++++++++-------- .../tensorflow/test/ExampleTypeSuite.scala | 14 +++ 2 files changed, 76 insertions(+), 47 deletions(-) diff --git a/tensorflow/src/main/scala/magnolify/tensorflow/ExampleType.scala b/tensorflow/src/main/scala/magnolify/tensorflow/ExampleType.scala index c0ed9bfec..87f29bf8d 100644 --- a/tensorflow/src/main/scala/magnolify/tensorflow/ExampleType.scala +++ b/tensorflow/src/main/scala/magnolify/tensorflow/ExampleType.scala @@ -66,7 +66,7 @@ sealed trait ExampleField[T] extends Serializable { } object ExampleField { - trait Primitive[T] extends ExampleField[T] { + sealed trait Primitive[T] extends ExampleField[T] { type ValueT def fromFeature(v: Feature): ju.List[T] def toFeature(v: Iterable[T]): Feature @@ -91,60 +91,75 @@ object ExampleField { def featureSchema(cm: CaseMapper): FeatureSchema } - trait Record[T] extends ExampleField[T] + sealed trait Record[T] extends ExampleField[T] // //////////////////////////////////////////////// type Typeclass[T] = ExampleField[T] - def join[T](caseClass: CaseClass[Typeclass, T]): Record[T] = new Record[T] { - private def key(prefix: String, label: String): String = - if (prefix == null) label else s"$prefix.$label" - - override def get(f: Features, k: String)(cm: CaseMapper): Value[T] = { - var fallback = true - val r = caseClass.construct { p => - val fieldKey = key(k, cm.map(p.label)) - val fieldValue = p.typeclass.get(f, fieldKey)(cm) - if (fieldValue.isSome) { - fallback = false - } - fieldValue.getOrElse(p.default) - } - // result is default if all fields are default - if (fallback) Value.Default(r) else Value.Some(r) - } - - override def put(f: Features.Builder, k: String, v: T)(cm: CaseMapper): Features.Builder = - caseClass.parameters.foldLeft(f) { (f, p) => - val fieldKey = key(k, cm.map(p.label)) - val fieldValue = p.dereference(v) - p.typeclass.put(f, fieldKey, fieldValue)(cm) - f + def join[T](caseClass: CaseClass[Typeclass, T]): Record[T] = { + if (caseClass.isValueClass) { + val p = caseClass.parameters.head + val tc = p.typeclass + new Record[T] { + override protected def buildSchema(cm: CaseMapper): Schema = + tc.buildSchema(cm) + override def get(f: Features, k: String)(cm: CaseMapper): Value[T] = + tc.get(f, k)(cm).map(x => caseClass.construct(_ => x)) + override def put(f: Features.Builder, k: String, v: T)(cm: CaseMapper): Features.Builder = + tc.put(f, k, p.dereference(v))(cm) } + } else { + new Record[T] { + private def key(prefix: String, label: String): String = + if (prefix == null) label else s"$prefix.$label" + + override def get(f: Features, k: String)(cm: CaseMapper): Value[T] = { + var fallback = true + val r = caseClass.construct { p => + val fieldKey = key(k, cm.map(p.label)) + val fieldValue = p.typeclass.get(f, fieldKey)(cm) + if (fieldValue.isSome) { + fallback = false + } + fieldValue.getOrElse(p.default) + } + // result is default if all fields are default + if (fallback) Value.Default(r) else Value.Some(r) + } - override protected def buildSchema(cm: CaseMapper): Schema = { - val sb = Schema.newBuilder() - getDoc(caseClass.annotations, caseClass.typeName.full).foreach(sb.setAnnotation) - caseClass.parameters.foldLeft(sb) { (b, p) => - val fieldNane = cm.map(p.label) - val fieldSchema = p.typeclass.schema(cm) - val fieldFeatures = fieldSchema.getFeatureList.asScala.map { f => - val fb = f.toBuilder - // if schema does not have a name (eg. primitive), use the fieldNane - // otherwise prepend to the feature name (eg. nested records) - val fieldKey = if (f.hasName) key(fieldNane, f.getName) else fieldNane - fb.setName(fieldKey) - // if field already has a doc, keep it - // otherwise use the parameter annotation - val fieldDoc = getDoc(p.annotations, s"${caseClass.typeName.full}#$fieldKey") - if (!f.hasAnnotation) fieldDoc.foreach(fb.setAnnotation) - fb.build() - }.asJava - b.addAllFeature(fieldFeatures) - b + override def put(f: Features.Builder, k: String, v: T)(cm: CaseMapper): Features.Builder = + caseClass.parameters.foldLeft(f) { (f, p) => + val fieldKey = key(k, cm.map(p.label)) + val fieldValue = p.dereference(v) + p.typeclass.put(f, fieldKey, fieldValue)(cm) + f + } + + override protected def buildSchema(cm: CaseMapper): Schema = { + val sb = Schema.newBuilder() + getDoc(caseClass.annotations, caseClass.typeName.full).foreach(sb.setAnnotation) + caseClass.parameters.foldLeft(sb) { (b, p) => + val fieldNane = cm.map(p.label) + val fieldSchema = p.typeclass.schema(cm) + val fieldFeatures = fieldSchema.getFeatureList.asScala.map { f => + val fb = f.toBuilder + // if schema does not have a name (eg. primitive), use the fieldNane + // otherwise prepend to the feature name (eg. nested records) + val fieldKey = if (f.hasName) key(fieldNane, f.getName) else fieldNane + fb.setName(fieldKey) + // if field already has a doc, keep it + // otherwise use the parameter annotation + val fieldDoc = getDoc(p.annotations, s"${caseClass.typeName.full}#$fieldKey") + if (!f.hasAnnotation) fieldDoc.foreach(fb.setAnnotation) + fb.build() + }.asJava + b.addAllFeature(fieldFeatures) + b + } + sb.build() + } } - sb.build() } } diff --git a/tensorflow/src/test/scala/magnolify/tensorflow/test/ExampleTypeSuite.scala b/tensorflow/src/test/scala/magnolify/tensorflow/test/ExampleTypeSuite.scala index c29765713..d8f171837 100644 --- a/tensorflow/src/test/scala/magnolify/tensorflow/test/ExampleTypeSuite.scala +++ b/tensorflow/src/test/scala/magnolify/tensorflow/test/ExampleTypeSuite.scala @@ -82,6 +82,20 @@ class ExampleTypeSuite extends MagnolifySuite { test[Custom] } + test("AnyVal") { + implicit val et: ExampleType[HasValueClass] = ExampleType[HasValueClass] + test[HasValueClass] + + val schema = et.schema + val feature = schema.getFeatureList.asScala.find(_.getName == "vc").get + assert(feature.getType == FeatureType.BYTES) + + val record = et(HasValueClass(ValueClass("String"))) + val value = record.getFeatures.getFeatureMap.get("vc") + assert(value.hasBytesList) + assert(value.getBytesList.getValue(0).toStringUtf8 == "String") + } + { implicit val arbByteString: Arbitrary[ByteString] = Arbitrary(Gen.alphaNumStr.map(ByteString.copyFromUtf8)) From 521773d9af7b8bbed885141095af53f212311c41 Mon Sep 17 00:00:00 2001 From: Michel Davit Date: Tue, 11 Oct 2022 09:22:38 +0200 Subject: [PATCH 12/12] Make magnolia macro return a normal field --- .../main/scala/magnolify/avro/AvroType.scala | 49 +++++------- .../magnolify/bigquery/TableRowType.scala | 30 +++---- .../magnolify/bigtable/BigtableType.scala | 28 ++++--- .../magnolify/datastore/EntityType.scala | 33 ++++---- .../scala/magnolify/neo4j/ValueType.scala | 78 ++++++++++--------- .../magnolify/neo4j/ValueTypeSuite.scala | 24 ++---- .../scala/magnolify/parquet/ParquetType.scala | 45 +++++------ .../magnolify/parquet/unsafe/package.scala | 5 +- .../parquet/test/ParquetTypeSuite.scala | 1 - .../magnolify/protobuf/ProtobufType.scala | 36 ++++----- .../magnolify/protobuf/unsafe/package.scala | 5 +- .../scala/magnolify/shared/Converter.scala | 6 ++ .../magnolify/tensorflow/ExampleType.scala | 31 ++++---- 13 files changed, 180 insertions(+), 191 deletions(-) diff --git a/avro/src/main/scala/magnolify/avro/AvroType.scala b/avro/src/main/scala/magnolify/avro/AvroType.scala index 7851bf800..f785965ae 100644 --- a/avro/src/main/scala/magnolify/avro/AvroType.scala +++ b/avro/src/main/scala/magnolify/avro/AvroType.scala @@ -20,14 +20,13 @@ import java.nio.ByteBuffer import java.time._ import java.{util => ju} import magnolia1._ -import magnolify.avro.AvroField.{ProductRecord, ValueClassRecord} import magnolify.shared._ import magnolify.shims.FactoryCompat import org.apache.avro.generic.GenericData.EnumSymbol import org.apache.avro.generic._ import org.apache.avro.{JsonProperties, LogicalType, LogicalTypes, Schema} -import scala.annotation.{StaticAnnotation, implicitNotFound} +import scala.annotation.{implicitNotFound, StaticAnnotation} import scala.collection.concurrent import scala.language.experimental.macros import scala.language.implicitConversions @@ -46,20 +45,20 @@ sealed trait AvroType[T] extends Converter[T, GenericRecord, GenericRecord] { } object AvroType { - implicit def apply[T: AvroField.Record]: AvroType[T] = AvroType(CaseMapper.identity) + implicit def apply[T: AvroField]: AvroType[T] = AvroType(CaseMapper.identity) - def apply[T](cm: CaseMapper)(implicit f: AvroField.Record[T]): AvroType[T] = { + def apply[T](cm: CaseMapper)(implicit f: AvroField[T]): AvroType[T] = { f match { - case pr: ProductRecord[_] => - pr.schema(cm) // fail fast on bad annotations + case r: AvroField.Record[_] => + r.schema(cm) // fail fast on bad annotations new AvroType[T] { private val caseMapper: CaseMapper = cm - @transient override lazy val schema: Schema = pr.schema(caseMapper) - override def from(v: GenericRecord): T = pr.from(v)(caseMapper) - override def to(v: T): GenericRecord = pr.to(v)(caseMapper) + @transient override lazy val schema: Schema = r.schema(caseMapper) + override def from(v: GenericRecord): T = r.from(v)(caseMapper) + override def to(v: T): GenericRecord = r.to(v)(caseMapper) } - case _: ValueClassRecord[_] => - throw new IllegalArgumentException("Value classes are not valid AvroType") + case _ => + throw new IllegalArgumentException(s"AvroType can only be created from Record. Got $f") } } } @@ -92,35 +91,25 @@ object AvroField { override type FromT = From override type ToT = To } - - sealed trait Record[T] extends AvroField[T] - sealed trait ValueClassRecord[T] extends Record[T] - sealed trait ProductRecord[T] extends Record[T] { - override type FromT = GenericRecord - override type ToT = GenericRecord - } + sealed trait Record[T] extends Aux[T, GenericRecord, GenericRecord] // //////////////////////////////////////////////// type Typeclass[T] = AvroField[T] - def join[T](caseClass: CaseClass[Typeclass, T]): Record[T] = { + def join[T](caseClass: CaseClass[Typeclass, T]): AvroField[T] = { if (caseClass.isValueClass) { val p = caseClass.parameters.head val tc = p.typeclass - new ValueClassRecord[T] { + new AvroField[T] { override type FromT = tc.FromT override type ToT = tc.ToT - override protected def buildSchema(cm: CaseMapper): Schema = tc.buildSchema(cm) - override def from(v: FromT)(cm: CaseMapper): T = - caseClass.construct(_ => tc.fromAny(v)(cm)) - override def to(v: T)(cm: CaseMapper): ToT = { - tc.to(p.dereference(v))(cm) - } + override def from(v: FromT)(cm: CaseMapper): T = caseClass.construct(_ => tc.fromAny(v)(cm)) + override def to(v: T)(cm: CaseMapper): ToT = tc.to(p.dereference(v))(cm) } } else { - new ProductRecord[T] { + new Record[T] { override protected def buildSchema(cm: CaseMapper): Schema = Schema .createRecord( caseClass.typeName.short, @@ -173,9 +162,9 @@ object AvroField { @implicitNotFound("Cannot derive AvroField for sealed trait") private sealed trait Dispatchable[T] - def split[T: Dispatchable](sealedTrait: SealedTrait[Typeclass, T]): Record[T] = ??? + def split[T: Dispatchable](sealedTrait: SealedTrait[Typeclass, T]): AvroField[T] = ??? - implicit def gen[T]: Record[T] = macro Magnolia.gen[T] + implicit def gen[T]: AvroField[T] = macro Magnolia.gen[T] // //////////////////////////////////////////////// @@ -226,7 +215,7 @@ object AvroField { override def to(v: Array[Byte])(cm: CaseMapper): ByteBuffer = ByteBuffer.wrap(v) } - implicit def afEnum[T](implicit et: EnumType[T]): AvroField[T] = + implicit def afEnum[T](implicit et: EnumType[T], lp: shapeless.LowPriority): AvroField[T] = // Avro 1.9+ added a type parameter for `GenericEnumSymbol`, breaking 1.8 compatibility // Some reader, i.e. `AvroParquetReader` reads enums as `Utf8` new Aux[T, AnyRef, EnumSymbol] { diff --git a/bigquery/src/main/scala/magnolify/bigquery/TableRowType.scala b/bigquery/src/main/scala/magnolify/bigquery/TableRowType.scala index 0b81d974e..1475b4eca 100644 --- a/bigquery/src/main/scala/magnolify/bigquery/TableRowType.scala +++ b/bigquery/src/main/scala/magnolify/bigquery/TableRowType.scala @@ -20,7 +20,6 @@ import java.{util => ju} import com.google.api.services.bigquery.model.{TableFieldSchema, TableRow, TableSchema} import com.google.common.io.BaseEncoding import magnolia1._ -import magnolify.bigquery.TableRowField.{ProductRecord, ValueClassRecord} import magnolify.shared.{CaseMapper, Converter} import magnolify.shims.FactoryCompat @@ -42,11 +41,11 @@ sealed trait TableRowType[T] extends Converter[T, TableRow, TableRow] { } object TableRowType { - implicit def apply[T: TableRowField.Record]: TableRowType[T] = TableRowType(CaseMapper.identity) + implicit def apply[T: TableRowField]: TableRowType[T] = TableRowType(CaseMapper.identity) - def apply[T](cm: CaseMapper)(implicit f: TableRowField.Record[T]): TableRowType[T] = { + def apply[T](cm: CaseMapper)(implicit f: TableRowField[T]): TableRowType[T] = { f match { - case pr: ProductRecord[_] => + case pr: TableRowField.Record[_] => pr.fieldSchema(cm) // fail fast on bad annotations new TableRowType[T] { private val caseMapper: CaseMapper = cm @@ -56,8 +55,8 @@ object TableRowType { override def from(v: TableRow): T = pr.from(v)(caseMapper) override def to(v: T): TableRow = pr.to(v)(caseMapper) } - case _: ValueClassRecord[_] => - throw new IllegalArgumentException("Value classes are not valid TableRowType") + case _ => + throw new IllegalArgumentException(s"TableRowType can only be created from Record. Got $f") } } } @@ -86,23 +85,16 @@ object TableRowField { } sealed trait Generic[T] extends Aux[T, Any, Any] - - sealed trait Record[T] extends TableRowField[T] - sealed trait ValueClassRecord[T] extends Record[T] - sealed trait ProductRecord[T] extends Record[T] { - override type FromT = ju.Map[String, AnyRef] - override type ToT = TableRow - } + sealed trait Record[T] extends Aux[T, ju.Map[String, AnyRef], TableRow] // //////////////////////////////////////////////// - type Typeclass[T] = TableRowField[T] - def join[T](caseClass: CaseClass[Typeclass, T]): Record[T] = { + def join[T](caseClass: CaseClass[Typeclass, T]): TableRowField[T] = { if (caseClass.isValueClass) { val p = caseClass.parameters.head val tc = p.typeclass - new ValueClassRecord[T] { + new TableRowField[T] { override type FromT = tc.FromT override type ToT = tc.ToT override protected def buildSchema(cm: CaseMapper): TableFieldSchema = tc.buildSchema(cm) @@ -110,7 +102,7 @@ object TableRowField { override def to(v: T)(cm: CaseMapper): ToT = tc.to(p.dereference(v))(cm) } } else { - new ProductRecord[T] { + new Record[T] { override protected def buildSchema(cm: CaseMapper): TableFieldSchema = { // do not use a scala wrapper in the schema, so clone() works val fields = new ju.ArrayList[TableFieldSchema](caseClass.parameters.size) @@ -159,9 +151,9 @@ object TableRowField { @implicitNotFound("Cannot derive TableRowField for sealed trait") private sealed trait Dispatchable[T] - def split[T: Dispatchable](sealedTrait: SealedTrait[Typeclass, T]): Record[T] = ??? + def split[T: Dispatchable](sealedTrait: SealedTrait[Typeclass, T]): TableRowField[T] = ??? - implicit def gen[T]: Record[T] = macro Magnolia.gen[T] + implicit def gen[T]: TableRowField[T] = macro Magnolia.gen[T] // //////////////////////////////////////////////// diff --git a/bigtable/src/main/scala/magnolify/bigtable/BigtableType.scala b/bigtable/src/main/scala/magnolify/bigtable/BigtableType.scala index fa7455bc1..ecca3bef7 100644 --- a/bigtable/src/main/scala/magnolify/bigtable/BigtableType.scala +++ b/bigtable/src/main/scala/magnolify/bigtable/BigtableType.scala @@ -49,14 +49,18 @@ sealed trait BigtableType[T] extends Converter[T, java.util.List[Column], Seq[Se } object BigtableType { - implicit def apply[T: BigtableField.Record]: BigtableType[T] = BigtableType(CaseMapper.identity) - - def apply[T](cm: CaseMapper)(implicit f: BigtableField.Record[T]): BigtableType[T] = - new BigtableType[T] { - private val caseMapper: CaseMapper = cm - override def from(xs: java.util.List[Column]): T = f.get(xs, null)(caseMapper).get - override def to(v: T): Seq[SetCell.Builder] = f.put(null, v)(caseMapper) - } + implicit def apply[T: BigtableField]: BigtableType[T] = BigtableType(CaseMapper.identity) + + def apply[T](cm: CaseMapper)(implicit f: BigtableField[T]): BigtableType[T] = f match { + case r: BigtableField.Record[_] => + new BigtableType[T] { + private val caseMapper: CaseMapper = cm + override def from(xs: java.util.List[Column]): T = r.get(xs, null)(caseMapper).get + override def to(v: T): Seq[SetCell.Builder] = r.put(null, v)(caseMapper) + } + case _ => + throw new IllegalArgumentException(s"BigtableType can only be created from Record. Got $f") + } def mutationsToRow(key: ByteString, mutations: Seq[Mutation]): Row = { val families = mutations @@ -135,11 +139,11 @@ object BigtableField { type Typeclass[T] = BigtableField[T] - def join[T](caseClass: CaseClass[Typeclass, T]): Record[T] = { + def join[T](caseClass: CaseClass[Typeclass, T]): BigtableField[T] = { if (caseClass.isValueClass) { val p = caseClass.parameters.head val tc = p.typeclass - new Record[T] { + new BigtableField[T] { override def get(xs: util.List[Column], k: String)(cm: CaseMapper): Value[T] = tc.get(xs, k)(cm).map(x => caseClass.construct(_ => x)) override def put(k: String, v: T)(cm: CaseMapper): Seq[SetCell.Builder] = @@ -174,9 +178,9 @@ object BigtableField { @implicitNotFound("Cannot derive BigtableField for sealed trait") private sealed trait Dispatchable[T] - def split[T: Dispatchable](sealedTrait: SealedTrait[Typeclass, T]): Record[T] = ??? + def split[T: Dispatchable](sealedTrait: SealedTrait[Typeclass, T]): BigtableField[T] = ??? - implicit def gen[T]: Record[T] = macro Magnolia.gen[T] + implicit def gen[T]: BigtableField[T] = macro Magnolia.gen[T] def apply[T](implicit f: BigtableField[T]): BigtableField[T] = f diff --git a/datastore/src/main/scala/magnolify/datastore/EntityType.scala b/datastore/src/main/scala/magnolify/datastore/EntityType.scala index e1037e49d..66383c2d2 100644 --- a/datastore/src/main/scala/magnolify/datastore/EntityType.scala +++ b/datastore/src/main/scala/magnolify/datastore/EntityType.scala @@ -21,11 +21,10 @@ import com.google.datastore.v1._ import com.google.datastore.v1.client.DatastoreHelper.makeValue import com.google.protobuf.{ByteString, NullValue} import magnolia1._ -import magnolify.datastore.EntityField.{ProductRecord, ValueClassRecord} import magnolify.shared.{CaseMapper, Converter} import magnolify.shims.FactoryCompat -import scala.annotation.{StaticAnnotation, implicitNotFound} +import scala.annotation.{implicitNotFound, StaticAnnotation} import scala.language.experimental.macros import scala.jdk.CollectionConverters._ import scala.collection.compat._ @@ -41,17 +40,17 @@ class key(val project: String = null, val namespace: String = null, val kind: St class excludeFromIndexes(val exclude: Boolean = true) extends StaticAnnotation with Serializable object EntityType { - implicit def apply[T: EntityField.Record]: EntityType[T] = EntityType(CaseMapper.identity) + implicit def apply[T: EntityField]: EntityType[T] = EntityType(CaseMapper.identity) - def apply[T](cm: CaseMapper)(implicit f: EntityField.Record[T]): EntityType[T] = f match { - case pr: ProductRecord[_] => + def apply[T](cm: CaseMapper)(implicit f: EntityField[T]): EntityType[T] = f match { + case r: EntityField.Record[_] => new EntityType[T] { private val caseMapper: CaseMapper = cm - override def from(v: Entity): T = pr.fromEntity(v)(caseMapper) - override def to(v: T): Entity.Builder = pr.toEntity(v)(caseMapper) + override def from(v: Entity): T = r.fromEntity(v)(caseMapper) + override def to(v: T): Entity.Builder = r.toEntity(v)(caseMapper) } - case _: ValueClassRecord[_] => - throw new IllegalArgumentException("Value classes are not valid EntityType") + case _ => + throw new IllegalArgumentException(s"EntityType can only be created from Record. Got $f") } } @@ -96,11 +95,7 @@ sealed trait EntityField[T] extends Serializable { object EntityField { - sealed trait Record[T] extends EntityField[T] - - sealed trait ValueClassRecord[T] extends Record[T] - - sealed trait ProductRecord[T] extends Record[T] { + sealed trait Record[T] extends EntityField[T] { def fromEntity(v: Entity)(cm: CaseMapper): T def toEntity(v: T)(cm: CaseMapper): Entity.Builder @@ -113,17 +108,17 @@ object EntityField { type Typeclass[T] = EntityField[T] - def join[T: KeyField](caseClass: CaseClass[Typeclass, T]): Record[T] = { + def join[T: KeyField](caseClass: CaseClass[Typeclass, T]): EntityField[T] = { if (caseClass.isValueClass) { val p = caseClass.parameters.head val tc = p.typeclass - new ValueClassRecord[T] { + new EntityField[T] { override lazy val keyField: KeyField[T] = tc.keyField.map(p.dereference) override def from(v: Value)(cm: CaseMapper): T = caseClass.construct(_ => tc.from(v)(cm)) override def to(v: T)(cm: CaseMapper): Value.Builder = tc.to(p.dereference(v))(cm) } } else { - new ProductRecord[T] { + new Record[T] { private val (keyIndex, keyOpt): (Int, Option[key]) = { val keys = caseClass.parameters .map(p => p -> getKey(p.annotations, s"${caseClass.typeName.full}#${p.label}")) @@ -215,9 +210,9 @@ object EntityField { @implicitNotFound("Cannot derive EntityField for sealed trait") private sealed trait Dispatchable[T] - def split[T: Dispatchable](sealedTrait: SealedTrait[Typeclass, T]): Record[T] = ??? + def split[T: Dispatchable](sealedTrait: SealedTrait[Typeclass, T]): EntityField[T] = ??? - implicit def gen[T]: Record[T] = macro Magnolia.gen[T] + implicit def gen[T]: EntityField[T] = macro Magnolia.gen[T] // //////////////////////////////////////////////// diff --git a/neo4j/src/main/scala/magnolify/neo4j/ValueType.scala b/neo4j/src/main/scala/magnolify/neo4j/ValueType.scala index c40a1f634..323861a04 100644 --- a/neo4j/src/main/scala/magnolify/neo4j/ValueType.scala +++ b/neo4j/src/main/scala/magnolify/neo4j/ValueType.scala @@ -30,20 +30,22 @@ import scala.collection.compat._ trait ValueType[T] extends Converter[T, Value, Value] { def apply(r: Value): T = from(r) - def apply(t: T): Value = to(t) } object ValueType { - implicit def apply[T: ValueField.Record]: ValueType[T] = ValueType(CaseMapper.identity) - - def apply[T](cm: CaseMapper)(implicit f: ValueField.Record[T]): ValueType[T] = new ValueType[T] { - private val caseMapper: CaseMapper = cm - - override def from(v: Value): T = f.from(v)(caseMapper) + implicit def apply[T: ValueField]: ValueType[T] = ValueType(CaseMapper.identity) - override def to(v: T): Value = f.to(v)(caseMapper) + def apply[T](cm: CaseMapper)(implicit f: ValueField[T]): ValueType[T] = f match { + case r: ValueField.Record[_] => + new ValueType[T] { + private val caseMapper: CaseMapper = cm + override def from(v: Value): T = r.from(v)(caseMapper) + override def to(v: T): Value = r.to(v)(caseMapper) + } + case _ => + throw new IllegalArgumentException(s"ValueType can only be created from Record. Got $f") } } @@ -61,41 +63,47 @@ object ValueField { // //////////////////////////////////////////////// type Typeclass[T] = ValueField[T] - def join[T](caseClass: CaseClass[Typeclass, T]): Record[T] = new Record[T] { - override def from(v: Value)(cm: CaseMapper): T = - caseClass.construct { p => - val field = cm.map(p.label) - try { - val value = if (caseClass.isValueClass) v else v.get(field) - p.typeclass.from(value)(cm) - } catch { - case e: ValueException => - throw new RuntimeException(s"Failed to decode $field: ${e.getMessage}", e) - } + def join[T](caseClass: CaseClass[Typeclass, T]): ValueField[T] = { + if (caseClass.isValueClass) { + val p = caseClass.parameters.head + val tc = p.typeclass + new ValueField[T] { + override def from(v: Value)(cm: CaseMapper): T = caseClass.construct(_ => tc.from(v)(cm)) + override def to(v: T)(cm: CaseMapper): Value = tc.to(p.dereference(v))(cm) } - - override def to(v: T)(cm: CaseMapper): Value = { - val jmap = if (caseClass.isValueClass) { - val p = caseClass.parameters.head - p.typeclass.to(p.dereference(v))(cm) - } else - caseClass.parameters - .foldLeft(Map.newBuilder[String, AnyRef]) { (m, p) => - m += cm.map(p.label) -> p.typeclass.to(p.dereference(v))(cm) - m + } else { + new Record[T] { + override def from(v: Value)(cm: CaseMapper): T = + caseClass.construct { p => + val field = cm.map(p.label) + try { + p.typeclass.from(v.get(field))(cm) + } catch { + case e: ValueException => + throw new RuntimeException(s"Failed to decode $field: ${e.getMessage}", e) + } } - .result() - .asJava - Values.value(jmap) + + override def to(v: T)(cm: CaseMapper): Value = { + val jmap = caseClass.parameters + .foldLeft(Map.newBuilder[String, AnyRef]) { (m, p) => + m += cm.map(p.label) -> p.typeclass.to(p.dereference(v))(cm) + m + } + .result() + .asJava + Values.value(jmap) + } + } } } @implicitNotFound("Cannot derive AvroField for sealed trait") private sealed trait Dispatchable[T] - def split[T: Dispatchable](sealedTrait: SealedTrait[Typeclass, T]): Record[T] = ??? + def split[T: Dispatchable](sealedTrait: SealedTrait[Typeclass, T]): ValueField[T] = ??? - implicit def gen[T]: Record[T] = macro Magnolia.gen[T] + implicit def gen[T]: ValueField[T] = macro Magnolia.gen[T] // //////////////////////////////////////////////// @@ -107,7 +115,6 @@ object ValueField { def apply[U](f: T => U)(g: U => T)(implicit af: ValueField[T]): ValueField[U] = new ValueField[U] { override def from(v: Value)(cm: CaseMapper): U = f(af.from(v)(cm)) - override def to(v: U)(cm: CaseMapper): Value = af.to(g(v))(cm) } } @@ -119,7 +126,6 @@ object ValueField { if (v.isNull) throw new ValueException("Cannot convert null value") f(v) } - override def to(v: T)(cm: CaseMapper): Value = Values.value(v) } diff --git a/neo4j/src/test/scala/magnolify/neo4j/ValueTypeSuite.scala b/neo4j/src/test/scala/magnolify/neo4j/ValueTypeSuite.scala index 2b8e667b4..1004ae498 100644 --- a/neo4j/src/test/scala/magnolify/neo4j/ValueTypeSuite.scala +++ b/neo4j/src/test/scala/magnolify/neo4j/ValueTypeSuite.scala @@ -22,13 +22,10 @@ import magnolify.test.Simple._ import magnolify.cats.auto._ import magnolify.scalacheck.auto._ import magnolify.shared.CaseMapper -import org.neo4j.driver.Value -import org.neo4j.driver.internal.value.{MapValue, StringValue} import org.scalacheck.{Arbitrary, Prop} import java.net.URI import scala.reflect.ClassTag -import scala.jdk.CollectionConverters._ class ValueTypeSuite extends MagnolifySuite { @@ -69,6 +66,14 @@ class ValueTypeSuite extends MagnolifySuite { test[Custom] } + test("AnyVal") { + implicit val vt: ValueType[HasValueClass] = ValueType[HasValueClass] + test[HasValueClass] + + val record = vt(HasValueClass(ValueClass("String"))) + assert(record.get("vc").asString() == "String") + } + test("LowerCamel mapping") { implicit val vt: ValueType[LowerCamel] = ValueType[LowerCamel](CaseMapper(_.toUpperCase)) test[LowerCamel] @@ -79,17 +84,4 @@ class ValueTypeSuite extends MagnolifySuite { assert(!fields.map(record.get).exists(_.isNull)) assert(!record.get("INNERFIELD").get("INNERFIRST").isNull) } - - test("AnyVal") { - implicit val vt: ValueType[HasValueClass] = ValueType[HasValueClass] - test[HasValueClass] - - val record = vt(HasValueClass(ValueClass("String"))) - assert(record.get("vc").asString() == "String") - - val v: Value = new StringValue("Hello, world") - val a = new MapValue(Map("vc" -> v).asJava) - val c = vt.from(a) - assert(c == HasValueClass(ValueClass("Hello, world"))) - } } diff --git a/parquet/src/main/scala/magnolify/parquet/ParquetType.scala b/parquet/src/main/scala/magnolify/parquet/ParquetType.scala index 7281b8dee..8f5dad3bb 100644 --- a/parquet/src/main/scala/magnolify/parquet/ParquetType.scala +++ b/parquet/src/main/scala/magnolify/parquet/ParquetType.scala @@ -85,19 +85,23 @@ sealed trait ParquetType[T] extends Serializable { object ParquetType { private val logger = LoggerFactory.getLogger(this.getClass) - implicit def apply[T](implicit f: ParquetField.Record[T], pa: ParquetArray): ParquetType[T] = + implicit def apply[T](implicit f: ParquetField[T], pa: ParquetArray): ParquetType[T] = ParquetType(CaseMapper.identity) def apply[T]( cm: CaseMapper - )(implicit f: ParquetField.Record[T], pa: ParquetArray): ParquetType[T] = - new ParquetType[T] { - override def schema: MessageType = Schema.message(f.schema(cm)) - - override val avroCompat: Boolean = pa == ParquetArray.AvroCompat.avroCompat || f.hasAvroArray - override def write(c: RecordConsumer, v: T): Unit = f.write(c, v)(cm) - override def newConverter: TypeConverter[T] = f.newConverter - } + )(implicit f: ParquetField[T], pa: ParquetArray): ParquetType[T] = f match { + case r: ParquetField.Record[_] => + new ParquetType[T] { + override def schema: MessageType = Schema.message(r.schema(cm)) + override val avroCompat: Boolean = + pa == ParquetArray.AvroCompat.avroCompat || f.hasAvroArray + override def write(c: RecordConsumer, v: T): Unit = r.write(c, v)(cm) + override def newConverter: TypeConverter[T] = r.newConverter + } + case _ => + throw new IllegalArgumentException(s"ParquetType can only be created from Record. Got $f") + } val ReadTypeKey = "parquet.type.read.type" val WriteTypeKey = "parquet.type.write.type" @@ -223,27 +227,25 @@ sealed trait ParquetField[T] extends Serializable { } object ParquetField { - type Typeclass[T] = ParquetField[T] - - sealed trait Record[T] extends ParquetField[T] - sealed trait ValueClassRecord[T] extends Record[T] - sealed trait ProductRecord[T] extends Record[T] { + sealed trait Record[T] extends ParquetField[T] { override protected val isGroup: Boolean = true override protected def isEmpty(v: T): Boolean = false } - def join[T](caseClass: CaseClass[Typeclass, T]): Record[T] = { + // //////////////////////////////////////////// + type Typeclass[T] = ParquetField[T] + + def join[T](caseClass: CaseClass[Typeclass, T]): ParquetField[T] = { if (caseClass.isValueClass) { val p = caseClass.parameters.head val tc = p.typeclass - new ValueClassRecord[T] { + new ParquetField[T] { override protected def buildSchema(cm: CaseMapper): Type = tc.buildSchema(cm) override protected def isEmpty(v: T): Boolean = tc.isEmpty(p.dereference(v)) override def write(c: RecordConsumer, v: T)(cm: CaseMapper): Unit = tc.writeGroup(c, p.dereference(v))(cm) override def newConverter: TypeConverter[T] = { - val buffered = tc - .newConverter + val buffered = tc.newConverter .asInstanceOf[TypeConverter.Buffered[p.PType]] new TypeConverter.Delegate[p.PType, T](buffered) { override def get: T = inner.get(b => caseClass.construct(_ => b.head)) @@ -251,7 +253,7 @@ object ParquetField { } } } else { - new ProductRecord[T] { + new Record[T] { override def buildSchema(cm: CaseMapper): Type = caseClass.parameters .foldLeft(Types.requiredGroup()) { (g, p) => @@ -305,9 +307,8 @@ object ParquetField { @implicitNotFound("Cannot derive ParquetType for sealed trait") private sealed trait Dispatchable[T] - def split[T: Dispatchable](sealedTrait: SealedTrait[Typeclass, T]): Typeclass[T] = ??? - - implicit def apply[T]: Record[T] = macro Magnolia.gen[T] + def split[T: Dispatchable](sealedTrait: SealedTrait[Typeclass, T]): ParquetField[T] = ??? + implicit def apply[T]: ParquetField[T] = macro Magnolia.gen[T] // //////////////////////////////////////////////// diff --git a/parquet/src/main/scala/magnolify/parquet/unsafe/package.scala b/parquet/src/main/scala/magnolify/parquet/unsafe/package.scala index e73d25a0e..78a4f7480 100644 --- a/parquet/src/main/scala/magnolify/parquet/unsafe/package.scala +++ b/parquet/src/main/scala/magnolify/parquet/unsafe/package.scala @@ -21,6 +21,9 @@ import magnolify.shared._ package object unsafe { implicit val pfChar = ParquetField.from[Int](_.toChar)(_.toInt) - implicit def pfUnsafeEnum[T](implicit et: EnumType[T]): ParquetField[UnsafeEnum[T]] = + implicit def pfUnsafeEnum[T](implicit + et: EnumType[T], + lp: shapeless.LowPriority + ): ParquetField[UnsafeEnum[T]] = ParquetField.from[String](UnsafeEnum.from(_))(UnsafeEnum.to(_)) } diff --git a/parquet/src/test/scala/magnolify/parquet/test/ParquetTypeSuite.scala b/parquet/src/test/scala/magnolify/parquet/test/ParquetTypeSuite.scala index e7ae70eb0..55b46dbc9 100644 --- a/parquet/src/test/scala/magnolify/parquet/test/ParquetTypeSuite.scala +++ b/parquet/src/test/scala/magnolify/parquet/test/ParquetTypeSuite.scala @@ -29,7 +29,6 @@ import magnolify.shared.CaseMapper import magnolify.test.Simple._ import magnolify.test.Time._ import magnolify.test._ -import org.apache.parquet.hadoop.ParquetFileReader import org.apache.parquet.io._ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.scalacheck._ diff --git a/protobuf/src/main/scala/magnolify/protobuf/ProtobufType.scala b/protobuf/src/main/scala/magnolify/protobuf/ProtobufType.scala index 9f0ed6eaa..78bb62488 100644 --- a/protobuf/src/main/scala/magnolify/protobuf/ProtobufType.scala +++ b/protobuf/src/main/scala/magnolify/protobuf/ProtobufType.scala @@ -60,26 +60,26 @@ object ProtobufOption { } object ProtobufType { - implicit def apply[T: ProtobufField.Record, MsgT <: Message: ClassTag](implicit + implicit def apply[T: ProtobufField, MsgT <: Message: ClassTag](implicit po: ProtobufOption ): ProtobufType[T, MsgT] = ProtobufType(CaseMapper.identity) def apply[T, MsgT <: Message](cm: CaseMapper)(implicit - f: ProtobufField.Record[T], + f: ProtobufField[T], ct: ClassTag[MsgT], po: ProtobufOption ): ProtobufType[T, MsgT] = f match { - case pr: ProtobufField.ProductRecord[_] => + case r: ProtobufField.Record[_] => new ProtobufType[T, MsgT] { { val descriptor = ct.runtimeClass .getMethod("getDescriptor") .invoke(null) .asInstanceOf[Descriptor] - if (pr.hasOptional) { - po.check(pr, descriptor.getFile.getSyntax) + if (r.hasOptional) { + po.check(r, descriptor.getFile.getSyntax) } - pr.checkDefaults(descriptor)(cm) + r.checkDefaults(descriptor)(cm) } @transient private var _newBuilder: Method = _ @@ -91,11 +91,11 @@ object ProtobufType { } private val caseMapper: CaseMapper = cm - override def from(v: MsgT): T = pr.from(v)(caseMapper) - override def to(v: T): MsgT = pr.to(v, newBuilder)(caseMapper).asInstanceOf[MsgT] + override def from(v: MsgT): T = r.from(v)(caseMapper) + override def to(v: T): MsgT = r.to(v, newBuilder)(caseMapper).asInstanceOf[MsgT] } - case _: ProtobufField.ValueClassRecord[_] => - throw new IllegalArgumentException("Value classes are not valid ProtobufType") + case _ => + throw new IllegalArgumentException(s"ProtobufType can only be created from Record. Got $f") } } @@ -119,11 +119,7 @@ object ProtobufField { override type ToT = To } - sealed trait Record[T] extends ProtobufField[T] - sealed trait ValueClassRecord[T] extends Record[T] - sealed trait ProductRecord[T] extends Record[T] { - override type FromT = Message - override type ToT = Message + sealed trait Record[T] extends Aux[T, Message, Message] { override val default: Option[T] = None } @@ -131,11 +127,11 @@ object ProtobufField { type Typeclass[T] = ProtobufField[T] - def join[T](caseClass: CaseClass[Typeclass, T]): Record[T] = { + def join[T](caseClass: CaseClass[Typeclass, T]): ProtobufField[T] = { if (caseClass.isValueClass) { val p = caseClass.parameters.head val tc = p.typeclass - new ValueClassRecord[T] { + new ProtobufField[T] { override type FromT = tc.FromT override type ToT = tc.ToT override val hasOptional: Boolean = tc.hasOptional @@ -146,7 +142,7 @@ object ProtobufField { } } else { - new ProductRecord[T] { + new Record[T] { // One Record[T] instance may be used for multiple Message types @transient private lazy val fieldsCache: concurrent.Map[String, Array[FieldDescriptor]] = concurrent.TrieMap.empty @@ -226,9 +222,9 @@ object ProtobufField { @implicitNotFound("Cannot derive ProtobufField for sealed trait") private sealed trait Dispatchable[T] - def split[T: Dispatchable](sealedTrait: SealedTrait[Typeclass, T]): Record[T] = ??? + def split[T: Dispatchable](sealedTrait: SealedTrait[Typeclass, T]): ProtobufField[T] = ??? - implicit def gen[T]: Record[T] = macro Magnolia.gen[T] + implicit def gen[T]: ProtobufField[T] = macro Magnolia.gen[T] // //////////////////////////////////////////////// diff --git a/protobuf/src/main/scala/magnolify/protobuf/unsafe/package.scala b/protobuf/src/main/scala/magnolify/protobuf/unsafe/package.scala index 6ce396969..46b2b5633 100644 --- a/protobuf/src/main/scala/magnolify/protobuf/unsafe/package.scala +++ b/protobuf/src/main/scala/magnolify/protobuf/unsafe/package.scala @@ -27,7 +27,10 @@ package object unsafe { implicit val proto3Option: ProtobufOption = new ProtobufOption.Proto3Option } - implicit def pfUnsafeEnum[T](implicit et: EnumType[T]): ProtobufField[UnsafeEnum[T]] = + implicit def pfUnsafeEnum[T](implicit + et: EnumType[T], + lp: shapeless.LowPriority + ): ProtobufField[UnsafeEnum[T]] = ProtobufField .from[String](s => if (s == null || s.isEmpty) null else UnsafeEnum.from(s))(UnsafeEnum.to(_)) } diff --git a/shared/src/main/scala/magnolify/shared/Converter.scala b/shared/src/main/scala/magnolify/shared/Converter.scala index bf51193c9..6e1df1088 100644 --- a/shared/src/main/scala/magnolify/shared/Converter.scala +++ b/shared/src/main/scala/magnolify/shared/Converter.scala @@ -32,6 +32,12 @@ sealed trait Value[+T] { def isSome: Boolean = this.isInstanceOf[Value.Some[_]] def isEmpty: Boolean = this eq Value.None + def map[U](f: T => U): Value[U] = this match { + case Value.Some(x) => Value.Some(f(x)) + case Value.Default(x) => Value.Default(f(x)) + case Value.None => Value.None + } + def getOrElse[U](fallback: Option[U])(implicit ev: T <:< U): U = (this, fallback) match { case (Value.Some(x), _) => x case (Value.Default(_), Some(x)) => x diff --git a/tensorflow/src/main/scala/magnolify/tensorflow/ExampleType.scala b/tensorflow/src/main/scala/magnolify/tensorflow/ExampleType.scala index 87f29bf8d..28114e53d 100644 --- a/tensorflow/src/main/scala/magnolify/tensorflow/ExampleType.scala +++ b/tensorflow/src/main/scala/magnolify/tensorflow/ExampleType.scala @@ -41,16 +41,19 @@ sealed trait ExampleType[T] extends Converter[T, Example, Example.Builder] { } object ExampleType { - implicit def apply[T: ExampleField.Record]: ExampleType[T] = ExampleType(CaseMapper.identity) - - def apply[T](cm: CaseMapper)(implicit f: ExampleField.Record[T]): ExampleType[T] = - new ExampleType[T] { - private val caseMapper: CaseMapper = cm - @transient override lazy val schema: Schema = f.schema(caseMapper) - override def from(v: Example): T = f.get(v.getFeatures, null)(caseMapper).get - override def to(v: T): Example.Builder = - Example.newBuilder().setFeatures(f.put(Features.newBuilder(), null, v)(caseMapper)) - } + implicit def apply[T: ExampleField]: ExampleType[T] = ExampleType(CaseMapper.identity) + + def apply[T](cm: CaseMapper)(implicit f: ExampleField[T]): ExampleType[T] = f match { + case r: ExampleField.Record[_] => + new ExampleType[T] { + @transient override lazy val schema: Schema = r.schema(cm) + override def from(v: Example): T = r.get(v.getFeatures, null)(cm).get + override def to(v: T): Example.Builder = + Example.newBuilder().setFeatures(r.put(Features.newBuilder(), null, v)(cm)) + } + case _ => + throw new IllegalArgumentException(s"ExampleType can only be created from Record. Got $f") + } } sealed trait ExampleField[T] extends Serializable { @@ -97,11 +100,11 @@ object ExampleField { type Typeclass[T] = ExampleField[T] - def join[T](caseClass: CaseClass[Typeclass, T]): Record[T] = { + def join[T](caseClass: CaseClass[Typeclass, T]): ExampleField[T] = { if (caseClass.isValueClass) { val p = caseClass.parameters.head val tc = p.typeclass - new Record[T] { + new ExampleField[T] { override protected def buildSchema(cm: CaseMapper): Schema = tc.buildSchema(cm) override def get(f: Features, k: String)(cm: CaseMapper): Value[T] = @@ -171,9 +174,9 @@ object ExampleField { @implicitNotFound("Cannot derive ExampleField for sealed trait") private sealed trait Dispatchable[T] - def split[T: Dispatchable](sealedTrait: SealedTrait[Typeclass, T]): Record[T] = ??? + def split[T: Dispatchable](sealedTrait: SealedTrait[Typeclass, T]): ExampleField[T] = ??? - implicit def gen[T]: Record[T] = macro Magnolia.gen[T] + implicit def gen[T]: ExampleField[T] = macro Magnolia.gen[T] // ////////////////////////////////////////////////