From ed71525c9967a0ba7c2978ed27c1f1fd5e5072ea Mon Sep 17 00:00:00 2001 From: Sorin Jianu Date: Wed, 5 Jun 2019 09:44:22 -0700 Subject: [PATCH] Implement the Omaha Cloud Policies Fetcher. --- omaha/base/constants.h | 12 +- omaha/common/config_manager.cc | 9 + omaha/common/config_manager.h | 5 + omaha/common/const_goopdate.h | 1 + omaha/goopdate/dm_client.cc | 204 +++++++++++++++----- omaha/goopdate/dm_client.h | 27 ++- omaha/goopdate/dm_client_unittest.cc | 258 +++++++++++++++++++++----- omaha/goopdate/dm_messages.cc | 78 +++++++- omaha/goopdate/dm_messages.h | 13 ++ omaha/goopdate/dm_storage.cc | 120 +++++++++++- omaha/goopdate/dm_storage.h | 42 ++++- omaha/goopdate/dm_storage_unittest.cc | 187 +++++++++++++------ omaha/goopdate/goopdate.cc | 50 +++-- omaha/setup/setup_google_update.cc | 21 +-- 14 files changed, 839 insertions(+), 188 deletions(-) diff --git a/omaha/base/constants.h b/omaha/base/constants.h index ce18890a9..577469e9f 100644 --- a/omaha/base/constants.h +++ b/omaha/base/constants.h @@ -178,19 +178,25 @@ const TCHAR* const kChromeAppId = CHROME_APP_ID; // // Directory names // -#define OFFLINE_DIR_NAME _T("Offline") +#define OFFLINE_DIR_NAME _T("Offline") +#define DOWNLOAD_DIR_NAME _T("Download") +#define INSTALL_WORKING_DIR_NAME _T("Install") +// Directories relative to \Google #define OMAHA_REL_COMPANY_DIR SHORT_COMPANY_NAME #define OMAHA_REL_CRASH_DIR OMAHA_REL_COMPANY_DIR _T("\\CrashReports") +#define OMAHA_REL_POLICY_RESPONSES_DIR OMAHA_REL_COMPANY_DIR _T("\\Policies") + +// Directories relative to \Google\Update #define OMAHA_REL_GOOPDATE_INSTALL_DIR \ OMAHA_REL_COMPANY_DIR _T("\\") PRODUCT_NAME #define OMAHA_REL_LOG_DIR OMAHA_REL_GOOPDATE_INSTALL_DIR _T("\\Log") #define OMAHA_REL_OFFLINE_STORAGE_DIR \ OMAHA_REL_GOOPDATE_INSTALL_DIR _T("\\") OFFLINE_DIR_NAME #define OMAHA_REL_DOWNLOAD_STORAGE_DIR \ - OMAHA_REL_GOOPDATE_INSTALL_DIR _T("\\Download") + OMAHA_REL_GOOPDATE_INSTALL_DIR _T("\\") DOWNLOAD_DIR_NAME #define OMAHA_REL_INSTALL_WORKING_DIR \ - OMAHA_REL_GOOPDATE_INSTALL_DIR _T("\\Install") + OMAHA_REL_GOOPDATE_INSTALL_DIR _T("\\") INSTALL_WORKING_DIR_NAME // This directory is relative to the user profile app data local. #define LOCAL_APPDATA_REL_TEMP_DIR _T("\\Temp") diff --git a/omaha/common/config_manager.cc b/omaha/common/config_manager.cc index afcf6faf2..251d62d06 100644 --- a/omaha/common/config_manager.cc +++ b/omaha/common/config_manager.cc @@ -524,6 +524,15 @@ HRESULT ConfigManager::GetDeviceManagementUrl(CString* url) const { return S_OK; } +CPath ConfigManager::GetPolicyResponsesDir() const { + CString path; + VERIFY1(SUCCEEDED(GetDir32(CSIDL_PROGRAM_FILES, + CString(OMAHA_REL_POLICY_RESPONSES_DIR), + true, + &path))); + return CPath(path); +} + #endif // defined(HAS_DEVICE_MANAGEMENT) // Returns the override from the registry locations if present. Otherwise, diff --git a/omaha/common/config_manager.h b/omaha/common/config_manager.h index a377cc32b..96e45e550 100644 --- a/omaha/common/config_manager.h +++ b/omaha/common/config_manager.h @@ -24,6 +24,7 @@ #define OMAHA_COMMON_CONFIG_MANAGER_H_ #include +#include #include #include "base/basictypes.h" #include "omaha/base/constants.h" @@ -175,6 +176,10 @@ class ConfigManager { #if defined(HAS_DEVICE_MANAGEMENT) // Returns the Device Management API url. HRESULT GetDeviceManagementUrl(CString* url) const; + + // Returns the directory under which the Device Management policies are + // persisted. + CPath GetPolicyResponsesDir() const; #endif // Returns the time interval between update checks in seconds. diff --git a/omaha/common/const_goopdate.h b/omaha/common/const_goopdate.h index 813ec6635..eac72b8ef 100644 --- a/omaha/common/const_goopdate.h +++ b/omaha/common/const_goopdate.h @@ -332,6 +332,7 @@ const int kNetworkRequestEventId = 20; // Device management events. const int kEnrollmentFailedEventId = 30; const int kEnrollmentRequiresNetworkEventId = 31; +const int kRefreshPoliciesFailedEventId = 32; // Maximum value the server can respond for elapsed_seconds attribute in // element. The value is one day plus an hour ("fall back" diff --git a/omaha/goopdate/dm_client.cc b/omaha/goopdate/dm_client.cc index b1de83ef2..0469ef60b 100644 --- a/omaha/goopdate/dm_client.cc +++ b/omaha/goopdate/dm_client.cc @@ -14,7 +14,9 @@ #include "omaha/goopdate/dm_client.h" +#include #include +#include #include "omaha/base/app_util.h" #include "omaha/base/constants.h" @@ -55,6 +57,13 @@ HRESULT RegisterIfNeeded(DmStorage* dm_storage) { ASSERT1(dm_storage); OPT_LOG(L1, (_T("[DmClient::RegisterIfNeeded]"))); + // No work to be done if the process is not running as an administrator, since + // we will not be able to persist anything. + if (!::IsUserAnAdmin()) { + OPT_LOG(L1, (_T("[RegisterIfNeeded][Process not Admin, exiting early]"))); + return S_FALSE; + } + // No work to be done if a DM token was found. CStringA dm_token = dm_storage->GetDmToken(); if (!dm_token.IsEmpty()) { @@ -76,8 +85,10 @@ HRESULT RegisterIfNeeded(DmStorage* dm_storage) { return E_FAIL; } + // RegisterWithRequest owns the SimpleRequest being created here. HRESULT hr = internal::RegisterWithRequest(new SimpleRequest, - enrollment_token, device_id, + enrollment_token, + device_id, &dm_token); if (FAILED(hr)) { return hr; @@ -94,6 +105,53 @@ HRESULT RegisterIfNeeded(DmStorage* dm_storage) { return S_OK; } +HRESULT RefreshPolicies() { + // No work to be done if the process is not running as an administrator, since + // we will not be able to persist anything. + if (!::IsUserAnAdmin()) { + OPT_LOG(L1, (_T("[RefreshPolicies][Process not Admin, exiting early]"))); + return S_FALSE; + } + + DmStorage* const dm_storage = DmStorage::Instance(); + const CString dm_token = CString(dm_storage->GetDmToken()); + if (dm_token.IsEmpty()) { + OPT_LOG(L1, (_T("[Skipping RefreshPolicies as there is no DMToken]"))); + return S_FALSE; + } + + const CString device_id = dm_storage->GetDeviceId(); + if (device_id.IsEmpty()) { + REPORT_LOG(LE, (_T("[Device ID not found]"))); + return E_FAIL; + } + + PolicyResponsesMap responses; + + // FetchPolicies owns the SimpleRequest being created here. + HRESULT hr = internal::FetchPolicies(new SimpleRequest, + dm_token, + device_id, + &responses); + if (FAILED(hr)) { + REPORT_LOG(LE, (_T("[FetchPolicies failed][%#x]"), hr)); + return hr; + } + + const CPath policy_responses_dir( + ConfigManager::Instance()->GetPolicyResponsesDir()); + + hr = DmStorage::PersistPolicies(policy_responses_dir, responses); + if (FAILED(hr)) { + REPORT_LOG(LE, (_T("[PersistPolicies failed][%#x]"), hr)); + return hr; + } + + OPT_LOG(L1, (_T("[RefreshPolicies complete]"))); + + return S_OK; +} + namespace internal { HRESULT RegisterWithRequest(HttpRequestInterface* http_request, @@ -102,6 +160,94 @@ HRESULT RegisterWithRequest(HttpRequestInterface* http_request, CStringA* dm_token) { ASSERT1(http_request); ASSERT1(dm_token); + + std::vector> query_params = { + {_T("request"), _T("register_policy_agent")}, + }; + + // Make the request payload. + CStringA payload = SerializeRegisterBrowserRequest( + WideToUtf8(app_util::GetHostName()), + CStringA("Windows"), + internal::GetOsVersion()); + if (payload.IsEmpty()) { + REPORT_LOG(LE, (_T("[SerializeRegisterBrowserRequest failed]"))); + return E_FAIL; + } + + std::vector response; + HRESULT hr = SendDeviceManagementRequest( + http_request, + payload, + internal::FormatEnrollmentTokenAuthorizationHeader(enrollment_token), + device_id, + std::move(query_params), + &response); + if (FAILED(hr)) { + REPORT_LOG(LE, (_T("[SendDeviceManagementRequest failed][%#x]"), hr)); + return hr; + } + + hr = ParseDeviceRegisterResponse(response, dm_token); + if (FAILED(hr)) { + REPORT_LOG(LE, (_T("[ParseDeviceRegisterResponse failed][%#x]"), hr)); + return hr; + } + + return S_OK; +} + +HRESULT FetchPolicies(HttpRequestInterface* http_request, + const CString& dm_token, + const CString& device_id, + PolicyResponsesMap* responses) { + ASSERT1(http_request); + ASSERT1(!dm_token.IsEmpty()); + ASSERT1(responses); + + std::vector> query_params = { + {_T("request"), _T("policy")}, + }; + + CStringA payload = SerializePolicyFetchRequest( + CStringA(kGoogleUpdateMachineLevelApps)); + if (payload.IsEmpty()) { + REPORT_LOG(LE, (_T("[SerializePolicyFetchRequest failed]"))); + return E_FAIL; + } + + std::vector response; + HRESULT hr = SendDeviceManagementRequest( + http_request, + payload, + FormatDMTokenAuthorizationHeader(dm_token), + device_id, + std::move(query_params), + &response); + if (FAILED(hr)) { + REPORT_LOG(LE, (_T("[SendDeviceManagementRequest failed][%#x]"), hr)); + return hr; + } + + hr = ParseDevicePolicyResponse(response, responses); + if (FAILED(hr)) { + REPORT_LOG(LE, (_T("[ParseDeviceRegisterResponse failed][%#x]"), hr)); + return hr; + } + + return S_OK; +} + +HRESULT SendDeviceManagementRequest( + HttpRequestInterface* http_request, + const CStringA& payload, + const CString& authorization_header, + const CString& device_id, + std::vector> query_params, + std::vector* response) { + ASSERT1(http_request); + ASSERT1(response); + // Get the network configuration. NetworkConfig* network_config = NULL; NetworkConfigManager& network_config_manager = @@ -115,10 +261,7 @@ HRESULT RegisterWithRequest(HttpRequestInterface* http_request, // Create a network request and configure its headers. std::unique_ptr request( new NetworkRequest(network_config->session())); - // DeviceManagementRequestJobImpl::ConfigureRequest. - request->AddHeader(L"Authorization", - internal::FormatEnrollmentTokenAuthorizationHeader( - enrollment_token)); + request->AddHeader(_T("Authorization"), authorization_header); // Set it up request->AddHttpRequest(http_request); @@ -131,22 +274,10 @@ HRESULT RegisterWithRequest(HttpRequestInterface* http_request, return hr; } - std::vector> query_params; - // DeviceManagementRequestJob::DeviceManagementRequestJob. - // kParamRequest = kValueRequestTokenEnrollment. - query_params.push_back(std::make_pair(_T("request"), - _T("register_policy_agent"))); - // kParamAppType = kValueAppType. - query_params.push_back(std::make_pair(_T("apptype"), _T("Chrome"))); - // kParamAgent. - query_params.push_back(std::make_pair(_T("agent"), internal::GetAgent())); - // kParamPlatform. - query_params.push_back(std::make_pair(_T("platform"), - internal::GetPlatform())); - - // DeviceManagementRequestJob::SetClientID. - // kParamDeviceID. - query_params.push_back(std::make_pair(_T("deviceid"), device_id)); + query_params.emplace_back(_T("agent"), internal::GetAgent()); + query_params.emplace_back(_T("apptype"), _T("Chrome")); + query_params.emplace_back(_T("deviceid"), device_id); + query_params.emplace_back(_T("platform"), internal::GetPlatform()); hr = internal::AppendQueryParamsToUrl(query_params, &url); if (FAILED(hr)) { @@ -154,19 +285,7 @@ HRESULT RegisterWithRequest(HttpRequestInterface* http_request, return hr; } - // Make the request payload. - // DeviceManagementRequest.RegisterBrowserRequest: - CStringA payload = SerializeRegisterBrowserRequest( - WideToUtf8(app_util::GetHostName()), // policy::GetMachineName - CStringA("Windows"), // policy::GetOSPlatform - internal::GetOsVersion()); // policy::GetOSVersion - if (payload.IsEmpty()) { - REPORT_LOG(LE, (_T("[SerializeRegisterBrowserRequest failed]"))); - return E_FAIL; - } - - std::vector response; - hr = request->Post(url, payload, payload.GetLength(), &response); + hr = request->Post(url, payload, payload.GetLength(), response); if (FAILED(hr)) { REPORT_LOG(LE, (_T("[NetworkRequest::Post failed][%#x, %s]"), hr, url)); return hr; @@ -177,31 +296,23 @@ HRESULT RegisterWithRequest(HttpRequestInterface* http_request, REPORT_LOG(LE, (_T("[NetworkRequest::Post failed][status code %d]"), http_status_code)); CStringA error_message; - hr = ParseDeviceManagementResponseError(response, &error_message); + hr = ParseDeviceManagementResponseError(*response, &error_message); if (SUCCEEDED(hr)) { OPT_LOG(LE, (_T("[Server returned: %S]"), error_message)); } return E_FAIL; } - hr = ParseDeviceRegisterResponse(response, dm_token); - if (FAILED(hr)) { - REPORT_LOG(LE, (_T("[ParseDeviceRegisterResponse failed][%#x]"), hr)); - return hr; - } - return S_OK; } CString GetAgent() { - // DeviceManagementServiceConfiguration::GetAgentParameter. CString agent; SafeCStringFormat(&agent, _T("%s %s()"), kAppName, GetVersionString()); return agent; } CString GetPlatform() { - // DeviceManagementServiceConfiguration::GetPlatformParameter. const DWORD architecture = SystemInfo::GetProcessorArchitecture(); int major_version = 0; @@ -258,6 +369,13 @@ CString FormatEnrollmentTokenAuthorizationHeader( return header_value; } +CString FormatDMTokenAuthorizationHeader( + const CString& token) { + CString header_value; + SafeCStringFormat(&header_value, _T("GoogleDMToken token=%s"), token); + return header_value; +} + } // namespace internal } // namespace dm_client } // namespace omaha diff --git a/omaha/goopdate/dm_client.h b/omaha/goopdate/dm_client.h index d8e906f23..896e32541 100644 --- a/omaha/goopdate/dm_client.h +++ b/omaha/goopdate/dm_client.h @@ -19,6 +19,7 @@ #include #include #include +#include "omaha/goopdate/dm_messages.h" namespace omaha { @@ -27,6 +28,10 @@ class HttpRequestInterface; namespace dm_client { +// The policy type that supports getting the policies for all Machine +// applications from the DMServer. +const char kGoogleUpdateMachineLevelApps[] = "google/machine-level-apps"; + enum RegistrationState { // This client appears to not be managed. In particular, neither a device // management token nor an enrollment token can be found. @@ -48,19 +53,39 @@ RegistrationState GetRegistrationState(DmStorage* dm_storage); // enrollment token is found), or a failure HRESULT in case of error. HRESULT RegisterIfNeeded(DmStorage* dm_storage); +// Retrieve and persist locally the policies from the Device Management Server. +HRESULT RefreshPolicies(); + namespace internal { HRESULT RegisterWithRequest(HttpRequestInterface* http_request, const CString& enrollment_token, const CString& device_id, CStringA* dm_token); + +// Fetch policies from the DMServer. The policies are returned in |responses| +// containing elements in the following format: +// {policy_type}=>{SerializeToString-PolicyFetchResponse}. +HRESULT FetchPolicies(HttpRequestInterface* http_request, + const CString& dm_token, + const CString& device_id, + PolicyResponsesMap* responses); + +HRESULT SendDeviceManagementRequest( + HttpRequestInterface* http_request, + const CStringA& payload, + const CString& authorization_header, + const CString& device_id, + std::vector> query_params, + std::vector* response); CString GetAgent(); CString GetPlatform(); CStringA GetOsVersion(); HRESULT AppendQueryParamsToUrl( - const std::vector>& query_params, + const std::vector>& query_params, CString* url); CString FormatEnrollmentTokenAuthorizationHeader(const CString& token); +CString FormatDMTokenAuthorizationHeader(const CString& token); } // namespace internal } // namespace dm_client diff --git a/omaha/goopdate/dm_client_unittest.cc b/omaha/goopdate/dm_client_unittest.cc index 68f873a48..26bb241d1 100644 --- a/omaha/goopdate/dm_client_unittest.cc +++ b/omaha/goopdate/dm_client_unittest.cc @@ -15,11 +15,14 @@ #include "omaha/goopdate/dm_client.h" #include +#include +#include #include #include #include "base/basictypes.h" #include "gtest/gtest-matchers.h" +#include "omaha/base/scope_guard.h" #include "omaha/base/string.h" #include "omaha/common/config_manager.h" #include "omaha/goopdate/dm_storage.h" @@ -49,8 +52,9 @@ namespace { class IsValidRequestUrlMatcher : public ::testing::MatcherInterface { public: - IsValidRequestUrlMatcher(const TCHAR* request_type, const TCHAR* device_id) - : request_type_(request_type), device_id_(device_id) { + IsValidRequestUrlMatcher( + std::vector> query_params) + : query_params_(std::move(query_params)) { ConfigManager::Instance()->GetDeviceManagementUrl(&device_management_url_); } @@ -79,56 +83,53 @@ class IsValidRequestUrlMatcher } // Check that the required params are present. - static const TCHAR* kRequiredParams[] = { - _T("agent"), - _T("apptype"), - _T("deviceid"), - _T("platform"), - _T("request"), - }; - for (size_t i = 0; i < arraysize(kRequiredParams); ++i) { - const TCHAR* p = kRequiredParams[i]; + for (const auto& query_param : query_params_) { + const TCHAR* p = query_param.first; if (query_params.find(p) == query_params.end()) { *listener << "the url is missing the \"" << WideToUtf8(p) << "\" query parameter"; return false; } - } - // Check the value of the request param. - if (query_params[_T("request")] != request_type_) { - *listener << "the request query parameter is \"" - << query_params[_T("request")] << "\""; - return false; - } + CString expected_param_value; + HRESULT hr = StringEscape(query_param.second, + false, + &expected_param_value); + if (FAILED(hr)) { + *listener << "failed to StringEscape \"" + << WideToUtf8(query_param.second) + << "\" query parameter"; + return false; + } - // Check the value of the device_id param. - if (query_params[_T("deviceid")] != device_id_) { - *listener << "the device_id query parameter is \"" - << query_params[_T("device_id")] << "\""; - return false; + if (query_params[p] != expected_param_value) { + *listener << "the actual request query parameter is \"" + << WideToUtf8(query_params[p]) << "\"" + << " and does not match the expected query parameter of \"" + << WideToUtf8(expected_param_value) << "\""; + return false; + } } return true; } - virtual void DescribeTo(::std::ostream* os) const { + virtual void DescribeTo(std::ostream* os) const { *os << "string contains a valid device management request URL"; } private: - const TCHAR* const request_type_; - const TCHAR* const device_id_; + const std::vector> query_params_; CString device_management_url_; }; // Returns an IsValidRequestUrl matcher, which takes a CString and matches if -// it is an URL leading to the device management server endpoint, contains all -// required query parameters, and has the given |request_type| and |device_id|. -::testing::Matcher IsValidRequestUrl(const TCHAR* request_type, - const TCHAR* device_id) { +// it is an URL leading to the device management server endpoint, and contains +// all the required query parameters in |query_params|. +::testing::Matcher IsValidRequestUrl( + std::vector> query_params) { return ::testing::MakeMatcher( - new IsValidRequestUrlMatcher(request_type, device_id)); + new IsValidRequestUrlMatcher(std::move(query_params))); } // A Google Mock matcher that returns true if a buffer contains a valid @@ -169,11 +170,59 @@ class IsRegisterBrowserRequestMatcher return true; } - virtual void DescribeTo(::std::ostream* os) const { + virtual void DescribeTo(std::ostream* os) const { *os << "buffer contains a valid serialized RegisterBrowserRequest"; } }; +// A Google Mock matcher that returns true if a buffer contains a valid +// serialized DevicePolicyRequest message. While the presence of each field +// in the request is checked, the exact value of each is not. +class IsFetchPoliciesRequestMatcher + : public ::testing::MatcherInterface& > { + public: + virtual bool MatchAndExplain( + const ::testing::tuple& buffer, + ::testing::MatchResultListener* listener) const { + enterprise_management::DeviceManagementRequest request; + if (!request.ParseFromArray( + ::testing::get<0>(buffer), + static_cast(::testing::get<1>(buffer)))) { + *listener << "parse failure"; + return false; + } + if (!request.has_policy_request()) { + *listener << "missing policy_request"; + return false; + } + if (!request.policy_request().requests_size()) { + *listener << "unexpected requests_size() == 0"; + return false; + } + const enterprise_management::PolicyFetchRequest& policy_request = + request.policy_request().requests(0); + if (!policy_request.has_policy_type()) { + *listener << "missing policy_request.has_policy_type"; + return false; + } + if (!policy_request.has_signature_type()) { + *listener << "missing policy_request.has_signature_type"; + return false; + } + if (!policy_request.has_verification_key_hash()) { + *listener << "missing policy_request.has_verification_key_hash"; + return false; + } + return true; + } + + + virtual void DescribeTo(std::ostream* os) const { + *os << "buffer contains a valid serialized DevicePolicyRequest"; + } +}; + // Returns an IsRegisterBrowserRequest matcher, which takes a tuple of a pointer // to a buffer and a buffer size. ::testing::Matcher& > @@ -181,6 +230,13 @@ IsRegisterBrowserRequest() { return ::testing::MakeMatcher(new IsRegisterBrowserRequestMatcher); } +// Returns an IsFetchPoliciesRequest matcher, which takes a tuple of a pointer +// to a buffer and a buffer size. +::testing::Matcher& > +IsFetchPoliciesRequest() { + return ::testing::MakeMatcher(new IsFetchPoliciesRequestMatcher); +} + class MockHttpRequest : public HttpRequestInterface { public: MOCK_METHOD0(Close, HRESULT()); @@ -220,10 +276,11 @@ class DmClientRequestTest : public ::testing::Test { virtual ~DmClientRequestTest() {} // Populates |request| with a mock HttpRequest that behaves as if the server - // successfully processed a RegisterBrowserRequest, returning a - // DeviceRegisterResponse containing |dm_token|. + // successfully processed a HTTP request, returning a HTTP response containing + // |response_data|. // Note: always wrap calls to this with ASSERT_NO_FATAL_FAILURE. - void MakeSuccessHttpRequest(const char* dm_token, MockHttpRequest** request) { + template + void MakeSuccessHttpRequest(T response_data, MockHttpRequest** request) { *request = new ::testing::NiceMock(); // The server responds with 200. @@ -232,7 +289,7 @@ class DmClientRequestTest : public ::testing::Test { // And a valid response. std::vector response; - ASSERT_NO_FATAL_FAILURE(MakeSuccessResponseBody(dm_token, &response)); + ASSERT_NO_FATAL_FAILURE(MakeSuccessResponseBody(response_data, &response)); ON_CALL(**request, GetResponse()).WillByDefault(Return(response)); } @@ -248,6 +305,26 @@ class DmClientRequestTest : public ::testing::Test { body->assign(response_string.begin(), response_string.end()); } + // Populates |body| with a valid serialized DevicePolicyResponse. + // Note: always wrap calls to this with ASSERT_NO_FATAL_FAILURE. + void MakeSuccessResponseBody(const PolicyResponsesMap& responses, + std::vector* body) { + enterprise_management::DeviceManagementResponse dm_response; + + for (const auto& response : responses) { + enterprise_management::PolicyFetchResponse* policy_response = + dm_response.mutable_policy_response()->add_responses(); + enterprise_management::PolicyData policy_data; + policy_data.set_policy_type(response.first); + policy_data.set_policy_value(response.second); + policy_response->set_policy_data(policy_data.SerializeAsString()); + } + + std::string response_string; + ASSERT_TRUE(dm_response.SerializeToString(&response_string)); + body->assign(response_string.begin(), response_string.end()); + } + DISALLOW_COPY_AND_ASSIGN(DmClientRequestTest); }; @@ -260,10 +337,17 @@ TEST_F(DmClientRequestTest, RegisterWithRequest) { MockHttpRequest* mock_http_request = nullptr; ASSERT_NO_FATAL_FAILURE(MakeSuccessHttpRequest(kDmToken, &mock_http_request)); + std::vector> query_params = { + {_T("request"), _T("register_policy_agent")}, + {_T("agent"), internal::GetAgent()}, + {_T("apptype"), _T("Chrome")}, + {_T("deviceid"), kDeviceId}, + {_T("platform"), internal::GetPlatform()}, + }; + // Expect the proper URL with query params. EXPECT_CALL(*mock_http_request, - set_url(IsValidRequestUrl(_T("register_policy_agent"), - kDeviceId))); + set_url(IsValidRequestUrl(std::move(query_params)))); // Expect that the request headers contain the enrollment token. EXPECT_CALL(*mock_http_request, @@ -285,33 +369,101 @@ TEST_F(DmClientRequestTest, RegisterWithRequest) { EXPECT_STREQ(dm_token.GetString(), kDmToken); } +// Test that DmClient can send a reasonable DevicePolicyRequest and handle a +// corresponding DevicePolicyResponse. +TEST_F(DmClientRequestTest, FetchPolicies) { + static const TCHAR kDeviceId[] = _T("device_id"); + + PolicyResponsesMap expected_responses = { + {"google/chrome/machine-level-user", "test-data-chrome"}, + {"google/drive/machine-level-user", "test-data-drive"}, + {"google/earth/machine-level-user", "test-data-earth"}, + }; + + MockHttpRequest* mock_http_request = nullptr; + ASSERT_NO_FATAL_FAILURE(MakeSuccessHttpRequest(expected_responses, + &mock_http_request)); + + std::vector> query_params = { + {_T("request"), _T("policy")}, + {_T("agent"), internal::GetAgent()}, + {_T("apptype"), _T("Chrome")}, + {_T("deviceid"), kDeviceId}, + {_T("platform"), internal::GetPlatform()}, + }; + + // Expect the proper URL with query params. + EXPECT_CALL(*mock_http_request, + set_url(IsValidRequestUrl(std::move(query_params)))); + + // Expect that the request headers contain the DMToken. + EXPECT_CALL(*mock_http_request, + set_additional_headers( + CStringHasSubstr(_T("Authorization: GoogleDMToken ") + _T("token=dm_token")))); + + // Expect that the body of the request contains a well-formed fetch policies + // request. + EXPECT_CALL(*mock_http_request, set_request_buffer(_, _)) + .With(AllArgs(IsFetchPoliciesRequest())); + + // Fetch Policies should succeed, providing the expected PolicyResponsesMap. + PolicyResponsesMap responses; + ASSERT_HRESULT_SUCCEEDED(internal::FetchPolicies(mock_http_request, + _T("dm_token"), + kDeviceId, + &responses)); + + EXPECT_EQ(expected_responses.size(), responses.size()); + for (const auto& expected_response : expected_responses) { + enterprise_management::PolicyFetchResponse response; + EXPECT_TRUE(response.ParseFromString( + responses[expected_response.first.c_str()])); + + enterprise_management::PolicyData policy_data; + EXPECT_TRUE(policy_data.ParseFromString(response.policy_data())); + EXPECT_TRUE(policy_data.IsInitialized()); + EXPECT_TRUE(policy_data.has_policy_type()); + + EXPECT_STREQ(expected_response.first.c_str(), + policy_data.policy_type().c_str()); + EXPECT_STREQ(expected_response.second.c_str(), + policy_data.policy_value().c_str()); + } +} + class DmClientRegistryTest : public RegistryProtectedTest { }; TEST_F(DmClientRegistryTest, GetRegistrationState) { // No enrollment token. { - DmStorage dm_storage((CString())); - EXPECT_EQ(GetRegistrationState(&dm_storage), kNotManaged); + EXPECT_HRESULT_SUCCEEDED(DmStorage::CreateInstance(CString())); + ON_SCOPE_EXIT(DmStorage::DeleteInstance); + EXPECT_EQ(GetRegistrationState(DmStorage::Instance()), kNotManaged); } // Enrollment token without device management token. { - DmStorage dm_storage(_T("enrollment_token")); - EXPECT_EQ(GetRegistrationState(&dm_storage), kRegistrationPending); + EXPECT_HRESULT_SUCCEEDED(DmStorage::CreateInstance(_T("enrollment_token"))); + ON_SCOPE_EXIT(DmStorage::DeleteInstance); + EXPECT_EQ(GetRegistrationState(DmStorage::Instance()), + kRegistrationPending); } // Enrollment token and device management token. ASSERT_NO_FATAL_FAILURE(WriteCompanyDmToken("dm_token")); { - DmStorage dm_storage(_T("enrollment_token")); - EXPECT_EQ(GetRegistrationState(&dm_storage), kRegistered); + EXPECT_HRESULT_SUCCEEDED(DmStorage::CreateInstance(_T("enrollment_token"))); + ON_SCOPE_EXIT(DmStorage::DeleteInstance); + EXPECT_EQ(GetRegistrationState(DmStorage::Instance()), kRegistered); } // Device management token without enrollment token. { - DmStorage dm_storage((CString())); - EXPECT_EQ(GetRegistrationState(&dm_storage), kRegistered); + EXPECT_HRESULT_SUCCEEDED(DmStorage::CreateInstance(CString())); + ON_SCOPE_EXIT(DmStorage::DeleteInstance); + EXPECT_EQ(GetRegistrationState(DmStorage::Instance()), kRegistered); } } @@ -329,9 +481,11 @@ TEST(DmClientTest, GetOsVersion) { TEST(DmClientTest, AppendQueryParamsToUrl) { static const TCHAR kUrl[] = _T("https://some.net/endpoint"); - std::vector> params; - params.push_back(std::make_pair(_T("one"), _T("1"))); - params.push_back(std::make_pair(_T("2"), _T("two"))); + std::vector> params = { + {_T("one"), _T("1")}, + {_T("2"), _T("two")}, + }; + CString url(kUrl); EXPECT_HRESULT_SUCCEEDED(internal::AppendQueryParamsToUrl(params, &url)); EXPECT_EQ(url, CString(kUrl) + _T("?one=1&2=two")); @@ -343,5 +497,11 @@ TEST(DmClientTest, FormatEnrollmentTokenAuthorizationHeader) { _T("GoogleEnrollmentToken token=token")); } +TEST(DmClientTest, FormatDMTokenAuthorizationHeader) { + static const TCHAR kToken[] = _T("token"); + EXPECT_EQ(internal::FormatDMTokenAuthorizationHeader(kToken), + _T("GoogleDMToken token=token")); +} + } // namespace dm_client } // namespace omaha diff --git a/omaha/goopdate/dm_messages.cc b/omaha/goopdate/dm_messages.cc index 78dab9a42..b322bc486 100644 --- a/omaha/goopdate/dm_messages.cc +++ b/omaha/goopdate/dm_messages.cc @@ -15,8 +15,10 @@ #include "omaha/goopdate/dm_messages.h" #include +#include #include "omaha/base/debug.h" +#include "omaha/base/logging.h" #include "wireless/android/enterprise/devicemanagement/proto/dm_api.pb.h" namespace omaha { @@ -39,7 +41,7 @@ void SerializeToCStringA(const ::google::protobuf_opensource::Message& message, output->ReleaseBufferSetLength(end - buffer); } -} // namespace +} // namespace CStringA SerializeRegisterBrowserRequest(const CStringA& machine_name, const CStringA& os_platform, @@ -57,6 +59,25 @@ CStringA SerializeRegisterBrowserRequest(const CStringA& machine_name, return result; } +CStringA SerializePolicyFetchRequest(const CStringA& policy_type) { + // Request signed policy blobs. kPolicyVerificationKeyHash needs to be kept in + // sync with the corresponding value in Chromium's cloud_policy_constants.cc. + static constexpr char kPolicyVerificationKeyHash[] = "1:356l7w"; + + enterprise_management::DeviceManagementRequest policy_request; + + enterprise_management::PolicyFetchRequest* policy_fetch_request = + policy_request.mutable_policy_request()->add_requests(); + policy_fetch_request->set_policy_type(policy_type); + policy_fetch_request->set_signature_type( + enterprise_management::PolicyFetchRequest::SHA1_RSA); + policy_fetch_request->set_verification_key_hash(kPolicyVerificationKeyHash); + + CStringA result; + SerializeToCStringA(policy_request, &result); + return result; +} + HRESULT ParseDeviceRegisterResponse(const std::vector& response, CStringA* dm_token) { ASSERT1(dm_token); @@ -89,6 +110,61 @@ HRESULT ParseDeviceRegisterResponse(const std::vector& response, return S_OK; } +HRESULT ParseDevicePolicyResponse(const std::vector& dm_response_array, + PolicyResponsesMap* response_map) { + ASSERT1(response_map); + enterprise_management::DeviceManagementResponse dm_response; + + if (dm_response_array.size() > + static_cast(std::numeric_limits::max())) { + return E_FAIL; + } + + if (!dm_response.ParseFromArray(dm_response_array.data(), + static_cast(dm_response_array.size()))) { + return E_FAIL; + } + + if (!dm_response.has_policy_response() || + dm_response.policy_response().responses_size() == 0) { + return E_FAIL; + } + + const enterprise_management::DevicePolicyResponse& policy_response = + dm_response.policy_response(); + PolicyResponsesMap responses; + for (int i = 0; i < policy_response.responses_size(); ++i) { + const enterprise_management::PolicyFetchResponse& response = + policy_response.responses(i); + enterprise_management::PolicyData policy_data; + if (!policy_data.ParseFromString(response.policy_data()) || + !policy_data.IsInitialized() || + !policy_data.has_policy_type()) { + OPT_LOG(LW, (_T("Ignoring invalid PolicyData"))); + continue; + } + + const std::string& type = policy_data.policy_type(); + if (responses.find(type) != responses.end()) { + OPT_LOG(LW, (_T("Duplicate PolicyFetchResponse for type: %S"), + type.c_str())); + continue; + } + + std::string policy_fetch_response; + if (!response.SerializeToString(&policy_fetch_response)) { + OPT_LOG(LW, (_T("Failed to serialize response for type: %S"), + type.c_str())); + continue; + } + + responses[type] = std::move(policy_fetch_response); + } + + *response_map = std::move(responses); + return S_OK; +} + HRESULT ParseDeviceManagementResponseError(const std::vector& response, CStringA* error_message) { ASSERT1(error_message); diff --git a/omaha/goopdate/dm_messages.h b/omaha/goopdate/dm_messages.h index 160f5a220..4f2a27488 100644 --- a/omaha/goopdate/dm_messages.h +++ b/omaha/goopdate/dm_messages.h @@ -16,19 +16,32 @@ #define OMAHA_GOOPDATE_DM_MESSAGES_H__ #include +#include +#include #include #include "base/basictypes.h" namespace omaha { +// Maps policy types to their corresponding serialized PolicyFetchResponses. +using PolicyResponsesMap = std::map; + CStringA SerializeRegisterBrowserRequest(const CStringA& machine_name, const CStringA& os_platform, const CStringA& os_version); +CStringA SerializePolicyFetchRequest(const CStringA& policy_type); + HRESULT ParseDeviceRegisterResponse(const std::vector& response, CStringA* dm_token); +// Parses the policies from the DMServer, and return the PolicyFetchResponses in +// |responses|. |responses| contains elements in the following format: +// {policy_type}=>{SerializeToString-PolicyFetchResponse}. +HRESULT ParseDevicePolicyResponse(const std::vector& dm_response_array, + PolicyResponsesMap* response_map); + HRESULT ParseDeviceManagementResponseError(const std::vector& response, CStringA* error_message); diff --git a/omaha/goopdate/dm_storage.cc b/omaha/goopdate/dm_storage.cc index ee8e3936b..f5a4a7789 100644 --- a/omaha/goopdate/dm_storage.cc +++ b/omaha/goopdate/dm_storage.cc @@ -14,10 +14,17 @@ #include "omaha/goopdate/dm_storage.h" +#include + #include "omaha/base/const_utils.h" #include "omaha/base/debug.h" +#include "omaha/base/file.h" #include "omaha/base/logging.h" +#include "omaha/base/path.h" #include "omaha/base/reg_key.h" +#include "omaha/base/safe_format.h" +#include "omaha/base/string.h" +#include "omaha/base/utils.h" #include "omaha/common/app_registry_utils.h" #include "omaha/common/config_manager.h" #include "omaha/common/const_goopdate.h" @@ -111,12 +118,56 @@ HRESULT StoreDmTokenInKey(const CStringA& dm_token, const TCHAR* path) { return hr; } +HRESULT DeleteObsoletePolicies(const CPath& policy_responses_dir, + const std::set& policy_types_base64) { + std::vector files; + HRESULT hr = FindFiles(policy_responses_dir, _T("*"), &files); + if (FAILED(hr)) { + return hr; + } + + for (const auto& file : files) { + if (file == _T(".") || + file == _T("..") || + policy_types_base64.count(file)) { + continue; + } + + CPath path = policy_responses_dir; + VERIFY1(path.Append(file)); + REPORT_LOG(L1, (_T("[DeleteObsoletePolicies][Deleting][%s]"), path)); + VERIFY1(SUCCEEDED(DeleteBeforeOrAfterReboot(path))); + } + + return S_OK; +} + } // namespace -DmStorage::DmStorage(const CString& runtime_enrollment_token) - : runtime_enrollment_token_(runtime_enrollment_token), - enrollment_token_source_(kETokenSourceNone), - dm_token_source_(kDmTokenSourceNone) { +DmStorage* DmStorage::instance_ = NULL; + +// There should not be any contention on creation because only GoopdateImpl +// should create DmStorage during its initialization. +HRESULT DmStorage::CreateInstance(const CString& enrollment_token) { + ASSERT1(!instance_); + + DmStorage* instance = new DmStorage(enrollment_token); + if (!instance) { + return E_OUTOFMEMORY; + } + + instance_ = instance; + return S_OK; +} + +void DmStorage::DeleteInstance() { + delete instance_; + instance_ = NULL; +} + +DmStorage* DmStorage::Instance() { + ASSERT1(instance_); + return instance_; } CString DmStorage::GetEnrollmentToken() { @@ -155,6 +206,7 @@ CStringA DmStorage::GetDmToken() { HRESULT DmStorage::StoreDmToken(const CStringA& dm_token) { HRESULT hr = StoreDmTokenInKey(dm_token, kRegKeyCompanyEnrollment); if (SUCCEEDED(hr)) { + dm_token_ = dm_token; dm_token_source_ = kDmTokenSourceCompany; #if defined(HAS_LEGACY_DM_CLIENT) hr = StoreDmTokenInKey(dm_token, kRegKeyLegacyEnrollment); @@ -170,6 +222,66 @@ CString DmStorage::GetDeviceId() { return device_id_; } +HRESULT DmStorage::PersistPolicies(const CPath& policy_responses_dir, + const PolicyResponsesMap& responses) { + std::set policy_types_base64; + + for (const auto& response : responses) { + CStringA encoded_policy_response_dirname; + Base64Escape(response.first.c_str(), + static_cast(response.first.length()), + &encoded_policy_response_dirname, + true); + + CString dirname(encoded_policy_response_dirname); + policy_types_base64.emplace(dirname); + CPath policy_response_dir(policy_responses_dir); + policy_response_dir.Append(dirname); + HRESULT hr = CreateDir(policy_response_dir, NULL); + if (FAILED(hr)) { + REPORT_LOG(LW, (_T("[PersistPolicies][Failed to create dir][%s][%#x]"), + policy_response_dir, hr)); + continue; + } + + CPath policy_response_file(policy_response_dir); + policy_response_file.Append(kPolicyResponseFileName); + + File file; + hr = file.Open(policy_response_file, true, false); + if (FAILED(hr)) { + REPORT_LOG(LW, (_T("[PersistPolicies][Failed to open][%s][%#x]"), + policy_response_file, hr)); + continue; + } + + uint32 bytes_written = 0; + hr = file.WriteAt(0, + reinterpret_cast(response.second.c_str()), + static_cast(response.second.length()), + 0, + &bytes_written); + if (FAILED(hr)) { + REPORT_LOG(LW, (_T("[PersistPolicies][Failed to write][%s][%#x]"), + policy_response_file, hr)); + continue; + } + + ASSERT1(bytes_written == response.second.length()); + VERIFY1(SUCCEEDED(file.SetLength(bytes_written, false))); + } + + VERIFY1(SUCCEEDED(DeleteObsoletePolicies(policy_responses_dir, + policy_types_base64))); + return S_OK; +} + +DmStorage::DmStorage(const CString& runtime_enrollment_token) + : runtime_enrollment_token_(runtime_enrollment_token), + enrollment_token_source_(kETokenSourceNone), + dm_token_source_(kDmTokenSourceNone) { +} + void DmStorage::LoadEnrollmentTokenFromStorage() { // Load from most to least preferred, stopping when one is found. enrollment_token_ = LoadEnrollmentTokenFromCompanyPolicy(); diff --git a/omaha/goopdate/dm_storage.h b/omaha/goopdate/dm_storage.h index be484b897..97f7910d3 100644 --- a/omaha/goopdate/dm_storage.h +++ b/omaha/goopdate/dm_storage.h @@ -16,13 +16,19 @@ #define OMAHA_GOOPDATE_DM_STORAGE_H__ #include +#include #include #include "base/basictypes.h" #include "omaha/base/constants.h" +#include "omaha/goopdate/dm_messages.h" namespace omaha { +// This is the standard name for the file that PersistPolicies() uses for each +// {policy_type} that it receives from the DMServer. +const TCHAR kPolicyResponseFileName[] = _T("PolicyFetchResponse"); + // A handler for storage related to cloud-based device management of Omaha. This // class provides access to an enrollment token, a device management token, and // a device identifier. @@ -41,9 +47,10 @@ class DmStorage { kETokenSourceInstall, }; - // Constructs an instance with a runtime-provided enrollment token (e.g., one - // obtained via the etoken extra arg). - explicit DmStorage(const CString& runtime_enrollment_token); + static HRESULT CreateInstance(const CString& enrollment_token); + static void DeleteInstance(); + + static DmStorage* Instance(); // Returns the current enrollment token, reading from sources as-needed to // find one. Returns an empty string if no enrollment token is found. @@ -51,7 +58,7 @@ class DmStorage { // Returns the origin of the current enrollment token, or kETokenSourceNone if // none has been found. - EnrollmentTokenSource enrollment_token_source () const { + EnrollmentTokenSource enrollment_token_source() const { return enrollment_token_source_; } @@ -71,7 +78,30 @@ class DmStorage { // Returns the device identifier, or an empty string in case of error. CString GetDeviceId(); + // Persists each PolicyFetchResponse in |responses| into a subdirectory within + // |policy_responses_dir|. Each PolicyFetchResponse is stored within a + // subdirectory named {Base64Encoded{policy_type}}, with a fixed file name of + // "PolicyFetchResponse", where the file contents are + // {SerializeToString-PolicyFetchResponse}}. + // + // Each PolicyFetchResponse file is opened in exclusive mode. If we are unable + // to open or write to this file, the caller is expected to try again later. + // For instance, if UA is calling us, UA will retry at the next UA interval. + // + // Client applications could use ::FindFirstChangeNotificationW on the + // subdirectory corresponding to their respective policy_type to watch for + // changes. They can then read and apply the policies within this file. + // To minimize the number of notifications for existing PolicyFetchResponse + // files, the files are first modified in-place if the response includes them, + // and then the files that do not have a corresponding response are deleted. + static HRESULT PersistPolicies(const CPath& policy_responses_dir, + const PolicyResponsesMap& responses); + private: + // Constructs an instance with a runtime-provided enrollment token (e.g., one + // obtained via the etoken extra arg). + explicit DmStorage(const CString& runtime_enrollment_token); + // The possible sources of a device management token, sorted by decreasing // precedence. enum DmTokenSource { @@ -104,6 +134,10 @@ class DmStorage { // The origin of the current device management token. DmTokenSource dm_token_source_; + static DmStorage* instance_; + + friend class DmStorageTest; + DISALLOW_COPY_AND_ASSIGN(DmStorage); }; diff --git a/omaha/goopdate/dm_storage_unittest.cc b/omaha/goopdate/dm_storage_unittest.cc index f172a7f3d..22aec06e6 100644 --- a/omaha/goopdate/dm_storage_unittest.cc +++ b/omaha/goopdate/dm_storage_unittest.cc @@ -14,6 +14,12 @@ #include "omaha/goopdate/dm_storage.h" +#include "omaha/base/app_util.h" +#include "omaha/base/file.h" +#include "omaha/base/path.h" +#include "omaha/base/scope_guard.h" +#include "omaha/base/string.h" +#include "omaha/base/utils.h" #include "omaha/goopdate/dm_storage_test_utils.h" #include "omaha/testing/unit_test.h" @@ -31,6 +37,41 @@ class DmStorageTest : public RegistryProtectedTest { static const TCHAR kETOldLegacyPolicy[]; static const char kDmTLegacy[]; #endif // defined(HAS_LEGACY_DM_CLIENT) + + DmStorage* NewDmStorage(const CString& enrollment_token) { + return new DmStorage(enrollment_token); + } + + CPath GetPolicyResponseFilePath(const CPath& policy_responses_dir, + const std::string& policy_type) { + CStringA encoded_policy_response_dirname; + Base64Escape(policy_type.c_str(), + static_cast(policy_type.length()), + &encoded_policy_response_dirname, + true); + + CPath policy_response_file(policy_responses_dir); + policy_response_file.Append(CString(encoded_policy_response_dirname)); + policy_response_file.Append(kPolicyResponseFileName); + return policy_response_file; + } + + void VerifyPolicies(const CPath& policy_responses_dir, + const PolicyResponsesMap& expected_responses) { + for (const auto& expected_response : expected_responses) { + CPath policy_response_file = GetPolicyResponseFilePath( + policy_responses_dir, expected_response.first); + + std::vector raw_policy_response; + ASSERT_HRESULT_SUCCEEDED(ReadEntireFileShareMode( + policy_response_file, 0, FILE_SHARE_READ, &raw_policy_response)); + const std::string policy_response( + reinterpret_cast(&raw_policy_response[0]), + raw_policy_response.size()); + + ASSERT_STREQ(expected_response.second.c_str(), policy_response.c_str()); + } + } }; const TCHAR DmStorageTest::kETRuntime[] = _T("runtime"); @@ -45,32 +86,33 @@ const char DmStorageTest::kDmTLegacy[] = "legacy"; // Test that empty strings are returned when the registry holds nothing. TEST_F(DmStorageTest, NoEnrollmentToken) { - DmStorage dm_storage((CString())); - EXPECT_EQ(dm_storage.GetEnrollmentToken(), CString()); - EXPECT_EQ(dm_storage.enrollment_token_source(), DmStorage::kETokenSourceNone); + std::unique_ptr dm_storage(NewDmStorage((CString()))); + EXPECT_EQ(dm_storage->GetEnrollmentToken(), CString()); + EXPECT_EQ(dm_storage->enrollment_token_source(), + DmStorage::kETokenSourceNone); } // Test the individual sources. TEST_F(DmStorageTest, EnrollmentTokenFromRuntime) { - DmStorage dm_storage(kETRuntime); - EXPECT_STREQ(dm_storage.GetEnrollmentToken(), kETRuntime); - EXPECT_EQ(dm_storage.enrollment_token_source(), + std::unique_ptr dm_storage(NewDmStorage(kETRuntime)); + EXPECT_STREQ(dm_storage->GetEnrollmentToken(), kETRuntime); + EXPECT_EQ(dm_storage->enrollment_token_source(), DmStorage::kETokenSourceRuntime); } TEST_F(DmStorageTest, EnrollmentTokenFromInstall) { ASSERT_NO_FATAL_FAILURE(WriteInstallToken(kETInstall)); - DmStorage dm_storage((CString())); - EXPECT_STREQ(dm_storage.GetEnrollmentToken(), kETInstall); - EXPECT_EQ(dm_storage.enrollment_token_source(), + std::unique_ptr dm_storage(NewDmStorage((CString()))); + EXPECT_STREQ(dm_storage->GetEnrollmentToken(), kETInstall); + EXPECT_EQ(dm_storage->enrollment_token_source(), DmStorage::kETokenSourceInstall); } TEST_F(DmStorageTest, EnrollmentTokenFromCompanyPolicy) { ASSERT_NO_FATAL_FAILURE(WriteCompanyPolicyToken(kETCompanyPolicy)); - DmStorage dm_storage((CString())); - EXPECT_STREQ(dm_storage.GetEnrollmentToken(), kETCompanyPolicy); - EXPECT_EQ(dm_storage.enrollment_token_source(), + std::unique_ptr dm_storage(NewDmStorage((CString()))); + EXPECT_STREQ(dm_storage->GetEnrollmentToken(), kETCompanyPolicy); + EXPECT_EQ(dm_storage->enrollment_token_source(), DmStorage::kETokenSourceCompanyPolicy); } @@ -78,17 +120,17 @@ TEST_F(DmStorageTest, EnrollmentTokenFromCompanyPolicy) { TEST_F(DmStorageTest, EnrollmentTokenFromLegacyPolicy) { ASSERT_NO_FATAL_FAILURE(WriteLegacyPolicyToken(kETLegacyPolicy)); - DmStorage dm_storage((CString())); - EXPECT_STREQ(dm_storage.GetEnrollmentToken(), kETLegacyPolicy); - EXPECT_EQ(dm_storage.enrollment_token_source(), + std::unique_ptr dm_storage(NewDmStorage((CString()))); + EXPECT_STREQ(dm_storage->GetEnrollmentToken(), kETLegacyPolicy); + EXPECT_EQ(dm_storage->enrollment_token_source(), DmStorage::kETokenSourceLegacyPolicy); } TEST_F(DmStorageTest, EnrollmentTokenFromOldLegacyPolicy) { ASSERT_NO_FATAL_FAILURE(WriteOldLegacyPolicyToken(kETOldLegacyPolicy)); - DmStorage dm_storage((CString())); - EXPECT_STREQ(dm_storage.GetEnrollmentToken(), kETOldLegacyPolicy); - EXPECT_EQ(dm_storage.enrollment_token_source(), + std::unique_ptr dm_storage(NewDmStorage((CString()))); + EXPECT_STREQ(dm_storage->GetEnrollmentToken(), kETOldLegacyPolicy); + EXPECT_EQ(dm_storage->enrollment_token_source(), DmStorage::kETokenSourceOldLegacyPolicy); } @@ -98,85 +140,85 @@ TEST_F(DmStorageTest, EnrollmentTokenPrecedence) { // Add the sources from lowest to highest priority. ASSERT_NO_FATAL_FAILURE(WriteInstallToken(kETInstall)); { - DmStorage dm_storage((CString())); - EXPECT_STREQ(dm_storage.GetEnrollmentToken(), kETInstall); - EXPECT_EQ(dm_storage.enrollment_token_source(), + std::unique_ptr dm_storage(NewDmStorage((CString()))); + EXPECT_STREQ(dm_storage->GetEnrollmentToken(), kETInstall); + EXPECT_EQ(dm_storage->enrollment_token_source(), DmStorage::kETokenSourceInstall); } { - DmStorage dm_storage(kETRuntime); - EXPECT_STREQ(dm_storage.GetEnrollmentToken(), kETRuntime); - EXPECT_EQ(dm_storage.enrollment_token_source(), + std::unique_ptr dm_storage(NewDmStorage(kETRuntime)); + EXPECT_STREQ(dm_storage->GetEnrollmentToken(), kETRuntime); + EXPECT_EQ(dm_storage->enrollment_token_source(), DmStorage::kETokenSourceRuntime); } #if defined(HAS_LEGACY_DM_CLIENT) ASSERT_NO_FATAL_FAILURE(WriteOldLegacyPolicyToken(kETOldLegacyPolicy)); { - DmStorage dm_storage(kETRuntime); - EXPECT_STREQ(dm_storage.GetEnrollmentToken(), kETOldLegacyPolicy); - EXPECT_EQ(dm_storage.enrollment_token_source(), + std::unique_ptr dm_storage(NewDmStorage(kETRuntime)); + EXPECT_STREQ(dm_storage->GetEnrollmentToken(), kETOldLegacyPolicy); + EXPECT_EQ(dm_storage->enrollment_token_source(), DmStorage::kETokenSourceOldLegacyPolicy); } ASSERT_NO_FATAL_FAILURE(WriteLegacyPolicyToken(kETLegacyPolicy)); { - DmStorage dm_storage(kETRuntime); - EXPECT_STREQ(dm_storage.GetEnrollmentToken(), kETLegacyPolicy); - EXPECT_EQ(dm_storage.enrollment_token_source(), + std::unique_ptr dm_storage(NewDmStorage(kETRuntime)); + EXPECT_STREQ(dm_storage->GetEnrollmentToken(), kETLegacyPolicy); + EXPECT_EQ(dm_storage->enrollment_token_source(), DmStorage::kETokenSourceLegacyPolicy); } #endif // defined(HAS_LEGACY_DM_CLIENT) ASSERT_NO_FATAL_FAILURE(WriteCompanyPolicyToken(kETCompanyPolicy)); { - DmStorage dm_storage(kETRuntime); - EXPECT_STREQ(dm_storage.GetEnrollmentToken(), kETCompanyPolicy); - EXPECT_EQ(dm_storage.enrollment_token_source(), + std::unique_ptr dm_storage(NewDmStorage(kETRuntime)); + EXPECT_STREQ(dm_storage->GetEnrollmentToken(), kETCompanyPolicy); + EXPECT_EQ(dm_storage->enrollment_token_source(), DmStorage::kETokenSourceCompanyPolicy); } } TEST_F(DmStorageTest, RuntimeEnrollmentTokenForInstall) { { - DmStorage dm_storage(kETRuntime); - EXPECT_STREQ(dm_storage.GetEnrollmentToken(), kETRuntime); - EXPECT_EQ(dm_storage.StoreRuntimeEnrollmentTokenForInstall(), S_OK); + std::unique_ptr dm_storage(NewDmStorage(kETRuntime)); + EXPECT_STREQ(dm_storage->GetEnrollmentToken(), kETRuntime); + EXPECT_EQ(dm_storage->StoreRuntimeEnrollmentTokenForInstall(), S_OK); } - DmStorage dm_storage((CString())); - EXPECT_STREQ(dm_storage.GetEnrollmentToken(), kETRuntime); - EXPECT_EQ(dm_storage.enrollment_token_source(), + std::unique_ptr dm_storage(NewDmStorage((CString()))); + EXPECT_STREQ(dm_storage->GetEnrollmentToken(), kETRuntime); + EXPECT_EQ(dm_storage->enrollment_token_source(), DmStorage::kETokenSourceInstall); } TEST_F(DmStorageTest, PolicyEnrollmentTokenForInstall) { { ASSERT_NO_FATAL_FAILURE(WriteCompanyPolicyToken(kETCompanyPolicy)); - DmStorage dm_storage((CString())); - EXPECT_EQ(dm_storage.StoreRuntimeEnrollmentTokenForInstall(), S_FALSE); + std::unique_ptr dm_storage(NewDmStorage((CString()))); + EXPECT_EQ(dm_storage->StoreRuntimeEnrollmentTokenForInstall(), S_FALSE); } - DmStorage dm_storage((CString())); - EXPECT_STREQ(dm_storage.GetEnrollmentToken(), kETCompanyPolicy); - EXPECT_EQ(dm_storage.enrollment_token_source(), + std::unique_ptr dm_storage(NewDmStorage((CString()))); + EXPECT_STREQ(dm_storage->GetEnrollmentToken(), kETCompanyPolicy); + EXPECT_EQ(dm_storage->enrollment_token_source(), DmStorage::kETokenSourceCompanyPolicy); } TEST_F(DmStorageTest, NoDmToken) { - DmStorage dm_storage((CString())); - EXPECT_EQ(dm_storage.GetDmToken(), CStringA()); + std::unique_ptr dm_storage(NewDmStorage((CString()))); + EXPECT_EQ(dm_storage->GetDmToken(), CStringA()); } TEST_F(DmStorageTest, DmTokenFromCompany) { ASSERT_NO_FATAL_FAILURE(WriteCompanyDmToken(kDmTCompany)); - DmStorage dm_storage((CString())); - EXPECT_EQ(dm_storage.GetDmToken(), kDmTCompany); + std::unique_ptr dm_storage(NewDmStorage((CString()))); + EXPECT_EQ(dm_storage->GetDmToken(), kDmTCompany); } #if defined(HAS_LEGACY_DM_CLIENT) TEST_F(DmStorageTest, DmTokenFromLegacy) { ASSERT_NO_FATAL_FAILURE(WriteLegacyDmToken(kDmTLegacy)); - DmStorage dm_storage((CString())); - EXPECT_EQ(dm_storage.GetDmToken(), kDmTLegacy); + std::unique_ptr dm_storage(NewDmStorage((CString()))); + EXPECT_EQ(dm_storage->GetDmToken(), kDmTLegacy); } #endif // defined(HAS_LEGACY_DM_CLIENT) @@ -186,21 +228,54 @@ TEST_F(DmStorageTest, DmTokenPrecedence) { #if defined(HAS_LEGACY_DM_CLIENT) ASSERT_NO_FATAL_FAILURE(WriteLegacyDmToken(kDmTLegacy)); { - DmStorage dm_storage((CString())); - EXPECT_EQ(dm_storage.GetDmToken(), kDmTLegacy); + std::unique_ptr dm_storage(NewDmStorage((CString()))); + EXPECT_EQ(dm_storage->GetDmToken(), kDmTLegacy); } #endif // defined(HAS_LEGACY_DM_CLIENT) ASSERT_NO_FATAL_FAILURE(WriteCompanyDmToken(kDmTCompany)); - DmStorage dm_storage((CString())); - EXPECT_EQ(dm_storage.GetDmToken(), kDmTCompany); + std::unique_ptr dm_storage(NewDmStorage((CString()))); + EXPECT_EQ(dm_storage->GetDmToken(), kDmTCompany); +} + +TEST_F(DmStorageTest, PersistPolicies) { + PolicyResponsesMap old_responses = { + {"google/chrome/machine-level-user", "test-data-chrome"}, + {"google/drive/machine-level-user", "test-data-drive"}, + {"google/earth/machine-level-user", "test-data-earth"}, + }; + + const CPath policy_responses_dir = CPath(ConcatenatePath( + app_util::GetCurrentModuleDirectory(), + _T("Policies"))); + + ASSERT_HRESULT_SUCCEEDED(DmStorage::PersistPolicies(policy_responses_dir, + old_responses)); + VerifyPolicies(policy_responses_dir, old_responses); + + PolicyResponsesMap new_responses = { + {"google/chrome/machine-level-user", "test-data-chr"}, // Shorter data. + // {"google/drive/machine-level-user", "test-data-drive"}, // Obsolete. + {"google/earth/machine-level-user", + "test-data-earth-foo-bar-baz-foo-bar-baz-foo-bar-baz"}, // Longer data. + {"google/newdrive/machine-level-user", "test-data-newdrive"}, // New. + }; + + ASSERT_HRESULT_SUCCEEDED(DmStorage::PersistPolicies(policy_responses_dir, + new_responses)); + VerifyPolicies(policy_responses_dir, new_responses); + EXPECT_FALSE(GetPolicyResponseFilePath( + policy_responses_dir, "google/drive/machine-level-user").FileExists()); + + EXPECT_HRESULT_SUCCEEDED(DeleteDirectory(policy_responses_dir)); } // This test must access the true registry, so it doesn't use the DmStorageTest // fixture. TEST(DmStorageDeviceIdTest, GetDeviceId) { - DmStorage dm_storage((CString())); - EXPECT_FALSE(dm_storage.GetDeviceId().IsEmpty()); + EXPECT_HRESULT_SUCCEEDED(DmStorage::CreateInstance(CString())); + ON_SCOPE_EXIT(DmStorage::DeleteInstance); + EXPECT_FALSE(DmStorage::Instance()->GetDeviceId().IsEmpty()); } } // namespace omaha diff --git a/omaha/goopdate/goopdate.cc b/omaha/goopdate/goopdate.cc index bb2bbcfcd..9f113ccf8 100644 --- a/omaha/goopdate/goopdate.cc +++ b/omaha/goopdate/goopdate.cc @@ -313,8 +313,6 @@ class GoopdateImpl { // a failure HRESULT if registration was mandatory and failed. HRESULT RegisterForDeviceManagement(); - DmStorage* GetDmStorage(); - #endif // defined(HAS_DEVICE_MANAGEMENT) // Called by operator new or operator new[] when they cannot satisfy @@ -342,10 +340,6 @@ class GoopdateImpl { std::unique_ptr exception_handler_; std::unique_ptr thread_pool_; -#if defined(HAS_DEVICE_MANAGEMENT) - std::unique_ptr dm_storage_; -#endif - Goopdate* goopdate_; DISALLOW_COPY_AND_ASSIGN(GoopdateImpl); @@ -406,6 +400,10 @@ GoopdateImpl::~GoopdateImpl() { Stop(); +#if defined(HAS_DEVICE_MANAGEMENT) + DmStorage::DeleteInstance(); +#endif + // Bug 994348 does not repro anymore. // If the assert fires, clean up the key, and fix the code if we have unit // tests or application code that create the key. @@ -761,6 +759,12 @@ HRESULT GoopdateImpl::ExecuteMode(bool* has_ui_been_displayed) { VERIFY1(SUCCEEDED(SetBackgroundPriorityIfNeeded(mode))); +#if defined(HAS_DEVICE_MANAGEMENT) + // Reference the DmStorage instance here so the singleton can be created + // before use. + VERIFY1(SUCCEEDED(DmStorage::CreateInstance(args_.extra.enrollment_token))); +#endif + #pragma warning(push) // C4061: enumerator 'xxx' in switch of enum 'yyy' is not explicitly handled by // a case label. @@ -1180,6 +1184,17 @@ HRESULT GoopdateImpl::DoInstall(bool* has_ui_been_displayed) { if (FAILED(hr)) { return hr; // Mandatory registration failed. } + + // TODO(ganesh): It is desirable to separate the execution paths of + // installs/updates and policy fetch. Once we have the Firebase Messaging + // feature solidified, we can move the policy fetch logic over there. + hr = dm_client::RefreshPolicies(); + if (FAILED(hr)) { + OPT_LOG(LE, (_T("[RefreshPolicies failed][%#x]"), hr)); + LogErrorWithHResult(kRefreshPoliciesFailedEventId, + _T("Device management policy refresh failed"), + hr); + } } #endif // defined(HAS_DEVICE_MANAGEMENT) @@ -1320,8 +1335,11 @@ HRESULT GoopdateImpl::DoUpdateAllApps(bool* has_ui_been_displayed ) { // - Non-mandatory registration failed during installation. // - An enrollment token was provisioned to the machine via Group Policy after // installation. + // TODO(ganesh): It is desirable to separate the execution paths of + // installs/updates and policy fetch. Once we have the Firebase Messaging + // feature solidified, we can move the policy fetch logic over there. if (is_machine_) { - hr = dm_client::RegisterIfNeeded(GetDmStorage()); + hr = dm_client::RegisterIfNeeded(DmStorage::Instance()); if (FAILED(hr)) { OPT_LOG(LE, (_T("[Registration failed][%#x]"), hr)); // Emit to the Event Log. The entry will include details by way of @@ -1329,6 +1347,14 @@ HRESULT GoopdateImpl::DoUpdateAllApps(bool* has_ui_been_displayed ) { LogErrorWithHResult(kEnrollmentFailedEventId, _T("Device management enrollment failed"), hr); + } else { + hr = dm_client::RefreshPolicies(); + if (FAILED(hr)) { + OPT_LOG(LE, (_T("[RefreshPolicies failed][%#x]"), hr)); + LogErrorWithHResult(kRefreshPoliciesFailedEventId, + _T("Device management policy refresh failed"), + hr); + } } } #endif // defined(HAS_DEVICE_MANAGEMENT) @@ -1621,7 +1647,7 @@ HRESULT GoopdateImpl::RegisterForDeviceManagement() { return S_FALSE; } - DmStorage* const dm_storage = GetDmStorage(); + DmStorage* const dm_storage = DmStorage::Instance(); const bool is_enrollment_mandatory = ConfigManager::Instance()->IsCloudManagementEnrollmentMandatory(); @@ -1668,14 +1694,6 @@ HRESULT GoopdateImpl::RegisterForDeviceManagement() { return is_enrollment_mandatory ? hr : S_FALSE; } -DmStorage* GoopdateImpl::GetDmStorage() { - ASSERT1(is_machine_); - if (!dm_storage_.get()) { - dm_storage_.reset(new DmStorage(args_.extra.enrollment_token)); - } - return dm_storage_.get(); -} - #endif // defined(HAS_DEVICE_MANAGEMENT) void GoopdateImpl::OutOfMemoryHandler() { diff --git a/omaha/setup/setup_google_update.cc b/omaha/setup/setup_google_update.cc index 34657367b..409b52d19 100644 --- a/omaha/setup/setup_google_update.cc +++ b/omaha/setup/setup_google_update.cc @@ -17,6 +17,7 @@ #include #include +#include #include #include "base/basictypes.h" #include "omaha/base/app_util.h" @@ -694,12 +695,14 @@ HRESULT SetupGoogleUpdate::UninstallPreviousVersions() { return HRESULT_FROM_WIN32(err); } - // The download and install directories are left alone here. They are cleaned - // up by DownloadManager::Initialize() and InstallManager::InstallManager(). - CPath download_dir(OMAHA_REL_DOWNLOAD_STORAGE_DIR); - download_dir.StripPath(); - CPath install_dir(OMAHA_REL_INSTALL_WORKING_DIR); - install_dir.StripPath(); + // The following subdirectories are not deleted here. + std::set excluded_directories = { + _T(".."), + _T("."), + this_version_, + DOWNLOAD_DIR_NAME, // Managed by DownloadManager::Initialize(). + INSTALL_WORKING_DIR_NAME, // Managed by InstallManager::InstallManager(). + }; BOOL found_next = TRUE; for (; found_next; found_next = ::FindNextFile(get(find_handle), @@ -711,11 +714,7 @@ HRESULT SetupGoogleUpdate::UninstallPreviousVersions() { if (_tcsicmp(file_data.cFileName, kOmahaShellFileName)) { DeleteBeforeOrAfterReboot(file_or_directory); } - } else if (_tcscmp(file_data.cFileName, _T("..")) && - _tcscmp(file_data.cFileName, _T(".")) && - _tcsicmp(file_data.cFileName, this_version_) && - _tcsicmp(file_data.cFileName, download_dir) && - _tcsicmp(file_data.cFileName, install_dir)) { + } else if (!excluded_directories.count(file_data.cFileName)) { // Unregister the previous version OneClick if it exists. Ignore // failures. The file is named npGoogleOneClick*.dll. CPath old_oneclick(file_or_directory);