Skip to content

Commit

Permalink
Implement 2D convolution by FFT
Browse files Browse the repository at this point in the history
  • Loading branch information
wichtounet committed Apr 27, 2015
1 parent 84109cd commit 0487543
Show file tree
Hide file tree
Showing 6 changed files with 303 additions and 127 deletions.
3 changes: 3 additions & 0 deletions include/etl/conv_expr.hpp
Expand Up @@ -308,6 +308,9 @@ using conv2_same_expr = basic_conv_expr<T, 2, conv_type::SAME, detail::conv2_sam
template<typename T>
using conv2_full_expr = basic_conv_expr<T, 2, conv_type::FULL, detail::conv2_full_impl>;

template<typename T>
using fft_conv2_full_expr = basic_conv_expr<T, 2, conv_type::FULL, detail::fft_conv2_full_impl>;

//>2D convolutions

template<typename T, std::size_t D>
Expand Down
14 changes: 14 additions & 0 deletions include/etl/fast_expr.hpp
Expand Up @@ -805,6 +805,20 @@ auto conv_2d_full(A&& a, B&& b, C&& c) -> forced_temporary_binary_helper<A, B, C
return {a, b, c};
}

template<typename A, typename B>
auto fft_conv_2d_full(A&& a, B&& b) -> temporary_binary_helper<A, B, fft_conv2_full_expr> {
static_assert(is_etl_expr<A>::value && is_etl_expr<B>::value, "Convolution only supported for ETL expressions");

return {a, b};
}

template<typename A, typename B, typename C>
auto fft_conv_2d_full(A&& a, B&& b, C&& c) -> forced_temporary_binary_helper<A, B, C, fft_conv2_full_expr> {
static_assert(is_etl_expr<A>::value && is_etl_expr<B>::value && is_etl_expr<C>::value, "Convolution only supported for ETL expressions");

return {a, b, c};
}

template<typename A, typename B>
auto conv_deep_valid(A&& a, B&& b) -> dim_temporary_binary_helper<A, B, conv_deep_valid_expr, decay_traits<A>::dimensions()> {
static_assert(is_etl_expr<A>::value && is_etl_expr<B>::value, "Convolution only supported for ETL expressions");
Expand Down
138 changes: 138 additions & 0 deletions include/etl/impl/blas/fft.hpp
Expand Up @@ -156,6 +156,32 @@ inline void zfft2_kernel(const std::complex<double>* in, std::size_t d1, std::si
status = DftiFreeDescriptor(&descriptor); //Free the descriptor
}

inline void inplace_cfft2_kernel(std::complex<float>* in, std::size_t d1, std::size_t d2){
DFTI_DESCRIPTOR_HANDLE descriptor;
MKL_LONG status;
MKL_LONG dim[]{static_cast<long>(d1), static_cast<long>(d2)};

auto* in_ptr = const_cast<void*>(static_cast<const void*>(in));

status = DftiCreateDescriptor(&descriptor, DFTI_SINGLE, DFTI_COMPLEX, 2, dim); //Specify size and precision
status = DftiCommitDescriptor(descriptor); //Finalize the descriptor
status = DftiComputeForward(descriptor, in_ptr); //Compute the Forward FFT
status = DftiFreeDescriptor(&descriptor); //Free the descriptor
}

inline void inplace_zfft2_kernel(std::complex<double>* in, std::size_t d1, std::size_t d2){
DFTI_DESCRIPTOR_HANDLE descriptor;
MKL_LONG status;
MKL_LONG dim[]{static_cast<long>(d1), static_cast<long>(d2)};

auto* in_ptr = const_cast<void*>(static_cast<const void*>(in));

status = DftiCreateDescriptor(&descriptor, DFTI_DOUBLE, DFTI_COMPLEX, 2, dim); //Specify size and precision
status = DftiCommitDescriptor(descriptor); //Finalize the descriptor
status = DftiComputeForward(descriptor, in_ptr); //Compute the Forward FFT
status = DftiFreeDescriptor(&descriptor); //Free the descriptor
}

inline void cifft2_kernel(const std::complex<float>* in, std::size_t d1, std::size_t d2, std::complex<float>* out){
DFTI_DESCRIPTOR_HANDLE descriptor;
MKL_LONG status;
Expand Down Expand Up @@ -186,6 +212,34 @@ inline void zifft2_kernel(const std::complex<double>* in, std::size_t d1, std::s
status = DftiFreeDescriptor(&descriptor); //Free the descriptor
}

inline void inplace_cifft2_kernel(std::complex<float>* in, std::size_t d1, std::size_t d2){
DFTI_DESCRIPTOR_HANDLE descriptor;
MKL_LONG status;
MKL_LONG dim[]{static_cast<long>(d1), static_cast<long>(d2)};

auto* in_ptr = const_cast<void*>(static_cast<const void*>(in));

status = DftiCreateDescriptor(&descriptor, DFTI_SINGLE, DFTI_COMPLEX, 2, dim); //Specify size and precision
status = DftiSetValue(descriptor, DFTI_BACKWARD_SCALE, 1.0f / (d1 * d2)); //Scale down the output
status = DftiCommitDescriptor(descriptor); //Finalize the descriptor
status = DftiComputeBackward(descriptor, in_ptr); //Compute the Forward FFT
status = DftiFreeDescriptor(&descriptor); //Free the descriptor
}

inline void inplace_zifft2_kernel(std::complex<double>* in, std::size_t d1, std::size_t d2){
DFTI_DESCRIPTOR_HANDLE descriptor;
MKL_LONG status;
MKL_LONG dim[]{static_cast<long>(d1), static_cast<long>(d2)};

auto* in_ptr = const_cast<void*>(static_cast<const void*>(in));

status = DftiCreateDescriptor(&descriptor, DFTI_DOUBLE, DFTI_COMPLEX, 2, dim); //Specify size and precision
status = DftiSetValue(descriptor, DFTI_BACKWARD_SCALE, 1.0 / (d1 * d2)); //Scale down the output
status = DftiCommitDescriptor(descriptor); //Finalize the descriptor
status = DftiComputeBackward(descriptor, in_ptr); //Compute the Forward FFT
status = DftiFreeDescriptor(&descriptor); //Free the descriptor
}

} //End of namespace detail

template<typename A, typename C>
Expand Down Expand Up @@ -364,6 +418,84 @@ void zifft2_real(A&& a, C&& c){
}
};

template<typename A, typename B, typename C>
void sfft2_convolve(A&& a, B&& b, C&& c){
const auto m1 = etl::dim<0>(a);
const auto n1= etl::dim<0>(b);
const auto s1 = m1 + n1 - 1;

const auto m2 = etl::dim<1>(a);
const auto n2= etl::dim<1>(b);
const auto s2 = m2 + n2 - 1;

auto a_padded = allocate<std::complex<float>>(c.size());
auto b_padded = allocate<std::complex<float>>(c.size());

for(std::size_t i = 0; i < m1; ++i){
for(std::size_t j = 0; j < m2; ++j){
a_padded[i * s2 + j] = a(i,j);
}
}

for(std::size_t i = 0; i < n1; ++i){
for(std::size_t j = 0; j < n2; ++j){
b_padded[i * s2 + j] = b(i,j);
}
}

detail::inplace_cfft2_kernel(a_padded.get(), s1, s2);
detail::inplace_cfft2_kernel(b_padded.get(), s1, s2);

for(std::size_t i = 0; i < c.size(); ++i){
a_padded[i] *= b_padded[i];
}

detail::inplace_cifft2_kernel(a_padded.get(), s1, s2);

for(std::size_t i = 0; i < c.size(); ++i){
c[i] = a_padded[i].real();
}
}

template<typename A, typename B, typename C>
void dfft2_convolve(A&& a, B&& b, C&& c){
const auto m1 = etl::dim<0>(a);
const auto n1= etl::dim<0>(b);
const auto s1 = m1 + n1 - 1;

const auto m2 = etl::dim<1>(a);
const auto n2= etl::dim<1>(b);
const auto s2 = m2 + n2 - 1;

auto a_padded = allocate<std::complex<double>>(c.size());
auto b_padded = allocate<std::complex<double>>(c.size());

for(std::size_t i = 0; i < m1; ++i){
for(std::size_t j = 0; j < m2; ++j){
a_padded[i * s2 + j] = a(i,j);
}
}

for(std::size_t i = 0; i < n1; ++i){
for(std::size_t j = 0; j < n2; ++j){
b_padded[i * s2 + j] = b(i,j);
}
}

detail::inplace_zfft2_kernel(a_padded.get(), s1, s2);
detail::inplace_zfft2_kernel(b_padded.get(), s1, s2);

for(std::size_t i = 0; i < c.size(); ++i){
a_padded[i] *= b_padded[i];
}

detail::inplace_zifft2_kernel(a_padded.get(), s1, s2);

for(std::size_t i = 0; i < c.size(); ++i){
c[i] = a_padded[i].real();
}
}

#else

template<typename A, typename C>
Expand Down Expand Up @@ -420,6 +552,12 @@ void cifft2_real(A&&, C&&);
template<typename A, typename C>
void zifft2_real(A&&, C&&);

template<typename A, typename B, typename C>
void sfft2_convolve(A&&, C&&);

template<typename A, typename B, typename C>
void dfft2_convolve(A&&, B&&, C&&);

#endif

} //end of namespace blas
Expand Down
17 changes: 17 additions & 0 deletions include/etl/impl/fft.hpp
Expand Up @@ -40,6 +40,9 @@ struct ifft2_real_impl;
template<typename A, typename B, typename C, typename Enable = void>
struct fft_conv1_full_impl;

template<typename A, typename B, typename C, typename Enable = void>
struct fft_conv2_full_impl;

template<typename A, typename C>
struct is_blas_dfft : cpp::and_c<is_mkl_enabled, is_double_precision<A>, is_dma_2<A, C>> {};

Expand Down Expand Up @@ -184,6 +187,20 @@ struct ifft2_real_impl<A, C, std::enable_if_t<is_blas_zfft<A,C>::value>> {
}
};

template<typename A, typename B, typename C>
struct fft_conv2_full_impl<A, B, C, std::enable_if_t<is_blas_sfft_convolve<A,B,C>::value>> {
static void apply(A&& a, B&& b, C&& c){
etl::impl::blas::sfft2_convolve(std::forward<A>(a), std::forward<B>(b), std::forward<C>(c));
}
};

template<typename A, typename B, typename C>
struct fft_conv2_full_impl<A, B, C, std::enable_if_t<is_blas_dfft_convolve<A,B,C>::value>> {
static void apply(A&& a, B&& b, C&& c){
etl::impl::blas::dfft2_convolve(std::forward<A>(a), std::forward<B>(b), std::forward<C>(c));
}
};

} //end of namespace detail

} //end of namespace etl
Expand Down
4 changes: 4 additions & 0 deletions test/conv_test.hpp
Expand Up @@ -64,9 +64,12 @@ CONV_FUNCTOR( std_conv2_valid, etl::impl::standard::conv2_valid(a, b, c) )

#ifdef ETL_MKL_MODE
CONV_FUNCTOR( fft_conv1_full, c = etl::fft_conv_1d_full(a, b) )
CONV_FUNCTOR( fft_conv2_full, c = etl::fft_conv_2d_full(a, b) )
#define CONV1_FULL_TEST_CASE_SECTION_FFT CONV_TEST_CASE_SECTIONS( fft_conv1_full, fft_conv1_full )
#define CONV2_FULL_TEST_CASE_SECTION_FFT CONV_TEST_CASE_SECTIONS( fft_conv2_full, fft_conv2_full )
#else
#define CONV1_FULL_TEST_CASE_SECTION_FFT
#define CONV2_FULL_TEST_CASE_SECTION_FFT
#endif

#ifdef TEST_SSE
Expand Down Expand Up @@ -193,6 +196,7 @@ CONV_FUNCTOR( avx_conv2_valid_double, etl::impl::avx::dconv2_valid(a, b, c) )
CONV2_FULL_TEST_CASE_SECTION_DEFAULT \
CONV2_FULL_TEST_CASE_SECTION_STD \
CONV2_FULL_TEST_CASE_SECTION_REDUC \
CONV2_FULL_TEST_CASE_SECTION_FFT \
CONV2_FULL_TEST_CASE_SECTION_SSE \
CONV2_FULL_TEST_CASE_SECTION_AVX \
} \
Expand Down

0 comments on commit 0487543

Please sign in to comment.