Skip to content

Commit

Permalink
multiple fix with path/prefix
Browse files Browse the repository at this point in the history
  • Loading branch information
zhreshold committed Jun 26, 2017
1 parent ef7a773 commit d741ab1
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 25 deletions.
5 changes: 4 additions & 1 deletion demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,10 @@ def parse_class_names(class_names):

network = None if args.deploy_net else args.network
class_names = parse_class_names(args.class_names)
prefix = args.prefix + args.network + '_' + str(args.data_shape)
if args.prefix.endswith('_'):
prefix = args.prefix + args.network + '_' + str(args.data_shape)
else:
prefix = args.prefix
detector = get_detector(network, prefix, args.epoch,
args.data_shape,
(args.mean_r, args.mean_g, args.mean_b),
Expand Down
9 changes: 6 additions & 3 deletions deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@ def parse_args():

if __name__ == '__main__':
args = parse_args()
net = get_symbol(args.network).get_symbol(args.network, args.data_shape,
net = get_symbol(args.network, args.data_shape,
num_classes=args.num_classes, nms_thresh=args.nms_thresh,
force_suppress=args.force_nms, nms_topk=args.nms_topk)
prefix = args.prefix + args.network + '_' + str(args.data_shape)
if args.prefix.endswith('_'):
prefix = args.prefix + args.network + '_' + str(args.data_shape)
else:
prefix = args.prefix
_, arg_params, aux_params = mx.model.load_checkpoint(prefix, args.epoch)
# new name
tmp = args.prefix.rsplit('/', 1)
tmp = prefix.rsplit('/', 1)
save_prefix = '/deploy_'.join(tmp)
mx.model.save_checkpoint(save_prefix, args.epoch, net, arg_params, aux_params)
print("Saved model: {}-{:04d}.param".format(save_prefix, args.epoch))
Expand Down
6 changes: 5 additions & 1 deletion evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,13 @@ def parse_args():
class_names = None

network = None if args.deploy_net else args.network
if args.prefix.endswith('_'):
prefix = args.prefix + args.network
else:
prefix = args.prefix
evaluate_net(network, args.rec_path, num_class,
(args.mean_r, args.mean_g, args.mean_b), args.data_shape,
args.prefix + args.network, args.epoch, ctx, batch_size=args.batch_size,
prefix, args.epoch, ctx, batch_size=args.batch_size,
path_imglist=args.list_path, nms_thresh=args.nms_thresh,
force_nms=args.force_nms, ovp_thresh=args.overlap_thresh,
use_difficult=args.use_difficult, class_names=class_names,
Expand Down
39 changes: 39 additions & 0 deletions symbol/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,45 @@ def conv_act_layer(from_layer, name, num_filter, kernel=(1,1), pad=(0,0), \
name="{}_{}".format(name, act_type))
return relu

def legacy_conv_act_layer(from_layer, name, num_filter, kernel=(1,1), pad=(0,0), \
stride=(1,1), act_type="relu", use_batchnorm=False):
"""
wrapper for a small Convolution group
Parameters:
----------
from_layer : mx.symbol
continue on which layer
name : str
base name of the new layers
num_filter : int
how many filters to use in Convolution layer
kernel : tuple (int, int)
kernel size (h, w)
pad : tuple (int, int)
padding size (h, w)
stride : tuple (int, int)
stride size (h, w)
act_type : str
activation type, can be relu...
use_batchnorm : bool
whether to use batch normalization
Returns:
----------
(conv, relu) mx.Symbols
"""
assert not use_batchnorm, "batchnorm not yet supported"
bias = mx.symbol.Variable(name="conv{}_bias".format(name),
init=mx.init.Constant(0.0), attr={'__lr_mult__': '2.0'})
conv = mx.symbol.Convolution(data=from_layer, bias=bias, kernel=kernel, pad=pad, \
stride=stride, num_filter=num_filter, name="conv{}".format(name))
relu = mx.symbol.Activation(data=conv, act_type=act_type, \
name="{}{}".format(act_type, name))
if use_batchnorm:
relu = mx.symbol.BatchNorm(data=relu, name="bn{}".format(name))
return conv, relu

def multi_layer_feature(body, from_layers, num_filters, strides, pads, min_filter=128):
"""Wrapper function to extract features from base network, attaching extra
layers and SSD specific layers
Expand Down
18 changes: 9 additions & 9 deletions symbol/legacy_vgg16_ssd_300.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import mxnet as mx
from common import conv_act_layer
from common import legacy_conv_act_layer
from common import multibox_layer

def get_symbol_train(num_classes=20, nms_thresh=0.5, force_suppress=False,
Expand Down Expand Up @@ -97,21 +97,21 @@ def get_symbol_train(num_classes=20, nms_thresh=0.5, force_suppress=False,
# drop7 = mx.symbol.Dropout(data=relu7, p=0.5, name="drop7")

### ssd extra layers ###
conv8_1, relu8_1 = conv_act_layer(relu7, "8_1", 256, kernel=(1,1), pad=(0,0), \
conv8_1, relu8_1 = legacy_conv_act_layer(relu7, "8_1", 256, kernel=(1,1), pad=(0,0), \
stride=(1,1), act_type="relu", use_batchnorm=False)
conv8_2, relu8_2 = conv_act_layer(relu8_1, "8_2", 512, kernel=(3,3), pad=(1,1), \
conv8_2, relu8_2 = legacy_conv_act_layer(relu8_1, "8_2", 512, kernel=(3,3), pad=(1,1), \
stride=(2,2), act_type="relu", use_batchnorm=False)
conv9_1, relu9_1 = conv_act_layer(relu8_2, "9_1", 128, kernel=(1,1), pad=(0,0), \
conv9_1, relu9_1 = legacy_conv_act_layer(relu8_2, "9_1", 128, kernel=(1,1), pad=(0,0), \
stride=(1,1), act_type="relu", use_batchnorm=False)
conv9_2, relu9_2 = conv_act_layer(relu9_1, "9_2", 256, kernel=(3,3), pad=(1,1), \
conv9_2, relu9_2 = legacy_conv_act_layer(relu9_1, "9_2", 256, kernel=(3,3), pad=(1,1), \
stride=(2,2), act_type="relu", use_batchnorm=False)
conv10_1, relu10_1 = conv_act_layer(relu9_2, "10_1", 128, kernel=(1,1), pad=(0,0), \
conv10_1, relu10_1 = legacy_conv_act_layer(relu9_2, "10_1", 128, kernel=(1,1), pad=(0,0), \
stride=(1,1), act_type="relu", use_batchnorm=False)
conv10_2, relu10_2 = conv_act_layer(relu10_1, "10_2", 256, kernel=(3,3), pad=(0,0), \
conv10_2, relu10_2 = legacy_conv_act_layer(relu10_1, "10_2", 256, kernel=(3,3), pad=(0,0), \
stride=(1,1), act_type="relu", use_batchnorm=False)
conv11_1, relu11_1 = conv_act_layer(relu10_2, "11_1", 128, kernel=(1,1), pad=(0,0), \
conv11_1, relu11_1 = legacy_conv_act_layer(relu10_2, "11_1", 128, kernel=(1,1), pad=(0,0), \
stride=(1,1), act_type="relu", use_batchnorm=False)
conv11_2, relu11_2 = conv_act_layer(relu11_1, "11_2", 256, kernel=(3,3), pad=(0,0), \
conv11_2, relu11_2 = legacy_conv_act_layer(relu11_1, "11_2", 256, kernel=(3,3), pad=(0,0), \
stride=(1,1), act_type="relu", use_batchnorm=False)

# specific parameters for VGG16 network
Expand Down
22 changes: 11 additions & 11 deletions symbol/legacy_vgg16_ssd_512.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import mxnet as mx
from common import conv_act_layer
from common import legacy_conv_act_layer
from common import multibox_layer

def get_symbol_train(num_classes=20, nms_thresh=0.5, force_suppress=False, nms_topk=400):
Expand Down Expand Up @@ -96,25 +96,25 @@ def get_symbol_train(num_classes=20, nms_thresh=0.5, force_suppress=False, nms_t
# drop7 = mx.symbol.Dropout(data=relu7, p=0.5, name="drop7")

### ssd extra layers ###
conv8_1, relu8_1 = conv_act_layer(relu7, "8_1", 256, kernel=(1,1), pad=(0,0), \
conv8_1, relu8_1 = legacy_conv_act_layer(relu7, "8_1", 256, kernel=(1,1), pad=(0,0), \
stride=(1,1), act_type="relu", use_batchnorm=False)
conv8_2, relu8_2 = conv_act_layer(relu8_1, "8_2", 512, kernel=(3,3), pad=(1,1), \
conv8_2, relu8_2 = legacy_conv_act_layer(relu8_1, "8_2", 512, kernel=(3,3), pad=(1,1), \
stride=(2,2), act_type="relu", use_batchnorm=False)
conv9_1, relu9_1 = conv_act_layer(relu8_2, "9_1", 128, kernel=(1,1), pad=(0,0), \
conv9_1, relu9_1 = legacy_conv_act_layer(relu8_2, "9_1", 128, kernel=(1,1), pad=(0,0), \
stride=(1,1), act_type="relu", use_batchnorm=False)
conv9_2, relu9_2 = conv_act_layer(relu9_1, "9_2", 256, kernel=(3,3), pad=(1,1), \
conv9_2, relu9_2 = legacy_conv_act_layer(relu9_1, "9_2", 256, kernel=(3,3), pad=(1,1), \
stride=(2,2), act_type="relu", use_batchnorm=False)
conv10_1, relu10_1 = conv_act_layer(relu9_2, "10_1", 128, kernel=(1,1), pad=(0,0), \
conv10_1, relu10_1 = legacy_conv_act_layer(relu9_2, "10_1", 128, kernel=(1,1), pad=(0,0), \
stride=(1,1), act_type="relu", use_batchnorm=False)
conv10_2, relu10_2 = conv_act_layer(relu10_1, "10_2", 256, kernel=(3,3), pad=(1,1), \
conv10_2, relu10_2 = legacy_conv_act_layer(relu10_1, "10_2", 256, kernel=(3,3), pad=(1,1), \
stride=(2,2), act_type="relu", use_batchnorm=False)
conv11_1, relu11_1 = conv_act_layer(relu10_2, "11_1", 128, kernel=(1,1), pad=(0,0), \
conv11_1, relu11_1 = legacy_conv_act_layer(relu10_2, "11_1", 128, kernel=(1,1), pad=(0,0), \
stride=(1,1), act_type="relu", use_batchnorm=False)
conv11_2, relu11_2 = conv_act_layer(relu11_1, "11_2", 256, kernel=(3,3), pad=(1,1), \
conv11_2, relu11_2 = legacy_conv_act_layer(relu11_1, "11_2", 256, kernel=(3,3), pad=(1,1), \
stride=(2,2), act_type="relu", use_batchnorm=False)
conv12_1, relu12_1 = conv_act_layer(relu11_2, "12_1", 128, kernel=(1,1), pad=(0,0), \
conv12_1, relu12_1 = legacy_conv_act_layer(relu11_2, "12_1", 128, kernel=(1,1), pad=(0,0), \
stride=(1,1), act_type="relu", use_batchnorm=False)
conv12_2, relu12_2 = conv_act_layer(relu12_1, "12_2", 256, kernel=(4,4), pad=(1,1), \
conv12_2, relu12_2 = legacy_conv_act_layer(relu12_1, "12_2", 256, kernel=(4,4), pad=(1,1), \
stride=(1,1), act_type="relu", use_batchnorm=False)

# specific parameters for VGG16 network
Expand Down
2 changes: 2 additions & 0 deletions symbol/symbol_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def get_symbol_train(network, data_shape, **kwargs):
see symbol_builder.get_symbol_train for more details
"""
if network.startswith('legacy'):
logging.warn('Using legacy model.')
return symbol_builder.import_module(network).get_symbol_train(**kwargs)
config = get_config(network, data_shape, **kwargs).copy()
config.update(kwargs)
Expand All @@ -114,6 +115,7 @@ def get_symbol(network, data_shape, **kwargs):
see symbol_builder.get_symbol for more details
"""
if network.startswith('legacy'):
logging.warn('Using legacy model.')
return symbol_builder.import_module(network).get_symbol(**kwargs)
config = get_config(network, data_shape, **kwargs).copy()
config.update(kwargs)
Expand Down

0 comments on commit d741ab1

Please sign in to comment.