diff --git a/src/lua/msgpack.c b/src/lua/msgpack.c index d5f65a032926..9721348262ef 100644 --- a/src/lua/msgpack.c +++ b/src/lua/msgpack.c @@ -91,6 +91,9 @@ mpstream_reset(struct mpstream *stream) } +static uint32_t CTID_CHAR_PTR; +static uint32_t CTID_STRUCT_IBUF; + struct luaL_serializer *luaL_msgpack_default = NULL; static enum mp_type @@ -430,14 +433,24 @@ static int lua_msgpack_encode(lua_State *L) { int index = lua_gettop(L); - if (index != 1) - luaL_error(L, "msgpack.encode: a Lua object expected"); + if (index < 1) + return luaL_error(L, "msgpack.encode: a Lua object expected"); + + struct ibuf *buf; + if (index > 1) { + uint32_t ctypeid; + buf = luaL_checkcdata(L, 2, &ctypeid); + if (ctypeid != CTID_STRUCT_IBUF) + return luaL_error(L, "msgpack.encode: argument 2 " + "must be of type 'struct ibuf'"); + } else { + buf = tarantool_lua_ibuf; + ibuf_reset(buf); + } + size_t used = ibuf_used(buf); struct luaL_serializer *cfg = luaL_checkserializer(L); - struct ibuf *buf = tarantool_lua_ibuf; - ibuf_reset(buf); - struct mpstream stream; mpstream_init(&stream, buf, ibuf_reserve_cb, ibuf_alloc_cb, luamp_error, L); @@ -445,37 +458,92 @@ lua_msgpack_encode(lua_State *L) luamp_encode(L, cfg, &stream, 1); mpstream_flush(&stream); - lua_pushlstring(L, buf->buf, ibuf_used(buf)); - ibuf_reinit(buf); + if (index > 1) { + lua_pushinteger(L, ibuf_used(buf) - used); + } else { + lua_pushlstring(L, buf->buf, ibuf_used(buf)); + ibuf_reinit(buf); + } return 1; } static int -lua_msgpack_decode(lua_State *L) +lua_msgpack_decode_cdata(lua_State *L, bool check) { - int index = lua_gettop(L); - if (index != 2 && index != 1 && lua_type(L, 1) != LUA_TSTRING) - return luaL_error(L, "msgpack.decode: a Lua string expected"); + uint32_t ctypeid; + const char *data = *(const char **)luaL_checkcdata(L, 1, &ctypeid); + if (ctypeid != CTID_CHAR_PTR) { + return luaL_error(L, "msgpack.decode: " + "a Lua string or 'char *' expected"); + } + if (check) { + size_t data_len = luaL_checkinteger(L, 2); + const char *p = data; + if (mp_check(&p, data + data_len) != 0) + return luaL_error(L, "msgpack.decode: invalid MsgPack"); + } + struct luaL_serializer *cfg = luaL_checkserializer(L); + luamp_decode(L, cfg, &data); + *(const char **)luaL_pushcdata(L, ctypeid) = data; + return 2; +} +static int +lua_msgpack_decode_string(lua_State *L, bool check) +{ + ptrdiff_t offset = 0; size_t data_len; - uint32_t offset = index > 1 ? lua_tointeger(L, 2) - 1 : 0; const char *data = lua_tolstring(L, 1, &data_len); - if (offset >= data_len) - luaL_error(L, "msgpack.decode: offset is out of bounds"); - const char *end = data + data_len; - - const char *b = data + offset; - if (mp_check(&b, end)) - return luaL_error(L, "msgpack.decode: invalid MsgPack"); - + if (lua_gettop(L) > 1) { + offset = luaL_checkinteger(L, 2) - 1; + if (offset < 0 || (size_t)offset >= data_len) + return luaL_error(L, "msgpack.decode: " + "offset is out of bounds"); + } + if (check) { + const char *p = data + offset; + if (mp_check(&p, data + data_len) != 0) + return luaL_error(L, "msgpack.decode: invalid MsgPack"); + } struct luaL_serializer *cfg = luaL_checkserializer(L); - - b = data + offset; - luamp_decode(L, cfg, &b); - lua_pushinteger(L, b - data + 1); + const char *p = data + offset; + luamp_decode(L, cfg, &p); + lua_pushinteger(L, p - data + 1); return 2; } +static int +lua_msgpack_decode(lua_State *L) +{ + int index = lua_gettop(L); + int type = index >= 1 ? lua_type(L, 1) : LUA_TNONE; + switch (type) { + case LUA_TCDATA: + return lua_msgpack_decode_cdata(L, true); + case LUA_TSTRING: + return lua_msgpack_decode_string(L, true); + default: + return luaL_error(L, "msgpack.decode: " + "a Lua string or 'char *' expected"); + } +} + +static int +lua_msgpack_decode_unchecked(lua_State *L) +{ + int index = lua_gettop(L); + int type = index >= 1 ? lua_type(L, 1) : LUA_TNONE; + switch (type) { + case LUA_TCDATA: + return lua_msgpack_decode_cdata(L, false); + case LUA_TSTRING: + return lua_msgpack_decode_string(L, false); + default: + return luaL_error(L, "msgpack.decode: " + "a Lua string or 'char *' expected"); + } +} + static int lua_ibuf_msgpack_decode(lua_State *L) { @@ -494,9 +562,10 @@ lua_msgpack_new(lua_State *L); static const luaL_Reg msgpacklib[] = { { "encode", lua_msgpack_encode }, { "decode", lua_msgpack_decode }, + { "decode_unchecked", lua_msgpack_decode_unchecked }, { "ibuf_decode", lua_ibuf_msgpack_decode }, { "new", lua_msgpack_new }, - { NULL, NULL} + { NULL, NULL } }; static int @@ -509,6 +578,13 @@ lua_msgpack_new(lua_State *L) LUALIB_API int luaopen_msgpack(lua_State *L) { + int rc = luaL_cdef(L, "struct ibuf;"); + assert(rc == 0); + (void) rc; + CTID_STRUCT_IBUF = luaL_ctypeid(L, "struct ibuf"); + assert(CTID_STRUCT_IBUF != 0); + CTID_CHAR_PTR = luaL_ctypeid(L, "char *"); + assert(CTID_CHAR_PTR != 0); luaL_msgpack_default = luaL_newserializer(L, "msgpack", msgpacklib); return 1; } diff --git a/test/app/msgpack.result b/test/app/msgpack.result new file mode 100644 index 000000000000..b1fe4e53b4f3 --- /dev/null +++ b/test/app/msgpack.result @@ -0,0 +1,204 @@ +buffer = require 'buffer' +--- +... +msgpack = require 'msgpack' +--- +... +-- Arguments check. +buf = buffer.ibuf() +--- +... +msgpack.encode() +--- +- error: 'msgpack.encode: a Lua object expected' +... +msgpack.encode('test', 'str') +--- +- error: expected cdata as 2 argument +... +msgpack.encode('test', buf.buf) +--- +- error: 'msgpack.encode: argument 2 must be of type ''struct ibuf''' +... +msgpack.decode() +--- +- error: 'msgpack.decode: a Lua string or ''char *'' expected' +... +msgpack.decode(123) +--- +- error: 'msgpack.decode: a Lua string or ''char *'' expected' +... +msgpack.decode(buf) +--- +- error: 'msgpack.decode: a Lua string or ''char *'' expected' +... +msgpack.decode(buf.buf, 'size') +--- +- error: 'bad argument #2 to ''?'' (number expected, got string)' +... +msgpack.decode('test', 0) +--- +- error: 'msgpack.decode: offset is out of bounds' +... +msgpack.decode('test', 5) +--- +- error: 'msgpack.decode: offset is out of bounds' +... +msgpack.decode('test', 'offset') +--- +- error: 'bad argument #2 to ''?'' (number expected, got string)' +... +msgpack.decode_unchecked() +--- +- error: 'msgpack.decode: a Lua string or ''char *'' expected' +... +msgpack.decode_unchecked(123) +--- +- error: 'msgpack.decode: a Lua string or ''char *'' expected' +... +msgpack.decode_unchecked(buf) +--- +- error: 'msgpack.decode: a Lua string or ''char *'' expected' +... +msgpack.decode_unchecked('test', 0) +--- +- error: 'msgpack.decode: offset is out of bounds' +... +msgpack.decode_unchecked('test', 5) +--- +- error: 'msgpack.decode: offset is out of bounds' +... +msgpack.decode_unchecked('test', 'offset') +--- +- error: 'bad argument #2 to ''?'' (number expected, got string)' +... +-- Encode/decode a string. +s = msgpack.encode({1, 2, 3}) .. msgpack.encode({4, 5, 6}) +--- +... +obj, offset = msgpack.decode(s) +--- +... +obj +--- +- [1, 2, 3] +... +obj, offset = msgpack.decode(s, offset) +--- +... +obj +--- +- [4, 5, 6] +... +offset == #s + 1 +--- +- true +... +obj, offset = msgpack.decode_unchecked(s) +--- +... +obj +--- +- [1, 2, 3] +... +obj, offset = msgpack.decode_unchecked(s, offset) +--- +... +obj +--- +- [4, 5, 6] +... +offset == #s + 1 +--- +- true +... +-- Encode/decode a buffer. +buf = buffer.ibuf() +--- +... +len = msgpack.encode({1, 2, 3}, buf) +--- +... +len = msgpack.encode({4, 5, 6}, buf) + len +--- +... +buf:size() == len +--- +- true +... +orig_rpos = buf.rpos +--- +... +obj, rpos = msgpack.decode(buf.rpos, buf:size()) +--- +... +obj +--- +- [1, 2, 3] +... +buf.rpos = rpos +--- +... +obj, rpos = msgpack.decode(buf.rpos, buf:size()) +--- +... +obj +--- +- [4, 5, 6] +... +buf.rpos = rpos +--- +... +buf:size() == 0 +--- +- true +... +buf.rpos = orig_rpos +--- +... +obj, rpos = msgpack.decode_unchecked(buf.rpos, buf:size()) +--- +... +obj +--- +- [1, 2, 3] +... +buf.rpos = rpos +--- +... +obj, rpos = msgpack.decode_unchecked(buf.rpos, buf:size()) +--- +... +obj +--- +- [4, 5, 6] +... +buf.rpos = rpos +--- +... +buf:size() == 0 +--- +- true +... +-- Invalid msgpack. +s = msgpack.encode({1, 2, 3}) +--- +... +s = s:sub(1, -2) +--- +... +msgpack.decode(s) +--- +- error: 'msgpack.decode: invalid MsgPack' +... +buf = buffer.ibuf() +--- +... +msgpack.encode({1, 2, 3}, buf) +--- +- 4 +... +msgpack.decode(buf.rpos, buf:size() - 1) +--- +- error: 'msgpack.decode: invalid MsgPack' +... diff --git a/test/app/msgpack.test.lua b/test/app/msgpack.test.lua new file mode 100644 index 000000000000..09c3dec5d1a9 --- /dev/null +++ b/test/app/msgpack.test.lua @@ -0,0 +1,64 @@ +buffer = require 'buffer' +msgpack = require 'msgpack' + +-- Arguments check. +buf = buffer.ibuf() +msgpack.encode() +msgpack.encode('test', 'str') +msgpack.encode('test', buf.buf) +msgpack.decode() +msgpack.decode(123) +msgpack.decode(buf) +msgpack.decode(buf.buf, 'size') +msgpack.decode('test', 0) +msgpack.decode('test', 5) +msgpack.decode('test', 'offset') +msgpack.decode_unchecked() +msgpack.decode_unchecked(123) +msgpack.decode_unchecked(buf) +msgpack.decode_unchecked('test', 0) +msgpack.decode_unchecked('test', 5) +msgpack.decode_unchecked('test', 'offset') + +-- Encode/decode a string. +s = msgpack.encode({1, 2, 3}) .. msgpack.encode({4, 5, 6}) +obj, offset = msgpack.decode(s) +obj +obj, offset = msgpack.decode(s, offset) +obj +offset == #s + 1 +obj, offset = msgpack.decode_unchecked(s) +obj +obj, offset = msgpack.decode_unchecked(s, offset) +obj +offset == #s + 1 + +-- Encode/decode a buffer. +buf = buffer.ibuf() +len = msgpack.encode({1, 2, 3}, buf) +len = msgpack.encode({4, 5, 6}, buf) + len +buf:size() == len +orig_rpos = buf.rpos +obj, rpos = msgpack.decode(buf.rpos, buf:size()) +obj +buf.rpos = rpos +obj, rpos = msgpack.decode(buf.rpos, buf:size()) +obj +buf.rpos = rpos +buf:size() == 0 +buf.rpos = orig_rpos +obj, rpos = msgpack.decode_unchecked(buf.rpos, buf:size()) +obj +buf.rpos = rpos +obj, rpos = msgpack.decode_unchecked(buf.rpos, buf:size()) +obj +buf.rpos = rpos +buf:size() == 0 + +-- Invalid msgpack. +s = msgpack.encode({1, 2, 3}) +s = s:sub(1, -2) +msgpack.decode(s) +buf = buffer.ibuf() +msgpack.encode({1, 2, 3}, buf) +msgpack.decode(buf.rpos, buf:size() - 1)