Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 45 additions & 28 deletions ignite/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import functools
import logging
import random
import sys
import warnings
from typing import Any, Callable, Dict, Optional, TextIO, Tuple, Type, TypeVar, Union, cast

Expand All @@ -12,53 +11,71 @@


def convert_tensor(
input_: Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes],
x: Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes],
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
) -> Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes]:
"""Move tensors to relevant device."""
"""Move tensors to relevant device.

Args:
x: input tensor or mapping, or sequence of tensors.
device: device type to move ``x``.
non_blocking: convert a CPU Tensor with pinned memory to a CUDA Tensor
asynchronously with respect to the host if possible
"""

def _func(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(device=device, non_blocking=non_blocking) if device is not None else tensor

return apply_to_tensor(input_, _func)
return apply_to_tensor(x, _func)


def apply_to_tensor(
input_: Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes], func: Callable
x: Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes], func: Callable
) -> Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes]:
"""Apply a function on a tensor or mapping, or sequence of tensors.

Args:
x: input tensor or mapping, or sequence of tensors.
func: the function to apply on ``x``.
"""
return apply_to_type(input_, torch.Tensor, func)
return apply_to_type(x, torch.Tensor, func)


def apply_to_type(
input_: Union[Any, collections.Sequence, collections.Mapping, str, bytes],
x: Union[Any, collections.Sequence, collections.Mapping, str, bytes],
input_type: Union[Type, Tuple[Type[Any], Any]],
func: Callable,
) -> Union[Any, collections.Sequence, collections.Mapping, str, bytes]:
"""Apply a function on a object of `input_type` or mapping, or sequence of objects of `input_type`.
"""Apply a function on an object of `input_type` or mapping, or sequence of objects of `input_type`.

Args:
x: object or mapping or sequence.
input_type: data type of ``x``.
func: the function to apply on ``x``.
"""
if isinstance(input_, input_type):
return func(input_)
if isinstance(input_, (str, bytes)):
return input_
if isinstance(input_, collections.Mapping):
return cast(Callable, type(input_))(
{k: apply_to_type(sample, input_type, func) for k, sample in input_.items()}
)
if isinstance(input_, tuple) and hasattr(input_, "_fields"): # namedtuple
return cast(Callable, type(input_))(*(apply_to_type(sample, input_type, func) for sample in input_))
if isinstance(input_, collections.Sequence):
return cast(Callable, type(input_))([apply_to_type(sample, input_type, func) for sample in input_])
raise TypeError((f"input must contain {input_type}, dicts or lists; found {type(input_)}"))
if isinstance(x, input_type):
return func(x)
if isinstance(x, (str, bytes)):
return x
if isinstance(x, collections.Mapping):
return cast(Callable, type(x))({k: apply_to_type(sample, input_type, func) for k, sample in x.items()})
if isinstance(x, tuple) and hasattr(x, "_fields"): # namedtuple
return cast(Callable, type(x))(*(apply_to_type(sample, input_type, func) for sample in x))
if isinstance(x, collections.Sequence):
return cast(Callable, type(x))([apply_to_type(sample, input_type, func) for sample in x])
raise TypeError((f"x must contain {input_type}, dicts or lists; found {type(x)}"))


def to_onehot(indices: torch.Tensor, num_classes: int) -> torch.Tensor:
"""Convert a tensor of indices of any shape `(N, ...)` to a
tensor of one-hot indicators of shape `(N, num_classes, ...) and of type uint8. Output's device is equal to the
input's device`.

Args:
indices: input tensor to convert.
num_classes: number of classes for one-hot tensor.

.. versionchanged:: 0.4.3
This functions is now torchscriptable.
"""
Expand All @@ -78,12 +95,12 @@ def setup_logger(
"""Setups logger: name, level, format etc.

Args:
name (str, optional): new name for the logger. If None, the standard logger is used.
level (int): logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG.
stream (TextIO, optional): logging stream. If None, the standard stream is used (sys.stderr).
format (str): logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`.
filepath (str, optional): Optional logging file path. If not None, logs are written to the file.
distributed_rank (int, optional): Optional, rank in distributed configuration to avoid logger setup for workers.
name: new name for the logger. If None, the standard logger is used.
level: logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG.
stream: logging stream. If None, the standard stream is used (sys.stderr).
format: logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`.
filepath: Optional logging file path. If not None, logs are written to the file.
distributed_rank: Optional, rank in distributed configuration to avoid logger setup for workers.
If None, distributed_rank is initialized to the rank of process.

Returns:
Expand Down Expand Up @@ -156,7 +173,7 @@ def manual_seed(seed: int) -> None:
"""Setup random state from a seed for `torch`, `random` and optionally `numpy` (if can be imported).

Args:
seed (int): Random state seed
seed: Random state seed

.. versionchanged:: 0.4.3
Added ``torch.cuda.manual_seed_all(seed)``.
Expand Down