Skip to content

Commit

Permalink
Engine docs (catalyst-team#1141)
Browse files Browse the repository at this point in the history
* typings & few docs

* typing fix; disabled `dist.barrier()` in optimizer step for ddp

* docs

* docs: fixed long lines with docs

* docs fixes

* optimizer args

* removed empty line

Co-authored-by: Dmytro Doroshenko <dimdoroshenko@gmail.com>
  • Loading branch information
2 people authored and zkid18 committed Jul 4, 2021
1 parent 41a7c7b commit 5389208
Show file tree
Hide file tree
Showing 4 changed files with 372 additions and 45 deletions.
57 changes: 43 additions & 14 deletions catalyst/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


@contextmanager
def nullcontext(enter_result=None):
def nullcontext(enter_result: Any = None):
"""Context handler."""
yield enter_result

Expand Down Expand Up @@ -37,8 +37,7 @@ def rank(self) -> int:
@property
@abstractmethod
def world_size(self) -> int:
"""Process world size for distributed training."""
# only for ddp
"""Process world size for distributed training."""
pass

@property
Expand All @@ -49,26 +48,28 @@ def is_ddp(self) -> bool:
@property
def is_master_process(self) -> bool:
"""Checks if a process is master process.
Should be implemented only for DDP setup in other cases should always return True.
Should be implemented only for distributed training (ddp).
For non distributed training should always return `True`.
Returns:
`True` if current process is a master process, otherwise `False`.
`True` if current process is a master process in other cases return `False`.
"""
return True

@property
def is_worker_process(self) -> bool:
"""Checks if a process is worker process.
Should be implemented only for DDP setup in other cases should always return False.
Should be implemented only for distributed training (ddp).
For non distributed training should always return `False`.
Returns:
`True` if current process is a worker process, otherwise `False`.
`True` if current process is a worker process in other cases return `False`.
"""
return False

@abstractmethod
def sync_device(self, tensor_or_module: Any) -> Any:
"""Moves ``tensor_or_module`` to Engine's deivce.
"""Moves ``tensor_or_module`` to Engine's device.
Args:
tensor_or_module: tensor to mode
Expand All @@ -89,23 +90,50 @@ def init_components(

@abstractmethod
def deinit_components(self):
"""Deinits the runs components."""
# only for ddp
"""Deinits the runs components.
In distributed mode should destroy process group.
"""
pass

@abstractmethod
def zero_grad(self, loss, model, optimizer) -> None:
"""Abstraction over ``model.zero_grad()`` step."""
"""Abstraction over ``model.zero_grad()`` step.
Should be overloaded in cases when required to set arguments
for ``model.zero_grad()`` like `set_to_none=True` or
you need to use custom scheme which replaces/improves
`.zero_grad()` method.
Args:
loss: tensor with loss value.
model: model module.
optimizer: model optimizer.
"""
pass

@abstractmethod
def backward_loss(self, loss, model, optimizer) -> None:
"""Abstraction over ``loss.backward()`` step."""
"""Abstraction over ``loss.backward()`` step.
Should be overloaded in cases when required loss scaling.
Examples - APEX and AMP.
Args:
loss: tensor with loss value.
model: model module.
optimizer: model optimizer.
"""
pass

@abstractmethod
def optimizer_step(self, loss, model, optimizer) -> None:
"""Abstraction over ``optimizer.step()`` step."""
"""Abstraction over ``optimizer.step()`` step.
Should be overloaded in cases when required gradient scaling.
Example - AMP.
Args:
loss: tensor with loss value.
model: model module.
optimizer: model optimizer.
"""
pass

@abstractmethod
Expand Down Expand Up @@ -174,7 +202,8 @@ def load_checkpoint(self, path: str) -> Dict:
pass

def autocast(self, *args, **kwargs):
"""AMP scaling context. Default autocast context does not scale anything.
"""AMP scaling context.
Default autocast context does not scale anything.
Args:
*args: some args
Expand Down
98 changes: 94 additions & 4 deletions catalyst/engines/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,35 @@ class AMPEngine(DeviceEngine):
Args:
device: used device, default is `"cuda"`.
Examples:
.. code-block:: python
from catalyst import dl
class MyRunner(dl.IRunner):
# ...
def get_engine(self):
return dl.AMPEngine("cuda:1")
# ...
.. code-block:: yaml
args:
logs: ...
model:
_target_: ...
...
engine:
_target_: AMPEngine
device: cuda:1
stages:
...
"""

def __init__(self, device: str = "cuda"):
Expand All @@ -36,7 +65,36 @@ def autocast(self):


class DataParallelAMPEngine(AMPEngine):
"""AMP multi-gpu training device engine."""
"""AMP multi-gpu training device engine.
Examples:
.. code-block:: python
from catalyst import dl
class MyRunner(dl.IRunner):
# ...
def get_engine(self):
return dl.DataParallelAMPEngine()
# ...
.. code-block:: yaml
args:
logs: ...
model:
_target_: ...
...
engine:
_target_: DataParallelAMPEngine
stages:
...
"""

def __init__(self):
"""Init."""
Expand Down Expand Up @@ -75,10 +133,42 @@ class DistributedDataParallelAMPEngine(DistributedDataParallelEngine):
"""Distributed AMP multi-gpu training device engine.
Args:
address: process address to use (required for PyTorch backend), default is `"localhost"`.
port: process port to listen (required for PyTorch backend), default is `"12345"`.
backend: multiprocessing backend to use, default is `"nccl"`.
address: process address to use
(required for PyTorch backend), default is `"localhost"`.
port: process port to listen
(required for PyTorch backend), default is `"12345"`.
backend: multiprocessing backend to use,
default is `"nccl"`.
world_size: number of processes.
Examples:
.. code-block:: python
from catalyst import dl
class MyRunner(dl.IRunner):
# ...
def get_engine(self):
return dl.DistributedDataParallelAMPEngine(port=12345)
# ...
.. code-block:: yaml
args:
logs: ...
model:
_target_: ...
...
engine:
_target_: DistributedDataParallelAMPEngine
port: 12345
stages:
...
"""

def __init__(
Expand Down

0 comments on commit 5389208

Please sign in to comment.