diff --git a/lib/thrift/generator/struct_binary_protocol.ex b/lib/thrift/generator/struct_binary_protocol.ex index a1b60f0a..4f9f15da 100644 --- a/lib/thrift/generator/struct_binary_protocol.ex +++ b/lib/thrift/generator/struct_binary_protocol.ex @@ -94,6 +94,7 @@ defmodule Thrift.Generator.StructBinaryProtocol do {acc, rest} end unquote_splicing(field_deserializers) + defp deserialize(_, _), do: :error end end @@ -154,8 +155,12 @@ defmodule Thrift.Generator.StructBinaryProtocol do def field_deserializer(:string, field, name, _file_group) do quote do defp unquote(name)(<<11, unquote(field.id)::16-signed, string_size::32-signed, rest::binary>>, acc) do - <> = rest - unquote(name)(rest, %{acc | unquote(field.name) => value}) + case rest do + <> -> + unquote(name)(rest, %{acc | unquote(field.name) => value}) + _ -> + :error + end end end end @@ -163,8 +168,12 @@ defmodule Thrift.Generator.StructBinaryProtocol do dest_module = FileGroup.dest_module(file_group, struct) quote do defp unquote(name)(<<12, unquote(field.id)::16-signed, rest::binary>>, acc) do - {value, rest} = unquote(dest_module).BinaryProtocol.deserialize(rest) - unquote(name)(rest, %{acc | unquote(field.name) => value}) + case unquote(dest_module).BinaryProtocol.deserialize(rest) do + {value, rest} -> + unquote(name)(rest, %{acc | unquote(field.name) => value}) + :error -> + :error + end end end end @@ -185,6 +194,8 @@ defmodule Thrift.Generator.StructBinaryProtocol do end unquote(map_key_deserializer(key_type, key_name, value_name, file_group)) unquote(map_value_deserializer(value_type, key_name, value_name, file_group)) + defp unquote(key_name)(_, _), do: :error + defp unquote(value_name)(_, _, _), do: :error end end def field_deserializer({:set, element_type}, field, name, file_group) do @@ -197,6 +208,7 @@ defmodule Thrift.Generator.StructBinaryProtocol do unquote(name)(rest, %{struct | unquote(field.name) => MapSet.new(Enum.reverse(list))}) end unquote(list_deserializer(element_type, sub_name, file_group)) + defp unquote(sub_name)(_, _), do: :error end end def field_deserializer({:list, element_type}, field, name, file_group) do @@ -209,6 +221,7 @@ defmodule Thrift.Generator.StructBinaryProtocol do unquote(name)(rest, %{struct | unquote(field.name) => Enum.reverse(list)}) end unquote(list_deserializer(element_type, sub_name, file_group)) + defp unquote(sub_name)(_, _), do: :error end end def field_deserializer(%StructRef{referenced_type: type}, field, name, file_group) do @@ -271,8 +284,12 @@ defmodule Thrift.Generator.StructBinaryProtocol do def map_key_deserializer(:string, key_name, value_name, _file_group) do quote do defp unquote(key_name)(<>, stack) do - <> = rest - unquote(value_name)(rest, key, stack) + case rest do + <> -> + unquote(value_name)(rest, key, stack) + _ -> + :error + end end end end @@ -280,8 +297,12 @@ defmodule Thrift.Generator.StructBinaryProtocol do dest_module = FileGroup.dest_module(file_group, struct) quote do defp unquote(key_name)(<>, stack) do - {key, rest} = unquote(dest_module).BinaryProtocol.deserialize(rest) - unquote(value_name)(rest, key, stack) + case unquote(dest_module).BinaryProtocol.deserialize(rest) do + {key, rest} -> + unquote(value_name)(rest, key, stack) + :error -> + :error + end end end end @@ -297,6 +318,8 @@ defmodule Thrift.Generator.StructBinaryProtocol do end unquote(map_key_deserializer(key_type, child_key_name, child_value_name, file_group)) unquote(map_value_deserializer(value_type, child_key_name, child_value_name, file_group)) + defp unquote(child_key_name)(_, _), do: :error + defp unquote(child_value_name)(_, _, _), do: :error end end def map_key_deserializer({:set, element_type}, key_name, value_name, file_group) do @@ -309,6 +332,7 @@ defmodule Thrift.Generator.StructBinaryProtocol do unquote(value_name)(rest, MapSet.new(Enum.reverse(key)), stack) end unquote(list_deserializer(element_type, sub_name, file_group)) + defp unquote(sub_name)(_, _), do: :error end end def map_key_deserializer({:list, element_type}, key_name, value_name, file_group) do @@ -321,6 +345,7 @@ defmodule Thrift.Generator.StructBinaryProtocol do unquote(value_name)(rest, Enum.reverse(key), stack) end unquote(list_deserializer(element_type, sub_name, file_group)) + defp unquote(sub_name)(_, _), do: :error end end def map_key_deserializer(%StructRef{referenced_type: type}, key_name, value_name, file_group) do @@ -383,8 +408,12 @@ defmodule Thrift.Generator.StructBinaryProtocol do def map_value_deserializer(:string, key_name, value_name, _file_group) do quote do defp unquote(value_name)(<>, key, [map, remaining | stack]) do - <> = rest - unquote(key_name)(rest, [Map.put(map, key, value), remaining - 1 | stack]) + case rest do + <> -> + unquote(key_name)(rest, [Map.put(map, key, value), remaining - 1 | stack]) + _ -> + :error + end end end end @@ -392,8 +421,12 @@ defmodule Thrift.Generator.StructBinaryProtocol do dest_module = FileGroup.dest_module(file_group, struct) quote do defp unquote(value_name)(<>, key, [map, remaining | stack]) do - {value, rest} = unquote(dest_module).BinaryProtocol.deserialize(rest) - unquote(key_name)(rest, [Map.put(map, key, value), remaining - 1 | stack]) + case unquote(dest_module).BinaryProtocol.deserialize(rest) do + {value, rest} -> + unquote(key_name)(rest, [Map.put(map, key, value), remaining - 1 | stack]) + :error -> + :error + end end end end @@ -409,6 +442,8 @@ defmodule Thrift.Generator.StructBinaryProtocol do end unquote(map_key_deserializer(key_type, child_key_name, child_value_name, file_group)) unquote(map_value_deserializer(value_type, child_key_name, child_value_name, file_group)) + defp unquote(child_key_name)(_, _), do: :error + defp unquote(child_value_name)(_, _, _), do: :error end end def map_value_deserializer({:set, element_type}, key_name, value_name, file_group) do @@ -421,6 +456,7 @@ defmodule Thrift.Generator.StructBinaryProtocol do unquote(key_name)(rest, [Map.put(map, key, MapSet.new(Enum.reverse(value))), remaining - 1 | stack]) end unquote(list_deserializer(element_type, sub_name, file_group)) + defp unquote(sub_name)(_, _), do: :error end end def map_value_deserializer({:list, element_type}, key_name, value_name, file_group) do @@ -433,6 +469,7 @@ defmodule Thrift.Generator.StructBinaryProtocol do unquote(key_name)(rest, [Map.put(map, key, Enum.reverse(value)), remaining - 1 | stack]) end unquote(list_deserializer(element_type, sub_name, file_group)) + defp unquote(sub_name)(_, _), do: :error end end def map_value_deserializer(%StructRef{referenced_type: type}, key_name, value_name, file_group) do @@ -495,8 +532,12 @@ defmodule Thrift.Generator.StructBinaryProtocol do def list_deserializer(:string, name, _file_group) do quote do defp unquote(name)(<>, [list, remaining | stack]) do - <> = rest - unquote(name)(rest, [[element | list], remaining - 1 | stack]) + case rest do + <> -> + unquote(name)(rest, [[element | list], remaining - 1 | stack]) + _ -> + :error + end end end end @@ -504,8 +545,12 @@ defmodule Thrift.Generator.StructBinaryProtocol do dest_module = FileGroup.dest_module(file_group, struct) quote do defp unquote(name)(<>, [list, remaining | stack]) do - {element, rest} = unquote(dest_module).BinaryProtocol.deserialize(rest) - unquote(name)(rest, [[element | list], remaining - 1 | stack]) + case unquote(dest_module).BinaryProtocol.deserialize(rest) do + {element, rest} -> + unquote(name)(rest, [[element | list], remaining - 1 | stack]) + :error -> + :error + end end end end @@ -525,6 +570,8 @@ defmodule Thrift.Generator.StructBinaryProtocol do end unquote(map_key_deserializer(key_type, key_name, value_name, file_group)) unquote(map_value_deserializer(value_type, key_name, value_name, file_group)) + defp unquote(key_name)(_, _), do: :error + defp unquote(value_name)(_, _, _), do: :error end end def list_deserializer({:set, element_type}, name, file_group) do @@ -537,6 +584,7 @@ defmodule Thrift.Generator.StructBinaryProtocol do unquote(name)(rest, [[MapSet.new(Enum.reverse(inner_list)) | list], remaining - 1 | stack]) end unquote(list_deserializer(element_type, sub_name, file_group)) + defp unquote(sub_name)(_, _), do: :error end end def list_deserializer({:list, element_type}, name, file_group) do @@ -549,6 +597,7 @@ defmodule Thrift.Generator.StructBinaryProtocol do unquote(name)(rest, [[Enum.reverse(inner_list) | list], remaining - 1 | stack]) end unquote(list_deserializer(element_type, sub_name, file_group)) + defp unquote(sub_name)(_, _), do: :error end end def list_deserializer(%StructRef{referenced_type: type}, name, file_group) do diff --git a/test/generator/binary_protocol_test.exs b/test/generator/binary_protocol_test.exs index 28f59f3f..5926a771 100644 --- a/test/generator/binary_protocol_test.exs +++ b/test/generator/binary_protocol_test.exs @@ -6,6 +6,21 @@ defmodule Thrift.Generator.BinaryProtocolTest do def assert_serializes(struct=%{__struct__: mod}, binary) do assert binary == Binary.serialize(:struct, struct) |> IO.iodata_to_binary assert {^struct, ""} = mod.deserialize(binary) + + # If we randomly mutate any byte in the binary, it may deserialize to a + # struct of the proper type, or it may return :error. But it should never + # raise. + for i <- 1..byte_size(binary) do + mutated_binary = binary + |> :binary.bin_to_list + |> List.replace_at(i - 1, :rand.uniform(256) - 1) + |> :binary.list_to_bin + + case mod.deserialize(mutated_binary) do + {%{__struct__: ^mod}, _} -> :ok + :error -> :ok + end + end end def assert_serializes(struct=%{__struct__: mod}, binary, deserialized_struct=%{__struct__: mod}) do