Skip to content

Commit

Permalink
make pulling sweeps optional when using public api to query for runs (#…
Browse files Browse the repository at this point in the history
…4186)

* make pulling sweeps optional when using public api to query for runs
  • Loading branch information
kptkin committed Sep 7, 2022
1 parent dfbcbb7 commit 643b25d
Showing 1 changed file with 34 additions and 6 deletions.
40 changes: 34 additions & 6 deletions wandb/apis/public.py
Expand Up @@ -23,7 +23,7 @@
import urllib
from collections import namedtuple
from functools import partial
from typing import Dict, List, Optional
from typing import Dict, List, Mapping, Optional

import requests
from wandb_gql import Client, gql
Expand Down Expand Up @@ -755,7 +755,14 @@ def users(self, username_or_email):
res = self._client.execute(self.USERS_QUERY, {"query": username_or_email})
return [User(self._client, edge["node"]) for edge in res["users"]["edges"]]

def runs(self, path=None, filters=None, order="-created_at", per_page=50):
def runs(
self,
path: Optional[str] = None,
filters: Optional[str] = None,
order: str = "-created_at",
per_page: int = 50,
include_sweeps: bool = True,
):
"""
Return a set of runs from a project that match the filters provided.
Expand Down Expand Up @@ -823,6 +830,7 @@ def runs(self, path=None, filters=None, order="-created_at", per_page=50):
filters=filters,
order=order,
per_page=per_page,
include_sweeps=include_sweeps,
)
return self._runs[key]

Expand Down Expand Up @@ -1530,12 +1538,22 @@ class Runs(Paginator):
% RUN_FRAGMENT
)

def __init__(self, client, entity, project, filters=None, order=None, per_page=50):
def __init__(
self,
client: "RetryingClient",
entity: str,
project: str,
filters: Optional[str] = None,
order: Optional[str] = None,
per_page: int = 50,
include_sweeps: bool = True,
):
self.entity = entity
self.project = project
self.filters = filters or {}
self.order = order
self._sweeps = {}
self._include_sweeps = include_sweeps
variables = {
"project": self.project,
"entity": self.entity,
Expand Down Expand Up @@ -1576,10 +1594,11 @@ def convert_objects(self):
self.project,
run_response["node"]["name"],
run_response["node"],
include_sweeps=self._include_sweeps,
)
objs.append(run)

if run.sweep_name:
if self._include_sweeps and run.sweep_name:
if run.sweep_name in self._sweeps:
sweep = self._sweeps[run.sweep_name]
else:
Expand Down Expand Up @@ -1627,7 +1646,15 @@ class Run(Attrs):
with `wandb.log({key: value})`
"""

def __init__(self, client, entity, project, run_id, attrs=None):
def __init__(
self,
client: "RetryingClient",
entity: str,
project: str,
run_id: str,
attrs: Optional[Mapping] = None,
include_sweeps: bool = True,
):
"""
Run is always initialized by calling api.runs() where api is an instance of wandb.Api
"""
Expand All @@ -1640,6 +1667,7 @@ def __init__(self, client, entity, project, run_id, attrs=None):
self._base_dir = env.get_dir(tempfile.gettempdir())
self.id = run_id
self.sweep = None
self._include_sweeps = include_sweeps
self.dir = os.path.join(self._base_dir, *self.path)
try:
os.makedirs(self.dir)
Expand Down Expand Up @@ -1756,7 +1784,7 @@ def load(self, force=False):
self._attrs = response["project"]["run"]
self._state = self._attrs["state"]

if self.sweep_name and not self.sweep:
if self._include_sweeps and self.sweep_name and not self.sweep:
# There may be a lot of runs. Don't bother pulling them all
# just for the sake of this one.
self.sweep = Sweep.get(
Expand Down

0 comments on commit 643b25d

Please sign in to comment.