Skip to content

Commit

Permalink
adding temporary memory switch in xsemantic_base::operator=.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mykola Vankovych committed Jan 18, 2024
1 parent af4403e commit 1f8c179
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 0 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ target_link_libraries(xtensor INTERFACE xtl)

OPTION(XTENSOR_ENABLE_ASSERT "xtensor bound check" OFF)
OPTION(XTENSOR_CHECK_DIMENSION "xtensor dimension check" OFF)
OPTION(XTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS "xtensor force the use of temporary memory when assigning instead of an automatic overlap check" ON)
OPTION(BUILD_TESTS "xtensor test suite" OFF)
OPTION(BUILD_BENCHMARK "xtensor benchmark" OFF)
OPTION(DOWNLOAD_GTEST "build gtest from downloaded sources" OFF)
Expand All @@ -219,6 +220,10 @@ if(XTENSOR_CHECK_DIMENSION)
add_definitions(-DXTENSOR_ENABLE_CHECK_DIMENSION)
endif()

if(XTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS)
add_definitions(-DXTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS)
endif()

if(DEFAULT_COLUMN_MAJOR)
add_definitions(-DXTENSOR_DEFAULT_LAYOUT=layout_type::column_major)
endif()
Expand Down
217 changes: 217 additions & 0 deletions include/xtensor/xsemantic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,211 @@

namespace xt
{
/* Forward declarations */

template <class D>
class xview_semantic;

template <class CT, class X>
class xbroadcast;

template <class F, class R, class S>
class xgenerator;

namespace detail
{

template <class T, class Enable = void>
struct has_memory_address : std::false_type
{
};

template <class T>
struct has_memory_address<T, void_t<decltype(std::addressof(*std::declval<T>().begin()))>> : std::true_type
{
};

struct memory_range
{
// Checking pointer overlap is more correct in integer values,
// for more explanation check https://devblogs.microsoft.com/oldnewthing/20170927-00/?p=97095
const uintptr_t m_first = 0;
const uintptr_t m_last = 0;

explicit memory_range() = default;

template <class T>
explicit memory_range(T* first, T* last)
: m_first(reinterpret_cast<uintptr_t>(last < first ? last : first)),
m_last(reinterpret_cast<uintptr_t>(last < first ? first : last)) {}

template <class T>
bool overlaps(T* first, T* last) const
{
if (first <= last)
{
return reinterpret_cast<uintptr_t>(first) <= m_last && reinterpret_cast<uintptr_t>(last) >= m_first;
}
else
{
return reinterpret_cast<uintptr_t>(last) <= m_last && reinterpret_cast<uintptr_t>(first) >= m_first;
}
}
};

template <class E, class Enable = void>
struct overlapping_memory_checker_traits
{
static bool check_overlap(const E&, const memory_range&)
{
return true;
}
};

template <class E>
struct overlapping_memory_checker_traits<E, std::enable_if_t<has_memory_address<E>::value>>
{
static bool check_overlap(const E& expr, const memory_range& dst_range)
{
if (expr.size() == 0)
{
return false;
}
else
{
return dst_range.overlaps(std::addressof(*expr.begin()), std::addressof(*expr.rbegin()));
}
}
};

template <class E>
struct overlapping_memory_checker_traits<E, std::enable_if_t<!has_memory_address<E>::value &&
is_crtp_base_of<xview_semantic, E>::value>>
{
static bool check_overlap(const E& expr, const memory_range& dst_range)
{
if (expr.size() == 0)
{
return false;
}
else
{
using ChildE = std::decay_t<decltype(expr.expression())>;
return overlapping_memory_checker_traits<ChildE>::check_overlap(expr.expression(), dst_range);
}
}
};

template <class E>
struct overlapping_memory_checker_traits<E, std::enable_if_t<!has_memory_address<E>::value &&
is_specialization_of<xbroadcast, E>::value>>
{
static bool check_overlap(const E& expr, const memory_range& dst_range)
{
if (expr.size() == 0)
{
return false;
}
else
{
using ChildE = std::decay_t<decltype(expr.expression())>;
return overlapping_memory_checker_traits<ChildE>::check_overlap(expr.expression(), dst_range);
}
}
};

template <class E>
struct overlapping_memory_checker_traits<E, std::enable_if_t<!has_memory_address<E>::value &&
is_specialization_of<xfunction, E>::value>>
{
template <std::size_t I = 0, class... T, std::enable_if_t<I == sizeof...(T), int> = 0>
static bool check_tuple(const std::tuple<T...>&, const memory_range&)
{
return false;
}

template <std::size_t I = 0, class... T, std::enable_if_t<I<sizeof...(T), int> = 0>
static bool check_tuple(const std::tuple<T...>& t, const memory_range& dst_range)
{
using ChildE = std::decay_t<decltype(std::get<I>(t))>;
return overlapping_memory_checker_traits<ChildE>::check_overlap(std::get<I>(t), dst_range) || check_tuple<I + 1>(t, dst_range);
}

static bool check_overlap(const E& expr, const memory_range& dst_range)
{
if (expr.size() == 0)
{
return false;
}
else
{
return check_tuple(expr.arguments(), dst_range);
}
}
};

template <class E>
struct overlapping_memory_checker_traits<E, std::enable_if_t<!has_memory_address<E>::value &&
is_specialization_of<xgenerator, E>::value>>
{
static bool check_overlap(const E&, const memory_range&)
{
return false;
}
};

struct overlapping_memory_checker_base
{
memory_range m_dst_range;

explicit overlapping_memory_checker_base() = default;
explicit overlapping_memory_checker_base(memory_range dst_memory_range) : m_dst_range(std::move(dst_memory_range)) {}

template <class E>
bool operator()(const E& expr) const
{
if (!m_dst_range.m_first || !m_dst_range.m_last)
{
return false;
}
else
{
return overlapping_memory_checker_traits<E>::check_overlap(expr, m_dst_range);
}
}
};

template <class Dst, class Enable = void>
struct overlapping_memory_checker : overlapping_memory_checker_base
{
explicit overlapping_memory_checker(const Dst&) : overlapping_memory_checker_base() {}
};

template <class Dst>
struct overlapping_memory_checker<Dst, std::enable_if_t<has_memory_address<Dst>::value>> : overlapping_memory_checker_base
{
explicit overlapping_memory_checker(const Dst& aDst)
: overlapping_memory_checker_base([&]() {
if (aDst.size() == 0)
{
return memory_range();
}
else
{
return memory_range(std::addressof(*aDst.begin()), std::addressof(*aDst.rbegin()));
}
}())
{
}
};

template <class Dst>
auto make_overlapping_memory_checker(const Dst& a_dst)
{
return overlapping_memory_checker<Dst>(a_dst);
}


template <class D>
struct is_sharable
{
Expand Down Expand Up @@ -598,8 +801,22 @@ namespace xt
template <class E>
inline auto xsemantic_base<D>::operator=(const xexpression<E>& e) -> derived_type&
{
#ifdef XTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS
temporary_type tmp(e);
return this->derived_cast().assign_temporary(std::move(tmp));
#else
auto&& this_derived = this->derived_cast();
auto memory_overlaps = detail::make_overlapping_memory_checker(this_derived);
if (memory_overlaps(e.derived_cast()))
{
temporary_type tmp(e);
return this_derived.assign_temporary(std::move(tmp));
}
else
{
return this->assign(e);
}
#endif
}

/**************************************
Expand Down
14 changes: 14 additions & 0 deletions include/xtensor/xutils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,20 @@ namespace xt
using type = T;
};

/***************************************
* is_specialization_of implementation *
***************************************/

template <template<class...> class TT, class T>
struct is_specialization_of : std::false_type
{
};

template <template<class...> class TT, class... Ts>
struct is_specialization_of<TT, TT<Ts...>> : std::true_type
{
};

/*******************************
* remove_class implementation *
*******************************/
Expand Down

0 comments on commit 1f8c179

Please sign in to comment.