Skip to content

Commit

Permalink
リファクタ
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Apr 20, 2018
1 parent 8a7b3c6 commit dfb35aa
Showing 1 changed file with 23 additions and 18 deletions.
41 changes: 23 additions & 18 deletions src/nbla/function/broadcast_to.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,25 @@ static void copy_value_vertically_to_block(
}
}

template <typename T>
static void copy_buf_vertically_to_block(
T* z, const T* y,
Size_t block_num,
Size_t y_buf_width,
Size_t z_block_height,
Size_t z_block_width) {
const Size_t z_block_size = z_block_height*z_block_width;
for (Size_t b=0; b<block_num; b++) {
const T* yrow = &y[b*y_buf_width];
T* zblock = &z[b*z_block_size];
for (Size_t v=0; v<z_block_height; v++) {
T val = yrow[v];
T* zrow = &zblock[v*z_block_width];
std::fill(zrow, zrow+z_block_width, val);
}
}
}

template <typename T>
void BroadcastTo<T>::forward_impl(const Variables &inputs, const Variables &outputs) {
const T *y = inputs[1]->get_data_pointer<T>(this->ctx_);
Expand Down Expand Up @@ -174,24 +193,10 @@ void BroadcastTo<T>::forward_impl(const Variables &inputs, const Variables &outp
case 2:
switch(axis_) {
case 0:
{
// X: (2,3,4) Y: (2,3) axis=0
// copy Y values vertically per channel
Size_t x_chan = xs[0];
Size_t x_height = xs[1];
Size_t x_width = xs[2];
Size_t y_width = ys[1];
const Size_t x_size = x_height*x_width;
for (Size_t c=0; c<x_chan; c++) {
const T* ychan = &y[c*y_width];
T* zchan = &z[c*x_size];
for (Size_t v=0; v<x_height; v++) {
T val = ychan[v];
T* zrow = &zchan[v*x_width];
std::fill(zrow, zrow+x_width, val);
}
}
}
// X: (2,3,4) Y: (2,3) axis=0
copy_buf_vertically_to_block(z, y,
xs[0], ys[1],
xs[1], xs[2]);
break;
case 1:
// X: (2,3,4) Y: (3,4) axis=1
Expand Down

0 comments on commit dfb35aa

Please sign in to comment.