Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct signatures for torch allocator plug in #1407

Merged
merged 5 commits into from
Dec 14, 2023

Conversation

wence-
Copy link
Contributor

@wence- wence- commented Dec 12, 2023

Description

Since pytorch/pytorch#91398, the signature of the pluggable allocate and deallocate functions must accept the device id. The current version only accepts a device id for allocate, which means that when using a stream ordered allocator with devices other than device zero, we pass an invalid stream into the deallocation function. To fix this, adapt the signature to match the one pytorch expects.

Now, since we have the device available during allocation and deallocation, we would like to use that device to obtain the appropriate memory resource.

Unfortunately, since RMM's cuda_device_id does not have a nullary constructor, we can't use it in Cython without some hacky workarounds.

However, since we don't actually need to build a Python module, but rather just a single shared library that offers two extern "C" functions, let's just write our allocator hooks directly in C++.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

@github-actions github-actions bot added the Python Related to RMM Python API label Dec 12, 2023
@github-actions github-actions bot added the CMake label Dec 12, 2023
@wence- wence- added bug Something isn't working non-breaking Non-breaking change labels Dec 12, 2023
@wence- wence- marked this pull request as ready for review December 12, 2023 18:11
@wence- wence- requested review from a team as code owners December 12, 2023 18:12
@wence- wence- changed the title WIP: Correct signatures for torch allocator plug in Correct signatures for torch allocator plug in Dec 12, 2023
Since pytorch/pytorch#91398, the signature of
the pluggable allocate and deallocate functions must accept the device
id. The current version only accepts a device id for allocate, which
means that when using a stream ordered allocator with devices other
than device zero, we pass an invalid stream into the deallocation
function. To fix this, adapt the signature to match the one pytorch
expects.

Now, since we have the device available during allocation and
deallocation, we would like to use that device to obtain the
appropriate memory resource.

Unfortunately, since RMM's cuda_device_id does not have a nullary
constructor, we can't use it in Cython without some hacky workarounds.

However, since we don't actually need to build a Python module, but
rather just a single shared library that offers two extern "C"
functions, let's just write our allocator hooks directly in C++.

- Closes rapidsai#1405
python/rmm/allocators/torch.py Outdated Show resolved Hide resolved
python/rmm/_lib/_torch_allocator.cpp Outdated Show resolved Hide resolved
@wence-
Copy link
Contributor Author

wence- commented Dec 12, 2023

@shwina I think there was some discussion when that pytorch PR went through about whether the device should be in the signature at all. Here I am explicitly honouring it, because I think that is the right thing to do.

python/rmm/allocators/torch.py Outdated Show resolved Hide resolved
python/rmm/allocators/torch.py Outdated Show resolved Hide resolved
python/rmm/allocators/torch.py Outdated Show resolved Hide resolved
@harrism
Copy link
Member

harrism commented Dec 13, 2023

Unfortunately, since RMM's cuda_device_id does not have a nullary constructor, we can't use it in Cython without some hacky workarounds.

Why do you need a nullary ctor? It's probably OK to just add {} to the id_ member of that struct if needed.

However, since the device is being passed to the function, it may be better to use rmm::cuda_set_device_raii. This will call rmm::get_current_cuda_device(), which should be pretty cheap. It will only call cuda_set_device() if the current device is different than the specified device. But this is a way to ensure we are safe no matter what PyTorch is doing with devices behind the scenes (since they don't document the semantics).

I really prefer NOT to do this if we can find out the pyTorch semantics, because we shouldn't have to query and set the current device on every device memory allocation / deallocation... But this would be a good solution if we need it.

@wence-
Copy link
Contributor Author

wence- commented Dec 13, 2023

Why do you need a nullary ctor?

Due to the way cython transpiles, it complains if you have:

cdef extern from "*":
   cdef cppclass cuda_device_id:
       cuda_device_id(int)


cdef public void allocate(int device):
    cdef cuda_device_id id_ = cuda_device_id(device)
    return

Produces:

Error compiling Cython file:
------------------------------------------------------------
...
   cdef cppclass cuda_device_id:
       cuda_device_id(int)


cdef public void allocate(int device):
    cdef cuda_device_id id_ = cuda_device_id(device)
                        ^
------------------------------------------------------------

cython_cons.pyx:7:24: C++ class must have a nullary constructor to be stack allocated

But having moved this plug in to being written directly in C++ I don't need to worry about workarounds for this

PyTorch does not guarantee that plug-in allocator functions are called
with the active device matching the device on which the
allocation/deallocation is requested. Hence, we must explicitly handle
that here by selecting the appropriate device.
Copy link
Member

@harrism harrism left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Just one nit.

I wonder if we should benchmark before and after due to the repeated device queries this will result in -- I don't know if we have a benchmark though.

python/rmm/_lib/_torch_allocator.cpp Outdated Show resolved Hide resolved
@wence-
Copy link
Contributor Author

wence- commented Dec 14, 2023

I wonder if we should benchmark before and after due to the repeated device queries this will result in -- I don't know if we have a benchmark though.

I wrote something simple. I can't benchmark direct from C++ easily (I have not attempted to build and link against torch...). I check for amdahl (so small allocations) on a two-device system with. The allocation/deallocation I measure is just a loop doing torch.empty((100, 4), dtype=torch.int32)

  1. Pool enabled
  2. Pool disabled

And then use three different allocation functions:

  1. ORIG_CYTHON (code in trunk, with a fixed signature, rmm::mr::get_current_device_resource())
  2. BAD_CXX (rmm::mr::get_per_device_resource(device) without rmm::cuda_set_device_raii)
  3. GOOD_CXX (this PR, rmm::mr::get_per_device_resource(device) with rmm::cuda_set_device_raii)
Results

With the pool allocator, the new correct code is faster than the code in trunk. We pay a cost of 200ns if the devices match and ~800ns if the devices don't, relative to the BAD_CXX case for a call that takes ~3500ns.

Without the pool allocator the data are noisier and its a wash in terms of performance.

Using which=<Which.ORIG_CYTHON: 0> allocation functions; use_pool=True
active_device=0 allocating_device=0: 3.628e-06s
active_device=0 allocating_device=1: 4.107e-06s
active_device=1 allocating_device=0: 4.053e-06s
active_device=1 allocating_device=1: 3.651e-06s

Using which=<Which.BAD_CXX: 1> allocation functions; use_pool=True
active_device=0 allocating_device=0: 3.269e-06s
active_device=0 allocating_device=1: 3.468e-06s
active_device=1 allocating_device=0: 3.376e-06s
active_device=1 allocating_device=1: 3.25e-06s

Using which=<Which.GOOD_CXX: 2> allocation functions; use_pool=True
active_device=0 allocating_device=0: 3.441e-06s
active_device=0 allocating_device=1: 4.224e-06s
active_device=1 allocating_device=0: 4.236e-06s
active_device=1 allocating_device=1: 3.523e-06s

Using which=<Which.ORIG_CYTHON: 0> allocation functions; use_pool=False
active_device=0 allocating_device=0: 9.812e-05s
active_device=0 allocating_device=1: 9.97e-05s
active_device=1 allocating_device=0: 9.93e-05s
active_device=1 allocating_device=1: 9.929e-05s

Using which=<Which.BAD_CXX: 1> allocation functions; use_pool=False
active_device=0 allocating_device=0: 9.898e-05s
active_device=0 allocating_device=1: 9.873e-05s
active_device=1 allocating_device=0: 9.713e-05s
active_device=1 allocating_device=1: 9.829e-05s

Using which=<Which.GOOD_CXX: 2> allocation functions; use_pool=False
active_device=0 allocating_device=0: 9.649e-05s
active_device=0 allocating_device=1: 9.853e-05s
active_device=1 allocating_device=0: 9.752e-05s
active_device=1 allocating_device=1: 9.822e-05s

Benchmark code
import itertools
import time
import rmm
from rmm.allocators.torch import (
    rmm_torch_allocator,        # New CXX
    bad_torch_allocator,        # Bad CXX
    old_torch_allocator,        # Old Cython
)
import torch
from enum import IntEnum


class Which(IntEnum):
    ORIG_CYTHON = 0
    BAD_CXX = 1
    GOOD_CXX = 2


use_pool = False
if use_pool:
    REPS = 500_000
else:
    REPS = 50_000

rmm.reinitialize(pool_allocator=use_pool, devices=[0, 1])

which = Which.ORIG_CYTHON
if which is Which.ORIG_CYTHON:
    # Old cython code, used rmm::mr::get_current_device_resource()
    torch.cuda.change_current_allocator(old_torch_allocator)
elif which is Which.BAD_CXX:
    # C++ uses rmm::mr::get_per_device_resource(device) but not cuda_set_device_raii
    torch.cuda.change_current_allocator(bad_torch_allocator)
elif which is Which.GOOD_CXX:
    # This is the only one that is semantically valid
    # C++ uses rmm::mr::get_per_device_resource(device) and cuda_set_device_raii
    torch.cuda.change_current_allocator(rmm_torch_allocator)
else:
    raise AssertionError()


def run(active_device, allocating_device, *, N):
    with torch.cuda.device(active_device):
        start = time.time()
        for _ in range(N):
            torch.empty((100, 4), dtype=torch.int32, device=f"cuda:{allocating_device}")
        end = time.time()
        print(f"{active_device=} {allocating_device=}: {(end - start)/N:.4g}s")


print(f"Using {which=} allocation functions; {use_pool=}")
for active, allocating in itertools.product((0, 1), (0, 1)):
    run(active, allocating, N=REPS)

@wence-
Copy link
Contributor Author

wence- commented Dec 14, 2023

/merge

@harrism
Copy link
Member

harrism commented Dec 14, 2023

Strange, I don't see any reason the new code should be faster.

@wence-
Copy link
Contributor Author

wence- commented Dec 14, 2023

Strange, I don't see any reason the new code should be faster.

Since we wrote the original code in cython the transpiled C++ code was a little bit heavier weight (even with all the type annotations), it looked like this:

void *allocate(size_t size, int device, void *stream) {
  void *ptr = NULL;
  PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure();
  try {
    rmm::mr::device_memory_resource *mr = rmm::mr::get_current_device_resource();
    rmm::cuda_stream_view stream_view = rmm::cuda_stream_view((cudaStream_t)stream);
    ptr = mr->allocate(size, stream_view);
  } catch(...) {
    __Pyx_CppExn2PyErr();
    goto error;
  }

  goto success;

  error:;
  const char *filename = NULL;
  __Pyx_AddTraceback("rmm._lib.old_torch.allocate", filename);

  success:;
  __Pyx_PyGILState_Release(__pyx_gilstate_save);
  return ptr;
}

In particular, the main difference is that we always used to pay to acquire and release the Python GIL. Aside, this was probably another bug in waiting in multithreaded environments.

@rapids-bot rapids-bot bot merged commit 0b931f6 into rapidsai:branch-24.02 Dec 14, 2023
47 checks passed
@wence- wence- deleted the wence/fix/1405 branch December 14, 2023 13:29
gmarkall added a commit to gmarkall/rapids-compose that referenced this pull request Dec 20, 2023
The RMM Python source now contains a non-generated C++ file,
`_torch_allocator.cpp`, from rapidsai/rmm#1407, so we need to avoid
deleting it when cleaning rmm-python. This approach might not be the
best if RMM Python goes on to include more C++ sources, but this fixes
the clean / build process for now.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CMake non-breaking Non-breaking change Python Related to RMM Python API
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

[BUG] Unexpected memory usage on GPU0
3 participants