Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support untagged ADTs #435

Open
carymrobbins opened this issue Dec 3, 2019 · 2 comments
Open

Support untagged ADTs #435

carymrobbins opened this issue Dec 3, 2019 · 2 comments
Labels

Comments

@carymrobbins
Copy link

Is it possible to derive a codec which supports writing/reading untagged ADTs? Something like Aeson's UntaggedValue, of course with some caveat that the encodings must be disjoint. I am deserializing data from another source and don't have control over the JSON format.

I'm also open to other ideas, e.g. deriving my own codec by creating alternations of derived codecs. I'm not sure how possible this is. (I've found writing the codecs by hand to be pretty cumbersome, which is not really a problem so long as we come up with some way to make this work).

Ideal situation would be something like the following. Again, if this currently achievable in some other way I'd be fine with it.

sealed trait Foo
final case class Bar(a: String) extends Foo
final case class Baz(b: Int) extends Foo

implicit val fooCodec: JsonValueCodec[Foo] =
  JsonCodecMaker.make(CodecMakerConfig.withUntaggedAdt)

writeToString(Bar("yo"))
// {"a":"yo"}
writeToString(Baz(12))
// {"b":12}
readFromString[Foo](""" {"a":"yo"} """)
// Bar("yo")
readFromString[Foo](""" {"b":12} """)
// Baz(12)
@plokhotnyuk
Copy link
Owner

plokhotnyuk commented Dec 4, 2019

@carymrobbins

Hello, Cary!
Thanks for reaching out!

Currently, it is achievable only with manually written custom codecs or 3rd-party derivation.

Let's start from your simplified example. Codecs for Bar and Baz are derived, but for Foo we can write a custom one:

implicit val barCodec: JsonValueCodec[Bar] = JsonCodecMaker.make(CodecMakerConfig)
implicit val bazCodec: JsonValueCodec[Baz] = JsonCodecMaker.make(CodecMakerConfig)
implicit val fooCodec: JsonValueCodec[Foo] = new JsonValueCodec[Foo] {
  override def decodeValue(in:  JsonReader, default:  Foo): Foo = {
    in.setMark()
    if (in.isNextToken('{')) {
      val l = in.readKeyAsCharBuf()
      if (in.isCharBufEqualsTo(l, "a")) {
        in.rollbackToMark()
        barCodec.decodeValue(in, barCodec.nullValue)
      } else if (in.isCharBufEqualsTo(l, "b")) {
        in.rollbackToMark()
        bazCodec.decodeValue(in, bazCodec.nullValue)
      } else in.unexpectedKeyError(l)
    } else in.readNullOrTokenError(default, '{')
  }
  override def encodeValue(x:  Foo, out: JsonWriter): Unit = x match {
    case x: Bar => barCodec.encodeValue(x, out)
    case x: Baz => bazCodec.encodeValue(x, out)
    case null => out.writeNull()
  }
  override val nullValue: Foo = null
}

println(writeToString[Foo](Bar("yo")))
println(writeToString[Foo](Baz(12)))
println(readFromString[Foo](""" {"a":"yo"} """))
println(readFromString[Foo](""" {"b":12} """))

Output of this code should be:

{"a":"yo"}
{"b":12}
Bar(yo)
Baz(12)

This solution can be easy evolved if all sub-types have unique field names.

In case if some simple intersection is possible you can parse all keys accumulating some bits (one per unique name of the required field) and skipping paired values out. And then after reaching JSON object end (} character) do rollback to marked position and switch to parsing of detected type:

sealed trait Foo
final case class Bar(a: String, x: Option[String]) extends Foo
final case class Baz(y: Option[Int], b: Int) extends Foo
final case class Qux(a: Int, z: Seq[Double], b: String) extends Foo

implicit val barCodec: JsonValueCodec[Bar] = JsonCodecMaker.make(CodecMakerConfig)
implicit val bazCodec: JsonValueCodec[Baz] = JsonCodecMaker.make(CodecMakerConfig)
implicit val quxCodec: JsonValueCodec[Qux] = JsonCodecMaker.make(CodecMakerConfig)
implicit val fooCodec: JsonValueCodec[Foo] = new JsonValueCodec[Foo] {
  override def decodeValue(in:  JsonReader, default:  Foo): Foo = {
    in.setMark()
    if (in.isNextToken('{')) {
      var p0 = 3
      do {
        val l = in.readKeyAsCharBuf()
        if (in.isCharBufEqualsTo(l, "a")) {
          if ((p0 & 1) != 0) p0 ^= 1
          else in.duplicatedKeyError(l)
        } else if (in.isCharBufEqualsTo(l, "b")) {
          if ((p0 & 2) != 0) p0 ^= 2
          else in.duplicatedKeyError(l)
        }
        in.skip()
      } while (in.isNextToken(','))
      in.rollbackToMark()
      p0 match {
        case 0 => quxCodec.decodeValue(in, quxCodec.nullValue)
        case 1 => bazCodec.decodeValue(in, bazCodec.nullValue)
        case 2 => barCodec.decodeValue(in, barCodec.nullValue)
        case _ => in.decodeError("missing required field(s)")
      }
    } else in.readNullOrTokenError(default, '{')
  }
  override def encodeValue(x:  Foo, out: JsonWriter): Unit = x match {
    case x: Bar => barCodec.encodeValue(x, out)
    case x: Baz => bazCodec.encodeValue(x, out)
    case x: Qux => quxCodec.encodeValue(x, out)
    case null => out.writeNull()
  }
  override val nullValue: Foo = null
}

println(writeToString[Foo](Bar("yo", None)))
println(writeToString[Foo](Baz(None, 12)))
println(writeToString[Foo](Qux(12, Seq(), "yo")))
println(writeToString[Foo](Bar("yo", Some("lo"))))
println(writeToString[Foo](Baz(Some(42), 12)))
println(writeToString[Foo](Qux(12, Seq(1.0, 2.0), "yo")))
println(readFromString[Foo](""" {"a":"yo"} """))
println(readFromString[Foo](""" {"b":12} """))
println(readFromString[Foo](""" {"a":12,"b":"yo"} """))
println(readFromString[Foo](""" {"a":"yo","x":"lo"} """))
println(readFromString[Foo](""" {"y":42,"b":12} """))
println(readFromString[Foo](""" {"a":12,"z":[1.0,2.0],"b":"yo"} """))

Expected output is:

{"a":"yo"}
{"b":12}
{"a":12,"b":"yo"}
{"a":"yo","x":"lo"}
{"y":42,"b":12}
{"a":12,"z":[1.0,2.0],"b":"yo"}
Bar(yo,None)
Baz(None,12)
Qux(12,List(),yo)
Bar(yo,Some(lo))
Baz(Some(42),12)
Qux(12,List(1.0, 2.0),yo)

This solution will work fine while there are no matching of keys for optional fields and fields with collection/array types which are optional by default, and if the set of required keys is different for each sub-type. Some of those matches can be overcome by trying to distinguish them by parsing of values and handing out errors instead of skipping.

If number of keys is greater than 8 for more efficiency the matching by the key hash code can be used (with subsequent resolving of hash collision cases by exact comparison).

BTW, you can see code for derived codes by turning on -Xmacro-settings:print-codecs option for scalac. Or you can just open the following link to see build logs with code for codecs that are used in benchmarks, including ADTs: https://plokhotnyuk.github.io/jsoniter-scala/openjdk-13.txt

So, implementation of withUntaggedAdt option for automatic derivation looks possible, but I'm not sure if it can be easily done for all supported data types securely and efficiently.

@carymrobbins
Copy link
Author

I did something similar, but not quite as efficient -

https://gist.github.com/carymrobbins/d0b900257cadb458b3de9b1b532cb2b3

I had played with a few approaches and this one seems to work best. However, I'm not sure exactly how the mark stuff works and if I did any of this correctly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants