Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

net: mqtt: Several bugfixes #23821

Merged
merged 3 commits into from Mar 28, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Next
net: mqtt: Fix packet length decryption
The standard allows up to 4 bytes of packet length data, while current
implementation parsed up to 5 bytes.

Add additional unit test, which verifies that error is reported in case
of invalid packet length.

Signed-off-by: Robert Lubos <robert.lubos@nordicsemi.no>
  • Loading branch information
rlubos committed Mar 26, 2020
commit 11b7a37d9a0b438270421b224221d91929843de4
8 changes: 6 additions & 2 deletions subsys/net/lib/mqtt/mqtt_decoder.c
Expand Up @@ -158,14 +158,14 @@ static int unpack_data(u32_t length, struct buf_ctx *buf,
* @retval -EINVAL if the length decoding would use more that 4 bytes.
* @retval -EAGAIN if the buffer would be exceeded during the read.
*/
int packet_length_decode(struct buf_ctx *buf, u32_t *length)
static int packet_length_decode(struct buf_ctx *buf, u32_t *length)
{
u8_t shift = 0U;
u8_t bytes = 0U;

*length = 0U;
do {
if (bytes > MQTT_MAX_LENGTH_BYTES) {
if (bytes >= MQTT_MAX_LENGTH_BYTES) {
return -EINVAL;
}

Expand All @@ -179,6 +179,10 @@ int packet_length_decode(struct buf_ctx *buf, u32_t *length)
bytes++;
} while ((*(buf->cur++) & MQTT_LENGTH_CONTINUATION_BIT) != 0U);

if (*length > MQTT_MAX_PAYLOAD_SIZE) {
return -EINVAL;
}

MQTT_TRC("length:0x%08x", *length);

return 0;
Expand Down
68 changes: 68 additions & 0 deletions tests/net/lib/mqtt_packet/src/mqtt_packet.c
Expand Up @@ -170,6 +170,24 @@ static int eval_msg_unsuback(struct mqtt_test *mqtt_test);
*/
static int eval_msg_disconnect(struct mqtt_test *mqtt_test);

/**
* @brief eval_max_pkt_len Evaluate header with maximum allowed packet
* length.
* @param [in] mqtt_test MQTT test structure
* @return TC_PASS on success
* @return TC_FAIL on error
*/
static int eval_max_pkt_len(struct mqtt_test *mqtt_test);

/**
* @brief eval_corrupted_pkt_len Evaluate header exceeding maximum
* allowed packet length.
* @param [in] mqtt_test MQTT test structure
* @return TC_PASS on success
* @return TC_FAIL on error
*/
static int eval_corrupted_pkt_len(struct mqtt_test *mqtt_test);

/**
* @brief eval_buffers Evaluate if two given buffers are equal
* @param [in] buf Input buffer 1, mostly used as the 'computed'
Expand All @@ -182,6 +200,7 @@ static int eval_msg_disconnect(struct mqtt_test *mqtt_test);
static int eval_buffers(const struct buf_ctx *buf,
const u8_t *expected, u16_t len);


/**
* @brief print_array Prints the array 'a' of 'size' elements
* @param a The array
Expand Down Expand Up @@ -513,6 +532,19 @@ static ZTEST_DMEM
u8_t unsuback1[] = {0xb0, 0x02, 0x00, 0x01};
static ZTEST_DMEM struct mqtt_unsuback_param msg_unsuback1 = {.message_id = 1};

static ZTEST_DMEM
u8_t max_pkt_len[] = {0x30, 0xff, 0xff, 0xff, 0x7f};
pfl marked this conversation as resolved.
Show resolved Hide resolved
static ZTEST_DMEM struct buf_ctx max_pkt_len_buf = {
.cur = max_pkt_len, .end = max_pkt_len + sizeof(max_pkt_len)
};

static ZTEST_DMEM
u8_t corrupted_pkt_len[] = {0x30, 0xff, 0xff, 0xff, 0xff, 0x01};
static ZTEST_DMEM struct buf_ctx corrupted_pkt_len_buf = {
.cur = corrupted_pkt_len,
.end = corrupted_pkt_len + sizeof(corrupted_pkt_len)
};

static ZTEST_DMEM
struct mqtt_test mqtt_tests[] = {

Expand Down Expand Up @@ -608,6 +640,12 @@ struct mqtt_test mqtt_tests[] = {
.ctx = &msg_unsuback1, .eval_fcn = eval_msg_unsuback,
.expected = unsuback1, .expected_len = sizeof(unsuback1)},

{.test_name = "Maximum packet length",
.ctx = &max_pkt_len_buf, .eval_fcn = eval_max_pkt_len},

{.test_name = "Corrupted packet length",
.ctx = &corrupted_pkt_len_buf, .eval_fcn = eval_corrupted_pkt_len},

/* last test case, do not remove it */
{.test_name = NULL}
};
Expand Down Expand Up @@ -1018,6 +1056,36 @@ static int eval_msg_unsuback(struct mqtt_test *mqtt_test)
return TC_PASS;
}

static int eval_max_pkt_len(struct mqtt_test *mqtt_test)
{
struct buf_ctx *buf = (struct buf_ctx *)mqtt_test->ctx;
int rc;
u8_t flags;
u32_t length;

rc = fixed_header_decode(buf, &flags, &length);

zassert_equal(rc, 0, "fixed_header_decode failed");
zassert_equal(length, MQTT_MAX_PAYLOAD_SIZE,
"Invalid packet length decoded");

return TC_PASS;
}

static int eval_corrupted_pkt_len(struct mqtt_test *mqtt_test)
{
struct buf_ctx *buf = (struct buf_ctx *)mqtt_test->ctx;
int rc;
u8_t flags;
u32_t length;

rc = fixed_header_decode(buf, &flags, &length);

zassert_equal(rc, -EINVAL, "fixed_header_decode should fail");

return TC_PASS;
}

void test_mqtt_packet(void)
{
TC_START("MQTT Library test");
Expand Down