Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make DeviceCachingAllocator's error handling more defensive and a bit easier to read #51158

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
127 changes: 72 additions & 55 deletions c10/cuda/CUDACachingAllocator.cpp
Expand Up @@ -57,12 +57,12 @@ namespace {

using stream_set = std::unordered_set<cuda::CUDAStream>;

constexpr size_t kMinBlockSize = 512; // all sizes are rounded to at least 512 bytes
constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 MiB
constexpr size_t kSmallBuffer = 2097152; // "small" allocations are packed in 2 MiB blocks
constexpr size_t kLargeBuffer = 20971520; // "large" allocations may be packed in 20 MiB blocks
constexpr size_t kMinBlockSize = 512; // all sizes are rounded to at least 512 bytes
constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 MiB
constexpr size_t kSmallBuffer = 2097152; // "small" allocations are packed in 2 MiB blocks
constexpr size_t kLargeBuffer = 20971520; // "large" allocations may be packed in 20 MiB blocks
constexpr size_t kMinLargeAlloc = 10485760; // allocations between 1 and 10 MiB may use kLargeBuffer
constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB
constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB

typedef std::bitset<static_cast<size_t>(StatType::NUM_TYPES)> StatTypes;

Expand Down Expand Up @@ -242,56 +242,57 @@ class DeviceCachingAllocator {
// Free all non-split cached blocks and retry alloc.
|| (free_cached_blocks() && alloc_block(params, true));

TORCH_INTERNAL_ASSERT((!block_found && params.err != cudaSuccess) || params.block);
if (!block_found) {
if (params.err == cudaErrorMemoryAllocation) {
size_t device_free;
size_t device_total;
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
std::string allowed_info;

if (set_fraction) {
allowed_info = format_size(allowed_memory_maximum) + " allowed; ";
}
// For any error code other than cudaErrorMemoryAllocation,
// alloc_block should have thrown an exception already.
TORCH_INTERNAL_ASSERT(params.err == cudaErrorMemoryAllocation);

stats.num_ooms += 1;

// "total capacity": total global memory on GPU
// "allowed": memory is allowed to use, which set by fraction.
// "already allocated": memory allocated by the program using the
// caching allocator
// "free": free memory as reported by the CUDA API
// "cached": memory held by the allocator but not used by the program
//
// The "allocated" amount does not include memory allocated outside
// of the caching allocator, such as memory allocated by other programs
// or memory held by the driver.
//
// The sum of "allocated" + "free" + "cached" may be less than the
// total capacity due to memory held by the driver and usage by other
// programs.
//
// Note that at this point free_cached_blocks has already returned all
// possible "cached" memory to the driver. The only remaining "cached"
// memory is split from a larger block that is partially in-use.
TORCH_CHECK_WITH(CUDAOutOfMemoryError, false,
"CUDA out of memory. Tried to allocate ", format_size(alloc_size),
" (GPU ", device, "; ",
format_size(device_total), " total capacity; ",
format_size(stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current),
" already allocated; ",
format_size(device_free), " free; ",
allowed_info,
format_size(stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current),
" reserved in total by PyTorch)");
} else {
C10_CUDA_CHECK(params.err);
size_t device_free;
size_t device_total;
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
std::string allowed_info;

if (set_fraction) {
allowed_info = format_size(allowed_memory_maximum) + " allowed; ";
}

stats.num_ooms += 1;

// "total capacity": total global memory on GPU
// "allowed": memory is allowed to use, which set by fraction.
// "already allocated": memory allocated by the program using the
// caching allocator
// "free": free memory as reported by the CUDA API
// "cached": memory held by the allocator but not used by the program
//
// The "allocated" amount does not include memory allocated outside
// of the caching allocator, such as memory allocated by other programs
// or memory held by the driver.
//
// The sum of "allocated" + "free" + "cached" may be less than the
// total capacity due to memory held by the driver and usage by other
// programs.
//
// Note that at this point free_cached_blocks has already returned all
// possible "cached" memory to the driver. The only remaining "cached"
// memory is split from a larger block that is partially in-use.
TORCH_CHECK_WITH(CUDAOutOfMemoryError, false,
"CUDA out of memory. Tried to allocate ", format_size(alloc_size),
" (GPU ", device, "; ",
format_size(device_total), " total capacity; ",
format_size(stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current),
" already allocated; ",
format_size(device_free), " free; ",
allowed_info,
format_size(stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current),
" reserved in total by PyTorch)");
}

TORCH_INTERNAL_ASSERT(params.err == cudaSuccess &&
params.block != nullptr &&
params.block->ptr != nullptr);
Block* block = params.block;
Block* remaining = nullptr;
TORCH_INTERNAL_ASSERT(block);

const bool already_split = block->is_split();
if (should_split(block, size)) {
Expand Down Expand Up @@ -647,30 +648,46 @@ class DeviceCachingAllocator {
}

bool alloc_block(AllocParams& p, bool isRetry) {
// Defensively checks for preexisting CUDA error state.
C10_CUDA_CHECK(cudaGetLastError());

size_t size = p.alloc_size;
void* ptr;

if (isRetry) {
stats.num_alloc_retries += 1;
}

if (set_fraction && total_allocated_memory + size > allowed_memory_maximum) {
p.err = cudaErrorMemoryAllocation;
return false;
} else {
p.err = cudaMalloc(&ptr, size);
}

if (p.err != cudaSuccess) {
if (!isRetry || p.err == cudaErrorMemoryAllocation)
cudaGetLastError(); // clear CUDA error
return false;
if (p.err != cudaSuccess) {
if (p.err == cudaErrorMemoryAllocation) {
// If this is the first attempt (!isRetry), we can forgive and clear CUDA's
// internal error state.
// If this is the second attempt (isRetry), malloc's TORCH_CHECK_WITH will take
// over to throw a helpful exception. The user can choose to catch the exception,
// free some stuff in their script, and attempt their allocation again.
// In this case, we can also forgive and clear CUDA's internal error state.
cudaGetLastError();
} else {
// If the error's unrelated to memory allocation, we should throw immediately.
C10_CUDA_CHECK(p.err);
}
return false;
}
}

total_allocated_memory += size;
p.block = new Block(p.device(), p.stream(), size, p.pool, (char*)ptr);
update_stat_array(stats.segment, 1, p.stat_types);
update_stat_array(stats.reserved_bytes, size, p.stat_types);

return (p.block != nullptr);
// p.block came from new, not cudaMalloc. It should not be nullptr here.
TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr);
return true;
}

bool free_cached_blocks()
Expand Down