Skip to content

Commit

Permalink
Fixing logging and errors blocking multi GPU trianing of Torch models (
Browse files Browse the repository at this point in the history
…#1509)

* added fix for multi GPU as per https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html#automatic-logging

* trying to add complete logging in case of distributed to avoid deadlock

* fixing the logging on epoch end for multigpu training

* Black fixes for formatting errors

* Added description of multi GPU setup do User Guide.

---------

Co-authored-by: Julien Herzen <julien@unit8.co>
Co-authored-by: madtoinou <32447896+madtoinou@users.noreply.github.com>
  • Loading branch information
3 people committed Feb 21, 2023
1 parent 690b6f4 commit 955e2b5
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
17 changes: 15 additions & 2 deletions darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,13 @@ def training_step(self, train_batch, batch_idx) -> torch.Tensor:
-1
] # By convention target is always the last element returned by datasets
loss = self._compute_loss(output, target)
self.log("train_loss", loss, batch_size=train_batch[0].shape[0], prog_bar=True)
self.log(
"train_loss",
loss,
batch_size=train_batch[0].shape[0],
prog_bar=True,
sync_dist=True,
)
self._calculate_metrics(output, target, self.train_metrics)
return loss

Expand All @@ -159,7 +165,13 @@ def validation_step(self, val_batch, batch_idx) -> torch.Tensor:
output = self._produce_train_output(val_batch[:-1])
target = val_batch[-1]
loss = self._compute_loss(output, target)
self.log("val_loss", loss, batch_size=val_batch[0].shape[0], prog_bar=True)
self.log(
"val_loss",
loss,
batch_size=val_batch[0].shape[0],
prog_bar=True,
sync_dist=True,
)
self._calculate_metrics(output, target, self.val_metrics)
return loss

Expand Down Expand Up @@ -274,6 +286,7 @@ def _calculate_metrics(self, output, target, metrics):
on_step=False,
logger=True,
prog_bar=True,
sync_dist=True,
)

def configure_optimizers(self):
Expand Down
38 changes: 38 additions & 0 deletions docs/userguide/gpu_and_tpu_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,44 @@ Epoch 299: 100% 8/8 [00:00<00:00, 39.81it/s, loss=0.00285, v_num=logs]

From the output we can see that the GPU is both available and used. The rest of the code doesn't require any change, i.e. it's irrelevant if we are using a GPU or CPU.

### Multi GPU support

Darts utilizes [Lightning's multi GPU capabilities](https://pytorch-lightning.readthedocs.io/en/stable/accelerators/gpu_intermediate.html) to be able to capitalize on scalable hardware.

Multiple parallelization strategies exist for multiple GPU training, which - because of different strategies for multiprocessing and data handling - interact strongly with the execution environment.

Currently in Darts the `ddp_spawn` distribution strategy is tested.

As per the description of the [Lightning documentation](https://pytorch-lightning.readthedocs.io/en/stable/accelerators/gpu_intermediate.html#distributed-data-parallel-spawn) has some noteworthy limitations, eg. it __can not run__ in:

- Jupyter Notebook, Google COLAB, Kaggle, etc.

- In case you have a nested script without a root package

This in practice means, that execution has to happen in a separate `.py` script, that has the following general context around the code executing the training:

```python
import torch

if __name__ == '__main__':

torch.multiprocessing.freeze_support()
```

The __main__ pattern is necessary (see [this](https://pytorch.org/docs/stable/notes/windows.html#multiprocessing-error-without-if-clause-protection)) even when your execution __does not__ happen in a windows environment.

Beyond this, no other major modification to your models is necessary other than allowing multi GPU training in the `pl_trainer_args` for example like

`pl_trainer_kwargs = {"accelerator": "gpu", "devices": -1, "auto_select_gpus": True}`

This method automatically selects all available GPUs for training. Manual setting of the number of devices is also possible.

The `ddp` family of strategies creates indiviual subprocesses for each GPU, so contents of the memory (notably the `Dataloder`) gets copied over. Thus, as per the [description of lightning docs](https://pytorch-lightning.readthedocs.io/en/stable/accelerators/gpu_intermediate.html#distributed-data-parallel) caution is advised in setting the `Dataloader(num_workers=N)` too high, since according to it:

"Dataloader(num_workers=N), where N is large, bottlenecks training with DDP… ie: it will be VERY slow or won’t work at all. This is a PyTorch limitation."

Usage of other distribution strategies with Darts currently _might_ very well work, but are yet untested and subject to individual setup / experimentation.

## Use a TPU

Tensor Processing Unit (TPU) is an AI accelerator application-specific integrated circuit (ASIC) developed by Google specifically for neural network machine learning.
Expand Down

0 comments on commit 955e2b5

Please sign in to comment.