From e35e184262e369311635f47bbf8cb88036e669c5 Mon Sep 17 00:00:00 2001 From: Attila Szegedi Date: Thu, 22 Sep 2022 18:05:00 +0000 Subject: [PATCH] scrooge: Allow null values for nullable fields in StructBuilder Problem ------- `c.t.scrooge.StructBuilder` methods `setField` and `setAllFields` don't accept `null` as a valid value for nullable fields, making it impossible to use `StructBuilder` to construct otherwise valid instances that have `null` as a value of one or more such fields. Solution -------- Enhance `StructBuilder` so it allows setting `null` values for nullable fields. JIRA Issues: CSL-12312 Differential Revision: https://phabricator.twitter.biz/D976737 --- CHANGELOG.rst | 2 + .../com/twitter/scrooge/StructBuilder.scala | 41 +++++++++++++- .../gold/thriftscala/AnotherException.scala | 10 ++-- .../test/gold/thriftscala/CollectionId.scala | 10 ++-- .../test/gold/thriftscala/GoldService.scala | 56 +++++++++++-------- .../thriftscala/OverCapacityException.scala | 10 ++-- .../gold/thriftscala/PlatinumService.scala | 24 ++++---- .../test/gold/thriftscala/Recursive.scala | 10 ++-- .../test/gold/thriftscala/Request.scala | 30 +++++----- .../gold/thriftscala/RequestException.scala | 14 +++-- .../test/gold/thriftscala/Response.scala | 14 +++-- .../scrooge/backend/ScalaGeneratorSpec.scala | 31 ++++++++++ .../main/resources/scalagen/struct.mustache | 14 +++-- .../twitter/scrooge/backend/Generator.scala | 21 ++++--- .../scrooge/backend/StructTemplate.scala | 3 +- 15 files changed, 197 insertions(+), 93 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index c826b7f2..362f7e95 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -39,6 +39,8 @@ Runtime Behavior Changes * scrooge: Use the util-mock for testing which depends on the 3.12.4 "org.mockito", remove the "org.jmock" dependencies. ``PHAB_ID=D973946`` +* scrooge: `c.t.scrooge.StructBuilder` now accepts `null` values for non-primitive default-required fields. ``PHAB_ID=D976737`` + 22.7.0 ------ diff --git a/scrooge-core/src/main/scala/com/twitter/scrooge/StructBuilder.scala b/scrooge-core/src/main/scala/com/twitter/scrooge/StructBuilder.scala index 7e900cdf..41a6c923 100644 --- a/scrooge-core/src/main/scala/com/twitter/scrooge/StructBuilder.scala +++ b/scrooge-core/src/main/scala/com/twitter/scrooge/StructBuilder.scala @@ -1,6 +1,7 @@ package com.twitter.scrooge import com.twitter.util.Memoize +import org.apache.thrift.protocol.TType import scala.reflect.ClassTag /** @@ -8,9 +9,12 @@ import scala.reflect.ClassTag * ThriftStruct or statically from T. * * We pass in a list of [[ClassTag]]s which describe each of the struct's field types - * so that we can validate the values we are setting at runtime. + * so that we can validate the values we are setting at runtime. We also pass in a set + * of indices of non-primitive default-required fields; those can have null as valid value. */ -abstract class StructBuilder[T <: ThriftStruct](fieldTypes: IndexedSeq[ClassTag[_]]) { +abstract class StructBuilder[T <: ThriftStruct]( + fieldTypes: IndexedSeq[ClassTag[_]], + nullableIndices: Set[Int]) { protected val fieldArray: Array[Any] = new Array[Any](fieldTypes.size) /** @@ -23,6 +27,7 @@ abstract class StructBuilder[T <: ThriftStruct](fieldTypes: IndexedSeq[ClassTag[ private[this] def addOrUpdateFieldArray[A](index: Int, v: Any)(implicit tag: ClassTag[A]): Unit = v match { case inputValue: A => fieldArray(index) = inputValue + case null if nullableIndices(index) => fieldArray(index) = StructBuilder.ExplicitNull case _ => throw new IllegalArgumentException(s"value at index $index must be of type $tag") } @@ -84,12 +89,44 @@ abstract class StructBuilder[T <: ThriftStruct](fieldTypes: IndexedSeq[ClassTag[ * This object provides operations to obtain `StructBuilder` instances. */ object StructBuilder { + // Used as an explicitly set null value for default-required fields + private[StructBuilder] object ExplicitNull + + def unwrapExplicitNull(v: Any): Any = + v match { + case ExplicitNull => null + case _ => v + } + private[this] val memoizeBuilderMethod: Class[_] => () => Any = Memoize.classValue { clazz => val thriftCodec = ThriftStructCodec.forStructClass(clazz.asSubclass(classOf[ThriftStruct])) val m = thriftCodec.getClass.getMethod("newBuilder") () => m.invoke(thriftCodec) } + /** + * Given a sequence of struct field information, returns a set of indices of non-primitive + * default-required fields. + * @param fieldInfos a sequence of field informations + * @return a set of indices for non-primitive default-required fields + */ + final def nullableIndices(fieldInfos: Seq[ThriftStructFieldInfo]): Set[Int] = { + fieldInfos.zipWithIndex.collect { + case (f, i) if isNullable(f) => i + }.toSet + } + + private[this] def isNullable(f: ThriftStructFieldInfo): Boolean = { + if (f.isOptional || f.isRequired) { + false + } else { + f.tfield.`type` match { + case TType.BOOL | TType.BYTE | TType.I16 | TType.I32 | TType.I64 | TType.DOUBLE => false + case _ => true + } + } + } + /** * For a given scrooge-generated thrift struct or union class, returns its StructBuilder. * This can be used for building a new ThriftStruct object. diff --git a/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/AnotherException.scala b/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/AnotherException.scala index f82fdb3c..5fe9c532 100644 --- a/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/AnotherException.scala +++ b/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/AnotherException.scala @@ -51,6 +51,8 @@ object AnotherException extends ValidatingThriftStructCodec3[AnotherException] w ) + lazy val nullableIndices: Set[Int] = StructBuilder.nullableIndices(fieldInfos) + val structAnnotations: immutable$Map[String, String] = immutable$Map.empty[String, String] @@ -111,7 +113,7 @@ object AnotherException extends ValidatingThriftStructCodec3[AnotherException] w ) } - def newBuilder(): StructBuilder[AnotherException] = new AnotherExceptionStructBuilder(_root_.scala.None, fieldTypes) + def newBuilder(): StructBuilder[AnotherException] = new AnotherExceptionStructBuilder(_root_.scala.None, fieldTypes, nullableIndices) override def encode(_item: AnotherException, _oproto: TProtocol): Unit = { _item.write(_oproto) @@ -348,11 +350,11 @@ class AnotherException( flags ) - def newBuilder(): StructBuilder[AnotherException] = new AnotherExceptionStructBuilder(_root_.scala.Some(this), fieldTypes) + def newBuilder(): StructBuilder[AnotherException] = new AnotherExceptionStructBuilder(_root_.scala.Some(this), fieldTypes, nullableIndices) } -private[thriftscala] class AnotherExceptionStructBuilder(instance: _root_.scala.Option[AnotherException], fieldTypes: IndexedSeq[ClassTag[_]]) - extends StructBuilder[AnotherException](fieldTypes) { +private[thriftscala] class AnotherExceptionStructBuilder(instance: _root_.scala.Option[AnotherException], fieldTypes: IndexedSeq[ClassTag[_]], nullableIndices: Set[Int]) + extends StructBuilder[AnotherException](fieldTypes, nullableIndices) { def build(): AnotherException = { val _fieldArray = fieldArray // shadow variable diff --git a/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/CollectionId.scala b/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/CollectionId.scala index e1cb90b0..f5914715 100644 --- a/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/CollectionId.scala +++ b/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/CollectionId.scala @@ -54,6 +54,8 @@ object CollectionId extends ValidatingThriftStructCodec3[CollectionId] with Stru ) + lazy val nullableIndices: Set[Int] = StructBuilder.nullableIndices(fieldInfos) + val structAnnotations: immutable$Map[String, String] = immutable$Map.empty[String, String] @@ -114,7 +116,7 @@ object CollectionId extends ValidatingThriftStructCodec3[CollectionId] with Stru ) } - def newBuilder(): StructBuilder[CollectionId] = new CollectionIdStructBuilder(_root_.scala.None, fieldTypes) + def newBuilder(): StructBuilder[CollectionId] = new CollectionIdStructBuilder(_root_.scala.None, fieldTypes, nullableIndices) override def encode(_item: CollectionId, _oproto: TProtocol): Unit = { _item.write(_oproto) @@ -444,11 +446,11 @@ trait CollectionId def _codec: ValidatingThriftStructCodec3[CollectionId] = CollectionId - def newBuilder(): StructBuilder[CollectionId] = new CollectionIdStructBuilder(_root_.scala.Some(this), fieldTypes) + def newBuilder(): StructBuilder[CollectionId] = new CollectionIdStructBuilder(_root_.scala.Some(this), fieldTypes, nullableIndices) } -private[thriftscala] class CollectionIdStructBuilder(instance: _root_.scala.Option[CollectionId], fieldTypes: IndexedSeq[ClassTag[_]]) - extends StructBuilder[CollectionId](fieldTypes) { +private[thriftscala] class CollectionIdStructBuilder(instance: _root_.scala.Option[CollectionId], fieldTypes: IndexedSeq[ClassTag[_]], nullableIndices: Set[Int]) + extends StructBuilder[CollectionId](fieldTypes, nullableIndices) { def build(): CollectionId = { val _fieldArray = fieldArray // shadow variable diff --git a/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/GoldService.scala b/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/GoldService.scala index 46b881a4..4f7e0c2a 100644 --- a/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/GoldService.scala +++ b/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/GoldService.scala @@ -329,6 +329,8 @@ object GoldService extends _root_.com.twitter.finagle.thrift.GeneratedThriftServ ) + lazy val nullableIndices: Set[Int] = StructBuilder.nullableIndices(fieldInfos) + val structAnnotations: immutable$Map[String, String] = immutable$Map.empty[String, String] @@ -425,7 +427,7 @@ object GoldService extends _root_.com.twitter.finagle.thrift.GeneratedThriftServ ) } - def newBuilder(): StructBuilder[Args] = new ArgsStructBuilder(_root_.scala.None, fieldTypes) + def newBuilder(): StructBuilder[Args] = new ArgsStructBuilder(_root_.scala.None, fieldTypes, nullableIndices) override def encode(_item: Args, _oproto: TProtocol): Unit = { _item.write(_oproto) @@ -589,27 +591,27 @@ object GoldService extends _root_.com.twitter.finagle.thrift.GeneratedThriftServ def _codec: ValidatingThriftStructCodec3[Args] = Args - def newBuilder(): StructBuilder[Args] = new ArgsStructBuilder(_root_.scala.Some(this), fieldTypes) + def newBuilder(): StructBuilder[Args] = new ArgsStructBuilder(_root_.scala.Some(this), fieldTypes, nullableIndices) } - private[thriftscala] class ArgsStructBuilder(instance: _root_.scala.Option[Args], fieldTypes: IndexedSeq[ClassTag[_]]) - extends StructBuilder[Args](fieldTypes) { + private[thriftscala] class ArgsStructBuilder(instance: _root_.scala.Option[Args], fieldTypes: IndexedSeq[ClassTag[_]], nullableIndices: Set[Int]) + extends StructBuilder[Args](fieldTypes, nullableIndices) { def build(): Args = { val _fieldArray = fieldArray // shadow variable if (instance.isDefined) { val instanceValue = instance.get Args( - if (_fieldArray(0) == null) instanceValue.request else _fieldArray(0).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.Request], - if (_fieldArray(1) == null) instanceValue.unionRequest else _fieldArray(1).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.RequestUnion], - if (_fieldArray(2) == null) instanceValue.exceptionRequest else _fieldArray(2).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.RequestException] + if (_fieldArray(0) == null) instanceValue.request else StructBuilder.unwrapExplicitNull(_fieldArray(0)).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.Request], + if (_fieldArray(1) == null) instanceValue.unionRequest else StructBuilder.unwrapExplicitNull(_fieldArray(1)).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.RequestUnion], + if (_fieldArray(2) == null) instanceValue.exceptionRequest else StructBuilder.unwrapExplicitNull(_fieldArray(2)).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.RequestException] ) } else { if (genericArrayOps(_fieldArray).contains(null)) throw new InvalidFieldsException(structBuildError("Args")) Args( - _fieldArray(0).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.Request], - _fieldArray(1).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.RequestUnion], - _fieldArray(2).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.RequestException] + StructBuilder.unwrapExplicitNull(_fieldArray(0)).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.Request], + StructBuilder.unwrapExplicitNull(_fieldArray(1)).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.RequestUnion], + StructBuilder.unwrapExplicitNull(_fieldArray(2)).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.RequestException] ) } } @@ -655,6 +657,8 @@ object GoldService extends _root_.com.twitter.finagle.thrift.GeneratedThriftServ ) + lazy val nullableIndices: Set[Int] = StructBuilder.nullableIndices(fieldInfos) + val structAnnotations: immutable$Map[String, String] = immutable$Map.empty[String, String] @@ -739,7 +743,7 @@ object GoldService extends _root_.com.twitter.finagle.thrift.GeneratedThriftServ ) } - def newBuilder(): StructBuilder[Result] = new ResultStructBuilder(_root_.scala.None, fieldTypes) + def newBuilder(): StructBuilder[Result] = new ResultStructBuilder(_root_.scala.None, fieldTypes, nullableIndices) override def encode(_item: Result, _oproto: TProtocol): Unit = { _item.write(_oproto) @@ -886,11 +890,11 @@ object GoldService extends _root_.com.twitter.finagle.thrift.GeneratedThriftServ def _codec: ValidatingThriftStructCodec3[Result] = Result - def newBuilder(): StructBuilder[Result] = new ResultStructBuilder(_root_.scala.Some(this), fieldTypes) + def newBuilder(): StructBuilder[Result] = new ResultStructBuilder(_root_.scala.Some(this), fieldTypes, nullableIndices) } - private[thriftscala] class ResultStructBuilder(instance: _root_.scala.Option[Result], fieldTypes: IndexedSeq[ClassTag[_]]) - extends StructBuilder[Result](fieldTypes) { + private[thriftscala] class ResultStructBuilder(instance: _root_.scala.Option[Result], fieldTypes: IndexedSeq[ClassTag[_]], nullableIndices: Set[Int]) + extends StructBuilder[Result](fieldTypes, nullableIndices) { def build(): Result = { val _fieldArray = fieldArray // shadow variable @@ -971,6 +975,8 @@ object GoldService extends _root_.com.twitter.finagle.thrift.GeneratedThriftServ ) + lazy val nullableIndices: Set[Int] = StructBuilder.nullableIndices(fieldInfos) + val structAnnotations: immutable$Map[String, String] = immutable$Map.empty[String, String] @@ -1035,7 +1041,7 @@ object GoldService extends _root_.com.twitter.finagle.thrift.GeneratedThriftServ ) } - def newBuilder(): StructBuilder[Args] = new ArgsStructBuilder(_root_.scala.None, fieldTypes) + def newBuilder(): StructBuilder[Args] = new ArgsStructBuilder(_root_.scala.None, fieldTypes, nullableIndices) override def encode(_item: Args, _oproto: TProtocol): Unit = { _item.write(_oproto) @@ -1156,23 +1162,23 @@ object GoldService extends _root_.com.twitter.finagle.thrift.GeneratedThriftServ def _codec: ValidatingThriftStructCodec3[Args] = Args - def newBuilder(): StructBuilder[Args] = new ArgsStructBuilder(_root_.scala.Some(this), fieldTypes) + def newBuilder(): StructBuilder[Args] = new ArgsStructBuilder(_root_.scala.Some(this), fieldTypes, nullableIndices) } - private[thriftscala] class ArgsStructBuilder(instance: _root_.scala.Option[Args], fieldTypes: IndexedSeq[ClassTag[_]]) - extends StructBuilder[Args](fieldTypes) { + private[thriftscala] class ArgsStructBuilder(instance: _root_.scala.Option[Args], fieldTypes: IndexedSeq[ClassTag[_]], nullableIndices: Set[Int]) + extends StructBuilder[Args](fieldTypes, nullableIndices) { def build(): Args = { val _fieldArray = fieldArray // shadow variable if (instance.isDefined) { val instanceValue = instance.get Args( - if (_fieldArray(0) == null) instanceValue.request else _fieldArray(0).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.Request] + if (_fieldArray(0) == null) instanceValue.request else StructBuilder.unwrapExplicitNull(_fieldArray(0)).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.Request] ) } else { if (genericArrayOps(_fieldArray).contains(null)) throw new InvalidFieldsException(structBuildError("Args")) Args( - _fieldArray(0).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.Request] + StructBuilder.unwrapExplicitNull(_fieldArray(0)).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.Request] ) } } @@ -1204,6 +1210,8 @@ object GoldService extends _root_.com.twitter.finagle.thrift.GeneratedThriftServ ) + lazy val nullableIndices: Set[Int] = StructBuilder.nullableIndices(fieldInfos) + val structAnnotations: immutable$Map[String, String] = immutable$Map.empty[String, String] @@ -1270,7 +1278,7 @@ object GoldService extends _root_.com.twitter.finagle.thrift.GeneratedThriftServ ) } - def newBuilder(): StructBuilder[Result] = new ResultStructBuilder(_root_.scala.None, fieldTypes) + def newBuilder(): StructBuilder[Result] = new ResultStructBuilder(_root_.scala.None, fieldTypes, nullableIndices) override def encode(_item: Result, _oproto: TProtocol): Unit = { _item.write(_oproto) @@ -1393,11 +1401,11 @@ object GoldService extends _root_.com.twitter.finagle.thrift.GeneratedThriftServ def _codec: ValidatingThriftStructCodec3[Result] = Result - def newBuilder(): StructBuilder[Result] = new ResultStructBuilder(_root_.scala.Some(this), fieldTypes) + def newBuilder(): StructBuilder[Result] = new ResultStructBuilder(_root_.scala.Some(this), fieldTypes, nullableIndices) } - private[thriftscala] class ResultStructBuilder(instance: _root_.scala.Option[Result], fieldTypes: IndexedSeq[ClassTag[_]]) - extends StructBuilder[Result](fieldTypes) { + private[thriftscala] class ResultStructBuilder(instance: _root_.scala.Option[Result], fieldTypes: IndexedSeq[ClassTag[_]], nullableIndices: Set[Int]) + extends StructBuilder[Result](fieldTypes, nullableIndices) { def build(): Result = { val _fieldArray = fieldArray // shadow variable diff --git a/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/OverCapacityException.scala b/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/OverCapacityException.scala index 91b8bfbc..b50b2956 100644 --- a/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/OverCapacityException.scala +++ b/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/OverCapacityException.scala @@ -53,6 +53,8 @@ object OverCapacityException extends ValidatingThriftStructCodec3[OverCapacityEx ) + lazy val nullableIndices: Set[Int] = StructBuilder.nullableIndices(fieldInfos) + lazy val structAnnotations: immutable$Map[String, String] = immutable$Map[String, String]( ("e.annotation", "true") @@ -115,7 +117,7 @@ object OverCapacityException extends ValidatingThriftStructCodec3[OverCapacityEx ) } - def newBuilder(): StructBuilder[OverCapacityException] = new OverCapacityExceptionStructBuilder(_root_.scala.None, fieldTypes) + def newBuilder(): StructBuilder[OverCapacityException] = new OverCapacityExceptionStructBuilder(_root_.scala.None, fieldTypes, nullableIndices) override def encode(_item: OverCapacityException, _oproto: TProtocol): Unit = { _item.write(_oproto) @@ -352,11 +354,11 @@ class OverCapacityException( flags ) - def newBuilder(): StructBuilder[OverCapacityException] = new OverCapacityExceptionStructBuilder(_root_.scala.Some(this), fieldTypes) + def newBuilder(): StructBuilder[OverCapacityException] = new OverCapacityExceptionStructBuilder(_root_.scala.Some(this), fieldTypes, nullableIndices) } -private[thriftscala] class OverCapacityExceptionStructBuilder(instance: _root_.scala.Option[OverCapacityException], fieldTypes: IndexedSeq[ClassTag[_]]) - extends StructBuilder[OverCapacityException](fieldTypes) { +private[thriftscala] class OverCapacityExceptionStructBuilder(instance: _root_.scala.Option[OverCapacityException], fieldTypes: IndexedSeq[ClassTag[_]], nullableIndices: Set[Int]) + extends StructBuilder[OverCapacityException](fieldTypes, nullableIndices) { def build(): OverCapacityException = { val _fieldArray = fieldArray // shadow variable diff --git a/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/PlatinumService.scala b/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/PlatinumService.scala index d21acba4..ba01dbb2 100644 --- a/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/PlatinumService.scala +++ b/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/PlatinumService.scala @@ -336,6 +336,8 @@ object PlatinumService extends _root_.com.twitter.finagle.thrift.GeneratedThrift ) + lazy val nullableIndices: Set[Int] = StructBuilder.nullableIndices(fieldInfos) + val structAnnotations: immutable$Map[String, String] = immutable$Map.empty[String, String] @@ -400,7 +402,7 @@ object PlatinumService extends _root_.com.twitter.finagle.thrift.GeneratedThrift ) } - def newBuilder(): StructBuilder[Args] = new ArgsStructBuilder(_root_.scala.None, fieldTypes) + def newBuilder(): StructBuilder[Args] = new ArgsStructBuilder(_root_.scala.None, fieldTypes, nullableIndices) override def encode(_item: Args, _oproto: TProtocol): Unit = { _item.write(_oproto) @@ -521,23 +523,23 @@ object PlatinumService extends _root_.com.twitter.finagle.thrift.GeneratedThrift def _codec: ValidatingThriftStructCodec3[Args] = Args - def newBuilder(): StructBuilder[Args] = new ArgsStructBuilder(_root_.scala.Some(this), fieldTypes) + def newBuilder(): StructBuilder[Args] = new ArgsStructBuilder(_root_.scala.Some(this), fieldTypes, nullableIndices) } - private[thriftscala] class ArgsStructBuilder(instance: _root_.scala.Option[Args], fieldTypes: IndexedSeq[ClassTag[_]]) - extends StructBuilder[Args](fieldTypes) { + private[thriftscala] class ArgsStructBuilder(instance: _root_.scala.Option[Args], fieldTypes: IndexedSeq[ClassTag[_]], nullableIndices: Set[Int]) + extends StructBuilder[Args](fieldTypes, nullableIndices) { def build(): Args = { val _fieldArray = fieldArray // shadow variable if (instance.isDefined) { val instanceValue = instance.get Args( - if (_fieldArray(0) == null) instanceValue.request else _fieldArray(0).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.Request] + if (_fieldArray(0) == null) instanceValue.request else StructBuilder.unwrapExplicitNull(_fieldArray(0)).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.Request] ) } else { if (genericArrayOps(_fieldArray).contains(null)) throw new InvalidFieldsException(structBuildError("Args")) Args( - _fieldArray(0).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.Request] + StructBuilder.unwrapExplicitNull(_fieldArray(0)).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.Request] ) } } @@ -597,6 +599,8 @@ object PlatinumService extends _root_.com.twitter.finagle.thrift.GeneratedThrift ) + lazy val nullableIndices: Set[Int] = StructBuilder.nullableIndices(fieldInfos) + val structAnnotations: immutable$Map[String, String] = immutable$Map.empty[String, String] @@ -693,7 +697,7 @@ object PlatinumService extends _root_.com.twitter.finagle.thrift.GeneratedThrift ) } - def newBuilder(): StructBuilder[Result] = new ResultStructBuilder(_root_.scala.None, fieldTypes) + def newBuilder(): StructBuilder[Result] = new ResultStructBuilder(_root_.scala.None, fieldTypes, nullableIndices) override def encode(_item: Result, _oproto: TProtocol): Unit = { _item.write(_oproto) @@ -859,11 +863,11 @@ object PlatinumService extends _root_.com.twitter.finagle.thrift.GeneratedThrift def _codec: ValidatingThriftStructCodec3[Result] = Result - def newBuilder(): StructBuilder[Result] = new ResultStructBuilder(_root_.scala.Some(this), fieldTypes) + def newBuilder(): StructBuilder[Result] = new ResultStructBuilder(_root_.scala.Some(this), fieldTypes, nullableIndices) } - private[thriftscala] class ResultStructBuilder(instance: _root_.scala.Option[Result], fieldTypes: IndexedSeq[ClassTag[_]]) - extends StructBuilder[Result](fieldTypes) { + private[thriftscala] class ResultStructBuilder(instance: _root_.scala.Option[Result], fieldTypes: IndexedSeq[ClassTag[_]], nullableIndices: Set[Int]) + extends StructBuilder[Result](fieldTypes, nullableIndices) { def build(): Result = { val _fieldArray = fieldArray // shadow variable diff --git a/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/Recursive.scala b/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/Recursive.scala index 913f9eec..8dbb2dad 100644 --- a/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/Recursive.scala +++ b/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/Recursive.scala @@ -68,6 +68,8 @@ object Recursive extends ValidatingThriftStructCodec3[Recursive] with StructBuil ) + lazy val nullableIndices: Set[Int] = StructBuilder.nullableIndices(fieldInfos) + val structAnnotations: immutable$Map[String, String] = immutable$Map.empty[String, String] @@ -146,7 +148,7 @@ object Recursive extends ValidatingThriftStructCodec3[Recursive] with StructBuil ) } - def newBuilder(): StructBuilder[Recursive] = new RecursiveStructBuilder(_root_.scala.None, fieldTypes) + def newBuilder(): StructBuilder[Recursive] = new RecursiveStructBuilder(_root_.scala.None, fieldTypes, nullableIndices) override def encode(_item: Recursive, _oproto: TProtocol): Unit = { _item.write(_oproto) @@ -522,11 +524,11 @@ trait Recursive def _codec: ValidatingThriftStructCodec3[Recursive] = Recursive - def newBuilder(): StructBuilder[Recursive] = new RecursiveStructBuilder(_root_.scala.Some(this), fieldTypes) + def newBuilder(): StructBuilder[Recursive] = new RecursiveStructBuilder(_root_.scala.Some(this), fieldTypes, nullableIndices) } -private[thriftscala] class RecursiveStructBuilder(instance: _root_.scala.Option[Recursive], fieldTypes: IndexedSeq[ClassTag[_]]) - extends StructBuilder[Recursive](fieldTypes) { +private[thriftscala] class RecursiveStructBuilder(instance: _root_.scala.Option[Recursive], fieldTypes: IndexedSeq[ClassTag[_]], nullableIndices: Set[Int]) + extends StructBuilder[Recursive](fieldTypes, nullableIndices) { def build(): Recursive = { val _fieldArray = fieldArray // shadow variable diff --git a/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/Request.scala b/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/Request.scala index 837b5410..80867039 100644 --- a/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/Request.scala +++ b/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/Request.scala @@ -275,6 +275,8 @@ object Request extends ValidatingThriftStructCodec3[Request] with StructBuilderF ) + lazy val nullableIndices: Set[Int] = StructBuilder.nullableIndices(fieldInfos) + lazy val structAnnotations: immutable$Map[String, String] = immutable$Map[String, String]( ("s.annotation.one", "something"), @@ -544,7 +546,7 @@ object Request extends ValidatingThriftStructCodec3[Request] with StructBuilderF ) } - def newBuilder(): StructBuilder[Request] = new RequestStructBuilder(_root_.scala.None, fieldTypes) + def newBuilder(): StructBuilder[Request] = new RequestStructBuilder(_root_.scala.None, fieldTypes, nullableIndices) override def encode(_item: Request, _oproto: TProtocol): Unit = { _item.write(_oproto) @@ -1799,23 +1801,23 @@ trait Request def _codec: ValidatingThriftStructCodec3[Request] = Request - def newBuilder(): StructBuilder[Request] = new RequestStructBuilder(_root_.scala.Some(this), fieldTypes) + def newBuilder(): StructBuilder[Request] = new RequestStructBuilder(_root_.scala.Some(this), fieldTypes, nullableIndices) } -private[thriftscala] class RequestStructBuilder(instance: _root_.scala.Option[Request], fieldTypes: IndexedSeq[ClassTag[_]]) - extends StructBuilder[Request](fieldTypes) { +private[thriftscala] class RequestStructBuilder(instance: _root_.scala.Option[Request], fieldTypes: IndexedSeq[ClassTag[_]], nullableIndices: Set[Int]) + extends StructBuilder[Request](fieldTypes, nullableIndices) { def build(): Request = { val _fieldArray = fieldArray // shadow variable if (instance.isDefined) { val instanceValue = instance.get Request( - if (_fieldArray(0) == null) instanceValue.aList else _fieldArray(0).asInstanceOf[_root_.scala.collection.Seq[String]], - if (_fieldArray(1) == null) instanceValue.aSet else _fieldArray(1).asInstanceOf[_root_.scala.collection.Set[Int]], - if (_fieldArray(2) == null) instanceValue.aMap else _fieldArray(2).asInstanceOf[_root_.scala.collection.Map[Long, Long]], + if (_fieldArray(0) == null) instanceValue.aList else StructBuilder.unwrapExplicitNull(_fieldArray(0)).asInstanceOf[_root_.scala.collection.Seq[String]], + if (_fieldArray(1) == null) instanceValue.aSet else StructBuilder.unwrapExplicitNull(_fieldArray(1)).asInstanceOf[_root_.scala.collection.Set[Int]], + if (_fieldArray(2) == null) instanceValue.aMap else StructBuilder.unwrapExplicitNull(_fieldArray(2)).asInstanceOf[_root_.scala.collection.Map[Long, Long]], if (_fieldArray(3) == null) instanceValue.aRequest else _fieldArray(3).asInstanceOf[_root_.scala.Option[com.twitter.scrooge.test.gold.thriftscala.Request]], - if (_fieldArray(4) == null) instanceValue.subRequests else _fieldArray(4).asInstanceOf[_root_.scala.collection.Seq[com.twitter.scrooge.test.gold.thriftscala.Request]], - if (_fieldArray(5) == null) instanceValue._default else _fieldArray(5).asInstanceOf[String], + if (_fieldArray(4) == null) instanceValue.subRequests else StructBuilder.unwrapExplicitNull(_fieldArray(4)).asInstanceOf[_root_.scala.collection.Seq[com.twitter.scrooge.test.gold.thriftscala.Request]], + if (_fieldArray(5) == null) instanceValue._default else StructBuilder.unwrapExplicitNull(_fieldArray(5)).asInstanceOf[String], if (_fieldArray(6) == null) instanceValue.noComment else _fieldArray(6).asInstanceOf[_root_.scala.Option[Long]], if (_fieldArray(7) == null) instanceValue.doubleSlashComment else _fieldArray(7).asInstanceOf[_root_.scala.Option[Long]], if (_fieldArray(8) == null) instanceValue.hashtagComment else _fieldArray(8).asInstanceOf[_root_.scala.Option[Long]], @@ -1830,12 +1832,12 @@ private[thriftscala] class RequestStructBuilder(instance: _root_.scala.Option[Re } else { if (genericArrayOps(_fieldArray).contains(null)) throw new InvalidFieldsException(structBuildError("Request")) Request( - _fieldArray(0).asInstanceOf[_root_.scala.collection.Seq[String]], - _fieldArray(1).asInstanceOf[_root_.scala.collection.Set[Int]], - _fieldArray(2).asInstanceOf[_root_.scala.collection.Map[Long, Long]], + StructBuilder.unwrapExplicitNull(_fieldArray(0)).asInstanceOf[_root_.scala.collection.Seq[String]], + StructBuilder.unwrapExplicitNull(_fieldArray(1)).asInstanceOf[_root_.scala.collection.Set[Int]], + StructBuilder.unwrapExplicitNull(_fieldArray(2)).asInstanceOf[_root_.scala.collection.Map[Long, Long]], _fieldArray(3).asInstanceOf[_root_.scala.Option[com.twitter.scrooge.test.gold.thriftscala.Request]], - _fieldArray(4).asInstanceOf[_root_.scala.collection.Seq[com.twitter.scrooge.test.gold.thriftscala.Request]], - _fieldArray(5).asInstanceOf[String], + StructBuilder.unwrapExplicitNull(_fieldArray(4)).asInstanceOf[_root_.scala.collection.Seq[com.twitter.scrooge.test.gold.thriftscala.Request]], + StructBuilder.unwrapExplicitNull(_fieldArray(5)).asInstanceOf[String], _fieldArray(6).asInstanceOf[_root_.scala.Option[Long]], _fieldArray(7).asInstanceOf[_root_.scala.Option[Long]], _fieldArray(8).asInstanceOf[_root_.scala.Option[Long]], diff --git a/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/RequestException.scala b/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/RequestException.scala index 968a1594..8c37b10b 100644 --- a/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/RequestException.scala +++ b/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/RequestException.scala @@ -53,6 +53,8 @@ object RequestException extends ValidatingThriftStructCodec3[RequestException] w ) + lazy val nullableIndices: Set[Int] = StructBuilder.nullableIndices(fieldInfos) + val structAnnotations: immutable$Map[String, String] = immutable$Map.empty[String, String] @@ -113,7 +115,7 @@ object RequestException extends ValidatingThriftStructCodec3[RequestException] w ) } - def newBuilder(): StructBuilder[RequestException] = new RequestExceptionStructBuilder(_root_.scala.None, fieldTypes) + def newBuilder(): StructBuilder[RequestException] = new RequestExceptionStructBuilder(_root_.scala.None, fieldTypes, nullableIndices) override def encode(_item: RequestException, _oproto: TProtocol): Unit = { _item.write(_oproto) @@ -358,23 +360,23 @@ class RequestException( flags ) - def newBuilder(): StructBuilder[RequestException] = new RequestExceptionStructBuilder(_root_.scala.Some(this), fieldTypes) + def newBuilder(): StructBuilder[RequestException] = new RequestExceptionStructBuilder(_root_.scala.Some(this), fieldTypes, nullableIndices) } -private[thriftscala] class RequestExceptionStructBuilder(instance: _root_.scala.Option[RequestException], fieldTypes: IndexedSeq[ClassTag[_]]) - extends StructBuilder[RequestException](fieldTypes) { +private[thriftscala] class RequestExceptionStructBuilder(instance: _root_.scala.Option[RequestException], fieldTypes: IndexedSeq[ClassTag[_]], nullableIndices: Set[Int]) + extends StructBuilder[RequestException](fieldTypes, nullableIndices) { def build(): RequestException = { val _fieldArray = fieldArray // shadow variable if (instance.isDefined) { val instanceValue = instance.get RequestException( - if (_fieldArray(0) == null) instanceValue.message else _fieldArray(0).asInstanceOf[String] + if (_fieldArray(0) == null) instanceValue.message else StructBuilder.unwrapExplicitNull(_fieldArray(0)).asInstanceOf[String] ) } else { if (genericArrayOps(_fieldArray).contains(null)) throw new InvalidFieldsException(structBuildError("RequestException")) RequestException( - _fieldArray(0).asInstanceOf[String] + StructBuilder.unwrapExplicitNull(_fieldArray(0)).asInstanceOf[String] ) } } diff --git a/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/Response.scala b/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/Response.scala index 94f4d884..0414a46f 100644 --- a/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/Response.scala +++ b/scrooge-generator-tests/src/test/resources/gold_file_output_scala/com/twitter/scrooge/test/gold/thriftscala/Response.scala @@ -68,6 +68,8 @@ object Response extends ValidatingThriftStructCodec3[Response] with StructBuilde ) + lazy val nullableIndices: Set[Int] = StructBuilder.nullableIndices(fieldInfos) + lazy val structAnnotations: immutable$Map[String, String] = immutable$Map[String, String]( ("com.twitter.scrooge.scala.generateStructProxy", "true") @@ -146,7 +148,7 @@ object Response extends ValidatingThriftStructCodec3[Response] with StructBuilde ) } - def newBuilder(): StructBuilder[Response] = new ResponseStructBuilder(_root_.scala.None, fieldTypes) + def newBuilder(): StructBuilder[Response] = new ResponseStructBuilder(_root_.scala.None, fieldTypes, nullableIndices) override def encode(_item: Response, _oproto: TProtocol): Unit = { _item.write(_oproto) @@ -533,11 +535,11 @@ trait Response def _codec: ValidatingThriftStructCodec3[Response] = Response - def newBuilder(): StructBuilder[Response] = new ResponseStructBuilder(_root_.scala.Some(this), fieldTypes) + def newBuilder(): StructBuilder[Response] = new ResponseStructBuilder(_root_.scala.Some(this), fieldTypes, nullableIndices) } -private[thriftscala] class ResponseStructBuilder(instance: _root_.scala.Option[Response], fieldTypes: IndexedSeq[ClassTag[_]]) - extends StructBuilder[Response](fieldTypes) { +private[thriftscala] class ResponseStructBuilder(instance: _root_.scala.Option[Response], fieldTypes: IndexedSeq[ClassTag[_]], nullableIndices: Set[Int]) + extends StructBuilder[Response](fieldTypes, nullableIndices) { def build(): Response = { val _fieldArray = fieldArray // shadow variable @@ -545,13 +547,13 @@ private[thriftscala] class ResponseStructBuilder(instance: _root_.scala.Option[R val instanceValue = instance.get Response( if (_fieldArray(0) == null) instanceValue.statusCode else _fieldArray(0).asInstanceOf[Int], - if (_fieldArray(1) == null) instanceValue.responseUnion else _fieldArray(1).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.ResponseUnion] + if (_fieldArray(1) == null) instanceValue.responseUnion else StructBuilder.unwrapExplicitNull(_fieldArray(1)).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.ResponseUnion] ) } else { if (genericArrayOps(_fieldArray).contains(null)) throw new InvalidFieldsException(structBuildError("Response")) Response( _fieldArray(0).asInstanceOf[Int], - _fieldArray(1).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.ResponseUnion] + StructBuilder.unwrapExplicitNull(_fieldArray(1)).asInstanceOf[com.twitter.scrooge.test.gold.thriftscala.ResponseUnion] ) } } diff --git a/scrooge-generator-tests/src/test/scala/com/twitter/scrooge/backend/ScalaGeneratorSpec.scala b/scrooge-generator-tests/src/test/scala/com/twitter/scrooge/backend/ScalaGeneratorSpec.scala index fc9f7668..85f74d24 100644 --- a/scrooge-generator-tests/src/test/scala/com/twitter/scrooge/backend/ScalaGeneratorSpec.scala +++ b/scrooge-generator-tests/src/test/scala/com/twitter/scrooge/backend/ScalaGeneratorSpec.scala @@ -20,6 +20,7 @@ import com.twitter.scrooge.validation.MissingConstructionRequiredField import com.twitter.scrooge.validation.MissingRequiredField import com.twitter.util.Await import com.twitter.util.Future +import includes.a.thriftscala.CityState import inheritance.aaa.Aaa import inheritance.aaa.Box import inheritance.bbb.Bbb @@ -1626,6 +1627,36 @@ class ScalaGeneratorSpec extends JMockSpec with EvalHelper { staticBuilder.build() } } + + "create new struct with nullable fields" in { _ => + val builder = CityState.newBuilder().setField(0, null).setField(1, null) + assert(builder.build() == CityState(null, null)) + } + + "throw an exception if nullable fields aren't explicitly set" in { _ => + val builder = CityState.newBuilder().setField(0, null) + + intercept[InvalidFieldsException] { + // Not all fields were set + builder.build() + } + } + + "refuse null in non-nullable fields" in { _ => + val boxBuilder = Box.newBuilder() + boxBuilder.setField(0, 42) // can set it to an int value + intercept[IllegalArgumentException] { + // primitive fields aren't nullable + boxBuilder.setField(0, null) + } + + val ooeBuilder = OneOfEachOptional.newBuilder() + ooeBuilder.setField(6, Some("foo")) // can set it to option-of-string + intercept[IllegalArgumentException] { + // optional non-primitive fields aren't nullable either + ooeBuilder.setField(6, null) + } + } } } diff --git a/scrooge-generator/src/main/resources/scalagen/struct.mustache b/scrooge-generator/src/main/resources/scalagen/struct.mustache index 707f347b..af045c1e 100644 --- a/scrooge-generator/src/main/resources/scalagen/struct.mustache +++ b/scrooge-generator/src/main/resources/scalagen/struct.mustache @@ -95,6 +95,8 @@ object {{StructName}} extends ValidatingThriftStructCodec3[{{StructName}}] with ) {{/hasFields}} + lazy val nullableIndices: Set[Int] = StructBuilder.nullableIndices(fieldInfos) + {{#structAnnotations}}lazy {{/structAnnotations}}val structAnnotations: immutable$Map[String, String] = {{#structAnnotations}} immutable$Map[String, String]( @@ -211,7 +213,7 @@ object {{StructName}} extends ValidatingThriftStructCodec3[{{StructName}}] with ) } - def newBuilder(): StructBuilder[{{StructName}}] = new {{StructName}}StructBuilder(_root_.scala.None, fieldTypes) + def newBuilder(): StructBuilder[{{StructName}}] = new {{StructName}}StructBuilder(_root_.scala.None, fieldTypes, nullableIndices) override def encode(_item: {{StructName}}, _oproto: TProtocol): Unit = { _item.write(_oproto) @@ -931,11 +933,11 @@ class {{StructName}}( ) {{/hasFailureFlags}} - def newBuilder(): StructBuilder[{{StructName}}] = new {{StructName}}StructBuilder(_root_.scala.Some(this), fieldTypes) + def newBuilder(): StructBuilder[{{StructName}}] = new {{StructName}}StructBuilder(_root_.scala.Some(this), fieldTypes, nullableIndices) } -private[{{packageName}}] class {{StructName}}StructBuilder(instance: _root_.scala.Option[{{StructName}}], fieldTypes: IndexedSeq[ClassTag[_]]) - extends StructBuilder[{{StructName}}](fieldTypes) { +private[{{packageName}}] class {{StructName}}StructBuilder(instance: _root_.scala.Option[{{StructName}}], fieldTypes: IndexedSeq[ClassTag[_]], nullableIndices: Set[Int]) + extends StructBuilder[{{StructName}}](fieldTypes, nullableIndices) { def build(): {{StructName}} = { {{#hasFields}} @@ -944,14 +946,14 @@ private[{{packageName}}] class {{StructName}}StructBuilder(instance: _root_.scal val instanceValue = instance.get {{StructName}}( {{#fields}} - if (_fieldArray({{index}}) == null) instanceValue.{{fieldName}}{{#constructionRequired}}.get{{/constructionRequired}} else _fieldArray({{index}}).asInstanceOf[{{>constructionOptionalType}}] + if (_fieldArray({{index}}) == null) instanceValue.{{fieldName}}{{#constructionRequired}}.get{{/constructionRequired}} else {{#nullValid}}StructBuilder.unwrapExplicitNull({{/nullValid}}_fieldArray({{index}}){{#nullValid}}){{/nullValid}}.asInstanceOf[{{>constructionOptionalType}}] {{/fields|,}} ) } else { if (genericArrayOps(_fieldArray).contains(null)) throw new InvalidFieldsException(structBuildError("{{StructName}}")) {{StructName}}( {{#fields}} - _fieldArray({{index}}).asInstanceOf[{{>constructionOptionalType}}] + {{#nullValid}}StructBuilder.unwrapExplicitNull({{/nullValid}}_fieldArray({{index}}){{#nullValid}}){{/nullValid}}.asInstanceOf[{{>constructionOptionalType}}] {{/fields|,}} ) } diff --git a/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/Generator.scala b/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/Generator.scala index 45a96c1f..84e5eb4f 100644 --- a/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/Generator.scala +++ b/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/Generator.scala @@ -270,14 +270,19 @@ abstract class TemplateGenerator(val resolvedDoc: ResolvedDocument) def quote(str: String): String = "\"" + str + "\"" def quoteKeyword(str: String): String - def isNullableType(t: FieldType, isOptional: Boolean = false): Boolean = { - !isOptional && ( - t match { - case TBool | TByte | TI16 | TI32 | TI64 | TDouble => false - case _ => true - } - ) - } + + def isNullableType(f: Field): Boolean = + !(f.requiredness.isOptional || isPrimitiveField(f)) + + // True if null is a valid value for the field + def isNullValid(f: Field): Boolean = + f.requiredness.isDefault && !isPrimitiveField(f) + + private[this] def isPrimitiveField(f: Field): Boolean = + f.fieldType match { + case TBool | TByte | TI16 | TI32 | TI64 | TDouble => true + case _ => false + } def getServiceParentID(parent: ServiceParent): Identifier = { val identifier: Identifier = parent.filename match { diff --git a/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/StructTemplate.scala b/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/StructTemplate.scala index c64406f2..09d93b92 100644 --- a/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/StructTemplate.scala +++ b/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/StructTemplate.scala @@ -442,7 +442,8 @@ trait StructTemplate { self: TemplateGenerator => ), "required" -> v(field.requiredness.isRequired), "optional" -> v(field.requiredness.isOptional), - "nullable" -> v(isNullableType(field.fieldType, field.requiredness.isOptional)), + "nullable" -> v(isNullableType(field)), + "nullValid" -> v(isNullValid(field)), "constructionOptional" -> v( !isConstructionRequiredField(field) && field.requiredness.isOptional ),