Skip to content
This repository has been archived by the owner on Sep 24, 2020. It is now read-only.

Commit

Permalink
Merge branch 'master' of github.com:wandb/client-ng into artifacts/meta
Browse files Browse the repository at this point in the history
  • Loading branch information
annirudh committed Sep 9, 2020
2 parents ed52736 + 37da261 commit 0fb2ba9
Show file tree
Hide file tree
Showing 161 changed files with 13,590 additions and 55 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ include = '''
| wandb/integration/keras/
| wandb/integration/sagemaker/
| wandb/sync/
| wandb/secretagent/
'''
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Click>=7.0
GitPython>=1.0.0
gql==0.2.0
python-dateutil>=2.6.1
requests>=2.0.0
requests>=2.0.0,<3
promise>=2.0,<3
shortuuid>=0.5.0
six>=1.13.0
watchdog>=0.8.3
Expand Down
34 changes: 28 additions & 6 deletions tests/utils/mock_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from datetime import datetime, timedelta
import json
import yaml
# HACK: restore first two entries of sys path after wandb load
save_path = sys.path[:2]
import wandb
sys.path[0:0] = save_path
import logging
from six.moves import urllib
import threading
Expand Down Expand Up @@ -51,6 +54,28 @@ def run(ctx):

stopped = ctx.get("stopped", False)

# for wandb_tests::wandb_restore_name_not_found
# if there is a fileName query, and this query is for nofile.h5
# return an empty file. otherwise, return the usual weights.h5
if ctx.get('graphql'):
fileNames = ctx['graphql'][-1]['variables'].get('fileNames')
else:
fileNames = None
if fileNames == ["nofile.h5"]:
fileNode = {
"name": "nofile.h5",
"sizeBytes": 0,
"md5": "0",
"url": request.url_root + "/storage?file=nofile.h5",
}
else:
fileNode = {
"name": "weights.h5",
"sizeBytes": 20,
"md5": "XXX",
"url": request.url_root + "/storage?file=weights.h5",
}

return {
"id": "test",
"name": "wild-test",
Expand All @@ -71,12 +96,7 @@ def run(ctx):
# Special weights url meant to be used with api_mocks#download_url
"edges": [
{
"node": {
"name": "weights.h5",
"sizeBytes": 20,
"md5": "XXX",
"url": request.url_root + "/storage?file=weights.h5",
}
"node": fileNode,
}
]
},
Expand Down Expand Up @@ -563,6 +583,8 @@ def storage():
size = ctx["files"].get(request.args.get("file"))
if request.method == "GET" and size:
return os.urandom(size), 200
# make sure to read the data
data = request.get_data()
if file == "wandb_manifest.json":
return {
"version": 1,
Expand Down
6 changes: 6 additions & 0 deletions tests/wandb_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,12 @@ def test_restore(runner, mock_server, wandb_init_run):
assert os.path.getsize(res.name) == 10000


def test_restore_name_not_found(runner, mock_server, wandb_init_run):
with runner.isolated_filesystem():
with pytest.raises(ValueError):
wandb.restore("nofile.h5")


@pytest.mark.wandb_args(env={"WANDB_RUN_ID": "123456"})
def test_run_id(wandb_init_run):
assert wandb.run.id == "123456"
Expand Down
1 change: 1 addition & 0 deletions wandb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from wandb.wandb_controller import sweep, controller

from wandb import superagent
from wandb.secretagent import secretagent

# from wandb.core import *
from wandb.viz import visualize
Expand Down
3 changes: 3 additions & 0 deletions wandb/apis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
api.
"""

import sys
from wandb import util
util.vendor_setup()
from .internal import Api as InternalApi
from .public import Api as PublicApi

3 changes: 3 additions & 0 deletions wandb/apis/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def set_current_run_id(self, run_id):
def viewer(self):
return self.api.viewer()

def viewer_server_info(self):
return self.api.viewer_server_info()

def list_projects(self, entity=None):
return self.api.list_projects(entity=entity)

Expand Down
43 changes: 43 additions & 0 deletions wandb/internal/internal_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,48 @@ def viewer(self):
res = self.gql(query)
return res.get('viewer') or {}

@normalize_exceptions
def viewer_server_info(self):
cli_query = '''
serverInfo {
cliVersionInfo
}
'''
query_str = '''
query Viewer{
viewer {
id
entity
flags
teams {
edges {
node {
name
}
}
}
}
_CLI_QUERY_
}
'''
query_new = gql(query_str.replace("_CLI_QUERY_", cli_query))
query_old = gql(query_str.replace("_CLI_QUERY_", ""))

for query in query_new, query_old:
try:
res = self.gql(query)
except UsageError as e:
raise(e)
except Exception as e:
# graphql schema exception is generic
err = e
continue
err = None
break
if err:
raise(err)
return res.get('viewer') or {}, res.get('serverInfo') or {}

@normalize_exceptions
def list_projects(self, entity=None):
"""Lists projects in W&B scoped by entity.
Expand Down Expand Up @@ -1003,6 +1045,7 @@ def upload_file(self, url, file, callback=None, extra_headers={}):
url, data=progress, headers=extra_headers)
response.raise_for_status()
except requests.exceptions.RequestException as e:
logger.error("upload_file exception {} {}".format(url, e))
status_code = e.response.status_code if e.response != None else 0
# We need to rewind the file for the next retry (the file passed in is seeked to 0)
progress.rewind()
Expand Down
8 changes: 7 additions & 1 deletion wandb/internal/sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,10 @@ def send_request_login(self, record):
# TODO: do something with api_key or anonymous?
# TODO: return an error if we aren't logged in?
self._api.reauth()
viewer = self._api.viewer()
viewer_tuple = self._api.viewer_server_info()
# self._login_flags = json.loads(viewer.get("flags", "{}"))
# self._login_entity = viewer.get("entity")
viewer, server_info = viewer_tuple
login_entity = viewer.get("entity")
if record.control.req_resp:
result = wandb_internal_pb2.Result(uuid=record.uuid)
Expand Down Expand Up @@ -346,6 +347,9 @@ def send_run(self, data):
storage_id = ups.get("id")
if storage_id:
self._run.storage_id = storage_id
id = ups.get("name")
if id:
self._api.set_current_run_id(id)
display_name = ups.get("displayName")
if display_name:
self._run.display_name = display_name
Expand All @@ -355,12 +359,14 @@ def send_run(self, data):
if project_name:
self._run.project = project_name
self._project = project_name
self._api.set_setting("project", project_name)
entity = project.get("entity")
if entity:
entity_name = entity.get("name")
if entity_name:
self._run.entity = entity_name
self._entity = entity_name
self._api.set_setting("entity", entity_name)
sweep_id = ups.get("sweepName")
if sweep_id:
self._run.sweep_id = sweep_id
Expand Down
57 changes: 53 additions & 4 deletions wandb/lib/redirect.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
util/redirect.
"""

import io
import logging
import os
import sys
Expand Down Expand Up @@ -30,6 +31,33 @@ def __getattr__(self, attr):
return getattr(self.stream, attr)


class StreamFork(object):

def __init__(self, output_streams, unbuffered=False):
self.output_streams = output_streams
self.unbuffered = unbuffered

def write(self, data):
output_streams = object.__getattribute__(self, 'output_streams')
unbuffered = object.__getattribute__(self, 'unbuffered')
for stream in output_streams:
stream.write(data)
if unbuffered:
stream.flush()

def writelines(self, datas):
output_streams = object.__getattribute__(self, 'output_streams')
unbuffered = object.__getattribute__(self, 'unbuffered')
for stream in output_streams:
stream.writelines(datas)
if unbuffered:
stream.flush()

def __getattr__(self, attr):
output_streams = object.__getattribute__(self, 'output_streams')
return getattr(output_streams[0], attr)


class StreamWrapper(object):
def __init__(self, name, cb, output_writer=None):
self.name = name
Expand Down Expand Up @@ -112,16 +140,37 @@ def __init__(self, src, dest, unbuffered=False, tee=False):
self._old_fd = None
self._old_fp = None

_src = getattr(sys, src)
if _src != getattr(sys, "__%s__" % src):
if hasattr(_src, "fileno"):
try:
_src.fileno()
self._io_wrapped = False
except io.UnsupportedOperation:
self._io_wrapped = True
else:
self._io_wrapped = True
else:
self._io_wrapped = False

def _redirect(self, to_fd, unbuffered=False, close=False):
if close:
fp = getattr(sys, self._stream)
# TODO(jhr): does this still work under windows? are we leaking a fd?
# Do not close old filedescriptor as others might be using it
fp.close()
os.dup2(to_fd, self._old_fd)
setattr(sys, self._stream, os.fdopen(self._old_fd, "w"))
if unbuffered:
setattr(sys, self._stream, Unbuffered(getattr(sys, self._stream)))
if self._io_wrapped:
if close:
setattr(sys, self._stream, getattr(sys, self._stream).output_streams[0])
else:
setattr(sys, self._stream, StreamFork([getattr(sys, self._stream),
os.fdopen(self._old_fd, "w")],
unbuffered=unbuffered))
else:
setattr(sys, self._stream, os.fdopen(self._old_fd, "w"))
if unbuffered:
setattr(sys, self._stream, Unbuffered(getattr(sys, self._stream)))

def install(self):
if self._installed:
Expand All @@ -137,7 +186,7 @@ def install(self):

logger.info("install start")

fp = getattr(sys, self._stream)
fp = getattr(sys, "__%s__" % self._stream)
fd = fp.fileno()
old_fp = os.fdopen(os.dup(fd), "w")

Expand Down
8 changes: 4 additions & 4 deletions wandb/lib/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ def __init__(self, api=None):

def query_with_timeout(self, timeout=None):
timeout = timeout or 5
async_viewer = util.async_call(self._api.viewer, timeout=timeout)
viewer, viewer_thread = async_viewer()
async_viewer = util.async_call(self._api.viewer_server_info, timeout=timeout)
viewer_tuple, viewer_thread = async_viewer()
if viewer_thread.is_alive():
self._error_network = True
return
self._error_network = False
# TODO(jhr): should we kill the thread?
self._viewer = viewer
self._flags = json.loads(viewer.get("flags", "{}"))
self._viewer, self._serverinfo = viewer_tuple
self._flags = json.loads(self._viewer.get("flags", "{}"))

def is_valid(self):
if self._error_network is None:
Expand Down
29 changes: 24 additions & 5 deletions wandb/sdk/wandb_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from wandb.errors import Error
from wandb.interface.summary_record import SummaryRecord
from wandb.lib import filenames, module, proto_util, redirect, sparkline
from wandb.util import sentry_set_scope, to_forward_slash_path
from wandb.util import add_import_hook, sentry_set_scope, to_forward_slash_path
from wandb.viz import Visualize

from . import wandb_config
Expand Down Expand Up @@ -836,6 +836,7 @@ def restore(
Raises:
wandb.CommError if it can't find the run
ValueError if the file is not found
"""

# TODO: handle restore outside of a run context?
Expand All @@ -849,6 +850,9 @@ def restore(
files = api_run.files([name])
if len(files) == 0:
return None
# if the file does not exist, the file has an md5 of 0
if files[0].md5 == "0":
raise ValueError("File {} not found.".format(path))
return files[0].download(root=root, replace=True)

def finish(self, exit_code=None):
Expand Down Expand Up @@ -968,8 +972,9 @@ def _display_run(self):
if not self._settings._offline:
wandb.termlog("Run `wandb off` to turn off syncing.")

def _redirect(self, stdout_slave_fd, stderr_slave_fd):
console = self._settings._console
def _redirect(self, stdout_slave_fd, stderr_slave_fd, console=None):
if console is None:
console = self._settings._console
logger.info("redirect: %s", console)

if console == self._settings.Console.REDIRECT:
Expand All @@ -986,8 +991,22 @@ def _redirect(self, stdout_slave_fd, stderr_slave_fd):
err_redir = redirect.Redirect(
src="stderr", dest=err_cap, unbuffered=True, tee=True
)
elif console == self._settings.Console.NOTEBOOK:
logger.info("Redirecting notebook output.")
if os.name == "nt":

def wrap_fallback():
self._out_redir.uninstall()
self._err_redir.uninstall()
msg = (
"Tensorflow detected. Stream redirection is not supported "
"on Windows when tensorflow is imported. Falling back to "
"wrapping stdout/err."
)
wandb.termlog(msg)
self._redirect(None, None, console=self._settings.Console.WRAP)

add_import_hook("tensorflow", wrap_fallback)
elif console == self._settings.Console.WRAP:
logger.info("Wrapping output streams.")
out_redir = redirect.StreamWrapper(
name="stdout", cb=self._redirect_cb, output_writer=self._output_writer
)
Expand Down

0 comments on commit 0fb2ba9

Please sign in to comment.