Skip to content

Commit

Permalink
Revert PR #48063: Add _FusedBatchNormGrad to support side_input and a…
Browse files Browse the repository at this point in the history
…ctivation

Still hits "unsupported cudnn" error

To repro, run (from tensorflow_models repository):

```
//tensorflow_models/official/benchmark:keras_imagenet_benchmark
 -- --logtostderr --benchmarks=Resnet50KerasBenchmarkSynth.benchmark_xla_1_gpu_fp16\$
```

PiperOrigin-RevId: 372380247
Change-Id: Iae61fa8095b0c60716cfdf147544d8f7b3e3f493
  • Loading branch information
cheshire authored and tensorflower-gardener committed May 6, 2021
1 parent edf29e8 commit 3eeb732
Show file tree
Hide file tree
Showing 13 changed files with 138 additions and 905 deletions.
5 changes: 0 additions & 5 deletions tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.cc
Expand Up @@ -190,24 +190,19 @@ template <typename ElemType>
void RunCudnnBatchNormBackwardImpl(CudnnBatchNormBackwardParams* params,
se::Stream* stream) {
se::DeviceMemory<float> null_device_ptr(nullptr);
se::DeviceMemory<ElemType> null_elem_device_ptr(nullptr);
auto output_grad_data = se::DeviceMemory<ElemType>(params->output_grad_data);
stream->ThenBatchNormalizationBackward(
se::DeviceMemory<ElemType>(params->grad_output), //
se::DeviceMemory<ElemType>(params->common.operand), //
params->common.scale, //
/*offset=*/null_device_ptr, //
params->mean, //
params->inv_stddev, //
/*y=*/null_elem_device_ptr, //
params->common.operand_desc, //
params->common.scale_offset_desc, //
params->common.epsilon, //
se::dnn::ActivationMode::kNone, //
&output_grad_data, //
&params->output_grad_scale, //
&params->output_grad_offset, //
/*side_input_backprop=*/&null_elem_device_ptr, //
/*reserve_space_allocator=*/nullptr, //
/*workspace_allocator=*/nullptr);
}
Expand Down
246 changes: 2 additions & 244 deletions tensorflow/core/grappler/optimizers/remapper.cc
Expand Up @@ -69,7 +69,6 @@ constexpr char kFusedConv2D[] = "_FusedConv2D";
constexpr char kFusedMatMul[] = "_FusedMatMul";
constexpr char kFusedDepthwiseConv2dNative[] = "_FusedDepthwiseConv2dNative";
constexpr char kFusedBatchNormEx[] = "_FusedBatchNormEx";
constexpr char kFusedBatchNormGradEx[] = "_FusedBatchNormGradEx";

constexpr char kDataFormat[] = "data_format";
constexpr char kIsTraining[] = "is_training";
Expand Down Expand Up @@ -112,17 +111,6 @@ struct FusedBatchNormEx {
int invalidated = kMissingIndex;
};

// FusedBatchNormGrad with fused side output and/or activation.
struct FusedBatchNormGradEx {
FusedBatchNormGradEx() = default;

int fused_batch_norm_grad = kMissingIndex;
int activation_grad = kMissingIndex;
int side_input_grad = kMissingIndex;
// Add node of the forward pass to access its "offset" input.
int fwd_fused_batch_norm = kMissingIndex;
};

// Contraction node followed by a BiasAdd.
struct ContractionWithBiasAdd {
ContractionWithBiasAdd() = default;
Expand Down Expand Up @@ -982,113 +970,6 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
return false;
}

bool FindFusedBatchNormGradEx(const RemapperContext& ctx, int node_index,
FusedBatchNormGradEx* matched) {
// Root of the pattern must be a FusedBatchNormGrad.
const auto* node_view = ctx.graph_view.GetNode(node_index);

// Returns true iff the node is a compatible FusedBatchNormGrad node.
const auto valid_batch_norm_grad =
[&](const utils::MutableNodeView& fused_batch_norm_grad) -> bool {
const auto* node_def = fused_batch_norm_grad.node();
if (!IsFusedBatchNormGrad(*node_def) ||
HasControlFaninOrFanout(fused_batch_norm_grad))
return false;

// We fuse FusedBatchNormGrad on GPU.
if (!NodeIsOnGpu(node_def)) return false;

// Data type must be DT_HALF.
DataType t_dtype = GetDataTypeFromAttr(*node_def, "T");
if (t_dtype != DT_HALF) return false;

// We rely on cuDNN for computing FusedBatchNormGrad with side
// outputs and activation. cuDNN only supports NHWC data layout.
string data_format;
if (!GetNodeAttr(*node_def, kDataFormat, &data_format).ok()) return false;
if (data_format != "NHWC") return false;

// Channel dimension must be a multiple of 4.
const auto& props =
ctx.graph_properties.GetInputProperties(node_def->name());
const bool valid_channel_dim = !props.empty() &&
props[0].shape().dim_size() == 4 &&
props[0].shape().dim(3).size() % 4 == 0;
if (!valid_channel_dim) return false;

// cuDNN must support CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode.
if (!BatchnormSpatialPersistentEnabled()) return false;

// FusedBatchNormV2 and V3 have an extra type parameter.
if (node_def->op() != "FusedBatchNorm" &&
!HasDataType(node_def, DT_FLOAT, "U"))
return false;

return true;
};

if (!valid_batch_norm_grad(*node_view)) return false;

if (node_view->NumRegularFanins() < 1) return false;

const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
const auto* relugrad_node_view = regular_fanin_0.node_view();
const auto* relugrad_node_def = relugrad_node_view->node();
bool is_relugrad = IsReluGrad(*relugrad_node_def);

if (!is_relugrad || HasControlFaninOrFanout(*relugrad_node_view) ||
IsInPreserveSet(ctx, relugrad_node_def))
return false;

if (relugrad_node_view->NumRegularFanins() < 1) return false;
// Find its corresponding forward node. We need the node to determine if the
// type is bn+add+act or bn+act. Also, we need to access its "offset" input.
const auto& fanin_1 = relugrad_node_view->GetRegularFanin(1);
const auto* fwd_node_view = fanin_1.node_view();
FusedBatchNormEx fwd_matched;
FindFusedBatchNormEx(ctx, fwd_node_view->node_index(), &fwd_matched);
bool fwd_bn_act_used = fwd_matched.activation != kMissingIndex &&
fwd_matched.side_input == kMissingIndex;
bool fwd_bn_add_act_used = fwd_matched.activation != kMissingIndex &&
fwd_matched.side_input != kMissingIndex;

// Check that only 1 node consumes the output of the ReluGrad node.
if (fwd_bn_act_used && relugrad_node_view->GetRegularFanout(0).size() == 1) {
matched->activation_grad = regular_fanin_0.node_index();
matched->fused_batch_norm_grad = node_index;
matched->fwd_fused_batch_norm = fwd_matched.fused_batch_norm;
return true;
}

// Check that only 2 nodes consume the output of the ReluGrad node.
if (fwd_bn_add_act_used &&
relugrad_node_view->GetRegularFanout(0).size() == 2) {
const auto& fanouts_at_port_0 = relugrad_node_view->GetRegularFanouts()[0];
const auto* fanout_0_node_view =
ctx.graph_view.GetNode(fanouts_at_port_0[0].node_view()->GetName());
const auto* fanout_1_node_view =
ctx.graph_view.GetNode(fanouts_at_port_0[1].node_view()->GetName());
const auto* fanout_0_node_def = fanout_0_node_view->node();
const auto* fanout_1_node_def = fanout_1_node_view->node();
const auto* node_def = node_view->node();

matched->activation_grad = regular_fanin_0.node_index();
matched->fused_batch_norm_grad = node_index;
matched->fwd_fused_batch_norm = fwd_matched.fused_batch_norm;

if (fanout_0_node_def == node_def) {
matched->side_input_grad = fanout_1_node_view->node_index();
return true;
}

if (fanout_1_node_def == node_def) {
matched->side_input_grad = fanout_0_node_view->node_index();
return true;
}
}

return false;
}
void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d,
const NodeDef* activation = nullptr) {
DCHECK(IsConv2D(conv2d)) << "Input node must be a Conv2D";
Expand Down Expand Up @@ -1157,27 +1038,6 @@ void CopyFusedBatchNormAttributes(const NodeDef& fused_batch_norm,
}
}

void CopyFusedBatchNormGradAttributes(const NodeDef& fused_batch_norm_grad,
NodeDef* fused_batch_norm_grad_ex) {
DCHECK(IsFusedBatchNormGrad(fused_batch_norm_grad))
<< "Input node must be a FusedBatchNormGrad";

auto* attr = fused_batch_norm_grad_ex->mutable_attr();
auto src_attr = fused_batch_norm_grad.attr();

(*attr)["T"] = src_attr.at("T");
(*attr)["is_training"] = src_attr.at("is_training");
(*attr)["data_format"] = src_attr.at("data_format");
(*attr)["epsilon"] = src_attr.at("epsilon");

// FusedBatchNormV2 and V3 have an extra type parameter.
if (fused_batch_norm_grad.op() != "FusedBatchNormGrad") {
SetAttrValue(src_attr.at("U"), &(*attr)["U"]);
} else {
SetAttrValue(DT_FLOAT, &(*attr)["U"]);
}
}

void CopyMatMulAttributes(const NodeDef& matmul, NodeDef* fused_matmul,
const NodeDef* activation = nullptr) {
DCHECK(IsMatMul(matmul)) << "Input node must be a MatMul";
Expand Down Expand Up @@ -1589,79 +1449,6 @@ Status AddFusedBatchNormExNode(RemapperContext* ctx,
return Status::OK();
}

Status AddFusedBatchNormGradExNode(RemapperContext* ctx,
const FusedBatchNormGradEx& matched,
std::vector<bool>* invalidated_nodes,
std::vector<bool>* nodes_to_delete) {
const GraphDef* graph = ctx->graph_view.graph();
const NodeDef& fused_batch_norm_grad =
graph->node(matched.fused_batch_norm_grad);
const NodeDef& activation_grad = graph->node(matched.activation_grad);
const NodeDef& fwd_fused_batch_norm =
graph->node(matched.fwd_fused_batch_norm);

VLOG(2) << "Fuse FusedBatchNormGrad with " << activation_grad.op() << ": "
<< " fused_batch_norm_grad=" << fused_batch_norm_grad.name()
<< " side_input="
<< (matched.side_input_grad != kMissingIndex
? graph->node(matched.side_input_grad).name()
: "<none>")
<< " activation=" << activation_grad.name()
<< " corresponding FusedBatchNorm=" << fwd_fused_batch_norm.name();

NodeDef fused_op;
fused_op.set_op(kFusedBatchNormGradEx);
fused_op.set_name(fused_batch_norm_grad.name());
fused_op.set_device(fused_batch_norm_grad.device());

fused_op.add_input(activation_grad.input(0)); // 0: y_backprop
fused_op.add_input(fused_batch_norm_grad.input(1)); // 1: x
fused_op.add_input(fused_batch_norm_grad.input(2)); // 2: scale
fused_op.add_input(fused_batch_norm_grad.input(3)); // 3: reserve_space_1
fused_op.add_input(fused_batch_norm_grad.input(4)); // 4: reserve_space_2
fused_op.add_input(fused_batch_norm_grad.input(5)); // 5: reserve_space_3
fused_op.add_input(fwd_fused_batch_norm.input(2)); // 6: offset
fused_op.add_input(activation_grad.input(1)); // 7: y

CopyFusedBatchNormGradAttributes(fused_batch_norm_grad, &fused_op);

auto* attrs = fused_op.mutable_attr();
// Only support Relu mode.
SetAttrValue("Relu", &(*attrs)["activation_mode"]);

if (matched.side_input_grad != kMissingIndex) {
SetAttrValue(1, &(*attrs)["num_side_inputs"]);
} else {
SetAttrValue(0, &(*attrs)["num_side_inputs"]);
}

NodeDef identity_op;
identity_op.set_op("Identity");
identity_op.set_name(activation_grad.name());
identity_op.set_device(fused_batch_norm_grad.device());
identity_op.add_input(absl::StrCat(fused_batch_norm_grad.name(), ":5"));
(*identity_op.mutable_attr())["T"] = attrs->at("T");

utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
Status status;
mutation->AddNode(std::move(fused_op), &status);
TF_RETURN_IF_ERROR(status);
if (matched.side_input_grad != kMissingIndex) {
mutation->AddNode(std::move(identity_op), &status);
TF_RETURN_IF_ERROR(status);
}
TF_RETURN_IF_ERROR(mutation->Apply());

(*invalidated_nodes)[matched.fused_batch_norm_grad] = true;
if (matched.side_input_grad != kMissingIndex) {
(*invalidated_nodes)[matched.activation_grad] = true;
} else {
(*nodes_to_delete)[matched.activation_grad] = true;
}

return Status::OK();
}

Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
const GraphDef* graph = ctx->graph_view.graph();
const NodeDef& fused_node = graph->node(matched.fused_batch_norm);
Expand Down Expand Up @@ -1909,7 +1696,6 @@ bool IsContractionWithAdd(const RemapperContext& ctx, int node_index) {
// (2) Fusing side input and/or activation into FusedBatchNorm.
// (3) Fusing Conv2D biasadd and relu on GPU
// (4) INTEL_MKL specific: Conv2D -> Add or Conv2D -> BiasAdd -> Add.
// (5) Fusing side output and/or activation into FusedBatchNormGrad.
bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
// Candidate for a FusedBatchNorm splitting.
const auto* node_view = ctx.graph_view.GetNode(node_index);
Expand Down Expand Up @@ -1980,23 +1766,6 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
return false;
};

// Candidate for a FusedBatchNormGrad fusion.
const auto is_batch_norm_grad_fusion_candidate = [&]() -> bool {
if (!IsFusedBatchNormGrad(*node_def)) return false;

if (node_view->NumRegularFanins() < 1) return false;
const auto& bn_fanin_0 = node_view->GetRegularFanin(0);
const auto* bn_fanin_0_node_view = bn_fanin_0.node_view();
const auto* bn_fanin_0_node_def = bn_fanin_0_node_view->node();

if (IsReluGrad(*bn_fanin_0_node_def)) {
// ReluGrad + FusedBatchNormGrad.
return true;
}

return false;
};

// TODO(intel-tf): Clean up #ifdef.
#ifdef INTEL_MKL
(void)is_relu_biasadd_conv2d_candidate; // To fix unused variable error.
Expand All @@ -2005,12 +1774,10 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
IsContractionWithAdd(ctx, node_index);
else
return is_relu_biasadd_conv2d_candidate() || is_batch_norm_candidate() ||
is_batch_norm_fusion_candidate() ||
is_batch_norm_grad_fusion_candidate();
is_batch_norm_fusion_candidate();
#else
return is_relu_biasadd_conv2d_candidate() || is_batch_norm_candidate() ||
is_batch_norm_fusion_candidate() ||
is_batch_norm_grad_fusion_candidate();
is_batch_norm_fusion_candidate();
#endif // INTEL_MKL
}

Expand Down Expand Up @@ -2158,15 +1925,6 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
continue;
}

FusedBatchNormGradEx fused_batch_norm_grad_ex;
if (allow_non_differentiable_rewrites &&
FindFusedBatchNormGradEx(ctx, i, &fused_batch_norm_grad_ex)) {
TF_RETURN_IF_ERROR(
AddFusedBatchNormGradExNode(&ctx, fused_batch_norm_grad_ex,
&invalidated_nodes, &nodes_to_delete));
continue;
}

// During inference, most of the inputs to FusedBatchNorm are constant, and
// we can therefore replace the op with a much cheaper set of primitives.
FusedBatchNorm fused_batch_norm;
Expand Down

2 comments on commit 3eeb732

@kaixih
Copy link
Contributor

@kaixih kaixih commented on 3eeb732 May 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cheshire Could you please share more info on how to repro the "unsupported cudnn" error. Like the command lines and GPUs.

I have tried to run the tf models benchmark on my V100 but cannot repro the issue. Below is how I did it and maybe I missed something:
/home/workspace/models contains the tensorflow models repo with "r2.4" branch.

MODEL_DIR=/home/workspace/models

export PYTHONPATH=$PYTHONPATH:$MODEL_DIR

cd $MODEL_DIR
pip install -r official/requirements.txt

cd $MODEL_DIR/official/benchmark
python -u test.py

What is inside test.py:

from official.benchmark.keras_imagenet_benchmark import  Resnet50KerasBenchmarkSynth

x = Resnet50KerasBenchmarkSynth()
x.benchmark_xla_1_gpu_fp16()

@kaixih
Copy link
Contributor

@kaixih kaixih commented on 3eeb732 May 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or if possible can you help check if the issue persists with this #48893.

Please sign in to comment.