From 0a424e17604abd525320542b0e61371c0aa317ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Patryk=20J=C4=99drzejczak?= Date: Thu, 6 Jul 2023 09:50:13 +0200 Subject: [PATCH 1/2] db: hints: add the version_size constant The next commit changes the format of encoding sync points to V2. The new format appends the checksum to the encoded sync points and its implementation uses the checksum_size constant - the number of bytes required to store the checksum. To increase consistency and readability, we can additionally add and use the version_size constant. Definitions of sync_point::decode and sync_point::encode are slightly changed so that they don't depend on the version_size value and make implementation of the V2 format easier. --- db/hints/sync_point.cc | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/db/hints/sync_point.cc b/db/hints/sync_point.cc index 0f40da240c71..1351c260ee87 100644 --- a/db/hints/sync_point.cc +++ b/db/hints/sync_point.cc @@ -41,6 +41,8 @@ namespace hints { // Flattened representation was chosen in order to save space on // vector lengths etc. +static constexpr size_t version_size = sizeof(uint8_t); + static std::vector decode_one_type_v1(uint16_t shard_count, const per_manager_sync_point_v1& v1) { std::vector ret; @@ -72,11 +74,14 @@ sync_point sync_point::decode(sstring_view s) { if (raw.empty()) { throw std::runtime_error("Could not decode the sync point - not a valid hex string"); } - if (raw[0] != 1) { - throw std::runtime_error(format("Unsupported sync point format version: {}", int(raw[0]))); + + seastar::simple_memory_input_stream in{reinterpret_cast(raw.data()), raw.size()}; + + uint8_t version = ser::serializer::read(in); + if (version != 1) { + throw std::runtime_error(format("Unsupported sync point format version: {}", int(version))); } - seastar::simple_memory_input_stream in{reinterpret_cast(raw.data()) + 1, raw.size() - 1}; sync_point_v1 v1 = ser::serializer::read(in); return sync_point{ @@ -133,10 +138,10 @@ sstring sync_point::encode() const { seastar::measuring_output_stream measure; ser::serializer::write(measure, v1); - // Reserve 1 byte for the version - bytes serialized{bytes::initialized_later{}, 1 + measure.size()}; - serialized[0] = 1; - seastar::simple_memory_output_stream out{reinterpret_cast(serialized.data()), measure.size(), 1}; + // Reserve version_size bytes for the version + bytes serialized{bytes::initialized_later{}, version_size + measure.size()}; + seastar::simple_memory_output_stream out{reinterpret_cast(serialized.data()), serialized.size()}; + ser::serializer::write(out, 1); ser::serializer::write(out, v1); return base64_encode(serialized); From 02618831efad618866094f79c446745e378d0b4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Patryk=20J=C4=99drzejczak?= Date: Wed, 5 Jul 2023 15:41:55 +0200 Subject: [PATCH 2/2] db: hints: add checksum to sync point encoding sync point API provided with incorrect sync point id might allocate crazy amount of memory and fail with std::bad_alloc. To fix this, we can check if the encoded sync point has been modified before decoding. We can achieve this by calculating a checksum before encoding, appending it to the encoded sync point, and compering it with a checksum calculated in db::hints::decode before decoding. --- db/hints/sync_point.cc | 47 +++++++++++++++++++++++----- test/rest_api/test_hinted_handoff.py | 24 ++++++++++++++ 2 files changed, 64 insertions(+), 7 deletions(-) create mode 100644 test/rest_api/test_hinted_handoff.py diff --git a/db/hints/sync_point.cc b/db/hints/sync_point.cc index 1351c260ee87..614e271b37a3 100644 --- a/db/hints/sync_point.cc +++ b/db/hints/sync_point.cc @@ -17,13 +17,22 @@ #include "idl/hinted_handoff.dist.hh" #include "idl/hinted_handoff.dist.impl.hh" #include "utils/base64.hh" +#include "utils/xx_hasher.hh" namespace db { namespace hints { - +// Sync points can be encoded in two formats: V1 and V2. V2 extends V1 by adding +// a checksum. Currently, we use the V2 format, but sync points encoded in the V1 +// format still can be safely decoded. +// // Format V1 (encoded in base64): // uint8_t 0x01 - version of format -// sync_point_v1 - encoded using IMR +// sync_point_v1 - encoded using IDL +// +// Format V2 (encoded in base64): +// uint8_t 0x02 - version of format +// sync_point_v1 - encoded using IDL +// uint64_t - checksum computed using the xxHash algorithm // // sync_point_v1: // UUID host_id - ID of the host which created the sync point @@ -42,6 +51,7 @@ namespace hints { // vector lengths etc. static constexpr size_t version_size = sizeof(uint8_t); +static constexpr size_t checksum_size = sizeof(uint64_t); static std::vector decode_one_type_v1(uint16_t shard_count, const per_manager_sync_point_v1& v1) { std::vector ret; @@ -69,16 +79,34 @@ static std::vector decode_one_type_v1(uint16_t shard_coun return ret; } +static uint64_t calculate_checksum(const sstring_view s) { + xx_hasher h; + h.update(s.data(), s.size()); + return h.finalize_uint64(); +} + sync_point sync_point::decode(sstring_view s) { bytes raw = base64_decode(s); if (raw.empty()) { throw std::runtime_error("Could not decode the sync point - not a valid hex string"); } - seastar::simple_memory_input_stream in{reinterpret_cast(raw.data()), raw.size()}; + sstring_view raw_s(reinterpret_cast(raw.data()), raw.size()); + seastar::simple_memory_input_stream in{raw_s.data(), raw_s.size()}; uint8_t version = ser::serializer::read(in); - if (version != 1) { + if (version == 2) { + if (raw_s.size() < version_size + checksum_size) { + throw std::runtime_error("Could not decode the sync point encoded in the V2 format - serialized blob is too short"); + } + + seastar::simple_memory_input_stream in_checksum{raw_s.end() - checksum_size, checksum_size}; + uint64_t checksum = ser::serializer::read(in_checksum); + if (checksum != calculate_checksum(raw_s.substr(0, raw_s.size() - checksum_size))) { + throw std::runtime_error("Could not decode the sync point encoded in the V2 format - wrong checksum"); + } + } + else if (version != 1) { throw std::runtime_error(format("Unsupported sync point format version: {}", int(version))); } @@ -138,11 +166,16 @@ sstring sync_point::encode() const { seastar::measuring_output_stream measure; ser::serializer::write(measure, v1); - // Reserve version_size bytes for the version - bytes serialized{bytes::initialized_later{}, version_size + measure.size()}; + // Reserve version_size bytes for the version and checksum_size bytes for the checksum + bytes serialized{bytes::initialized_later{}, version_size + measure.size() + checksum_size}; + + // Encode using V2 format seastar::simple_memory_output_stream out{reinterpret_cast(serialized.data()), serialized.size()}; - ser::serializer::write(out, 1); + ser::serializer::write(out, 2); ser::serializer::write(out, v1); + sstring_view serialized_s(reinterpret_cast(serialized.data()), version_size + measure.size()); + uint64_t checksum = calculate_checksum(serialized_s); + ser::serializer::write(out, checksum); return base64_encode(serialized); } diff --git a/test/rest_api/test_hinted_handoff.py b/test/rest_api/test_hinted_handoff.py new file mode 100644 index 000000000000..378378c3fa2d --- /dev/null +++ b/test/rest_api/test_hinted_handoff.py @@ -0,0 +1,24 @@ +import requests +import urllib.parse +import base64 + +def test_sync_point_checksum(rest_api): + resp = rest_api.send('POST', "hinted_handoff/sync_point") + sync_point = resp.json() + # Decode the sync_point to bytes to ensure that every modification changes the data + # (multiple base64 encoded strings may represent a single binary value) + sync_point_b = base64.b64decode(sync_point.encode('ascii')) + + resp = rest_api.send('GET', "hinted_handoff/sync_point", { "id": urllib.parse.quote(sync_point) }) + assert resp.ok + + # Modify each sync_point's byte (except the first one) and send an incorrect request + # The first byte representing version is omitted, because changing it causes a different error + for i in range(1, len(sync_point_b)): + bad_sync_point_b = sync_point_b[:i] + bytes([(sync_point_b[i] + 1) % 255]) + sync_point_b[i + 1:] + bad_sync_point = base64.b64encode(bad_sync_point_b).decode('ascii') + + # Expect that checksum is different + resp = rest_api.send('GET', "hinted_handoff/sync_point", { "id": urllib.parse.quote(bad_sync_point) }) + assert resp.status_code == requests.codes.bad_request + assert "wrong checksum" in resp.json()['message']