diff --git a/lib/thrift/generator/struct_binary_protocol.ex b/lib/thrift/generator/struct_binary_protocol.ex index 33ef1b94..4cd9fba8 100644 --- a/lib/thrift/generator/struct_binary_protocol.ex +++ b/lib/thrift/generator/struct_binary_protocol.ex @@ -769,50 +769,50 @@ defmodule Thrift.Generator.StructBinaryProtocol do end end - defp field_serializer(%Field{name: name, type: :bool, id: id, required: false}, struct_name, _file_group) do + defp field_serializer(%Field{name: name, type: :bool, id: id, required: true}, struct_name, _file_group) do quote do case unquote(Macro.var(name, nil)) do - nil -> - <<>> false -> <> true -> <> _ -> raise Thrift.InvalidValueException, - unquote("Optional boolean field #{inspect name} on #{inspect struct_name} must be true, false, or nil") + unquote("Required boolean field #{inspect name} on #{inspect struct_name} must be true or false") end end end defp field_serializer(%Field{name: name, type: :bool, id: id}, struct_name, _file_group) do quote do case unquote(Macro.var(name, nil)) do + nil -> + <<>> false -> <> true -> <> _ -> raise Thrift.InvalidValueException, - unquote("Required boolean field #{inspect name} on #{inspect struct_name} must be true or false") + unquote("Optional boolean field #{inspect name} on #{inspect struct_name} must be true, false, or nil") end end end - defp field_serializer(%Field{name: name, required: false} = field, _struct_name, file_group) do + defp field_serializer(%Field{name: name, required: true} = field, struct_name, file_group) do quote do case unquote(Macro.var(name, nil)) do nil -> - <<>> + raise Thrift.InvalidValueException, + unquote("Required field #{inspect name} on #{inspect struct_name} must not be nil") _ -> unquote(required_field_serializer(field, file_group)) end end end - defp field_serializer(%Field{name: name} = field, struct_name, file_group) do + defp field_serializer(%Field{name: name} = field, _struct_name, file_group) do quote do case unquote(Macro.var(name, nil)) do nil -> - raise Thrift.InvalidValueException, - unquote("Required field #{inspect name} on #{inspect struct_name} must not be nil") + <<>> _ -> unquote(required_field_serializer(field, file_group)) end diff --git a/test/generator/binary_protocol_test.exs b/test/generator/binary_protocol_test.exs index 738f70e8..e980b63d 100644 --- a/test/generator/binary_protocol_test.exs +++ b/test/generator/binary_protocol_test.exs @@ -359,9 +359,9 @@ defmodule Thrift.Generator.BinaryProtocolTest do struct RequiredBool { 1: required bool val } struct DefaultRequiredBool { 1: bool val } struct OptionalBool { 1: optional bool val } - struct RequiredField { 1: required string val } - struct DefaultRequiredField { 1: string val } - struct OptionalField { 1: optional string val } + struct RequiredField { 1: required i8 val } + struct DefaultRequiredField { 1: i8 val } + struct OptionalField { 1: optional i8 val } """ thrift_test "required boolean fields must not be nil during serialization" do @@ -371,15 +371,14 @@ defmodule Thrift.Generator.BinaryProtocolTest do end end - thrift_test "default required boolean fields must not be nil during serialization" do - message = "Required boolean field :val on Thrift.Generator.BinaryProtocolTest.DefaultRequiredBool must be true or false" - assert_raise Thrift.InvalidValueException, message, fn -> - DefaultRequiredBool.serialize(%DefaultRequiredBool{}) - end + thrift_test "default required boolean fields may be nil during serialization" do + assert_serializes %DefaultRequiredBool{}, <<0>> + assert_serializes %DefaultRequiredBool{val: true}, <<2, 0, 1, 1, 0>> end - thrift_test "optional boolean fields must not be nil during serialization" do - assert OptionalBool.serialize(%OptionalBool{}) + thrift_test "optional boolean fields may be nil during serialization" do + assert_serializes %OptionalBool{}, <<0>> + assert_serializes %OptionalBool{val: true}, <<2, 0, 1, 1, 0>> end thrift_test "required fields must not be nil during serialization" do @@ -389,14 +388,13 @@ defmodule Thrift.Generator.BinaryProtocolTest do end end - thrift_test "default required fields must not be nil during serialization" do - message = "Required field :val on Thrift.Generator.BinaryProtocolTest.DefaultRequiredField must not be nil" - assert_raise Thrift.InvalidValueException, message, fn -> - DefaultRequiredField.serialize(%DefaultRequiredField{}) - end + thrift_test "default required fields may be nil during serialization" do + assert_serializes %DefaultRequiredField{}, <<0>> + assert_serializes %DefaultRequiredField{val: 123}, <<3, 0, 1, 123, 0>> end - thrift_test "optional fields must not be nil during serialization" do - OptionalField.serialize(%OptionalField{}) + thrift_test "optional fields may be nil during serialization" do + assert_serializes %OptionalField{}, <<0>> + assert_serializes %OptionalField{val: 123}, <<3, 0, 1, 123, 0>> end end