Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable GCS filesystem for Windows #14856

Merged
merged 1 commit into from
Dec 9, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 14 additions & 0 deletions tensorflow/core/platform/cloud/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ licenses(["notice"]) # Apache 2.0
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
"tf_copts",
)

filegroup(
Expand All @@ -29,6 +30,7 @@ filegroup(
cc_library(
name = "expiring_lru_cache",
hdrs = ["expiring_lru_cache.h"],
copts = tf_copts(),
visibility = ["//tensorflow:__subpackages__"],
deps = ["//tensorflow/core:lib"],
)
Expand All @@ -37,6 +39,7 @@ cc_library(
name = "file_block_cache",
srcs = ["file_block_cache.cc"],
hdrs = ["file_block_cache.h"],
copts = tf_copts(),
visibility = ["//tensorflow:__subpackages__"],
deps = ["//tensorflow/core:lib"],
)
Expand All @@ -45,6 +48,7 @@ cc_library(
name = "gcs_dns_cache",
srcs = ["gcs_dns_cache.cc"],
hdrs = ["gcs_dns_cache.h"],
copts = tf_copts(),
visibility = ["//tensorflow:__subpackages__"],
deps = [
":http_request",
Expand All @@ -56,6 +60,7 @@ cc_library(
name = "gcs_file_system",
srcs = ["gcs_file_system.cc"],
hdrs = ["gcs_file_system.h"],
copts = tf_copts(),
linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669
visibility = ["//visibility:public"],
deps = [
Expand All @@ -78,6 +83,7 @@ cc_library(
cc_library(
name = "http_request",
hdrs = ["http_request.h"],
copts = tf_copts(),
visibility = ["//tensorflow:__subpackages__"],
deps = [
"//tensorflow/core:framework_headers_lib",
Expand All @@ -89,6 +95,7 @@ cc_library(
name = "curl_http_request",
srcs = ["curl_http_request.cc"],
hdrs = ["curl_http_request.h"],
copts = tf_copts(),
visibility = ["//tensorflow:__subpackages__"],
deps = [
":http_request",
Expand All @@ -104,6 +111,7 @@ cc_library(
hdrs = [
"http_request_fake.h",
],
copts = tf_copts(),
visibility = ["//tensorflow:__subpackages__"],
deps = [
":curl_http_request",
Expand All @@ -121,6 +129,7 @@ cc_library(
"auth_provider.h",
"google_auth_provider.h",
],
copts = tf_copts(),
visibility = ["//tensorflow:__subpackages__"],
deps = [
":curl_http_request",
Expand All @@ -136,6 +145,7 @@ cc_library(
name = "now_seconds_env",
testonly = 1,
hdrs = ["now_seconds_env.h"],
copts = tf_copts(),
visibility = ["//tensorflow:__subpackages__"],
deps = [
"//tensorflow/core:lib",
Expand All @@ -151,6 +161,7 @@ cc_library(
hdrs = [
"oauth_client.h",
],
copts = tf_copts(),
deps = [
":curl_http_request",
":http_request",
Expand All @@ -169,6 +180,7 @@ cc_library(
hdrs = [
"retrying_utils.h",
],
copts = tf_copts(),
deps = [
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:lib_internal",
Expand All @@ -183,6 +195,7 @@ cc_library(
hdrs = [
"retrying_file_system.h",
],
copts = tf_copts(),
deps = [
":retrying_utils",
"//tensorflow/core:framework_headers_lib",
Expand All @@ -198,6 +211,7 @@ cc_library(
hdrs = [
"time_util.h",
],
copts = tf_copts(),
deps = [
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:lib_internal",
Expand Down
31 changes: 22 additions & 9 deletions tensorflow/core/platform/cloud/gcs_dns_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@ limitations under the License.
==============================================================================*/

#include "tensorflow/core/platform/cloud/gcs_dns_cache.h"

#ifndef _WIN32
#include <arpa/inet.h>
#include <netdb.h>
#else
#include <winsock2.h>
#include <ws2tcpip.h>
#include <Windows.h>
#endif
#include <sys/types.h>

namespace tensorflow {
Expand All @@ -26,6 +31,20 @@ namespace {
constexpr char kStorageHost[] = "storage.googleapis.com";
constexpr char kWwwHost[] = "www.googleapis.com";

inline void print_getaddrinfo_error(const string& name, int error_code) {
#ifndef _WIN32
if (error_code == EAI_SYSTEM) {
LOG(ERROR) << "Error resolving " << name
<< " (EAI_SYSTEM): " << strerror(errno);
} else {
LOG(ERROR) << "Error resolving " << name << ": "
<< gai_strerror(error_code);
}
#else
// TODO:WSAGetLastError is better than gai_strerror
LOG(ERROR) << "Error resolving " << name << ": " << gai_strerror(error_code);
#endif
}
} // namespace

GcsDnsCache::GcsDnsCache(Env* env, int64 refresh_rate_secs)
Expand Down Expand Up @@ -77,7 +96,7 @@ Status GcsDnsCache::AnnotateRequest(HttpRequest* request) {

std::vector<string> output;
if (return_code == 0) {
for (addrinfo* i = result; i != nullptr; i = i->ai_next) {
for (const addrinfo* i = result; i != nullptr; i = i->ai_next) {
if (i->ai_family != AF_INET || i->ai_addr->sa_family != AF_INET) {
LOG(WARNING) << "Non-IPv4 address returned. ai_family: " << i->ai_family
<< ". sa_family: " << i->ai_addr->sa_family << ".";
Expand All @@ -96,13 +115,7 @@ Status GcsDnsCache::AnnotateRequest(HttpRequest* request) {
}
}
} else {
if (return_code == EAI_SYSTEM) {
LOG(ERROR) << "Error resolving " << name
<< " (EAI_SYSTEM): " << strerror(errno);
} else {
LOG(ERROR) << "Error resolving " << name << ": "
<< gai_strerror(return_code);
}
print_getaddrinfo_error(name, return_code);
}
if (result != nullptr) {
freeaddrinfo(result);
Expand Down
21 changes: 20 additions & 1 deletion tensorflow/core/platform/cloud/gcs_file_system.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ limitations under the License.
#include <cstring>
#include <fstream>
#include <vector>
#ifdef _WIN32
#include <io.h> //for _mktemp
#endif
#include "include/json/json.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
Expand All @@ -40,6 +43,12 @@ limitations under the License.
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/thread_annotations.h"

#ifdef _WIN32
#ifdef DeleteFile
#undef DeleteFile
#endif
#endif

namespace tensorflow {

namespace {
Expand Down Expand Up @@ -95,16 +104,25 @@ const FileStatistics DIRECTORY_STAT(0, 0, true);
// userspace DNS cache.
constexpr char kResolveCacheSecs[] = "GCS_RESOLVE_REFRESH_SECS";

// TODO: DO NOT use a hardcoded path
Status GetTmpFilename(string* filename) {
if (!filename) {
return errors::Internal("'filename' cannot be nullptr.");
}
#ifndef _WIN32
char buffer[] = "/tmp/gcs_filesystem_XXXXXX";
int fd = mkstemp(buffer);
if (fd < 0) {
return errors::Internal("Failed to create a temporary file.");
}
close(fd);
#else
char buffer[] = "/tmp/gcs_filesystem_XXXXXX";
char* ret = _mktemp(buffer);
if (ret == nullptr) {
return errors::Internal("Failed to create a temporary file.");
}
#endif
*filename = buffer;
return Status::OK();
}
Expand Down Expand Up @@ -292,6 +310,7 @@ class GcsWritableFile : public WritableFile {
file_cache_erase_(std::move(file_cache_erase)),
sync_needed_(true),
initial_retry_delay_usec_(initial_retry_delay_usec) {
// TODO: to make it safer, outfile_ should be constructed from an FD
if (GetTmpFilename(&tmp_content_filename_).ok()) {
outfile_.open(tmp_content_filename_,
std::ofstream::binary | std::ofstream::app);
Expand Down Expand Up @@ -416,7 +435,7 @@ class GcsWritableFile : public WritableFile {
return errors::Internal("'size' cannot be nullptr");
}
const auto tellp = outfile_.tellp();
if (tellp == -1) {
if (tellp == static_cast<std::streampos>(-1)) {
return errors::Internal(
"Could not get the size of the internal temporary file.");
}
Expand Down
5 changes: 4 additions & 1 deletion tensorflow/core/platform/cloud/google_auth_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ limitations under the License.
==============================================================================*/

#include "tensorflow/core/platform/cloud/google_auth_provider.h"
#ifndef _WIN32
#include <pwd.h>
#include <sys/types.h>
#include <unistd.h>
#else
#include <sys/types.h>
#endif
#include <fstream>
#include "include/json/json.h"
#include "tensorflow/core/lib/core/errors.h"
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/platform/cloud/oauth_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@ limitations under the License.
==============================================================================*/

#include "tensorflow/core/platform/cloud/oauth_client.h"
#ifndef _WIN32
#include <pwd.h>
#include <sys/types.h>
#include <unistd.h>
#else
#include <sys/types.h>
#endif
#include <fstream>
#include <openssl/bio.h>
#include <openssl/evp.h>
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/core/platform/cloud/time_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ limitations under the License.
#include <cmath>
#include <cstdio>
#include <ctime>
#ifdef _WIN32
#define timegm _mkgmtime
#endif
#include "tensorflow/core/lib/core/errors.h"

namespace tensorflow {
Expand Down
1 change: 0 additions & 1 deletion tensorflow/core/platform/default/build_config.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,6 @@ def tf_additional_lib_deps():

def tf_additional_core_deps():
return select({
"//tensorflow:with_gcp_support_windows_override": [],
"//tensorflow:with_gcp_support_android_override": [],
"//tensorflow:with_gcp_support_ios_override": [],
"//tensorflow:with_gcp_support": [
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ function run_configure_for_cpu_build {
export TF_NEED_MKL=0
fi
export TF_NEED_VERBS=0
export TF_NEED_GCP=0
export TF_NEED_GCP=1
export TF_NEED_HDFS=0
export TF_NEED_OPENCL_SYCL=0
echo "" | ./configure
Expand Down
26 changes: 23 additions & 3 deletions third_party/curl.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ CURL_WIN_COPTS = [
"/DHAVE_CONFIG_H",
"/DCURL_DISABLE_FTP",
"/DCURL_DISABLE_NTLM",
"/DCURL_DISABLE_PROXY",
"/DHAVE_LIBZ",
"/DHAVE_ZLIB_H",
# Defining _USING_V110_SDK71_ is hackery to defeat curl's incorrect
Expand All @@ -23,6 +24,8 @@ CURL_WIN_SRCS = [
"lib/asyn-thread.c",
"lib/inet_ntop.c",
"lib/system_win32.c",
"lib/vtls/schannel.c",
"lib/idn_win32.c",
]

cc_library(
Expand Down Expand Up @@ -276,6 +279,7 @@ cc_library(
"-DCURL_MAX_WRITE_SIZE=65536",
],
}),
defines = ["CURL_STATICLIB"],
includes = ["include"],
linkopts = select({
"@org_tensorflow//tensorflow:android": [
Expand All @@ -289,10 +293,16 @@ cc_library(
],
"@org_tensorflow//tensorflow:ios": [],
"@org_tensorflow//tensorflow:windows": [
"-Wl,ws2_32.lib",
"-DEFAULTLIB:ws2_32.lib",
"-DEFAULTLIB:advapi32.lib",
"-DEFAULTLIB:crypt32.lib",
"-DEFAULTLIB:Normaliz.lib",
],
"@org_tensorflow//tensorflow:windows_msvc": [
"-Wl,ws2_32.lib",
"-DEFAULTLIB:ws2_32.lib",
"-DEFAULTLIB:advapi32.lib",
"-DEFAULTLIB:crypt32.lib",
"-DEFAULTLIB:Normaliz.lib",
],
"//conditions:default": [
"-lrt",
Expand Down Expand Up @@ -438,12 +448,22 @@ genrule(
"# include \"lib/config-win32.h\"",
"# define BUILDING_LIBCURL 1",
"# define CURL_DISABLE_CRYPTO_AUTH 1",
"# define CURL_DISABLE_DICT 1",
"# define CURL_DISABLE_FILE 1",
"# define CURL_DISABLE_GOPHER 1",
"# define CURL_DISABLE_IMAP 1",
"# define CURL_DISABLE_LDAP 1",
"# define CURL_DISABLE_LDAPS 1",
"# define CURL_DISABLE_POP3 1",
"# define CURL_PULL_WS2TCPIP_H 1",
"# define HTTP_ONLY 1",
"# define CURL_DISABLE_SMTP 1",
"# define CURL_DISABLE_TELNET 1",
"# define CURL_DISABLE_TFTP 1",
"# define CURL_PULL_WS2TCPIP_H 1",
"# define USE_WINDOWS_SSPI 1",
"# define USE_WIN32_IDN 1",
"# define USE_SCHANNEL 1",
"# define WANT_IDN_PROTOTYPES 1",
"#elif defined(__APPLE__)",
"# define HAVE_FSETXATTR_6 1",
"# define HAVE_SETMODE 1",
Expand Down