Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix thread pool #9260

Merged
merged 1 commit into from Nov 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
60 changes: 43 additions & 17 deletions Utilities/Thread.cpp
Expand Up @@ -6,6 +6,7 @@
#include "Emu/Cell/lv2/sys_mmapper.h"
#include "Emu/Cell/lv2/sys_event.h"
#include "Thread.h"
#include "Utilities/JIT.h"
#include "sysinfo.h"
#include <typeinfo>
#include <thread>
Expand Down Expand Up @@ -39,6 +40,7 @@
#endif
#ifdef __linux__
#include <sys/timerfd.h>
#include <unistd.h>
#endif

#if defined(__APPLE__) || defined(__DragonFly__) || defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)
Expand Down Expand Up @@ -72,7 +74,6 @@
# endif
#endif

#include "sync.h"
#include "util/vm.hpp"
#include "util/logs.hpp"
#include "util/asm.hpp"
Expand Down Expand Up @@ -1847,7 +1848,7 @@ static atomic_t<u128, 64> s_thread_bits{0};

static atomic_t<thread_base**> s_thread_pool[128]{};

void thread_base::start(native_entry entry)
void thread_base::start(native_entry entry, void(*trampoline)())
{
for (u128 bits = s_thread_bits.load(); bits; bits &= bits - 1)
{
Expand All @@ -1866,12 +1867,12 @@ void thread_base::start(native_entry entry)
}

// Send "this" and entry point
m_thread = reinterpret_cast<u64>(entry);
m_thread = reinterpret_cast<u64>(trampoline);
atomic_storage<thread_base*>::release(*tls, this);
s_thread_pool[pos].notify_all();

// Wait for actual "m_thread" in return
while (m_thread == reinterpret_cast<u64>(entry))
while (m_thread == reinterpret_cast<u64>(trampoline))
{
busy_wait(300);
}
Expand Down Expand Up @@ -2026,16 +2027,15 @@ u64 thread_base::finalize(thread_state result_state) noexcept

void thread_base::finalize(u64 _self) noexcept
{
atomic_wait_engine::set_wait_callback(nullptr);
g_tls_log_prefix = []() -> std::string { return {}; };
thread_ctrl::g_tls_this_thread = nullptr;

if (!_self)
{
// Don't even need to clean these values for detached threads
return;
}

atomic_wait_engine::set_wait_callback(nullptr);
g_tls_log_prefix = []() -> std::string { return {}; };
thread_ctrl::g_tls_this_thread = nullptr;

// Try to add self to thread pool
const auto [bits, ok] = s_thread_bits.fetch_op([](u128& bits)
{
Expand All @@ -2052,9 +2052,10 @@ void thread_base::finalize(u64 _self) noexcept
if (!ok)
{
#ifdef _WIN32
CloseHandle(reinterpret_cast<HANDLE>(_self));
_endthread();
#else
pthread_detach(reinterpret_cast<pthread_t>(_self));
pthread_exit(0);
#endif
return;
}
Expand Down Expand Up @@ -2082,7 +2083,26 @@ void thread_base::finalize(u64 _self) noexcept
const auto entry = _this->m_thread.exchange(_self);
_this->m_thread.notify_one();

reinterpret_cast<native_entry>(entry)(_this);
// Hack return address to avoid tail call
#ifdef _MSC_VER
*static_cast<u64*>(_AddressOfReturnAddress()) = entry;
#else
static_cast<u64*>(__builtin_frame_address(0))[1] = entry;
#endif
//reinterpret_cast<native_entry>(entry)(_this);
}

void (*thread_base::make_trampoline(native_entry entry))()
{
return build_function_asm<void(*)()>([&](asmjit::X86Assembler& c, auto& args)
{
using namespace asmjit;

// Revert effect of ret instruction (fix stack alignment)
c.mov(x86::rax, imm_ptr(entry));
c.sub(x86::rsp, 8);
c.jmp(x86::rax);
});
}

void thread_ctrl::_wait_for(u64 usec, bool alert /* true */)
Expand Down Expand Up @@ -2168,15 +2188,14 @@ thread_base::thread_base(std::string_view name)

thread_base::~thread_base()
{
if (m_thread)
if (u64 handle = m_thread.exchange(0))
{
#ifdef _WIN32
CloseHandle(reinterpret_cast<HANDLE>(m_thread.raw()));
CloseHandle(reinterpret_cast<HANDLE>(handle));
#else
pthread_detach(reinterpret_cast<pthread_t>(m_thread.raw()));
pthread_detach(reinterpret_cast<pthread_t>(handle));
#endif
}

}

bool thread_base::join() const
Expand Down Expand Up @@ -2260,16 +2279,23 @@ void thread_ctrl::emergency_exit(std::string_view reason)
{
g_tls_error_callback();

if (_this->finalize(thread_state::errored))
u64 _self = _this->finalize(thread_state::errored);

if (!_self)
{
delete _this;
}

thread_base::finalize(0);

#ifdef _WIN32
_endthreadex(0);
_endthread();
#else
if (_self)
{
pthread_detach(reinterpret_cast<pthread_t>(_self));
}

pthread_exit(0);
#endif
}
Expand Down
13 changes: 9 additions & 4 deletions Utilities/Thread.h
Expand Up @@ -119,7 +119,7 @@ class thread_base
atomic_t<u64> m_cycles = 0;

// Start thread
void start(native_entry);
void start(native_entry, void(*)());

// Called at the thread start
void initialize(void (*error_cb)(), bool(*wait_cb)(const void*));
Expand All @@ -136,6 +136,9 @@ class thread_base
// Set name for debugger
static void set_name(std::string);

// Make trampoline with stack fix
static void(*make_trampoline(native_entry))();

friend class thread_ctrl;

template <class Context>
Expand Down Expand Up @@ -359,6 +362,8 @@ class named_thread final : public Context, result_storage_t<Context>, thread_bas
return thread::finalize(thread_state::finished);
}

static inline void(*trampoline)() = thread::make_trampoline(entry_point);

friend class thread_ctrl;

public:
Expand All @@ -368,7 +373,7 @@ class named_thread final : public Context, result_storage_t<Context>, thread_bas
: Context()
, thread(Context::thread_name)
{
thread::start(&named_thread::entry_point);
thread::start(&named_thread::entry_point, trampoline);
}

// Normal forwarding constructor
Expand All @@ -377,15 +382,15 @@ class named_thread final : public Context, result_storage_t<Context>, thread_bas
: Context(std::forward<Args>(args)...)
, thread(name)
{
thread::start(&named_thread::entry_point);
thread::start(&named_thread::entry_point, trampoline);
}

// Lambda constructor, also the implicit deduction guide candidate
named_thread(std::string_view name, Context&& f)
: Context(std::forward<Context>(f))
, thread(name)
{
thread::start(&named_thread::entry_point);
thread::start(&named_thread::entry_point, trampoline);
}

named_thread(const named_thread&) = delete;
Expand Down