Skip to content

Commit

Permalink
mxnet scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
soumith committed Jan 5, 2016
1 parent c4dfa52 commit 4de3904
Show file tree
Hide file tree
Showing 3 changed files with 355 additions and 0 deletions.
57 changes: 57 additions & 0 deletions mxnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
### Install via:


```
sudo apt-get update
sudo apt-get install -y build-essential git libatlas-base-dev libopencv-dev
git clone --recursive https://github.com/dmlc/mxnet
make -j12 USE_CUDA=1 USE_CUDA_PATH=/usr/local/cuda
cd python; python setup.py install
cd ../../
```

### Run benchmarks

```
CUDA_VISIBLE_DEVICES=2 MXNET_GPU_WORKER_NTHREADS=2 MXNET_EXEC_NUM_TEMP=1 python alexnet.py | tee out_alexnet.log
```

### Notes from antinucleon

We choose to block the dynamic thread engine to get fair result, definitely there will be some costs,
and I do think it is not worth for these microseconds but the most important thing is learn spirit
from each tools.
The first epoch will make some delay for lazy allocation, but we think it is not a problem.
Also there is a magic number of 4GB threshold for dynamic memory recycle,
we didn't change it although dynamic memory recycle will hurt performance too much.


One import thing about MXNet is chose parallelism level.
Basically, less parallelism, fewer memory cost.
For example, on Titan X with 12 GB memory, train on GoogLeNet v1,

```MXNET_GPU_WORKER_NTHREADS=2 MXNET_EXEC_NUM_TEMP=1 python3 gnet.py``` allows training in batch of 256,

but

```MXNET_GPU_WORKER_NTHREADS=4 MXNET_EXEC_NUM_TEMP=4 python3 gnet.py```

will be oom for batch of 256 (guess still saving a little more than other library but not tested)


Various of setting can be found at: https://mxnet.readthedocs.org/en/latest/env_var.html

In my feeling because currently hardware is bottleneck, dynamic data flow engine and
multi-execution doesn't show its advantage on single card, but in multi-gpu
and distributed case, it makes problem much easier.


BTW. Do you have plan to benchmark multi-gpu or distributed convolution net?
We have collected some result already.

https://github.com/dmlc/mxnet/tree/master/example/distributed-training


142 changes: 142 additions & 0 deletions mxnet/alexnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# In[1]:

import mxnet as mx
import numpy as np
import time


# In[2]:

# Basic Info
dev = mx.gpu()
batch_size = 128
dshape = (batch_size, 3, 224, 224)
lshape = (batch_size)
num_epoch = 100

# Mock data iterator
tmp_data = np.random.uniform(-1, 1, dshape).astype("float32")

train_iter = mx.io.NDArrayIter(data=tmp_data, batch_size=batch_size, shuffle=False, last_batch_handle='pad')



# In[5]:

def get_alexnet_symbol():
## define alexnet
input_data = mx.symbol.Variable(name="data")
# stage 1
conv1 = mx.symbol.Convolution(
data=input_data, kernel=(11, 11), stride=(4, 4), num_filter=64)
relu1 = mx.symbol.Activation(data=conv1, act_type="relu")
pool1 = mx.symbol.Pooling(
data=relu1, pool_type="max", kernel=(3, 3), stride=(2,2))
# lrn1 = mx.symbol.LRN(data=pool1, alpha=0.0001, beta=0.75, knorm=1, nsize=5)
# stage 2
conv2 = mx.symbol.Convolution(
data=pool1, kernel=(5, 5), pad=(2, 2), num_filter=192)
relu2 = mx.symbol.Activation(data=conv2, act_type="relu")
pool2 = mx.symbol.Pooling(data=relu2, kernel=(3, 3), stride=(2, 2), pool_type="max")
# lrn2 = mx.symbol.LRN(data=pool2, alpha=0.0001, beta=0.75, knorm=1, nsize=5)
# stage 3
conv3 = mx.symbol.Convolution(
data=pool2, kernel=(3, 3), pad=(1, 1), num_filter=384)
relu3 = mx.symbol.Activation(data=conv3, act_type="relu")
conv4 = mx.symbol.Convolution(
data=relu3, kernel=(3, 3), pad=(1, 1), num_filter=256)
relu4 = mx.symbol.Activation(data=conv4, act_type="relu")
conv5 = mx.symbol.Convolution(
data=relu4, kernel=(3, 3), pad=(1, 1), num_filter=256)
relu5 = mx.symbol.Activation(data=conv5, act_type="relu")
pool3 = mx.symbol.Pooling(data=relu5, kernel=(3, 3), stride=(2, 2), pool_type="max")
# stage 4
flatten = mx.symbol.Flatten(data=pool3)
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=4096)
relu6 = mx.symbol.Activation(data=fc1, act_type="relu")
# stage 5
fc2 = mx.symbol.FullyConnected(data=relu6, num_hidden=4096)
relu7 = mx.symbol.Activation(data=fc2, act_type="relu")
# stage 6
fc3 = mx.symbol.FullyConnected(data=relu7, num_hidden=1000)
return fc3

# In[6]:

# bind to get executor
# This is what happened behind mx.model.Feedforward
fc3 = get_alexnet_symbol()
alex_exec = fc3.simple_bind(ctx=dev, grad_req="write", data=dshape)
print("Temp space: ", alex_exec.debug_str().split('\n')[-3])
# Find where to set data


# In[7]:

# some useful structure
# data structues
arg_names = fc3.list_arguments()
arg_map = dict(zip(arg_names, alex_exec.arg_arrays))
grad_map = dict(zip(arg_names, alex_exec.grad_arrays))


param_blocks = [(i, arg_map[arg_names[i]], grad_map[arg_names[i]]) for i in range(len(arg_names)) if grad_map[arg_names[i]] != None]
input_ndarray = arg_map["data"]
grad = mx.nd.zeros((batch_size, 1000), ctx=mx.gpu())
param_len = len(param_blocks)


# In[8]:

#init
for i in range(param_len):
param_blocks[i][1][:] = mx.rnd.uniform(-0.01, 0.01, param_blocks[i][1].shape)
param_blocks[i][2][:] = 0.
# Set data
train_iter.reset()
dbatch = train_iter.next()
dbatch.data[0].copyto(input_ndarray)
# block all async all
mx.nd.waitall()


# In[12]:

# Test forward
def test_forward(model, epoch):
tic = time.time()
for i in range(epoch):
model.forward(is_train=True)
# Note: This command will force thread engine block, which hurts performance a lot
# Remove it will bring parallelism bias
# model.outputs[0].wait_to_read()
model.outputs[0].wait_to_read()
toc = time.time()
return (toc - tic) / epoch

print("Avg forward per batch: ", test_forward(alex_exec, num_epoch))


# In[13]:

# Test full path
def test_full(model, epoch):
tic = time.time()
for i in range(epoch):
model.forward(is_train=True)
model.backward([grad])
#model.outputs[0].wait_to_read()
# mx.nd.waitall()
# mock update
for i in range(param_len):
param_blocks[i][1][:] -= 0.0 * param_blocks[i][2][:]
# Note: This command will force thread engine block, which hurts performance a lot
mx.nd.waitall()
toc = time.time()
return (toc - tic) / epoch

print("Avg fullpath per batch: ", test_full(alex_exec, num_epoch))


# In[ ]:

156 changes: 156 additions & 0 deletions mxnet/gnetv1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# coding: utf-8

# # Before start
#
# There is many important [environment variables](https://mxnet.readthedocs.org/en/latest/env_var.html) which will influence the performance. Change these variable will change the parallelism, memory cost.
#
# sample command:
# ```
# MXNET_GPU_WORKER_NTHREADS=4 MXNET_EXEC_NUM_TEMP=4 python3 googlenet.py
# ```
#
# Speed and memory cost may change due to different level of parallelism

# In[1]:

import mxnet as mx
import numpy as np
import time


# In[2]:

# Basic Info
dev = mx.gpu()
batch_size = 128
dshape = (batch_size, 3, 224, 224)
lshape = (batch_size)
num_epoch = 100

# Mock data iterator
tmp_data = np.random.uniform(-128, 128, dshape).astype("float32")
tmp_label = np.random.uniform(0, 1000, lshape).astype("int").astype("float32")

train_iter = mx.io.NDArrayIter(data=tmp_data, label=tmp_label, batch_size=batch_size, shuffle=False, last_batch_handle='pad')



# GoogLeNet V1: Converted from [Caffe](https://github.com/BVLC/caffe/blob/master/models/bvlc_googlenet/deploy.prototxt) directly

def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), name=None, suffix=''):
conv = mx.symbol.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, name='conv_%s%s' %(name, suffix))
act = mx.symbol.Activation(data=conv, act_type='relu', name='relu_%s%s' %(name, suffix))
return act

def InceptionFactory(data, num_1x1, num_3x3red, num_3x3, num_d5x5red, num_d5x5, pool, proj, name):
# 1x1
c1x1 = ConvFactory(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_1x1' % name))
# 3x3 reduce + 3x3
c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1), name=('%s_3x3' % name), suffix='_reduce')
c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_3x3' % name))
# double 3x3 reduce + double 3x3
cd5x5r = ConvFactory(data=data, num_filter=num_d5x5red, kernel=(1, 1), name=('%s_double_3x3' % name), suffix='_reduce')
cd5x5 = ConvFactory(data=cd5x5r, num_filter=num_d5x5, kernel=(5, 5), pad=(2, 2), name=('%s_double_3x3_1' % name))
# pool + proj
pooling = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
cproj = ConvFactory(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_proj' % name))
# concat
concat = mx.symbol.Concat(*[c1x1, c3x3, cd5x5, cproj], name='ch_concat_%s_chconcat' % name)
return concat

data = mx.sym.Variable("data")
conv1 = ConvFactory(data, 64, kernel=(7, 7), stride=(2,2), pad=(3, 3))
pool1 = mx.sym.Pooling(conv1, kernel=(3, 3), stride=(2, 2), pool_type="max")
conv2 = ConvFactory(pool1, 64, kernel=(1, 1), stride=(1,1))
conv3 = ConvFactory(conv2, 192, kernel=(3, 3), stride=(1, 1), pad=(1,1))
pool3 = mx.sym.Pooling(conv3, kernel=(3, 3), stride=(2, 2), pool_type="max")

in3a = InceptionFactory(pool3, 64, 96, 128, 16, 32, "max", 32, name="in3a")
in3b = InceptionFactory(in3a, 128, 128, 192, 32, 96, "max", 64, name="in3b")
pool4 = mx.sym.Pooling(in3b, kernel=(3, 3), stride=(2, 2), pool_type="max")
in4a = InceptionFactory(pool4, 192, 96, 208, 16, 48, "max", 64, name="in4a")
in4b = InceptionFactory(in4a, 160, 112, 224, 24, 64, "max", 64, name="in4b")
in4c = InceptionFactory(in4b, 128, 128, 256, 24, 64, "max", 64, name="in4c")
in4d = InceptionFactory(in4c, 112, 144, 288, 32, 64, "max", 64, name="in4d")
in4e = InceptionFactory(in4d, 256, 160, 320, 32, 128, "max", 128, name="in4e")
pool5 = mx.sym.Pooling(in4e, kernel=(3, 3), stride=(2, 2), pool_type="max")
in5a = InceptionFactory(pool5, 256, 160, 320, 32, 128, "max", 128, name="in5a")
in5b = InceptionFactory(in5a, 384, 192, 384, 48, 128, "max", 128, name="in5b")
pool6 = mx.sym.Pooling(in5b, kernel=(7, 7), stride=(1,1), pool_type="avg")
flatten = mx.sym.Flatten(data=pool6)
loss3_classifier = mx.sym.FullyConnected(data=flatten, num_hidden=1000)


# In[4]:

# bind to get executor
# This is what happened behind mx.model.Feedforward
g_exec = loss3_classifier.simple_bind(ctx=dev, grad_req="write", data=dshape)
print("Temp Space: ", g_exec.debug_str().split('\n')[-3])
# Find where to set data


# In[5]:

# data structues
arg_names = loss3_classifier.list_arguments()
arg_map = dict(zip(arg_names, g_exec.arg_arrays))
grad_map = dict(zip(arg_names, g_exec.grad_arrays))


param_blocks = [(i, arg_map[arg_names[i]], grad_map[arg_names[i]]) for i in range(len(arg_names)) if grad_map[arg_names[i]] != None]
input_ndarray = arg_map["data"]
#label_ndarray = arg_map["prob_label"]
grad = mx.nd.zeros((batch_size, 1000), ctx=mx.gpu())
param_len = len(param_blocks)


# In[6]:

#init
for i in range(param_len):
param_blocks[i][1][:] = mx.rnd.uniform(-0.01, 0.01, param_blocks[i][1].shape)
param_blocks[i][2][:] = 0.
# Set data
train_iter.reset()
dbatch = train_iter.next()
dbatch.data[0].copyto(input_ndarray)
#dbatch.label[0].copyto(label_ndarray)
# block all async all
mx.nd.waitall()


# In[ ]:

# Test forward
def test_forward(model, epoch):
tic = time.time()
for i in range(epoch):
model.forward(is_train=True)
# Note: This command will force thread engine block, which hurts performance a lot
# Remove it will bring parallelism bias
model.outputs[0].wait_to_read()
toc = time.time()
return (toc - tic) / epoch

print("Avg forward per batch: ", test_forward(g_exec, num_epoch))


# In[ ]:

# Test full path
def test_full(model, epoch):
tic = time.time()
for i in range(epoch):
model.forward(is_train=True)
model.backward([grad])
# mock update, prevent NaN
for i in range(param_len):
param_blocks[i][1][:] -= 0.0 * param_blocks[i][2]
# Note: This command will force thread engine block, which hurts performance a lot
mx.nd.waitall()
toc = time.time()
return (toc - tic) / epoch

print("Avg fullpath per batch: ", test_full(g_exec, num_epoch))

0 comments on commit 4de3904

Please sign in to comment.