Skip to content

Commit

Permalink
style(sdk): update mypy to 0.991 (#4546)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitryduev committed Nov 29, 2022
1 parent cb77d21 commit 928b1a3
Show file tree
Hide file tree
Showing 73 changed files with 418 additions and 337 deletions.
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ basepython=python3
skip_install = true
deps=
types-click==7.1.8
mypy==0.971
mypy==0.991
lxml
grpcio
setenv =
Expand Down
6 changes: 3 additions & 3 deletions wandb/apis/public.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def create_report(
title: Optional[str] = "Untitled Report",
description: Optional[str] = "",
width: Optional[str] = "readable",
blocks: "Optional[wandb.apis.reports.util.Block]" = None,
blocks: Optional["wandb.apis.reports.util.Block"] = None,
) -> "wandb.apis.reports.Report":
if entity == "":
entity = self.default_entity or ""
Expand Down Expand Up @@ -3807,7 +3807,7 @@ def __init__(
entity: str,
project: str,
type_name: str,
attrs: Mapping[str, Any] = None,
attrs: Optional[Mapping[str, Any]] = None,
):
self.client = client
self.entity = entity
Expand Down Expand Up @@ -5301,7 +5301,7 @@ class Job:
_project: str
_entrypoint: List[str]

def __init__(self, api: Api, name, path: str = None) -> None:
def __init__(self, api: Api, name, path: Optional[str] = None) -> None:
try:
self._job_artifact = api.artifact(name, type="job")
except CommError:
Expand Down
8 changes: 7 additions & 1 deletion wandb/apis/reports/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,13 @@ class RGBA(Base):
a: Union[int, float] = Attr(validators=[Between(0, 1)])

def __init__(
self, r: int, g: int, b: int, a: Union[int, float] = None, *args, **kwargs
self,
r: int,
g: int,
b: int,
a: Optional[Union[int, float]] = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.r = r
Expand Down
2 changes: 1 addition & 1 deletion wandb/apis/reports/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ class Panel(Base, SubclassOnlyABC):
layout: dict = Attr(json_path="spec.layout")

def __init__(
self, layout: Dict[str, int] = None, *args: Any, **kwargs: Any
self, layout: Optional[Dict[str, int]] = None, *args: Any, **kwargs: Any
) -> None:
super().__init__(*args, **kwargs)
self._spec["viewType"] = self.view_type
Expand Down
2 changes: 1 addition & 1 deletion wandb/docker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def run(
args: List[Any],
capture_stdout: bool = True,
capture_stderr: bool = True,
input: bytes = None,
input: Optional[bytes] = None,
return_stderr: bool = False,
env: Optional[Dict[str, str]] = None,
) -> Union[str, Tuple[str, str]]:
Expand Down
10 changes: 5 additions & 5 deletions wandb/docker/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def get_credential_store(authconfig: Dict, registry: str) -> Optional[str]:


class AuthConfig(dict):
def __init__(self, dct: Dict, credstore_env: Mapping = None) -> None:
def __init__(self, dct: Dict, credstore_env: Optional[Mapping] = None) -> None:
super().__init__(dct)
if "auths" not in dct:
dct["auths"] = {}
Expand Down Expand Up @@ -208,7 +208,7 @@ def load_config(
cls,
config_path: Optional[str],
config_dict: Optional[Dict[str, Any]],
credstore_env: Mapping = None,
credstore_env: Optional[Mapping] = None,
) -> "AuthConfig":
"""
Loads authentication data from a Docker configuration file in the given
Expand Down Expand Up @@ -401,9 +401,9 @@ def parse_auth(


def load_config(
config_path: str = None,
config_dict: Dict[str, Any] = None,
credstore_env: Mapping = None,
config_path: Optional[str] = None,
config_dict: Optional[Dict[str, Any]] = None,
credstore_env: Optional[Mapping] = None,
) -> AuthConfig:
return AuthConfig.load_config(config_path, config_dict, credstore_env)

Expand Down
6 changes: 3 additions & 3 deletions wandb/docker/www_authenticate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Taken from: https://github.com/alexsdutton/www-authenticate
import re
from collections import OrderedDict
from typing import Any
from typing import Any, Optional

_tokens = (
("token", re.compile(r"""^([!#$%&'*+\-.^_`|~\w/]+(?:={1,2}$)?)""")),
Expand Down Expand Up @@ -29,10 +29,10 @@ def __setitem__(self, key: str, value: Any) -> None:
def __contains__(self, key: object) -> bool:
return super().__contains__(_casefold(key)) # type: ignore

def get(self, key: str, default: Any = None) -> Any:
def get(self, key: str, default: Optional[Any] = None) -> Any:
return super().get(_casefold(key), default)

def pop(self, key: str, default: Any = None) -> Any:
def pop(self, key: str, default: Optional[Any] = None) -> Any:
return super().pop(_casefold(key), default)


Expand Down
80 changes: 51 additions & 29 deletions wandb/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def immutable_keys() -> List[str]:
]


def _env_as_bool(var: str, default: Optional[str] = None, env: Env = None) -> bool:
def _env_as_bool(
var: str, default: Optional[str] = None, env: Optional[Env] = None
) -> bool:
if env is None:
env = os.environ
val = env.get(var, default)
Expand All @@ -129,7 +131,7 @@ def _env_as_bool(var: str, default: Optional[str] = None, env: Env = None) -> bo
return val if isinstance(val, bool) else False


def is_debug(default: Optional[str] = None, env: Env = None) -> bool:
def is_debug(default: Optional[str] = None, env: Optional[Env] = None) -> bool:
return _env_as_bool(DEBUG, default=default, env=env)


Expand All @@ -143,23 +145,23 @@ def ssl_disabled() -> bool:

def get_error_reporting(
default: Union[bool, str] = True,
env: Env = None,
env: Optional[Env] = None,
) -> Union[bool, str]:
if env is None:
env = os.environ

return env.get(ERROR_REPORTING, default)


def get_run(default: Optional[str] = None, env: Env = None) -> Optional[str]:
def get_run(default: Optional[str] = None, env: Optional[Env] = None) -> Optional[str]:
if env is None:
env = os.environ

return env.get(RUN_ID, default)


def get_args(
default: Optional[List[str]] = None, env: Env = None
default: Optional[List[str]] = None, env: Optional[Env] = None
) -> Optional[List[str]]:
if env is None:
env = os.environ
Expand All @@ -172,22 +174,24 @@ def get_args(
return default or sys.argv[1:]


def get_docker(default: Optional[str] = None, env: Env = None) -> Optional[str]:
def get_docker(
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[str]:
if env is None:
env = os.environ

return env.get(DOCKER, default)


def get_http_timeout(default: int = 10, env: Env = None) -> int:
def get_http_timeout(default: int = 10, env: Optional[Env] = None) -> int:
if env is None:
env = os.environ

return int(env.get(HTTP_TIMEOUT, default))


def get_ignore(
default: Optional[List[str]] = None, env: Env = None
default: Optional[List[str]] = None, env: Optional[Env] = None
) -> Optional[List[str]]:
if env is None:
env = os.environ
Expand All @@ -198,35 +202,45 @@ def get_ignore(
return default


def get_project(default: Optional[str] = None, env: Env = None) -> Optional[str]:
def get_project(
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[str]:
if env is None:
env = os.environ

return env.get(PROJECT, default)


def get_username(default: Optional[str] = None, env: Env = None) -> Optional[str]:
def get_username(
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[str]:
if env is None:
env = os.environ

return env.get(USERNAME, default)


def get_user_email(default: Optional[str] = None, env: Env = None) -> Optional[str]:
def get_user_email(
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[str]:
if env is None:
env = os.environ

return env.get(USER_EMAIL, default)


def get_entity(default: Optional[str] = None, env: Env = None) -> Optional[str]:
def get_entity(
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[str]:
if env is None:
env = os.environ

return env.get(ENTITY, default)


def get_base_url(default: Optional[str] = None, env: Env = None) -> Optional[str]:
def get_base_url(
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[str]:
if env is None:
env = os.environ

Expand All @@ -235,48 +249,54 @@ def get_base_url(default: Optional[str] = None, env: Env = None) -> Optional[str
return base_url.rstrip("/") if base_url is not None else base_url


def get_app_url(default: Optional[str] = None, env: Env = None) -> Optional[str]:
def get_app_url(
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[str]:
if env is None:
env = os.environ

return env.get(APP_URL, default)


def get_show_run(default: Optional[str] = None, env: Env = None) -> bool:
def get_show_run(default: Optional[str] = None, env: Optional[Env] = None) -> bool:
if env is None:
env = os.environ

return bool(env.get(SHOW_RUN, default))


def get_description(default: Optional[str] = None, env: Env = None) -> Optional[str]:
def get_description(
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[str]:
if env is None:
env = os.environ

return env.get(DESCRIPTION, default)


def get_tags(default: str = "", env: Env = None) -> List[str]:
def get_tags(default: str = "", env: Optional[Env] = None) -> List[str]:
if env is None:
env = os.environ

return [tag for tag in env.get(TAGS, default).split(",") if tag]


def get_dir(default: Optional[str] = None, env: Env = None) -> Optional[str]:
def get_dir(default: Optional[str] = None, env: Optional[Env] = None) -> Optional[str]:
if env is None:
env = os.environ
return env.get(DIR, default)


def get_config_paths(default: Optional[str] = None, env: Env = None) -> Optional[str]:
def get_config_paths(
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[str]:
if env is None:
env = os.environ
return env.get(CONFIG_PATHS, default)


def get_agent_report_interval(
default: Optional[str] = None, env: Env = None
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[int]:
if env is None:
env = os.environ
Expand All @@ -289,7 +309,7 @@ def get_agent_report_interval(


def get_agent_kill_delay(
default: Optional[str] = None, env: Env = None
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[int]:
if env is None:
env = os.environ
Expand All @@ -302,7 +322,7 @@ def get_agent_kill_delay(


def get_crash_nosync_time(
default: Optional[str] = None, env: Env = None
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[int]:
if env is None:
env = os.environ
Expand All @@ -314,30 +334,32 @@ def get_crash_nosync_time(
return val


def get_magic(default: Optional[str] = None, env: Env = None) -> Optional[str]:
def get_magic(
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[str]:
if env is None:
env = os.environ
val = env.get(MAGIC, default)
return val


def get_cache_dir(env: Env = None) -> str:
def get_cache_dir(env: Optional[Env] = None) -> str:
default_dir = os.path.expanduser(os.path.join("~", ".cache", "wandb"))
if env is None:
env = os.environ
val = env.get(CACHE_DIR, default_dir)
return val


def get_use_v1_artifacts(env: Env = None) -> bool:
def get_use_v1_artifacts(env: Optional[Env] = None) -> bool:
if env is None:
env = os.environ
val = bool(env.get(USE_V1_ARTIFACTS, False))
return val


def get_agent_max_initial_failures(
default: Optional[int] = None, env: Env = None
default: Optional[int] = None, env: Optional[Env] = None
) -> Optional[int]:
if env is None:
env = os.environ
Expand All @@ -349,13 +371,13 @@ def get_agent_max_initial_failures(
return val


def set_entity(value: str, env: Env = None) -> None:
def set_entity(value: str, env: Optional[Env] = None) -> None:
if env is None:
env = os.environ
env[ENTITY] = value


def set_project(value: str, env: Env = None) -> None:
def set_project(value: str, env: Optional[Env] = None) -> None:
if env is None:
env = os.environ
env[PROJECT] = value or "uncategorized"
Expand All @@ -367,7 +389,7 @@ def should_save_code() -> bool:
return save_code and not code_disabled


def disable_git(env: Env = None) -> bool:
def disable_git(env: Optional[Env] = None) -> bool:
if env is None:
env = os.environ
val = env.get(DISABLE_GIT, default="False")
Expand Down
2 changes: 1 addition & 1 deletion wandb/filesync/upload_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def push(self) -> bool:
message = str(e)
# TODO: this is usually XML, but could be JSON
if hasattr(e, "response"):
message = e.response.content # type: ignore[attr-defined]
message = e.response.content
wandb.termerror(
'Error uploading "{}": {}, {}'.format(
self.save_path, type(e).__name__, message
Expand Down
Loading

0 comments on commit 928b1a3

Please sign in to comment.