Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix segfault on OOM in Conv2D.
PiperOrigin-RevId: 404655317
Change-Id: I33588dbd3f5d0fef980e3c908bf5515a9ee09ce7
  • Loading branch information
reedwm authored and tensorflower-gardener committed Oct 20, 2021
1 parent e021452 commit e7f4975
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions tensorflow/core/kernels/conv_ops.cc
Expand Up @@ -183,20 +183,29 @@ struct LaunchGrouped {
auto on_shuffled = [&]() { shuffles_completed.DecrementCount(); };

// Shuffle input into temporary tensor.
Tensor input_shuffled(input.dtype(), TensorShape(post_shuffle(input)));
Tensor input_shuffled;
OP_REQUIRES_OK(
ctx, ctx->allocate_temp(input.dtype(), TensorShape(post_shuffle(input)),
&input_shuffled));
input_shuffled.tensor<T, 5>().device(device, on_shuffled) =
input.shaped<T, 5>(pre_shuffle(input)).shuffle(shuffle);

// Shuffle filter into temporary tensor.
Tensor filter_shuffled(filter.dtype(), TensorShape(post_shuffle(filter)));
Tensor filter_shuffled;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(filter.dtype(),
TensorShape(post_shuffle(filter)),
&filter_shuffled));
filter_shuffled.tensor<T, 5>().device(device, on_shuffled) =
filter.shaped<T, 5>(pre_shuffle(filter)).shuffle(shuffle);

// Wait for the completion of input/filter shuffles.
shuffles_completed.Wait();

// Write group convolution results into temporary output tensor.
Tensor output_shuffled(output->dtype(), TensorShape(post_shuffle(*output)));
Tensor output_shuffled;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(output->dtype(),
TensorShape(post_shuffle(*output)),
&output_shuffled));

for (int64_t i = 0; i < num_groups; ++i) {
// TODO(ezhulenev): Run this loop using `parallelFor` (regular parallelFor
Expand Down

0 comments on commit e7f4975

Please sign in to comment.