Skip to content
Permalink
Browse files

DCE c2 nets (#2419)

  • Loading branch information...
jackm321 committed Feb 21, 2019
1 parent 74bfbdc commit ac7ee90524bd5ad0bad2a3af61e410434af97ed2
0 python
No changes.
@@ -0,0 +1,74 @@
name: "init"
op {
output: "conv_w"
type: "GivenTensorFill"
arg {
name: "shape"
ints: 1
ints: 1
ints: 2
ints: 2
}
arg {
name: "values"
floats: 1.0
floats: 1.0
floats: 1.0
floats: 1.0
}
}
op {
output: "conv_b"
type: "GivenTensorFill"
arg {
name: "shape"
ints: 1
}
arg {
name: "values"
floats: 2.0
}
}
op {
output: "fc_w"
type: "GivenTensorFill"
arg {
name: "shape"
ints: 4
ints: 3
}
arg {
name: "values"
floats: 1.0
floats: 2.0
floats: 3.0
floats: 4.0
floats: 5.0
floats: 6.0
floats: 7.0
floats: 8.0
floats: 9.0
floats: 10.0
floats: 11.0
floats: 12.0
}
}
op {
output: "fc_b"
type: "GivenTensorFill"
arg {
name: "shape"
ints: 4
}
arg {
name: "values"
floats: 0.1
floats: 0.2
floats: 0.3
floats: 0.4
}
}
external_output: "conv_w"
external_output: "conv_b"
external_output: "fc_w"
external_output: "fc_b"
@@ -0,0 +1,43 @@
name: "dce_test"
op {
input: "fc_inputs"
input: "fc_w"
input: "fc_b"
output: "fc_result"
name: ""
type: "FC"
}
op {
input: "conv_inputs"
input: "conv_w"
input: "conv_b"
output: "conv_result"
type: "Conv"
arg {
name: "order"
s: "NHWC"
}
arg {
name: "kernel"
i: 2
}
arg {
name: "stride"
i: 1
}
arg {
name: "group"
i: 1
}
arg {
name: "pad"
i: 1
}
}
external_input: "fc_inputs"
external_input: "fc_w"
external_input: "fc_b"
external_input: "conv_inputs"
external_input: "conv_w"
external_input: "conv_b"
external_output: "conv_result"
@@ -0,0 +1,132 @@
from caffe2.proto import caffe2_pb2
from google.protobuf import text_format
import argparse

def read_model_from_file(path):
m = caffe2_pb2.NetDef()
with open(path, "rb") as f:
if ".pbtxt" in path:
text_format.Merge(f.read(), m)
else:
m.ParseFromString(f.read())
return m

def write_model_to_file(path, m):
with open(path, "wb") as f:
if ".pbtxt" in path:
f.write(text_format.MessageToString(m))
else:
f.write(m.SerializeToString())

# Perform dead code elimination on predict_net removing any nodes that aren't
# used for producing values in predict_net.external_output. Remove any nodes in
# init_net that produce values that are no longer needed by predict_net.
def dce(init_net, predict_net):
num_predict_net_ops_original = len(predict_net.op)
num_predict_net_inputs_original = len(predict_net.external_input)

# Find the set of tensors used in the computation of the outputs.
live_predict_net_op_outputs = set(predict_net.external_output)
prev_num_live_predict_net_op_outputs = len(live_predict_net_op_outputs)
while True:
for op in predict_net.op:
for output_tensor in op.output:
if output_tensor in live_predict_net_op_outputs:
for input_tensor in op.input:
live_predict_net_op_outputs.add(input_tensor)
num_live_predict_net_op_outputs = len(live_predict_net_op_outputs)
if num_live_predict_net_op_outputs == prev_num_live_predict_net_op_outputs:
break
prev_num_live_predict_net_op_outputs = num_live_predict_net_op_outputs

# Find the ops that are required to compute the tensors used during
# computation of the outputs.
live_predict_net_ops = []
for op in predict_net.op:
for output_tensor in op.output:
if output_tensor in live_predict_net_op_outputs:
live_predict_net_ops.append(op)

# Delete all unused ops in predict_net.
num_predict_net_ops_eliminated = len(predict_net.op) - len(live_predict_net_ops)
del predict_net.op[:]
predict_net.op.extend(live_predict_net_ops)

# Find the set of all used inputs tensors in predict_net.
live_predict_net_op_inputs = set()
for op in predict_net.op:
for input_tensor in op.input:
live_predict_net_op_inputs.add(input_tensor)

# Find the set of used external_inputs.
live_predict_net_external_inputs = set()
for external_input in predict_net.external_input:
if external_input in live_predict_net_op_inputs:
live_predict_net_external_inputs.add(external_input)

# Delete unused external_inputs in predict_net.
num_predict_net_inputs_eliminated = len(predict_net.external_input) - len(live_predict_net_external_inputs)
del predict_net.external_input[:]
predict_net.external_input.extend(live_predict_net_external_inputs)

print("predict_net ops eliminated: {}/{}".format(num_predict_net_ops_eliminated, num_predict_net_ops_original))
print("predict_net external_inputs eliminated: {}/{}".format(num_predict_net_inputs_eliminated, num_predict_net_inputs_original))

# Everything below pertains to removing unused outputs in the init_net,
# if no init net was provided then stop here.
if init_net == None:
return

num_init_net_ops_original = len(init_net.op)

# Find the set of init_net ops with outputs needed by the init_net
live_init_net_ops = []
for op in init_net.op:
for output_tensor in op.output:
if output_tensor in live_predict_net_external_inputs:
live_init_net_ops.append(op)

# Eliminate dead init_net ops
num_init_net_ops_eliminated = len(init_net.op) - len(live_init_net_ops)
del init_net.op[:]
init_net.op.extend(live_init_net_ops)

# Update init_net external_outputs
live_init_net_op_outputs = set()
for op in init_net.op:
for output_tensor in op.output:
live_init_net_op_outputs.add(output_tensor)

live_init_net_external_outputs = set()
for output_tensor in init_net.external_output:
if output_tensor in live_init_net_op_outputs:
live_init_net_external_outputs.add(output_tensor)

del init_net.external_output[:]
init_net.external_output.extend(live_init_net_external_outputs)

print("init_net ops eliminated: {}/{}".format(num_init_net_ops_eliminated, num_init_net_ops_original))


if __name__ == "__main__":
parser = argparse.ArgumentParser("Caffe2 model dead code elimination")
parser.add_argument('--input_init_net_path', type=str)
parser.add_argument('--input_predict_net_path', type=str, required=True)
parser.add_argument('--output_init_net_path', type=str)
parser.add_argument('--output_predict_net_path', type=str, required=True)

args = parser.parse_args()

predict_net = read_model_from_file(args.input_predict_net_path)

init_net = None
if args.input_init_net_path != None:
init_net = read_model_from_file(args.input_init_net_path)

dce(init_net, predict_net)

write_model_to_file(args.output_predict_net_path, predict_net)

if args.output_init_net_path != None:
write_model_to_file(args.output_init_net_path, init_net)

0 comments on commit ac7ee90

Please sign in to comment.
You can’t perform that action at this time.