Skip to content

Commit

Permalink
Methods for checking CUDA memory usage (#4511)
Browse files Browse the repository at this point in the history
* gpu mem allocated

* add test

* addressed some of @apaszke 's comments

* cache stats

* add more comments about test
  • Loading branch information
ssnl authored and soumith committed Jan 9, 2018
1 parent f4a75de commit 5918243
Show file tree
Hide file tree
Showing 9 changed files with 361 additions and 8 deletions.
74 changes: 72 additions & 2 deletions aten/src/THC/THCCachingAllocator.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "THCCachingAllocator.h"

#include <cuda_runtime_api.h>
#include <algorithm>
#include <deque>
#include <map>
#include <memory>
Expand Down Expand Up @@ -43,6 +44,35 @@ const size_t kRoundSmall = 512; // round up small allocs to 512 bytes
const size_t kRoundLarge = 131072; // round up large allocs to 128 KiB
const size_t kSmallAlloc = 1048576; // largest "small" allocation is 1 MiB

struct DeviceStats {
uint64_t amount_allocated; // total amount allocated in bytes
uint64_t max_amount_allocated; // max total amount allocated in bytes
uint64_t amount_cached; // total amount in cache in bytes
uint64_t max_amount_cached; // max total amount in cache in bytes

DeviceStats() :
amount_allocated(0), max_amount_allocated(0),
amount_cached(0), max_amount_cached(0) { }

void increaseAllocated(size_t delta) {
amount_allocated += delta;
max_amount_allocated = std::max(max_amount_allocated, amount_allocated);
}

void decreaseAllocated(size_t delta) {
amount_allocated -= delta;
}

void increaseCached(size_t delta) {
amount_cached += delta;
max_amount_cached = std::max(max_amount_cached, amount_cached);
}

void decreaseCached(size_t delta) {
amount_cached -= delta;
}
};

struct Block {
int device; // gpu
cudaStream_t stream; // allocation stream
Expand Down Expand Up @@ -80,6 +110,9 @@ struct THCCachingAllocator
typedef bool (*Comparison)(const Block*, const Block*);
typedef std::set<Block*, Comparison> FreeBlocks;

// device statistics
std::vector<DeviceStats> device_stats;

// lock around all operations
std::mutex mutex;

Expand All @@ -102,6 +135,14 @@ struct THCCachingAllocator
large_blocks(BlockComparator),
small_blocks(BlockComparator) {}

DeviceStats &get_stats_for_device(int device) {
THAssert(device >= 0);
if ((size_t) device >= device_stats.size()) {
device_stats.resize(device + 1);
}
return device_stats.at(device);
}

/** allocates a block which is safe to use from the provided stream */
cudaError_t malloc(void** devPtr, size_t size, cudaStream_t stream)
{
Expand All @@ -121,6 +162,8 @@ struct THCCachingAllocator
size = round_size(size);
bool small = size <= kSmallAlloc;

DeviceStats &stats = get_stats_for_device(device);

Block search_key(device, stream, size);
auto& free_blocks = small ? large_blocks : small_blocks;

Expand All @@ -138,6 +181,7 @@ struct THCCachingAllocator
if (err != cudaSuccess) {
return err;
}
stats.increaseCached(alloc_size);
block = new Block(device, stream, alloc_size, (char*)ptr);
}

Expand All @@ -161,6 +205,8 @@ struct THCCachingAllocator
allocated_blocks[block->ptr] = block;

*devPtr = (void*)block->ptr;

stats.increaseAllocated(block->size);
return cudaSuccess;
}

Expand All @@ -180,6 +226,7 @@ struct THCCachingAllocator
allocated_blocks.erase(it);
block->allocated = false;

get_stats_for_device(block->device).decreaseAllocated(block->size);
if (!block->stream_uses.empty()) {
return insert_events(block);
}
Expand Down Expand Up @@ -358,6 +405,7 @@ struct THCCachingAllocator
if (err != cudaSuccess) {
return err;
}
get_stats_for_device(block->device).decreaseCached(block->size);
auto cur = it;
++it;
blocks.erase(cur);
Expand Down Expand Up @@ -496,8 +544,30 @@ THC_API std::mutex* THCCachingAllocator_getCudaFreeMutex()
return &caching_allocator.cuda_free_mutex;
}

THC_API cudaError_t THCCachingAllocator_emptyCache(void)
static inline void assertValidDevice(int device) {
int device_count;
THCudaCheck(cudaGetDeviceCount(&device_count));
THAssertMsg(0 <= device && device < device_count, "Invalid device argument.");
}

THC_API uint64_t THCCachingAllocator_currentMemoryAllocated(int device)
{
return caching_allocator.emptyCache();
assertValidDevice(device);
return caching_allocator.get_stats_for_device(device).amount_allocated;
}

THC_API uint64_t THCCachingAllocator_maxMemoryAllocated(int device) {
assertValidDevice(device);
return caching_allocator.get_stats_for_device(device).max_amount_allocated;
}

THC_API uint64_t THCCachingAllocator_currentMemoryCached(int device)
{
assertValidDevice(device);
return caching_allocator.get_stats_for_device(device).amount_cached;
}

THC_API uint64_t THCCachingAllocator_maxMemoryCached(int device) {
assertValidDevice(device);
return caching_allocator.get_stats_for_device(device).max_amount_cached;
}
5 changes: 4 additions & 1 deletion aten/src/THC/THCCachingAllocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
THC_API THCDeviceAllocator* THCCachingAllocator_get(void);
THC_API void* THCCachingAllocator_getBaseAllocation(void *ptr, size_t *size);
THC_API void THCCachingAllocator_recordStream(void *ptr, THCStream* stream);
THC_API cudaError_t THCCachingAllocator_emptyCache(void);
THC_API uint64_t THCCachingAllocator_currentMemoryAllocated(int device);
THC_API uint64_t THCCachingAllocator_maxMemoryAllocated(int device);
THC_API uint64_t THCCachingAllocator_currentMemoryCached(int device);
THC_API uint64_t THCCachingAllocator_maxMemoryCached(int device);

#if (__cplusplus >= 201103L) || (defined(_MSC_VER) && defined(__cplusplus))
THC_API std::mutex* THCCachingAllocator_getCudaFreeMutex();
Expand Down
4 changes: 4 additions & 0 deletions docs/source/cuda.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ Streams and events
Memory management
-----------------
.. autofunction:: empty_cache
.. autofunction:: memory_allocated
.. autofunction:: max_memory_allocated
.. autofunction:: memory_cached
.. autofunction:: max_memory_cached

NVIDIA Tools Extension (NVTX)
-----------------------------
Expand Down
8 changes: 6 additions & 2 deletions docs/source/notes/cuda.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,12 @@ Memory management
PyTorch use a caching memory allocator to speed up memory allocations. This
allows fast memory deallocation without device synchronizations. However, the
unused memory managed by the allocator will still show as if used in
`nvidia-smi`. Calling :meth:`~torch.cuda.empty_cache` can release all unused
cached memory from PyTorch so that those can be used by other GPU applications.
`nvidia-smi`. You can use :meth:`~torch.cuda.memory_allocated` and
:meth:`~torch.cuda.max_memory_allocated` to monitor memory occupied by
tensors, and use :meth:`~torch.cuda.memory_cached` and
:meth:`~torch.cuda.max_memory_cached` to monitor memory managed by the caching
allocator. Calling :meth:`~torch.cuda.empty_cache` can release all unused cached
memory from PyTorch so that those can be used by other GPU applications.


Best practices
Expand Down
2 changes: 1 addition & 1 deletion docs/source/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ There are a few more in-place random sampling functions defined on Tensors as we
- :func:`torch.Tensor.log_normal_` - samples from the log-normal distribution
- :func:`torch.Tensor.normal_` - in-place version of :func:`torch.normal`
- :func:`torch.Tensor.random_` - numbers sampled from the discrete uniform distribution
- :func:`torch.Tensor.uniform_` - numbers sampled from the uniform distribution
- :func:`torch.Tensor.uniform_` - numbers sampled from the continuous uniform distribution


Serialization
Expand Down
165 changes: 165 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,171 @@ def tmp(self):


class TestCuda(TestCase):

@staticmethod
def _test_memory_stats_generator(self, device=None, N=35):
if device is None:
device = torch.cuda.current_device()

m0 = torch.cuda.memory_allocated(device)
last_m_arr = [torch.cuda.memory_allocated(device)]
max_m_arr = [torch.cuda.max_memory_allocated(device)]
last_c_arr = [torch.cuda.memory_cached(device)]
max_c_arr = [torch.cuda.max_memory_cached(device)]

def alloc(*size):
with torch.cuda.device(device):
# NOTE: do **not** use methods that can have additional
# memory overhead, e.g., inplace random sampling methods.
# they can leave some memory occupied even after being
# deallocated, e.g., initialized RNG state, causing some
# memory checks below to fail.
return torch.cuda.FloatTensor(*size)

def assert_change(comp=1, empty_cache=False):
# comp > 0: increased
# comp = 0: equal
# comp < 0: decreased
new_m = torch.cuda.memory_allocated(device)
new_max_m = torch.cuda.max_memory_allocated(device)
if comp > 0:
self.assertGreater(new_m, last_m_arr[0])
elif comp < 0:
self.assertLess(new_m, last_m_arr[0])
else:
self.assertEqual(new_m, last_m_arr[0])
self.assertLessEqual(new_m, new_max_m)
self.assertGreaterEqual(new_max_m, max_m_arr[0])
last_m_arr[0] = new_m
max_m_arr[0] = new_max_m

new_c = torch.cuda.memory_cached(device)
new_max_c = torch.cuda.max_memory_cached(device)
# emptying cache may happen (due to allocation or empty_cache), so
# we can't assert new_c >= last_c
self.assertLessEqual(new_c, new_max_c)
self.assertGreaterEqual(new_max_c, max_c_arr[0])
last_c_arr[0] = new_c
max_c_arr[0] = new_max_c

if empty_cache:
torch.cuda.empty_cache()
new_c = torch.cuda.memory_cached(device)
new_max_c = torch.cuda.max_memory_cached(device)
self.assertLessEqual(new_c, last_c_arr[0])
self.assertLessEqual(new_c, new_max_c)
self.assertEqual(new_max_c, max_c_arr[0])
last_c_arr[0] = new_c

assert_change(0)
assert_change(0)
yield

tensors1 = [alloc(1), alloc(10, 20), alloc(200, 300, 2000)]
m1 = torch.cuda.memory_allocated(device)
assert_change(1)
yield

tensors2 = []

for i in range(1, int(N / 2) + 1):
# small ones
tensors2.append(alloc(i, i * 4))
assert_change(1)
yield

for i in range(5, int(N / 2) + 5):
# large ones
tensors2.append(alloc(i, i * 7, i * 9, i * 11))
assert_change(1)
yield

tensors2.append(alloc(0, 0, 0))
assert_change(0)
yield

permute = []
for i in torch.randperm(len(tensors2)):
permute.append(tensors2[i])
assert_change(0)
yield

del tensors2
assert_change(0)
yield
tensors2 = permute
assert_change(0)
yield
del permute
assert_change(0)
yield

for i in range(int(N / 2)):
x = tensors2[i].numel()
del tensors2[i]
assert_change(-x) # in case that tensors2[i] is empty
yield

for i in range(2, int(2 * N / 3) + 2):
tensors2.append(alloc(i, i * 3, i * 8))
assert_change(1)
yield

del tensors2
assert_change(-1)
assert_change(0)
self.assertEqual(torch.cuda.memory_allocated(device), m1)
yield True

del tensors1
assert_change(-1)
self.assertEqual(torch.cuda.memory_allocated(device), m0)

# test empty_cache
assert_change(0, empty_cache=True)

def test_memory_stats(self):
torch.cuda.empty_cache()
for _ in self._test_memory_stats_generator(self):
pass

@unittest.skipIf(torch.cuda.device_count() < 2, "only one GPU detected")
def test_memory_stats_multigpu(self):
# advance a generator with a end flag
def advance(gen, end):
if not end:
try:
next(gen)
except StopIteration:
end = True
return end

# interlace
torch.cuda.empty_cache()
gen0 = self._test_memory_stats_generator(self, device=0, N=35)
gen1 = self._test_memory_stats_generator(self, device=1, N=35)
end0 = end1 = False
while not (end0 and end1):
end0 = advance(gen0, end0)
end1 = advance(gen1, end1)

# semi-random order
torch.cuda.empty_cache()
gen0 = self._test_memory_stats_generator(self, device=0, N=35)
gen1 = self._test_memory_stats_generator(self, device=1, N=35)
end0 = end1 = False

while not (end0 and end1):
end0 = advance(gen0, end0)
if not end0:
gen1_max_times = torch.LongTensor(1).random_(0, 3)[0]
else:
gen1_max_times = float('inf')
t = 0
while t < gen1_max_times and not end1:
end1 = advance(gen1, end1)
t += 1

@unittest.skipIf(torch.cuda.device_count() < 2, "only one GPU detected")
def _test_autogpu(self, TensorCtor):
x = TensorCtor().cuda()
Expand Down
4 changes: 2 additions & 2 deletions torch/_tensor_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1795,10 +1795,10 @@ def callable(a, b) -> number
r"""
uniform_(from=0, to=1) -> Tensor
Fills :attr:`self` tensor with numbers sampled from the uniform distribution:
Fills :attr:`self` tensor with numbers sampled from the continuous uniform
distribution:
.. math:
P(x) = \dfrac{1}{to - from}
""")

Expand Down
Loading

0 comments on commit 5918243

Please sign in to comment.