Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the ability to enable memory overlap check in assignment to avoid unneeded temporary memory allocation #2768

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ target_link_libraries(xtensor INTERFACE xtl)

OPTION(XTENSOR_ENABLE_ASSERT "xtensor bound check" OFF)
OPTION(XTENSOR_CHECK_DIMENSION "xtensor dimension check" OFF)
OPTION(XTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS "xtensor force the use of temporary memory when assigning instead of an automatic overlap check" ON)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of having different assign behaviors depending on an option, when it's not for managing dependency or debugging. We should probably have some pattern to dynamically decide if we should check for memory overlap or not. This can be done in a dedicated PR, as it may not be obvious to get it done correctly.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So should we keep the option for now?

OPTION(BUILD_TESTS "xtensor test suite" OFF)
OPTION(BUILD_BENCHMARK "xtensor benchmark" OFF)
OPTION(DOWNLOAD_GTEST "build gtest from downloaded sources" OFF)
Expand All @@ -219,6 +220,10 @@ if(XTENSOR_CHECK_DIMENSION)
add_definitions(-DXTENSOR_ENABLE_CHECK_DIMENSION)
endif()

if(XTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS)
add_definitions(-DXTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS)
endif()

if(DEFAULT_COLUMN_MAJOR)
add_definitions(-DXTENSOR_DEFAULT_LAYOUT=layout_type::column_major)
endif()
Expand Down
23 changes: 23 additions & 0 deletions include/xtensor/xbroadcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,29 @@ namespace xt
return linear_end(c.expression());
}

/*************************************
* overlapping_memory_checker_traits *
*************************************/

template <class E>
struct overlapping_memory_checker_traits<
E,
std::enable_if_t<!has_memory_address<E>::value && is_specialization_of<xbroadcast, E>::value>>
{
static bool check_overlap(const E& expr, const memory_range& dst_range)
{
if (expr.size() == 0)
{
return false;
}
else
{
using ChildE = std::decay_t<decltype(expr.expression())>;
return overlapping_memory_checker_traits<ChildE>::check_overlap(expr.expression(), dst_range);
}
}
};

/**
* @class xbroadcast
* @brief Broadcasted xexpression to a specified shape.
Expand Down
36 changes: 36 additions & 0 deletions include/xtensor/xfunction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,42 @@ namespace xt
{
};

/*************************************
* overlapping_memory_checker_traits *
*************************************/

template <class E>
struct overlapping_memory_checker_traits<
E,
std::enable_if_t<!has_memory_address<E>::value && is_specialization_of<xfunction, E>::value>>
{
template <std::size_t I = 0, class... T, std::enable_if_t<(I == sizeof...(T)), int> = 0>
static bool check_tuple(const std::tuple<T...>&, const memory_range&)
{
return false;
}

template <std::size_t I = 0, class... T, std::enable_if_t<(I < sizeof...(T)), int> = 0>
static bool check_tuple(const std::tuple<T...>& t, const memory_range& dst_range)
{
using ChildE = std::decay_t<decltype(std::get<I>(t))>;
return overlapping_memory_checker_traits<ChildE>::check_overlap(std::get<I>(t), dst_range)
|| check_tuple<I + 1>(t, dst_range);
}

static bool check_overlap(const E& expr, const memory_range& dst_range)
{
if (expr.size() == 0)
{
return false;
}
else
{
return check_tuple(expr.arguments(), dst_range);
}
}
};

/*************
* xfunction *
*************/
Expand Down
15 changes: 15 additions & 0 deletions include/xtensor/xgenerator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,21 @@ namespace xt
using size_type = std::size_t;
};

/*************************************
* overlapping_memory_checker_traits *
*************************************/

template <class E>
struct overlapping_memory_checker_traits<
E,
std::enable_if_t<!has_memory_address<E>::value && is_specialization_of<xgenerator, E>::value>>
{
static bool check_overlap(const E&, const memory_range&)
{
return false;
}
};

/**
* @class xgenerator
* @brief Multidimensional function operating on indices.
Expand Down
37 changes: 37 additions & 0 deletions include/xtensor/xsemantic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,29 @@ namespace xt
template <class E, class R = void>
using disable_xcontainer_semantics = typename std::enable_if<!has_container_semantics<E>::value, R>::type;


template <class D>
class xview_semantic;

template <class E>
struct overlapping_memory_checker_traits<
E,
std::enable_if_t<!has_memory_address<E>::value && is_crtp_base_of<xview_semantic, E>::value>>
{
static bool check_overlap(const E& expr, const memory_range& dst_range)
{
if (expr.size() == 0)
{
return false;
}
else
{
using ChildE = std::decay_t<decltype(expr.expression())>;
return overlapping_memory_checker_traits<ChildE>::check_overlap(expr.expression(), dst_range);
}
}
};

/**
* @class xview_semantic
* @brief Implementation of the xsemantic_base interface for
Expand Down Expand Up @@ -598,8 +621,22 @@ namespace xt
template <class E>
inline auto xsemantic_base<D>::operator=(const xexpression<E>& e) -> derived_type&
{
#ifdef XTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS
temporary_type tmp(e);
return this->derived_cast().assign_temporary(std::move(tmp));
#else
auto&& this_derived = this->derived_cast();
auto memory_checker = make_overlapping_memory_checker(this_derived);
if (memory_checker.check_overlap(e.derived_cast()))
{
temporary_type tmp(e);
return this_derived.assign_temporary(std::move(tmp));
}
else
{
return this->assign(e);
}
#endif
}

/**************************************
Expand Down
147 changes: 147 additions & 0 deletions include/xtensor/xutils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,20 @@ namespace xt
using type = T;
};

/***************************************
* is_specialization_of implementation *
***************************************/

template <template <class...> class TT, class T>
struct is_specialization_of : std::false_type
{
};

template <template <class...> class TT, class... Ts>
struct is_specialization_of<TT, TT<Ts...>> : std::true_type
{
};

/*******************************
* remove_class implementation *
*******************************/
Expand Down Expand Up @@ -860,6 +874,139 @@ namespace xt
{
};

/*************************************
* overlapping_memory_checker_traits *
*************************************/

template <class T, class Enable = void>
struct has_memory_address : std::false_type
{
};

template <class T>
struct has_memory_address<T, void_t<decltype(std::addressof(*std::declval<T>().begin()))>> : std::true_type
{
};

struct memory_range
{
// Checking pointer overlap is more correct in integer values,
// for more explanation check https://devblogs.microsoft.com/oldnewthing/20170927-00/?p=97095
const uintptr_t m_first = 0;
const uintptr_t m_last = 0;

explicit memory_range() = default;

template <class T>
explicit memory_range(T* first, T* last)
: m_first(reinterpret_cast<uintptr_t>(last < first ? last : first))
, m_last(reinterpret_cast<uintptr_t>(last < first ? first : last))
{
}

template <class T>
bool overlaps(T* first, T* last) const
{
if (first <= last)
{
return reinterpret_cast<uintptr_t>(first) <= m_last
&& reinterpret_cast<uintptr_t>(last) >= m_first;
}
else
{
return reinterpret_cast<uintptr_t>(last) <= m_last
&& reinterpret_cast<uintptr_t>(first) >= m_first;
}
}
};

template <class E, class Enable = void>
struct overlapping_memory_checker_traits
{
static bool check_overlap(const E&, const memory_range&)
{
return true;
}
};

template <class E>
struct overlapping_memory_checker_traits<E, std::enable_if_t<has_memory_address<E>::value>>
{
static bool check_overlap(const E& expr, const memory_range& dst_range)
{
if (expr.size() == 0)
{
return false;
}
else
{
return dst_range.overlaps(std::addressof(*expr.begin()), std::addressof(*expr.rbegin()));
}
}
};

struct overlapping_memory_checker_base
{
memory_range m_dst_range;

explicit overlapping_memory_checker_base() = default;

explicit overlapping_memory_checker_base(memory_range dst_memory_range)
: m_dst_range(std::move(dst_memory_range))
{
}

template <class E>
bool check_overlap(const E& expr) const
{
if (!m_dst_range.m_first || !m_dst_range.m_last)
{
return false;
}
else
{
return overlapping_memory_checker_traits<E>::check_overlap(expr, m_dst_range);
}
}
};

template <class Dst, class Enable = void>
struct overlapping_memory_checker : overlapping_memory_checker_base
{
explicit overlapping_memory_checker(const Dst&)
: overlapping_memory_checker_base()
{
}
};

template <class Dst>
struct overlapping_memory_checker<Dst, std::enable_if_t<has_memory_address<Dst>::value>>
: overlapping_memory_checker_base
{
explicit overlapping_memory_checker(const Dst& aDst)
: overlapping_memory_checker_base(
[&]()
{
if (aDst.size() == 0)
{
return memory_range();
}
else
{
return memory_range(std::addressof(*aDst.begin()), std::addressof(*aDst.rbegin()));
}
}()
)
{
}
};

template <class Dst>
auto make_overlapping_memory_checker(const Dst& a_dst)
{
return overlapping_memory_checker<Dst>(a_dst);
}

/********************
* rebind_container *
********************/
Expand Down
Loading