Skip to content

Commit

Permalink
More tests less code
Browse files Browse the repository at this point in the history
  • Loading branch information
vanpelt committed Sep 25, 2021
1 parent f60634e commit 52f4312
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 18 deletions.
9 changes: 8 additions & 1 deletion tests/test_lib.py
Expand Up @@ -4,8 +4,15 @@

def test_write_netrc():
api_key = "X" * 40
wandb_lib.apikey.write_netrc("http://localhost", "vanpelt", api_key)
res = wandb_lib.apikey.write_netrc("http://localhost", "vanpelt", api_key)
assert res
with open(os.path.expanduser("~/.netrc")) as f:
assert f.read() == (
"machine localhost\n" " login vanpelt\n" " password %s\n" % api_key
)


def test_write_netrc_invalid_host():
api_key = "X" * 40
res = wandb_lib.apikey.write_netrc("http://foo", "vanpelt", api_key)
assert res is None
28 changes: 27 additions & 1 deletion tests/test_public_api.py
Expand Up @@ -9,7 +9,7 @@
import json
import pytest
import platform
import sys
import requests

import wandb
from wandb import Api
Expand Down Expand Up @@ -564,11 +564,15 @@ def test_viewer(mock_server, api):
v = api.viewer
assert v.admin is False
assert v.username == "mock"
assert v.api_keys == []
assert v.teams == []


def test_create_service_account(mock_server, api):
t = api.team("test")
assert t.create_service_account("My service account").api_key == "Y" * 40
mock_server.set_context("graphql_conflict", True)
assert t.create_service_account("My service account") is None


def test_create_team(mock_server, api):
Expand All @@ -577,15 +581,25 @@ def test_create_team(mock_server, api):
assert repr(t) == "<Team test>"


def test_create_team_exists(mock_server, api):
mock_server.set_context("graphql_conflict", True)
with pytest.raises(requests.exceptions.HTTPError):
api.create_team("test")


def test_invite_user(mock_server, api):
t = api.team("test")
assert t.invite("test@test.com")
assert t.invite("test")
mock_server.set_context("graphql_conflict", True)
assert t.invite("conflict") == False


def test_delete_member(mock_server, api):
t = api.team("test")
assert t.members[0].delete()
mock_server.set_context("graphql_conflict", True)
assert t.invite("conflict") == False


def test_query_user(mock_server, api):
Expand All @@ -596,11 +610,23 @@ def test_query_user(mock_server, api):
assert repr(u) == "<User test@test.com>"


def test_query_user_multiple(mock_server, api):
mock_server.set_context("num_search_users", 2)
u = api.user("test@test.com")
assert u.email == "test@test.com"
users = api.users("test")
assert len(users) == 2


def test_delete_api_key(mock_server, api):
u = api.user("test@test.com")
assert u.delete_api_key("Y" * 40)
mock_server.set_context("graphql_conflict", True)
assert u.delete_api_key("Y" * 40) == False


def test_generate_api_key(mock_server, api):
u = api.user("test@test.com")
assert u.generate_api_key()
mock_server.set_context("graphql_conflict", True)
assert u.generate_api_key() is None
6 changes: 6 additions & 0 deletions tests/utils/mock_server.py
Expand Up @@ -43,6 +43,8 @@ def default_ctx():
"fail_graphql_count": 0, # used via "fail_graphql_times"
"fail_storage_count": 0, # used via "fail_storage_times"
"rate_limited_count": 0, # used via "rate_limited_times"
"graphql_conflict": False,
"num_search_users": 1,
"page_count": 0,
"page_times": 2,
"requested_file": "weights.h5",
Expand Down Expand Up @@ -374,6 +376,8 @@ def graphql():
if ctx["rate_limited_count"] < ctx["rate_limited_times"]:
ctx["rate_limited_count"] += 1
return json.dumps({"error": "rate limit exceeded"}), 429
if ctx["graphql_conflict"]:
return json.dumps({"error": "resource already exists"}), 409

# Setup artifact emulator (should this be somewhere else?)
emulate_random_str = ctx["emulate_artifacts"]
Expand Down Expand Up @@ -1118,6 +1122,7 @@ def graphql():
}
}
if "query SearchUsers" in body["query"]:

return {
"data": {
"users": {
Expand All @@ -1139,6 +1144,7 @@ def graphql():
}
}
]
* ctx["num_search_users"]
}
}
}
Expand Down
31 changes: 15 additions & 16 deletions wandb/apis/public.py
Expand Up @@ -570,6 +570,18 @@ def user(self, username_or_email):
)
return User(self._client, res["users"]["edges"][0]["node"])

def users(self, username_or_email):
"""Return all users from a partial username or email address query
Arguments:
username_or_email: (str) The prefix or suffix of the user you want to find
Returns:
An array of `User` objects
"""
res = self._client.execute(self.USERS_QUERY, {"query": username_or_email})
return [User(self._client, edge["node"]) for edge in res["users"]["edges"]]

def runs(self, path="", filters=None, order="-created_at", per_page=50):
"""
Return a set of runs from a project that match the filters provided.
Expand Down Expand Up @@ -850,13 +862,13 @@ def __init__(self, client, attrs):

@property
def api_keys(self):
if self._attrs["apiKeys"] is None:
if self._attrs.get("apiKeys") is None:
return []
return [k["node"]["name"] for k in self._attrs["apiKeys"]["edges"]]

@property
def teams(self):
if self._attrs["teams"] is None:
if self._attrs.get("teams") is None:
return []
return [k["node"]["name"] for k in self._attrs["teams"]["edges"]]

Expand Down Expand Up @@ -1031,7 +1043,7 @@ def create(cls, api, team, admin_username=None):
{"teamName": team, "teamAdminUserName": admin_username},
)
except requests.exceptions.HTTPError:
return api.team(team)
pass
return Team(api, team)

def invite(self, username_or_email, admin=False):
Expand Down Expand Up @@ -1913,19 +1925,6 @@ def url(self):
path.insert(2, "runs")
return self.client.app_url + "/".join(path)

def _get_run_url(self):
return self.url

def _get_sweep_url(self):
if self.sweep:
return self.sweep.url
return ""

def _get_project_url(self):
path = self.path
path.pop()
return self.client.app_url + "/".join(path + ["workspace"])

@property
def lastHistoryStep(self): # noqa: N802
query = gql(
Expand Down

0 comments on commit 52f4312

Please sign in to comment.