Skip to content
This repository has been archived by the owner on Sep 24, 2020. It is now read-only.

Commit

Permalink
Merge branch 'master' into release/cli-0.10.0rc
Browse files Browse the repository at this point in the history
  • Loading branch information
raubitsj committed Sep 10, 2020
2 parents d968d7d + 60835d5 commit df05bf7
Show file tree
Hide file tree
Showing 175 changed files with 14,012 additions and 190 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ We use protocol buffers to communicate from the user process to the wandb backen
If you update any of the *.proto files in wandb/proto, you'll need to run:

```
cd wandb/proto && python wandb_internal_codegen.py
make proto
```

## Code checks
Expand Down
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ test-short:
format:
tox -e format

proto:
cd wandb/proto && python wandb_internal_codegen.py

clean-pyc: ## remove Python file artifacts
find . -name '*.pyc' -exec rm -f {} +
find . -name '*.pyo' -exec rm -f {} +
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ include = '''
| wandb/integration/keras/
| wandb/integration/sagemaker/
| wandb/sync/
| wandb/secretagent/
'''
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Click>=7.0
GitPython>=1.0.0
gql==0.2.0
python-dateutil>=2.6.1
requests>=2.0.0
requests>=2.0.0,<3
promise>=2.0,<3
shortuuid>=0.5.0
six>=1.13.0
watchdog>=0.8.3
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_settings(test_dir, mocker):
@pytest.fixture
def mocked_run(runner, test_settings):
""" A managed run object for tests with a mock backend """
run = wandb.wandb_sdk.wandb_run.RunManaged(settings=test_settings)
run = wandb.wandb_sdk.wandb_run.Run(settings=test_settings)
run._set_backend(MagicMock())
yield run

Expand Down
2 changes: 1 addition & 1 deletion tests/test_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_image_accepts_masks_without_class_labels(mocked_run):
def test_cant_serialize_to_other_run(mocked_run, test_settings):
"""This isn't implemented yet. Should work eventually.
"""
other_run = wandb.wandb_sdk.wandb_run.RunManaged(settings=test_settings)
other_run = wandb.wandb_sdk.wandb_run.Run(settings=test_settings)
other_run._set_backend(mocked_run._backend)
wb_image = wandb.Image(image)

Expand Down
34 changes: 28 additions & 6 deletions tests/utils/mock_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from datetime import datetime, timedelta
import json
import yaml
# HACK: restore first two entries of sys path after wandb load
save_path = sys.path[:2]
import wandb
sys.path[0:0] = save_path
import logging
from six.moves import urllib
import threading
Expand Down Expand Up @@ -51,6 +54,28 @@ def run(ctx):

stopped = ctx.get("stopped", False)

# for wandb_tests::wandb_restore_name_not_found
# if there is a fileName query, and this query is for nofile.h5
# return an empty file. otherwise, return the usual weights.h5
if ctx.get('graphql'):
fileNames = ctx['graphql'][-1]['variables'].get('fileNames')
else:
fileNames = None
if fileNames == ["nofile.h5"]:
fileNode = {
"name": "nofile.h5",
"sizeBytes": 0,
"md5": "0",
"url": request.url_root + "/storage?file=nofile.h5",
}
else:
fileNode = {
"name": "weights.h5",
"sizeBytes": 20,
"md5": "XXX",
"url": request.url_root + "/storage?file=weights.h5",
}

return {
"id": "test",
"name": "wild-test",
Expand All @@ -71,12 +96,7 @@ def run(ctx):
# Special weights url meant to be used with api_mocks#download_url
"edges": [
{
"node": {
"name": "weights.h5",
"sizeBytes": 20,
"md5": "XXX",
"url": request.url_root + "/storage?file=weights.h5",
}
"node": fileNode,
}
]
},
Expand Down Expand Up @@ -563,6 +583,8 @@ def storage():
size = ctx["files"].get(request.args.get("file"))
if request.method == "GET" and size:
return os.urandom(size), 200
# make sure to read the data
data = request.get_data()
if file == "wandb_manifest.json":
return {
"version": 1,
Expand Down
6 changes: 6 additions & 0 deletions tests/wandb_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,12 @@ def test_restore(runner, mock_server, wandb_init_run):
assert os.path.getsize(res.name) == 10000


def test_restore_name_not_found(runner, mock_server, wandb_init_run):
with runner.isolated_filesystem():
with pytest.raises(ValueError):
wandb.restore("nofile.h5")


@pytest.mark.wandb_args(env={"WANDB_RUN_ID": "123456"})
def test_run_id(wandb_init_run):
assert wandb.run.id == "123456"
Expand Down
16 changes: 10 additions & 6 deletions wandb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from wandb.wandb_controller import sweep, controller

from wandb import superagent
from wandb.jupyteragent import jupyteragent as _secretagent

# from wandb.core import *
from wandb.viz import visualize
Expand Down Expand Up @@ -109,22 +110,25 @@ def _is_internal_process():
config = _preinit.PreInitObject("wandb.config")
summary = _preinit.PreInitObject("wandb.summary")
log = _preinit.PreInitCallable(
"wandb.log", wandb_sdk.wandb_run.RunManaged.log
"wandb.log", wandb_sdk.wandb_run.Run.log
)
join = _preinit.PreInitCallable(
"wandb.join", wandb_sdk.wandb_run.RunManaged.join
"wandb.join", wandb_sdk.wandb_run.Run.join
)
finish = _preinit.PreInitCallable(
"wandb.finish", wandb_sdk.wandb_run.Run.finish
)
save = _preinit.PreInitCallable(
"wandb.save", wandb_sdk.wandb_run.RunManaged.save
"wandb.save", wandb_sdk.wandb_run.Run.save
)
restore = _preinit.PreInitCallable(
"wandb.restore", wandb_sdk.wandb_run.RunManaged.restore
"wandb.restore", wandb_sdk.wandb_run.Run.restore
)
use_artifact = _preinit.PreInitCallable(
"wandb.use_artifact", wandb_sdk.wandb_run.RunManaged.use_artifact
"wandb.use_artifact", wandb_sdk.wandb_run.Run.use_artifact
)
log_artifact = _preinit.PreInitCallable(
"wandb.log_artifact", wandb_sdk.wandb_run.RunManaged.log_artifact
"wandb.log_artifact", wandb_sdk.wandb_run.Run.log_artifact
)
# record of patched libraries
patched = {"tensorboard": [], "keras": [], "gym": []}
Expand Down
3 changes: 3 additions & 0 deletions wandb/apis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
api.
"""

import sys
from wandb import util
util.vendor_setup()
from .internal import Api as InternalApi
from .public import Api as PublicApi

3 changes: 3 additions & 0 deletions wandb/apis/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def set_current_run_id(self, run_id):
def viewer(self):
return self.api.viewer()

def viewer_server_info(self):
return self.api.viewer_server_info()

def list_projects(self, entity=None):
return self.api.list_projects(entity=entity)

Expand Down
6 changes: 3 additions & 3 deletions wandb/interface/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,10 @@ def _make_artifact(self, artifact):
proto_artifact.type = artifact.type
proto_artifact.name = artifact.name
proto_artifact.digest = artifact.digest
if proto_artifact.description:
if artifact.description:
proto_artifact.description = artifact.description
if proto_artifact.metadata:
proto_artifact.metadata = artifact.metadata
if artifact.metadata:
proto_artifact.metadata = json.dumps(artifact.metadata)
self._make_artifact_manifest(artifact.manifest, obj=proto_artifact.manifest)
return proto_artifact

Expand Down
7 changes: 6 additions & 1 deletion wandb/internal/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,12 @@ def handle_request_defer(self, record):

logger.info("handle defer: {}".format(state))
# only handle flush tb (sender handles the rest)
if state == defer.FLUSH_TB:
if state == defer.FLUSH_STATS:
if self._system_stats:
# TODO(jhr): this could block so we dont really want to call shutdown
# from handler thread
self._system_stats.shutdown()
elif state == defer.FLUSH_TB:
if self._tb_watcher:
# shutdown tensorboard workers so we get all metrics flushed
self._tb_watcher.finish()
Expand Down
43 changes: 43 additions & 0 deletions wandb/internal/internal_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,48 @@ def viewer(self):
res = self.gql(query)
return res.get('viewer') or {}

@normalize_exceptions
def viewer_server_info(self):
cli_query = '''
serverInfo {
cliVersionInfo
}
'''
query_str = '''
query Viewer{
viewer {
id
entity
flags
teams {
edges {
node {
name
}
}
}
}
_CLI_QUERY_
}
'''
query_new = gql(query_str.replace("_CLI_QUERY_", cli_query))
query_old = gql(query_str.replace("_CLI_QUERY_", ""))

for query in query_new, query_old:
try:
res = self.gql(query)
except UsageError as e:
raise(e)
except Exception as e:
# graphql schema exception is generic
err = e
continue
err = None
break
if err:
raise(err)
return res.get('viewer') or {}, res.get('serverInfo') or {}

@normalize_exceptions
def list_projects(self, entity=None):
"""Lists projects in W&B scoped by entity.
Expand Down Expand Up @@ -1003,6 +1045,7 @@ def upload_file(self, url, file, callback=None, extra_headers={}):
url, data=progress, headers=extra_headers)
response.raise_for_status()
except requests.exceptions.RequestException as e:
logger.error("upload_file exception {} {}".format(url, e))
status_code = e.response.status_code if e.response != None else 0
# We need to rewind the file for the next retry (the file passed in is seeked to 0)
progress.rewind()
Expand Down
6 changes: 3 additions & 3 deletions wandb/internal/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
from wandb.sdk_py27 import wandb_run


class InternalRun(wandb_run.RunManaged):
class InternalRun(wandb_run.Run):
def __init__(self, run_obj, settings):
super(InternalRun, self).__init__(settings=settings)
self._run_obj = run_obj

# TODO: This undoes what's done in the constructor of RunManaged. Probably what
# really want is a common interface for RunManaged and InternalRun.
# TODO: This undoes what's done in the constructor of wandb_run.Run.
# We really want a common interface for wandb_run.Run and InternalRun.
data_types._datatypes_set_callback(None)

def _set_backend(self, backend):
Expand Down
15 changes: 13 additions & 2 deletions wandb/internal/sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,12 @@ def send_request_login(self, record):
# TODO: do something with api_key or anonymous?
# TODO: return an error if we aren't logged in?
self._api.reauth()
viewer = self._api.viewer()
viewer_tuple = self._api.viewer_server_info()
# self._login_flags = json.loads(viewer.get("flags", "{}"))
# self._login_entity = viewer.get("entity")
viewer, server_info = viewer_tuple
if server_info:
logger.info("Login server info: {}".format(server_info))
login_entity = viewer.get("entity")
if record.control.req_resp:
result = wandb_internal_pb2.Result(uuid=record.uuid)
Expand Down Expand Up @@ -163,6 +166,9 @@ def send_request_defer(self, data):
done = False
if state == defer.BEGIN:
pass
elif state == defer.FLUSH_STATS:
# NOTE: this is handled in handler.py:handle_request_defer()
pass
elif state == defer.FLUSH_TB:
# NOTE: this is handled in handler.py:handle_request_defer()
pass
Expand Down Expand Up @@ -346,6 +352,9 @@ def send_run(self, data):
storage_id = ups.get("id")
if storage_id:
self._run.storage_id = storage_id
id = ups.get("name")
if id:
self._api.set_current_run_id(id)
display_name = ups.get("displayName")
if display_name:
self._run.display_name = display_name
Expand All @@ -355,12 +364,14 @@ def send_run(self, data):
if project_name:
self._run.project = project_name
self._project = project_name
self._api.set_setting("project", project_name)
entity = project.get("entity")
if entity:
entity_name = entity.get("name")
if entity_name:
self._run.entity = entity_name
self._entity = entity_name
self._api.set_setting("entity", entity_name)
sweep_id = ups.get("sweepName")
if sweep_id:
self._run.sweep_id = sweep_id
Expand Down Expand Up @@ -514,7 +525,7 @@ def send_artifact(self, data):
saver.save(
type=artifact.type,
name=artifact.name,
metadata=artifact.metadata,
metadata=json.loads(artifact.metadata),
description=artifact.description,
aliases=artifact.aliases,
use_after_commit=artifact.use_after_commit,
Expand Down
3 changes: 2 additions & 1 deletion wandb/internal/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def _thread_body(self):
time.sleep(0.1)
seconds += 0.1
if self._shutdown:
break
self.flush()
return

def shutdown(self):
self._shutdown = True
Expand Down
3 changes: 3 additions & 0 deletions wandb/jupyteragent/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .agent import agent as jupyteragent

__all__ = ["jupyteragent"]

0 comments on commit df05bf7

Please sign in to comment.