Skip to content

Commit

Permalink
WIP: axis付きのBroadcastTo
Browse files Browse the repository at this point in the history
  • Loading branch information
Masato Hori committed Apr 19, 2018
1 parent 36f502e commit 7f89cf1
Showing 1 changed file with 62 additions and 1 deletion.
63 changes: 62 additions & 1 deletion src/nbla/function/broadcast_to.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,68 @@ void BroadcastTo<T>::forward_impl(const Variables &inputs, const Variables &outp
std::copy(y, y+ysize, z+i*ysize);
}
} else {

// copy Y depending on the axis position
NBLA_CHECK(xss >= 2, error_code::value,
"X's dimension size should be greater than 1");
switch(xss) {
case 2:
// yss is always 1
switch(_axis) {
case 0:
// X: (2,3) Y: (2) axis=0
// copy Y values vertically
break;
case 1:
// X: (2,3) Y: (3) axis=1
// copy Y values horizontally
break;
default:
NBLA_ERROR(error_code::value, "Unexpected axis value");
}
break;
case 3:
// yss maybe 1 or 2
switch(yss) {
case 1:
switch(_axis) {
case 0:
// X: (2,3,4) Y: (2) axis=0
// copy Y values vertically
break;
case 1:
// X: (2,3,4) Y: (3) axis=1
// copy Y values vertically
break;
case 2:
// X: (2,3,4) Y: (4) axis=2
// copy Y values per block
break;
default:
NBLA_ERROR(error_code::value, "Unexpected axis value");
}
break;
case 2:
switch(_axis) {
case 0:
// X: (2,3,4) Y: (2,3) axis=0
// copy each Y column value vertically
break;
case 1:
// X: (2,3,4) Y: (3,4) axis=1
// copy Y values per block
break;
default:
NBLA_ERROR(error_code::value, "Unexpected axis value");
}
break;
default:
NBLA_ERROR(error_code::value, "Unexpected Y dimension size");
}
break;
case 4:
// yss maybe 1, 2, or 3
break;
}
}
}

Expand Down

0 comments on commit 7f89cf1

Please sign in to comment.