diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 91b67396de2d..3f70560036b9 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -25,6 +25,7 @@ lua_source(lua_sources lua/buffer.lua) lua_source(lua_sources lua/uuid.lua) lua_source(lua_sources lua/crypto.lua) lua_source(lua_sources lua/digest.lua) +lua_source(lua_sources lua/msgpack.lua) lua_source(lua_sources lua/msgpackffi.lua) lua_source(lua_sources lua/uri.lua) lua_source(lua_sources lua/socket.lua) diff --git a/src/lua/init.c b/src/lua/init.c index d4e008d302ed..b4a086959530 100644 --- a/src/lua/init.c +++ b/src/lua/init.c @@ -77,6 +77,7 @@ bool start_loop = true; /* contents of src/lua/ files */ extern char strict_lua[], uuid_lua[], + msgpack_lua[], msgpackffi_lua[], fun_lua[], crypto_lua[], @@ -125,6 +126,7 @@ static const char *lua_modules[] = { "string", string_lua, "table", table_lua, "buffer", buffer_lua, + "msgpack", msgpack_lua, "msgpackffi", msgpackffi_lua, "crypto", crypto_lua, "digest", digest_lua, @@ -637,4 +639,4 @@ tarantool_lua_free() rl_cleanup_after_signal(); } #endif -} \ No newline at end of file +} diff --git a/src/lua/msgpack.c b/src/lua/msgpack.c index d5f65a032926..f6bc5b96fcb4 100644 --- a/src/lua/msgpack.c +++ b/src/lua/msgpack.c @@ -454,49 +454,118 @@ static int lua_msgpack_decode(lua_State *L) { int index = lua_gettop(L); - if (index != 2 && index != 1 && lua_type(L, 1) != LUA_TSTRING) - return luaL_error(L, "msgpack.decode: a Lua string expected"); + int type = index >= 1 ? lua_type(L, 1) : LUA_TNONE; + if (type != LUA_TSTRING && type != LUA_TCDATA) + return luaL_error(L, "msgpack.decode: " + "a Lua string or cdata expected"); + const char *data; size_t data_len; - uint32_t offset = index > 1 ? lua_tointeger(L, 2) - 1 : 0; - const char *data = lua_tolstring(L, 1, &data_len); - if (offset >= data_len) - luaL_error(L, "msgpack.decode: offset is out of bounds"); - const char *end = data + data_len; + ptrdiff_t offset = 0; + uint32_t ctypeid = 0; + if (type == LUA_TSTRING) { + if (index > 1) + offset = luaL_checkinteger(L, 2) - 1; + data = lua_tolstring(L, 1, &data_len); + if (offset < 0 || (size_t)offset >= data_len) + return luaL_error(L, "msgpack.decode: " + "offset is out of bounds"); + } else { + data = *(const char **)luaL_checkcdata(L, 1, &ctypeid); + data_len = luaL_checkinteger(L, 2); + } - const char *b = data + offset; - if (mp_check(&b, end)) - return luaL_error(L, "msgpack.decode: invalid MsgPack"); + const char *p = data + offset; + if (mp_check(&p, data + data_len)) { + lua_pushnil(L); + lua_pushnil(L); + lua_pushstring(L, "msgpack.decode: invalid MsgPack"); + return 3; + } struct luaL_serializer *cfg = luaL_checkserializer(L); - b = data + offset; - luamp_decode(L, cfg, &b); - lua_pushinteger(L, b - data + 1); + p = data + offset; + luamp_decode(L, cfg, &p); + + if (type == LUA_TSTRING) { + lua_pushinteger(L, p - data + 1); + } else { + *(const char **)luaL_pushcdata(L, ctypeid) = p; + } return 2; } static int -lua_ibuf_msgpack_decode(lua_State *L) +lua_msgpack_decode_unchecked(lua_State *L) { + int index = lua_gettop(L); + int type = index >= 1 ? lua_type(L, 1) : LUA_TNONE; + if (type != LUA_TSTRING && type != LUA_TCDATA) + return luaL_error(L, "msgpack.decode_unchecked: " + "a Lua string or cdata expected"); + + const char *data; + ptrdiff_t offset = 0; uint32_t ctypeid = 0; - const char *rpos = *(const char **)luaL_checkcdata(L, 1, &ctypeid); + if (type == LUA_TSTRING) { + if (index > 1) + offset = luaL_checkinteger(L, 2) - 1; + size_t data_len; + data = lua_tolstring(L, 1, &data_len); + if (offset < 0 || (size_t)offset >= data_len) + return luaL_error(L, "msgpack.decode_unchecked: " + "offset is out of bounds"); + } else { + data = *(const char **)luaL_checkcdata(L, 1, &ctypeid); + } + struct luaL_serializer *cfg = luaL_checkserializer(L); - luamp_decode(L, cfg, &rpos); - *(const char **)luaL_pushcdata(L, ctypeid) = rpos; - lua_pushvalue(L, -2); + + const char *p = data + offset; + luamp_decode(L, cfg, &p); + + if (type == LUA_TSTRING) { + lua_pushinteger(L, p - data + 1); + } else { + *(const char **)luaL_pushcdata(L, ctypeid) = p; + } return 2; } +static int +lua_msgpack_encode_internal(lua_State *L) +{ + struct luaL_serializer *cfg = (struct luaL_serializer *) + lua_topointer(L, 1); + struct ibuf *buf = (struct ibuf *) lua_topointer(L, 2); + size_t used = ibuf_used(buf); + + struct mpstream stream; + mpstream_init(&stream, buf, ibuf_reserve_cb, ibuf_alloc_cb, + luamp_error, L); + + luamp_encode(L, cfg, &stream, 3); + mpstream_flush(&stream); + + lua_pushinteger(L, ibuf_used(buf) - used); + return 1; +} + static int lua_msgpack_new(lua_State *L); static const luaL_Reg msgpacklib[] = { { "encode", lua_msgpack_encode }, { "decode", lua_msgpack_decode }, - { "ibuf_decode", lua_ibuf_msgpack_decode }, + { "decode_unchecked", lua_msgpack_decode_unchecked }, { "new", lua_msgpack_new }, - { NULL, NULL} + { NULL, NULL } +}; + +static const luaL_Reg msgpacklib_internal[] = { + { "encode", lua_msgpack_encode_internal }, + { NULL, NULL } }; static int @@ -510,5 +579,7 @@ LUALIB_API int luaopen_msgpack(lua_State *L) { luaL_msgpack_default = luaL_newserializer(L, "msgpack", msgpacklib); + luaL_register_module(L, "msgpack.internal", msgpacklib_internal); + lua_pop(L, 1); return 1; } diff --git a/src/lua/msgpack.lua b/src/lua/msgpack.lua new file mode 100644 index 000000000000..5be64c8cf706 --- /dev/null +++ b/src/lua/msgpack.lua @@ -0,0 +1,22 @@ +-- msgpack.lua (internal file) + +local ffi = require('ffi') +local msgpack = require('msgpack') +local internal = msgpack.internal + +local ibuf_t = ffi.typeof('struct ibuf') + +msgpack.ibuf_encode = function(buf, obj) + if not ffi.istype(ibuf_t, buf) then + error("Usage: msgpack.ibuf_encode(ibuf, obj)") + end + return internal.encode(msgpack, buf, obj) +end + +-- Backward compatibility wrapper. +msgpack.ibuf_decode = function(data) + local obj, data_end = msgpack.decode_unchecked(data) + return data_end, obj +end + +return msgpack diff --git a/test/app/msgpack.result b/test/app/msgpack.result new file mode 100644 index 000000000000..1a2b6e0a643b --- /dev/null +++ b/test/app/msgpack.result @@ -0,0 +1,173 @@ +buffer = require 'buffer' +--- +... +msgpack = require 'msgpack' +--- +... +-- Arguments check. +msgpack.decode() +--- +- error: 'msgpack.decode: a Lua string or cdata expected' +... +msgpack.decode(123) +--- +- error: 'msgpack.decode: a Lua string or cdata expected' +... +msgpack.decode('test', 0) +--- +- error: 'msgpack.decode: offset is out of bounds' +... +msgpack.decode('test', 5) +--- +- error: 'msgpack.decode: offset is out of bounds' +... +msgpack.decode_unchecked() +--- +- error: 'msgpack.decode_unchecked: a Lua string or cdata expected' +... +msgpack.decode_unchecked(123) +--- +- error: 'msgpack.decode_unchecked: a Lua string or cdata expected' +... +msgpack.decode_unchecked('test', 0) +--- +- error: 'msgpack.decode_unchecked: offset is out of bounds' +... +msgpack.decode_unchecked('test', 5) +--- +- error: 'msgpack.decode_unchecked: offset is out of bounds' +... +-- Encode/decode a string. +s = msgpack.encode({1, 2, 3}) .. msgpack.encode({4, 5, 6}) +--- +... +obj, offset = msgpack.decode(s) +--- +... +obj +--- +- [1, 2, 3] +... +obj, offset = msgpack.decode(s, offset) +--- +... +obj +--- +- [4, 5, 6] +... +offset == #s + 1 +--- +- true +... +obj, offset = msgpack.decode_unchecked(s) +--- +... +obj +--- +- [1, 2, 3] +... +obj, offset = msgpack.decode_unchecked(s, offset) +--- +... +obj +--- +- [4, 5, 6] +... +offset == #s + 1 +--- +- true +... +-- Encode/decode a buffer. +buf = buffer.ibuf() +--- +... +len = msgpack.ibuf_encode(buf, {1, 2, 3}) +--- +... +len = msgpack.ibuf_encode(buf, {4, 5, 6}) + len +--- +... +buf:size() == len +--- +- true +... +orig_rpos = buf.rpos +--- +... +obj, rpos = msgpack.decode(buf.rpos, buf:size()) +--- +... +obj +--- +- [1, 2, 3] +... +buf.rpos = rpos +--- +... +obj, rpos = msgpack.decode(buf.rpos, buf:size()) +--- +... +obj +--- +- [4, 5, 6] +... +buf.rpos = rpos +--- +... +buf:size() == 0 +--- +- true +... +buf.rpos = orig_rpos +--- +... +obj, rpos = msgpack.decode_unchecked(buf.rpos, buf:size()) +--- +... +obj +--- +- [1, 2, 3] +... +buf.rpos = rpos +--- +... +obj, rpos = msgpack.decode_unchecked(buf.rpos, buf:size()) +--- +... +obj +--- +- [4, 5, 6] +... +buf.rpos = rpos +--- +... +buf:size() == 0 +--- +- true +... +-- Invalid msgpack. +s = msgpack.encode({1, 2, 3}) +--- +... +s = s:sub(1, -2) +--- +... +msgpack.decode(s) +--- +- null +- null +- 'msgpack.decode: invalid MsgPack' +... +buf = buffer.ibuf() +--- +... +msgpack.ibuf_encode(buf, {1, 2, 3}) +--- +- 4 +... +msgpack.decode(buf.rpos, buf:size() - 1) +--- +- null +- null +- 'msgpack.decode: invalid MsgPack' +... diff --git a/test/app/msgpack.test.lua b/test/app/msgpack.test.lua new file mode 100644 index 000000000000..91ffddf5830f --- /dev/null +++ b/test/app/msgpack.test.lua @@ -0,0 +1,55 @@ +buffer = require 'buffer' +msgpack = require 'msgpack' + +-- Arguments check. +msgpack.decode() +msgpack.decode(123) +msgpack.decode('test', 0) +msgpack.decode('test', 5) +msgpack.decode_unchecked() +msgpack.decode_unchecked(123) +msgpack.decode_unchecked('test', 0) +msgpack.decode_unchecked('test', 5) + +-- Encode/decode a string. +s = msgpack.encode({1, 2, 3}) .. msgpack.encode({4, 5, 6}) +obj, offset = msgpack.decode(s) +obj +obj, offset = msgpack.decode(s, offset) +obj +offset == #s + 1 +obj, offset = msgpack.decode_unchecked(s) +obj +obj, offset = msgpack.decode_unchecked(s, offset) +obj +offset == #s + 1 + +-- Encode/decode a buffer. +buf = buffer.ibuf() +len = msgpack.ibuf_encode(buf, {1, 2, 3}) +len = msgpack.ibuf_encode(buf, {4, 5, 6}) + len +buf:size() == len +orig_rpos = buf.rpos +obj, rpos = msgpack.decode(buf.rpos, buf:size()) +obj +buf.rpos = rpos +obj, rpos = msgpack.decode(buf.rpos, buf:size()) +obj +buf.rpos = rpos +buf:size() == 0 +buf.rpos = orig_rpos +obj, rpos = msgpack.decode_unchecked(buf.rpos, buf:size()) +obj +buf.rpos = rpos +obj, rpos = msgpack.decode_unchecked(buf.rpos, buf:size()) +obj +buf.rpos = rpos +buf:size() == 0 + +-- Invalid msgpack. +s = msgpack.encode({1, 2, 3}) +s = s:sub(1, -2) +msgpack.decode(s) +buf = buffer.ibuf() +msgpack.ibuf_encode(buf, {1, 2, 3}) +msgpack.decode(buf.rpos, buf:size() - 1)