diff --git a/ignite/utils.py b/ignite/utils.py index b21f9c79b3b0..08a33f0aebf0 100644 --- a/ignite/utils.py +++ b/ignite/utils.py @@ -2,7 +2,6 @@ import functools import logging import random -import sys import warnings from typing import Any, Callable, Dict, Optional, TextIO, Tuple, Type, TypeVar, Union, cast @@ -12,46 +11,60 @@ def convert_tensor( - input_: Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes], + x: Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes], device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False, ) -> Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes]: - """Move tensors to relevant device.""" + """Move tensors to relevant device. + + Args: + x: input tensor or mapping, or sequence of tensors. + device: device type to move ``x``. + non_blocking: convert a CPU Tensor with pinned memory to a CUDA Tensor + asynchronously with respect to the host if possible + """ def _func(tensor: torch.Tensor) -> torch.Tensor: return tensor.to(device=device, non_blocking=non_blocking) if device is not None else tensor - return apply_to_tensor(input_, _func) + return apply_to_tensor(x, _func) def apply_to_tensor( - input_: Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes], func: Callable + x: Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes], func: Callable ) -> Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes]: """Apply a function on a tensor or mapping, or sequence of tensors. + + Args: + x: input tensor or mapping, or sequence of tensors. + func: the function to apply on ``x``. """ - return apply_to_type(input_, torch.Tensor, func) + return apply_to_type(x, torch.Tensor, func) def apply_to_type( - input_: Union[Any, collections.Sequence, collections.Mapping, str, bytes], + x: Union[Any, collections.Sequence, collections.Mapping, str, bytes], input_type: Union[Type, Tuple[Type[Any], Any]], func: Callable, ) -> Union[Any, collections.Sequence, collections.Mapping, str, bytes]: - """Apply a function on a object of `input_type` or mapping, or sequence of objects of `input_type`. + """Apply a function on an object of `input_type` or mapping, or sequence of objects of `input_type`. + + Args: + x: object or mapping or sequence. + input_type: data type of ``x``. + func: the function to apply on ``x``. """ - if isinstance(input_, input_type): - return func(input_) - if isinstance(input_, (str, bytes)): - return input_ - if isinstance(input_, collections.Mapping): - return cast(Callable, type(input_))( - {k: apply_to_type(sample, input_type, func) for k, sample in input_.items()} - ) - if isinstance(input_, tuple) and hasattr(input_, "_fields"): # namedtuple - return cast(Callable, type(input_))(*(apply_to_type(sample, input_type, func) for sample in input_)) - if isinstance(input_, collections.Sequence): - return cast(Callable, type(input_))([apply_to_type(sample, input_type, func) for sample in input_]) - raise TypeError((f"input must contain {input_type}, dicts or lists; found {type(input_)}")) + if isinstance(x, input_type): + return func(x) + if isinstance(x, (str, bytes)): + return x + if isinstance(x, collections.Mapping): + return cast(Callable, type(x))({k: apply_to_type(sample, input_type, func) for k, sample in x.items()}) + if isinstance(x, tuple) and hasattr(x, "_fields"): # namedtuple + return cast(Callable, type(x))(*(apply_to_type(sample, input_type, func) for sample in x)) + if isinstance(x, collections.Sequence): + return cast(Callable, type(x))([apply_to_type(sample, input_type, func) for sample in x]) + raise TypeError((f"x must contain {input_type}, dicts or lists; found {type(x)}")) def to_onehot(indices: torch.Tensor, num_classes: int) -> torch.Tensor: @@ -59,6 +72,10 @@ def to_onehot(indices: torch.Tensor, num_classes: int) -> torch.Tensor: tensor of one-hot indicators of shape `(N, num_classes, ...) and of type uint8. Output's device is equal to the input's device`. + Args: + indices: input tensor to convert. + num_classes: number of classes for one-hot tensor. + .. versionchanged:: 0.4.3 This functions is now torchscriptable. """ @@ -78,12 +95,12 @@ def setup_logger( """Setups logger: name, level, format etc. Args: - name (str, optional): new name for the logger. If None, the standard logger is used. - level (int): logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG. - stream (TextIO, optional): logging stream. If None, the standard stream is used (sys.stderr). - format (str): logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`. - filepath (str, optional): Optional logging file path. If not None, logs are written to the file. - distributed_rank (int, optional): Optional, rank in distributed configuration to avoid logger setup for workers. + name: new name for the logger. If None, the standard logger is used. + level: logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG. + stream: logging stream. If None, the standard stream is used (sys.stderr). + format: logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`. + filepath: Optional logging file path. If not None, logs are written to the file. + distributed_rank: Optional, rank in distributed configuration to avoid logger setup for workers. If None, distributed_rank is initialized to the rank of process. Returns: @@ -156,7 +173,7 @@ def manual_seed(seed: int) -> None: """Setup random state from a seed for `torch`, `random` and optionally `numpy` (if can be imported). Args: - seed (int): Random state seed + seed: Random state seed .. versionchanged:: 0.4.3 Added ``torch.cuda.manual_seed_all(seed)``.