Skip to content

Commit

Permalink
[XLA:SPACE_TO_BATCH] correctly propagate on dot
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 633322160
  • Loading branch information
blakehechtman authored and tensorflower-gardener committed May 13, 2024
1 parent 11beb9d commit 696e681
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 10 deletions.
101 changes: 97 additions & 4 deletions third_party/xla/xla/service/space_to_batch_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <map>
#include <memory>
Expand Down Expand Up @@ -112,6 +113,8 @@ class ConvolutionVisitor {
// This function checks if the HLO instruction supports propagation.
bool SupportedOpForPropagation(HloInstruction* consumer,
HloInstruction* producer);
bool SupportedDotForPropagation(HloInstruction* consumer,
HloInstruction* producer);

// Method that checks validity of Broadcast propagation.
bool IsBroadcastPropagatable(HloInstruction* broadcast,
Expand Down Expand Up @@ -1561,13 +1564,55 @@ bool ConvolutionVisitor::IsOpcodeNonPropagatable(HloInstruction* consumer) {
switch (consumer->opcode()) {
case HloOpcode::kCustomCall:
return true;
case HloOpcode::kDot:
return !ctrl_.enable_propagations_on_dots;
default:
return false;
}
}

bool ConvolutionVisitor::SupportedDotForPropagation(HloInstruction* consumer,
HloInstruction* producer) {
if (consumer->opcode() != HloOpcode::kDot) {
return false;
}
auto operand = consumer->mutable_operand(0);
if (operand != producer || !instr_to_dim_map_.contains(operand)) {
return false;
}
const auto& dnums = consumer->dot_dimension_numbers();
const auto& contracting_dims = dnums.lhs_contracting_dimensions();
const auto& batch_dims = dnums.lhs_batch_dimensions();
auto result = instr_to_dim_map_[operand];
const int64_t old_batch_dim = result[DimMapper(SpaceToBatchDimMap::kBatch)];
const int64_t old_space_dim = result[DimMapper(SpaceToBatchDimMap::kSpace0)];
const int64_t old_feature_dim =
result[DimMapper(SpaceToBatchDimMap::kFeature)];
// No feature dimension in output
if (consumer->operand(1)->shape().rank() ==
batch_dims.size() + contracting_dims.size()) {
return false;
}
// If the convolution space or batch dimension are contracting or batch on
// the dot, do not propagate.
bool found = false;
for (auto dim : batch_dims) {
if (dim == old_batch_dim || dim == old_space_dim) {
return false;
}
if (dim == old_feature_dim) {
found = true;
}
}
if (!found) {
return false;
}
for (auto dim : contracting_dims) {
if (dim == old_batch_dim || dim == old_space_dim) {
return false;
}
}
return true;
}

bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer,
HloInstruction* producer) {
if (IsOpcodeNonPropagatable(consumer)) {
Expand Down Expand Up @@ -1682,6 +1727,10 @@ bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer,
return true;
}

if (SupportedDotForPropagation(consumer, producer)) {
return true;
}

if (consumer->opcode() == HloOpcode::kReduce) {
// Support only the trivial case where both batch and split spatial dim are
// being reduced
Expand Down Expand Up @@ -1964,6 +2013,50 @@ absl::StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer,
return true;
}

if (consumer->opcode() == HloOpcode::kDot) {
auto dim_map_val = instr_to_dim_map_[producer];
const int64_t old_batch_dim =
dim_map_val[DimMapper(SpaceToBatchDimMap::kBatch)];
const int64_t old_space_dim =
dim_map_val[DimMapper(SpaceToBatchDimMap::kSpace0)];
int64_t new_batch_dim = -1;
int64_t new_space_dim = -1;
int64_t outer = 0;
for (int64_t i = 0; i < producer->shape().rank(); ++i) {
if (absl::c_linear_search(
consumer->dot_dimension_numbers().lhs_batch_dimensions(), i) ||
absl::c_linear_search(
consumer->dot_dimension_numbers().lhs_contracting_dimensions(),
i)) {
continue;
}
if (i == old_batch_dim) {
new_batch_dim =
outer +
consumer->dot_dimension_numbers().lhs_batch_dimensions_size();
}
if (i == old_space_dim) {
new_batch_dim =
outer +
consumer->dot_dimension_numbers().lhs_batch_dimensions_size();
}
++outer;
}
std::vector<int64_t> dim_map(NumMappedDims());
dim_map[DimMapper(SpaceToBatchDimMap::kBatch)] = new_batch_dim;
dim_map[DimMapper(SpaceToBatchDimMap::kSpace0)] = new_space_dim;
dim_map[DimMapper(SpaceToBatchDimMap::kFeature)] =
consumer->shape().rank() - 1;
instr_to_dim_map_[consumer] = dim_map;
auto new_consumer = computation->AddInstruction(consumer->Clone());
new_consumer->mutable_shape()->mutable_dimensions()[new_batch_dim] =
producer->shape().dimensions(old_batch_dim);
new_consumer->mutable_shape()->mutable_dimensions()[new_space_dim] =
producer->shape().dimensions(old_space_dim);
old_to_new_instrs_[consumer] = new_consumer;
return true;
}

// TODO(b/189500737) : Consider a common way of propagation for
// slice/pad/reduce-window.
if (consumer->opcode() == HloOpcode::kPad) {
Expand Down Expand Up @@ -3619,7 +3712,8 @@ ConvolutionVisitor::DoesConvolutionFeedReduceWindowOrSelectAndScatter(
// Stop the search if these ops are encountered.
if (user->opcode() == HloOpcode::kConvolution ||
user->opcode() == HloOpcode::kPad ||
user->opcode() == HloOpcode::kTranspose) {
user->opcode() == HloOpcode::kTranspose ||
user->opcode() == HloOpcode::kDot) {
continue;
}
auto ret =
Expand Down Expand Up @@ -3986,7 +4080,6 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
}
TF_CHECK_OK(PropagateOnUsers(original_conv));


return OkStatus();
}

Expand Down
13 changes: 7 additions & 6 deletions third_party/xla/xla/service/space_to_batch_converter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ ENTRY computation {
EXPECT_GT(previous_reshape->operand(0)->shape().dimensions(batch_dim), 4);
}

TEST_F(SpaceToBatchConverterTest, NoPropagateThroughDot) {
TEST_F(SpaceToBatchConverterTest, PropagateThroughDot) {
std::string hlo_string = R"(
HloModule module
Expand All @@ -256,9 +256,10 @@ TEST_F(SpaceToBatchConverterTest, NoPropagateThroughDot) {
%p1 = bf16[3,3,32,32] parameter(1)
%convolution = bf16[1,256,256,32] convolution(%p0, %p1), window={size=3x3},
dim_labels=b01f_01io->b01f
%p2 = bf16[1,256,256,32] parameter(2)
ROOT %dot.5010 = bf16[1,256,32,32] dot(%convolution, %p2), lhs_batch_dims={0,1},
lhs_contracting_dims={2}, rhs_batch_dims={0,2}, rhs_contracting_dims={1}
%p2 = bf16[32,32] parameter(2)
ROOT %dot.5010 = bf16[1,256,256,32] dot(%convolution, %p2),
lhs_contracting_dims={3},
rhs_contracting_dims={0}
}
)";
Expand All @@ -267,8 +268,8 @@ TEST_F(SpaceToBatchConverterTest, NoPropagateThroughDot) {

SpaceToBatchConverter converter(
SpaceToBatchController{true, true, true, true, 8});
// Test that we do not start space-to-batch on conv->dot chains
ASSERT_FALSE(converter.Run(module.get()).value());
// Test that we do not start space-to-batch on conv->dot chains.
ASSERT_TRUE(converter.Run(module.get()).value());
}

} // namespace
Expand Down

0 comments on commit 696e681

Please sign in to comment.