Skip to content

Commit

Permalink
feat: Support passing in shard metadata via data_cli flags
Browse files Browse the repository at this point in the history
Bug: b/311382140
Change-Id: I0453195f6025d1db611ee17cf2a1b4a7f59813ff
GitOrigin-RevId: a1c6676596b87c4428adda0eab1bc1825413f5ce
  • Loading branch information
kelvintatendagorekore authored and Privacy Sandbox Team committed Nov 17, 2023
1 parent 2cad819 commit 95c0b96
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 49 deletions.
2 changes: 2 additions & 0 deletions tools/data_cli/commands/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ cc_library(
"//public/data_loading/readers:delta_record_stream_reader",
"//public/data_loading/writers:delta_record_stream_writer",
"//public/data_loading/writers:delta_record_writer",
"//public/sharding:sharding_function",
"@com_github_google_glog//:glog",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
Expand Down Expand Up @@ -71,6 +72,7 @@ cc_library(
"//public/data_loading:riegeli_metadata_cc_proto",
"//public/data_loading/readers:delta_record_stream_reader",
"//public/data_loading/writers:snapshot_stream_writer",
"//public/sharding:sharding_function",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down
40 changes: 33 additions & 7 deletions tools/data_cli/commands/format_data_command.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "public/data_loading/csv/csv_delta_record_stream_writer.h"
#include "public/data_loading/readers/delta_record_stream_reader.h"
#include "public/data_loading/writers/delta_record_stream_writer.h"
#include "public/sharding/sharding_function.h"
#include "src/cpp/util/status_macro/status_macros.h"

namespace kv_server {
Expand Down Expand Up @@ -60,6 +61,14 @@ absl::Status ValidateParams(const FormatDataCommand::Params& params) {
"Input and output format must be different. Input format: ",
params.input_format, " Output format: ", params.output_format));
}
if (params.shard_number >= 0 &&
params.number_of_shards <= params.shard_number) {
return absl::InvalidArgumentError(absl::StrCat(
"Shard metadata is invalid. shard_number is ", params.shard_number,
" and number_of_shards is ", params.number_of_shards,
". Valid inputs must satisfy the requirement: 0 <= shard_number < "
"number_of_shards"));
}
return absl::OkStatus();
}

Expand Down Expand Up @@ -128,6 +137,10 @@ absl::StatusOr<std::unique_ptr<DeltaRecordWriter>> CreateRecordWriter(
}
if (lw_output_format == kDeltaFormat) {
KVFileMetadata metadata;
if (params.shard_number >= 0) {
auto* shard_metadata = metadata.mutable_sharding_metadata();
shard_metadata->set_shard_num(params.shard_number);
}
auto delta_record_writer = DeltaRecordStreamWriter<std::ostream>::Create(
output_stream, DeltaRecordWriter::Options{.metadata = metadata});
if (!delta_record_writer.ok()) {
Expand All @@ -142,8 +155,7 @@ absl::StatusOr<std::unique_ptr<DeltaRecordWriter>> CreateRecordWriter(
} // namespace

absl::StatusOr<std::unique_ptr<FormatDataCommand>> FormatDataCommand::Create(
const Params& params, std::istream& input_stream,
std::ostream& output_stream) {
Params params, std::istream& input_stream, std::ostream& output_stream) {
if (absl::Status status = ValidateParams(params); !status.ok()) {
return status;
}
Expand All @@ -155,21 +167,35 @@ absl::StatusOr<std::unique_ptr<FormatDataCommand>> FormatDataCommand::Create(
if (!record_writer.ok()) {
return record_writer.status();
}
return absl::WrapUnique(new FormatDataCommand(std::move(*record_reader),
std::move(*record_writer)));
return absl::WrapUnique(new FormatDataCommand(
std::move(*record_reader), std::move(*record_writer), params));
}

absl::Status FormatDataCommand::Execute() {
LOG(INFO) << "Formatting records ...";
int64_t records_count = 0;
ShardingFunction sharding_function(/*seed=*/"");
absl::Status status = record_reader_->ReadRecords(
[record_writer = record_writer_.get(),
&records_count](DataRecordStruct data_record) {
[&records_count, &sharding_function, this](DataRecordStruct data_record) {
if (params_.shard_number >= 0 &&
std::holds_alternative<KeyValueMutationRecordStruct>(
data_record.record)) {
KeyValueMutationRecordStruct record_struct =
std::get<KeyValueMutationRecordStruct>(data_record.record);
auto record_shard_num = sharding_function.GetShardNumForKey(
record_struct.key, params_.number_of_shards);
if (params_.shard_number != record_shard_num) {
LOG(INFO) << "Skipping record with key: " << record_struct.key
<< " . The record belongs to shard: " << record_shard_num
<< ", but shard_number is " << params_.shard_number;
return absl::OkStatus();
}
}
records_count++;
if ((double)std::rand() / RAND_MAX <= kSamplingThreshold) {
LOG(INFO) << "Formatting record: " << records_count;
}
return record_writer->WriteRecord(data_record);
return record_writer_->WriteRecord(data_record);
});
record_writer_->Close();
LOG(INFO) << "Sucessfully formated records.";
Expand Down
24 changes: 14 additions & 10 deletions tools/data_cli/commands/format_data_command.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,27 +55,31 @@ namespace kv_server {
class FormatDataCommand : public Command {
public:
struct Params {
std::string_view input_format;
std::string_view output_format;
char csv_column_delimiter;
char csv_value_delimiter;
std::string_view record_type;
std::string_view csv_encoding = "PLAINTEXT";
std::string input_format = "CSV";
std::string output_format = "DELTA";
char csv_column_delimiter = ',';
char csv_value_delimiter = '|';
std::string record_type = "KEY_VALUE_MUTATION_RECORD";
std::string csv_encoding = "PLAINTEXT";
int64_t shard_number = -1;
int64_t number_of_shards = -1;
};

static absl::StatusOr<std::unique_ptr<FormatDataCommand>> Create(
const Params& params, std::istream& input_stream,
std::ostream& output_stream);
Params params, std::istream& input_stream, std::ostream& output_stream);
absl::Status Execute() override;

private:
FormatDataCommand(std::unique_ptr<DeltaRecordReader> record_reader,
std::unique_ptr<DeltaRecordWriter> record_writer)
std::unique_ptr<DeltaRecordWriter> record_writer,
Params params)
: record_reader_(std::move(record_reader)),
record_writer_(std::move(record_writer)) {}
record_writer_(std::move(record_writer)),
params_(std::move(params)) {}

std::unique_ptr<DeltaRecordReader> record_reader_;
std::unique_ptr<DeltaRecordWriter> record_writer_;
Params params_;
};

} // namespace kv_server
Expand Down
47 changes: 46 additions & 1 deletion tools/data_cli/commands/format_data_command_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ FormatDataCommand::Params GetParams(
.output_format = "DELTA",
.csv_column_delimiter = ',',
.csv_value_delimiter = '|',
.record_type = std::move(record_type),
.record_type = std::string(record_type),
};
}

Expand Down Expand Up @@ -416,5 +416,50 @@ TEST(FormatDataCommandTest, ValidateIncorrectOutputParams) {
<< status;
}

TEST(FormatDataCommandTest,
ValidateGeneratingCsvToDeltaData_KVMutations_ShardNum) {
std::stringstream csv_stream;
std::stringstream delta_stream;
CsvDeltaRecordStreamWriter csv_writer(csv_stream);
const auto& record = GetDataRecord(GetKVMutationRecord());
EXPECT_TRUE(csv_writer.WriteRecord(record).ok());
EXPECT_TRUE(csv_writer.WriteRecord(record).ok());
EXPECT_TRUE(csv_writer.WriteRecord(record).ok());
csv_writer.Close();
EXPECT_FALSE(csv_stream.str().empty());
auto params = GetParams();
params.shard_number = 2;
params.number_of_shards = 3;
auto command = FormatDataCommand::Create(params, csv_stream, delta_stream);
EXPECT_TRUE(command.ok()) << command.status();
EXPECT_TRUE((*command)->Execute().ok());
DeltaRecordStreamReader delta_reader(delta_stream);
auto metadata = delta_reader.ReadMetadata();
EXPECT_TRUE(metadata.ok());
EXPECT_EQ(metadata->sharding_metadata().shard_num(), 2);
testing::MockFunction<absl::Status(DataRecordStruct)> record_callback;
EXPECT_CALL(record_callback, Call)
.Times(3)
.WillRepeatedly([&record](DataRecordStruct actual_record) {
EXPECT_EQ(actual_record, record);
return absl::OkStatus();
});
EXPECT_TRUE(delta_reader.ReadRecords(record_callback.AsStdFunction()).ok());
}

TEST(FormatDataCommandTest, ValidateIncorrectShardingMetadataParams) {
std::stringstream unused_stream;
auto params = GetParams();
params.shard_number = 2;
absl::Status status =
FormatDataCommand::Create(params, unused_stream, unused_stream).status();
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument) << status;
EXPECT_STREQ(status.message().data(),
"Shard metadata is invalid. shard_number is 2 and "
"number_of_shards is -1. Valid inputs must satisfy the "
"requirement: 0 <= shard_number < number_of_shards")
<< status;
}

} // namespace
} // namespace kv_server
61 changes: 43 additions & 18 deletions tools/data_cli/commands/generate_snapshot_command.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "public/data_loading/filename_utils.h"
#include "public/data_loading/readers/delta_record_stream_reader.h"
#include "public/data_loading/riegeli_metadata.pb.h"
#include "public/sharding/sharding_function.h"
#include "src/cpp/telemetry/telemetry_provider.h"

namespace kv_server {
Expand Down Expand Up @@ -80,6 +81,14 @@ absl::Status ValidateRequiredParams(GenerateSnapshotCommand::Params& params) {
!IsDeltaFilename(params.ending_delta_file)) {
return absl::InvalidArgumentError("Ending delta file is not valid.");
}
if (params.shard_number >= 0 &&
params.number_of_shards <= params.shard_number) {
return absl::InvalidArgumentError(absl::StrCat(
"Shard metadata is invalid. shard_number is ", params.shard_number,
" and number_of_shards is ", params.number_of_shards,
". Valid inputs must satisfy the requirement: 0 <= shard_number < "
"number_of_shards"));
}
return absl::OkStatus();
}

Expand Down Expand Up @@ -109,6 +118,10 @@ absl::StatusOr<KVFileMetadata> CreateSnapshotMetadata(
auto snapshot_metadata = metadata.mutable_snapshot();
*snapshot_metadata->mutable_starting_file() = params.starting_file;
*snapshot_metadata->mutable_ending_delta_file() = params.ending_delta_file;
if (params.shard_number >= 0) {
auto* sharding_metadata = metadata.mutable_sharding_metadata();
sharding_metadata->set_shard_num(params.shard_number);
}
return metadata;
}

Expand All @@ -117,6 +130,32 @@ void ResetInputStream(std::istream& istream) {
istream.seekg(0, std::ios::beg);
}

absl::Status WriteRecordsToSnapshotStream(
const GenerateSnapshotCommand::Params& params,
DeltaRecordStreamReader<std::istream>& record_reader,
SnapshotStreamWriter<std::ostream>& snapshot_writer) {
ShardingFunction sharding_function(/*seed=*/"");
return record_reader.ReadRecords(
[&params, &snapshot_writer,
&sharding_function](DataRecordStruct data_record) {
if (params.shard_number >= 0 &&
std::holds_alternative<KeyValueMutationRecordStruct>(
data_record.record)) {
KeyValueMutationRecordStruct record_struct =
std::get<KeyValueMutationRecordStruct>(data_record.record);
auto record_shard_num = sharding_function.GetShardNumForKey(
record_struct.key, params.number_of_shards);
if (params.shard_number != record_shard_num) {
LOG(INFO) << "Skipping record with key: " << record_struct.key
<< " . The record belongs to shard: " << record_shard_num
<< ", but shard_number is " << params.shard_number;
return absl::OkStatus();
}
}
return snapshot_writer.WriteRecord(data_record);
});
}

absl::StatusOr<std::string> WriteBaseSnapshotData(
const GenerateSnapshotCommand::Params& params,
BlobStorageClient& blob_client,
Expand All @@ -129,13 +168,8 @@ absl::StatusOr<std::string> WriteBaseSnapshotData(
if (!metadata.ok()) {
return metadata.status();
}
if (blob_reader->CanSeek()) {
ResetInputStream(blob_reader->Stream());
} else {
blob_reader = blob_client.GetBlobReader(
{.bucket = params.data_dir.data(), .key = params.starting_file.data()});
}
if (auto status = snapshot_writer.WriteRecordStream(blob_reader->Stream());
if (auto status =
WriteRecordsToSnapshotStream(params, record_reader, snapshot_writer);
!status.ok()) {
return status;
}
Expand Down Expand Up @@ -163,17 +197,8 @@ absl::Status WriteDeltaFilesToSnapshot(
auto blob_reader = blob_client.GetBlobReader(
{.bucket = params.data_dir.data(), .key = delta_file});
DeltaRecordStreamReader record_reader(blob_reader->Stream());
auto metadata = record_reader.ReadMetadata();
if (!metadata.ok()) {
return metadata.status();
}
if (blob_reader->CanSeek()) {
ResetInputStream(blob_reader->Stream());
} else {
blob_reader = blob_client.GetBlobReader(
{.bucket = params.data_dir.data(), .key = delta_file});
}
if (auto status = snapshot_writer.WriteRecordStream(blob_reader->Stream());
if (auto status = WriteRecordsToSnapshotStream(params, record_reader,
snapshot_writer);
!status.ok()) {
return status;
}
Expand Down
2 changes: 2 additions & 0 deletions tools/data_cli/commands/generate_snapshot_command.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class GenerateSnapshotCommand : public Command {
std::string ending_delta_file;
std::string snapshot_file;
bool in_memory_compaction;
int64_t shard_number = -1;
int64_t number_of_shards = -1;
};

~GenerateSnapshotCommand();
Expand Down
Loading

0 comments on commit 95c0b96

Please sign in to comment.