diff --git a/src/main/boilerplate/spray/json/ProductFormatsInstances.scala.template b/src/main/boilerplate/spray/json/ProductFormatsInstances.scala.template index 3d29b582..f52b2e47 100644 --- a/src/main/boilerplate/spray/json/ProductFormatsInstances.scala.template +++ b/src/main/boilerplate/spray/json/ProductFormatsInstances.scala.template @@ -16,7 +16,10 @@ package spray.json +import scala.collection.mutable.HashSet + trait ProductFormatsInstances { self: ProductFormats with StandardFormats => + def allowExtraKeys: Boolean = true [# // Case classes with 1 parameters def jsonFormat1[[#P1 :JF#], T <: Product :ClassManifest](construct: ([#P1#]) => T): RootJsonFormat[T] = { @@ -24,6 +27,10 @@ trait ProductFormatsInstances { self: ProductFormats with StandardFormats => jsonFormat(construct, [#p1#]) } def jsonFormat[[#P1 :JF#], T <: Product](construct: ([#P1#]) => T, [#fieldName1: String#]): RootJsonFormat[T] = new RootJsonFormat[T]{ + val knownFields = HashSet[String]() + [# knownFields.add(fieldName1)# + ] + def write(p: T) = { val fields = new collection.mutable.ListBuffer[(String, JsValue)] fields.sizeHint(1 * 2) @@ -34,6 +41,14 @@ trait ProductFormatsInstances { self: ProductFormats with StandardFormats => def read(value: JsValue) = { [#val p1V = fromField[P1](value, fieldName1)# ] + if (!allowExtraKeys) { + val jsObject = value.asJsObject() + val keySet = jsObject.fields.keys.toSet + val keySetDiff = keySet.diff(knownFields) + if (!keySetDiff.isEmpty) { + throw new DeserializationException(s"${keySetDiff.head} is not a known key", null, keySetDiff.toList) + } + } construct([#p1V#]) } }# diff --git a/src/main/scala/spray/json/ProductFormats.scala b/src/main/scala/spray/json/ProductFormats.scala index 7d6c63e2..35c0d28a 100644 --- a/src/main/scala/spray/json/ProductFormats.scala +++ b/src/main/scala/spray/json/ProductFormats.scala @@ -37,7 +37,7 @@ trait ProductFormats extends ProductFormatsInstances { } // helpers - + protected def productElement2Field[T](fieldName: String, p: Product, ix: Int, rest: List[JsField] = Nil) (implicit writer: JsonWriter[T]): List[JsField] = { val value = p.productElement(ix).asInstanceOf[T] @@ -152,3 +152,14 @@ trait NullOptions extends ProductFormats { (fieldName, writer.write(value)) :: rest } } + +/** + * This trait changes the behavior for reading JSON values. + * If you mix in this trait into your custom JsonProtocol, JSON deserialization + * will throw an error if the input contains keys that are not a field of the + * target case class. + */ +trait ExtraKeysOptions extends ProductFormats { + this: StandardFormats => + override def allowExtraKeys = false +} diff --git a/src/test/scala/spray/json/ProductFormatsSpec.scala b/src/test/scala/spray/json/ProductFormatsSpec.scala index 30582a8f..6ad7ea64 100644 --- a/src/test/scala/spray/json/ProductFormatsSpec.scala +++ b/src/test/scala/spray/json/ProductFormatsSpec.scala @@ -44,6 +44,8 @@ class ProductFormatsSpec extends Specification { object TestProtocol1 extends DefaultJsonProtocol with TestProtocol object TestProtocol2 extends DefaultJsonProtocol with TestProtocol with NullOptions + object TestProtocol3 extends DefaultJsonProtocol with TestProtocol with ExtraKeysOptions + "A JsonFormat created with `jsonFormat`, for a case class with 2 elements," should { import TestProtocol1._ val obj = Test2(42, Some(4.2)) @@ -61,6 +63,9 @@ class ProductFormatsSpec extends Specification { "not require the presence of optional fields for deserialization" in { JsObject("a" -> JsNumber(42)).convertTo[Test2] mustEqual Test2(42, None) } + "allow extra fields during deserialization" in { + JsObject("a" -> JsNumber(42), "extra_key" -> JsNumber(43)).convertTo[Test2] mustEqual Test2(42, None) + } "not render `None` members during serialization" in { Test2(42, None).toJson mustEqual JsObject("a" -> JsNumber(42)) } @@ -92,6 +97,15 @@ class ProductFormatsSpec extends Specification { } } + "A JsonProtocol mixing in ExtraKeysOptions" should { + "not allow extra keys to be read" in { + import TestProtocol3._ + JsObject("a" -> JsNumber(42), "extra_key" -> JsNumber(43)).convertTo[Test2] must throwA[DeserializationException].like { + case DeserializationException(_, _, fieldNames) => fieldNames mustEqual "extra_key" :: Nil + } + } + } + "A JsonFormat for a generic case class and created with `jsonFormat`" should { import TestProtocol1._ val obj = Test3(42 :: 43 :: Nil, "x" :: "y" :: "z" :: Nil)