diff --git a/backends/xnnpack/runtime/XNNWeightsCache.cpp b/backends/xnnpack/runtime/XNNWeightsCache.cpp index 1a230c19976..06216b721c1 100644 --- a/backends/xnnpack/runtime/XNNWeightsCache.cpp +++ b/backends/xnnpack/runtime/XNNWeightsCache.cpp @@ -11,6 +11,9 @@ #include #include #include +#include +#include +#include #include #include @@ -155,21 +158,45 @@ size_t XNNWeightsCache::look_up( return packed_weight_entry->second.offset; } +/** + * Reserve space in the weight cache for n bytes of weight data, aligned to + * context->kPackedAllocationAlignment. This function will return nullptr if + * the allocation fails. + */ void* XNNWeightsCache::reserve_space(XNNWeightsCache* context, size_t n) { // MemoryAllocator* allocator = context->runtime_allocator_; // void* reserved_pointer = allocator->allocate(n, // context->kPackedAllocationAlignment); // return reserved_pointer; - std::string data_container; - data_container.resize(n + context->kPackedAllocationAlignment); - void* maybe_aligned_space = data_container.data(); - void* aligned_space = (void*)((intptr_t)maybe_aligned_space + 64 - - (intptr_t)maybe_aligned_space % 64); - - context->packed_pointer_to_container_[aligned_space] = - std::move(data_container); - return aligned_space; + try { + std::string data_container; + size_t raw_allocation_size = n + context->kPackedAllocationAlignment - 1; + data_container.resize(raw_allocation_size); + + void* maybe_aligned_space = data_container.data(); + void* aligned_space = std::align( + context->kPackedAllocationAlignment, + n, + maybe_aligned_space, + raw_allocation_size // Note that std::align mutates this value. + ); + ET_CHECK_MSG(aligned_space != nullptr, "Memory alignment failed."); + + context->packed_pointer_to_container_[aligned_space] = + std::move(data_container); + return aligned_space; + } catch (std::bad_alloc& e) { + // XNNPACK can gracefully handle allocation failures, so return nullptr. + // We want to be able to recover from a failed attempt to load a large + // model without a crash. + ET_LOG( + Error, + "XNN weight cache failed to allocate %zu bytes: %s.", + n, + e.what()); + return nullptr; + } } size_t XNNWeightsCache::look_up_or_insert(