Skip to content

Commit

Permalink
split tensor operator files (apache#4150)
Browse files Browse the repository at this point in the history
* fix examples

* fix

* split tensor operator files
  • Loading branch information
piiswrong committed Dec 29, 2016
1 parent 2be08f0 commit adb7cb1
Show file tree
Hide file tree
Showing 25 changed files with 312 additions and 215 deletions.
9 changes: 5 additions & 4 deletions example/nce-loss/lstm_word.py
@@ -1,4 +1,5 @@
# pylint:skip-file
import logging
import sys, random, time, math
sys.path.insert(0, "../../python")
import mxnet as mx
Expand Down Expand Up @@ -182,6 +183,9 @@ def reset(self):
pass

if __name__ == '__main__':
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

parser = OptionParser()
parser.add_option("-g", "--gpu", action = "store_true", dest = "gpu", default = False,
help = "use gpu")
Expand All @@ -195,6 +199,7 @@ def reset(self):
init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
init_states = init_c + init_h


data_train = DataIter("./data/text8", batch_size, seq_len, num_label,
init_states)

Expand All @@ -210,10 +215,6 @@ def reset(self):
momentum = 0.9,
wd = 0.0000,
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))

import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

metric = NceLSTMAuc()
model.fit(X = data_train,
Expand Down
7 changes: 4 additions & 3 deletions example/nce-loss/toy_nce.py
@@ -1,4 +1,5 @@
# pylint:skip-file
import logging
import sys, random, time
sys.path.insert(0, "../../python")
import mxnet as mx
Expand Down Expand Up @@ -83,6 +84,9 @@ def reset(self):
pass

if __name__ == '__main__':
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

batch_size = 128
vocab_size = 10000
feature_size = 100
Expand All @@ -100,9 +104,6 @@ def reset(self):
momentum = 0.9,
wd = 0.00001,
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

metric = NceAccuracy()
model.fit(X = data_train, eval_data = data_test,
Expand Down
9 changes: 5 additions & 4 deletions example/nce-loss/toy_softmax.py
@@ -1,4 +1,5 @@
# pylint:skip-file
import logging
import sys, random, time
sys.path.insert(0, "../../python")
import mxnet as mx
Expand Down Expand Up @@ -72,11 +73,14 @@ def reset(self):
pass

if __name__ == '__main__':
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

batch_size = 128
vocab_size = 10000
feature_size = 100
num_label = 6

data_train = DataIter(100000, batch_size, vocab_size, num_label, feature_size)
data_test = DataIter(1000, batch_size, vocab_size, num_label, feature_size)

Expand All @@ -89,9 +93,6 @@ def reset(self):
momentum = 0.9,
wd = 0.0000,
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

model.fit(X = data_train, eval_data = data_test,
batch_end_callback = mx.callback.Speedometer(batch_size, 50),)
Expand Down
7 changes: 4 additions & 3 deletions example/nce-loss/wordvec.py
@@ -1,4 +1,5 @@
# pylint:skip-file
import logging
import sys, random, time, math
sys.path.insert(0, "../../python")
import mxnet as mx
Expand Down Expand Up @@ -116,6 +117,9 @@ def reset(self):
pass

if __name__ == '__main__':
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

parser = OptionParser()
parser.add_option("-g", "--gpu", action = "store_true", dest = "gpu", default = False,
help = "use gpu")
Expand All @@ -138,9 +142,6 @@ def reset(self):
wd = 0.0000,
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))

import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

metric = NceAuc()
model.fit(X = data_train,
Expand Down
9 changes: 4 additions & 5 deletions example/nce-loss/wordvec_subwords.py
@@ -1,4 +1,5 @@
# pylint:skip-file
import logging
import sys, random, time, math
import mxnet as mx
import numpy as np
Expand Down Expand Up @@ -247,6 +248,9 @@ def reset(self):


if __name__ == '__main__':
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

parser = OptionParser()
parser.add_option("-g", "--gpu", action="store_true", dest="gpu", default=False,
help="use gpu")
Expand All @@ -271,11 +275,6 @@ def reset(self):
wd=0.0000,
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))

import logging

head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

metric = NceAuc()
model.fit(X=data_train,
eval_metric=metric,
Expand Down
14 changes: 14 additions & 0 deletions example/rnn-time-major/get_ptb_data.sh
@@ -0,0 +1,14 @@
#!/usr/bin/env bash

RNN_DIR=$(cd `dirname $0`; pwd)
DATA_DIR="${RNN_DIR}/data/"

if [[ ! -d "${DATA_DIR}" ]]; then
echo "${DATA_DIR} doesn't exist, will create one";
mkdir -p ${DATA_DIR}
fi

wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.train.txt;
wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.valid.txt;
wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.test.txt;
wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/tinyshakespeare/input.txt;
3 changes: 3 additions & 0 deletions example/rnn/README.md
Expand Up @@ -2,6 +2,9 @@ RNN Example
===========
This folder contains RNN examples using low level symbol interface.

## Data
Run `get_ptb_data.sh` to download PenTreeBank data.

## Python

- [lstm.py](lstm.py) Functions for building a LSTM Network
Expand Down
2 changes: 1 addition & 1 deletion nnvm
Submodule nnvm updated 1 files
+1 −2 src/core/symbolic.cc
5 changes: 3 additions & 2 deletions python/mxnet/model.py
Expand Up @@ -428,8 +428,9 @@ def __init__(self, symbol, ctx=None,
allow_extra_params=False,
begin_epoch=0,
**kwargs):
logging.warning('[Deprecation Warning] mxnet.model.FeedForward has been deprecated. ' + \
'Please use mxnet.mod.Module instead.')
logging.warning(
'\033[91m[Deprecation Warning] mxnet.model.FeedForward has been deprecated. ' + \
'Please use mxnet.mod.Module instead.\033[0m')

if isinstance(symbol, sym.Symbol):
self.symbol = symbol
Expand Down
3 changes: 2 additions & 1 deletion python/mxnet/symbol.py
Expand Up @@ -44,8 +44,9 @@ class Symbol(SymbolBase):

def __repr__(self):
"""Get a string representation of the symbol."""
name = self.name
return '<%s %s>' % (self.__class__.__name__,
self.name)
'Grouped' if name is None else name)

def __add__(self, other):
if isinstance(other, Symbol):
Expand Down
4 changes: 2 additions & 2 deletions src/operator/elemwise_op_common.h
Expand Up @@ -74,7 +74,7 @@ template<int n_in, int n_out>
inline bool ElemwiseShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), n_in) << attrs.name << in_attrs->size() << n_in;
CHECK_EQ(in_attrs->size(), n_in) << " in operator " << attrs.name;
CHECK_EQ(out_attrs->size(), n_out);
return ElemwiseAttr<TShape, shape_is_none, true>(
attrs, in_attrs, out_attrs);
Expand All @@ -88,7 +88,7 @@ template<int n_in, int n_out>
inline bool ElemwiseType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), n_in) << attrs.name << in_attrs->size() << n_in;
CHECK_EQ(in_attrs->size(), n_in) << " in operator " << attrs.name;
CHECK_EQ(out_attrs->size(), n_out);
return ElemwiseAttr<int, type_is_none, true>(
attrs, in_attrs, out_attrs);
Expand Down
2 changes: 1 addition & 1 deletion src/operator/loss_binary_op.cc
Expand Up @@ -13,7 +13,7 @@ NNVM_REGISTER_OP(softmax_cross_entropy)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<nnvm::FInferShape>("FInferShape", SoftmaxCrossEntropyShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
Expand Down
34 changes: 34 additions & 0 deletions src/operator/tensor/broadcast_reduce_op_index.cc
@@ -0,0 +1,34 @@
/*!
* Copyright (c) 2016 by Contributors
* \file broadcast_reduce_op.cc
* \brief CPU Implementation of broadcast and reduce functions.
*/
#include "./broadcast_reduce_op.h"

namespace mxnet {
namespace op {
MXNET_OPERATOR_REGISTER_REDUCE_AXIS(argmax)
.MXNET_DESCRIBE("Compute argmax")
.set_attr<FCompute>("FCompute<cpu>", SearchAxisCompute<cpu, mshadow::red::maximum>);

MXNET_OPERATOR_REGISTER_REDUCE_AXIS(argmin)
.MXNET_DESCRIBE("Compute argmin")
.set_attr<FCompute>("FCompute<cpu>", SearchAxisCompute<cpu, mshadow::red::minimum>);

// Legacy support
NNVM_REGISTER_OP(argmax_channel)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser([](NodeAttrs* attrs) {
ReduceAxisParam param;
param.axis = 1;
param.keepdims = false;
attrs->parsed = param;
})
.set_attr<nnvm::FInferShape>("FInferShape", ReduceAxisShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCompute>("FCompute<cpu>", SearchAxisCompute<cpu, mshadow::red::maximum>)
.add_argument("src", "NDArray", "Source input");

} // namespace op
} // namespace mxnet
21 changes: 21 additions & 0 deletions src/operator/tensor/broadcast_reduce_op_index.cu
@@ -0,0 +1,21 @@
/*!
* Copyright (c) 2016 by Contributors
* \file broadcast_reduce_op.cu
* \brief GPU Implementation of broadcast and reduce functions.
*/
#include "./broadcast_reduce_op.h"

namespace mxnet {
namespace op {
NNVM_REGISTER_OP(argmax)
.set_attr<FCompute>("FCompute<gpu>", SearchAxisCompute<gpu, mshadow::red::maximum>);

NNVM_REGISTER_OP(argmin)
.set_attr<FCompute>("FCompute<gpu>", SearchAxisCompute<gpu, mshadow::red::minimum>);

// Legacy support
NNVM_REGISTER_OP(argmax_channel)
.set_attr<FCompute>("FCompute<gpu>", SearchAxisCompute<gpu, mshadow::red::maximum>);

} // namespace op
} // namespace mxnet
Expand Up @@ -62,29 +62,6 @@ NNVM_REGISTER_OP(_broadcast_backward)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", ReduceAxesCompute<cpu, mshadow::red::sum>);

MXNET_OPERATOR_REGISTER_REDUCE_AXIS(argmax)
.MXNET_DESCRIBE("Compute argmax")
.set_attr<FCompute>("FCompute<cpu>", SearchAxisCompute<cpu, mshadow::red::maximum>);

MXNET_OPERATOR_REGISTER_REDUCE_AXIS(argmin)
.MXNET_DESCRIBE("Compute argmin")
.set_attr<FCompute>("FCompute<cpu>", SearchAxisCompute<cpu, mshadow::red::minimum>);

// Legacy support
NNVM_REGISTER_OP(argmax_channel)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser([](NodeAttrs* attrs) {
ReduceAxisParam param;
param.axis = 1;
param.keepdims = false;
attrs->parsed = param;
})
.set_attr<nnvm::FInferShape>("FInferShape", ReduceAxisShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCompute>("FCompute<cpu>", SearchAxisCompute<cpu, mshadow::red::maximum>)
.add_argument("src", "NDArray", "Source input");

NNVM_REGISTER_OP(norm)
.set_num_inputs(1)
.set_num_outputs(1)
Expand Down
Expand Up @@ -34,16 +34,6 @@ NNVM_REGISTER_OP(broadcast_to)
NNVM_REGISTER_OP(_broadcast_backward)
.set_attr<FCompute>("FCompute<gpu>", ReduceAxesCompute<gpu, mshadow::red::sum>);

NNVM_REGISTER_OP(argmax)
.set_attr<FCompute>("FCompute<gpu>", SearchAxisCompute<gpu, mshadow::red::maximum>);

NNVM_REGISTER_OP(argmin)
.set_attr<FCompute>("FCompute<gpu>", SearchAxisCompute<gpu, mshadow::red::minimum>);

// Legacy support
NNVM_REGISTER_OP(argmax_channel)
.set_attr<FCompute>("FCompute<gpu>", SearchAxisCompute<gpu, mshadow::red::maximum>);

NNVM_REGISTER_OP(norm)
.set_attr<FCompute>("FCompute<gpu>", L2NormCompute<gpu>);

Expand Down

0 comments on commit adb7cb1

Please sign in to comment.