Skip to content

Commit

Permalink
葉にコピーするケースへ対応
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Apr 20, 2018
1 parent 7f89cf1 commit cf6957c
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 34 deletions.
11 changes: 10 additions & 1 deletion python/test/function/test_broadcast_to.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,23 @@
function_tester,
list_ctx_and_func_name)

def copying_to_leaf(x, y, axis):
return (len(x.shape) - len(y.shape) - axis) == 0

def ref_broadcast_to(x, y, axis):
if axis < 0:
if axis < 0 or copying_to_leaf(x, y, axis):
return np.ones(x.shape) * y
else:
return np.ones(x.shape)


PARAMS = [
((2, 3), (3), 1),
((2, 3, 4), (4), 2),
((2, 3, 4), (3, 4), 1),
((2, 3, 4, 5), (5), 3),
((2, 3, 4, 5), (4, 5), 2),
((2, 3, 4, 5), (3, 4, 5), 1),
((2, 3, 4, 5), (5), -1),
((2, 3, 4, 5), (4, 5), -1),
#((2, 3, 4, 5), (3, 4), 1),
Expand Down
128 changes: 95 additions & 33 deletions src/nbla/function/broadcast_to.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,28 +35,28 @@ template <typename T>
void BroadcastTo<T>::setup_impl(const Variables &inputs, const Variables &outputs) {
const Shape_t xs = inputs[0]->shape();
const Shape_t ys = inputs[1]->shape();
const Size_t xss = xs.size();
const Size_t yss = ys.size();
NBLA_CHECK(xss >= yss, error_code::value,
const Size_t xdim = xs.size();
const Size_t ydim = ys.size();
NBLA_CHECK(xdim >= ydim, error_code::value,
"BroadcastTo expects Y (variable to be broadcasted) to be smaller than or equal to X (target variable we want to fit to): %d vs %d",
yss, xss);
ydim, xdim);
if (axis_ < 0) {
// No axis was specified.
// Check if y shape can fit x shape from the tail dimension
const Size_t xofs = xss - yss;
for (Size_t i=yss-1; i>=0; i--) {
const Size_t xofs = xdim - ydim;
for (Size_t i=ydim-1; i>=0; i--) {
Size_t xds = xs[xofs+i];
Size_t yds = ys[i];
NBLA_CHECK(xds == yds, error_code::value,
"Dimension %d's size of X and Y do not match: %d vs %d",
xofs+i, xds, yds);
}
} else {
NBLA_CHECK(axis_ < xss, error_code::value,
NBLA_CHECK(axis_ < xdim, error_code::value,
"Specified axis index %d must be within the size of the actual input dimension %d",
axis_, xss);
axis_, xdim);
// Check if y shape can fit x shape from the axis index
for (Size_t i=0; i<yss; i++) {
for (Size_t i=0; i<ydim; i++) {
Size_t xds = xs[i+axis_];
Size_t yds = ys[i];
NBLA_CHECK(xds == yds, error_code::value,
Expand All @@ -69,81 +69,89 @@ void BroadcastTo<T>::setup_impl(const Variables &inputs, const Variables &output
outputs[0]->reshape(xs, true);
}

// Copy Y block to Z's tail
template <typename T>
static void copy_block_to_tail(
T* z, const T* y, const Shape_t &xs,
Size_t xdim, Size_t ydim, Size_t ysize) {
const Size_t diff = xdim - ydim;
Size_t loop_num = 1;
for (Size_t i=0; i<diff; i++) {
loop_num *= xs[i];
}
for (Size_t i=0; i<loop_num; i++) {
std::copy(y, y+ysize, z+i*ysize);
}
}

template <typename T>
void BroadcastTo<T>::forward_impl(const Variables &inputs, const Variables &outputs) {
const T *y = inputs[1]->get_data_pointer<T>(this->ctx_);
T *z = outputs[0]->cast_data_and_get_pointer<T>(this->ctx_);
const Shape_t xs = inputs[0]->shape();
const Shape_t ys = inputs[1]->shape();
const Size_t ysize = inputs[1]->size();
const Size_t xss = xs.size();
const Size_t yss = ys.size();
if (xss == yss) {
const Size_t xdim = xs.size();
const Size_t ydim = ys.size();
if (xdim == ydim) {
// X and Y have exactly the same shape
// Copy Y to Z
std::copy(y, y+ysize, z);
return;
}
if (axis_ < 0) {
// copy the whole Y block to Z per stride
const Size_t diff = xss - yss;
Size_t loop_num = 1;
for (Size_t i=0; i<diff; i++) {
loop_num *= xs[i];
}
for (Size_t i=0; i<loop_num; i++) {
std::copy(y, y+ysize, z+i*ysize);
}
copy_block_to_tail(z, y, xs, xdim, ydim, ysize);
} else {
// copy Y depending on the axis position
NBLA_CHECK(xss >= 2, error_code::value,
NBLA_CHECK(xdim >= 2, error_code::value,
"X's dimension size should be greater than 1");
switch(xss) {
switch(xdim) {
case 2:
// yss is always 1
switch(_axis) {
// Y dimension size is always 1
switch(axis_) {
case 0:
// X: (2,3) Y: (2) axis=0
// copy Y values vertically
break;
case 1:
// X: (2,3) Y: (3) axis=1
// copy Y values horizontally
copy_block_to_tail(z, y, xs, xdim, ydim, ysize);
break;
default:
NBLA_ERROR(error_code::value, "Unexpected axis value");
}
break;
case 3:
// yss maybe 1 or 2
switch(yss) {
// Y dimension size maybe 1 or 2
switch(ydim) {
case 1:
switch(_axis) {
switch(axis_) {
case 0:
// X: (2,3,4) Y: (2) axis=0
// copy Y values vertically
// copy Y values per block
break;
case 1:
// X: (2,3,4) Y: (3) axis=1
// copy Y values vertically
break;
case 2:
// X: (2,3,4) Y: (4) axis=2
// copy Y values per block
copy_block_to_tail(z, y, xs, xdim, ydim, ysize);
break;
default:
NBLA_ERROR(error_code::value, "Unexpected axis value");
}
break;
case 2:
switch(_axis) {
switch(axis_) {
case 0:
// X: (2,3,4) Y: (2,3) axis=0
// copy each Y column value vertically
break;
case 1:
// X: (2,3,4) Y: (3,4) axis=1
// copy Y values per block
copy_block_to_tail(z, y, xs, xdim, ydim, ysize);
break;
default:
NBLA_ERROR(error_code::value, "Unexpected axis value");
Expand All @@ -154,8 +162,62 @@ void BroadcastTo<T>::forward_impl(const Variables &inputs, const Variables &outp
}
break;
case 4:
// yss maybe 1, 2, or 3
// Y dimension size maybe 1, 2, or 3
switch(ydim) {
case 1:
switch(axis_) {
case 0:
// X: (2,3,4,5) Y: (2) axis=0
break;
case 1:
// X: (2,3,4,5) Y: (3) axis=1
break;
case 2:
// X: (2,3,4,5) Y: (4) axis=2
break;
case 3:
// X: (2,3,4,5) Y: (5) axis=3
copy_block_to_tail(z, y, xs, xdim, ydim, ysize);
break;
default:
NBLA_ERROR(error_code::value, "Unexpected axis value");
}
break;
case 2:
switch(axis_) {
case 0:
// X: (2,3,4,5) Y: (2,3) axis=0
break;
case 1:
// X: (2,3,4,5) Y: (3,4) axis=1
break;
case 2:
// X: (2,3,4,5) Y: (4,5) axis=2
copy_block_to_tail(z, y, xs, xdim, ydim, ysize);
break;
default:
NBLA_ERROR(error_code::value, "Unexpected axis value");
}
break;
case 3:
switch(axis_) {
case 0:
// X: (2,3,4,5) Y: (2,3,4) axis=0
break;
case 1:
// X: (2,3,4,5) Y: (3,4,5) axis=1
copy_block_to_tail(z, y, xs, xdim, ydim, ysize);
break;
default:
NBLA_ERROR(error_code::value, "Unexpected axis value");
}
break;
default:
NBLA_ERROR(error_code::value, "Unexpected Y dimension size");
}
break;
default:
NBLA_ERROR(error_code::value, "Unexpected X dimension size");
}
}
}
Expand Down

0 comments on commit cf6957c

Please sign in to comment.