Skip to content

Commit

Permalink
Added parameter transpose to dot product action, which makes possible…
Browse files Browse the repository at this point in the history
… to do BP computation using the transposed of a weights matrix
  • Loading branch information
pakozm committed Jan 20, 2013
1 parent a50271e commit ba6bf01
Show file tree
Hide file tree
Showing 13 changed files with 680 additions and 565 deletions.
6 changes: 6 additions & 0 deletions AUTHORS.txt
@@ -0,0 +1,6 @@
In this project has been worked:
- Salvador España Boquera
- Jorge Gorbe Moya
- Adrián Palacios Corella
- Joan Pastor Pellicer
- Francisco Zamora Martínez
16 changes: 13 additions & 3 deletions packages/ann/ann_base/binding/bind_ann_base.lua.cc
Expand Up @@ -107,6 +107,8 @@ using namespace Functions;
obj = new RealActivationUnits(size, ann->getConfReference(),
strcmp(type, "inputs") != 0);
ann->registerActivationUnits(obj);
if (strcmp(type, "inputs") == 0) ann->registerInput(obj);
else if (strcmp(type, "outputs") == 0) ann->registerOutput(obj);
LUABIND_RETURN(RealActivationUnits, obj);
}
//BIND_END
Expand Down Expand Up @@ -242,6 +244,7 @@ using namespace Functions;

obj = new ForwardBiasAction(ann->getConfReference(),
output, conn);
ann->registerAction(obj);
LUABIND_RETURN(ForwardBiasAction, obj);
}
//BIND_END
Expand All @@ -255,20 +258,25 @@ using namespace Functions;
//BIND_CONSTRUCTOR DotProductAction
{
LUABIND_CHECK_ARGN(==,1);
check_table_fields(L, 1, "ann", "input", "output", "connections", 0);
check_table_fields(L, 1, "ann", "input", "output", "connections",
"transpose", 0);

ActivationUnits *input;
ActivationUnits *output;
Connections *conn;
ANNBase *ann;
bool transpose;

LUABIND_GET_TABLE_PARAMETER(1, input, ActivationUnits, input);
LUABIND_GET_TABLE_PARAMETER(1, output, ActivationUnits, output);
LUABIND_GET_TABLE_PARAMETER(1, connections, Connections, conn);
LUABIND_GET_TABLE_PARAMETER(1, ann, ANNBase, ann);
LUABIND_GET_TABLE_OPTIONAL_PARAMETER(1, transpose, bool, transpose, false);

obj = new DotProductAction(ann->getConfReference(),
input, output, conn);
input, output, conn,
transpose);
ann->registerAction(obj);
LUABIND_RETURN(DotProductAction, obj);
}
//BIND_END
Expand All @@ -294,6 +302,7 @@ using namespace Functions;

obj = new ActivationsAction(ann->getConfReference(),
output, actfunc);
ann->registerAction(obj);
LUABIND_RETURN(ActivationsAction, obj);
}
//BIND_END
Expand Down Expand Up @@ -386,7 +395,8 @@ using namespace Functions;
LUABIND_GET_TABLE_PARAMETER(1, w, MatrixFloat, w);
LUABIND_GET_TABLE_OPTIONAL_PARAMETER(1, oldw, MatrixFloat, oldw, w);
LUABIND_GET_TABLE_OPTIONAL_PARAMETER(1, first_pos, uint, first_pos, 0);
LUABIND_GET_TABLE_OPTIONAL_PARAMETER(1, column_size, uint, column_size, 1);
LUABIND_GET_TABLE_OPTIONAL_PARAMETER(1, column_size, uint, column_size,
obj->getNumInputs());

LUABIND_RETURN(uint, obj->loadWeights(w, oldw, first_pos, column_size));
}
Expand Down
3 changes: 1 addition & 2 deletions packages/ann/ann_base/c_src/all_all_connection.cc
Expand Up @@ -25,8 +25,7 @@ namespace ANN {

AllAllConnections::AllAllConnections(unsigned int num_inputs,
unsigned int num_outputs) :
Connections(num_inputs*num_outputs),
num_inputs(num_inputs), num_outputs(num_outputs) {
Connections(num_inputs*num_outputs, num_inputs, num_outputs) {
}

bool AllAllConnections::checkInputOutputSizes(ActivationUnits *input,
Expand Down
1 change: 0 additions & 1 deletion packages/ann/ann_base/c_src/all_all_connection.h
Expand Up @@ -26,7 +26,6 @@

namespace ANN {
class AllAllConnections : public Connections {
unsigned int num_inputs, num_outputs;
public:
AllAllConnections(unsigned int num_inputs,
unsigned int num_outputs);
Expand Down
14 changes: 5 additions & 9 deletions packages/ann/ann_base/c_src/bias_connection.cc
Expand Up @@ -25,7 +25,7 @@
namespace ANN {

BiasConnections::BiasConnections(unsigned int bias_size) :
Connections(bias_size) {
Connections(bias_size, 1, bias_size) {
}

bool BiasConnections::checkInputOutputSizes(ActivationUnits *input,
Expand Down Expand Up @@ -57,14 +57,10 @@ namespace ANN {

float *w = weights->getPPALForReadAndWrite();
float *prev_w = prev_weights->getPPALForReadAndWrite();

for (unsigned int j=0; j<num_outputs; ++j) {
unsigned int k = j;
for (unsigned int i=0; i<num_inputs; ++i) {
rnd_weight(w[k]);
prev_w[k] = w[k];
k += num_outputs;
}

for (unsigned int j=0; j<total_size; ++j) {
rnd_weight(w[j]);
prev_w[j] = w[j];
}
}

Expand Down
4 changes: 3 additions & 1 deletion packages/ann/ann_base/c_src/connection.cc
Expand Up @@ -26,10 +26,12 @@
namespace ANN {
const double Connections::weightnearzero = 1e-10;

Connections::Connections(unsigned int total_size) :
Connections::Connections(unsigned int total_size,
unsigned int num_inputs, unsigned int num_outputs) :
Referenced(),
weights(0), prev_weights(0),
total_size(total_size),
num_inputs(num_inputs), num_outputs(num_outputs),
num_references(0), update_weights_calls(0) {
weights = new FloatGPUMirroredMemoryBlock(total_size);
prev_weights = new FloatGPUMirroredMemoryBlock(total_size);
Expand Down
10 changes: 9 additions & 1 deletion packages/ann/ann_base/c_src/connection.h
Expand Up @@ -51,7 +51,8 @@ namespace ANN {
public:
static const double weightnearzero;

Connections(unsigned int total_size);
Connections(unsigned int total_size,
unsigned int num_inputs, unsigned int num_outputs);
virtual ~Connections();

// contamos el numero de veces que nos referencian, asi sabemos si
Expand Down Expand Up @@ -98,6 +99,13 @@ namespace ANN {
return total_size;
}

virtual unsigned int getNumInputs() const {
return num_inputs;
}
virtual unsigned int getNumOutputs() const {
return num_outputs;
}

};
}
#endif
187 changes: 130 additions & 57 deletions packages/ann/ann_base/c_src/dot_product_action.cc
Expand Up @@ -31,7 +31,8 @@ namespace ANN {
DotProductAction::DotProductAction(const ANNConfiguration &conf,
ActivationUnits *inputs,
ActivationUnits *outputs,
Connections *weights_matrix) :
Connections *weights_matrix,
bool transpose_weights) :
Action(conf),
inputs(inputs), outputs(outputs), weights_matrix(weights_matrix),
num_inputs(inputs->numNeurons()),
Expand All @@ -40,9 +41,15 @@ namespace ANN {
learning_rate(-1.0f),
momentum(0.0f),
weight_decay(0.0f),
c_weight_decay(1.0f) {
if (!weights_matrix->checkInputOutputSizes(inputs, outputs))
ERROR_EXIT(256, "The input/output sizes are not correct.\n");
c_weight_decay(1.0f),
transpose_weights(transpose_weights) {
if (!transpose_weights) {
if (!weights_matrix->checkInputOutputSizes(inputs, outputs))
ERROR_EXIT(256, "The input/output sizes are not correct.\n");
}
else
if (!weights_matrix->checkInputOutputSizes(outputs, inputs))
ERROR_EXIT(256, "The input/output sizes are not correct.\n");
weights_matrix->countReference();
IncRef(inputs);
IncRef(outputs);
Expand All @@ -61,28 +68,51 @@ namespace ANN {
FloatGPUMirroredMemoryBlock *output_ptr = outputs->getPtr();
FloatGPUMirroredMemoryBlock *weights_mat_ptr = weights_matrix->getPtr();

if (conf.cur_bunch_size == 1)
if (conf.cur_bunch_size == 1) {
// vector x matrix product
doSgemv(CblasColMajor, CblasNoTrans,
num_outputs, num_inputs,
1.0f, weights_mat_ptr, num_outputs,
input_ptr, conf.max_bunch_size,
1.0f, output_ptr, conf.max_bunch_size,
0, inputs->getOffset(), outputs->getOffset(),
conf.use_cuda_flag);
else
if (!transpose_weights)
doSgemv(CblasColMajor, CblasNoTrans,
num_outputs, num_inputs,
1.0f, weights_mat_ptr, num_outputs,
input_ptr, conf.max_bunch_size,
1.0f, output_ptr, conf.max_bunch_size,
0, inputs->getOffset(), outputs->getOffset(),
conf.use_cuda_flag);
else
doSgemv(CblasColMajor, CblasTrans,
num_inputs, num_outputs,
1.0f, weights_mat_ptr, num_inputs,
input_ptr, conf.max_bunch_size,
1.0f, output_ptr, conf.max_bunch_size,
0, inputs->getOffset(), outputs->getOffset(),
conf.use_cuda_flag);
}
else {
// matrix x matrix product
// C = \alpha op(A) op(B) + \beta C
// input * weights = output
doSgemm(CblasColMajor, CblasNoTrans, CblasTrans,
conf.cur_bunch_size, num_outputs, num_inputs,
1.0f, input_ptr, conf.max_bunch_size,
weights_mat_ptr, num_outputs,
// beta = 1.0f, C matrix contains BIAS and probably other layer
// computations
1.0f, output_ptr, conf.max_bunch_size,
inputs->getOffset(), 0, outputs->getOffset(),
conf.use_cuda_flag);
if (!transpose_weights)
doSgemm(CblasColMajor, CblasNoTrans, CblasTrans,
conf.cur_bunch_size, num_outputs, num_inputs,
1.0f, input_ptr, conf.max_bunch_size,
weights_mat_ptr, num_outputs,
// beta = 1.0f, C matrix contains BIAS and probably other layer
// computations
1.0f, output_ptr, conf.max_bunch_size,
inputs->getOffset(), 0, outputs->getOffset(),
conf.use_cuda_flag);
else
doSgemm(CblasColMajor, CblasNoTrans, CblasNoTrans,
conf.cur_bunch_size, num_outputs, num_inputs,
1.0f,
input_ptr, conf.max_bunch_size,
weights_mat_ptr, num_inputs,
// beta = 1.0f, C matrix contains BIAS and probably other layer
// computations
1.0f, output_ptr, conf.max_bunch_size,
inputs->getOffset(), 0, outputs->getOffset(),
conf.use_cuda_flag);
}
}

void DotProductAction::
Expand All @@ -94,22 +124,42 @@ namespace ANN {
if (output_error != 0) {
if (conf.cur_bunch_size > 1) {
// C = alpha * A * B + beta * C
doSgemm(CblasColMajor, CblasNoTrans, CblasNoTrans,
conf.cur_bunch_size, num_inputs, num_outputs,
1.0f, input_error, conf.max_bunch_size,
weights_mat_ptr, num_outputs,
1.0f, output_error, conf.max_bunch_size,
input_error_shift, 0, output_error_shift,
conf.use_cuda_flag);
if (!transpose_weights)
doSgemm(CblasColMajor, CblasNoTrans, CblasNoTrans,
conf.cur_bunch_size, num_inputs, num_outputs,
1.0f, input_error, conf.max_bunch_size,
weights_mat_ptr, num_outputs,
1.0f, output_error, conf.max_bunch_size,
input_error_shift, 0, output_error_shift,
conf.use_cuda_flag);
else
doSgemm(CblasColMajor, CblasNoTrans, CblasTrans,
conf.cur_bunch_size, num_inputs, num_outputs,
1.0f, input_error, conf.max_bunch_size,
weights_mat_ptr, num_inputs,
1.0f, output_error, conf.max_bunch_size,
input_error_shift, 0, output_error_shift,
conf.use_cuda_flag);
}
else {
doSgemv(CblasColMajor, CblasNoTrans,
num_inputs, num_outputs,
1.0f, weights_mat_ptr, num_inputs,
input_error, conf.max_bunch_size,
1.0f, output_error, conf.max_bunch_size,
0, input_error_shift, output_error_shift,
conf.use_cuda_flag);
// FIXME: I'm not sure of this two calls... please review it
if (!transpose_weights)
doSgemv(CblasColMajor, CblasTrans,
num_outputs, num_inputs,
1.0f, weights_mat_ptr, num_outputs,
input_error, conf.max_bunch_size,
1.0f, output_error, conf.max_bunch_size,
0, input_error_shift, output_error_shift,
conf.use_cuda_flag);
else {
doSgemv(CblasColMajor, CblasNoTrans,
num_inputs, num_outputs,
1.0f, weights_mat_ptr, num_inputs,
input_error, conf.max_bunch_size,
1.0f, output_error, conf.max_bunch_size,
0, input_error_shift, output_error_shift,
conf.use_cuda_flag);
}
}
}
}
Expand All @@ -130,33 +180,56 @@ namespace ANN {
-(1.0f/sqrtf(static_cast<float>(references))) *
learning_rate;

if (conf.cur_bunch_size > 1)
doSgemm(CblasColMajor, CblasTrans, CblasNoTrans, // transposicones
num_outputs, num_inputs, conf.cur_bunch_size, // dimensiones
norm_learn_rate, // alpha
input_error, // A
conf.max_bunch_size, // A stride
input, // B
conf.max_bunch_size, // B stride
beta, // beta
prev_weights_mat_ptr, // C
num_outputs, // C stride
input_error_shift, input_shift, 0, // desplazamientos
conf.use_cuda_flag);
if (conf.cur_bunch_size > 1) {
if (!transpose_weights)
doSgemm(CblasColMajor, CblasTrans, CblasNoTrans,
num_outputs, num_inputs, conf.cur_bunch_size, // dimensiones
norm_learn_rate, // alpha
input_error, // A
conf.max_bunch_size, // A stride
input, // B
conf.max_bunch_size, // B stride
beta, // beta
prev_weights_mat_ptr, // C
num_outputs, // C stride
input_error_shift, input_shift, 0, // desplazamientos
conf.use_cuda_flag);
else
doSgemm(CblasColMajor, CblasTrans, CblasNoTrans,
num_inputs, num_outputs, conf.cur_bunch_size, // dimensiones
norm_learn_rate, // alpha
input, // B
conf.max_bunch_size, // B stride
input_error, // A
conf.max_bunch_size, // A stride
beta, // beta
prev_weights_mat_ptr, // C
num_inputs, // C stride
input_shift, input_error_shift, 0, // desplazamientos
conf.use_cuda_flag);
}
else {
if (beta < 1.0f)
doSscal((num_inputs * num_outputs),
beta,
prev_weights_mat_ptr, 0, 1,
conf.use_cuda_flag);

doSger(CblasColMajor,
num_outputs, num_inputs,
norm_learn_rate,
input_error, input_error_shift, conf.max_bunch_size,
input, input_shift, conf.max_bunch_size,
prev_weights_mat_ptr, 0, num_outputs,
conf.use_cuda_flag);
if (!transpose_weights)
doSger(CblasColMajor,
num_outputs, num_inputs,
norm_learn_rate,
input_error, input_error_shift, conf.max_bunch_size,
input, input_shift, conf.max_bunch_size,
prev_weights_mat_ptr, 0, num_outputs,
conf.use_cuda_flag);
else
doSger(CblasColMajor,
num_inputs, num_outputs,
norm_learn_rate,
input, input_shift, conf.max_bunch_size,
input_error, input_error_shift, conf.max_bunch_size,
prev_weights_mat_ptr, 0, num_inputs,
conf.use_cuda_flag);
}
}

Expand Down

0 comments on commit ba6bf01

Please sign in to comment.