Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cfg.DIST_ENABLE false fail #36

Open
bhack opened this issue Mar 3, 2023 · 21 comments
Open

cfg.DIST_ENABLE false fail #36

bhack opened this issue Mar 3, 2023 · 21 comments

Comments

@bhack
Copy link

bhack commented Mar 3, 2023

With cfg.DIST_ENABLE false all the distributed specific parts are not wrapped by a condition to check if distributed was enabled or not.

@z-x-yang
Copy link
Collaborator

z-x-yang commented Mar 3, 2023

self.DIST_ENABLE is necessary for multi-GPU training.

@bhack
Copy link
Author

bhack commented Mar 3, 2023

I meant that I am trying to test single GPU.
In many places cfg.DIST_ENABLE is checked to safely go through the non distributed code path.

But in many other places not:

self.train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset)
self.train_loader = DataLoader(train_dataset,
batch_size=int(cfg.TRAIN_BATCH_SIZE /
cfg.TRAIN_GPUS),
shuffle=False,
num_workers=cfg.DATA_WORKERS,
pin_memory=True,
sampler=self.train_sampler,
drop_last=True,
prefetch_factor=4)

@bhack
Copy link
Author

bhack commented Mar 7, 2023

E.g. here instead the code is checking cfg.DIST_ENABLE

if cfg.DIST_ENABLE:
dist.init_process_group(backend=cfg.DIST_BACKEND,
init_method=cfg.DIST_URL,
world_size=cfg.TRAIN_GPUS,
rank=rank,
timeout=datetime.timedelta(seconds=300))
self.model.encoder = nn.SyncBatchNorm.convert_sync_batchnorm(
self.model.encoder).cuda(self.gpu)
self.dist_engine = torch.nn.parallel.DistributedDataParallel(
self.engine,
device_ids=[self.gpu],
output_device=self.gpu,
find_unused_parameters=True,
broadcast_buffers=False)
else:
self.dist_engine = self.engine
self.use_frozen_bn = False
if 'swin' in cfg.MODEL_ENCODER:
self.print_log('Use LN in Encoder!')
elif not cfg.MODEL_FREEZE_BN:
if cfg.DIST_ENABLE:
self.print_log('Use Sync BN in Encoder!')

@bhack bhack changed the title self.DIST_ENABLE false fail cfg.DIST_ENABLE false fail Mar 7, 2023
@z-x-yang
Copy link
Collaborator

z-x-yang commented Mar 7, 2023

The distributed sampler is useless and meaningless for evaluation, where GPUs are asynchronous instead of synchronous. The video lengths are always different for different GPUs.

@bhack
Copy link
Author

bhack commented Mar 7, 2023

Is this not in the trainer?
I am trying to train a single GPU job with cfg.DIST_ENABLE=False

@z-x-yang
Copy link
Collaborator

z-x-yang commented Mar 7, 2023 via email

@bhack
Copy link
Author

bhack commented Mar 7, 2023

Yes it is what I meant. Are we not going to have issue if we don't conditional wrap torch.nn.parallel.DistributedDataParallel in the trainer?

@z-x-yang
Copy link
Collaborator

z-x-yang commented Mar 7, 2023 via email

@bhack
Copy link
Author

bhack commented Mar 7, 2023

Isn't that one going to require init_process_group? But it is conditional wrapped in the trainer.

if cfg.DIST_ENABLE:
dist.init_process_group(backend=cfg.DIST_BACKEND,
init_method=cfg.DIST_URL,
world_size=cfg.TRAIN_GPUS,
rank=rank,
timeout=datetime.timedelta(seconds=300))

@z-x-yang
Copy link
Collaborator

z-x-yang commented Mar 7, 2023 via email

@bhack
Copy link
Author

bhack commented Mar 7, 2023

E.g. with self.DIST_ENABLE = False in configs/default.py we are going to fail directly at:

self.train_sampler = torch.utils.data.distributed.DistributedSampler(

RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.

Cause:
#36 (comment)

As we don't have completely covered cfg.DIST_ENABLE safeguards and the relatrive alternative code path.

@bhack
Copy link
Author

bhack commented Mar 8, 2023

Sorry, I don't understand what you mean.

Is it clear now?

@z-x-yang
Copy link
Collaborator

z-x-yang commented Mar 8, 2023

I have updated the trainer.py, and it should be ok to set self.DIST_ENABLE = False and train with a single GPU.

@bhack
Copy link
Author

bhack commented Mar 8, 2023

Thanks, I've checked your changes and they were the same I've done locally on my side in these days.

What do you think now about this (pytorch/pytorch#37444)?

# Use torch.multiprocessing.spawn to launch distributed processes
mp.spawn(main_worker, nprocs=cfg.TRAIN_GPUS, args=(cfg, args.amp))

@bhack
Copy link
Author

bhack commented Mar 8, 2023

Other then my previous comment I think that we still have now an issue with the keys with DIST_ENABLE=False:

for key in boards['image'].keys():
tmp = boards['image'][key].cpu().numpy()
self.tblogger.add_image('S{}/' + key, tmp, step)
for key in boards['scalar'].keys():
tmp = boards['scalar'][key].cpu().numpy()
self.tblogger.add_scalar('S{}/' + key, tmp, step)

    for key in boards['image'].keys():
AttributeError: 'list' object has no attribute 'keys'

@z-x-yang
Copy link
Collaborator

z-x-yang commented Mar 8, 2023

It should be ok now.

@bhack
Copy link
Author

bhack commented Mar 8, 2023

It seems that we have two issues:

  • The first is that the trainer it seems to be "randomly" deadlocked on different runs same code
  • images in img_logs seems to be not correct anymore as they are only the binary masks instead they was "composited" if I remember correctly.

@bhack
Copy link
Author

bhack commented Mar 8, 2023

For the first issue this is the stacktrace of one of the deadlock and it could be related to #36 (comment):

  File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 240, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 109, in join
    ready = multiprocessing.connection.wait(
  File "/opt/conda/lib/python3.10/multiprocessing/connection.py", line 936, in wait
    ready = selector.select(timeout)
  File "/opt/conda/lib/python3.10/selectors.py", line 416, in select
    fd_event_list = self._selector.poll(timeout)

@z-x-yang
Copy link
Collaborator

z-x-yang commented Mar 8, 2023

I guess the first problem is due to torch.spawn, and I have modified the code related to it. Please take a try, and hope this will work for you.

As to the second issue, img_logs should be images where target objects are marked with colorful masks, referring to here.

@bhack
Copy link
Author

bhack commented Mar 8, 2023

I guess the first problem is due to torch.spawn, and I have modified the code related to it. Please take a try, and hope this will work for you.

Yes, it seems ok now.

As to the second issue, img_logs should be images where target objects are marked with colorful masks, referring to here.

Yes sorry this was a false positive related to the image range in the compositing phase, thanks for the double check.

@bhack bhack closed this as completed Mar 8, 2023
@bhack
Copy link
Author

bhack commented Dec 7, 2023

The same need to be fixed in the PAOT branch.

@bhack bhack reopened this Dec 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants