Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix cwise dimension overflow issue again.
If resulting dimensions overflow an int32, we were seeing an overflow and
crash due to size mismatch during broadcast assignment.  The cause is a simple
dimension type mismatch.

Note that actual tests for this are currently impractical, since successful
operations require more than 2^32 elements and OOM on most machines.

PiperOrigin-RevId: 479336566
  • Loading branch information
cantonios authored and tensorflower-gardener committed Oct 6, 2022
1 parent e6e7ca5 commit c5b3037
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions tensorflow/core/kernels/cwise_ops_common.h
Expand Up @@ -450,13 +450,15 @@ struct BinaryFunctor<CPUDevice, Functor, 2, false> {
Assign(d, out, in.unaryExpr(Unary(scalar.data())));
}

inline Eigen::IndexList<int, Eigen::type2index<1>> NByOne(int n) {
Eigen::IndexList<int, Eigen::type2index<1>> ret;
inline Eigen::IndexList<Eigen::DenseIndex, Eigen::type2index<1>> NByOne(
Eigen::DenseIndex n) {
Eigen::IndexList<Eigen::DenseIndex, Eigen::type2index<1>> ret;
ret.set(0, n);
return ret;
}
inline Eigen::IndexList<Eigen::type2index<1>, int> OneByM(int m) {
Eigen::IndexList<Eigen::type2index<1>, int> ret;
inline Eigen::IndexList<Eigen::type2index<1>, Eigen::DenseIndex> OneByM(
Eigen::DenseIndex m) {
Eigen::IndexList<Eigen::type2index<1>, Eigen::DenseIndex> ret;
ret.set(1, m);
return ret;
}
Expand Down Expand Up @@ -487,10 +489,10 @@ struct BinaryFunctor<CPUDevice, Functor, 2, false> {
// use_broadcast_optimization<T> are compile-time constant, gcc
// does a decent job avoiding generating code when conditions
// are not met.
const int a = in0.dimension(0); // in0 is shape [a, b]
const int b = in0.dimension(1);
const int c = in1.dimension(0); // in1 is shape [c, d]
const int d = in1.dimension(1);
const Eigen::DenseIndex a = in0.dimension(0); // in0 is shape [a, b]
const Eigen::DenseIndex b = in0.dimension(1);
const Eigen::DenseIndex c = in1.dimension(0); // in1 is shape [c, d]
const Eigen::DenseIndex d = in1.dimension(1);
if ((a == 1) && (d == 1)) {
auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c));
auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b));
Expand Down

0 comments on commit c5b3037

Please sign in to comment.