Skip to content

Conversation

vfdev-5
Copy link
Collaborator

@vfdev-5 vfdev-5 commented Aug 3, 2019

Fixes #568

Description: I created "distrib" branch to work on #568 adapting our code to compute metrics while in distributed configuration. Idea is to merge to this branch and test the code in various conditions. Iteratively we can improve the code by merging to the branch before merging to master.

Tests are added and passes on single node 2 GPU config. Test pytest.fixture initialize every time 'nccl' backended group and runs various code tests. Sometimes running all the tests in a single run, they can stuck, but passes separately.

Check list:

  • New tests are added (if a new feature is added)
  • New doc strings: description and/or example code are in RST format
  • Documentation is updated (if required)

cc @zasdfgbnm if you could take a look and give your opinion, would be awesome ! thanks

* Fixes issue pytorch#543

Previous CM implementation suffered from the problem if target contains non-contiguous indices.
New implementation is almost taken from torchvision's https://github.com/pytorch/vision/blob/master/references/segmentation/utils.py#L75-L117

This commit also removes the case of targets as (batchsize, num_categories, ...) where num_categories excludes background class.
Confusion matrix computation is possible almost similarly for (batchsize, ...), but when target is all zero (0, ..., 0)  = no classes (background class),
then confusion matrix does not count any true/false predictions.

* Update confusion_matrix.py
@zasdfgbnm
Copy link
Contributor

Hi @vfdev-5,

I don't have any experience with distributed training, so I don't think I can give constructive suggestions. But I have some general comments. Ignore me if I am saying something wrong.

  1. Why should metric worry about synchronization? Synchronization looks to me like something that should be taken care of by the engine instead.
  2. in _sync_all_reduce, there is a torch.distributed.all_reduce(tensor), which assumes op=ReduceOp.SUM, is this a general behavior?
  3. If I am a user that doesn't care about distributed training, but I want to define my custom metrics. Is it still as easy as before? Especially, now I need to manually specify devices.

@vfdev-5
Copy link
Collaborator Author

vfdev-5 commented Aug 6, 2019

@zasdfgbnm thank you for the review and your comments ! I'll try to give my point of view on this.

Why should metric worry about synchronization? Synchronization looks to me like something that should be taken care of by the engine instead.

Originally, I saw the synchronization in maskrcnn-benchmark and now in torchvision references, when they print metrics: https://github.com/pytorch/vision/blob/c187c2b12d86c3909e59a40dbe49555d85b98703/references/classification/train.py#L69
and in other distributed examples as horovod: https://github.com/horovod/horovod/blob/a72bc96a0f87a8a7c666fff005f1afa21e95b972/examples/pytorch_imagenet_resnet50.py#L256

As far as I understand synchronization goes together with a reduction op, so I did not think separate them.
Another point, as engine is not aware of the content of its handlers, it does know a priori when to call the synchronization. By default, this make sens to call it on Events.EPOCH_COMPLETED and all similar custom period events, but if user calls metric.compute somewhere else it maybe can cause a problem.

in _sync_all_reduce, there is a torch.distributed.all_reduce(tensor), which assumes op=ReduceOp.SUM, is this a general behavior?

Yes, you are right, in most of our metrics, they are computed by cumulating things as sum of something and number of samples. In these cases, ReduceOp.SUM works well for us. Metrics like EpochMetric, multilabel not averaged Precision or Recall when they store whole history of predictions, ReduceOp.SUM wont give the correct answer. In these cases we need to use gather operation (not implemented and maybe we will have memory issues). Decorator on compute by default gives the reduce op sum by default and for particular thing we/user need to handle this case by case.

If I am a user that doesn't care about distributed training, but I want to define my custom metrics. Is it still as easy as before? Especially, now I need to manually specify devices.

I tried to answer this question here : https://github.com/pytorch/ignite/blob/99a6b4a515acacc84bb7438f0e6afebcb6378d70/docs/source/metrics.rst#how-to-create-a-custom-metric

After rereading this I think we need to add more info about how to play with distributed decorators and device as you asked.

In case of non distributed computations, I would say everything remain the same (even without decorators).

Thank you again for the comments and please let me know what do you think ?

@vfdev-5
Copy link
Collaborator Author

vfdev-5 commented Aug 6, 2019

I changed the behaviour of device in distributed case, previously it raised an error if nothing is specified, now it takes silently default "cuda".

Another remark, in RunningAverage we do not collect the result across the devices, so when printed on rank=0 it shows the values for this rank only. Maybe we would like to collect the values, but as it is called every iteration, maybe this can produce an overhead...

Copy link
Contributor

@zasdfgbnm zasdfgbnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @vfdev-5, the change is big, and I just started looking at it. I am a little busy recently so I won't finish very quick. But I will keep reading and discussing. I do have some general suggestions besides my few comments on the code. But in general, do you think it would be better if we could create an ignite/metrics/distributed.py, and move the decorators and _sync_all_reduce inside that file. I mean, for the Metric class we do not make any changes. Also, if the methods that need to decorate is the same, why not create a single decorator, so that the user could use it like:

from ignite.metrics.dist import all_reduced_synchronized

@all_reduced_synchronized('myvar1', 'myvar2')
class MyMetric:
    def .....

if isinstance(tensor, torch.Tensor):
# check if the tensor is at specified device
if tensor.device != self._device:
tensor = tensor.to(self._device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know about distributed training, but what would happen if we don't do so? This is only called in self.compute, correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understood, if we have multiple gpus and call an op as reduce, than internally it collect data data from all gpus. In case if tensor is not defined on one of the gpus, the operation would hang...

So, in update we do not force all internal variables to be on the associated device (rank), but in compute it is necessary.

return wrapper


def reinit_is_reduced(func):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name reinit_is_reduced might look confusing to users. Without looking at the source code of Metric, it is hard to know what is is_reduced and why it needs reinit. BTW: I only see code that write this variable, but nothing is reading it. Am I correct?

Copy link
Collaborator Author

@vfdev-5 vfdev-5 Aug 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really ignite's internal method and should not be used by users except to decorate custom metrics. The purpose of the method is to ensure that compute always results the same output.

"""
pass

def _sync_all_reduce(self, tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we define this inside the body of sync_all_reduce, instead of here?

@vfdev-5
Copy link
Collaborator Author

vfdev-5 commented Aug 6, 2019

@zasdfgbnm thank you for the review !

But in general, do you think it would be better if we could create an ignite/metrics/distributed.py, and move the decorators and _sync_all_reduce inside that file. I mean, for the Metric class we do not make any changes.

I thought in the begining about this and as ignite/metrics/distributed.py will be a sort of helper class, I thought this can be confusing for users. As we already have metric.py file which contains only abstract class I thought we could put other things inside.

Also, if the methods that need to decorate is the same, why not create a single decorator, so that the user could use it like.

Yes, this could be helpful in majority of metrics, except some corner cases as Precision, Recall etc where we need to use reinit_is_reduced and not sync_all_reduce (or use it in some case). Actually, now I find out that I forgot to setup _is_reduced properly in these classes.

The purpose of _is_reduced flag is to keep compute method in distributed same as not distributed:
a) not dist

r1 = m.compute()
r2 = m.compute()
assert r1 == r2

b) dist without reinit_is_reduced

r1 = m.compute()
r2 = m.compute()
assert r1 != r2

Because, all_reduce will collect on every compute call.

* Added mlflow logger without tests

* Added mlflow tests, updated mlflow logger code and other tests

* Updated docs and added mlflow in travis

* Added tests for mlflow OptimizerParamsHandler
- additionally added OptimizerParamsHandler for plx with tests
vfdev-5 and others added 6 commits August 15, 2019 19:40
* Update .travis.yml

* Update .travis.yml

* Fixed tests and improved travis
* Update .travis.yml

* Update .travis.yml

* Fixed tests and improved travis

* Fixes SSL problem to download model weights
* Add tests for event removable handle.

Add feature tests for engine.add_event_handler returning removable event
handles.

* Return RemovableEventHandle from Engine.add_event_handler.

* Fixup removable event handle test in python 2.7.

Explicitly trigger gc, allowing cycle detection between engine and
state, in removable handle weakref test. Python 2.7 cycle detection
appears to be less aggressive than python 3+.

* Add removable event handler docs.

Add autodoc configuration for RemovableEventHandler, expand "concepts"
documentation with event remove example following event add example.

* Update concepts.rst
@vfdev-5 vfdev-5 removed the request for review from anmolsjoshi August 29, 2019 22:24
@vfdev-5 vfdev-5 merged commit 2036075 into pytorch:distrib Aug 30, 2019
vfdev-5 added a commit that referenced this pull request Oct 24, 2019
* [WIP] Added cifar10 distributed example

* [WIP] Metric with all reduce decorator and tests

* [WIP] Added tests for accumulation metric

* [WIP] Updated with reinit_is_reduced

* [WIP] Distrib adaptation for other metrics

* [WIP] Warnings for EpochMetric and Precision/Recall when distrib

* Updated metrics and tests to run on distributed configuration
- Test on 2 GPUS single node
- Added cmd in .travis.yml to indicate how to test locally
- Updated travis to run tests in 4 processes

* Minor fixes and cosmetics

* Fixed bugs and improved contrib/cifar10 example

* Updated docs

* Update metrics.rst

* Updated docs and set device as "cuda" in distributed instead of raising error

* [WIP] Fix missing _is_reduced in precision/recall with tests

* Updated other tests

* Updated travis and renamed tbptt test gpu -> cuda

* Distrib (#573)

* [WIP] Added cifar10 distributed example

* [WIP] Metric with all reduce decorator and tests

* [WIP] Added tests for accumulation metric

* [WIP] Updated with reinit_is_reduced

* [WIP] Distrib adaptation for other metrics

* [WIP] Warnings for EpochMetric and Precision/Recall when distrib

* Updated metrics and tests to run on distributed configuration
- Test on 2 GPUS single node
- Added cmd in .travis.yml to indicate how to test locally
- Updated travis to run tests in 4 processes

* Minor fixes and cosmetics

* Fixed bugs and improved contrib/cifar10 example

* Updated docs

* Fixes issue #543 (#572)

* Fixes issue #543

Previous CM implementation suffered from the problem if target contains non-contiguous indices.
New implementation is almost taken from torchvision's https://github.com/pytorch/vision/blob/master/references/segmentation/utils.py#L75-L117

This commit also removes the case of targets as (batchsize, num_categories, ...) where num_categories excludes background class.
Confusion matrix computation is possible almost similarly for (batchsize, ...), but when target is all zero (0, ..., 0)  = no classes (background class),
then confusion matrix does not count any true/false predictions.

* Update confusion_matrix.py

* Update metrics.rst

* Updated docs and set device as "cuda" in distributed instead of raising error

* [WIP] Fix missing _is_reduced in precision/recall with tests

* Updated other tests

* Added mlflow logger (#558)

* Added mlflow logger without tests

* Added mlflow tests, updated mlflow logger code and other tests

* Updated docs and added mlflow in travis

* Added tests for mlflow OptimizerParamsHandler
- additionally added OptimizerParamsHandler for plx with tests

* Update to PyTorch v1.2.0 (#580)

* Update .travis.yml

* Update .travis.yml

* Fixed tests and improved travis

* Fix SSL problem of failing travis (#581)

* Update .travis.yml

* Update .travis.yml

* Fixed tests and improved travis

* Fixes SSL problem to download model weights

* Fixed travis for deploy and nightly

* Fixes #583 (#584)

* Fixes docs build warnings (#585)

* Return removable handle from Engine.add_event_handler(). (#588)

* Add tests for event removable handle.

Add feature tests for engine.add_event_handler returning removable event
handles.

* Return RemovableEventHandle from Engine.add_event_handler.

* Fixup removable event handle test in python 2.7.

Explicitly trigger gc, allowing cycle detection between engine and
state, in removable handle weakref test. Python 2.7 cycle detection
appears to be less aggressive than python 3+.

* Add removable event handler docs.

Add autodoc configuration for RemovableEventHandler, expand "concepts"
documentation with event remove example following event add example.

* Update concepts.rst

* Updated travis and renamed tbptt test gpu -> cuda

* Compute IoU, Precision, Recall based on CM on CPU

* Fixes incomplete merge with 1856c8e

* Update distrib branch and CIFAR10 example (#647)

* Added tests with gloo, minor updates and fixes

* Added single/multi node tests with gloo and [WIP] with nccl

* Added tests for multi-node nccl, improved examples/contrib/cifar10 example

* Experiments: 1n1gpu, 1n2gpus, 2n2gpus

* Fix flake8

* Fixes #645 (#646)

- fix CI and improve create_lr_scheduler_with_warmup

* Fix tests for python 2.7

* Finalized Cifar10 example (#649)

* Added gcp tb logger image and updated README

* Added gcp ai platform scripts to run trainings

* Improved docs and readmes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants