Skip to content

Commit

Permalink
[multistep] Activations in client (VowpalWabbit#397)
Browse files Browse the repository at this point in the history
* [Binary parser] reward functions for ccb format (VowpalWabbit#361)

* [Binary parser] refactoring rewards (VowpalWabbit#366)

* Example gen add ccb loop for e2e testing (VowpalWabbit#368)

* [Binary parser] add e2e ccb test, compare dsjson and fb logged files create the same model (VowpalWabbit#369)

* minor binary parser cleanup (VowpalWabbit#370)

* [Binary parser] add external parser test for apprentice mode cb (VowpalWabbit#374)

* CCB apprentice reward (VowpalWabbit#373)

* [Binary parser] add metrics for cb (VowpalWabbit#371)

* [Binary parser] don't log when skip learn, more tests, skip over unknown msg type (VowpalWabbit#375)

* [binary parser] ccb skip learn (VowpalWabbit#376)

* refactor: add error message to fix config file (VowpalWabbit#377)

* Fix CI's after flatbuffer version update to 2.0 (VowpalWabbit#390)

* try set fb span minimal

* add to preprocessor definitions

* add to unit_test project file

* Revert "mac ci: continue on error true (VowpalWabbit#327)" (VowpalWabbit#385)

* Fix python build path on windows, and formatting. (VowpalWabbit#383)

* Update build_docs.yml (VowpalWabbit#391)

* only convert timestamp to string before exiting (VowpalWabbit#382)

* ntohl is a define on osx, rename the function. (VowpalWabbit#386)

* Add bunch of nice to haves CLI options and fix FB 2.0 compat. (VowpalWabbit#387)

* our build requires CMP0074 due to usage of PackageName_ROOT variables. (VowpalWabbit#393)

* our build requires CMP0074 due to usage of PackageName_ROOT variables.

* try to use cmake_policy

* Activations in multistep: first PR with schema changes only (VowpalWabbit#392)

* deferred action to multistep schema

* Multistep to problem type

* try to set cmake policy for CMP0074

* OLD -> NEW

* try default policy for cmp0074

* build fix

* flags to request_episodic_decision

* episodic decision: deferred action implementation

* report_action_taken for secondary index

* formatting fixes

Co-authored-by: cheng-tan <chengtan2013@gmail.com>
Co-authored-by: olgavrou <olgavrou@gmail.com>
Co-authored-by: Griffin Bassman <griffinbassman@gmail.com>
Co-authored-by: Eduardo Salinas <edus@microsoft.com>
Co-authored-by: zwd-ms <71728747+zwd-ms@users.noreply.github.com>
Co-authored-by: Rodrigo Kumpera <kumpera@users.noreply.github.com>
  • Loading branch information
7 people committed Oct 27, 2021
1 parent 5710163 commit deab5ed
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 9 deletions.
14 changes: 13 additions & 1 deletion include/live_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,8 @@ namespace reinforcement_learning {
int request_multi_slot_decision(const char * event_id, const char * context_json, unsigned int flags, multi_slot_response_detailed& resp, const int* baseline_actions, size_t baseline_actions_size, api_status* status = nullptr);

//multistep
int request_episodic_decision(const char *event_id, const char *previous_id, const char *context_json, ranking_response &resp, episode_state &episode, api_status *status = nullptr);
int request_episodic_decision(const char* event_id, const char* previous_id, const char* context_json, ranking_response& resp, episode_state& episode, api_status* status = nullptr);
int request_episodic_decision(const char* event_id, const char* previous_id, const char* context_json, unsigned int flags, ranking_response& resp, episode_state& episode, api_status* status = nullptr);

/**
* @brief Report that action was taken.
Expand All @@ -278,6 +279,17 @@ namespace reinforcement_learning {
*/
int report_action_taken(const char* event_id, api_status* status = nullptr);

/**
* @brief Report that action was taken.
*
* @param primary_id The unique primary_id used when choosing an action should be presented here. This is so that
* the action taken can be matched with feedback received.
* @param secondary_id Index of the partial outcome.
* @param status Optional field with detailed string description if there is an error
* @return int Return error code. This will also be returned in the api_status object
*/
int report_action_taken(const char* primary_id, const char* secondary_id, api_status* status = nullptr);

/**
* @brief Report the outcome for the top action.
*
Expand Down
13 changes: 12 additions & 1 deletion rlclientlib/live_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,17 @@ namespace reinforcement_learning

int live_model::request_episodic_decision(const char* event_id, const char* previous_id, const char* context_json, ranking_response& resp, episode_state& episode, api_status* status) {
INIT_CHECK();
return _pimpl->request_episodic_decision(event_id, previous_id, context_json, resp, episode, status);
return _pimpl->request_episodic_decision(event_id, previous_id, context_json, action_flags::DEFAULT, resp, episode, status);
}

int live_model::request_episodic_decision(const char* event_id, const char* previous_id, const char* context_json, unsigned int flags, ranking_response& resp, episode_state& episode, api_status* status) {
INIT_CHECK();
return _pimpl->request_episodic_decision(event_id, previous_id, context_json, flags, resp, episode, status);
}

int live_model::report_action_taken(const char* primary_id, const char* secondary_id, api_status* status) {
INIT_CHECK();
return _pimpl->report_action_taken(primary_id, secondary_id, status);
}

}
11 changes: 9 additions & 2 deletions rlclientlib/live_model_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,13 @@ namespace reinforcement_learning {
return _outcome_logger->report_action_taken(event_id, status);
}

int live_model_impl::report_action_taken(const char* primary_id, const char* secondary_id, api_status* status) {
// Clear previous errors if any
api_status::try_clear(status);
// Send the outcome event to the backend
return _outcome_logger->report_action_taken(primary_id, secondary_id, status);
}

int live_model_impl::report_outcome(const char* event_id, const char* outcome, api_status* status) {
// Check arguments
RETURN_IF_FAIL(check_null_or_empty(event_id, outcome, _trace_logger.get(), status));
Expand Down Expand Up @@ -585,7 +592,7 @@ namespace reinforcement_learning {
return refresh_model(status);
}

int live_model_impl::request_episodic_decision(const char* event_id, const char* previous_id, const char* context_json, ranking_response& resp, episode_state& episode, api_status* status) {
int live_model_impl::request_episodic_decision(const char* event_id, const char* previous_id, const char* context_json, unsigned int flags, ranking_response& resp, episode_state& episode, api_status* status) {
resp.clear();
//clear previous errors if any
api_status::try_clear(status);
Expand All @@ -608,7 +615,7 @@ namespace reinforcement_learning {
resp.set_event_id(event_id);

RETURN_IF_FAIL(episode.update(event_id, previous_id, context_json, resp, status));
RETURN_IF_FAIL(_interaction_logger->log(episode.get_episode_id(), previous_id, context_patched.c_str(), resp, status));
RETURN_IF_FAIL(_interaction_logger->log(episode.get_episode_id(), previous_id, context_patched.c_str(), flags, resp, status));
return error_code::success;
}

Expand Down
2 changes: 2 additions & 0 deletions rlclientlib/live_model_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ namespace reinforcement_learning
int request_multi_slot_decision(const char* context_json, unsigned int flags, multi_slot_response& resp, const std::vector<int>& baseline_actions, api_status* status = nullptr);
int request_multi_slot_decision(const char* event_id, const char* context_json, unsigned int flags, multi_slot_response_detailed& resp, const std::vector<int>& baseline_actions, api_status* status = nullptr);
int request_multi_slot_decision(const char* context_json, unsigned int flags, multi_slot_response_detailed& resp, const std::vector<int>& baseline_actions, api_status* status = nullptr);
int request_episodic_decision(const char *event_id, const char *previous_id, const char *context_json, unsigned int flags, ranking_response &resp, episode_state &episode, api_status *status = nullptr);
int request_episodic_decision(const char *event_id, const char *previous_id, const char *context_json, ranking_response &resp, episode_state &episode, api_status *status = nullptr);

int report_action_taken(const char* event_id, api_status* status);
int report_action_taken(const char* primary_id, const char *secondary_id, api_status* status);

int report_outcome(const char* event_id, const char* outcome_data, api_status* status);
int report_outcome(const char* event_id, float reward, api_status* status);
Expand Down
11 changes: 9 additions & 2 deletions rlclientlib/logger/logger_facade.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ namespace reinforcement_learning {
}
}

int interaction_logger_facade::log(const char* episode_id, const char* previous_id, const char* context, const ranking_response& response, api_status* status) {
int interaction_logger_facade::log(const char* episode_id, const char* previous_id, const char* context, unsigned int flags, const ranking_response& response, api_status* status) {
switch (_version) {
case 2: {
generic_event::object_list_t actions;
generic_event::payload_buffer_t payload;
event_content_type content_type;

RETURN_IF_FAIL(wrap_log_call(_ext, _multistep_serializer, context, actions, payload, content_type, status, previous_id, response));
RETURN_IF_FAIL(wrap_log_call(_ext, _multistep_serializer, context, actions, payload, content_type, status, previous_id, flags, response));
return _v2->log(episode_id, std::move(payload), _multistep_serializer.type, content_type, std::move(actions), status);
}
default: return protocol_not_supported(status);
Expand Down Expand Up @@ -262,5 +262,12 @@ namespace reinforcement_learning {
default: return protocol_not_supported(status);
}
}

int observation_logger_facade::report_action_taken(const char* primary_id, const char* secondary_id, api_status* status) {
switch (_version) {
case 2: return _v2->log(primary_id, _serializer.report_action_taken(secondary_id), _serializer.type, event_content_type::IDENTITY, status);
default: return protocol_not_supported(status);
}
}
}
}
3 changes: 2 additions & 1 deletion rlclientlib/logger/logger_facade.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ namespace reinforcement_learning
//CB v1/v2
int log(const char* context, unsigned int flags, const ranking_response& response, api_status* status, learning_mode learning_mode = ONLINE);

int log(const char* episode_id, const char* previous_id, const char* context, const ranking_response& response, api_status* status);
int log(const char* episode_id, const char* previous_id, const char* context, unsigned int flags, const ranking_response& response, api_status* status);
const multistep_serializer _multistep_serializer;
int log_decisions(std::vector<const char*>& event_ids, const char* context, unsigned int flags, const std::vector<std::vector<uint32_t>>& action_ids,
const std::vector<std::vector<float>>& pdfs, const std::string& model_version, api_status* status);
Expand Down Expand Up @@ -108,6 +108,7 @@ namespace reinforcement_learning
int log(const char* event_id, const char* index, const char* outcome, api_status* status);

int report_action_taken(const char* event_id, api_status* status);
int report_action_taken(const char* event_id, const char* index, api_status* status);
private:
const int _version;
int _serializer_shared_state;
Expand Down
13 changes: 11 additions & 2 deletions rlclientlib/serialization/payload_serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,18 @@ namespace reinforcement_learning {
fbb.Finish(fb);
return fbb.Release();
}

static generic_event::payload_buffer_t report_action_taken(const char* index) {
flatbuffers::FlatBufferBuilder fbb;
const auto idx = fbb.CreateString(index).Union();
auto fb = v2::CreateOutcomeEvent(fbb, v2::OutcomeValue_NONE, 0, v2::IndexValue_literal, idx, true);
fbb.Finish(fb);
return fbb.Release();
}
};

struct multistep_serializer : payload_serializer<generic_event::payload_type_t::PayloadType_MultiStep> {
static generic_event::payload_buffer_t event(const char* context, const char* previous_id, const ranking_response& response) {
static generic_event::payload_buffer_t event(const char* context, const char* previous_id, unsigned int flags, const ranking_response& response) {
flatbuffers::FlatBufferBuilder fbb;
std::vector<uint64_t> action_ids;
std::vector<float> probabilities;
Expand All @@ -180,7 +188,8 @@ namespace reinforcement_learning {
std::string context_str(context);
copy(context_str.begin(), context_str.end(), std::back_inserter(_context));

auto fb = v2::CreateMultiStepEventDirect(fbb, response.get_event_id(), previous_id, &action_ids, &_context, &probabilities, response.get_model_id());
auto fb = v2::CreateMultiStepEventDirect(fbb, response.get_event_id(), previous_id, &action_ids,
&_context, &probabilities, response.get_model_id(), flags & action_flags::DEFERRED);
fbb.Finish(fb);
return fbb.Release();
}
Expand Down

0 comments on commit deab5ed

Please sign in to comment.