Skip to content

Commit

Permalink
fix optionals
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitryduev committed Nov 29, 2022
1 parent a95c275 commit 77dc647
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 39 deletions.
2 changes: 1 addition & 1 deletion wandb/sdk/data_types/base_types/wb_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def with_suffix(cls: Type["WBValue"], name: str, filetype: str = "json") -> str:
@staticmethod
def init_from_json(
json_obj: dict, source_artifact: "PublicArtifact"
) -> "Optional[WBValue]":
) -> Optional["WBValue"]:
"""Looks through all subclasses and tries to match the json obj with the class which created it. It will then
call that subclass' `from_json` method. Importantly, this function will set the return object's `source_artifact`
attribute to the passed in source artifact. This is critical for artifact bookkeeping. If you choose to create
Expand Down
4 changes: 2 additions & 2 deletions wandb/sdk/data_types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


def history_dict_to_json(
run: "Optional[LocalRun]",
run: Optional["LocalRun"],
payload: dict,
step: Optional[int] = None,
ignore_copy_err: Optional[bool] = None,
Expand Down Expand Up @@ -58,7 +58,7 @@ def history_dict_to_json(

# TODO: refine this
def val_to_json(
run: "Optional[LocalRun]",
run: Optional["LocalRun"],
key: str,
val: "ValToJsonType",
namespace: Optional[Union[str, int]] = None,
Expand Down
2 changes: 1 addition & 1 deletion wandb/sdk/interface/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, mailbox: Optional[mailbox.Mailbox] = None) -> None:
self._thread.start()

@abstractmethod
def _read_message(self) -> "Optional[pb.Result]":
def _read_message(self) -> Optional["pb.Result"]:
raise NotImplementedError

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion wandb/sdk/interface/router_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
self._response_queue = response_queue
super().__init__(mailbox=mailbox)

def _read_message(self) -> "Optional[pb.Result]":
def _read_message(self) -> Optional["pb.Result"]:
try:
msg = self._response_queue.get(timeout=1)
except queue.Empty:
Expand Down
2 changes: 1 addition & 1 deletion wandb/sdk/interface/router_sock.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, sock_client: SockClient, mailbox: Mailbox) -> None:
self._sock_client = sock_client
super().__init__(mailbox=mailbox)

def _read_message(self) -> "Optional[pb.Result]":
def _read_message(self) -> Optional["pb.Result"]:
try:
resp = self._sock_client.read_server_response(timeout=1)
except SockClientClosedError:
Expand Down
13 changes: 6 additions & 7 deletions wandb/sdk/internal/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import time
import traceback
from datetime import datetime
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, List, Optional

import psutil

Expand All @@ -36,7 +36,6 @@
if TYPE_CHECKING:
from queue import Queue
from threading import Event
from typing import Any, List, Optional

from wandb.proto.wandb_internal_pb2 import Record, Result

Expand All @@ -51,8 +50,8 @@ def wandb_internal(
settings: "SettingsDict",
record_q: "Queue[Record]",
result_q: "Queue[Result]",
port: "Optional[int]" = None,
user_pid: "Optional[int]" = None,
port: Optional[int] = None,
user_pid: Optional[int] = None,
) -> None:
"""Internal process function entrypoint.
Expand Down Expand Up @@ -181,7 +180,7 @@ def _setup_tracelog() -> None:


def configure_logging(
log_fname: str, log_level: int, run_id: "Optional[str]" = None
log_fname: str, log_level: int, run_id: Optional[str] = None
) -> None:
# TODO: we may want make prints and stdout make it into the logs
# sys.stdout = open(settings.log_internal, "a")
Expand Down Expand Up @@ -362,9 +361,9 @@ def _debounce(self) -> None:
class ProcessCheck:
"""Class to help watch a process id to detect when it is dead."""

check_process_last: "Optional[float]"
check_process_last: Optional[float]

def __init__(self, settings: "SettingsStatic", user_pid: "Optional[int]") -> None:
def __init__(self, settings: "SettingsStatic", user_pid: Optional[int]) -> None:
self.settings = settings
self.pid = user_pid
self.check_process_last = None
Expand Down
7 changes: 3 additions & 4 deletions wandb/sdk/internal/internal_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@
import sys
import threading
import time
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional, Tuple, Type, Union

from ..lib import tracelog

if TYPE_CHECKING:
from queue import Queue
from threading import Event
from types import TracebackType
from typing import Optional, Tuple, Type, Union

from wandb.proto.wandb_internal_pb2 import Record, Result

Expand All @@ -35,7 +34,7 @@ class ExceptionThread(threading.Thread):
"""Class to catch exceptions when running a thread."""

__stopped: "Event"
__exception: "Optional[ExceptionType]"
__exception: Optional["ExceptionType"]

def __init__(self, stopped: "Event") -> None:
threading.Thread.__init__(self)
Expand All @@ -54,7 +53,7 @@ def run(self) -> None:
if self.__exception and self.__stopped:
self.__stopped.set()

def get_exception(self) -> "Optional[ExceptionType]":
def get_exception(self) -> Optional["ExceptionType"]:
return self.__exception


Expand Down
18 changes: 9 additions & 9 deletions wandb/sdk/internal/sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,14 @@ class SendManager:
_partial_output: Dict[str, str]

_telemetry_obj: telemetry.TelemetryRecord
_fs: "Optional[file_stream.FileStreamApi]"
_run: "Optional[RunRecord]"
_entity: "Optional[str]"
_project: "Optional[str]"
_dir_watcher: "Optional[DirWatcher]"
_pusher: "Optional[FilePusher]"
_record_exit: "Optional[Record]"
_exit_result: "Optional[RunExitResult]"
_fs: Optional["file_stream.FileStreamApi"]
_run: Optional["RunRecord"]
_entity: Optional[str]
_project: Optional[str]
_dir_watcher: Optional["DirWatcher"]
_pusher: Optional["FilePusher"]
_record_exit: Optional["Record"]
_exit_result: Optional["RunExitResult"]
_resume_state: ResumeState
_cached_server_info: Dict[str, Any]
_cached_viewer: Dict[str, Any]
Expand Down Expand Up @@ -579,7 +579,7 @@ def send_request_server_info(self, record: "Record") -> None:

def _maybe_setup_resume(
self, run: "RunRecord"
) -> "Optional[wandb_internal_pb2.ErrorInfo]":
) -> Optional["wandb_internal_pb2.ErrorInfo"]:
"""This maybe queries the backend for a run and fails if the settings are
incompatible."""
if not self._settings.resume:
Expand Down
19 changes: 9 additions & 10 deletions wandb/sdk/internal/tb_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import sys
import threading
import time
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Dict, List, Optional

import wandb
from wandb import util
Expand All @@ -21,7 +21,6 @@

if TYPE_CHECKING:
from queue import PriorityQueue
from typing import Dict, List, Optional

from tensorboard.backend.event_processing.event_file_loader import EventFileLoader
from tensorboard.compat.proto.event_pb2 import ProtoEvent
Expand Down Expand Up @@ -110,7 +109,7 @@ def __init__(
force: bool = False,
) -> None:
self._logdirs = {}
self._consumer: "Optional[TBEventConsumer]" = None
self._consumer: Optional["TBEventConsumer"] = None
self._settings = settings
self._interface = interface
self._run_proto = run_proto
Expand All @@ -119,8 +118,8 @@ def __init__(
self._watcher_queue = queue.PriorityQueue()
wandb.tensorboard.reset_state()

def _calculate_namespace(self, logdir: str, rootdir: str) -> "Optional[str]":
namespace: "Optional[str]"
def _calculate_namespace(self, logdir: str, rootdir: str) -> Optional[str]:
namespace: Optional[str]
dirs = list(self._logdirs) + [logdir]

if os.path.isfile(logdir):
Expand Down Expand Up @@ -182,7 +181,7 @@ def __init__(
tbwatcher: "TBWatcher",
logdir: str,
save: bool,
namespace: "Optional[str]",
namespace: Optional[str],
queue: "PriorityQueue",
force: bool = False,
) -> None:
Expand Down Expand Up @@ -227,7 +226,7 @@ def _is_our_tfevents_file(self, path: str) -> bool:
)

def _loader(
self, save: bool = True, namespace: "Optional[str]" = None
self, save: bool = True, namespace: Optional[str] = None
) -> "EventFileLoader":
"""Incredibly hacky class generator to optionally save / prefix tfevent files"""
_loader_interface = self._tbwatcher._interface
Expand Down Expand Up @@ -286,7 +285,7 @@ def _thread_except_body(self) -> None:

def _thread_body(self) -> None:
"""Check for new events every second"""
shutdown_time: "Optional[float]" = None
shutdown_time: Optional[float] = None
while True:
self._process_events()
if self._shutdown.is_set():
Expand Down Expand Up @@ -320,7 +319,7 @@ def finish(self) -> None:
class Event:
"""An event wrapper to enable priority queueing"""

def __init__(self, event: "ProtoEvent", namespace: "Optional[str]"):
def __init__(self, event: "ProtoEvent", namespace: Optional[str]):
self.event = event
self.namespace = namespace
self.created_at = time.time()
Expand Down Expand Up @@ -419,7 +418,7 @@ def _thread_body(self) -> None:
self._save_row(item)

def _handle_event(
self, event: "ProtoEvent", history: "Optional[TBHistory]" = None
self, event: "ProtoEvent", history: Optional["TBHistory"] = None
) -> None:
wandb.tensorboard._log(
event.event,
Expand Down
6 changes: 3 additions & 3 deletions wandb/sdk/wandb_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2140,7 +2140,7 @@ def _create_repo_job(
input_types: Dict[str, Any],
output_types: Dict[str, Any],
installed_packages_list: List[str],
) -> "Optional[Artifact]":
) -> Optional["Artifact"]:
"""Create a job version artifact from a repo."""
has_repo = self._remote_url is not None and self._commit is not None
program_relpath = self._settings.program_relpath
Expand Down Expand Up @@ -2182,7 +2182,7 @@ def _create_artifact_job(
input_types: Dict[str, Any],
output_types: Dict[str, Any],
installed_packages_list: List[str],
) -> "Optional[Artifact]":
) -> Optional["Artifact"]:
if (
self._code_artifact_info is None
or self._run_obj is None
Expand Down Expand Up @@ -2218,7 +2218,7 @@ def _create_image_job(
output_types: Dict[str, Any],
installed_packages_list: List[str],
docker_image_name: Optional[str] = None,
) -> "Optional[Artifact]":
) -> Optional["Artifact"]:
docker_image_name = docker_image_name or os.getenv("WANDB_DOCKER")

if not docker_image_name:
Expand Down

0 comments on commit 77dc647

Please sign in to comment.