Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RMM PyTorch allocator #1168

Merged
merged 18 commits into from
Jan 5, 2023
Merged

Conversation

shwina
Copy link
Contributor

@shwina shwina commented Nov 29, 2022

Closes #1144

This PR adds an RMM-based allocator for PyTorch, rmm.rmm_torch_allocator.

This enables, e.g., using the same memory pool in code that uses both RAPIDS and PyTorch. It also enables PyTorch to use all of the different memory resources provided by RMM. For example:

import rmm                                                                                                                                                                                                 
import torch

torch.cuda.memory.change_current_allocator(rmm.rmm_torch_allocator)
                                                                                                                                                                                                           
base_mr = rmm.mr.CudaMemoryResource()                                                                                                                                                                      
                                                                                                                                                                                                           
def allocate_func(size):                                                                                                                                                                                   
    print(f"Allocating {size} bytes")                                                                                                                                                                      
    return base_mr.allocate(size)                                                                                                                                                                          
                                                                                                                                                                                                           
def deallocate_func(ptr, size):                                                                                                                                                                            
    print(f"Deallocating {size} bytes")                                                                                                                                                                    
    return base_mr.deallocate(ptr, size)                                                                                                                                                                   
                                                                                                                                                                                                           
rmm.mr.set_current_device_resource(                                                                                                                                                                        
    rmm.mr.CallbackMemoryResource(allocate_func, deallocate_func)                                                                                                                                          
)                                                                                                                                                                                                          
                                                                                                                                                                                                           
x = torch.tensor([1, 2]).cuda()                                                                                                                                                                            
del x                                                                                                                                                                                                      
y = torch.tensor([1, 2, 3]).cuda()                                                                                                                                                                         
del y                                                                                                                                                                                                      

Output:

Allocating 16 bytes
Deallocating 16 bytes
Allocating 24 bytes
Deallocating 24 bytes

@github-actions github-actions bot added CMake Python Related to RMM Python API labels Nov 29, 2022
@shwina shwina added feature request New feature or request non-breaking Non-breaking change labels Nov 30, 2022
@shwina shwina marked this pull request as ready for review November 30, 2022 21:05
@shwina shwina requested review from a team as code owners November 30, 2022 21:05
Copy link
Member

@leofang leofang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, left a few questions.

except ImportError:
rmm_torch_allocator = None
else:
_alloc_free_lib_path = rmm._lib.torch_allocator.__file__
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is neat!

python/rmm/tests/test_rmm.py Outdated Show resolved Hide resolved
Comment on lines +246 to +250
rmm_torch_allocator = CUDAPluggableAllocator(
_alloc_free_lib_path,
alloc_fn_name="allocate",
free_fn_name="deallocate",
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Would this honor rmm.reinitialize() if a user changes the MR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rmm.reinitialize() resets the default memory resource used by RMM. Each call to allocate() and deallocate() queries the default memory resource via a call to get_current_device_resource(), so -- yes.

Copy link
Contributor

@bdice bdice left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome. A few comments.



cdef extern from "rmm/mr/device/per_device_resource.hpp" namespace "rmm" nogil:
cdef device_memory_resource* get_current_device_resource \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this use from rmm._lib.memory_resource cimport get_current_device_resource like in device_buffer.pyx?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah I think we should have a single declaration in memory_resource.pxd

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a new per_device_resource.pxd file from which we can cimport the function.

Note that we cannot get away with declaring get_current_device_resource in memory_resource.pxd, because memory_resource exports a cpdef function also named get_current_device_resource that is a wrapper around the C++ function.


cuda_mr = rmm.mr.CudaMemoryResource()
@pytest.fixture
def stats_mr():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be set up as a yield fixture and reset the current device resource afterward?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have an autouse function scoped fixture that does that: https://github.com/rapidsai/rmm/blob/branch-23.02/python/rmm/tests/test_rmm.py#L46. I'm guessing that should just work as expected?

Comment on lines 929 to 930
except RuntimeError:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the except/pass here for? Can we call change_current_allocator in an else: block from the first try:?

try:
    from torch.cuda.memory import change_current_allocator
except ImportError:
    pytest.skip("pytorch pluggable allocator not available")
else:
    change_current_allocator(rmm.rmm_torch_allocator)

Alternatively can we use pytest.importorskip?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I think it's because a RuntimeError is raised if you set the torch allocator twice in the same session. Maybe we should use a session scoped fixture instead?

@harrism
Copy link
Member

harrism commented Dec 13, 2022

I didn't see any significant documentation of the new functionality added in this PR. Do you plan to add more?

@shwina
Copy link
Contributor Author

shwina commented Dec 13, 2022

Thanks @harrism -- I added a blurb in our README.md, which is where we also have documentation on how to use RMM + Numba or RMM + CuPy

@codecov-commenter
Copy link

codecov-commenter commented Dec 13, 2022

Codecov Report

❗ No coverage uploaded for pull request base (branch-23.02@c5c02fc). Click here to learn what that means.
Patch has no changes to coverable lines.

Additional details and impacted files
@@              Coverage Diff               @@
##             branch-23.02   #1168   +/-   ##
==============================================
  Coverage                ?   0.00%           
==============================================
  Files                   ?       6           
  Lines                   ?     421           
  Branches                ?       0           
==============================================
  Hits                    ?       0           
  Misses                  ?     421           
  Partials                ?       0           

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

README.md Show resolved Hide resolved
Copy link
Member

@harrism harrism left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New docs show the power of RMM to bridge the memory needs of RAPIDS, PyTorch (and Cupy, and Numba). Nice!

@jakirkham
Copy link
Member

It's really great seeing this is now possible! 🎉

@VibhuJawa
Copy link
Member

Would it be possible to merge this soon. Want to start testing this right away for cugraph_dgl .

@vyasr
Copy link
Contributor

vyasr commented Dec 19, 2022

Planning to review this later today.

…idsai#1170)

Closes rapidsai#1169.

Essentially, we are running into the situation described in https://cython.readthedocs.io/en/latest/src/userguide/extension_types.html#disabling-cycle-breaking-tp-clear with `UpstreamResourceAdaptor`.

The solution is to prevent clearing of `UpstreamResourceAdaptor` objects by decorating them with `no_gc_clear`.

Cython calls out the following:

> If you use no_gc_clear, it is important that any given reference cycle contains at least one object without no_gc_clear. Otherwise, the cycle cannot be broken, which is a memory leak.

The other object in RMM that we mark `@no_gc_clear` is `DeviceBuffer`, and a `DeviceBuffer` can keep a reference to an `UpstreamResourceAdaptor`. But, an `UpstreamResourceAdaptor` cannot keep a reference to a `DeviceBuffer`, so instances of the two cannot form a reference cycle AFAICT.

Authors:
  - Ashwin Srinath (https://github.com/shwina)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)
  - Mark Harris (https://github.com/harrism)

URL: rapidsai#1170
@shwina shwina requested a review from a team as a code owner December 19, 2022 20:16
@ajschmidt8 ajschmidt8 removed the request for review from a team December 19, 2022 22:13
@ajschmidt8
Copy link
Member

Removing ops-codeowners from the required reviews since it doesn't seem there are any file changes that we're responsible for. Feel free to add us back if necessary.

@vyasr
Copy link
Contributor

vyasr commented Dec 20, 2022

I spent most of my time with this PR today helping Ashwin troubleshoot the test failures and didn't get around to reviewing. Probably won't get back to this until Wednesday or so.



cdef public void* allocate(
ssize_t size, int device, void* stream
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: device is ignored by design? (I was reviewing cupy/cupy#7210 and noticed this.)

Copy link
Contributor Author

@shwina shwina Dec 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, great catch! This brings out a subtle problem:

In RMM, each device has its own memory resource. Thus, to do the allocation on a specified device with RMM, I would write the torch allocate function like this:

cdef public void* allocate(ssize_t size, int device, void* stream) except * with gil:                                                                                                                                                                                       
    cdef device_memory_resource* mr = get_per_device_resource(device)                                                                                                                            
    return mr[0].allocate(size, <cudaStream_t> stream)

Unforunately, the deallocation function does not accept a device argument, so we cannot retrieve the memory resource that was used for allocation:

void deallocate(void* ptr, ssize_t size, void* stream)

I don't really see a way around this other than for the deallocate signature to include the device argument. cc: @emcastillo

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would love to know too. TBH I am puzzled by PyTorch's (long-time) behavior of asking for device. It should just honor the current device...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll submit a follow-up PR adding support for device, once pytorch/pytorch#91398 is merged.

@jakirkham jakirkham requested a review from bdice January 4, 2023 19:19
Copy link
Contributor

@bdice bdice left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One suggestion, otherwise LGTM!

Comment on lines +10 to +13
try:
from torch.cuda.memory import change_current_allocator
except ImportError:
pytest.skip("pytorch pluggable allocator not available")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a pytest utility for this if you want to use it. pytest.importorskip

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that handle importing specific functions?

This is using importorskip for torch generally above. The torch.cuda.memory module has been around for a while. Though the functionality we need from it is pretty new.

Maybe in the future this could require a specific PyTorch version. There doesn't seem to be one yet that has what we need though.

Copy link
Contributor

@bdice bdice Jan 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, pytest.importorskip only handles modules and you have to use attribute accessors to get functions. It's kind of messy. The current solution is probably easier to read, let's keep it as-is.

Copy link
Contributor

@vyasr vyasr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple of minor questions, this is awesome!

Comment on lines 9 to 17
) except * with gil:
cdef device_memory_resource* mr = get_current_device_resource()
return mr[0].allocate(size, <cudaStream_t> stream)

cdef public void deallocate(
void* ptr, ssize_t size, void* stream
) except * with gil:
cdef device_memory_resource* mr = get_current_device_resource()
mr[0].deallocate(ptr, size, <cudaStream_t> stream)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these gil requiring functions? It seems like it's all pure C code here, no Python objects etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear, if they're getting called from GIL-requiring code in PyTorch, so be it. I just don't see a reason that these functions need to explicitly acquire the GIL. If PyT can call these in a nogil context, is there a reason for us not to allow that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the allocate() and deallocate() methods can involve Python operations on Python objects, e.g., in CallbackMemoryResource or FailureCallbackResourceAdaptor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking that would be automatically be handled when those callbacks are invoked since those callbacks are stored as Python objects. Those are stored as Python objects in the class, so any interaction with them should reacquire the GIL already, right? I guess the potential issue is that we cast these to void * pointers before passing them to the C++ classes, so at the point of the call we've lost Cython's safety net. Is that right? If so, we should consider (out of scope for this PR of course) inserting the necessary Python C API calls into the relevant rmm C++ classes i.e. in failure_callback_resource_adaptor::do_allocate.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was my expectation as well. If the callback touches Python objects, shouldn't it be the responsibility of the callback to acquire/release the GIL?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My proposal above was wrong, there's no reason to embed this information in librmm where the callbacks could be anything (not necessarily Python objects). However, it should be the responsibility of the callbacks in rmm's Cython code to acquire the GIL as needed, and we do appear to do this correctly already. The _oom_callback_function used by the FailureCallbackResourceAdaptor acquires the GIL before calling the user-provided callback, as do both the allocate and deallocate callbacks used by the CallbackMemoryResource.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup -- the GIL should neither be released in C++, nor can it be released in Python. The Cython "wrapper" functions are what need to take on the responsibility of handling the GIL.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Vyas and I did a bit more exploration of exactly why we need a with gil here and ended up quite deep in the CPython and Cython internals (still without a clear answer though).

The symptom though is clear. If you take the example I have in the PR description:

import rmm                                                                                                                                                                                                 
import torch

torch.cuda.memory.change_current_allocator(rmm.rmm_torch_allocator)
                                                                                                                                                                                                           
base_mr = rmm.mr.CudaMemoryResource()                                                                                                                                                                      
                                                                                                                                                                                                           
def allocate_func(size):                                                                                                                                                                                   
    print(f"Allocating {size} bytes")                                                                                                                                                                      
    return base_mr.allocate(size)                                                                                                                                                                          
                                                                                                                                                                                                           
def deallocate_func(ptr, size):                                                                                                                                                                            
    print(f"Deallocating {size} bytes")                                                                                                                                                                    
    return base_mr.deallocate(ptr, size)                                                                                                                                                                   
                                                                                                                                                                                                           
rmm.mr.set_current_device_resource(                                                                                                                                                                        
    rmm.mr.CallbackMemoryResource(allocate_func, deallocate_func)                                                                                                                                          
)                                                                                                                                                                                                          
                                                                                                                                                                                                           
x = torch.tensor([1, 2]).cuda()                                                                                                                                                                            
del x                                                                                                                                                                                                      
y = torch.tensor([1, 2, 3]).cuda()                                                                                                                                                                         
del y                     

And raise an error in allocate_func, while removing the with gil, you'll see that the error is uncaught and eventually this segfaults.

python/rmm/_lib/memory_resource.pxd Outdated Show resolved Hide resolved
python/rmm/tests/test_rmm_pytorch.py Outdated Show resolved Hide resolved
python/rmm/tests/test_rmm_pytorch.py Outdated Show resolved Hide resolved
python/rmm/tests/test_rmm_pytorch.py Show resolved Hide resolved
python/rmm/tests/test_rmm_pytorch.py Show resolved Hide resolved
from torch.cuda.memory import change_current_allocator
except ImportError:
pytest.skip("pytorch pluggable allocator not available")
change_current_allocator(rmm.rmm_torch_allocator)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this function behave if passed None (the case where the torch allocator hasn't been defined)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm - I wouldn't expect it to be None if we were able to import change_current_allocator, since the existence of change_current_allocator implies that rmm_torch_allocator was defined (although somewhat implicitly: change_current_allocator and CudaPluggableAllocator in PyTorch were introduced together).

Should we also skip this test if rmm.rmm_torch_allocator is None for some reason?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah OK I wasn't sure about that, I thought they weren't introduced entirely concurrently. Up to you on the extra skip, it sounds like it would be pedantically correct but not practically necessary.

try:
from torch.cuda.memory import CUDAPluggableAllocator
except ImportError:
rmm_torch_allocator = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this bite a user that accesses this attribute and passes it around thinking that it's fine when it's really just None? It might be safer to override __getattr__ for the module and have it raise an error to prevent the user from accessing this attribute when CUDAPluggableAllocator failed to import.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we alternatively pass on ImportError to achieve the same effect as defining that module __getattr__?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you'd get close to the same effect, just a slightly less user-friendly version. With a __getattr__ override you could provide a more friendly error message indicating that this happened because the torch allocator failed to import, whereas if you just avoid defining it the user will see an AttributeError without any additional diagnostics and may think it's a bug in rmm.

It's a very minor point though, I'm fine leaving this as is for now and only revisiting in the future if we get a lot of user questions about why the allocator is None.

README.md Outdated Show resolved Hide resolved
Copy link
Contributor

@vyasr vyasr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One outstanding question regarding the GIL, but it's not a blocker. We can address it later if you agree that it's an issue since it's not strictly related to the PyT allocator.

@shwina
Copy link
Contributor Author

shwina commented Jan 5, 2023

/merge

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CMake feature request New feature or request non-breaking Non-breaking change Python Related to RMM Python API
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

Test PyTorch's Pluggable CUDA allocator with RMM
9 participants