Skip to content

Commit

Permalink
[NNVM] Add symbol squeezenet (apache#1436)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored and tqchen committed Jul 15, 2018
1 parent 6b8d0c0 commit c19cf6f
Show file tree
Hide file tree
Showing 10 changed files with 235 additions and 7 deletions.
1 change: 1 addition & 0 deletions nnvm/python/nnvm/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from . import mlp
from . import resnet
from . import vgg
from . import squeezenet
from . import dcgan
from . import dqn
from . import yolo2_detection
2 changes: 1 addition & 1 deletion nnvm/python/nnvm/testing/mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224), dtype=
The batch size used in the model
num_classes : int, optional
Number of claseses
Number of classes
image_shape : tuple, optional
The input image shape
Expand Down
2 changes: 1 addition & 1 deletion nnvm/python/nnvm/testing/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def get_workload(batch_size=1, num_classes=1000, num_layers=18,
The batch size used in the model
num_classes : int, optional
Number of claseses
Number of classes
num_layers : int, optional
Number of layers
Expand Down
132 changes: 132 additions & 0 deletions nnvm/python/nnvm/testing/squeezenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=unused-argument

"""
Symbol of SqueezeNet
Reference:
Iandola, Forrest N., et al.
"Squeezenet: Alexnet-level accuracy with 50x fewer parameters and< 0.5 mb model size." (2016).
"""

from .. import symbol as sym
from . utils import create_workload

# Helpers
def _make_fire(net, squeeze_channels, expand1x1_channels, expand3x3_channels):
net = _make_fire_conv(net, squeeze_channels, 1, 0)

left = _make_fire_conv(net, expand1x1_channels, 1, 0)
right = _make_fire_conv(net, expand3x3_channels, 3, 1)
# NOTE : Assume NCHW layout here
net = sym.concatenate(left, right, axis=1)

return net

def _make_fire_conv(net, channels, kernel_size, padding=0):
net = sym.conv2d(net, channels=channels, kernel_size=(kernel_size, kernel_size),
padding=(padding, padding))
net = sym.relu(net)
return net

# Net
def get_symbol(num_classes, version, **kwargs):
"""Get symbol of SqueezeNet
Parameters
----------
num_classes: int
The number of classification results
version : str, optional
"1.0" or "1.1" of SqueezeNet
"""
assert version in ['1.0', '1.1'], ("Unsupported SqueezeNet version {version}:"
"1.0 or 1.1 expected".format(version=version))
net = sym.Variable("data")
if version == '1.0':
net = sym.conv2d(net, channels=96, kernel_size=(7, 7), strides=(2, 2), padding=(3, 3))
net = sym.relu(net)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 32, 128, 128)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 32, 128, 128)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 64, 256, 256)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 64, 256, 256)
else:
net = sym.conv2d(net, channels=64, kernel_size=(3, 3), strides=(2, 2), padding=(1, 1))
net = sym.relu(net)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 16, 64, 64)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 32, 128, 128)
net = _make_fire(net, 32, 128, 128)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 64, 256, 256)
net = _make_fire(net, 64, 256, 256)
net = sym.dropout(net, rate=0.5)
net = sym.conv2d(net, channels=num_classes, kernel_size=(1, 1))
net = sym.relu(net)
net = sym.global_avg_pool2d(net)
net = sym.flatten(net)
return sym.softmax(net)

def get_workload(batch_size=1, num_classes=1000, version='1.0',
image_shape=(3, 224, 224), dtype="float32", **kwargs):
"""Get benchmark workload for resnet
Parameters
----------
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of classes
version : str, optional
"1.0" or "1.1" of SqueezeNet
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
kwargs : dict
Extra arguments
Returns
-------
net : nnvm.Symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_symbol(num_classes=num_classes, version=version, **kwargs)
return create_workload(net, batch_size, image_shape, dtype)
12 changes: 10 additions & 2 deletions nnvm/tests/python/frontend/mxnet/model_zoo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""MXNet and NNVM model zoo."""
from __future__ import absolute_import
from . import mlp, resnet, vgg, dqn, dcgan
from . import mlp, resnet, vgg, dqn, dcgan, squeezenet
import nnvm.testing

__all__ = ['mx_mlp', 'nnvm_mlp', 'mx_resnet', 'nnvm_resnet', 'mx_vgg', 'nnvm_vgg']
__all__ = ['mx_mlp', 'nnvm_mlp', 'mx_resnet', 'nnvm_resnet', 'mx_vgg', 'nnvm_vgg',
'mx_squeezenet', 'nnvm_squeezenet']

_num_class = 1000

Expand All @@ -27,6 +28,13 @@
nnvm_vgg[num_layer] = nnvm.testing.vgg.get_workload(
1, _num_class, num_layers=num_layer)[0]

# squeezenet
mx_squeezenet = {}
nnvm_squeezenet = {}
for version in ['1.0', '1.1']:
mx_squeezenet[version] = squeezenet.get_symbol(version=version)
nnvm_squeezenet[version] = nnvm.testing.squeezenet.get_workload(1, version=version)[0]

# dqn
mx_dqn = dqn.get_symbol()
nnvm_dqn = nnvm.testing.dqn.get_workload(1)[0]
Expand Down
76 changes: 76 additions & 0 deletions nnvm/tests/python/frontend/mxnet/model_zoo/squeezenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
Symbol of SqueezeNet
Reference:
Iandola, Forrest N., et al.
"Squeezenet: Alexnet-level accuracy with 50x fewer parameters and< 0.5 mb model size." (2016).
"""

import mxnet as mx

# Helpers
def _make_fire(net, squeeze_channels, expand1x1_channels, expand3x3_channels):
net = _make_fire_conv(net, squeeze_channels, 1, 0)

left = _make_fire_conv(net, expand1x1_channels, 1, 0)
right = _make_fire_conv(net, expand3x3_channels, 3, 1)
# NOTE : Assume NCHW layout here
net = mx.sym.concat(left, right, dim=1)

return net

def _make_fire_conv(net, channels, kernel_size, padding=0):
net = mx.sym.Convolution(net, num_filter=channels, kernel=(kernel_size, kernel_size),
pad=(padding, padding))
net = mx.sym.Activation(net, act_type='relu')
return net

# Net
def get_symbol(num_classes=1000, version='1.0', **kwargs):
"""Get symbol of SqueezeNet
Parameters
----------
num_classes: int
The number of classification results
version : str, optional
"1.0" or "1.1" of SqueezeNet
"""
assert version in ['1.0', '1.1'], ("Unsupported SqueezeNet version {version}:"
"1.0 or 1.1 expected".format(version=version))
net = mx.sym.Variable("data")
if version == '1.0':
net = mx.sym.Convolution(net, num_filter=96, kernel=(7, 7), stride=(2, 2), pad=(3, 3))
net = mx.sym.Activation(net, act_type='relu')
net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 32, 128, 128)
net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
net = _make_fire(net, 32, 128, 128)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 64, 256, 256)
net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
net = _make_fire(net, 64, 256, 256)
else:
net = mx.sym.Convolution(net, num_filter=64, kernel=(3, 3), stride=(2, 2), pad=(1, 1))
net = mx.sym.Activation(net, act_type='relu')
net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 16, 64, 64)
net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
net = _make_fire(net, 32, 128, 128)
net = _make_fire(net, 32, 128, 128)
net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 64, 256, 256)
net = _make_fire(net, 64, 256, 256)
net = mx.sym.Dropout(net, p=0.5)
net = mx.sym.Convolution(net, num_filter=num_classes, kernel=(1, 1))
net = mx.sym.Activation(net, act_type='relu')
net = mx.sym.Pooling(data=net, global_pool=True, kernel=(13, 13), pool_type='avg')
net = mx.sym.flatten(net)
return mx.sym.softmax(net)
8 changes: 8 additions & 0 deletions nnvm/tests/python/frontend/mxnet/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ def test_resnet():
nnvm_sym = model_zoo.nnvm_resnet[n]
compare_graph(from_mx_sym, nnvm_sym)

def test_squeezenet():
for version in ['1.0', '1.1']:
mx_sym = model_zoo.mx_squeezenet[version]
from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
nnvm_sym = model_zoo.nnvm_squeezenet[version]
compare_graph(from_mx_sym, nnvm_sym)

def test_dqn():
mx_sym = model_zoo.mx_dqn
from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
Expand Down Expand Up @@ -62,3 +69,4 @@ def compose(F, **kwargs):
test_multi_outputs()
test_dqn()
test_dcgan()
test_squeezenet()
2 changes: 1 addition & 1 deletion tutorials/autotvm/tune_cuda_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# ---------------------------------
# There are plenty of useful schedule primitives in tvm. You can also find
# some tutorials that describe them in more details, such as
# (1). :doc:``Optimizing Conv2d on NVIDIA GPU <../optimize/opt_conv_cuda>`
# (1). :ref:`opt-conv-gpu`
# (2). `Optimizing DepthwiseConv on NVIDIA GPU <https://tvm.ai/2017/08/22/Optimize-Deep-Learning-GPU-Operators-with-TVM-A-Depthwise-Convolution-Example.html>`_
#
# However, their implementations are manually tuned for some special input
Expand Down
2 changes: 1 addition & 1 deletion tutorials/nnvm/imagenet_inference_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# To get the maximum performance, we need to enable nvcc's compiler hook.
# This usually gives better performance than nvrtc mode.

@tvm.register_func
@tvm.register_func("tvm_callback_cuda_compile", override=True)
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx")
return ptx
Expand Down
5 changes: 4 additions & 1 deletion tutorials/optimize/opt_conv_cuda.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""How to optimize convolution on GPU
"""
.. _opt-conv-gpu:
How to optimize convolution on GPU
==================================
**Author**: `Haichen Shen <https://homes.cs.washington.edu/~haichen/>`_
Expand Down

0 comments on commit c19cf6f

Please sign in to comment.