Skip to content

Commit

Permalink
Merge pull request #24 from pylon/bs-signature
Browse files Browse the repository at this point in the history
add support for saved module signatures
  • Loading branch information
brentspell committed Oct 12, 2020
2 parents 317eb45 + 5c5cf7b commit 50147ef
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 52 deletions.
20 changes: 17 additions & 3 deletions c_src/tf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,11 @@ ERL_NIF_TERM nif_tf_load_saved_model (
status = CHECKALLOC(TF_NewStatus());
graph = CHECKALLOC(TF_NewGraph());
session_options = CHECKALLOC(TF_NewSessionOptions());
// initialize the metagraph output buffer
// note: while tensorflow returns this buffer to us, it maintains
// ownership of it inside the session, so we don't free it,
// unlike the other objects returned to us
TF_Buffer metagraph; memset(&metagraph, 0, sizeof(metagraph));
// retrieve the model path
ErlNifBinary path_bin; memset(&path_bin, 0, sizeof(path_bin));
CHECK(enif_inspect_binary(env, argv[0], &path_bin), "invalid_path");
Expand Down Expand Up @@ -261,7 +266,7 @@ ERL_NIF_TERM nif_tf_load_saved_model (
tag_strs,
1,
graph,
NULL,
&metagraph,
status);
tf_check_status(status);
// create the erlang resource to wrap the session
Expand All @@ -272,8 +277,17 @@ ERL_NIF_TERM nif_tf_load_saved_model (
resource->graph = graph;
session = NULL;
graph = NULL;
// relinquish the model resource to erlang
result = enif_make_resource(env, resource);
// create the erlang binary for the metagraph protobuf
ERL_NIF_TERM metagraph_bin;
void* metagraph_data = CHECKALLOC(enif_make_new_binary(
env,
metagraph.length,
&metagraph_bin));
memcpy(metagraph_data, metagraph.data, metagraph.length);
// relinquish the model resource and metagraph to erlang
result = enif_make_tuple2(env,
enif_make_resource(env, resource),
metagraph_bin);
enif_release_resource(resource);
} catch (NifError& e) {
result = e.to_term(env);
Expand Down
2 changes: 1 addition & 1 deletion lib/extensor/nif.ex
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ defmodule Extensor.NIF do
path :: String.t(),
tag :: String.t(),
config_pb :: binary()
) :: reference()
) :: {reference(), binary()}
def tf_load_saved_model(_path, _tag, _config_pb) do
:erlang.nif_error(:nif_library_not_loaded)
end
Expand Down
107 changes: 86 additions & 21 deletions lib/extensor/session.ex
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,12 @@ defmodule Extensor.Session do
"""

alias Extensor.{NIF, Tensor}
alias Tensorflow.ConfigProto
alias Tensorflow.{ConfigProto, MetaGraphDef, SignatureDef}

@type t :: reference()
@type t :: %__MODULE__{
resource: reference(),
signatures: %{String.t() => SignatureDef.t()}
}

@default_config ConfigProto.new()

Expand Down Expand Up @@ -87,6 +90,9 @@ defmodule Extensor.Session do
}

@tft2atom Map.new(@atom2tft, fn {k, v} -> {v, k} end)
@signature_default %SignatureDef{inputs: %{}, outputs: %{}}

defstruct [:resource, signatures: %{}]

@doc "loads a custom op kernel library"
@spec load_library(name :: String.t()) :: :ok | {:error, any()}
Expand Down Expand Up @@ -142,7 +148,11 @@ defmodule Extensor.Session do
config :: %ConfigProto{}
) :: t() | no_return()
def parse_frozen_graph!(graph_pb, config \\ @default_config) do
NIF.tf_parse_frozen_graph(graph_pb, ConfigProto.encode(config))
resource = NIF.tf_parse_frozen_graph(graph_pb, ConfigProto.encode(config))

%__MODULE__{
resource: resource
}
end

@doc "loads a saved_model from a directory path"
Expand All @@ -164,17 +174,34 @@ defmodule Extensor.Session do
tag :: String.t()
) :: t() | no_return()
def load_saved_model!(path, config \\ @default_config, tag \\ "serve") do
NIF.tf_load_saved_model(path, tag, ConfigProto.encode(config))
# load the saved model directory from the nif
{resource, metagraph} =
NIF.tf_load_saved_model(path, tag, ConfigProto.encode(config))

# parse the metagraph for signatures
metagraph = MetaGraphDef.decode(metagraph)

# return the session
%__MODULE__{
resource: resource,
signatures: metagraph.signature_def
}
end

@doc "executes a tensorflow session"
@spec run(
session :: t(),
input_tensors :: %{String.t() => Tensor.t()},
output_names :: [String.t(), ...]
output_names :: [String.t(), ...] | nil,
signature :: String.t()
) :: {:ok, %{String.t() => Tensor.t()}} | {:error, any}
def run(session, input_tensors, output_names) do
{:ok, run!(session, input_tensors, output_names)}
def run(
session,
input_tensors,
output_names \\ nil,
signature \\ "serving_default"
) do
{:ok, run!(session, input_tensors, output_names, signature)}
rescue
e -> {:error, e}
end
Expand All @@ -183,24 +210,62 @@ defmodule Extensor.Session do
@spec run!(
session :: t(),
input_tensors :: %{String.t() => Tensor.t()},
output_names :: [String.t(), ...]
output_names :: [String.t(), ...] | nil,
signature :: String.t()
) :: %{String.t() => Tensor.t()}
def run!(session, input_tensors, output_names) do
input_tensors = ex2tf(input_tensors)
output_tensors = NIF.tf_run_session(session, input_tensors, output_names)
tf2ex(output_tensors)
def run!(
session,
input_tensors,
output_names \\ nil,
signature \\ "serving_default"
) do
# fetch the metadata for the requested signature and
# default the output tensors to the signature outputs
signature = session.signatures[signature] || @signature_default
output_names = output_names || Map.keys(signature.outputs)

# map the input names through the signaturedef mapping
# convert the input tensor structs to values accepted by the nif
input_tensors =
input_tensors
|> Map.new(fn {k, v} -> {sig2tensor(k, signature.inputs), ex2tf(v)} end)

# create a parallel list of mapped tensor names,
# so that we can invert the mapping on the other side
result_names =
output_names
|> Enum.map(&sig2tensor(&1, signature.outputs))

# run tensorflow inference
output_tensors =
NIF.tf_run_session(
session.resource,
input_tensors,
result_names
)

# map the output tensor names through the signaturedef mapping
# convert the nif tensor values to elixir tensor structs
Enum.zip(output_names, result_names)
|> Map.new(fn {s, t} -> {s, tf2ex(output_tensors[t])} end)
end

# convert a metagraph signature name to a tensor name
defp sig2tensor(key, map) do
case map[key] do
%{encoding: {:name, key}} -> key
_ -> key
end
end

defp ex2tf(tensors) do
Map.new(tensors, fn {k, v} ->
Tensor.validate!(v)
{k, {Map.fetch!(@atom2tft, v.type), v.shape, v.data}}
end)
# convert an elixir tensor to a tensorflow tensor tuple
defp ex2tf(tensor) do
Tensor.validate!(tensor)
{Map.fetch!(@atom2tft, tensor.type), tensor.shape, tensor.data}
end

defp tf2ex(tensors) do
Map.new(tensors, fn {k, {t, s, d}} ->
{k, %Tensor{type: Map.fetch!(@tft2atom, t), shape: s, data: d}}
end)
# convert a tensorflow tensor tuple to an elixir tensor
defp tf2ex({type, shape, data}) do
%Tensor{type: Map.fetch!(@tft2atom, type), shape: shape, data: data}
end
end
2 changes: 1 addition & 1 deletion mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ defmodule Extensor.MixProject do
[
app: :extensor,
name: "Extensor",
version: "2.3.1",
version: "2.3.2",
elixir: "~> 1.9",
compilers: [:elixir_make] ++ Mix.compilers(),
make_cwd: "c_src",
Expand Down
6 changes: 3 additions & 3 deletions test/data/pythagoras.pb
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
dtype0*
shape:
,
b Placeholder*
dtype0*
shape:
b Placeholder*
shape:*
dtype0

SquareSquarea*
T0
Expand Down
Binary file modified test/data/pythagoras/saved_model.pb
Binary file not shown.
50 changes: 29 additions & 21 deletions test/extensor/session_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,11 @@ defmodule Extensor.SessionTest do
end

# valid protobuf
{:ok, session} = Session.parse_frozen_graph(graph_def)
assert is_reference(session)
{:ok, _session} = Session.parse_frozen_graph(graph_def)

session = Session.parse_frozen_graph!(graph_def)
assert is_reference(session)
Session.parse_frozen_graph!(graph_def)

session = Session.parse_frozen_graph!(graph_def, config)
assert is_reference(session)
Session.parse_frozen_graph!(graph_def, config)
end

test "load frozen graph file" do
Expand All @@ -65,14 +62,11 @@ defmodule Extensor.SessionTest do
end

# valid file path
{:ok, session} = Session.load_frozen_graph(graph_path)
assert is_reference(session)
{:ok, _session} = Session.load_frozen_graph(graph_path)

session = Session.load_frozen_graph!(graph_path)
assert is_reference(session)
Session.load_frozen_graph!(graph_path)

session = Session.load_frozen_graph!(graph_path, config)
assert is_reference(session)
Session.load_frozen_graph!(graph_path, config)
end

test "load saved_model directory" do
Expand All @@ -98,21 +92,17 @@ defmodule Extensor.SessionTest do
end

# valid model
{:ok, session} = Session.load_saved_model(model_path)
assert is_reference(session)
{:ok, _session} = Session.load_saved_model(model_path)

session = Session.load_saved_model!(model_path)
assert is_reference(session)
Session.load_saved_model!(model_path)

session = Session.load_saved_model!(model_path, config)
assert is_reference(session)
Session.load_saved_model!(model_path, config)

session = Session.load_saved_model!(model_path, config, "serve")
assert is_reference(session)
Session.load_saved_model!(model_path, config, "serve")
end

test "run session" do
session = Session.load_frozen_graph!("test/data/pythagoras.pb")
session = Session.load_saved_model!("test/data/pythagoras")

# missing input/output tensors
input = %{
Expand Down Expand Up @@ -196,6 +186,24 @@ defmodule Extensor.SessionTest do

output = Session.run!(session, input, ["c"])
assert Tensor.to_list(output["c"]) == [[5], [13]]

# metagraph name mapping
input = %{
"a_input" => Tensor.from_list([3]),
"b_input" => Tensor.from_list([4])
}

{:ok, output} = Session.run(session, input, ["c_output"])
assert Tensor.to_list(output["c_output"]) == [5]

# metagraph output defaults
input = %{
"a_input" => Tensor.from_list([3]),
"b_input" => Tensor.from_list([4])
}

{:ok, output} = Session.run(session, input)
assert Tensor.to_list(output["c_output"]) == [5]
end

test "global parallelism" do
Expand Down
4 changes: 2 additions & 2 deletions test/pythagoras.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@
# save in the saved_model format, which can include non-const variables
tf.saved_model.simple_save(session,
'test/data/pythagoras',
inputs={'a': a, 'b': b},
outputs={'c': c})
inputs={'a_input': a, 'b_input': b},
outputs={'c_output': c})

0 comments on commit 50147ef

Please sign in to comment.