Skip to content

Commit

Permalink
sync changes related to adding support for 3D convolutions
Browse files Browse the repository at this point in the history
  • Loading branch information
deven-amd committed Jul 2, 2019
1 parent 619f8bc commit 1e24623
Showing 1 changed file with 44 additions and 43 deletions.
87 changes: 44 additions & 43 deletions tensorflow/stream_executor/rocm/rocm_dnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ class MIOpenHandle {
namespace wrap {

#ifdef PLATFORM_GOOGLE

#define STREAM_EXECUTOR_MIOPEN_WRAP(__name) \
struct WrapperShim__##__name { \
template <typename... Args> \
Expand Down Expand Up @@ -162,6 +161,7 @@ namespace wrap {
__macro(miopenBatchNormalizationForwardInference) \
__macro(miopenBatchNormalizationForwardTraining) \
__macro(miopenGetConvolutionForwardOutputDim) \
__macro(miopenGetConvolutionNdForwardOutputDim) \
__macro(miopenFindConvolutionForwardAlgorithm) \
__macro(miopenCreateTensorDescriptor) \
__macro(miopenDestroyTensorDescriptor) \
Expand All @@ -183,7 +183,9 @@ namespace wrap {
__macro(miopenConvolutionBackwardBias) \
__macro(miopenConvolutionForwardGetWorkSpaceSize) \
__macro(miopenInitConvolutionDescriptor) \
__macro(miopenInitConvolutionNdDescriptor) \
__macro(miopenGetConvolutionDescriptor) \
__macro(miopenGetConvolutionNdDescriptor) \
__macro(miopenSetConvolutionGroupCount) \
__macro(miopenSet4dTensorDescriptor) \
__macro(miopenGetTensorDescriptor) \
Expand Down Expand Up @@ -282,28 +284,29 @@ uint64 GetHashValue(miopenTensorDescriptor_t tensor_desc) {

uint64 GetHashValue(miopenConvolutionDescriptor_t conv_desc) {
miopenConvolutionMode_t c_mode = miopenConvolution;
int pad_h = 0, pad_w = 0, u = 0, v = 0, dilation_h = 0, dilation_w = 0;
wrap::miopenGetConvolutionDescriptor(conv_desc, &c_mode, &pad_h, &pad_w, &u,
&v, &dilation_h, &dilation_w);
int nd = 0;
wrap::miopenGetConvolutionNdDescriptor(conv_desc, 0, &nd, nullptr, nullptr,
nullptr, &c_mode);

std::vector<int> stride(nd);
std::vector<int> pad(nd);
std::vector<int> dilation(nd);

wrap::miopenGetConvolutionNdDescriptor(
conv_desc, nd, &nd, pad.data(), stride.data(), dilation.data(), &c_mode);

uint64 hash_value = tensorflow::hash<int>()(c_mode);
hash_value =
tensorflow::Hash64Combine(hash_value, tensorflow::hash<int>()(pad_h));
hash_value =
tensorflow::Hash64Combine(hash_value, tensorflow::hash<int>()(pad_w));
hash_value =
tensorflow::Hash64Combine(hash_value, tensorflow::hash<int>()(u));
hash_value =
tensorflow::Hash64Combine(hash_value, tensorflow::hash<int>()(v));
hash_value = tensorflow::Hash64Combine(hash_value,
tensorflow::hash<int>()(dilation_h));
hash_value = tensorflow::Hash64Combine(hash_value,
tensorflow::hash<int>()(dilation_w));
auto hash64Combine = [&hash_value](int element) {
tensorflow::Hash64Combine(hash_value, tensorflow::hash<int>()(element));
};
std::for_each(pad.begin(), pad.end(), hash64Combine);
std::for_each(stride.begin(), stride.end(), hash64Combine);
std::for_each(dilation.begin(), dilation.end(), hash64Combine);

return hash_value;
}

// Class to implement a cache of compiled fusion plans.
// Class to implement a cache of compiled fusion plans
class CachedFusionPlans {
public:
// Check if we already have a fusion_plan corresponding to the given hash
Expand Down Expand Up @@ -340,7 +343,7 @@ class CachedFusionPlans {
return found_cached_plan;
}

// Need to figure out the right place to call this routine.
// Need to figure out the right place to call this routine
static void Clear() {
absl::MutexLock lock{&cached_plans_mutex};

Expand All @@ -357,24 +360,24 @@ class CachedFusionPlans {
unsupported_plans.clear();
}

// Is the Fusion plan corresponding to this hash unsupported.
// Is the Fusion plan corresponding to this hash unsupported
static bool IsUnsupportedFusionPlan(uint64 hash) {
absl::MutexLock lock{&cached_plans_mutex};
return unsupported_plans.count(hash) > 0;
}

// Mark the given hash value as corresponding to an unsupported fusion plan.
// Mark the given hash value as corresponding to an unsupported fusion plan
static void MarkFusionPlanUnsupported(uint64 hash) {
absl::MutexLock lock{&cached_plans_mutex};
unsupported_plans.insert(hash);
}

private:
// Mutex to guard access to all data within this class.
// Mutex to guard access to all data within this class
static absl::Mutex cached_plans_mutex;

// Map of hash-value to MIOpen Fusion plan descriptors.
// Need to be able share this across more than one stream and hence static.
// Map of hash-value to MIOpen Fusion plan descriptors
// Need to be able share this across more than one stream and hence static
static std::map<uint64, miopenFusionPlanDescriptor_t> cached_plans;

// Set of hash-values that correspond to MIOpen Fusion plans that will fail
Expand All @@ -386,6 +389,10 @@ absl::Mutex CachedFusionPlans::cached_plans_mutex;
std::map<uint64, miopenFusionPlanDescriptor_t> CachedFusionPlans::cached_plans;
std::set<uint64> CachedFusionPlans::unsupported_plans;

} // namespace

namespace {

miopenHandle_t ToHandle(void* opaque_handle) {
return static_cast<miopenHandle_t>(opaque_handle);
}
Expand Down Expand Up @@ -538,10 +545,6 @@ class ScopedTensorDescriptor {
case dnn::DataLayout::kBatchYXDepth:
case dnn::DataLayout::kBatchDepthYX: {
const int nd = batch_descriptor.ndims() + 2;
if (nd != 4) {
LOG(FATAL) << "miopen only supports 4D tensors, dim=" << nd
<< " not allowed";
}

// MIOpen requires the strides and dims to be ordered as BDYX.
std::vector<int64> strides64 =
Expand All @@ -556,8 +559,8 @@ class ScopedTensorDescriptor {
&CheckedNarrowing<int64, int>);
std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
&CheckedNarrowing<int64, int>);
status = wrap::miopenSet4dTensorDescriptor(handle_, elem_type, dims[0],
dims[1], dims[2], dims[3]);
status = wrap::miopenSetTensorDescriptor(handle_, elem_type, nd,
dims.data(), strides.data());

if (status != miopenStatusSuccess) {
LOG(FATAL) << "could not convert BatchDescriptor "
Expand Down Expand Up @@ -604,19 +607,14 @@ class ScopedFilterDescriptor {

const int nd = batch_descriptor.ndims() + 2;

if (nd != 4) {
LOG(FATAL) << "miopen only supports 4D filters, dim=" << nd
<< "not allowed" << ToString(status);
}

std::vector<int> dims(2 + filter_descriptor.ndims());
dims[0] = filter_descriptor.output_feature_map_count();
dims[1] = filter_descriptor.input_feature_map_count();
const auto& spatial_dims = filter_descriptor.input_filter_dims();
std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2);

status = wrap::miopenSet4dTensorDescriptor(handle_, elem_type, dims[0],
dims[1], dims[2], dims[3]);
status = wrap::miopenSetTensorDescriptor(handle_, elem_type, nd,
dims.data(), nullptr);
if (status != miopenStatusSuccess) {
LOG(FATAL) << "could not set miopen filter descriptor: "
<< ToString(status);
Expand Down Expand Up @@ -667,11 +665,15 @@ class ScopedConvolutionDescriptor {
&CheckedNarrowing<int64, int>);
std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
&CheckedNarrowing<int64, int>);
std::vector<int> upscale(convolution_descriptor.ndims(), 1);

status = wrap::miopenInitConvolutionDescriptor(
handle_, miopenConvolution, padding[0], padding[1], strides[0],
strides[1], upscale[0], upscale[1]);
std::vector<int> upscale(convolution_descriptor.ndims());
const auto& dilations64 = convolution_descriptor.dilations();
std::transform(dilations64.cbegin(), dilations64.cend(), upscale.begin(),
&CheckedNarrowing<int64, int>);

status = wrap::miopenInitConvolutionNdDescriptor(
handle_, convolution_descriptor.ndims(), padding.data(), strides.data(),
upscale.data(), miopenConvolution);
if (status != miopenStatusSuccess) {
LOG(FATAL) << "could not set miopen convolution descriptor: "
<< ToString(status);
Expand Down Expand Up @@ -4003,9 +4005,8 @@ bool MIOpenSupport::DeriveOutputBatchDescriptor(

int dn = batch_descriptor.ndims() + 2;
std::vector<int> dims(dn); // in BDYX
auto status = wrap::miopenGetConvolutionForwardOutputDim(
conv.handle(), input_nd.handle(), filter.handle(), &dims[0], &dims[1],
&dims[2], &dims[3]);
auto status = wrap::miopenGetConvolutionNdForwardOutputDim(
conv.handle(), input_nd.handle(), filter.handle(), &dn, dims.data());
if (status != miopenStatusSuccess) {
LOG(ERROR) << "could not get output tensor for convolution: "
<< ToString(status);
Expand Down

0 comments on commit 1e24623

Please sign in to comment.