Skip to content

Commit

Permalink
[ADD] Add optional warp checking for iris_lua_t.
Browse files Browse the repository at this point in the history
  • Loading branch information
paintdream committed Jun 4, 2023
1 parent 62af9d4 commit a7f19ed
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 17 deletions.
34 changes: 27 additions & 7 deletions src/iris_lua.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ namespace iris {
struct iris_lua_convert_t : std::false_type {};

// A simple lua binding with C++17
template <typename warp_t = void>
struct iris_lua_t : enable_read_write_fence_t<> {
// borrow from an existing state
explicit iris_lua_t(lua_State* L) noexcept : state(L) {}
Expand Down Expand Up @@ -196,7 +197,7 @@ namespace iris {
stack_guard_t stack_guard(L);

push_variable(L, *this);
size_t len = lua_rawlen(L, -1);
size_t len = static_cast<size_t>(lua_rawlen(L, -1));
lua_pop(L, 1);

return len;
Expand All @@ -219,7 +220,7 @@ namespace iris {
using internal_type_t = type_t*;

operator bool() const noexcept {
return value != 0 && ptr != nullptr;
return ref_t::value != 0 && ptr != nullptr;
}

operator type_t* () const noexcept {
Expand All @@ -235,6 +236,7 @@ namespace iris {
}

friend struct iris_lua_t;

protected:
type_t* ptr;
};
Expand Down Expand Up @@ -757,22 +759,40 @@ namespace iris {
using return_t = typename coroutine_t::return_type_t;
auto coroutine = function(std::forward<params_t>(params)...);

#ifdef _DEBUG
warp_t* current_warp = nullptr;
if constexpr (!std::is_void_v<warp_t>) {
current_warp = warp_t::get_current_warp();
}
#endif
void* yield_mark = static_cast<void*>(&coroutine);

if constexpr (!std::is_void_v<return_t>) {
coroutine.complete([L, p = static_cast<void*>(&coroutine)](return_t&& value) {
coroutine.complete([=](return_t&& value) {
#ifdef _DEBUG
if constexpr (!std::is_void_v<warp_t>) {
assert(current_warp == warp_t::get_current_warp());
}
#endif
push_variable(L, std::move(value));
push_variable(L, p);
push_variable(L, yield_mark);
coroutine_continuation(L);
}).run();
} else {
coroutine.complete([L, p = static_cast<void*>(&coroutine)]() {
coroutine.complete([=]() {
#ifdef _DEBUG
if constexpr (!std::is_void_v<warp_t>) {
assert(current_warp == warp_t::get_current_warp());
}
#endif
lua_pushnil(L);
push_variable(L, p);
push_variable(L, yield_mark);
coroutine_continuation(L);
}).run();
}

// already completed?
if (lua_touserdata(L, -1) == &coroutine) {
if (lua_touserdata(L, -1) == yield_mark) {
lua_pop(L, 1);
return true;
} else {
Expand Down
24 changes: 14 additions & 10 deletions test/iris_lua_demo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@ using warp_t = iris_warp_t<worker_t>;
static warp_t* warpptr = nullptr;
static warp_t* warpptr2 = nullptr;
static worker_t* workerptr = nullptr;
#else
using warp_t = void;
#endif

using lua_t = iris_lua_t<warp_t>;

struct vector3 {
float x, y, z;
};
Expand Down Expand Up @@ -47,7 +51,7 @@ struct iris::iris_lua_convert_t<vector3> : std::true_type {
};

struct example_t {
static void lua_registar(iris_lua_t lua) {
static void lua_registar(lua_t lua) {
lua.define<&example_t::value>("value");
lua.define<&example_t::const_value>("const_value");
lua.define<&example_t::accum_value>("accum_value");
Expand All @@ -70,7 +74,7 @@ struct example_t {
#endif
}

int call(iris_lua_t lua, iris_lua_t::ref_t&& r, int value) {
int call(lua_t lua, lua_t::ref_t&& r, int value) {
int result = lua.call<int>(r, value);
lua.deref(r);

Expand All @@ -87,18 +91,18 @@ struct example_t {
}
}

void join_value_required(iris_lua_t::required_t<example_t*>&& rhs) noexcept {
void join_value_required(lua_t::required_t<example_t*>&& rhs) noexcept {
printf("Required!\n");
value += rhs.get()->value;
}

void join_value_required_refptr(iris_lua_t&& lua, iris_lua_t::required_t<iris_lua_t::refptr_t<example_t>>&& rhs) noexcept {
void join_value_required_refptr(lua_t&& lua, lua_t::required_t<lua_t::refptr_t<example_t>>&& rhs) noexcept {
printf("Required ptr!\n");
value += rhs.get().get()->value;
lua.deref(rhs.get());
}

void join_value_refptr(iris_lua_t&& lua, iris_lua_t::refptr_t<example_t>&& rhs) noexcept {
void join_value_refptr(lua_t&& lua, lua_t::refptr_t<example_t>&& rhs) noexcept {
auto guard = lua.ref_guard(rhs);
if (rhs != nullptr) {
value += rhs->value;
Expand Down Expand Up @@ -133,8 +137,8 @@ struct example_t {
return std::move(v);
}

iris_lua_t::ref_t prime(iris_lua_t lua) const {
return lua.make_table([](iris_lua_t lua) noexcept {
lua_t::ref_t prime(lua_t lua) const {
return lua.make_table([](lua_t lua) noexcept {
lua.define("name", "prime");
lua.define(1, 2);
lua.define(2, 3);
Expand Down Expand Up @@ -171,7 +175,7 @@ int main(void) {
lua_State* L = luaL_newstate();
luaL_openlibs(L);

iris_lua_t lua(L);
lua_t lua(L);
lua.register_type<example_t>("example_t");

#if USE_LUA_COROUTINE
Expand Down Expand Up @@ -226,15 +230,15 @@ int main(void) {
end\n\
return true\n");
assert(ret);
auto tab = lua.make_table([](iris_lua_t&& lua) {
auto tab = lua.make_table([](lua_t&& lua) {
lua.define("key", "value");
lua.define(1, "number");
lua.define(2, 2);
});

tab.set(lua, "set", "newvalue");
assert(tab.get<int>(lua, 2) == 2);
auto r = tab.get<iris_lua_t::ref_t>(lua, 2);
auto r = tab.get<lua_t::ref_t>(lua, 2);
assert(r.as<int>(lua) == 2);
assert(tab.size(lua) == 2);
lua.deref(r);
Expand Down

0 comments on commit a7f19ed

Please sign in to comment.