Skip to content

Commit

Permalink
Feat: scala3 enumeration support (#1068)
Browse files Browse the repository at this point in the history
* addscala 3 enumeration decoding/encoding
  • Loading branch information
ThijsBroersen committed Jun 14, 2024
1 parent 608bb6e commit 5773f6c
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 61 deletions.
17 changes: 17 additions & 0 deletions docs/decoding.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,23 @@ Decoding fail because 'Pear' is not a valid value

Almost all of the standard library data types are supported as fields on the case class, and it is easy to add support if one is missing.

### Sealed families and enums for Scala 3
Sealed families where all members are only objects, or a Scala 3 enum with all cases parameterless are interpreted as enumerations and will encode 1:1 with their value-names.
```scala
enum Foo derives JsonDecoder:
case Bar
case Baz
case Qux
```
or
```scala
sealed trait Foo derives JsonDecoder
object Foo:
case object Bar extends Foo
case object Baz extends Foo
case object Qux extends Foo
```

## Manual instances

Sometimes it is easier to reuse an existing `JsonDecoder` rather than generate a new one. This can be accomplished using convenience methods on the `JsonDecoder` typeclass to *derive* new decoders
Expand Down
17 changes: 17 additions & 0 deletions docs/encoding.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,23 @@ apple.toJson

Almost all of the standard library data types are supported as fields on the case class, and it is easy to add support if one is missing.

### Sealed families and enums for Scala 3
Sealed families where all members are only objects, or a Scala 3 enum with all cases parameterless are interpreted as enumerations and will encode 1:1 with their value-names.
```scala
enum Foo derives JsonEncoder:
case Bar
case Baz
case Qux
```
or
```scala
sealed trait Foo derives JsonEncoder
object Foo:
case object Bar extends Foo
case object Baz extends Foo
case object Qux extends Foo
```

## Manual instances

Sometimes it is easier to reuse an existing `JsonEncoder` rather than generate a new one. This can be accomplished using convenience methods on the `JsonEncoder` typeclass to *derive* new decoders:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ object GoldenSpec extends ZIOSpecDefault {
sealed trait SumType

object SumType {
case object Case1 extends SumType
case object Case2 extends SumType
case object Case3 extends SumType
case object Case1 extends SumType
case object Case2 extends SumType
case class Case3() extends SumType

implicit val jsonCodec: JsonCodec[SumType] = DeriveJsonCodec.gen
}
Expand Down
112 changes: 88 additions & 24 deletions zio-json/shared/src/main/scala-3/zio/json/macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -207,21 +207,7 @@ final class jsonNoExtraFields extends Annotation
*/
final class jsonExclude extends Annotation

// TODO: implement same configuration as for Scala 2 once this issue is resolved: https://github.com/softwaremill/magnolia/issues/296
object DeriveJsonDecoder extends Derivation[JsonDecoder] { self =>
def join[A](ctx: CaseClass[Typeclass, A]): JsonDecoder[A] = {
val (transformNames, nameTransform): (Boolean, String => String) =
ctx.annotations.collectFirst { case jsonMemberNames(format) => format }
.map(true -> _)
.getOrElse(false -> identity)

val no_extra = ctx
.annotations
.collectFirst { case _: jsonNoExtraFields => () }
.isDefined

if (ctx.params.isEmpty) {
new JsonDecoder[A] {
private class CaseObjectDecoder[Typeclass[*], A](val ctx: CaseClass[Typeclass, A], no_extra: Boolean) extends JsonDecoder[A] {
def unsafeDecode(trace: List[JsonError], in: RetractReader): A = {
if (no_extra) {
Lexer.char(trace, in, '{')
Expand All @@ -239,6 +225,22 @@ object DeriveJsonDecoder extends Derivation[JsonDecoder] { self =>
case _ => throw UnsafeJson(JsonError.Message("Not an object") :: trace)
}
}

// TODO: implement same configuration as for Scala 2 once this issue is resolved: https://github.com/softwaremill/magnolia/issues/296
object DeriveJsonDecoder extends Derivation[JsonDecoder] { self =>
def join[A](ctx: CaseClass[Typeclass, A]): JsonDecoder[A] = {
val (transformNames, nameTransform): (Boolean, String => String) =
ctx.annotations.collectFirst { case jsonMemberNames(format) => format }
.map(true -> _)
.getOrElse(false -> identity)

val no_extra = ctx
.annotations
.collectFirst { case _: jsonNoExtraFields => () }
.isDefined

if (ctx.params.isEmpty) {
new CaseObjectDecoder(ctx, no_extra)
} else {
new JsonDecoder[A] {
val (names, aliases): (Array[String], Array[(String, Int)]) = {
Expand Down Expand Up @@ -400,9 +402,35 @@ object DeriveJsonDecoder extends Derivation[JsonDecoder] { self =>
lazy val namesMap: Map[String, Int] =
names.zipWithIndex.toMap

def isEnumeration =
(ctx.isEnum && ctx.subtypes.forall(_.typeclass.isInstanceOf[CaseObjectDecoder[?, ?]])) || (
!ctx.isEnum && ctx.subtypes.forall(_.isObject)
)

def discrim = ctx.annotations.collectFirst { case jsonDiscriminator(n) => n }

if (discrim.isEmpty) {
if (isEnumeration) {
new JsonDecoder[A] {
def unsafeDecode(trace: List[JsonError], in: RetractReader): A = {
val typeName = Lexer.string(trace, in).toString()
namesMap.find(_._1 == typeName) match {
case Some((_, idx)) => tcs(idx).asInstanceOf[CaseObjectDecoder[JsonDecoder, A]].ctx.rawConstruct(Nil)
case None => throw UnsafeJson(JsonError.Message(s"Invalid enumeration value $typeName") :: trace)
}
}

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): A = {
json match {
case Json.Str(typeName) =>
ctx.subtypes.find(_.typeInfo.short == typeName) match {
case Some(sub) => sub.typeclass.asInstanceOf[CaseObjectDecoder[JsonDecoder, A]].ctx.rawConstruct(Nil)
case None => throw UnsafeJson(JsonError.Message(s"Invalid enumeration value $typeName") :: trace)
}
case _ => throw UnsafeJson(JsonError.Message("Not a string") :: trace)
}
}
}
} else if (discrim.isEmpty) {
// We're not allowing extra fields in this encoding
new JsonDecoder[A] {
val spans: Array[JsonError] = names.map(JsonError.ObjectAccess(_))
Expand Down Expand Up @@ -506,16 +534,18 @@ object DeriveJsonDecoder extends Derivation[JsonDecoder] { self =>
}
}

private lazy val caseObjectEncoder = new JsonEncoder[Any] {
def unsafeEncode(a: Any, indent: Option[Int], out: Write): Unit =
out.write("{}")

override final def toJsonAST(a: Any): Either[String, Json] =
Right(Json.Obj(Chunk.empty))
}

object DeriveJsonEncoder extends Derivation[JsonEncoder] { self =>
def join[A](ctx: CaseClass[Typeclass, A]): JsonEncoder[A] =
if (ctx.params.isEmpty) {
new JsonEncoder[A] {
def unsafeEncode(a: A, indent: Option[Int], out: Write): Unit =
out.write("{}")

override final def toJsonAST(a: A): Either[String, Json] =
Right(Json.Obj(Chunk.empty))
}
caseObjectEncoder.narrow[A]
} else {
new JsonEncoder[A] {
val (transformNames, nameTransform): (Boolean, String => String) =
Expand Down Expand Up @@ -612,15 +642,49 @@ object DeriveJsonEncoder extends Derivation[JsonEncoder] { self =>
}

def split[A](ctx: SealedTrait[JsonEncoder, A]): JsonEncoder[A] = {
val isEnumeration =
(ctx.isEnum && ctx.subtypes.forall(_.typeclass == caseObjectEncoder)) || (
!ctx.isEnum && ctx.subtypes.forall(_.isObject)
)

val jsonHintFormat: JsonMemberFormat =
ctx.annotations.collectFirst { case jsonHintNames(format) => format }.getOrElse(IdentityFormat)

val discrim = ctx
.annotations
.collectFirst {
case jsonDiscriminator(n) => n
}

if (discrim.isEmpty) {
if (isEnumeration) {
new JsonEncoder[A] {
def unsafeEncode(a: A, indent: Option[Int], out: Write): Unit = {
val typeName = ctx.choose(a) { sub =>
sub
.annotations
.collectFirst {
case jsonHint(name) => name
}.getOrElse(sub.typeInfo.short)
}

JsonEncoder.string.unsafeEncode(typeName, indent, out)
}

override final def toJsonAST(a: A): Either[String, Json] = {
ctx.choose(a) { sub =>
Right(
Json.Str(
sub
.annotations
.collectFirst {
case jsonHint(name) => name
}.getOrElse(sub.typeInfo.short)
)
)
}
}
}
} else if (discrim.isEmpty) {
new JsonEncoder[A] {
def unsafeEncode(a: A, indent: Option[Int], out: Write): Unit = {
ctx.choose(a) { sub =>
Expand Down
51 changes: 34 additions & 17 deletions zio-json/shared/src/test/scala-3/zio/json/DerivedDecoderSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,42 @@ object DerivedDecoderSpec extends ZIOSpecDefault {

val spec = suite("DerivedDecoderSpec")(
test("Derives for a product type") {
assertZIO(typeCheck {
"""
case class Foo(bar: String) derives JsonDecoder
case class Foo(bar: String) derives JsonDecoder

"{\"bar\": \"hello\"}".fromJson[Foo]
"""
})(isRight(anything))
val result = "{\"bar\": \"hello\"}".fromJson[Foo]

assertTrue(result == Right(Foo("hello")))
},
test("Derives for a sum enum Enumeration type") {
enum Foo derives JsonDecoder:
case Bar
case Baz
case Qux

val result = "\"Qux\"".fromJson[Foo]

assertTrue(result == Right(Foo.Qux))
},
test("Derives for a sum type") {
assertZIO(typeCheck {
"""
enum Foo derives JsonDecoder:
case Bar
case Baz(baz: String)
case Qux(foo: Foo)
"{\"Qux\":{\"foo\":{\"Bar\":{}}}}".fromJson[Foo]
"""
})(isRight(anything))
test("Derives for a sum sealed trait Enumeration type") {
sealed trait Foo derives JsonDecoder
object Foo:
case object Bar extends Foo
case object Baz extends Foo
case object Qux extends Foo

val result = "\"Qux\"".fromJson[Foo]

assertTrue(result == Right(Foo.Qux))
},
test("Derives for a sum ADT type") {
enum Foo derives JsonDecoder:
case Bar
case Baz(baz: String)
case Qux(foo: Foo)

val result = "{\"Qux\":{\"foo\":{\"Bar\":{}}}}".fromJson[Foo]

assertTrue(result == Right(Foo.Qux(Foo.Bar)))
},
test("Derives and decodes for a union of string-based literals") {
case class Foo(aOrB: "A" | "B", optA: Option["A"]) derives JsonDecoder
Expand Down
51 changes: 34 additions & 17 deletions zio-json/shared/src/test/scala-3/zio/json/DerivedEncoderSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,42 @@ import zio.test._
object DerivedEncoderSpec extends ZIOSpecDefault {
val spec = suite("DerivedEncoderSpec")(
test("Derives for a product type") {
assertZIO(typeCheck {
"""
case class Foo(bar: String) derives JsonEncoder
case class Foo(bar: String) derives JsonEncoder

Foo("bar").toJson
"""
})(isRight(anything))
val json = Foo("bar").toJson

assertTrue(json == """{"bar":"bar"}""")
},
test("Derives for a sum type") {
assertZIO(typeCheck {
"""
enum Foo derives JsonEncoder:
case Bar
case Baz(baz: String)
case Qux(foo: Foo)
(Foo.Qux(Foo.Bar): Foo).toJson
"""
})(isRight(anything))
test("Derives for a sum enum Enumeration type") {
enum Foo derives JsonEncoder:
case Bar
case Baz
case Qux

val json = (Foo.Qux: Foo).toJson

assertTrue(json == """"Qux"""")
},
test("Derives for a sum sealed trait Enumeration type") {
sealed trait Foo derives JsonEncoder
object Foo:
case object Bar extends Foo
case object Baz extends Foo
case object Qux extends Foo

val json = (Foo.Qux: Foo).toJson

assertTrue(json == """"Qux"""")
},
test("Derives for a sum ADT type") {
enum Foo derives JsonEncoder:
case Bar
case Baz(baz: String)
case Qux(foo: Foo)

val json = (Foo.Qux(Foo.Bar): Foo).toJson

assertTrue(json == """{"Qux":{"foo":{"Bar":{}}}}""")
},
test("Derives and encodes for a union of string-based literals") {
case class Foo(aOrB: "A" | "B", optA: Option["A"]) derives JsonEncoder
Expand Down

0 comments on commit 5773f6c

Please sign in to comment.