Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Profiler occupies unassigned GPU memory with DDP #37313

Closed
Godricly opened this issue Apr 26, 2020 · 4 comments
Closed

Profiler occupies unassigned GPU memory with DDP #37313

Godricly opened this issue Apr 26, 2020 · 4 comments
Labels
module: regression It used to work, and now it doesn't oncall: profiler profiler-related issues (cpu, gpu, kineto) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Godricly
Copy link

Godricly commented Apr 26, 2020

馃悰 Bug

The profiler creates allocated extra memory on the GPU thats not suppose to use.

To Reproduce

Steps to reproduce the behavior:

  • main.py
import logging
import torch
import torchvision
import comm
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel


def main():

    model = torchvision.models.resnet50(pretrained=False) # skip model download
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = model.to(comm.get_local_rank())
    model = DistributedDataParallel(
        model,
        device_ids=[comm.get_local_rank()],
        broadcast_buffers=False,
        find_unused_parameters=True,
    )
           
    profiler = torch.autograd.profiler.profile(True, True, True)
    comm.synchronize()

    for i in range(100000):
        print(i)
        dummpy_input = torch.zeros(4,3,400,400)
        if i > 10 and i < 20:
            with torch.autograd.profiler.profile(True, True, True) as prof:
                out = model(dummpy_input)
        else:
            out = model(dummpy_input)


    print("rank {}: finish".format(comm.get_local_rank()))

def _find_free_port():
    import socket

    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    # Binding to port 0 will cause the OS to find an available port for us
    sock.bind(("", 0))
    port = sock.getsockname()[1]
    sock.close()
    # NOTE: there is still a chance the port could be taken by other processes.
    return port

def _distributed_worker(
    local_rank, main_func, world_size, num_gpus_per_machine, machine_rank, dist_url, args
):
    assert torch.cuda.is_available(), "cuda is not available. Please check your installation."
    global_rank = machine_rank * num_gpus_per_machine + local_rank
    try:
        dist.init_process_group(
            backend="NCCL", init_method=dist_url, world_size=world_size, rank=global_rank
        )
    except Exception as e:
        logger = logging.getLogger(__name__)
        logger.error("Process group URL: {}".format(dist_url))
        raise e
    # synchronize is needed here to prevent a possible timeout after calling init_process_group
    # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
    comm.synchronize()

    assert num_gpus_per_machine <= torch.cuda.device_count()
    torch.cuda.set_device(local_rank)

    # Setup the local process group (which contains ranks within the same machine)
    assert comm._LOCAL_PROCESS_GROUP is None
    num_machines = world_size // num_gpus_per_machine
    for i in range(num_machines):
        ranks_on_i = list(range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine))
        pg = dist.new_group(ranks_on_i)
        if i == machine_rank:
            comm._LOCAL_PROCESS_GROUP = pg

    main_func(*args)

if __name__ == "__main__":
    port = _find_free_port()
    dist_url = f"tcp://127.0.0.1:{port}"

    mp.spawn(
        _distributed_worker,
        nprocs=2,
        args=(main, 2, 2, 0, dist_url, ()),
        daemon=False,
    )

  • comm.py
import torch.distributed as dist


_LOCAL_PROCESS_GROUP = None


# some codes for maskrcnn
def get_local_rank() -> int:
    """The rank of the current process within the local machine.

    Returns:
        int: The rank of the current process within the local (per-machine) process group.
    """
    if not dist.is_available():
        return 0
    if not dist.is_initialized():
        return 0
    assert _LOCAL_PROCESS_GROUP is not None
    return dist.get_rank(group=_LOCAL_PROCESS_GROUP)

def synchronize():
    """Synchronization Function.

    Helper function to synchronize (barrier) among all processes when
    using distributed training
    """
    if not dist.is_available():
        return
    if not dist.is_initialized():
        return
    world_size = dist.get_world_size()
    if world_size == 1:
        return
    dist.barrier()

  1. download codes.
  2. run python main.py

Expected behavior

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 430.40       Driver Version: 430.40       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  GeForce GTX 108...  Off  | 00000000:04:00.0 Off |                  N/A |
| 29%   29C    P2    55W / 250W |   3662MiB / 11178MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  GeForce GTX 108...  Off  | 00000000:05:00.0 Off |                  N/A |
| 29%   30C    P2    54W / 250W |   3642MiB / 11178MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   2  GeForce GTX 108...  Off  | 00000000:08:00.0 Off |                  N/A |
| 29%   29C    P2    53W / 250W |   1228MiB / 11178MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   3  GeForce GTX 108...  Off  | 00000000:09:00.0 Off |                  N/A |
| 29%   29C    P2    53W / 250W |   1228MiB / 11178MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0     43400      C   /home/yan_li/python3env/bin/python          3043MiB |
|    0     43401      C   /home/yan_li/python3env/bin/python           609MiB |
|    1     43400      C   /home/yan_li/python3env/bin/python           609MiB |
|    1     43401      C   /home/yan_li/python3env/bin/python          3023MiB |
|    2     43400      C   /home/yan_li/python3env/bin/python           609MiB |
|    2     43401      C   /home/yan_li/python3env/bin/python           609MiB |
|    3     43400      C   /home/yan_li/python3env/bin/python           609MiB |
|    3     43401      C   /home/yan_li/python3env/bin/python           609MiB |
+-----------------------------------------------------------------------------+

Environment

Collecting environment information...
PyTorch version: 1.6.0.dev20200424+cu101
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: Ubuntu 18.04.4 LTS
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
CMake version: version 3.10.2

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.1.243
GPU models and configuration: 
GPU 0: GeForce GTX 1080 Ti
GPU 1: GeForce GTX 1080 Ti
GPU 2: GeForce GTX 1080 Ti
GPU 3: GeForce GTX 1080 Ti

Nvidia driver version: 430.40
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] numpy==1.18.0
[pip3] torch==1.6.0.dev20200424+cu101
[pip3] torchvision==0.7.0.dev20200424+cu101
[conda] Could not collect

Additional context

some codes shares with #36136 .

@Godricly
Copy link
Author

pytorch 1.4 works fine. FYI.

@izdeby izdeby added oncall: profiler profiler-related issues (cpu, gpu, kineto) module: regression It used to work, and now it doesn't triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 26, 2020
@xwang233
Copy link
Collaborator

Now that profiler event fix is merged #39962. Can you verify if nightly release fixes the issue?

@Godricly
Copy link
Author

Godricly commented Jul 1, 2020

its not fixed for me. I used 20200625 version.

@Quentin-Anthony
Copy link

Same issue with PyTorch 1.8.0a0+810a1b9

Are there any current plans to resolve this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: regression It used to work, and now it doesn't oncall: profiler profiler-related issues (cpu, gpu, kineto) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants