Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does the test accuracy need to be synchronized in distributed.py? #1

Closed
yifanjiang19 opened this issue Dec 25, 2019 · 7 comments
Closed
Assignees
Labels
question Further information is requested

Comments

@yifanjiang19
Copy link

If directly output the test accuracy, will the code automatically synchronize the accuracy between each GPUs?

@tczhangzhi tczhangzhi added the question Further information is requested label Dec 25, 2019
@tczhangzhi
Copy link
Owner

tczhangzhi commented Dec 25, 2019

Nope, if u really need it, u can use .share_memory() to share a Tensor's memory.
All in all, most distributed lib only help u to handle the synchronization of data, parameters, and gradient.

@yifanjiang19
Copy link
Author

yifanjiang19 commented Dec 26, 2019

Could you give a specific example?
Thanks!

@tczhangzhi
Copy link
Owner

hm, m afraid that's not right.
Here are two ways to communicate between torch.multiprocessing:

  1. if u dont care the running results, u can use share_memory_ like this, which is more faster:
import time
import random

import torch
import torch.distributed as dist
import torch.multiprocessing as mp


def evaluate(rank):
    torch.cuda.manual_seed(rank)
    local_acc = torch.randn(1)[0].cuda(rank)

    print("local_acc:", local_acc)

    return local_acc

def main_worker(gpu, ngpus_per_node, args):
    dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:23456', world_size=4, rank=gpu)

    local_acc = evaluate(gpu)

    global_acc, global_count = args['global_acc'], args['global_count']

    global_acc += local_acc.cpu()
    global_count += 1

    print("global_acc:", global_acc / global_count)

if __name__ == '__main__':
    global_acc = torch.tensor(.0)
    global_count = torch.tensor(.0)
    
    global_acc.share_memory_()
    global_count.share_memory_()

    args = {
        'global_acc': global_acc,
        'global_count': global_count
    }
    
    mp.spawn(main_worker, nprocs=4, args=(4, args))

But if you really need to synchronize the accuracy, I suggest this kind of implement or something else using all_reduce:

import time
import random

import torch
import torch.distributed as dist
import torch.multiprocessing as mp


def evaluate(rank):
    torch.cuda.manual_seed(rank)
    local_acc = torch.randn(1)[0].cuda(rank)

    print("local_acc:", local_acc)

    return local_acc

def main_worker(gpu, ngpus_per_node, args):
    dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:23456', world_size=4, rank=gpu)

    local_acc = evaluate(gpu)

    dist.all_reduce(local_acc, op=dist.reduce_op.SUM)
    global_acc = local_acc / ngpus_per_node

    print("global:", global_acc)

if __name__ == '__main__':
    args = {}
    mp.spawn(main_worker, nprocs=4, args=(4, args))

@tczhangzhi
Copy link
Owner

m not sure if u understand, if not u can directly use this code:

acc1, acc5 = accuracy(output, target, topk=(1, 5))
...
dist.all_reduce(acc1, op=dist.reduce_op.SUM)
...
top1.update(acc1[0] / 4 , images.size(0))

Btw, I don't think we really need to calculate the average accuracy during training, which is the waste of time.

@yifanjiang19
Copy link
Author

Thanks!
Should the code synchronize the loss between each gpus before loss.backward()? Or the backward function will synchronize automatically?

@tczhangzhi
Copy link
Owner

No. That's DistributedDataParallel's job.
Wrap your model with DistributedDataParallel and just call backward() as usual. During the backwards pass, gradients from each node are averaged (same as your saying "synchronize the loss") and parameters are synchronized automatically.
Check it here: https://github.com/pytorch/pytorch/blob/46539eee0363e25ce5eb408c85cefd808cd6f878/torch/nn/parallel/distributed.py#L378-L382

@tczhangzhi tczhangzhi self-assigned this Dec 27, 2019
@yifanjiang19
Copy link
Author

thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants