Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

style(sdk): update mypy to 0.991 #4546

Merged
merged 8 commits into from
Nov 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ basepython=python3
skip_install = true
deps=
types-click==7.1.8
mypy==0.971
mypy==0.991
lxml
grpcio
setenv =
Expand Down
6 changes: 3 additions & 3 deletions wandb/apis/public.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def create_report(
title: Optional[str] = "Untitled Report",
description: Optional[str] = "",
width: Optional[str] = "readable",
blocks: "Optional[wandb.apis.reports.util.Block]" = None,
blocks: Optional["wandb.apis.reports.util.Block"] = None,
) -> "wandb.apis.reports.Report":
if entity == "":
entity = self.default_entity or ""
Expand Down Expand Up @@ -3807,7 +3807,7 @@ def __init__(
entity: str,
project: str,
type_name: str,
attrs: Mapping[str, Any] = None,
attrs: Optional[Mapping[str, Any]] = None,
):
self.client = client
self.entity = entity
Expand Down Expand Up @@ -5301,7 +5301,7 @@ class Job:
_project: str
_entrypoint: List[str]

def __init__(self, api: Api, name, path: str = None) -> None:
def __init__(self, api: Api, name, path: Optional[str] = None) -> None:
try:
self._job_artifact = api.artifact(name, type="job")
except CommError:
Expand Down
8 changes: 7 additions & 1 deletion wandb/apis/reports/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,13 @@ class RGBA(Base):
a: Union[int, float] = Attr(validators=[Between(0, 1)])

def __init__(
self, r: int, g: int, b: int, a: Union[int, float] = None, *args, **kwargs
self,
r: int,
g: int,
b: int,
a: Optional[Union[int, float]] = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.r = r
Expand Down
2 changes: 1 addition & 1 deletion wandb/apis/reports/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ class Panel(Base, SubclassOnlyABC):
layout: dict = Attr(json_path="spec.layout")

def __init__(
self, layout: Dict[str, int] = None, *args: Any, **kwargs: Any
self, layout: Optional[Dict[str, int]] = None, *args: Any, **kwargs: Any
) -> None:
super().__init__(*args, **kwargs)
self._spec["viewType"] = self.view_type
Expand Down
2 changes: 1 addition & 1 deletion wandb/docker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def run(
args: List[Any],
capture_stdout: bool = True,
capture_stderr: bool = True,
input: bytes = None,
input: Optional[bytes] = None,
return_stderr: bool = False,
env: Optional[Dict[str, str]] = None,
) -> Union[str, Tuple[str, str]]:
Expand Down
10 changes: 5 additions & 5 deletions wandb/docker/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def get_credential_store(authconfig: Dict, registry: str) -> Optional[str]:


class AuthConfig(dict):
def __init__(self, dct: Dict, credstore_env: Mapping = None) -> None:
def __init__(self, dct: Dict, credstore_env: Optional[Mapping] = None) -> None:
super().__init__(dct)
if "auths" not in dct:
dct["auths"] = {}
Expand Down Expand Up @@ -208,7 +208,7 @@ def load_config(
cls,
config_path: Optional[str],
config_dict: Optional[Dict[str, Any]],
credstore_env: Mapping = None,
credstore_env: Optional[Mapping] = None,
) -> "AuthConfig":
"""
Loads authentication data from a Docker configuration file in the given
Expand Down Expand Up @@ -401,9 +401,9 @@ def parse_auth(


def load_config(
config_path: str = None,
config_dict: Dict[str, Any] = None,
credstore_env: Mapping = None,
config_path: Optional[str] = None,
config_dict: Optional[Dict[str, Any]] = None,
credstore_env: Optional[Mapping] = None,
) -> AuthConfig:
return AuthConfig.load_config(config_path, config_dict, credstore_env)

Expand Down
6 changes: 3 additions & 3 deletions wandb/docker/www_authenticate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Taken from: https://github.com/alexsdutton/www-authenticate
import re
from collections import OrderedDict
from typing import Any
from typing import Any, Optional

_tokens = (
("token", re.compile(r"""^([!#$%&'*+\-.^_`|~\w/]+(?:={1,2}$)?)""")),
Expand Down Expand Up @@ -29,10 +29,10 @@ def __setitem__(self, key: str, value: Any) -> None:
def __contains__(self, key: object) -> bool:
return super().__contains__(_casefold(key)) # type: ignore

def get(self, key: str, default: Any = None) -> Any:
def get(self, key: str, default: Optional[Any] = None) -> Any:
return super().get(_casefold(key), default)

def pop(self, key: str, default: Any = None) -> Any:
def pop(self, key: str, default: Optional[Any] = None) -> Any:
return super().pop(_casefold(key), default)


Expand Down
80 changes: 51 additions & 29 deletions wandb/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def immutable_keys() -> List[str]:
]


def _env_as_bool(var: str, default: Optional[str] = None, env: Env = None) -> bool:
def _env_as_bool(
var: str, default: Optional[str] = None, env: Optional[Env] = None
) -> bool:
if env is None:
env = os.environ
val = env.get(var, default)
Expand All @@ -129,7 +131,7 @@ def _env_as_bool(var: str, default: Optional[str] = None, env: Env = None) -> bo
return val if isinstance(val, bool) else False


def is_debug(default: Optional[str] = None, env: Env = None) -> bool:
def is_debug(default: Optional[str] = None, env: Optional[Env] = None) -> bool:
return _env_as_bool(DEBUG, default=default, env=env)


Expand All @@ -143,23 +145,23 @@ def ssl_disabled() -> bool:

def get_error_reporting(
default: Union[bool, str] = True,
env: Env = None,
env: Optional[Env] = None,
) -> Union[bool, str]:
if env is None:
env = os.environ

return env.get(ERROR_REPORTING, default)


def get_run(default: Optional[str] = None, env: Env = None) -> Optional[str]:
def get_run(default: Optional[str] = None, env: Optional[Env] = None) -> Optional[str]:
if env is None:
env = os.environ

return env.get(RUN_ID, default)


def get_args(
default: Optional[List[str]] = None, env: Env = None
default: Optional[List[str]] = None, env: Optional[Env] = None
) -> Optional[List[str]]:
if env is None:
env = os.environ
Expand All @@ -172,22 +174,24 @@ def get_args(
return default or sys.argv[1:]


def get_docker(default: Optional[str] = None, env: Env = None) -> Optional[str]:
def get_docker(
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[str]:
if env is None:
env = os.environ

return env.get(DOCKER, default)


def get_http_timeout(default: int = 10, env: Env = None) -> int:
def get_http_timeout(default: int = 10, env: Optional[Env] = None) -> int:
if env is None:
env = os.environ

return int(env.get(HTTP_TIMEOUT, default))


def get_ignore(
default: Optional[List[str]] = None, env: Env = None
default: Optional[List[str]] = None, env: Optional[Env] = None
) -> Optional[List[str]]:
if env is None:
env = os.environ
Expand All @@ -198,35 +202,45 @@ def get_ignore(
return default


def get_project(default: Optional[str] = None, env: Env = None) -> Optional[str]:
def get_project(
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[str]:
if env is None:
env = os.environ

return env.get(PROJECT, default)


def get_username(default: Optional[str] = None, env: Env = None) -> Optional[str]:
def get_username(
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[str]:
if env is None:
env = os.environ

return env.get(USERNAME, default)


def get_user_email(default: Optional[str] = None, env: Env = None) -> Optional[str]:
def get_user_email(
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[str]:
if env is None:
env = os.environ

return env.get(USER_EMAIL, default)


def get_entity(default: Optional[str] = None, env: Env = None) -> Optional[str]:
def get_entity(
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[str]:
if env is None:
env = os.environ

return env.get(ENTITY, default)


def get_base_url(default: Optional[str] = None, env: Env = None) -> Optional[str]:
def get_base_url(
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[str]:
if env is None:
env = os.environ

Expand All @@ -235,48 +249,54 @@ def get_base_url(default: Optional[str] = None, env: Env = None) -> Optional[str
return base_url.rstrip("/") if base_url is not None else base_url


def get_app_url(default: Optional[str] = None, env: Env = None) -> Optional[str]:
def get_app_url(
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[str]:
if env is None:
env = os.environ

return env.get(APP_URL, default)


def get_show_run(default: Optional[str] = None, env: Env = None) -> bool:
def get_show_run(default: Optional[str] = None, env: Optional[Env] = None) -> bool:
if env is None:
env = os.environ

return bool(env.get(SHOW_RUN, default))


def get_description(default: Optional[str] = None, env: Env = None) -> Optional[str]:
def get_description(
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[str]:
if env is None:
env = os.environ

return env.get(DESCRIPTION, default)


def get_tags(default: str = "", env: Env = None) -> List[str]:
def get_tags(default: str = "", env: Optional[Env] = None) -> List[str]:
if env is None:
env = os.environ

return [tag for tag in env.get(TAGS, default).split(",") if tag]


def get_dir(default: Optional[str] = None, env: Env = None) -> Optional[str]:
def get_dir(default: Optional[str] = None, env: Optional[Env] = None) -> Optional[str]:
if env is None:
env = os.environ
return env.get(DIR, default)


def get_config_paths(default: Optional[str] = None, env: Env = None) -> Optional[str]:
def get_config_paths(
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[str]:
if env is None:
env = os.environ
return env.get(CONFIG_PATHS, default)


def get_agent_report_interval(
default: Optional[str] = None, env: Env = None
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[int]:
if env is None:
env = os.environ
Expand All @@ -289,7 +309,7 @@ def get_agent_report_interval(


def get_agent_kill_delay(
default: Optional[str] = None, env: Env = None
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[int]:
if env is None:
env = os.environ
Expand All @@ -302,7 +322,7 @@ def get_agent_kill_delay(


def get_crash_nosync_time(
default: Optional[str] = None, env: Env = None
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[int]:
if env is None:
env = os.environ
Expand All @@ -314,30 +334,32 @@ def get_crash_nosync_time(
return val


def get_magic(default: Optional[str] = None, env: Env = None) -> Optional[str]:
def get_magic(
default: Optional[str] = None, env: Optional[Env] = None
) -> Optional[str]:
if env is None:
env = os.environ
val = env.get(MAGIC, default)
return val


def get_cache_dir(env: Env = None) -> str:
def get_cache_dir(env: Optional[Env] = None) -> str:
default_dir = os.path.expanduser(os.path.join("~", ".cache", "wandb"))
if env is None:
env = os.environ
val = env.get(CACHE_DIR, default_dir)
return val


def get_use_v1_artifacts(env: Env = None) -> bool:
def get_use_v1_artifacts(env: Optional[Env] = None) -> bool:
if env is None:
env = os.environ
val = bool(env.get(USE_V1_ARTIFACTS, False))
return val


def get_agent_max_initial_failures(
default: Optional[int] = None, env: Env = None
default: Optional[int] = None, env: Optional[Env] = None
) -> Optional[int]:
if env is None:
env = os.environ
Expand All @@ -349,13 +371,13 @@ def get_agent_max_initial_failures(
return val


def set_entity(value: str, env: Env = None) -> None:
def set_entity(value: str, env: Optional[Env] = None) -> None:
if env is None:
env = os.environ
env[ENTITY] = value


def set_project(value: str, env: Env = None) -> None:
def set_project(value: str, env: Optional[Env] = None) -> None:
if env is None:
env = os.environ
env[PROJECT] = value or "uncategorized"
Expand All @@ -367,7 +389,7 @@ def should_save_code() -> bool:
return save_code and not code_disabled


def disable_git(env: Env = None) -> bool:
def disable_git(env: Optional[Env] = None) -> bool:
if env is None:
env = os.environ
val = env.get(DISABLE_GIT, default="False")
Expand Down
2 changes: 1 addition & 1 deletion wandb/filesync/upload_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def push(self) -> bool:
message = str(e)
# TODO: this is usually XML, but could be JSON
if hasattr(e, "response"):
message = e.response.content # type: ignore[attr-defined]
message = e.response.content
wandb.termerror(
'Error uploading "{}": {}, {}'.format(
self.save_path, type(e).__name__, message
Expand Down
Loading