Skip to content

Commit

Permalink
[MOD] Wrap lua_State* within iris_lua_t.
Browse files Browse the repository at this point in the history
  • Loading branch information
paintdream committed May 13, 2023
1 parent ad0d730 commit 9734c43
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 72 deletions.
72 changes: 36 additions & 36 deletions src/optional/iris_lua.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ namespace iris {
struct iris_lua_t : enable_read_write_fence_t<> {
// borrow from an existing state
explicit iris_lua_t(lua_State* L) noexcept : state(L) {}
iris_lua_t(iris_lua_t&& rhs) noexcept : state(rhs.state) { rhs.state = nullptr; }
iris_lua_t(const iris_lua_t& rhs) noexcept : state(rhs.state) {}

operator lua_State* () const noexcept {
return get_state();
Expand Down Expand Up @@ -261,19 +263,18 @@ namespace iris {
};

template <typename... args_t>
static auto refguard(lua_State* L, args_t&... args) noexcept {
return refguard_t<sizeof...(args_t)>(L, args...);
auto refguard(args_t&... args) noexcept {
return refguard_t<sizeof...(args_t)>(state, args...);
}

typedef void (*registar_t)(lua_State*);
template <typename type_t>
struct has_registar {
template <typename> static std::false_type test(...);
template <typename impl_t> static auto test(int)
-> decltype(std::declval<impl_t>().lua_registar(std::declval<lua_State*>()), std::true_type());
-> decltype(std::declval<impl_t>().lua_registar(std::declval<iris_lua_t>()), std::true_type());
static constexpr bool value = std::is_same<decltype(test<type_t>(0)), std::true_type>::value;
static void default_registar(lua_State*) {}
static constexpr registar_t get_registar() {
static void default_registar(iris_lua_t) {}
static constexpr auto get_registar() {
if constexpr (value) {
return &type_t::lua_registar;
} else {
Expand All @@ -283,7 +284,7 @@ namespace iris {
};

// register a new type, taking registar from &type_t::lua_registar by default, and you could also specify your own registar.
template <typename type_t, int user_value_count = 0, registar_t registar = has_registar<type_t>::get_registar(), typename... args_t>
template <typename type_t, int user_value_count = 0, auto registar = has_registar<type_t>::get_registar(), typename... args_t>
ref_t make_type(std::string_view name, args_t&... args) {
auto guard = write_fence();

Expand Down Expand Up @@ -319,14 +320,15 @@ namespace iris {
lua_rawset(L, -3);

// call custom registar if needed
registar(L);
registar(iris_lua_t(L));
return ref_t(luaL_ref(L, LUA_REGISTRYINDEX));
}

// quick way for registering a type into global space
template <typename type_t, int user_value_count = 0, registar_t registar = has_registar<type_t>::get_registar(), typename... args_t>
template <typename type_t, int user_value_count = 0, auto registar = has_registar<type_t>::get_registar(), typename... args_t>
void register_type(std::string_view name, args_t&... args) {
ref_t r = make_type<type_t, user_value_count, registar, args_t...>(name, args...);

auto guard = write_fence();
lua_State* L = state;
stack_guard_t stack_guard(L);
Expand All @@ -338,7 +340,10 @@ namespace iris {

// define a variable by value
template <typename key_t, typename value_t>
static void define(lua_State* L, key_t&& key, value_t&& value) {
void define(key_t&& key, value_t&& value) {
auto guard = write_fence();

lua_State* L = state;
stack_guard_t stack_guard(L);

push_variable(L, std::forward<key_t>(key));
Expand All @@ -348,7 +353,10 @@ namespace iris {

// define a bound member function/property
template <auto ptr, typename key_t>
static void define(lua_State* L, key_t&& key) {
void define(key_t&& key) {
auto guard = write_fence();

lua_State* L = state;
stack_guard_t stack_guard(L);

if constexpr (std::is_member_function_pointer_v<decltype(ptr)>) {
Expand All @@ -368,7 +376,7 @@ namespace iris {
stack_guard_t stack_guard(L);
lua_newtable(L);

func(L);
func(iris_lua_t(L));
return ref_t(luaL_ref(L, LUA_REGISTRYINDEX));
}

Expand All @@ -378,27 +386,12 @@ namespace iris {
deref(state, r);
}

static void deref(lua_State* L, ref_t& r) noexcept {
if (r.value != 0) {
luaL_unref(L, LUA_REGISTRYINDEX, r.value);
r.value = 0;
}
}

// call function in protect mode
template <typename return_t, typename callable_t, typename... args_t>
return_t call(callable_t&& reference, args_t&&... args) {
auto guard = write_fence();

if constexpr (!std::is_void_v<return_t>) {
return call<return_t>(std::forward<callable_t>(reference), std::forward<args_t>(args)...);
} else {
call<return_t>(std::forward<callable_t>(reference), std::forward<args_t>(args)...);
}
}

template <typename return_t, typename callable_t, typename... args_t>
static return_t call(lua_State* L, callable_t&& reference, args_t&&... args) {
lua_State* L = state;
stack_guard_t stack_guard(L);
push_variable(L, std::forward<callable_t>(reference));
push_arguments(L, std::forward<args_t>(args)...);
Expand Down Expand Up @@ -429,6 +422,13 @@ namespace iris {
}

protected:
static void deref(lua_State* L, ref_t& r) noexcept {
if (r.value != 0) {
luaL_unref(L, LUA_REGISTRYINDEX, r.value);
r.value = 0;
}
}

// a guard for checking stack balance
struct stack_guard_t {
#ifdef _DEBUG
Expand Down Expand Up @@ -516,7 +516,7 @@ namespace iris {
struct has_finalize {
template <typename> static std::false_type test(...);
template <typename impl_t> static auto test(int)
-> decltype(std::declval<impl_t>().lua_finalize(std::declval<lua_State*>()), std::true_type());
-> decltype(std::declval<impl_t>().lua_finalize(std::declval<iris_lua_t>()), std::true_type());
static constexpr bool value = std::is_same<decltype(test<type_t>(0)), std::true_type>::value;
};

Expand All @@ -527,7 +527,7 @@ namespace iris {

// call lua_finalize if needed
if constexpr (has_finalize<type_t>::value) {
p->lua_finalize(L);
p->lua_finalize(iris_lua_t(L));
}

assert(p != nullptr);
Expand Down Expand Up @@ -690,8 +690,8 @@ namespace iris {
template <auto function, int index, typename return_t, typename tuple_t, typename... params_t>
static int function_invoke(lua_State* L, int stack_index, params_t&&... params) {
if constexpr (index < std::tuple_size_v<tuple_t>) {
if constexpr (std::is_same_v<std::tuple_element_t<index, tuple_t>, lua_State*>) {
return function_invoke<function, index + 1, return_t, tuple_t>(L, stack_index, std::forward<params_t>(params)..., L);
if constexpr (std::is_same_v<iris_lua_t, std::remove_volatile_t<std::remove_const_t<std::remove_reference_t<std::tuple_element_t<index, tuple_t>>>>>) {
return function_invoke<function, index + 1, return_t, tuple_t>(L, stack_index, std::forward<params_t>(params)..., iris_lua_t(L));
} else {
return function_invoke<function, index + 1, return_t, tuple_t>(L, stack_index + 1, std::forward<params_t>(params)..., get_variable<std::tuple_element_t<index, tuple_t>>(L, stack_index));
}
Expand Down Expand Up @@ -724,7 +724,7 @@ namespace iris {
}
}

check_required_parameters<index + (std::is_same_v<first_t, lua_State*> ? 0 : 1), args_t...>(L);
check_required_parameters<index + (std::is_same_v<iris_lua_t, std::remove_volatile_t<std::remove_const_t<std::remove_reference_t<first_t>>>> ? 0 : 1), args_t...>(L);
}

template <auto function, typename return_t, typename... args_t>
Expand All @@ -736,8 +736,8 @@ namespace iris {
template <auto function, int index, typename coroutine_t, typename tuple_t, typename... params_t>
static bool function_coroutine_invoke(lua_State* L, int stack_index, params_t&&... params) {
if constexpr (index < std::tuple_size_v<tuple_t>) {
if constexpr (std::is_same_v<std::tuple_element_t<index, tuple_t>, lua_State*>) {
return function_coroutine_invoke<function, index + 1, coroutine_t, tuple_t>(L, stack_index, std::forward<params_t>(params)..., L);
if constexpr (std::is_same_v<iris_lua_t, std::remove_volatile_t<std::remove_const_t<std::remove_reference_t<std::tuple_element_t<index, tuple_t>>>>>) {
return function_coroutine_invoke<function, index + 1, coroutine_t, tuple_t>(L, stack_index, std::forward<params_t>(params)..., iris_lua_t(L));
} else {
return function_coroutine_invoke<function, index + 1, coroutine_t, tuple_t>(L, stack_index + 1, std::forward<params_t>(params)..., get_variable<std::tuple_element_t<index, tuple_t>>(L, stack_index));
}
Expand Down Expand Up @@ -854,7 +854,7 @@ namespace iris {
} else if constexpr (iris_lua_convert_t<value_t>::value) {
iris_lua_convert_t<value_t>::to_lua(L, std::forward<type_t>(variable));
} else if constexpr (std::is_same_v<value_t, void*> || std::is_same_v<value_t, const void*>) {
lua_pushlightuserdata(L, variable);
lua_pushlightuserdata(L, const_cast<value_t>(variable));
} else if constexpr (std::is_same_v<value_t, bool>) {
lua_pushboolean(L, static_cast<bool>(variable));
} else if constexpr (std::is_integral_v<value_t> || std::is_enum_v<value_t>) {
Expand Down
73 changes: 37 additions & 36 deletions test/iris_lua_demo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,32 +47,33 @@ struct iris::iris_lua_convert_t<vector3> : std::true_type {
};

struct example_t {
static void lua_registar(lua_State* L) {
iris_lua_t::define<&example_t::value>(L, "value");
iris_lua_t::define<&example_t::const_value>(L, "const_value");
iris_lua_t::define<&example_t::accum_value>(L, "accum_value");
iris_lua_t::define<&example_t::join_value>(L, "join_value");
iris_lua_t::define<&example_t::join_value_required>(L, "join_value_required");
iris_lua_t::define<&example_t::join_value_refptr>(L, "join_value_refptr");
iris_lua_t::define<&example_t::join_value_required_refptr>(L, "join_value_required_refptr");
iris_lua_t::define<&example_t::get_value>(L, "get_value");
iris_lua_t::define<&example_t::call>(L, "call");
iris_lua_t::define<&example_t::forward_pair>(L, "forward_pair");
iris_lua_t::define<&example_t::forward_tuple>(L, "forward_tuple");
iris_lua_t::define<&example_t::forward_map>(L, "forward_map");
iris_lua_t::define<&example_t::forward_vector>(L, "forward_vector");
iris_lua_t::define<&example_t::prime>(L, "prime");
iris_lua_t::define<&example_t::get_vector3>(L, "get_vector3");
static void lua_registar(iris_lua_t lua) {
lua.define<&example_t::value>("value");
lua.define<&example_t::const_value>("const_value");
lua.define<&example_t::accum_value>("accum_value");
lua.define<&example_t::join_value>("join_value");
lua.define<&example_t::join_value_required>("join_value_required");
lua.define<&example_t::join_value_refptr>("join_value_refptr");
lua.define<&example_t::join_value_required_refptr>("join_value_required_refptr");
lua.define<&example_t::get_value>("get_value");
lua.define<&example_t::call>("call");
lua.define<&example_t::forward_pair>("forward_pair");
lua.define<&example_t::forward_tuple>("forward_tuple");
lua.define<&example_t::forward_map>("forward_map");
lua.define<&example_t::forward_vector>("forward_vector");
lua.define<&example_t::prime>("prime");
lua.define<&example_t::get_vector3>("get_vector3");
#if USE_LUA_COROUTINE
iris_lua_t::define<&example_t::coro_get_int>(L, "coro_get_int");
iris_lua_t::define<&example_t::coro_get_none>(L, "coro_get_none");
iris_lua_t::define<&example_t::mem_coro_get_int>(L, "mem_coro_get_int");
lua.define<&example_t::coro_get_int>("coro_get_int");
lua.define<&example_t::coro_get_none>("coro_get_none");
lua.define<&example_t::mem_coro_get_int>("mem_coro_get_int");
#endif
}

int call(lua_State* L, iris_lua_t::ref_t&& r, int value) {
int result = iris_lua_t::call<int>(L, r, value);
iris_lua_t::deref(L, r);
int call(iris_lua_t lua, iris_lua_t::ref_t&& r, int value) {
int result = lua.call<int>(r, value);
lua.deref(r);

return result;
}

Expand All @@ -91,14 +92,14 @@ struct example_t {
value += rhs.get()->value;
}

void join_value_required_refptr(lua_State* L, iris_lua_t::required_t<iris_lua_t::refptr_t<example_t>>&& rhs) noexcept {
void join_value_required_refptr(iris_lua_t&& lua, iris_lua_t::required_t<iris_lua_t::refptr_t<example_t>>&& rhs) noexcept {
printf("Required ptr!\n");
value += rhs.get().get()->value;
iris_lua_t::deref(L, rhs.get());
lua.deref(rhs.get());
}

void join_value_refptr(lua_State* L, iris_lua_t::refptr_t<example_t>&& rhs) noexcept {
auto guard = iris_lua_t::refguard(L, rhs);
void join_value_refptr(iris_lua_t&& lua, iris_lua_t::refptr_t<example_t>&& rhs) noexcept {
auto guard = lua.refguard(rhs);
if (rhs != nullptr) {
value += rhs->value;
}
Expand Down Expand Up @@ -132,12 +133,12 @@ struct example_t {
return std::move(v);
}

iris_lua_t::ref_t prime(lua_State* L) const {
return iris_lua_t(L).make_table([](lua_State* L) noexcept {
iris_lua_t::define(L, "name", "prime");
iris_lua_t::define(L, 1, 2);
iris_lua_t::define(L, 2, 3);
iris_lua_t::define(L, 3, 5);
iris_lua_t::ref_t prime(iris_lua_t lua) const {
return lua.make_table([](iris_lua_t lua) noexcept {
lua.define("name", "prime");
lua.define(1, 2);
lua.define(2, 3);
lua.define(3, 5);
});
}

Expand Down Expand Up @@ -224,10 +225,10 @@ int main(void) {
end\n\
return true\n");
assert(ret);
auto tab = lua.make_table([](lua_State* L) {
iris_lua_t::define(L, "key", "value");
iris_lua_t::define(L, 1, "number");
iris_lua_t::define(L, 2, 2);
auto tab = lua.make_table([](iris_lua_t&& lua) {
lua.define("key", "value");
lua.define(1, "number");
lua.define(2, 2);
});

tab.set(lua, "set", "newvalue");
Expand Down

0 comments on commit 9734c43

Please sign in to comment.