Skip to content

POC implementation for task_group dynamic dependencies - part 1 - task_tracker #1682

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 189 additions & 9 deletions include/oneapi/tbb/detail/_task_handle.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2020-2024 Intel Corporation
Copyright (c) 2020-2025 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,7 +32,82 @@ namespace d2 {

class task_handle;

class task_handle_task : public d1::task {
#if __TBB_PREVIEW_TASK_GROUP_EXTENSIONS

class task_dynamic_state {
public:
task_dynamic_state(d1::small_object_allocator& alloc)
: m_num_references(0)
, m_allocator(alloc)
{}

void reserve() { ++m_num_references; }

void release() {
if (--m_num_references == 0) {
m_allocator.delete_object(this);
}
}

void complete_task() {
}
private:
std::atomic<std::size_t> m_num_references;
d1::small_object_allocator m_allocator;
};

class task_with_dynamic_state : public d1::task {
public:
task_with_dynamic_state() : m_state(nullptr) {}

virtual ~task_with_dynamic_state() {
task_dynamic_state* current_state = m_state.load(std::memory_order_relaxed);
if (current_state != nullptr) {
current_state->release();
}
}

// Create dynamic state if task_tracker is created or a first successor is added to task_handle
task_dynamic_state* get_dynamic_state() {
task_dynamic_state* current_state = m_state.load(std::memory_order_acquire);

if (current_state == nullptr) {
d1::small_object_allocator alloc;

task_dynamic_state* new_state = alloc.new_object<task_dynamic_state>(alloc);

if (m_state.compare_exchange_strong(current_state, new_state)) {
// Reserve a task co-ownership for dynamic_state
new_state->reserve();
current_state = new_state;
} else {
// Other thread created the dynamic state
alloc.delete_object(new_state);
}
}

__TBB_ASSERT(current_state != nullptr, "Failed to create dynamic state");
return current_state;
}

void complete_task() {
task_dynamic_state* current_state = m_state.load(std::memory_order_relaxed);
if (current_state != nullptr) {
current_state->complete_task();
}
}
private:
std::atomic<task_dynamic_state*> m_state;
};
#endif // __TBB_PREVIEW_TASK_GROUP_EXTENSIONS

class task_handle_task
#if __TBB_PREVIEW_TASK_GROUP_EXTENSIONS
: public task_with_dynamic_state
#else
: public d1::task
#endif
{
std::uint64_t m_version_and_traits{};
d1::wait_tree_vertex_interface* m_wait_tree_vertex;
d1::task_group_context& m_ctx;
Expand Down Expand Up @@ -84,21 +159,26 @@ class task_handle {

private:
friend struct task_handle_accessor;
#if __TBB_PREVIEW_TASK_GROUP_EXTENSIONS
friend class task_tracker;
#endif

task_handle(task_handle_task* t) : m_handle {t}{};
task_handle(task_handle_task* t) : m_handle {t}{}

d1::task* release() {
return m_handle.release();
}
};

struct task_handle_accessor {
static task_handle construct(task_handle_task* t) { return {t}; }
static d1::task* release(task_handle& th) { return th.release(); }
static d1::task_group_context& ctx_of(task_handle& th) {
__TBB_ASSERT(th.m_handle, "ctx_of does not expect empty task_handle.");
return th.m_handle->ctx();
}
static task_handle construct(task_handle_task* t) { return {t}; }

static d1::task* release(task_handle& th) { return th.release(); }

static d1::task_group_context& ctx_of(task_handle& th) {
__TBB_ASSERT(th.m_handle, "ctx_of does not expect empty task_handle.");
return th.m_handle->ctx();
}
};

inline bool operator==(task_handle const& th, std::nullptr_t) noexcept {
Expand All @@ -116,6 +196,106 @@ inline bool operator!=(std::nullptr_t, task_handle const& th) noexcept {
return th.m_handle != nullptr;
}

#if __TBB_PREVIEW_TASK_GROUP_EXTENSIONS
class task_tracker {
public:
task_tracker() : m_task_state(nullptr) {}

task_tracker(const task_tracker& other)
: m_task_state(other.m_task_state)
{
// Register one more co-owner of the dynamic state
if (m_task_state) m_task_state->reserve();
}
task_tracker(task_tracker&& other)
: m_task_state(other.m_task_state)
{
other.m_task_state = nullptr;
}

task_tracker(const task_handle& th)
: m_task_state(th ? th.m_handle->get_dynamic_state() : nullptr)
{
// Register new co-owner of the dynamic state
if (m_task_state) m_task_state->reserve();
}

~task_tracker() {
if (m_task_state) m_task_state->release();
}

task_tracker& operator=(const task_tracker& other) {
if (this != &other) {
// Release co-ownership on the previously tracked dynamic state
if (m_task_state) m_task_state->release();

m_task_state = other.m_task_state;

// Register new co-owner of the new dynamic state
if (m_task_state) m_task_state->reserve();
}
return *this;
}

task_tracker& operator=(task_tracker&& other) {
if (this != &other) {
// Release co-ownership on the previously tracked dynamic state
if (m_task_state) m_task_state->release();

m_task_state = other.m_task_state;
other.m_task_state = nullptr;
}
return *this;
}

task_tracker& operator=(const task_handle& th) {
// Release co-ownership on the previously tracked dynamic state
if (m_task_state) m_task_state->release();

if (th) {
m_task_state = th.m_handle->get_dynamic_state();

// Reserve co-ownership on the new dynamic state
__TBB_ASSERT(m_task_state != nullptr, "No state in the non-empty task_handle");
m_task_state->reserve();
} else {
m_task_state = nullptr;
}
return *this;
}

explicit operator bool() const noexcept { return m_task_state != nullptr; }
private:
friend bool operator==(const task_tracker& t, std::nullptr_t) noexcept {
return t.m_task_state == nullptr;
}

friend bool operator==(const task_tracker& lhs, const task_tracker& rhs) noexcept {
return lhs.m_task_state == rhs.m_task_state;
}

#if !__TBB_CPP20_COMPARISONS_PRESENT
friend bool operator==(std::nullptr_t, const task_tracker& t) noexcept {
return t == nullptr;
}

friend bool operator!=(const task_tracker& t, std::nullptr_t) noexcept {
return !(t == nullptr);
}

friend bool operator!=(std::nullptr_t, const task_tracker& t) noexcept {
return !(t == nullptr);
}

friend bool operator!=(const task_tracker& lhs, const task_tracker& rhs) noexcept {
return !(lhs == rhs);
}
#endif // !__TBB_CPP20_COMPARISONS_PRESENT

task_dynamic_state* m_task_state;
};
#endif

} // namespace d2
} // namespace detail
} // namespace tbb
Expand Down
19 changes: 17 additions & 2 deletions include/oneapi/tbb/task_group.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2005-2024 Intel Corporation
Copyright (c) 2005-2025 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -87,6 +87,9 @@ class function_task : public task_handle_task {
d1::task* execute(d1::execution_data& ed) override {
__TBB_ASSERT(ed.context == &this->ctx(), "The task group context should be used for all tasks");
task* res = task_ptr_or_nullptr(m_func);
#if __TBB_PREVIEW_TASK_GROUP_EXTENSIONS
this->complete_task();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No-op in this part, would be used in part 2 for bypassing the successor task

#endif
finalize(&ed);
return res;
}
Expand Down Expand Up @@ -440,7 +443,13 @@ class isolated_task_group;
#endif

template <typename F>
class function_stack_task : public d1::task {
class function_stack_task
#if __TBB_PREVIEW_TASK_GROUP_EXTENSIONS
: public task_with_dynamic_state
#else
: public d1::task
#endif
{
const F& m_func;
d1::wait_tree_vertex_interface* m_wait_tree_vertex;

Expand All @@ -449,6 +458,9 @@ class function_stack_task : public d1::task {
}
task* execute(d1::execution_data&) override {
task* res = d2::task_ptr_or_nullptr(m_func);
#if __TBB_PREVIEW_TASK_GROUP_EXTENSIONS
this->complete_task();
#endif
finalize();
return res;
}
Expand Down Expand Up @@ -701,6 +713,9 @@ using detail::d1::is_current_task_group_canceling;
using detail::r1::missing_wait;

using detail::d2::task_handle;
#if __TBB_PREVIEW_TASK_GROUP_EXTENSIONS
using detail::d2::task_tracker;
#endif
}

} // namespace tbb
Expand Down
Loading